PyTorch-Dataset和 DataLoader

简单来说:

  • Dataset 是一个仓库菜单。它定义了数据的来源(如图片文件、文本列表),并告诉程序如何根据索引(像菜单编号)获取一个单独的数据样本。
  • DataLoader 是一个高效的服务员和厨房。它从 Dataset 中取数据,按照你的要求(如批量大小、是否打乱)打包成批,并利用多进程并行加载,确保数据能持续、快速地供给模型。

下图清晰地展示了一条样本从原始数据到最终输入模型的完整旅程,以及DatasetDataLoader各自在其中扮演的角色:

flowchart LR A["原始数据<br>(图像、文本等)"] --> B[Dataset 类] subgraph B [Dataset 定义] B1["__init__<br>加载数据路径/列表"] --> B2["__getitem__<br>按索引读取并预处理单个样本"] end B2 --> C["单个样本<br>(如:Tensor 图像, 标签)"] C --> D{DataLoader} subgraph D [DataLoader 批处理] direction TB D1[Sampler<br>生成索引顺序] --> D2[Batch 组装] D3["可选的<br>自定义 collate_fn"] --> D2 end D2 --> E["一个 Batch<br>(如: Tensor[B,C,H,W], Tensor[B])"] E --> F[GPU / 模型训练]

1. Dataset:数据的抽象仓库

Dataset 是一个抽象类,你通过继承它来创建自己的数据类。核心是必须实现两个方法:

  • __len__():返回数据集的样本总数。
  • __getitem__(idx):根据给定的索引 idx 返回一个样本(如图像张量和它的标签)。

代码示例:创建自定义Dataset

import torch
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomImageDataset(Dataset):
    """一个自定义的图像数据集示例"""
    
    def __init__(self, img_dir, label_file, transform=None):
        """
        初始化函数,通常在这里读取数据路径或列表。
        Args:
            img_dir (str): 图像文件目录
            label_file (str): 标签文件路径
            transform (callable, optional): 一个可选的图像变换函数
        """
        self.img_dir = img_dir
        self.transform = transform
        
        # 假设label_file是一个每行"图片名 标签"的文本文件
        with open(label_file, 'r') as f:
            self.samples = [line.strip().split() for line in f]
    
    def __len__(self):
        """返回数据集的大小"""
        return len(self.samples)
    
    def __getitem__(self, idx):
        """根据索引返回一个样本(图像, 标签)"""
        img_name, label = self.samples[idx]
        img_path = os.path.join(self.img_dir, img_name)
        
        # 1. 加载图像
        image = Image.open(img_path).convert('RGB') # 转为RGB
        
        # 2. 应用变换(如调整大小、转为Tensor、归一化)
        if self.transform:
            image = self.transform(image)
        
        # 3. 将标签转为整数张量
        label = torch.tensor(int(label))
        
        return image, label # 返回一个样本对

Dataset 的核心变体

  • Map-style Dataset(最常用):就是上面这种。它像一个数组或映射,可以通过 dataset[i] 直接随机访问第 i 个样本。要求数据能全部加载到内存或路径可索引。
  • Iterable-style Dataset:像一个数据流,通过 __iter__() 方法顺序读取。适用于数据量极大(如大型文本文件、实时日志)且顺序读取的场景。

2. DataLoader:数据加载的引擎

DataLoader 接收一个 Dataset 对象,并封装了复杂的批生成、数据打乱、多进程并行加载等逻辑。

核心参数解析

from torch.utils.data import DataLoader
from torchvision import transforms

# 定义图像变换管道
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(), # 将PIL图像转为[C, H, W]的Tensor,并归一化到[0,1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
])

# 实例化Dataset
dataset = CustomImageDataset(img_dir='./data/images',
                             label_file='./data/labels.txt',
                             transform=transform)

# 创建DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=32,          # 💡 核心:批大小。影响内存使用和梯度稳定性。
    shuffle=True,           # 💡 是否在每个epoch打乱数据。训练集必须为True,验证/测试集为False。
    num_workers=4,          # 💡 用于数据加载的子进程数。可加速I/O密集型操作,但过多可能适得其反。
    pin_memory=True,        # 💡 将数据张量锁页到CPU内存,加速到GPU的传输。当使用GPU时建议启用。
    drop_last=True          # 💡 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的批次。
)

# 使用方式:通常在训练循环中迭代
for epoch in range(num_epochs):
    for batch_images, batch_labels in dataloader: # dataloader自动返回批数据
        # 此时 batch_images 的形状是 [32, 3, 256, 256]
        # batch_labels 的形状是 [32]
        
        # 将数据转移到GPU(如果可用)
        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
        
        # ... 后续的训练步骤:前向传播、计算损失、反向传播、优化器更新 ...

DataLoader 的内部齿轮

  1. sampler:决定读取样本的顺序。默认是 RandomSampler(当shuffle=True)或 SequentialSampler(当shuffle=False)。你可以自定义采样器,例如进行类别平衡采样WeightedRandomSampler)。
  2. batch_sampler:在 sampler 的基础上,进一步将索引分组为批次。通常不需要直接设置。
  3. collate_fn:一个非常重要的函数,它负责将 __getitem__ 返回的多个样本列表整理和打包成一个批量的张量。当你的样本结构复杂或不规则时(如处理不同长度的文本序列),需要自定义此函数。

代码示例:自定义 collate_fn 处理变长文本

def my_collate_fn(batch):
    """
    处理变长序列的collate_fn示例。
    batch: 一个列表,其元素是Dataset.__getitem__返回的样本,例如 [(text1, label1), (text2, label2), ...]
    """
    # 假设每个样本是 (token_ids, label),且token_ids的长度不同
    texts, labels = zip(*batch) # 解压成两个列表:texts = [token_ids1, token_ids2, ...], labels = [label1, label2, ...]
    
    # 1. 对labels直接堆叠
    labels = torch.stack(labels)
    
    # 2. 对变长文本进行padding
    padded_texts = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0)
    # padded_texts的形状现在是 [batch_size, max_seq_len_in_this_batch]
    
    # 3. 还可以同时生成一个“attention mask”来标记非padding部分
    attention_mask = (padded_texts != 0).long()
    
    return padded_texts, attention_mask, labels

# 在DataLoader中使用自定义collate_fn
text_dataloader = DataLoader(dataset, batch_size=16, collate_fn=my_collate_fn)

3. 在大模型训练中的最佳实践

  1. 明确分工

    • Dataset 专注于单个样本的读取和轻量预处理(如解码文件、调整大小)。
    • DataLoadercollate_fn 负责批层次的组装和高效加载(如padding、多进程)。
  2. 性能关键

    • num_workers:根据你的CPU核心数和任务类型(I/O vs CPU计算)调整。通常设置为 [2, 4, 8] 进行测试,找到性能峰值点。
    • pin_memory=True:在GPU训练时务必开启,可以显著减少数据从CPU到GPU的传输时间。
    • 避免在 __getitem__ 中进行重操作:如果可能,将耗时操作(如读取大文件)提前到 __init__ 中。
  3. 处理大规模数据

    • 对于超大规模数据集,考虑使用 IterableDataset 进行流式加载。
    • 使用 分布式训练 时,DataLoader 需要配合 DistributedSampler 来确保每个GPU进程看到数据的不同部分。
posted @ 2026-02-08 12:03  ffff5  阅读(20)  评论(0)    收藏  举报