PyTorch torch.utils.data 模块 结构化数据
torch.utils.data 模块中的一些函数,PyTorch 官方文档
1. Dataset 类
Dataset 类创建 Map-style 数据集,通过 __getitem__() 和 __len__() 方法来从数据集中采样,样本可以表示为数据集的索引或键值(indices / keys)的映射(map)。
引入
from torch.utils.data import Dataset
主要作用: 规范化模型的数据,结合 DataLoader 类,根据索引,在每次训练的过程中取出数据,基本结构
class MyDataset(Dataset):
def __init__(self, params):
# 传入必要的参数,原始数据集,等
super(MyDataset, self).__init__() # 父类初始化模型
...
return None
def __len__(self):
# 返回数据集样本总数
return data_len
def __getitem__(self, idx):
# 根据索引 idx,确定每个(或 batch size)需要输入模型的样本
# 返回值可根据具体情况调整
return input, output
1.1 TensorDataset() 函数
对于不需要任何加工的向量,TensorDataset() 函数可以直接将数据(torch.Tensor 数据类型)直接打包成 Dataset 类。
mydataset = TensorDataset(X, Y) # X, Y 应为 torch.Tensor 类型,且数量相等(即X.shape[0]=Y.shape[0])
实例: 获取数据集的大小,从数据集中选取样本
# 加载 iris 数据集
from sklearn.datasets import load_iris
data = load_iris()
X, Y = data.data, data.target
print(X.shape, Y.shape)
import torch
import torch.utils.data as tud
X, Y = torch.tensor(X), torch.tensor(Y, dtype=torch.long) # 数据转换为torch.Tensor 类型
mydataset = tud.TensorDataset(X, Y) # 构建 Dataset 实例
print(mydataset.__len__()) # 获取 Dataset 样本总数,len() 函数也可
print(mydataset.__getitem__(0)) # 获取 Dataset 中第一个样本(索引为 0)
print(mydataset.__getitem__([1, 2, 3])) # 获取 Dataset 中第二、三个样本(索引为 1, 2, 3)
print(len(mydataset)) # 获取 Dataset 样本总数
print(mydataset[0]) # 获取 Dataset 中第一个样本(索引为 0)
print(mydataset[[1, 2, 3]]) # 获取 Dataset 中第二、三、四个样本(索引为 1, 2, 3)
print(mydataset[1:4]) # 获取 Dataset 中第二个到四个(第二、三、四)样本
2. DataLoader 类
引入
from torch.utils.data import DataLoader
主要参数:
-
dataset:上文中Dataset类型 -
batch_size:int类型,默认为 1 -
shuffle:bool类型,默认为False;在每一 epoch 训练模型前,是否 shuffle 数据。 -
drop_last:bool类型,默认为False;是否丢弃(不参与训练)每一 epoch 最后 1 个batch_size的样本。这是由于样本总量(sample size)不能整除 batch size,因此,最后一批的样本数量多数情况下会小于 batch size,设置drop_last=True,这最后这一批次的样本不参与模型训练。 -
collate_fn:
实例:
n_sample = len(mydataset)
batch_size = 16
dataloader = tud.DataLoader(mydataset, batch_size=batch_size) # DataLoader 实例
for i, (X_, Y_) in enumerate(dataloader):
# print(X_, Y_)
# X_, Y_ 为 Dataset 中 __getitem__() 方法的返回值,样本数量(X.shape[0])由 batch_size 决定
print(X_.size(), Y_.size())
# 等价于,但无法 shuffle
for i in range((n_sample - 1) // batch_size + 1):
X_, Y_ = mydataset[i * batch_size : (i + 1) * batch_size]
print(i, X_.size(), Y_.size())
2.1 关于 GPU 加速
若使用 GPU 加速,需要将数据加载到 GPU 设备上。可以在重构 Dataset 类时,在 __getitem__() 方法的最后,将返回的数据加载到 GPU 设备上。
也可以在 DataLoader 类中,过设置 collate_fn 参数实现,代码如下:
import torch
from torch.utils.data.dataloader import default_collate
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
dataloader = tud.DataLoader(
mydataset, batch_size=16, shuffle=True,
collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x))) # 将加载的数据置于 GPU 上
3. 基本工具
3.1 子集提取
以下函数返回的结果均为 Dataset 类(或 Subset 类型),可直接传入 DataLoader 中实现数据加载。
(1) random_split() 函数
random_split(dataset, lengths, generator=<torch._C.Generator object>) :随机划分数据集
主要参数:
datasetlengths:list 类型,每个子集的样本数量generator
实例:
random_split(MyDataset, [3, 7], generator=torch.Generator().manual_seed(42))
实例:
n_sample = X.size()[0]
n_train = int(n_sample*0.7)
n_valid = int(n_sample*0.2)
n_test = n_sample - n_train - n_valid
lens = [n_train, n_valid, n_test]
print(n_sample, lens)
d1, d2, d3 = tud.random_split(mydataset, lens, generator=None)
print(d1.__len__(), d2.__len__(), d3.__len__())
# Output: 105 30 15
由于 PyTorch 未提供非随机(即按顺序)划分样本的方法,可借助 Subset() 函数实现:
d1 = tud.Subset(mydataset, range(n_train))
d2 = tud.Subset(mydataset, range(n_train, n_train + n_valid))
d3 = tud.Subset(mydataset, range(n_train + n_valid, n_sample))
print(len(d1), len(d2), len(d3))
# Output: 105 30 15
(2) ConcatDataset() 函数
ConcatDataset(datasets):合并 dataset
- 参数:
datasets:List of dataset - 也可直接用
+合并两个数据集
实例:
# 方式一
dc1 = d1 + d2 + d3
# 方式二:与方式一等价
dc2 = tud.ConcatDataset([d1, d2, d3])
print(dc1.__len__(), dc1.__len__())
# Output: 105 105
(3) Subset() 函数
Subset(dataset, indices):从 dataset 中提取子集
注意: Dataset.__getitem__() 方法返回的是具体数据(tuple 类型),而 Subset() 函数返回的是 Dataset 类
实例:
d4 = tud.Subset(mydataset, [1,2,3,4])
print(d4.__len__())
# Output: 4
3.2 随机采样
SubsetRandomSampler(indices, generator=None)
WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
代码实例
参考资料
PyTorch, TORCH.UTILS.DATA, site
5-1, Dataset和DataLoader, 20天吃掉那只Pytorch, site

浙公网安备 33010602011771号