加载和处理不同类型数据。

加载和迭代数据集,此类DataLoader在torch.utils.data模块中。

from torch.utils.data import DataLoader

重要参数说明:

  • dataset:Dataset类,决定数据从哪里读取及如何读取
  • batch_size:批大小
  • num_workers:是否多进程读取数据(减少时间,加速模型训练)
  • shuffle:每个epoch是否打乱
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

代码示例:

import torch
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


test_data=torchvision.datasets.CIFAR10('./CIFAR',train=False,download=False,transform=torchvision.transforms.ToTensor())

dataloader=DataLoader(dataset=test_data,batch_size=64,shuffle=True)

writer=SummaryWriter('dataloader')
epoch=2
for i in range(epoch):
    step=0
    for data in dataloader:
        img,target=data
        # print(img)
        # print(img.shape)
        # print(target)
        writer.add_image('epoch {}'.format(i),img,step)
        step=step+1
writer.close()

 

 posted on 2024-03-16 16:22  会飞的金鱼  阅读(35)  评论(0)    收藏  举报