Pytorch数据处理流程

1. numpy.genfromtxt(path, delimiter=',', dtype=str, skip_header=1)

将数据从csv导入array *类型为string

若数据为图像,还需对图像进行处理(增广)

string--split()--list--np.array()--ndarray--reshape()

最后转换数据类型string to float/int

2. torch.Tensor(x_train)

将np.array转换成torch.Tensor

3. TensorDataset(data,lable) from torch.utils.data

将data和label合并成一个dataset

4. DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8) from torch.utils.data

将dataset导入DataLoader,由torch进行batch分割

5. for i, data in enumerate(DataLoader):

model(data[0])导入模型训练

posted @ 2019-11-12 21:12  Junzhao  阅读(523)  评论(0)    收藏  举报