from torch.utils.data.dataset import Dataset
import torch
import numpy as np
from torch.utils.data.dataloader import DataLoader
class MyDataSet(Dataset):
def __init__(self,train_data,label_data):
self.data = torch.tensor(train_data,dtype=torch.float32)
self.label = torch.tensor(label_data,dtype=torch.int)
self.lens = self.data.shape[0]
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return self.lens
data = np.random.randn(20,10)
label = np.random.randint(0,9,size=[20])
#print(data,label)
train_data = MyDataSet(data,label)
train_data_loader = DataLoader(train_data,batch_size=4,shuffle=True)
for i, data in enumerate(train_data_loader):
print(i,data)