3.3.2 读取数据集

在 PyTorch 中,DataLoader 本身是一个可迭代对象(Iterable),而不是一个迭代器(Iterator)。两者的关键区别在于:


1. 可迭代对象 vs. 迭代器的区别

  • 可迭代对象(Iterable)
    任何实现了 __iter__() 方法的对象,例如列表、元组、DataLoader
    特点:每次调用 iter(iterable) 会返回一个新的独立迭代器,从头开始遍历数据。

  • 迭代器(Iterator)
    实现了 __iter__()__next__() 方法的对象,例如通过 iter(iterable) 生成的对象。
    特点:保存遍历的进度状态,每次调用 next(iterator) 会返回下一个元素。


2. 为什么 DataLoader 需要 iter()

  • DataLoader 是可迭代对象
    它实现了 __iter__() 方法,但本身不直接保存遍历的进度状态。每次调用 iter(data_iter) 会生成一个新的迭代器,用于遍历数据集。

  • 直接调用 next(data_iter) 会报错
    因为 DataLoader 没有实现 __next__() 方法(它不是迭代器),必须显式调用 iter() 获取迭代器后才能使用 next()


3. 代码示例分析

你的代码:

data_iter = load_array((features, labels), batch_size)  # 返回 DataLoader 对象
next(iter(data_iter))  # 通过 iter(data_iter) 生成迭代器,再调用 next()
  • iter(data_iter)
    生成一个新的迭代器对象,从头开始遍历数据。

  • next(iter(data_iter))
    每次执行时,都会生成一个新的迭代器,并返回第一个批次的数据。
    这意味着多次调用 next(iter(data_iter))始终返回第一个批次,而不是遍历后续批次。


4. 正确用法

如果需要遍历整个数据集,应该先获取迭代器并保存状态

data_iter = load_array((features, labels), batch_size)
iterator = iter(data_iter)  # 生成一个迭代器,保存遍历状态

# 遍历批次
batch1 = next(iterator)  # 第一个批次
batch2 = next(iterator)  # 第二个批次

5. 常见误区

误区:多次调用 iter(data_iter)

# 错误!每次调用 iter(data_iter) 都会生成新迭代器,导致重复获取第一个批次
batch1 = next(iter(data_iter))
batch2 = next(iter(data_iter))  # 依然是第一个批次!

正确做法:保存迭代器

# 正确!通过保存迭代器状态,遍历后续批次
iterator = iter(data_iter)
batch1 = next(iterator)
batch2 = next(iterator)

6. 总结

  • DataLoader 是可迭代对象:必须通过 iter() 生成迭代器后才能调用 next()
  • 迭代器保存遍历状态:多次调用 next(iterator) 会按顺序返回批次。
  • 避免重复生成迭代器:多次调用 iter(data_iter) 会重置遍历进度。
posted @ 2025-03-08 19:46  最爱丁珰  阅读(31)  评论(0)    收藏  举报