torch.util.data

python 迭代器和生成器

总结:Python3 迭代器与生成器

迭代器和生成器是 Python 中用于处理集合元素的核心工具,它们提供了高效、灵活的迭代方式。以下是关于迭代器和生成器的详细总结和补充。


1. 迭代器 (Iterator)

什么是迭代器?

  • 迭代器是一个可以记住遍历位置的对象。
  • 它允许逐个访问集合中的元素,直到所有元素被访问完为止。
  • 迭代器只能向前遍历,不能后退。

关键方法

  • iter():创建一个迭代器对象。
  • next():获取迭代器的下一个元素。
    • 如果没有更多元素,则抛出 StopIteration 异常。

示例

# 创建一个列表并生成迭代器
lst = [1, 2, 3, 4]
it = iter(lst)  # 创建迭代器
print(next(it))  # 输出: 1
print(next(it))  # 输出: 2

使用 for 循环遍历迭代器

for x in it:
    print(x, end=" ")

输出:

3 4

手动实现迭代器

通过在类中实现 __iter__()__next__() 方法,可以自定义迭代器。

class MyNumbers:
    def __iter__(self):
        self.a = 1
        return self

    def __next__(self):
        if self.a <= 5:
            x = self.a
            self.a += 1
            return x
        else:
            raise StopIteration

myclass = MyNumbers()
myiter = iter(myclass)

for x in myiter:
    print(x)

输出:

1
2
3
4
5

2. 生成器 (Generator)

什么是生成器?

  • 生成器是一种特殊的函数,使用 yield 关键字逐步返回值。
  • 生成器本质上是一个返回迭代器的函数,适合处理大量数据或需要按需生成值的场景。

特点

  • 惰性计算:每次调用 next() 时,生成器才执行到下一个 yield 语句。
  • 节省内存:生成器不会一次性将所有结果加载到内存中,而是动态生成值。

示例

def countdown(n):
    while n > 0:
        yield n
        n -= 1

gen = countdown(5)
print(next(gen))  # 输出: 5
print(next(gen))  # 输出: 4

使用 for 循环遍历生成器

for value in countdown(5):
    print(value)

输出:

5
4
3
2
1

生成器函数的优势

  • 按需生成值,避免一次性占用大量内存。
  • 简化代码逻辑,特别适合处理复杂迭代过程。

3. 迭代器 vs 生成器

特性 迭代器 生成器
定义方式 实现 __iter__()__next__() 方法 使用 yield 关键字定义函数
返回值 返回一个可迭代对象 返回一个生成器对象
性能 需要显式实现迭代逻辑 自动生成迭代逻辑,代码更简洁
内存效率 需要存储整个集合 按需生成值,节省内存
适用场景 自定义迭代逻辑 动态生成值,特别是大数据集

4. 生成器的经典应用

斐波那契数列

def fibonacci(n):
    a, b, counter = 0, 1, 0
    while counter < n:
        yield a
        a, b = b, a + b
        counter += 1

for num in fibonacci(10):
    print(num, end=" ")

输出:

0 1 1 2 3 5 8 13 21 34

读取大文件

生成器非常适合逐行读取大文件,而不需要一次性加载到内存中:

def read_large_file(file_path):
    with open(file_path, 'r') as file:
        for line in file:
            yield line.strip()

for line in read_large_file('large_file.txt'):
    print(line)

5. 补充知识

yield from

  • yield from 是 Python 3.3 引入的新语法,用于简化生成器的嵌套。
  • 它允许直接从另一个生成器、迭代器或可迭代对象中生成值。

示例:

def generator1():
    yield from range(3)

def generator2():
    yield from generator1()
    yield from range(3, 6)

for value in generator2():
    print(value)

输出:

0
1
2
3
4
5

生成器表达式

  • 类似于列表推导式,但使用圆括号 () 而不是方括号 []
  • 生成器表达式是惰性求值的,适合处理大数据集。

示例:

gen = (x * x for x in range(5))
for value in gen:
    print(value)

输出:

0
1
4
9
16

6. 总结

  • 迭代器:适合自定义迭代逻辑,能够灵活控制遍历过程。
  • 生成器:适合按需生成值的场景,代码简洁且节省内存。
  • 选择依据
    • 如果需要自定义复杂的迭代逻辑,使用迭代器。
    • 如果需要按需生成值(如大数据集),使用生成器。

通过迭代器和生成器,可以显著提升代码的可读性和性能,同时优化内存使用。

如何在 PyTorch 中使用 DatasetDataLoader 来加载和处理数据。


1. 数据集与数据加载器

  • 核心概念

    • torch.utils.data.Dataset:存储样本及其标签的容器。
    • torch.utils.data.DataLoader:为 Dataset 提供一个可迭代的接口,支持批量加载、数据打乱和多进程加速。
  • 优点

    • 数据加载代码与模型训练代码解耦,提升代码的可读性和模块化。
    • 支持预加载数据集(如 FashionMNIST)和自定义数据集。

2. 加载预加载数据集

以 FashionMNIST 为例,展示了如何从 TorchVision 加载数据集:

from torchvision import datasets
from torchvision.transforms import ToTensor

# 加载训练集和测试集
training_data = datasets.FashionMNIST(
    root="data",           # 数据存储路径
    train=True,            # 训练集
    download=True,         # 下载数据
    transform=ToTensor()   # 转换为张量
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,           # 测试集
    download=True,
    transform=ToTensor()
)

3. 可视化数据

通过索引访问数据集,并使用 Matplotlib 可视化样本:

import matplotlib.pyplot as plt

labels_map = {
    0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat",
    5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot"
}

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

4. 创建自定义数据集

实现一个自定义数据集类,继承 torch.utils.data.Dataset 并实现以下方法:

  • __init__:初始化数据集路径、标注文件及转换操作。
  • __len__:返回数据集的样本数量。
  • __getitem__:根据索引加载并返回样本及其标签。

示例代码:

import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)  # 标注文件
        self.img_dir = img_dir                          # 图像目录
        self.transform = transform                      # 特征转换
        self.target_transform = target_transform        # 标签转换

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])  # 获取图像路径
        image = read_image(img_path)                                         # 读取图像
        label = self.img_labels.iloc[idx, 1]                                # 获取标签
        if self.transform:
            image = self.transform(image)                                   # 应用特征转换
        if self.target_transform:
            label = self.target_transform(label)                            # 应用标签转换
        return image, label

5. 使用 DataLoader 加载数据

DataLoader 提供了对数据集的高效访问,支持批量加载、数据打乱和多进程加速:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

6. 遍历 DataLoader

每次迭代返回一个批次的特征和标签:

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")  # [batch_size, channels, height, width]
print(f"Labels batch shape: {train_labels.size()}")     # [batch_size]

# 可视化第一个样本
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

7. 总结

  • Dataset

    • 存储样本及其标签。
    • 支持预加载数据集和自定义数据集。
    • 自定义数据集需实现 __init____len____getitem__ 方法。
  • DataLoader

    • 提供高效的批量加载、数据打乱和多进程加速。
    • 是训练深度学习模型时的核心工具。

通过合理使用 DatasetDataLoader,可以显著简化数据加载流程,提高代码的可维护性和性能。

torch.utils.data

torch.utils.data 是 PyTorch 数据加载的核心模块,其中最重要的类是 DataLoader。它是一个可迭代对象,用于从数据集中高效地加载数据,并支持多种功能和选项。以下是总结:


核心功能

DataLoader 提供了以下主要功能:

  1. 支持两种数据集类型

    • Map-style 数据集:通过索引访问数据(实现 __getitem____len__ 方法)。
    • Iterable-style 数据集:通过迭代器顺序生成数据(实现 __iter__ 方法)。
  2. 自定义数据加载顺序

    • 通过 shuffle 参数或自定义 Sampler 控制数据加载顺序。
  3. 自动批量处理

    • 使用 batch_size 参数将样本组合成批次。
    • 支持自定义批量逻辑(通过 collate_fn)。
  4. 单进程和多进程数据加载

    • 单进程加载适合小型数据集或调试场景。
    • 多进程加载(通过 num_workers 参数)适合大型数据集和高性能需求。
  5. 自动内存锁定

    • 设置 pin_memory=True 将数据放置在页锁定内存中,加速数据传输到 CUDA GPU。

构造函数参数

DataLoader 的构造函数提供了丰富的参数来配置数据加载行为:

  • 基本参数

    • dataset:要加载的数据集。
    • batch_size:每个批次包含的样本数量。
    • shuffle:是否在每个 epoch 开始时打乱数据。
    • sampler:自定义采样器以控制数据加载顺序。
    • batch_sampler:自定义批量采样器,与 batch_sizeshuffledrop_last 互斥。
  • 多进程相关参数

    • num_workers:用于数据加载的子进程数量。
    • worker_init_fn:初始化每个工作进程的函数。
    • prefetch_factor:每个工作进程预取的批次数。
    • persistent_workers:是否在数据集耗尽后保持工作进程存活。
  • 其他功能参数

    • collate_fn:用于将样本组合成批次的函数。
    • pin_memory:是否启用内存锁定。
    • drop_last:是否丢弃最后一个不完整的批次。
    • timeout:设置从工作进程收集批次的超时时间。

总结

torch.utils.data.DataLoader 是一个功能强大且灵活的工具,能够满足各种数据加载需求。通过合理配置其参数,可以实现高效的批量处理、多进程加载和内存优化,从而显著提升模型训练的效率和性能。

在实际使用中,应根据数据集的类型、任务需求以及硬件资源选择合适的配置,例如是否启用多进程加载、是否使用内存锁定等。
PyTorch 的 DataLoader 是一个非常强大的工具,用于加载数据集并支持高效的批量处理、多线程加载等功能。根据数据集的类型和需求,DataLoader 提供了不同的机制来加载和处理数据。以下是内容的总结与解释:


1. 数据集类型

PyTorch 支持两种主要的数据集类型:

(1) Map-style 数据集

  • 定义:实现 __getitem__()__len__() 方法。
  • 特点
    • 数据集被看作是从索引(或键)到样本的映射。
    • 可以通过索引访问特定样本,例如 dataset[idx]
    • 常见场景:从磁盘文件夹中读取图片及其标签。
  • 适用场景:适合可以随机读取数据的情况。

(2) Iterable-style 数据集

  • 定义:实现 __iter__() 方法,是 IterableDataset 的子类。
  • 特点
    • 数据集被看作是一个可迭代的对象。
    • 不支持随机访问,而是按顺序生成数据流。
    • 常见场景:从数据库、远程服务器或实时日志中读取数据。
  • 注意
    • 在多进程加载时,每个工作进程会复制相同的 IterableDataset 对象,因此需要配置以避免重复数据。

2. 数据加载顺序与采样器

  • Map-style 数据集

    • 使用 Sampler 类控制数据加载顺序。
    • 默认情况下,DataLoader 根据 shuffle 参数自动构造顺序或随机打乱的采样器。
    • 用户可以通过 samplerbatch_sampler 参数自定义采样逻辑。
  • Iterable-style 数据集

    • 数据加载顺序完全由用户定义的迭代器控制。
    • 不支持 Samplerbatch_sampler,因为这类数据集没有索引或键的概念。

3. 批量加载与非批量加载

DataLoader 提供了灵活的方式来处理批量数据加载。

(1) 自动批量加载(默认行为)

  • 参数
    • batch_size:指定批量大小。
    • drop_last:是否丢弃最后一个不完整的批次。
    • collate_fn:用于将单个样本组合成批量的函数。
  • 工作流程
    • 从数据集中获取一批样本。
    • 使用 collate_fn 将这些样本组合成一个批次。
  • 默认行为
    • 如果每个样本是 (image, label) 形式的元组,collate_fn 会将它们组合成 (batched_images, batched_labels)
    • 自动将 NumPy 数组转换为 PyTorch 张量,并添加批量维度。

(2) 禁用自动批量加载

  • 场景
    • 用户希望手动处理批量逻辑。
    • 数据集本身已经返回批量数据。
  • 方法
    • 设置 batch_size=Nonebatch_sampler=None
    • 每次从数据集中获取单个样本,并使用 collate_fn 进行简单处理(如将 NumPy 数组转换为张量)。

4. 自定义 collate_fn

collate_fn 是一个灵活的工具,可以根据需求自定义批量处理逻辑。

(1) 自动批量加载启用时

  • 输入:一批样本的列表。
  • 输出:一个批次的数据。
  • 示例:
    • 如果样本是字典形式,collate_fn 会输出一个字典,其中每个值是批量化的张量。
    • 如果样本是序列形式,可以自定义填充逻辑以处理不同长度的序列。

(2) 自动批量加载禁用时

  • 输入:单个样本。
  • 输出:经过简单处理的样本(如 NumPy 转张量)。

5. 注意事项

  • 多进程加载
    • 对于 IterableDataset,需要确保每个工作进程的数据流不重复。
  • collate_fn 调试
    • 如果 DataLoader 输出的数据维度或类型不符合预期,应检查 collate_fn 的实现。

6. 总结

  • Map-style 数据集适合随机访问数据的场景,而Iterable-style 数据集适合顺序读取数据的场景。
  • 数据加载顺序可以通过 Sampler 或用户定义的迭代器控制。
  • DataLoader 提供了灵活的批量加载机制,默认情况下会自动将样本组合成批次。
  • 使用 collate_fn 可以自定义批量处理逻辑,满足特殊需求。

通过合理选择数据集类型、采样器和批量加载策略,可以高效地加载和处理各种复杂的数据集。

PyTorch 的 DataLoader 提供了单进程和多进程两种数据加载模式,以满足不同场景下的需求。以下是内容的总结与解释:


1. 单进程数据加载(默认模式)

  • 特点
    • 数据加载在初始化 DataLoader 的同一进程中完成。
    • 数据加载可能会阻塞计算代码,因为 Python 的全局解释器锁(GIL)限制了线程间的并行执行。
  • 适用场景
    • 数据集较小且可以完全加载到内存中。
    • 系统资源(如共享内存、文件描述符)有限。
    • 调试时更易读的错误追踪信息。
  • 优点
    • 实现简单,调试方便。
  • 缺点
    • 数据加载可能成为性能瓶颈。

2. 多进程数据加载

通过设置 num_workers 参数为正整数,可以启用多进程数据加载。

(1) 工作原理

  • 每次创建 DataLoader 的迭代器时(例如调用 enumerate(dataloader)),会启动 num_workers 个子进程。
  • 数据集对象、collate_fnworker_init_fn 会被传递给每个子进程,并在子进程中初始化和加载数据。
  • 子进程中的数据加载操作(包括 IO 和转换)是独立于主进程的。

(2) 内存使用注意事项

  • 问题
    • 子进程会复制父进程中所有被访问的 Python 对象,可能导致内存占用显著增加(总内存 = 工作进程数 × 父进程内存)。
    • 如果数据集包含大量数据(如长列表的文件名),或使用了过多的工作进程,可能会导致内存不足。
  • 解决方法
    • 使用非引用计数的对象表示形式(如 Pandas、NumPy 或 PyArrow 对象)来减少内存开销。

(3) 平台相关行为

  • Unix 系统
    • 默认使用 fork() 创建子进程,子进程可以直接访问父进程的地址空间。
  • Windows/MacOS 系统
    • 默认使用 spawn() 创建子进程,子进程会重新运行主脚本并通过序列化(pickle)传递参数。
    • 兼容性建议
      • 将主脚本的主要逻辑放在 if __name__ == '__main__': 块中,避免子进程重复执行代码。
      • 确保自定义的 collate_fnworker_init_fn 和数据集代码是顶层定义,以便在子进程中可用。

3. 数据集类型的多进程处理

  • Map-style 数据集
    • 主进程生成索引并通过采样器发送给工作进程。
    • 随机打乱等操作由主进程控制。
  • Iterable-style 数据集
    • 每个工作进程都会获得数据集对象的一个副本,可能导致数据重复。
    • 解决方法:
      • 使用 torch.utils.data.get_worker_info()worker_init_fn 配置每个副本,确保数据分片独立。
      • 在多进程加载时,drop_last 参数会丢弃每个工作进程最后一个不完整的批次。

4. 随机性管理

  • 默认行为
    • 每个工作进程的 PyTorch 种子设置为 base_seed + worker_id,其中 base_seed 是由主进程生成的随机数。
    • 其他库的种子可能在初始化工作进程时重复,导致每个进程生成相同的随机数。
  • 解决方法
    • worker_init_fn 中,可以使用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 来为其他库设置种子。

5. CUDA 张量与多进程加载

  • 警告
    • 不建议在多进程加载中返回 CUDA 张量,因为 CUDA 和多进程之间存在许多复杂性。
  • 推荐做法
    • 启用自动内存锁定(pin_memory=True),以加速数据传输到 CUDA GPU。

6. 总结

  • 单进程加载
    • 简单易用,适合小型数据集或调试场景。
    • 可能因数据加载阻塞计算而影响性能。
  • 多进程加载
    • 提高性能,适合大型数据集和复杂任务。
    • 需要注意内存占用、平台差异和随机性管理。
  • 最佳实践
    • 根据数据集大小和任务需求选择合适的加载模式。
    • 在多进程加载中,合理配置工作进程数和数据集分片,避免内存和性能问题。
    • 使用 pin_memory=True 加速 GPU 数据传输。

通过合理配置 DataLoader 的参数,可以实现高效、灵活的数据加载流程,从而提升模型训练的效率和稳定性。

这段内容主要讨论了 PyTorch 中 DataLoader 的内存锁定(Memory Pinning)功能,以及如何通过自定义数据类型支持该功能。以下是总结和解释:

内存锁定(Memory Pinning)

  • 概念:内存锁定指的是将数据放置在页锁定(pinned)内存中,这可以显著加快从主机到 GPU 的数据传输速度。
  • 使用方法
    • pin_memory=True 参数传递给 DataLoader,这样从数据集中获取的数据张量将自动置于页锁定内存中,从而实现更快的 CUDA 设备数据传输。

默认内存锁定逻辑

  • 默认情况下,内存锁定逻辑仅识别张量(Tensors)及包含张量的映射和可迭代对象。
  • 如果你的批处理是一个自定义类型(例如,你有一个返回自定义批处理类型的 collate_fn),或批处理中的每个元素都是自定义类型,内存锁定逻辑将不会识别它们,并会直接返回这些批处理或元素而不锁定内存。

为自定义类型启用内存锁定

  • 对于自定义数据类型,可以通过在其上定义一个 pin_memory() 方法来启用内存锁定功能。
  • 示例代码展示了如何为一个自定义批处理类型 SimpleCustomBatch 实现内存锁定逻辑。
class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # 自定义内存锁定方法
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

# 创建数据集并使用 DataLoader 加载数据
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

其他注意事项

  • 多进程加载警告
    • 当使用 spawn 启动方法时,worker_init_fn 不能是不可序列化的对象,如 lambda 函数。
  • 随机性管理
    • 每个工作进程都有自己的 PyTorch 种子,基于 base_seed + worker_id 设置。但是,对于其他库,种子可能重复,导致生成相同的随机数。可以在 worker_init_fn 中手动设置种子以避免这种情况。
  • in_order 参数
    • 如果设置为 False,可能会损害结果的可重复性,并可能导致不平衡数据训练器接收到的数据分布偏斜。

总结

  • 内存锁定能够加速数据从主机内存到 GPU 的传输过程,特别适合需要高性能数据传输的应用场景。
  • 自定义类型的支持:当使用自定义批处理类型时,需自行实现 pin_memory() 方法以利用内存锁定的优势。
  • 多进程加载与随机性管理:正确配置工作进程的初始化函数和随机种子,确保数据加载过程的稳定性和高效性。
  • 最佳实践:根据具体应用场景选择是否启用内存锁定,特别是在涉及大量数据传输至 GPU 的任务中,合理配置 DataLoader 参数可以大幅提升性能。
posted @ 2025-04-25 18:17  玉米面手雷王  阅读(44)  评论(0)    收藏  举报