【pytorch】土堆pytorch教程学习(二)加载数据
Pytorch加载数据初认识
pytorch 中加载数据主要涉及两个类:Dataset 和 Dataloader。
-
Dataset提供一种方式去获取数据及其label -
Dataloader构建可迭代的数据装载器,为网络提供不同的数据形式
Dataset
Dataset 实现的功能:
- 获取每个数据及其label
- 获取数据长度
每个数据集都需要继承 torch.utils.data.Dataset 类,并且重写 __getitem__ 和 __len__。
数据存放在 dataset/train里,分为两个目录 ants 和 bees,也分别是数据的标签,如下图所示:

from PIL import Image
from torch.utils.data import Dataset
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 根目录
self.label_dir = label_dir # 标签目录
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path) # 图片路径列表
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path) # 获取数据
label = self.label_dir # 获取label
return img, label
def __len__(self):
return len(self.img_path) # 获取数据集长度
# test
root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
# 生成两个数据集
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset # 拼接两个数据集
img1, label1 = ants_dataset[0]
img1.show()
print('label1:', label1)
img2, label2 = train_dataset[130]
img2.show()
print('label2:', label2)
本文来自博客园,作者:hzyuan,转载请注明原文链接:https://www.cnblogs.com/hzyuan/p/17344219.html

浙公网安备 33010602011771号