PyTorch笔记--MNIST数据集的准备

MNIST数据集是手写数字数据集,它是分类任务的数据集。所有图像是28x28大小的黑白图像,分为训练集和测试集两个数据集,

训练集有60000张图像,测试集有10000张图像,图像的内容为0~9的手写数字。

 

 1 from torchvision.datasets import MNIST
 2 import torchvision.transforms as transforms
 3 from torch.utils.data import DataLoader
 4 
 5 data_train = MNIST('./data',   # MNIST数据集的目录
 6                    download=True,  # 如果没有MNIST数据集,将会从网络上下载
 7                    transform=transforms.Compose([  # 对图片进行的变换(旋转、切割、数据类型转换等)
 8                        transforms.Resize((32, 32)),
 9                        transforms.ToTensor()
10                    ]))
11 data_test = MNIST('./data',
12                   train=False,
13                   download=False, 
14                   transform=transforms.Compose([  
15                       transforms.Resize((32, 32)),
16                       transforms.ToTensor()
17                   ]))
18 data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
19 data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)

 

data_train是训练数据集,data_test是测试数据集。download=True代表如果没有下载MNIST数据集,就会将它下到当前目录的data文件下,如果

已经下载了数据集,就直接从其读取。transform是将MNIST数据集做相应的变换,并且最终以张量输出。

数据加载器用来把训练数据或测试和数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。

data_train_loader是训练数据集的载入器,它载入了数据集data_train。因为train=True,代表该数据集是训练集。False代表该数据集是测试集,

默认值是true。批次大小设置是256(每个批次中的样本数量是256),shuffle=True代表了输入图像会做随机排列,同时工作的进程数为8。

data_test_loader是测试数据集载入器,它载入了数据集data_test,因为在测试数据集中运行模型并不需要记录计算图以及计算反向传播速度,

节约了内存。因此,在代码中相应的批次的大小也有了相应的增加,这里设置的1024(每个批次中的样本数量是1024)。同时工作的进程数为8。

 

MNIST(  # 载入数据集
    root,   # 数据集,根(字符串)–MNIST所在的文件夹
    train=True, # True = 训练集, False = 测试集
    transform=None, # 接收PIL映像并返回转换版本的函数/变换
    target_transform=None, # 接受目标并对其进行转换的函数
    download=False  # 下载(bool,可选)–如果为true,则从internet下载数据集并将其放在根目录中。如果数据集已下载,则不会再次下载。
    )

 

class torchvision.transforms.Compose(transforms)  # 将多个transform组合起来使用 

 

class torch.utils.data.DataLoader(  # 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
                    dataset, # 加载数据的数据集
                    batch_size=1, # 批训练的数据个数
                    shuffle=False, # 设置为True,在每个epoch重新排列数据
                    sampler=None, # 定义从数据集中提取样本的策略,如果为True,则忽略shuffle参数。
                    num_workers=0, # 用于数据加载的子进程数
                    collate_fn=<function default_collate>, # 合并样本以形成小批量
                    pin_memory=False, # 如果为True,数据加载器在返回前将张量复制到CUDA固定内存中
                    drop_last=False # 如果数据集大小不能被batch_size整除,设置为True可删除最后一个不完整的批处理
                    )

 

posted @ 2021-08-14 13:08  奋斗的小仔  阅读(474)  评论(0)    收藏  举报