PyTorch - 数据读取机制DataLoader
pytorch的数据读取机制DataLoader包括两个子模块,Sampler模块,主要是生成索引index,DataSet模块,主要是根据索引读取数据。Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset 类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。
在实际项目中,如果数据量很大,考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而 DataLoader 就是基于这些需要被设计出来的。DataLoader 是一个迭代器,最基本的使用方法就是传入一个Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。
torch.utils.data.Dataset
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
Dataset是用来解决数据从哪里读取以及如何读取的问题。 pytorch给定的Dataset是一个抽象类,所有自定义的Dataset都要继承它,并且复写__getitem__()和__len__()类方法,__getitem__()的作用是接受一个索引,返回一个样本或者标签。下面通过实例构造一个数据集:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
结合代码可以看到,我们定义了一个名字为 MyDataset 的数据集,在构造函数中,传入 Tensor 类型的数据与标签;在 __len__函数中,直接返回 Tensor 的大小;在 __getitem__ 函数中返回索引的数据与标签。
接下来看如何调用刚才定义的数据集。首先随机生成一个 10*3 维的数据 Tensor,然后生成 10 维的标签 Tensor,与数据 Tensor 相对应。利用这两个 Tensor,生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len() 函数,索引调用数据可以直接使用下标。
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
'''
输出:
tensor_data[0]: (tensor([ 0.4931, -0.0697, 0.4171]), tensor(0))
'''
DataLoader的功能是构建可迭代的数据装载器,在训练的时候,每一个for循环,每一次Iteration,就是从DataLoader中获取一个batch_size大小的数据。
torch.utils.data.DataLoader
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,
pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)
DataLoader有很多参数,但常用的有下面五个:
dataset表示Dataset类,它决定了数据从哪读取以及如何读取;batch_size表示批大小;num_works表示是否多进程读取数据;shuffle表示每个epoch是否乱序;drop_last表示当样本数不能被batch_size整除时,是否舍弃最后一批数据;
这里提到了epoch,所有训练样本都已输入到模型之中,称为一个Epoch,也就说将样本都训练一遍,称为一个Epoch。一批样输入到模型中,这样称为一个Iteration。决定一个Epoch有多少个Iteration,为批次大小Batchsize。
举个例子,假如样本总数为80,设置的Batchsize为8,那么输入模型一次就输入8个样本,1 Epoch = 10 Iteration,这是正好可以整除的例子。假如样本总数为87,设置的Batchsize为8,那么是有10个Iteration还是11个Iteration?这就需要看drop_last参数了,如果设置为True,那就是10个Iteration,相当于舍弃了7个样本,如果设置为False,那就是11个Iteration,最后一个Iteration比前面的Iteration样本个数要少。
根据上面Dataset的例子,我们查看一下DataLoader:
from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数
batch_size=2, # 输出的batch大小
shuffle=True, # 数据是否打乱
num_workers=0) # 进程数, 0表示只有主进程
# 以循环形式输出
for data, target in tensor_dataloader:
print(data, target)
'''
输出:
tensor([[-0.1781, -1.1019, -0.1507],
[-0.6170, 0.2366, 0.1006]]) tensor([0, 0])
tensor([[ 0.9451, -0.4923, -1.8178],
[-0.4046, -0.5436, -1.7911]]) tensor([0, 0])
tensor([[-0.4561, -1.2480, -0.3051],
[-0.9738, 0.9465, 0.4812]]) tensor([1, 0])
tensor([[ 0.0260, 1.5276, 0.1687],
[ 1.3692, -0.0170, -1.6831]]) tensor([1, 0])
tensor([[ 0.0515, -0.8892, -0.1699],
[ 0.4931, -0.0697, 0.4171]]) tensor([1, 0])
'''
# 输出一个batch
print('One batch tensor data: ', iter(tensor_dataloader).next())
'''
输出:
One batch tensor data: [tensor([[ 0.9451, -0.4923, -1.8178],
[-0.4046, -0.5436, -1.7911]]), tensor([0, 0])]
'''
最后,通过流程图认识一下DataLoader的数据读取机制:

首先,在for循环中使用了DataLoader,进入DataLoader后,首先根据是否使用多进程DataLoaderIter,做出判断之后单线程还是多线程,接着使用Sampler得索引Index,然后将索引给到DatasetFetcher,在这里面调用Dataset,根据索引,通过getitem得到实际的数据和标签,得到一个batch size大小的数据后,通过collate_fn函数整理成一个Batch Data的形式输入到模型去训练。

浙公网安备 33010602011771号