第五章:计算机视觉-项目实战之生成式算法实战:扩散模型

第三部分:生成式算法实战:扩散模型

第二节:DDPM数据读取

在上一节中,我们介绍了如何从零开始训练扩散模型的总体思路与训练框架。本节将深入探讨训练过程中最关键的第一步——数据读取与预处理
无论是DDPM(Denoising Diffusion Probabilistic Model)还是其他生成式模型,数据质量与输入管线的设计,都会直接影响模型的收敛速度与生成效果。


一、数据读取的重要性

扩散模型的核心任务是学习噪声到图像的反向映射,因此它依赖于大量高质量的图像样本。
在训练中,每一张图像都会被多次采样、添加噪声、再进行去噪预测。
因此,一个高效的数据加载系统应当满足以下特征:

  1. 高吞吐量:支持批量加载与GPU并行。

  2. 数据随机化:避免模型过拟合到特定顺序。

  3. 可扩展性:支持多种图像来源(本地文件夹、WebDataset、HuggingFace Datasets等)。

  4. 轻量预处理:在加载阶段完成尺寸缩放、归一化、增强等。


二、数据集结构示例

以CIFAR-10为例,数据组织通常如下:

/data
 ├── train
 │    ├── cat_0001.png
 │    ├── cat_0002.png
 │    ├── dog_0001.png
 │    └── ...
 └── val
      ├── cat_1001.png
      ├── dog_1002.png
      └── ...

当然,对于自定义数据集,只要保证所有图像可被正确读取即可。
DDPM通常不需要标签信息(除非是条件生成,如class-conditional DDPM)。


三、数据读取核心代码实现(PyTorch)

以下为一个简化版的 DDPM数据加载器 实现:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import os
# 1. 自定义Dataset类
class DiffusionDataset(Dataset):
    def __init__(self, data_dir, image_size=64):
        self.data_dir = data_dir
        self.image_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".png") or f.endswith(".jpg")]
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # 归一化到[-1,1]
        ])
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        return self.transform(image)
# 2. DataLoader创建
def create_dataloader(data_dir, batch_size=64, image_size=64, num_workers=4):
    dataset = DiffusionDataset(data_dir, image_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    return dataloader
# 使用示例
train_loader = create_dataloader("./data/train", batch_size=128, image_size=64)

四、数据预处理细节说明

操作说明目的
Resize将图像统一缩放到指定尺寸(如64×64)保证批次一致性
CenterCrop居中裁剪(可避免边缘干扰)提高图像稳定性
ToTensor将PIL图像转换为Tensor进入PyTorch计算图
Normalize([0.5],[0.5])将像素值从[0,1]缩放到[-1,1]与扩散模型的噪声范围匹配

五、批次可视化验证

在训练前,我们建议先可视化数据批次,确保数据被正确读取与归一化:

import matplotlib.pyplot as plt
import torchvision
def show_batch(dataloader):
    images = next(iter(dataloader))
    grid = torchvision.utils.make_grid(images[:64], nrow=8, normalize=True)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.axis("off")
    plt.show()
show_batch(train_loader)

如果显示出的图像清晰、亮度正常且分布均匀,即可开始训练。


六、与DDPM训练循环对接

在DDPM中,数据加载器的输出直接送入训练主循环:

for epoch in range(num_epochs):
    for images in train_loader:
        images = images.to(device)
        t = torch.randint(0, timesteps, (images.size(0),), device=device).long()
        noisy_images, noise = diffusion.add_noise(images, t)
        predicted_noise = model(noisy_images, t)
        loss = loss_fn(predicted_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

七、扩展与优化

优化方式技术实现效果
多GPU并行加载DistributedSampler大幅提高吞吐量
WebDataset格式支持.tar.tfrecord适合超大规模数据
随机增强加入随机裁剪、翻转、颜色扰动提升模型泛化能力
缓存与预加载使用prefetch_factorpersistent_workers=True避免I/O瓶颈

八、小结

核心要点内容
输入质量决定生成质量模型再强也离不开干净、均衡的数据
高效DataLoader是训练基础优化加载性能可节省大量GPU时间
归一化与尺寸一致性非常重要不同图像尺寸会破坏批次一致性
推荐逐步扩展数据规模先用CIFAR-10调试,再迁移至高分辨率数据集

本节总结

  • 学会了如何构建自定义数据加载类,并使用torchvision工具完成预处理;

  • 了解了数据预处理在DDPM训练流程中的关键作用

  • 掌握了如何将数据管线与DDPM训练主循环衔接;

  • 为下一节的噪声添加与反向去噪过程实现打下基础。