初识pytorch:数据标准化及数据增强的transforms

transfroms

在 PyTorch 中,torchvision.transforms 是用于数据预处理和数据增强的工具集,主要作用是将原始数据(如图像、文本)转换为适合模型输入的格式,并通过随机变换增加数据多样性,从而提升模型的泛化能力。

transforms的作用
1.数据标准化

将原始数据转换为模型要求的格式,例如:

  • 将图像从 PIL 格式转为 Tensor(ToTensor)
  • 对像素值进行归一化(Normalize),使数据分布更稳定
  • 调整图像尺寸(Resize),确保输入尺寸一致

2.数据增强

通过随机变换生成更多样化的训练样本,减少过拟合,例如:

  • 随机裁剪(RandomCrop)、翻转(RandomHorizontalFlip)
  • 随机调整亮度、对比度(ColorJitter)
  • 随机旋转(RandomRotation)

还是上一下示例代码:

from torchvision import transforms
from PIL import Image

# 定义变换流水线
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小为 224x224
    transforms.RandomHorizontalFlip(p=0.5),  # 50% 概率水平翻转(数据增强)
    transforms.ToTensor(),  # 转为 Tensor
    transforms.Normalize(  # 归一化(使用 ImageNet 数据集的均值和标准差)
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 加载图像并应用变换
img = Image.open("test.jpg")  # 原始 PIL 图像
transformed_img = transform(img)  # 经过变换后的 Tensor

print(transformed_img.shape)  # 输出: torch.Size([3, 224, 224])(通道数×高×宽)

训练和测试中transforms的使用区别

  • 训练集:通常加入数据增强(如随机翻转、裁剪),增加数据多样性
  • 测试集:仅进行标准化处理(如 Resize、ToTensor、Normalize),不使用随机变换,确保结果可复现

例如:

# 训练集变换(含数据增强)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 测试集变换(无随机操作)
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),  # 中心裁剪,而非随机
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

总结:
transforms 是连接原始数据与模型输入的关键环节,其核心价值在于:

1.统一数据格式,使原始数据符合模型输入要求
2.通过数据增强扩展训练样本多样性,提升模型泛化能力
3.简化数据预处理流程,与 Dataset、DataLoader 无缝配合

在实际使用中,需根据数据集特点和模型需求选择合适的变换组合,平衡数据增强效果与计算开销。

posted @ 2025-10-11 21:42  沃德天sama  阅读(22)  评论(0)    收藏  举报
1