TensorDataset和DataLoader

一、TensorDataset

语法:class torch.utils.data.TensorDataset(data_tensortarget_tensor)

作用:包装数据和目标张量(类似Python中的zip()函数),可通过第一维度索引两个张量恢复数据。故要保证两个tensor的第一维度是一致的。

参数:

  • data_tensor (Tensor) - 包含样本数据
  • target_tensor (Tensor) - 包含样本目标(标签)

二、DataLoader

语法:class torch.utils.data.DataLoader(datasetbatch_size=1shuffle=Falsesampler=None, num_workers=0collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

作用:组合数据集和采样器,并在数据集上提供单进程或多进程迭代器,每次抛出一批数据。

参数:

  • dataset (Dataset) – 加载数据的数据集。
  • batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
  • shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
  • sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • collate_fn (callable, optional) –
  • pin_memory (bool, optional) –
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

三、Demo

 1 # 导入相应的依赖包
 2 import torch
 3 from torch.utils.data import TensorDataset
 4 from torch.utils.data import DataLoader
 5 
 6 #---------------------------step 0 数据准备--------------------------#
 7 a = torch.tensor([[1, 1, 1],
 8                   [2, 2, 2],
 9                   [3, 3, 3],
10                   [4, 4, 4],
11                   [5, 5, 5],
12                   [6, 6, 6],
13                  ])
14 print(a.shape)  # torch.Size([5, 3])
15 b = torch.tensor([0, 1, 0, 1, 1, 0])  # torch.Size([5])
16 print(b.shape)
17 
18 # --------------------step 1 调用TensorDataset--------------------#
19 dataset = TensorDataset(a, b)
20 
21 # ---------------------step 2 索引验证数据--------------------------#
22 
23 print(dataset) 
24 
25 
26 # 通过第一维度索引恢复两个张量数据
27 print(dataset[0])     # (tensor([1, 1, 1]), tensor(0))
28 print(dataset[1])     # (tensor([2, 2, 2]), tensor(1))
29 print(dataset[2])     # (tensor([3, 3, 3]), tensor(0))
30 print(dataset[:2])    # (tensor([[1, 1, 1], [2, 2, 2]]), tensor([0, 1]))
31 
32 # 循环取数据
33 for x_train, y_label in dataset:
34     print(x_train, y_label)
35 
36 #------------------------step 1 调用DataLoader---------------------#
37 # 用DataLoader进行数据封装
38 
39 dataloader = DataLoader(dataset=dataset, 
40                         batch_size=2,  # batchsize大小为2
41                         shuffle=True)  # 顺序打乱
42 
43 print(dataloader)   # <torch.utils.data.dataloader.DataLoader object at 0x0000022BE7116280>
44 
45 #-------------------------step 2 数据验证----------------------------#
46 
47 for batch, data in enumerate(dataloader, 1):   # enumerate返回两个值一个是序号,一个是数据, 1代表序号从1开始
48     x_data, y_label = data
49     print("batch:{batch}, x_data:{x_data}, y_label:{y_label}".format(batch=batch, x_data=x_data, y_label=y_label))

四、参考文献

https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/

https://zhuanlan.zhihu.com/p/371516520

posted @ 2023-03-02 17:13  Kruskal  阅读(162)  评论(0编辑  收藏  举报