pytorch Dataset Dataloader用法(一个示例)

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import numpy as np
import torch

class OwnDataset(Dataset):
    def __init__(self,x,y):
        self.x = x
        self.y = y
        return
    
    def __getitem__(self,index):
        return self.x[index], self.y[index]
    
    def __len__(self):
        return len(self.x)

x_train = np.random.rand(100,2)
y_train = np.random.randint(2, size = 100)

x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

dataset = OwnDataset(x_train, y_train)
train_loader = DataLoader(dataset, batch_size = 40,shuffle = True)

for epoch in range(2):
    for i,data in enumerate(train_loader):
        inputs, labels = data  #接下来喂入模型中
        print("i:",i)
        print("inputs:",inputs)
        print("len(inputs):",len(inputs))

 相关用法参考官网:

https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset

posted @ 2020-12-07 19:27  qiezi_online  阅读(358)  评论(0编辑  收藏  举报