pytorch中的DataLoader 使用

 

from torch.utils.data.dataset import Dataset
import torch
import numpy as np
from torch.utils.data.dataloader import DataLoader

class MyDataSet(Dataset):
    def __init__(self,train_data,label_data):
        self.data = torch.tensor(train_data,dtype=torch.float32)
        self.label = torch.tensor(label_data,dtype=torch.int)
        self.lens = self.data.shape[0]

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return self.lens
data = np.random.randn(20,10)
label = np.random.randint(0,9,size=[20])
#print(data,label)
train_data = MyDataSet(data,label)
train_data_loader = DataLoader(train_data,batch_size=4,shuffle=True)

for i, data in enumerate(train_data_loader):
    print(i,data)

 

posted @ 2022-04-05 13:16  一笑任逍遥  阅读(88)  评论(0编辑  收藏  举报