Pytorch数据处理流程
1. numpy.genfromtxt(path, delimiter=',', dtype=str, skip_header=1)
将数据从csv导入array *类型为string
若数据为图像,还需对图像进行处理(增广)
string--split()--list--np.array()--ndarray--reshape()
最后转换数据类型string to float/int
2. torch.Tensor(x_train)
将np.array转换成torch.Tensor
3. TensorDataset(data,lable) from torch.utils.data
将data和label合并成一个dataset
4. DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8) from torch.utils.data
将dataset导入DataLoader,由torch进行batch分割
5. for i, data in enumerate(DataLoader):
model(data[0])导入模型训练

浙公网安备 33010602011771号