TensorDataset和DataLoader
一、TensorDataset
语法:class torch.utils.data.TensorDataset(data_tensor, target_tensor)
作用:包装数据和目标张量(类似Python中的zip()函数),可通过第一维度索引两个张量恢复数据。故要保证两个tensor的第一维度是一致的。
参数:
- data_tensor (Tensor) - 包含样本数据
- target_tensor (Tensor) - 包含样本目标(标签)
二、DataLoader
语法:class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
作用:组合数据集和采样器,并在数据集上提供单进程或多进程迭代器,每次抛出一批数据。
参数:
- dataset (Dataset) – 加载数据的数据集。
- batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
- shuffle (bool, optional) – 设置为
True
时会在每个epoch重新打乱数据(默认: False). - sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略
shuffle
参数。 - num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
- collate_fn (callable, optional) –
- pin_memory (bool, optional) –
- drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
三、Demo
1 # 导入相应的依赖包
2 import torch
3 from torch.utils.data import TensorDataset
4 from torch.utils.data import DataLoader
5
6 #---------------------------step 0 数据准备--------------------------#
7 a = torch.tensor([[1, 1, 1],
8 [2, 2, 2],
9 [3, 3, 3],
10 [4, 4, 4],
11 [5, 5, 5],
12 [6, 6, 6],
13 ])
14 print(a.shape) # torch.Size([5, 3])
15 b = torch.tensor([0, 1, 0, 1, 1, 0]) # torch.Size([5])
16 print(b.shape)
17
18 # --------------------step 1 调用TensorDataset--------------------#
19 dataset = TensorDataset(a, b)
20
21 # ---------------------step 2 索引验证数据--------------------------#
22
23 print(dataset)
24
25
26 # 通过第一维度索引恢复两个张量数据
27 print(dataset[0]) # (tensor([1, 1, 1]), tensor(0))
28 print(dataset[1]) # (tensor([2, 2, 2]), tensor(1))
29 print(dataset[2]) # (tensor([3, 3, 3]), tensor(0))
30 print(dataset[:2]) # (tensor([[1, 1, 1], [2, 2, 2]]), tensor([0, 1]))
31
32 # 循环取数据
33 for x_train, y_label in dataset:
34 print(x_train, y_label)
35
36 #------------------------step 1 调用DataLoader---------------------#
37 # 用DataLoader进行数据封装
38
39 dataloader = DataLoader(dataset=dataset,
40 batch_size=2, # batchsize大小为2
41 shuffle=True) # 顺序打乱
42
43 print(dataloader) # <torch.utils.data.dataloader.DataLoader object at 0x0000022BE7116280>
44
45 #-------------------------step 2 数据验证----------------------------#
46
47 for batch, data in enumerate(dataloader, 1): # enumerate返回两个值一个是序号,一个是数据, 1代表序号从1开始
48 x_data, y_label = data
49 print("batch:{batch}, x_data:{x_data}, y_label:{y_label}".format(batch=batch, x_data=x_data, y_label=y_label))
四、参考文献
https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/
作者:kali
-------------------------------------------
个性签名:纸上学来终觉浅,绝知此事要躬行。
如果觉得这篇文章对你有小小的帮助的话,记得在右下角点个“推荐”哦,博主在此感谢!