PyTorch笔记--CIFAR数据集
CIFAR数据集是物体的分类数据集,包含两个不同的数据子集,分别是CIFAR10和CIFAR100,这两个数据集均有60000张32x32大小
的图像,其中每个数据集被分成训练集和测试集两类,训练集有50000张图像,测试集有10000张图像。对于CIFAR10数据集来说,这些
图像被分成不同的10类,CIFAR100数据集则是将这些图像分成不同的100类。
1 import torch 2 from torchvision import datasets 3 from torchvision import transforms 4 from torch.utils.data import DataLoader 5 6 cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([ 7 transforms.Resize((32, 32)), 8 transforms.ToTensor() 9 ]), download=True) 10 11 cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True) 12 13 cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([ 14 transforms.Resize((32, 32)), 15 transforms.ToTensor() 16 ]), download=True) 17 18 cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
CIFAR10(
root, # 数据集根目录
train=True, # True = 训练集,False=测试集
transform=None, # 接收PIL映像并返回转换版本的函数/变换。
target_transform=None, # 接受目标并对其进行变换的函数
download=False # 下载,如果为true,则从internet下载数据集到根目录中,如果已经下载则不会再次下载。
)
class torch.utils.data.DataLoader( # 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。 dataset, # 加载数据的数据集 batch_size=1, # 批训练的数据个数 shuffle=False, # 设置为True,在每个epoch重新排列数据 sampler=None, # 定义从数据集中提取样本的策略,如果为True,则忽略shuffle参数。 num_workers=0, # 用于数据加载的子进程数 collate_fn=<function default_collate>, # 合并样本以形成小批量 pin_memory=False, # 如果为True,数据加载器在返回前将张量复制到CUDA固定内存中 drop_last=False # 如果数据集大小不能被batch_size整除,设置为True可删除最后一个不完整的批处理 )

浙公网安备 33010602011771号