Pytorch将数据打包

借助TensorDataset直接将数据包装成dataset类

直接使用 TensorDataset 来将数据包装成Dataset类,再使用dataloader。

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
 
src = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))
 
data = TensorDataset(src, trg)
data_loader = DataLoader(data, batch_size=5, shuffle=False)
for i_batch, batch_data in enumerate(data_loader):
    print(i_batch)  # 打印batch编号
    print(batch_data[0].size())  # 打印该batch里面src
    print(batch_data[1].size())  # 打印该batch里面trg

output:

0
torch.Size([5])
torch.Size([5])
1
torch.Size([5])
torch.Size([5])
...

 

希望后续多看到这种,若有好的资源可以在评论区留言

参考:https://blog.csdn.net/weixin_42468475/article/details/108714940

 

posted @ 2021-10-18 23:09  多发Paper哈  阅读(299)  评论(0编辑  收藏  举报
Live2D