Fork me on Gitee

源码解析一 数据加载 数据增强 datasets.py augmentations.py

参考:https://blog.csdn.net/qq_38253797/article/details/119904518

本文不是简单的照搬,只是做一个引子,原文中比较详细,但是很多东西不是非常重要,本文基于yolov6,因此作者将 数据增强部分合并在了augmentations.py中,详细代码可在这里看

1.数据加载

这个函数会在train.py中被调用,用于生成Trainloader, dataset,testloader:

def create_dataloader(path, imgsz, batch_size, stride, single_cls=False,
                      hyp=None, augment=False, cache=False, pad=0.0, rect=False,
                      rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
    """在train.py中被调用,用于生成Trainloader, dataset,testloader
    自定义dataloader函数: 调用LoadImagesAndLabels获取数据集(包括数据增强) + 调用分布式采样器DistributedSampler +
                        自定义InfiniteDataLoader 进行永久持续的采样数据
    :param path: 图片数据加载路径 train/test  如: ../datasets/VOC/images/train2007
    :param imgsz: train/test图片尺寸(数据增强后大小) 640
    :param batch_size: batch size 大小 8/16/32
    :param stride: 模型最大stride=32   [32 16 8]
    :param single_cls: 数据集是否是单类别 默认False
    :param hyp: 超参列表dict 网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数
    :param augment: 是否要进行数据增强  True
    :param cache: 是否cache_images False
    :param pad: 设置矩形训练的shape时进行的填充 默认0.0
    :param rect: 是否开启矩形train/test  默认训练集关闭 验证集开启
    :param rank:  多卡训练时的进程编号 rank为进程编号  -1且gpu=1时不进行分布式  -1且多块gpu使用DataParallel模式  默认-1
    :param workers: dataloader的numworks 加载数据时的cpu进程数
    :param image_weights: 训练时是否根据图片样本真实框分布权重来选择图片  默认False
    :param quad: dataloader取数据时, 是否使用collate_fn4代替collate_fn  默认False
    :param prefix: 显示信息   一个标志,多为train/val,处理标签时保存cache文件会用到
    """

LoadImagesAndLabels

class LoadImagesAndLabels(Dataset):
    # for training/testing
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False,
                 image_weights=False, cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
        """
        初始化过程并没有什么实质性的操作,更多是一个定义参数的过程(self参数),以便在__getitem()__中进行数据增强操作,所以这部分代码只需要抓住self中的各个变量的含义就算差不多了
        self.img_files: {list: N} 存放着整个数据集图片的相对路径
        self.label_files: {list: N} 存放着整个数据集图片的相对路径
        cache label -> verify_image_label
        self.labels: 如果数据集所有图片中没有一个多边形label  labels存储的label就都是原始label(都是正常的矩形label)
                     否则将所有图片正常gt的label存入labels 不正常gt(存在一个多边形)经过segments2boxes转换为正常的矩形label
        self.shapes: 所有图片的shape
        self.segments: 如果数据集所有图片中没有一个多边形label  self.segments=None
                       否则存储数据集中所有存在多边形gt的图片的所有原始label(肯定有多边形label 也可能有矩形正常label 未知数)
        self.batch: 记载着每张图片属于哪个batch
        self.n: 数据集中所有图片的数量
        self.indices: 记载着所有图片的index
        self.rect=True时self.batch_shapes记载每个batch的shape(同一个batch的图片shape相同)

collate_fn 注意:这个函数一般是当调用了batch_size次 getitem 函数后才会调用一次这个函数,对batch_size张图片和对应的label进行打包。 强烈建议这里大家debug试试这里return的数据是不是我说的这样定义的。

collate_fn4

这里是yolo-v5作者实验性的一个代码 quad-collate function 当train.py的opt参数quad=True 则调用collate_fn4代替collate_fn。 作用:将4张mosaic图片[1, 3, 640, 640]合成一张大的mosaic图片[1, 3, 1280, 1280]。将一个batch的图片每四张处理, 0.5的概率将四张图片拼接到一张大图上训练, 0.5概率直接将某张图片上采样两倍训练。

img2label_paths

这个文件是根据数据集中所有图片的路径找到数据集中所有labels对应的路径。用在LoadImagesAndLabels模块的__init__函数中。

verify_image_label

这个函数用于检查每一张图片和每一张label文件是否完好。

图片文件: 检查内容、格式、大小、完整性

label文件: 检查每个gt必须是矩形(每行都得是5个数 class+xywh) + 标签是否全部>=0 + 标签坐标xywh是否归一化 + 标签中是否有重复的坐标

其他不常用的函数(在detect.py需要使用):

LoadWebcam:webcam数据流

LoadStreams:支持rtsp流推送

load_image:图片或者视频的文件夹载入

工具函数:

可以拆分出来自己使用,也可以在LoadImagesAndLabels函数中,编排自己的规则

flatten_recursive

这个模块是将一个文件路径中的所有文件复制到另一个文件夹中 即将image文件和label文件放到一个新文件夹中

extract_boxes

把目标检测数据集中的每一个gt拆解开 分类别存储到对应的文件当中,

如果目标检测数据集之前以及混合过,那么可以通过这个函数将其分开,不需要手动编写函数

autosplit

这个模块是进行自动划分数据集,weights=(0.9, 0.1, 0.0)为控制比例关系,

当然,这些函数的初衷是为,coco数据集准备的。

dataset_stats

这个模块是统计数据集的信息返回状态字典。包含: 每个类别的图片数量 + 每个类别的实例数量

使用建议:特别是在我们拿到数据集以后,需要了解数据类别的分布关系,经常需要手动编写代码查看分类,这个代码,可以直接统计总类别数目,

在最新的版本中,数据的统计图是通过这个函数得了的,缺点是比较小,最后会生成统计的json文件,带有自动下载工具

使用mosaic

在LoadImagesAndLabels函数中,对载入的图片算法可以使用load_mosaic load_mosaic9 l即为更改载入策略

load_mosaic函数是拼接四张图,而load_mosaic9函数是拼接九张图

这个函数是根据图片index,从self或者从对应图片路径中载入对应index的图片 并将原图中hw中较大者扩展到self.img_size, 较小者同比例扩展。
会被用在LoadImagesAndLabels模块的__getitem__函数和load_mosaic模块中载入对应index的图片

2. 数据增强

datasets.py涉及的数据增强在augmentations 总定义,但导入的这些函数增强方法的配置参数是通过配置文件在这里导入的

from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective

例如:augment_hsv 参数的导入(参数只是作为响应的概率,并不一定作用)

图片的色域增强模块,图片并不发生移动,所有不需要改变label,只需要 img 增强即可

if self.augment:
            # Albumentations
            img, labels = self.albumentations(img, labels)
            nl = len(labels)  # update after albumentations
            # HSV color-space
            augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
            # Flip up-down
            if random.random() < hyp['flipud']:
            # Flip left-right
            if random.random() < hyp['fliplr']:

random_perspective

1.这个函数是进行随机透视变换,对mosaic整合后的图片进行随机旋转、缩放、平移、裁剪,透视变换,并resize为输入大小img_size。

2.在mosaic操作之后进行透视变换/仿射变换:

box_candidates

这个函数用在random_perspective中,是对透视变换后的图片label进行筛选,去除被裁剪过小的框(面积小于裁剪前的area_thr)

还有长和宽必须大于wh_thr个像素,且长宽比范围在(1/ar_thr, ar_thr)之间的限制。

def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16):
    """用在random_perspective中 对透视变换后的图片label进行筛选
    去除被裁剪过小的框(面积小于裁剪前的area_thr) 还有长和宽必须大于wh_thr个像素,且长宽比范围在(1/ar_thr, ar_thr)之间的限制
    Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
    :params box1: [4, n]
    :params box2: [4, n]
    :params wh_thr: 筛选条件 宽高阈值
    :params ar_thr: 筛选条件 宽高比、高宽比最大值阈值
    :params area_thr: 筛选条件 面积阈值
    :params eps: 1e-16 接近0的数 防止分母为0
    :return i: 筛选结果 [n] 全是True或False   使用比如: box1[i]即可得到i中所有等于True的矩形框 False的矩形框全部删除

replicate

这个函数是随机偏移标签中心,生成新的标签与原标签结合。可以用在load_mosaic里在mosaic操作之后 random_perspective操作之前, 作者默认是关闭的, 自己可以实验一下效果。

def replicate(img, labels):
    """可以用在load_mosaic里在mosaic操作之后 random_perspective操作之前  作者默认是关闭的 自己可以实验一下效果
    随机偏移标签中心,生成新的标签与原标签结合  Replicate labels
    :params img: img4 因为是用在mosaic操作之后 所以size=[2*img_size, 2*img_size]
    :params labels: mosaic整合后图片的所有正常label标签labels4(不正常的会通过segments2boxes将多边形标签转化为正常标签) [N, cls+xyxy]
    :return img: img4 size=[2*img_size, 2*img_size] 不过图片中多了一半的较小gt个数
    :params labels: labels4 不过另外增加了一半的较小label [3/2N, cls+xyxy]

letterbox

对输入的图像进行缩放:原始数据的处理,可见在YOLOV5根本无需考虑图片的缩放问题。

1.load_image将图片从文件中加载出来,并resize到相应的尺寸(最长边等于我们需要的尺寸,最短边等比例缩放);

2.letterbox将之前resize后的图片再pad到我们所需要的放到dataloader中(collate_fn函数)的尺寸(矩形训练要求同一个batch中的图片的尺寸必须保持一致);

3.将label从相对原图尺寸(原文件中图片尺寸)缩放到相对letterbox pad后的图片尺寸。因为前两部分的图片尺寸发生了变化,同样的我们的label也需要发生相应的变化。

cutout

cutout数据增强,给图片随机添加随机大小的方块噪声 ,目的是提高泛化能力和鲁棒性。来自论文: https://arxiv.org/abs/1708.04552。

mixup

这个函数是进行mixup数据增强:按比例融合两张图片。论文:https://arxiv.org/pdf/1710.09412.pdf。

hist_equalize

进行直方图均衡化

posted @ 2021-12-19 23:28  ---dgw博客  阅读(891)  评论(0编辑  收藏  举报