加载和处理不同类型数据。
加载和迭代数据集,此类DataLoader在torch.utils.data模块中。
from torch.utils.data import DataLoader
重要参数说明:
- dataset:Dataset类,决定数据从哪里读取及如何读取
- batch_size:批大小
- num_workers:是否多进程读取数据(减少时间,加速模型训练)
- shuffle:每个epoch是否打乱
- drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
代码示例:
import torch import torchvision.datasets from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter test_data=torchvision.datasets.CIFAR10('./CIFAR',train=False,download=False,transform=torchvision.transforms.ToTensor()) dataloader=DataLoader(dataset=test_data,batch_size=64,shuffle=True) writer=SummaryWriter('dataloader') epoch=2 for i in range(epoch): step=0 for data in dataloader: img,target=data # print(img) # print(img.shape) # print(target) writer.add_image('epoch {}'.format(i),img,step) step=step+1 writer.close()
posted on
浙公网安备 33010602011771号