Dataset和Dataloader的使用
在深度学习中训练模型都是小批量小批量地优化训练的,即每次都会从原数据集中取出一小批量进行训练,完成一次权重更新后,再从原数据集中取下一个小批量数据,然后再训练再更新。
另外,原数据集往往很大,不可能一次性的全部载入模型,只能一小批一小批地载入。训练完了就扔了,再加载下一小批。
准备数据
import pandas as pd
import numpy as np
data = np.random.rand(128, 3) # 128x3
data = pd.DataFrame(data, columns=['feature_1', 'feature_2', 'label'])
Dataset和Dataloader使用模板
class MyDataset(Dataset):
def __init__(self, data):
super().__init__()
'''
有两种写法:
1、将全部数据都加载进内存里,适用于少量数据;
2、当数据量或者标签量很大时,比如图片,就把这些数据或者标签放到文件或数据库里去,只需在此方法中初始化定义这些文件索引的列表即可。
'''
# 以下2个方法都是魔法方法
def __getitem__(self, index): # 实现索引数据集中的某一个数据
# 表示将来实例化这个对象后,它能支持下标(索引)操作,也就是能通过索引把里面的数据拿出来。
# ...
# 数据解析
# ...
return 某一个数据
def __len__(self): # 返回数据集的长度
return len(self.data)
my_dataset = MyDataset(data)
train_loader = DataLoader(dataset=my_dataset, # 传递数据集
batch_size=32, #一个小批量容量是多少
shuffle=True, # 数据集顺序是否要打乱,一般是要的。测试数据集一般没必要
num_workers=0) # 需要几个进程来一次性读取这个小批量数据
创建一个完整的Dataset,使用上面自己生成的数据集。
from torch.utils.data import Dataset # Dataset是个抽象类,只能用于继承
from torch.utils.data import DataLoader # DataLoader需实例化,用于加载数据
class MyDataset(Dataset):
def __init__(self, data):
super().__init__()
self.data = data
self.features = self.data[['feature_1', 'feature_2']].values # [[]]是返回结果是一个df,[]可能返回一个Serial;
self.label = self.data[['label']].values
# .values 取DataFrame或Series的值,返回值是一个numpy ndarray的副本;而不是对原始数据的引用。这意味着你可以对返回的数组执行任何操作,而不会影响原始的pandas对象。
def __getitem__(self, index): # 参数index必写
return self.features[index], self.label[index]
def __len__(self):
return len(self.data)
# 实例化
my_dataset = MyDataset(data)
train_loader = DataLoader(dataset=my_dataset, # 要传递的数据集
batch_size=32, #一个小批量数据的大小是多少
shuffle=True, # 数据集顺序是否要打乱,一般是要的。测试数据集一般没必要
num_workers=0) # 需要几个进程来一次性读取这个小批量数据
dataloader 本质上是一个可迭代对象!
关于可迭代对象以及迭代器的说明:https://www.cnblogs.com/kphang/p/17026932.html
所以就有这说明里提到的几种方式进行访问。
# 方式1:每一轮就是Dataloader中设置的batch_size的大小。
for batch_features, batch_labels in train_loader:
pass
# 方式2:
for i, batch_data in enumerate(train_loader):
print(i) # 0 1 2 3 # 共4轮,因为batch_size设置的是32,数据项共128个;
在 PyTorch 中,dataloader 会遍历完所有的数据。每次迭代会返回一个批次的数据。你可以使用 for 循环来迭代 dataloader,每次循环会返回一个批次的数据。当 dataloader 迭代完所有数据时,for 循环将会结束。
也可以使用 iter() 函数来获取迭代器,然后使用 next() 函数来迭代数据。例如:
# 方式3:
iterator = iter(dataloader)
while True:
try:
data = next(iterator)
# Process the data
except StopIteration:
break