torch的数据集

torch.utils.data.TensorDataset 这个类可以初始化数据集

例子:

import torch
from torch.utils import data
# torch.utils.data.dataset 类的使用
x = torch.arange(12, dtype=torch.float32).reshape(6, 2)
y = torch.arange(6, dtype=torch.float32).reshape(6, 1)
# 初始化数据集,需要两个参数,x是特征,y是标签
torch_dataset = data.TensorDataset(x, y)
 
# 使用data.DataLoader 导入数据集,得到可迭代对象
train_iter = data.DataLoader(
    dataset = torch_dataset,  # 数据集
    batch_size = 2,       # 批量大小
    shuffle=True,         # 是否打乱
    num_workers=2,        # 读取线程
)
# 读取数据
for i in train_iter:
  for y in i:
    print(y)
  print('----------')
posted @ 2023-01-06 10:26  __sunshine  阅读(67)  评论(0)    收藏  举报