pytorch中的DataLoader 使用
from torch.utils.data.dataset import Dataset import torch import numpy as np from torch.utils.data.dataloader import DataLoader class MyDataSet(Dataset): def __init__(self,train_data,label_data): self.data = torch.tensor(train_data,dtype=torch.float32) self.label = torch.tensor(label_data,dtype=torch.int) self.lens = self.data.shape[0] def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return self.lens data = np.random.randn(20,10) label = np.random.randint(0,9,size=[20]) #print(data,label) train_data = MyDataSet(data,label) train_data_loader = DataLoader(train_data,batch_size=4,shuffle=True) for i, data in enumerate(train_data_loader): print(i,data)