torch.util.data
python 迭代器和生成器
总结:Python3 迭代器与生成器
迭代器和生成器是 Python 中用于处理集合元素的核心工具,它们提供了高效、灵活的迭代方式。以下是关于迭代器和生成器的详细总结和补充。
1. 迭代器 (Iterator)
什么是迭代器?
- 迭代器是一个可以记住遍历位置的对象。
- 它允许逐个访问集合中的元素,直到所有元素被访问完为止。
- 迭代器只能向前遍历,不能后退。
关键方法
iter():创建一个迭代器对象。next():获取迭代器的下一个元素。- 如果没有更多元素,则抛出
StopIteration异常。
- 如果没有更多元素,则抛出
示例
# 创建一个列表并生成迭代器
lst = [1, 2, 3, 4]
it = iter(lst) # 创建迭代器
print(next(it)) # 输出: 1
print(next(it)) # 输出: 2
使用 for 循环遍历迭代器
for x in it:
print(x, end=" ")
输出:
3 4
手动实现迭代器
通过在类中实现 __iter__() 和 __next__() 方法,可以自定义迭代器。
class MyNumbers:
def __iter__(self):
self.a = 1
return self
def __next__(self):
if self.a <= 5:
x = self.a
self.a += 1
return x
else:
raise StopIteration
myclass = MyNumbers()
myiter = iter(myclass)
for x in myiter:
print(x)
输出:
1
2
3
4
5
2. 生成器 (Generator)
什么是生成器?
- 生成器是一种特殊的函数,使用
yield关键字逐步返回值。 - 生成器本质上是一个返回迭代器的函数,适合处理大量数据或需要按需生成值的场景。
特点
- 惰性计算:每次调用
next()时,生成器才执行到下一个yield语句。 - 节省内存:生成器不会一次性将所有结果加载到内存中,而是动态生成值。
示例
def countdown(n):
while n > 0:
yield n
n -= 1
gen = countdown(5)
print(next(gen)) # 输出: 5
print(next(gen)) # 输出: 4
使用 for 循环遍历生成器
for value in countdown(5):
print(value)
输出:
5
4
3
2
1
生成器函数的优势
- 按需生成值,避免一次性占用大量内存。
- 简化代码逻辑,特别适合处理复杂迭代过程。
3. 迭代器 vs 生成器
| 特性 | 迭代器 | 生成器 |
|---|---|---|
| 定义方式 | 实现 __iter__() 和 __next__() 方法 |
使用 yield 关键字定义函数 |
| 返回值 | 返回一个可迭代对象 | 返回一个生成器对象 |
| 性能 | 需要显式实现迭代逻辑 | 自动生成迭代逻辑,代码更简洁 |
| 内存效率 | 需要存储整个集合 | 按需生成值,节省内存 |
| 适用场景 | 自定义迭代逻辑 | 动态生成值,特别是大数据集 |
4. 生成器的经典应用
斐波那契数列
def fibonacci(n):
a, b, counter = 0, 1, 0
while counter < n:
yield a
a, b = b, a + b
counter += 1
for num in fibonacci(10):
print(num, end=" ")
输出:
0 1 1 2 3 5 8 13 21 34
读取大文件
生成器非常适合逐行读取大文件,而不需要一次性加载到内存中:
def read_large_file(file_path):
with open(file_path, 'r') as file:
for line in file:
yield line.strip()
for line in read_large_file('large_file.txt'):
print(line)
5. 补充知识
yield from
yield from是 Python 3.3 引入的新语法,用于简化生成器的嵌套。- 它允许直接从另一个生成器、迭代器或可迭代对象中生成值。
示例:
def generator1():
yield from range(3)
def generator2():
yield from generator1()
yield from range(3, 6)
for value in generator2():
print(value)
输出:
0
1
2
3
4
5
生成器表达式
- 类似于列表推导式,但使用圆括号
()而不是方括号[]。 - 生成器表达式是惰性求值的,适合处理大数据集。
示例:
gen = (x * x for x in range(5))
for value in gen:
print(value)
输出:
0
1
4
9
16
6. 总结
- 迭代器:适合自定义迭代逻辑,能够灵活控制遍历过程。
- 生成器:适合按需生成值的场景,代码简洁且节省内存。
- 选择依据:
- 如果需要自定义复杂的迭代逻辑,使用迭代器。
- 如果需要按需生成值(如大数据集),使用生成器。
通过迭代器和生成器,可以显著提升代码的可读性和性能,同时优化内存使用。
如何在 PyTorch 中使用 Dataset 和 DataLoader 来加载和处理数据。
1. 数据集与数据加载器
-
核心概念:
torch.utils.data.Dataset:存储样本及其标签的容器。torch.utils.data.DataLoader:为Dataset提供一个可迭代的接口,支持批量加载、数据打乱和多进程加速。
-
优点:
- 数据加载代码与模型训练代码解耦,提升代码的可读性和模块化。
- 支持预加载数据集(如 FashionMNIST)和自定义数据集。
2. 加载预加载数据集
以 FashionMNIST 为例,展示了如何从 TorchVision 加载数据集:
from torchvision import datasets
from torchvision.transforms import ToTensor
# 加载训练集和测试集
training_data = datasets.FashionMNIST(
root="data", # 数据存储路径
train=True, # 训练集
download=True, # 下载数据
transform=ToTensor() # 转换为张量
)
test_data = datasets.FashionMNIST(
root="data",
train=False, # 测试集
download=True,
transform=ToTensor()
)
3. 可视化数据
通过索引访问数据集,并使用 Matplotlib 可视化样本:
import matplotlib.pyplot as plt
labels_map = {
0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat",
5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot"
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
4. 创建自定义数据集
实现一个自定义数据集类,继承 torch.utils.data.Dataset 并实现以下方法:
__init__:初始化数据集路径、标注文件及转换操作。__len__:返回数据集的样本数量。__getitem__:根据索引加载并返回样本及其标签。
示例代码:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file) # 标注文件
self.img_dir = img_dir # 图像目录
self.transform = transform # 特征转换
self.target_transform = target_transform # 标签转换
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # 获取图像路径
image = read_image(img_path) # 读取图像
label = self.img_labels.iloc[idx, 1] # 获取标签
if self.transform:
image = self.transform(image) # 应用特征转换
if self.target_transform:
label = self.target_transform(label) # 应用标签转换
return image, label
5. 使用 DataLoader 加载数据
DataLoader 提供了对数据集的高效访问,支持批量加载、数据打乱和多进程加速:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
6. 遍历 DataLoader
每次迭代返回一个批次的特征和标签:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}") # [batch_size, channels, height, width]
print(f"Labels batch shape: {train_labels.size()}") # [batch_size]
# 可视化第一个样本
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
7. 总结
-
Dataset:- 存储样本及其标签。
- 支持预加载数据集和自定义数据集。
- 自定义数据集需实现
__init__、__len__和__getitem__方法。
-
DataLoader:- 提供高效的批量加载、数据打乱和多进程加速。
- 是训练深度学习模型时的核心工具。
通过合理使用 Dataset 和 DataLoader,可以显著简化数据加载流程,提高代码的可维护性和性能。
torch.utils.data
torch.utils.data 是 PyTorch 数据加载的核心模块,其中最重要的类是 DataLoader。它是一个可迭代对象,用于从数据集中高效地加载数据,并支持多种功能和选项。以下是总结:
核心功能
DataLoader 提供了以下主要功能:
-
支持两种数据集类型:
- Map-style 数据集:通过索引访问数据(实现
__getitem__和__len__方法)。 - Iterable-style 数据集:通过迭代器顺序生成数据(实现
__iter__方法)。
- Map-style 数据集:通过索引访问数据(实现
-
自定义数据加载顺序:
- 通过
shuffle参数或自定义Sampler控制数据加载顺序。
- 通过
-
自动批量处理:
- 使用
batch_size参数将样本组合成批次。 - 支持自定义批量逻辑(通过
collate_fn)。
- 使用
-
单进程和多进程数据加载:
- 单进程加载适合小型数据集或调试场景。
- 多进程加载(通过
num_workers参数)适合大型数据集和高性能需求。
-
自动内存锁定:
- 设置
pin_memory=True将数据放置在页锁定内存中,加速数据传输到 CUDA GPU。
- 设置
构造函数参数
DataLoader 的构造函数提供了丰富的参数来配置数据加载行为:
-
基本参数:
dataset:要加载的数据集。batch_size:每个批次包含的样本数量。shuffle:是否在每个 epoch 开始时打乱数据。sampler:自定义采样器以控制数据加载顺序。batch_sampler:自定义批量采样器,与batch_size、shuffle和drop_last互斥。
-
多进程相关参数:
num_workers:用于数据加载的子进程数量。worker_init_fn:初始化每个工作进程的函数。prefetch_factor:每个工作进程预取的批次数。persistent_workers:是否在数据集耗尽后保持工作进程存活。
-
其他功能参数:
collate_fn:用于将样本组合成批次的函数。pin_memory:是否启用内存锁定。drop_last:是否丢弃最后一个不完整的批次。timeout:设置从工作进程收集批次的超时时间。
总结
torch.utils.data.DataLoader 是一个功能强大且灵活的工具,能够满足各种数据加载需求。通过合理配置其参数,可以实现高效的批量处理、多进程加载和内存优化,从而显著提升模型训练的效率和性能。
在实际使用中,应根据数据集的类型、任务需求以及硬件资源选择合适的配置,例如是否启用多进程加载、是否使用内存锁定等。
PyTorch 的 DataLoader 是一个非常强大的工具,用于加载数据集并支持高效的批量处理、多线程加载等功能。根据数据集的类型和需求,DataLoader 提供了不同的机制来加载和处理数据。以下是内容的总结与解释:
1. 数据集类型
PyTorch 支持两种主要的数据集类型:
(1) Map-style 数据集
- 定义:实现
__getitem__()和__len__()方法。 - 特点:
- 数据集被看作是从索引(或键)到样本的映射。
- 可以通过索引访问特定样本,例如
dataset[idx]。 - 常见场景:从磁盘文件夹中读取图片及其标签。
- 适用场景:适合可以随机读取数据的情况。
(2) Iterable-style 数据集
- 定义:实现
__iter__()方法,是IterableDataset的子类。 - 特点:
- 数据集被看作是一个可迭代的对象。
- 不支持随机访问,而是按顺序生成数据流。
- 常见场景:从数据库、远程服务器或实时日志中读取数据。
- 注意:
- 在多进程加载时,每个工作进程会复制相同的
IterableDataset对象,因此需要配置以避免重复数据。
- 在多进程加载时,每个工作进程会复制相同的
2. 数据加载顺序与采样器
-
Map-style 数据集:
- 使用
Sampler类控制数据加载顺序。 - 默认情况下,
DataLoader根据shuffle参数自动构造顺序或随机打乱的采样器。 - 用户可以通过
sampler或batch_sampler参数自定义采样逻辑。
- 使用
-
Iterable-style 数据集:
- 数据加载顺序完全由用户定义的迭代器控制。
- 不支持
Sampler或batch_sampler,因为这类数据集没有索引或键的概念。
3. 批量加载与非批量加载
DataLoader 提供了灵活的方式来处理批量数据加载。
(1) 自动批量加载(默认行为)
- 参数:
batch_size:指定批量大小。drop_last:是否丢弃最后一个不完整的批次。collate_fn:用于将单个样本组合成批量的函数。
- 工作流程:
- 从数据集中获取一批样本。
- 使用
collate_fn将这些样本组合成一个批次。
- 默认行为:
- 如果每个样本是
(image, label)形式的元组,collate_fn会将它们组合成(batched_images, batched_labels)。 - 自动将 NumPy 数组转换为 PyTorch 张量,并添加批量维度。
- 如果每个样本是
(2) 禁用自动批量加载
- 场景:
- 用户希望手动处理批量逻辑。
- 数据集本身已经返回批量数据。
- 方法:
- 设置
batch_size=None且batch_sampler=None。 - 每次从数据集中获取单个样本,并使用
collate_fn进行简单处理(如将 NumPy 数组转换为张量)。
- 设置
4. 自定义 collate_fn
collate_fn 是一个灵活的工具,可以根据需求自定义批量处理逻辑。
(1) 自动批量加载启用时
- 输入:一批样本的列表。
- 输出:一个批次的数据。
- 示例:
- 如果样本是字典形式,
collate_fn会输出一个字典,其中每个值是批量化的张量。 - 如果样本是序列形式,可以自定义填充逻辑以处理不同长度的序列。
- 如果样本是字典形式,
(2) 自动批量加载禁用时
- 输入:单个样本。
- 输出:经过简单处理的样本(如 NumPy 转张量)。
5. 注意事项
- 多进程加载:
- 对于
IterableDataset,需要确保每个工作进程的数据流不重复。
- 对于
collate_fn调试:- 如果
DataLoader输出的数据维度或类型不符合预期,应检查collate_fn的实现。
- 如果
6. 总结
- Map-style 数据集适合随机访问数据的场景,而Iterable-style 数据集适合顺序读取数据的场景。
- 数据加载顺序可以通过
Sampler或用户定义的迭代器控制。 DataLoader提供了灵活的批量加载机制,默认情况下会自动将样本组合成批次。- 使用
collate_fn可以自定义批量处理逻辑,满足特殊需求。
通过合理选择数据集类型、采样器和批量加载策略,可以高效地加载和处理各种复杂的数据集。
PyTorch 的 DataLoader 提供了单进程和多进程两种数据加载模式,以满足不同场景下的需求。以下是内容的总结与解释:
1. 单进程数据加载(默认模式)
- 特点:
- 数据加载在初始化
DataLoader的同一进程中完成。 - 数据加载可能会阻塞计算代码,因为 Python 的全局解释器锁(GIL)限制了线程间的并行执行。
- 数据加载在初始化
- 适用场景:
- 数据集较小且可以完全加载到内存中。
- 系统资源(如共享内存、文件描述符)有限。
- 调试时更易读的错误追踪信息。
- 优点:
- 实现简单,调试方便。
- 缺点:
- 数据加载可能成为性能瓶颈。
2. 多进程数据加载
通过设置 num_workers 参数为正整数,可以启用多进程数据加载。
(1) 工作原理
- 每次创建
DataLoader的迭代器时(例如调用enumerate(dataloader)),会启动num_workers个子进程。 - 数据集对象、
collate_fn和worker_init_fn会被传递给每个子进程,并在子进程中初始化和加载数据。 - 子进程中的数据加载操作(包括 IO 和转换)是独立于主进程的。
(2) 内存使用注意事项
- 问题:
- 子进程会复制父进程中所有被访问的 Python 对象,可能导致内存占用显著增加(总内存 = 工作进程数 × 父进程内存)。
- 如果数据集包含大量数据(如长列表的文件名),或使用了过多的工作进程,可能会导致内存不足。
- 解决方法:
- 使用非引用计数的对象表示形式(如 Pandas、NumPy 或 PyArrow 对象)来减少内存开销。
(3) 平台相关行为
- Unix 系统:
- 默认使用
fork()创建子进程,子进程可以直接访问父进程的地址空间。
- 默认使用
- Windows/MacOS 系统:
- 默认使用
spawn()创建子进程,子进程会重新运行主脚本并通过序列化(pickle)传递参数。 - 兼容性建议:
- 将主脚本的主要逻辑放在
if __name__ == '__main__':块中,避免子进程重复执行代码。 - 确保自定义的
collate_fn、worker_init_fn和数据集代码是顶层定义,以便在子进程中可用。
- 将主脚本的主要逻辑放在
- 默认使用
3. 数据集类型的多进程处理
- Map-style 数据集:
- 主进程生成索引并通过采样器发送给工作进程。
- 随机打乱等操作由主进程控制。
- Iterable-style 数据集:
- 每个工作进程都会获得数据集对象的一个副本,可能导致数据重复。
- 解决方法:
- 使用
torch.utils.data.get_worker_info()或worker_init_fn配置每个副本,确保数据分片独立。 - 在多进程加载时,
drop_last参数会丢弃每个工作进程最后一个不完整的批次。
- 使用
4. 随机性管理
- 默认行为:
- 每个工作进程的 PyTorch 种子设置为
base_seed + worker_id,其中base_seed是由主进程生成的随机数。 - 其他库的种子可能在初始化工作进程时重复,导致每个进程生成相同的随机数。
- 每个工作进程的 PyTorch 种子设置为
- 解决方法:
- 在
worker_init_fn中,可以使用torch.utils.data.get_worker_info().seed或torch.initial_seed()来为其他库设置种子。
- 在
5. CUDA 张量与多进程加载
- 警告:
- 不建议在多进程加载中返回 CUDA 张量,因为 CUDA 和多进程之间存在许多复杂性。
- 推荐做法:
- 启用自动内存锁定(
pin_memory=True),以加速数据传输到 CUDA GPU。
- 启用自动内存锁定(
6. 总结
- 单进程加载:
- 简单易用,适合小型数据集或调试场景。
- 可能因数据加载阻塞计算而影响性能。
- 多进程加载:
- 提高性能,适合大型数据集和复杂任务。
- 需要注意内存占用、平台差异和随机性管理。
- 最佳实践:
- 根据数据集大小和任务需求选择合适的加载模式。
- 在多进程加载中,合理配置工作进程数和数据集分片,避免内存和性能问题。
- 使用
pin_memory=True加速 GPU 数据传输。
通过合理配置 DataLoader 的参数,可以实现高效、灵活的数据加载流程,从而提升模型训练的效率和稳定性。
这段内容主要讨论了 PyTorch 中 DataLoader 的内存锁定(Memory Pinning)功能,以及如何通过自定义数据类型支持该功能。以下是总结和解释:
内存锁定(Memory Pinning)
- 概念:内存锁定指的是将数据放置在页锁定(pinned)内存中,这可以显著加快从主机到 GPU 的数据传输速度。
- 使用方法:
- 将
pin_memory=True参数传递给DataLoader,这样从数据集中获取的数据张量将自动置于页锁定内存中,从而实现更快的 CUDA 设备数据传输。
- 将
默认内存锁定逻辑
- 默认情况下,内存锁定逻辑仅识别张量(Tensors)及包含张量的映射和可迭代对象。
- 如果你的批处理是一个自定义类型(例如,你有一个返回自定义批处理类型的
collate_fn),或批处理中的每个元素都是自定义类型,内存锁定逻辑将不会识别它们,并会直接返回这些批处理或元素而不锁定内存。
为自定义类型启用内存锁定
- 对于自定义数据类型,可以通过在其上定义一个
pin_memory()方法来启用内存锁定功能。 - 示例代码展示了如何为一个自定义批处理类型
SimpleCustomBatch实现内存锁定逻辑。
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# 自定义内存锁定方法
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
# 创建数据集并使用 DataLoader 加载数据
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
其他注意事项
- 多进程加载警告:
- 当使用
spawn启动方法时,worker_init_fn不能是不可序列化的对象,如 lambda 函数。
- 当使用
- 随机性管理:
- 每个工作进程都有自己的 PyTorch 种子,基于
base_seed + worker_id设置。但是,对于其他库,种子可能重复,导致生成相同的随机数。可以在worker_init_fn中手动设置种子以避免这种情况。
- 每个工作进程都有自己的 PyTorch 种子,基于
in_order参数:- 如果设置为
False,可能会损害结果的可重复性,并可能导致不平衡数据训练器接收到的数据分布偏斜。
- 如果设置为
总结
- 内存锁定能够加速数据从主机内存到 GPU 的传输过程,特别适合需要高性能数据传输的应用场景。
- 自定义类型的支持:当使用自定义批处理类型时,需自行实现
pin_memory()方法以利用内存锁定的优势。 - 多进程加载与随机性管理:正确配置工作进程的初始化函数和随机种子,确保数据加载过程的稳定性和高效性。
- 最佳实践:根据具体应用场景选择是否启用内存锁定,特别是在涉及大量数据传输至 GPU 的任务中,合理配置
DataLoader参数可以大幅提升性能。

浙公网安备 33010602011771号