Python Class 类中的属性之 "__iter__",背景:小样本学习采样器

可以参考GeeksforGeeks:https://www.geeksforgeeks.org/(这是一个很好用的学习计算机编程语言的网站,有C、C++、Java、Python等各种常用语言)
关于__iter__,请前往: Python中的迭代器

1. "__iter__" 方法是什么?

__iter__是Python类中的一个特殊的方法,也叫魔术方法,它使得这个类可迭代(iterable),当你在该类的对象上调用for循环或者其他迭代方法的时候,就会调用该__iter__方法。这个属性对于深度学习中的采样器sampler有着重要的作用,因为该方法定义了如何从数据集中去采集数据。这里距离用的PyTorch中的sampler采样器,具体来说有一下几点:

  1. 迭代器协议:__iter__方法返回了一个迭代器对象,该对象实现了__next__方法。
  2. 生成器函数: __iter__使用了yield关键字(关于关键字yield,可以参考理解python中的yield关键字:https://zhuanlan.zhihu.com/p/37257918),这使得它称为一个生成器函数,自动实现迭代器协议。
  3. 数据加载过程:PyTorch的 DataLoader会调用采样器的__ier__方法来确定每哥批次应该包含哪些数据样本。

2. 举个例子

下面的__iter__是一个从小样本学习中的一个分布式的采样器中class DistributedCategoriesSampler(Sampler)截取的__iter__方法

def __iter__(self):
        """Random sample a FSL task batch(multi-task).

        Yields:
            torch.Tensor: The stacked tensor of a FSL task batch(multi-task).
        """
        batch = []
        for i_batch in range(self.episode_num):
            classes = torch.randperm(len(self.idx_list))[: self.way_num]
            for c in classes:
                idxes = self.idx_list[c.item()]
                pos = torch.randperm(idxes.size(0))[: self.image_num]
                batch.append(idxes[pos])
            if len(batch) == self.episode_size * self.way_num:
                batch = torch.stack(batch).reshape(-1)
                yield batch
                batch = []

解释

1. 初始化批次容器

batch = []

创建一个空的列表,用于存储当前批次中的样本的索引

2. 生成多个episode(任务)

for i_batch in range(self.episode_num):

循环生成预定数量的episode(由 episode_num 参数指定)

3. 随机选择类别

classes = torch.randperm(len(self.idx_list))[: self.way_num]

使用 torch.randperm 函数生成一个随机排列,然后选择前 way_num 个类别。这实现了N-way分类任务中的"N-way",即每个任务包含N个类别

4. 为每个类别选择样本

for c in classes:
    idxes = self.idx_list[c.item()]
    pos = torch.randperm(idxes.size(0))[: self.image_num]
    batch.append(idxes[pos])

对于每个选定的类别:

  • 获取该类别的所有样本索引( idxes )
  • 随机打乱这些索引并选择前 image_num 个( pos )
  • 将选定的样本索引添加到批次中
    这实现了K-shot学习中的"K-shot"部分,其中 image_num = shot_num + query_num ,即每个类别包含K个支持样本和若干查询样本。

5. 批次完成检查与输出

if len(batch) == self.episode_size * self.way_num:
    batch = torch.stack(batch).reshape(-1)
    yield batch
    batch = []

当批次大小达到预期值( episode_size * way_num )时:

  • 将批次中的所有张量堆叠成一个大张量
  • 重塑为一维张量
  • 使用 yield 关键字返回这个批次,同时保持函数状态
  • 清空批次列表,准备下一个批次

3. 核心技术要点

  1. 任务构建 :代码实现了FSL中的episode-based训练范式,每个episode是一个N-way K-shot任务
  2. 随机采样 :使用 torch.randperm 确保类别和样本的随机性,这对于模型泛化能力至关重要
  3. 批处理 :支持一次生成多个episode(由 episode_size 控制),提高训练效率
  4. 生成器模式 :使用 yield 实现惰性计算,只在需要时生成数据,节省内存
  5. 灵活配置 :通过参数化的方式支持不同的FSL设置(N-way, K-shot)

这个采样器是小样本学习框架中的关键组件,它确保了训练和测试过程中任务的随机性和多样性,同时保持了小样本学习的基本范式,而采样的过程也就是迭代的关键部分,正是__iter__方法在起作用。

posted @ 2025-06-27 15:05  3klxi  阅读(40)  评论(0)    收藏  举报