pytorch加载数据集

分为两种方式

第一种:通过加载pytorch官方常用的数据集

# dataloader = torch.utils.data.DataLoader(
#     datasets.MNIST(
#         "../../data/mnist",
#         train=True,
#         download=True,
#         transform=transforms.Compose(
#             [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
#         ),
#     ),
#     batch_size=opt.batch_size,
#     shuffle=True,
# )

第二种:通过加载本地的数据集

train_loader=datasets.ImageFolder(args.datasets,
                                transform=transforms.Compose([
                                    transforms.Resize(opt.img_size),
                                    transforms.Grayscale(1),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.5], [0.5])]),
                                )
dataloader = torch.utils.data.DataLoader(
        train_loader, batch_size=args.batch_size, shuffle=True)

  

posted @ 2021-05-13 18:57  荼离伤花  阅读(225)  评论(0编辑  收藏  举报