Python Class 类中的属性之 "__iter__",背景:小样本学习采样器
可以参考GeeksforGeeks:https://www.geeksforgeeks.org/(这是一个很好用的学习计算机编程语言的网站,有C、C++、Java、Python等各种常用语言)
关于__iter__,请前往: Python中的迭代器
1. "__iter__" 方法是什么?
__iter__是Python类中的一个特殊的方法,也叫魔术方法,它使得这个类可迭代(iterable),当你在该类的对象上调用for循环或者其他迭代方法的时候,就会调用该__iter__方法。这个属性对于深度学习中的采样器sampler有着重要的作用,因为该方法定义了如何从数据集中去采集数据。这里距离用的PyTorch中的sampler采样器,具体来说有一下几点:
- 迭代器协议:__iter__方法返回了一个迭代器对象,该对象实现了__next__方法。
- 生成器函数: __iter__使用了
yield关键字(关于关键字yield,可以参考理解python中的yield关键字:https://zhuanlan.zhihu.com/p/37257918),这使得它称为一个生成器函数,自动实现迭代器协议。 - 数据加载过程:PyTorch的
DataLoader会调用采样器的__ier__方法来确定每哥批次应该包含哪些数据样本。
2. 举个例子
下面的__iter__是一个从小样本学习中的一个分布式的采样器中
class DistributedCategoriesSampler(Sampler)截取的__iter__方法
def __iter__(self):
"""Random sample a FSL task batch(multi-task).
Yields:
torch.Tensor: The stacked tensor of a FSL task batch(multi-task).
"""
batch = []
for i_batch in range(self.episode_num):
classes = torch.randperm(len(self.idx_list))[: self.way_num]
for c in classes:
idxes = self.idx_list[c.item()]
pos = torch.randperm(idxes.size(0))[: self.image_num]
batch.append(idxes[pos])
if len(batch) == self.episode_size * self.way_num:
batch = torch.stack(batch).reshape(-1)
yield batch
batch = []
解释
1. 初始化批次容器
batch = []
创建一个空的列表,用于存储当前批次中的样本的索引
2. 生成多个episode(任务)
for i_batch in range(self.episode_num):
循环生成预定数量的episode(由 episode_num 参数指定)
3. 随机选择类别
classes = torch.randperm(len(self.idx_list))[: self.way_num]
使用 torch.randperm 函数生成一个随机排列,然后选择前 way_num 个类别。这实现了N-way分类任务中的"N-way",即每个任务包含N个类别
4. 为每个类别选择样本
for c in classes:
idxes = self.idx_list[c.item()]
pos = torch.randperm(idxes.size(0))[: self.image_num]
batch.append(idxes[pos])
对于每个选定的类别:
- 获取该类别的所有样本索引( idxes )
- 随机打乱这些索引并选择前 image_num 个( pos )
- 将选定的样本索引添加到批次中
这实现了K-shot学习中的"K-shot"部分,其中 image_num = shot_num + query_num ,即每个类别包含K个支持样本和若干查询样本。
5. 批次完成检查与输出
if len(batch) == self.episode_size * self.way_num:
batch = torch.stack(batch).reshape(-1)
yield batch
batch = []
当批次大小达到预期值( episode_size * way_num )时:
- 将批次中的所有张量堆叠成一个大张量
- 重塑为一维张量
- 使用 yield 关键字返回这个批次,同时保持函数状态
- 清空批次列表,准备下一个批次
3. 核心技术要点
- 任务构建 :代码实现了FSL中的episode-based训练范式,每个episode是一个N-way K-shot任务
- 随机采样 :使用 torch.randperm 确保类别和样本的随机性,这对于模型泛化能力至关重要
- 批处理 :支持一次生成多个episode(由 episode_size 控制),提高训练效率
- 生成器模式 :使用 yield 实现惰性计算,只在需要时生成数据,节省内存
- 灵活配置 :通过参数化的方式支持不同的FSL设置(N-way, K-shot)
这个采样器是小样本学习框架中的关键组件,它确保了训练和测试过程中任务的随机性和多样性,同时保持了小样本学习的基本范式,而采样的过程也就是迭代的关键部分,正是__iter__方法在起作用。

浙公网安备 33010602011771号