Tensorflow版Faster RCNN源码解析(TFFRCNN) (08) roi_data_layer/minibatch.py

本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记

---------------个人学习笔记---------------

----------------本文作者疆--------------

------点击此处链接至博客园原文------

 

"""Compute minibatch blobs for training a Fast R-CNN network."""

1.get_minibatch(roidb,num_classes)

训练时使用/不使用RPN时构造blobs字典作为网络输入,

使用RPN时,blobs字典包含'data'、'gt_boxes'、'gt_ishard'、'dontcare_area'、'im_info'、'im_name'字段,且仅支持single batch一张图像此情况下TRAIN.BATCH_SIZE=128、rois_per_image、fg_rois_per_image未被使用

不使用RPN时,blobs字典包含'data'、'rois'、'labels'和'bbox_targets'、'bbox_inside_weights'字段,支持多张图像构成batch

另roidb中为何记录的是gt roi信息?不明不使用RPN时各变量的具体意义? _get_next_minibatch(...)调用(roi_data_layer/layer.py中)

# 构造blobs作为网络输入,(训练阶段使用RPN时)roidb列表仅含一个元素,记录一张图像的roi相关信息,同时num_images为1
def get_minibatch(roidb, num_classes):
    """Given a roidb, construct a minibatch sampled from it."""
    num_images = len(roidb)
    # Sample random scales to use for each image in this batch
    # 默认TRAIN.SCALES = (600,)
    # 默认为全0的ndarray
    random_scale_inds = npr.randint(0, high=len(cfg.TRAIN.SCALES),
                                    size=num_images)
    # 默认TRAIN.BATCH_SIZE = 128,Minibatch size (number of regions of interest [ROIs])
    assert(cfg.TRAIN.BATCH_SIZE % num_images == 0), \
        'num_images ({}) must divide BATCH_SIZE ({})'. \
        format(num_images, cfg.TRAIN.BATCH_SIZE)
    # 每个minibatch中每张图像roi数量
    rois_per_image = cfg.TRAIN.BATCH_SIZE / num_images
    # 每个minibatch中每张图像前景roi数量,默认训练前景roi比率TRAIN.FG_FRACTION = 0.25  即1:3的比例
    # Fraction of minibatch that is labeled foreground
    fg_rois_per_image = np.round(cfg.TRAIN.FG_FRACTION * rois_per_image)
    # Get the input image blob, formatted for caffe
    # 图像数据blob(即im_blob)和对应缩放比例构成的列表
    im_blob, im_scales = _get_image_blob(roidb, random_scale_inds)
    # im_blob作为blobs中'data'字段
    blobs = {'data': im_blob}
    # 默认TRAIN.HAS_RPN = True
    # 使用RPN时,blobs字典包含'data'、'gt_boxes'、'gt_ishard'、'dontcare_area'、'im_info'、'im_name'字段
    if cfg.TRAIN.HAS_RPN:
        assert len(im_scales) == 1, "Single batch only" # 在使用RPN时仅支持一个batch且单张图像
        assert len(roidb) == 1, "Single batch only"
        # gt boxes: (x1, y1, x2, y2, cls)
        # 类别非0即为gt roi
        gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0]
        gt_boxes = np.empty((len(gt_inds), 5), dtype=np.float32)
        # 经过缩放后的gt roi
        gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :] * im_scales[0]
        gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds]
        # blobs字典,包含'data'、'gt_boxes'、'gt_ishard'、'dontcare_area'、'im_info'、'im_name'字段
        blobs['gt_boxes'] = gt_boxes
        blobs['gt_ishard'] = roidb[0]['gt_ishard'][gt_inds]  \
            if 'gt_ishard' in roidb[0] else np.zeros(gt_inds.size, dtype=int)
        # blobs['gt_ishard'] = roidb[0]['gt_ishard'][gt_inds]
        # dontcare_areas同样经过缩放
        blobs['dontcare_areas'] = roidb[0]['dontcare_areas'] * im_scales[0] \
            if 'dontcare_areas' in roidb[0] else np.zeros([0, 4], dtype=float)
        # im_info为(经过缩放后的图像构成的列表中)图像max shape(可见blob.py)和图像缩放比例
        blobs['im_info'] = np.array(
            [[im_blob.shape[1], im_blob.shape[2], im_scales[0]]],
            dtype=np.float32)
        # osp.basename()返回path最后的文件名
        blobs['im_name'] = os.path.basename(roidb[0]['image'])

    # 不使用RPN时,blobs字典包含'data'、'rois'、'labels'和'bbox_targets'、'bbox_inside_weights'、
    # 'bbox_outside_weights'(cfg.TRAIN.BBOX_REG=True时)字段
    # not using RPN
    # 未知不使用RPN时的具体意义?
    else:
        # Now, build the region of interest and label blobs
        # 构造rois_blob、labels_blob、bbox_targets_blob、bbox_inside_blob,待填充数据
        # 维度分别为N*5、N*1、N*4K、N*4K
        rois_blob = np.zeros((0, 5), dtype=np.float32)
        labels_blob = np.zeros((0), dtype=np.float32)
        bbox_targets_blob = np.zeros((0, 4 * num_classes), dtype=np.float32)
        bbox_inside_blob = np.zeros(bbox_targets_blob.shape, dtype=np.float32)
        # all_overlaps = []
        # 循环多张图像(使用RPN时仅支持single batch一张图像)
        for im_i in xrange(num_images):
            labels, overlaps, im_rois, bbox_targets, bbox_inside_weights \
                = _sample_rois(roidb[im_i], fg_rois_per_image, rois_per_image,
                               num_classes)
            # Add to RoIs blob  对im_rois进行缩放
            rois = _project_im_rois(im_rois, im_scales[im_i])
            # 类别索引/图像索引
            batch_ind = im_i * np.ones((rois.shape[0], 1))
            rois_blob_this_image = np.hstack((batch_ind, rois))
            # 填充rois_blob、labels_blob、bbox_targets_blob、bbox_inside_blob
            rois_blob = np.vstack((rois_blob, rois_blob_this_image))
            # Add to labels, bbox targets, and bbox loss blobs
            labels_blob = np.hstack((labels_blob, labels))
            bbox_targets_blob = np.vstack((bbox_targets_blob, bbox_targets))
            bbox_inside_blob = np.vstack((bbox_inside_blob, bbox_inside_weights))
            # all_overlaps = np.hstack((all_overlaps, overlaps))
        # For debug visualizations
        # _vis_minibatch(im_blob, rois_blob, labels_blob, all_overlaps)
        blobs['rois'] = rois_blob
        blobs['labels'] = labels_blob
        if cfg.TRAIN.BBOX_REG:
            blobs['bbox_targets'] = bbox_targets_blob
            blobs['bbox_inside_weights'] = bbox_inside_blob
            blobs['bbox_outside_weights'] = \
                np.array(bbox_inside_blob > 0).astype(np.float32)
    return blobs
# -*- coding:utf-8 -*-
# Author: WUJiang
# 测试功能

import numpy as np

roidb = [{'gt_classes': np.array([0, 1, 2, 3, 4])}]
a = np.where(roidb[0]['gt_classes'] != 0)  # np.where()处理ndarray而非list
# (array([1, 2, 3, 4], dtype=int64),)
print(a)
gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0]
# [1 2 3 4]
print(gt_inds)

2._sample_rois(roidb,fg_rois_per_image,rois_per_image,num_classes)

训练阶段不使用RPN时对rois进行foreground(与gt IOU超过0.5)和background(与gt IOU介于0.1---0.5) roi采样获取roi相关信息以构造blobs,返回对应正负样本roi的labels(0和1构成的元组), overlaps(各roi与gt的IOU), rois(坐标数据), bbox_targets(gt), bbox_inside_weights,被get_minibatch(...)函数调用(训练阶段不使用RPN时)

# 训练阶段不使用RPN时对rois进行foreground和background roi采样获取roi相关信息以构造blobs
def _sample_rois(roidb, fg_rois_per_image, rois_per_image, num_classes):
    """Generate a random sample of RoIs comprising foreground and background examples."""
    # label = class RoI has max overlap with
    labels = roidb['max_classes']
    overlaps = roidb['max_overlaps']
    rois = roidb['boxes']
    # Select foreground RoIs as those with >= FG_THRESH overlap
    # TRAIN.FG_THRESH = 0.5 与gt IOU超过规定阈值0.5的认为是前景roi
    fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
    # Guard against the case when an image has fewer than fg_rois_per_image
    # foreground RoIs
    fg_rois_per_this_image = np.minimum(fg_rois_per_image, fg_inds.size)
    # Sample foreground regions without replacement
    # npr.choice()函数返回元组的随机项
    if fg_inds.size > 0:
        fg_inds = npr.choice(
                fg_inds, size=fg_rois_per_this_image, replace=False)
    # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
    # 默认TRAIN.BG_THRESH_HI = 0.5、TRAIN.BG_THRESH_LO = 0.1
    # 与gt IOU介于阈值0.1---0.5的被认为是背景roi
    bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                       (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
    # Compute number of background RoIs to take from this image (guarding
    # against there being fewer than desired)
    bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
    bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
                                        bg_inds.size)
    # Sample foreground regions without replacement
    if bg_inds.size > 0:
        bg_inds = npr.choice(
                bg_inds, size=bg_rois_per_this_image, replace=False)
    # The indices that we're selecting (both fg and bg)
    # 前、背景roi索引(前、背景roi数量无比例要求?),np.append()函数组合数组
    keep_inds = np.append(fg_inds, bg_inds)
    # Select sampled values from various arrays
    # 根据前、背景roi索引取相关值
    labels = labels[keep_inds]
    # Clamp强加 labels for the background RoIs to 0
    labels[fg_rois_per_this_image:] = 0
    overlaps = overlaps[keep_inds]
    rois = rois[keep_inds]
    # _get_bbox_regression_labels()函数扩展成网络需要使用的维度,N*4 to N*4K
    bbox_targets, bbox_inside_weights = _get_bbox_regression_labels(
            roidb['bbox_targets'][keep_inds, :], num_classes)
    return labels, overlaps, rois, bbox_targets, bbox_inside_weights

3._get_image_blob(roidb,scale_inds)

调用utils/blob.py中prep_im_for_blob(im,cfg.PIXEL_MEANS,target_size,cfg.TRAIN.MAX_SIZE)和im_list_to_blob(processed_ims)函数构造网络输入的图像数据blob,其中roidb应为多张或一张,使用RPN时图像的gt roi相关信息构成的列表,列表中每个元素为一个字典,记录某图像的roi相关信息,含‘image’、‘flipped(未知图像/roi翻转)’、‘gt_classes’、‘boxes’、‘gt_ishard’、‘dontcare_areas’等字段,scale_inds为图像缩放索引构成的列表,processed_ims和im_scales列表分别存储各张缩放后图像和对应缩放比例,函数最后返回图像数据blob及im_scales列表,被get_minibatch(...)函数调用

def _get_image_blob(roidb, scale_inds):
    """Builds an input blob from the images in the roidb at the specified scales."""
    num_images = len(roidb)
    processed_ims = []
    im_scales = []
    for i in xrange(num_images):
        # 注意roidb的形式
        im = cv2.imread(roidb[i]['image'])
        # 翻转是指roi还是image?
        if roidb[i]['flipped']:
            im = im[:, ::-1, :]
        target_size = cfg.TRAIN.SCALES[scale_inds[i]]
        # 缩放后的图像、缩放比例  utils/blob.py中
        im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,
                                        cfg.TRAIN.MAX_SIZE)
        # processed_ims和im_scales列表分别存储各张缩放后图像和对应缩放比例 
        im_scales.append(im_scale)
        processed_ims.append(im)

    # Create a blob to hold the input images 
    # 构造网络输入的图像数据blob(已做减均值、缩放处理)   utils/blob.py中
    blob = im_list_to_blob(processed_ims)
    return blob, im_scales

4._project_im_rois(im_rois, im_scale_factor)

对roi缩放并返回,被get_minibatch(...)函数调用(训练阶段不使用RPN时)

def _project_im_rois(im_rois, im_scale_factor):
    """Project image RoIs into the rescaled重缩放 training image."""
    rois = im_rois * im_scale_factor
    return rois

5._get_bbox_regression_labels(bbox_target_data,num_classes)

将N*4的bbox_target_data和bbox_inside_weights扩展为N*4K(网络需要的形式,4K中仅4为有效信息,其余全0),被get_minibatch(...)函数(训练阶段不使用RPN时)

# 将N*4的bbox_target_data和bbox_inside_weights扩展为N*4K(网络需要的形式)
def _get_bbox_regression_labels(bbox_target_data, num_classes):
    """
    Bounding-box regression targets are stored in a compact紧凑的,小型的 form in the roidb.
    This function expands those targets into the 4-of-4*K representation used
    by the network (i.e. only one class has non-zero targets). The loss weights
    are similarly expanded.
    Returns:
        bbox_target_data (ndarray): N x 4K blob of regression targets
        bbox_inside_weights (ndarray): N x 4K blob of loss weights
    """
    clss = bbox_target_data[:, 0]
    bbox_targets = np.zeros((clss.size, 4 * num_classes), dtype=np.float32)
    bbox_inside_weights = np.zeros(bbox_targets.shape, dtype=np.float32)
    inds = np.where(clss > 0)[0]
    for ind in inds:
        cls = clss[ind]
        start = 4 * cls
        end = start + 4
        bbox_targets[ind, start:end] = bbox_target_data[ind, 1:]
        # 默认TRAIN.BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
        bbox_inside_weights[ind, start:end] = cfg.TRAIN.BBOX_INSIDE_WEIGHTS
    return bbox_targets, bbox_inside_weights

6._vis_minibatch(im_blob,rois_blob,labels_blob,overlaps)

在训练阶段不使用RPN时绘制roi相关信息,如类别、bbox矩形框等供代码调试,在get_minibatch(...)中被注释调用

# 绘制roi相关信息,如类别、bbox矩形框等
def _vis_minibatch(im_blob, rois_blob, labels_blob, overlaps):
    """Visualize a mini-batch for debugging."""
    import matplotlib.pyplot as plt
    for i in xrange(rois_blob.shape[0]):
        rois = rois_blob[i, :]
        im_ind = rois[0]
        roi = rois[1:]             
        im = im_blob[im_ind, :, :, :].transpose((1, 2, 0)).copy()
        # 加均值
        im += cfg.PIXEL_MEANS
        # 转换为BGR
        im = im[:, :, (2, 1, 0)]
        im = im.astype(np.uint8)
        cls = labels_blob[i]
        plt.imshow(im)
        print 'class: ', cls, ' overlap: ', overlaps[i]
        plt.gca().add_patch(
            plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0],
                          roi[3] - roi[1], fill=False,
                          edgecolor='r', linewidth=3)
            )
        plt.show()
posted @ 2019-08-07 07:45  JiangJ~  阅读(434)  评论(0编辑  收藏  举报