数据并不会形成训练机器学习算法需要的最终处理过的数据格式。所以使用transform来执行一些数据操作使数据适合训练。
所有的TorchVision数据集有两个参数--transform来修改特征和target_transform来修改标签,接收包含transform逻辑的调用方法。
FashionMNIST特征图是PIL Image格式,标签是整数类型。为了训练,需要将特征转换为归一化tensors,标签转换为独热编码tensors。为了实现这些转变,使用ToTensor和Lambda。
import torch from torchvision import datasets from torchvision.transforms import ToTensor, Lambda ds = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) )
ToTensor()
ToTensor将PIL image或者Numpy ndarray转换为FloatTensor。并且标量图片的像素强度值在[0,1]之间。
Lambda Transforms
Lambda转换应用在任何用户定义的lamba函数,定义一个函数将整数转换为独热编码tensor,先创建一个大小为10的0 tensor和值为1,index为label的标量。
target_transform = Lambda(lambda y: torch.zeros( 10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
posted on
浙公网安备 33010602011771号