数据并不会形成训练机器学习算法需要的最终处理过的数据格式。所以使用transform来执行一些数据操作使数据适合训练。

所有的TorchVision数据集有两个参数--transform来修改特征和target_transform来修改标签,接收包含transform逻辑的调用方法。

FashionMNIST特征图是PIL Image格式,标签是整数类型。为了训练,需要将特征转换为归一化tensors,标签转换为独热编码tensors。为了实现这些转变,使用ToTensor和Lambda。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor将PIL image或者Numpy ndarray转换为FloatTensor。并且标量图片的像素强度值在[0,1]之间。

Lambda Transforms

Lambda转换应用在任何用户定义的lamba函数,定义一个函数将整数转换为独热编码tensor,先创建一个大小为10的0 tensor和值为1,index为label的标量。

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

 

 posted on 2024-03-24 13:53  会飞的金鱼  阅读(19)  评论(0)    收藏  举报