PyTorch-Dataset和 DataLoader
简单来说:
Dataset是一个仓库或菜单。它定义了数据的来源(如图片文件、文本列表),并告诉程序如何根据索引(像菜单编号)获取一个单独的数据样本。DataLoader是一个高效的服务员和厨房。它从Dataset中取数据,按照你的要求(如批量大小、是否打乱)打包成批,并利用多进程并行加载,确保数据能持续、快速地供给模型。
下图清晰地展示了一条样本从原始数据到最终输入模型的完整旅程,以及Dataset和DataLoader各自在其中扮演的角色:
flowchart LR
A["原始数据<br>(图像、文本等)"] --> B[Dataset 类]
subgraph B [Dataset 定义]
B1["__init__<br>加载数据路径/列表"] --> B2["__getitem__<br>按索引读取并预处理单个样本"]
end
B2 --> C["单个样本<br>(如:Tensor 图像, 标签)"]
C --> D{DataLoader}
subgraph D [DataLoader 批处理]
direction TB
D1[Sampler<br>生成索引顺序] --> D2[Batch 组装]
D3["可选的<br>自定义 collate_fn"] --> D2
end
D2 --> E["一个 Batch<br>(如: Tensor[B,C,H,W], Tensor[B])"]
E --> F[GPU / 模型训练]
1. Dataset:数据的抽象仓库
Dataset 是一个抽象类,你通过继承它来创建自己的数据类。核心是必须实现两个方法:
__len__():返回数据集的样本总数。__getitem__(idx):根据给定的索引idx返回一个样本(如图像张量和它的标签)。
代码示例:创建自定义Dataset
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomImageDataset(Dataset):
"""一个自定义的图像数据集示例"""
def __init__(self, img_dir, label_file, transform=None):
"""
初始化函数,通常在这里读取数据路径或列表。
Args:
img_dir (str): 图像文件目录
label_file (str): 标签文件路径
transform (callable, optional): 一个可选的图像变换函数
"""
self.img_dir = img_dir
self.transform = transform
# 假设label_file是一个每行"图片名 标签"的文本文件
with open(label_file, 'r') as f:
self.samples = [line.strip().split() for line in f]
def __len__(self):
"""返回数据集的大小"""
return len(self.samples)
def __getitem__(self, idx):
"""根据索引返回一个样本(图像, 标签)"""
img_name, label = self.samples[idx]
img_path = os.path.join(self.img_dir, img_name)
# 1. 加载图像
image = Image.open(img_path).convert('RGB') # 转为RGB
# 2. 应用变换(如调整大小、转为Tensor、归一化)
if self.transform:
image = self.transform(image)
# 3. 将标签转为整数张量
label = torch.tensor(int(label))
return image, label # 返回一个样本对
Dataset 的核心变体
- Map-style Dataset(最常用):就是上面这种。它像一个数组或映射,可以通过
dataset[i]直接随机访问第i个样本。要求数据能全部加载到内存或路径可索引。 - Iterable-style Dataset:像一个数据流,通过
__iter__()方法顺序读取。适用于数据量极大(如大型文本文件、实时日志)且顺序读取的场景。
2. DataLoader:数据加载的引擎
DataLoader 接收一个 Dataset 对象,并封装了复杂的批生成、数据打乱、多进程并行加载等逻辑。
核心参数解析
from torch.utils.data import DataLoader
from torchvision import transforms
# 定义图像变换管道
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(), # 将PIL图像转为[C, H, W]的Tensor,并归一化到[0,1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
])
# 实例化Dataset
dataset = CustomImageDataset(img_dir='./data/images',
label_file='./data/labels.txt',
transform=transform)
# 创建DataLoader
dataloader = DataLoader(
dataset,
batch_size=32, # 💡 核心:批大小。影响内存使用和梯度稳定性。
shuffle=True, # 💡 是否在每个epoch打乱数据。训练集必须为True,验证/测试集为False。
num_workers=4, # 💡 用于数据加载的子进程数。可加速I/O密集型操作,但过多可能适得其反。
pin_memory=True, # 💡 将数据张量锁页到CPU内存,加速到GPU的传输。当使用GPU时建议启用。
drop_last=True # 💡 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的批次。
)
# 使用方式:通常在训练循环中迭代
for epoch in range(num_epochs):
for batch_images, batch_labels in dataloader: # dataloader自动返回批数据
# 此时 batch_images 的形状是 [32, 3, 256, 256]
# batch_labels 的形状是 [32]
# 将数据转移到GPU(如果可用)
batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
# ... 后续的训练步骤:前向传播、计算损失、反向传播、优化器更新 ...
DataLoader 的内部齿轮
sampler:决定读取样本的顺序。默认是RandomSampler(当shuffle=True)或SequentialSampler(当shuffle=False)。你可以自定义采样器,例如进行类别平衡采样(WeightedRandomSampler)。batch_sampler:在sampler的基础上,进一步将索引分组为批次。通常不需要直接设置。collate_fn:一个非常重要的函数,它负责将__getitem__返回的多个样本列表整理和打包成一个批量的张量。当你的样本结构复杂或不规则时(如处理不同长度的文本序列),需要自定义此函数。
代码示例:自定义 collate_fn 处理变长文本
def my_collate_fn(batch):
"""
处理变长序列的collate_fn示例。
batch: 一个列表,其元素是Dataset.__getitem__返回的样本,例如 [(text1, label1), (text2, label2), ...]
"""
# 假设每个样本是 (token_ids, label),且token_ids的长度不同
texts, labels = zip(*batch) # 解压成两个列表:texts = [token_ids1, token_ids2, ...], labels = [label1, label2, ...]
# 1. 对labels直接堆叠
labels = torch.stack(labels)
# 2. 对变长文本进行padding
padded_texts = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0)
# padded_texts的形状现在是 [batch_size, max_seq_len_in_this_batch]
# 3. 还可以同时生成一个“attention mask”来标记非padding部分
attention_mask = (padded_texts != 0).long()
return padded_texts, attention_mask, labels
# 在DataLoader中使用自定义collate_fn
text_dataloader = DataLoader(dataset, batch_size=16, collate_fn=my_collate_fn)
3. 在大模型训练中的最佳实践
-
明确分工:
- 让
Dataset专注于单个样本的读取和轻量预处理(如解码文件、调整大小)。 - 让
DataLoader和collate_fn负责批层次的组装和高效加载(如padding、多进程)。
- 让
-
性能关键:
num_workers:根据你的CPU核心数和任务类型(I/O vs CPU计算)调整。通常设置为[2, 4, 8]进行测试,找到性能峰值点。pin_memory=True:在GPU训练时务必开启,可以显著减少数据从CPU到GPU的传输时间。- 避免在
__getitem__中进行重操作:如果可能,将耗时操作(如读取大文件)提前到__init__中。
-
处理大规模数据:
- 对于超大规模数据集,考虑使用
IterableDataset进行流式加载。 - 使用 分布式训练 时,
DataLoader需要配合DistributedSampler来确保每个GPU进程看到数据的不同部分。
- 对于超大规模数据集,考虑使用
本文来自博客园,作者:ffff5,转载请注明原文链接:https://www.cnblogs.com/ffff5/p/19591229

浙公网安备 33010602011771号