torchvision、Dataset与dataloader
Dataloader:
from torch.utils.data import DataLoader tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数 batch_size=2, # 输出的batch大小 shuffle=True, # 数据是否打乱 num_workers=0) # 进程数, 0表示只有主进程 # 以循环形式输出 for data, target in tensor_dataloader: print(data, target) ''' 输出: tensor([[-0.1781, -1.1019, -0.1507], [-0.6170, 0.2366, 0.1006]]) tensor([0, 0]) tensor([[ 0.9451, -0.4923, -1.8178], [-0.4046, -0.5436, -1.7911]]) tensor([0, 0]) tensor([[-0.4561, -1.2480, -0.3051], [-0.9738, 0.9465, 0.4812]]) tensor([1, 0]) tensor([[ 0.0260, 1.5276, 0.1687], [ 1.3692, -0.0170, -1.6831]]) tensor([1, 0]) tensor([[ 0.0515, -0.8892, -0.1699], [ 0.4931, -0.0697, 0.4171]]) tensor([1, 0]) ''' # 输出一个batch print('One batch tensor data: ', iter(tensor_dataloader).next()) ''' 输出: One batch tensor data: [tensor([[ 0.9451, -0.4923, -1.8178], [-0.4046, -0.5436, -1.7911]]), tensor([0, 0])] '''
Dataloader 参数:
- dataset:Dataset 类型,输入的数据集,必须参数;
- batch_size:int 类型,每个 batch 有多少个样本;
- shuffle:bool 类型,在每个 epoch 开始的时候,是否对数据进行重新打乱;
- num_workers:int 类型,加载数据的进程数,0 意味着所有的数据都会被加载进主进程,默认为 0。
Torchvision:
torchvision 库就是常用数据集 + 常见网络模型 + 常用图像处理方法。

浙公网安备 33010602011771号