自定义Dataset

Example:

class load_data(Dataset):
    def __init__(self, dataset):
        self.x = np.loadtxt('data/{}.txt'.format(dataset), dtype=float)
        self.y = np.loadtxt('data/{}_label.txt'.format(dataset), dtype=int)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return torch.from_numpy(np.array(self.x[idx])),\
               torch.from_numpy(np.array(self.y[idx])),\
               torch.from_numpy(np.array(idx))

 

if  __name__ =='__main__':
    dataname = 'dblp'
    dataset = load_data(dataname)
    train_loader = DataLoader(dataset, batch_size=256, shuffle=True)
    for x,y,idx in train_loader:
        print(x.shape)
        print(y.shape)
        print(idx.shape)

 

posted @ 2022-04-11 16:07  多发Paper哈  阅读(50)  评论(0编辑  收藏  举报
Live2D