3.3.2 读取数据集
在 PyTorch 中,DataLoader
本身是一个可迭代对象(Iterable),而不是一个迭代器(Iterator)。两者的关键区别在于:
1. 可迭代对象 vs. 迭代器的区别
-
可迭代对象(Iterable)
任何实现了__iter__()
方法的对象,例如列表、元组、DataLoader
。
特点:每次调用iter(iterable)
会返回一个新的独立迭代器,从头开始遍历数据。 -
迭代器(Iterator)
实现了__iter__()
和__next__()
方法的对象,例如通过iter(iterable)
生成的对象。
特点:保存遍历的进度状态,每次调用next(iterator)
会返回下一个元素。
2. 为什么 DataLoader
需要 iter()
?
-
DataLoader
是可迭代对象:
它实现了__iter__()
方法,但本身不直接保存遍历的进度状态。每次调用iter(data_iter)
会生成一个新的迭代器,用于遍历数据集。 -
直接调用
next(data_iter)
会报错:
因为DataLoader
没有实现__next__()
方法(它不是迭代器),必须显式调用iter()
获取迭代器后才能使用next()
。
3. 代码示例分析
你的代码:
data_iter = load_array((features, labels), batch_size) # 返回 DataLoader 对象
next(iter(data_iter)) # 通过 iter(data_iter) 生成迭代器,再调用 next()
-
iter(data_iter)
:
生成一个新的迭代器对象,从头开始遍历数据。 -
next(iter(data_iter))
:
每次执行时,都会生成一个新的迭代器,并返回第一个批次的数据。
这意味着多次调用next(iter(data_iter))
会始终返回第一个批次,而不是遍历后续批次。
4. 正确用法
如果需要遍历整个数据集,应该先获取迭代器并保存状态:
data_iter = load_array((features, labels), batch_size)
iterator = iter(data_iter) # 生成一个迭代器,保存遍历状态
# 遍历批次
batch1 = next(iterator) # 第一个批次
batch2 = next(iterator) # 第二个批次
5. 常见误区
误区:多次调用 iter(data_iter)
# 错误!每次调用 iter(data_iter) 都会生成新迭代器,导致重复获取第一个批次
batch1 = next(iter(data_iter))
batch2 = next(iter(data_iter)) # 依然是第一个批次!
正确做法:保存迭代器
# 正确!通过保存迭代器状态,遍历后续批次
iterator = iter(data_iter)
batch1 = next(iterator)
batch2 = next(iterator)
6. 总结
DataLoader
是可迭代对象:必须通过iter()
生成迭代器后才能调用next()
。- 迭代器保存遍历状态:多次调用
next(iterator)
会按顺序返回批次。 - 避免重复生成迭代器:多次调用
iter(data_iter)
会重置遍历进度。