如何更改训练策略——利用torch.utils.data.batchsampler修改batch处理逻辑

问题背景

给了个任务,小老板单独给了个训练集,要按照他创造的mimo策略进行训练/验证。mimo策略其中第一步就是对数据集进行处理,要把每个batch重复n_infers遍,之后组合所有的batch生成一个单独的epoch。

原码是使用torch.utils.dataloader进行数据集加载的,并使用sampler(torch.utils.data.sampler)进行batch采样的策略选取。

所以打算看看能否利用torch直接实现batch的策略,要是不行就得抛弃dataloder,自己写一个数据集加载。

修改过程

torch.dataloader

参考资料:PyTorch入门必学:DataLoader(数据迭代器)参数解析与用法合集_python dataloader-CSDN博客

  1. 参数:

    dataset (必需): 用于加载数据的数据集,通常是torch.utils.data.Dataset的子类实例。
    batch_size (可选): 每个批次的数据样本数。默认值为1。
    shuffle (可选): 是否在每个周期开始时打乱数据。默认为False。
    sampler (可选): 定义从数据集中抽取样本的策略。如果指定,则忽略shuffle参数。
    batch_sampler (可选): 与sampler类似,但一次返回一个批次的索引。不能与batch_size、shuffle和sampler同时使用。
    num_workers (可选): 用于数据加载的子进程数量。默认为0,意味着数据将在主进程中加载。
    collate_fn (可选): 如何将多个数据样本整合成一个批次。通常不需要指定。
    drop_last (可选): 如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次。默认为False。

在这里可以发现,参数上看可以做操作的sampler,batch_sampler这俩个参数,于是细细研究一下。

这两个参数都是源自data.utils.data.sampler和BatchSampler这个俩类。看到一篇不错的文章,讲解了以sampler,batchsampler等几个模块作为基础,dataloader的工作流程(如何从dataset变成一个batch的数据的)[Pytorch] Sampler, DataLoader和数据batch的形成 - 知乎 发现其实从sampler提供idx,给batchsampler进行处理,返回一个batch的数据。于是起手阅读batch_sampler,打算从这开始更改Pytorch Sampler详解 - 知乎

更改大致思路是集成batchsampler并更改iter迭代器,从而获得想要的idx

class CustomBatchSampler(BatchSampler):
    def __init__(self, sampler, batch_size, drop_last, custom_param):
        # 调用父类的构造函数
        super().__init__(sampler, batch_size, drop_last)
        self.custom_param = custom_param  # 添加自定义参数
    
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            #在这里处理batch的序号从而获得想要的batch
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        
        # 最后的批次处理
        if len(batch) > 0 and not self.drop_last:
            yield batch
    
    def __len__(self):
        # 如果需要修改 batch size 或其他因素,可以重写该方法
        return super().__len__()

至此修改完毕

题外话,最好先注意下封装好的数据集在读取batch的时候是不是只有图像这一个数据。本次处理虽然成功,但之后发现这里的batch中不知有image,还有notation等等,所以最后还是再batch中读取图像那一步修改才达成要求

posted @ 2024-11-25 22:57  stt211818  阅读(115)  评论(0)    收藏  举报