pytorch (四) 数据加载

自定义加载数据

torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

  1. __len__:实现len(dataset)返回整个数据集的大小。
  2. __getitem__用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
  3. 不覆写这两个方法会直接返回错误。
from torch.utils.data import DataLoader,Dataset
class MyData(Dataset): #继承Dataset
    def __init__(self, root_dir, transform=None): #初始化图片路径,一些变换操作。
        self.root_dir = root_dir   #文件目录
        self.transform = transform #变换
        self.images = os.listdir(self.root_dir)#目录里的所有文件
    
    def __len__(self):#返回整个数据集的大小
        return len(self.images)
    
    def __getitem__(self,index):#根据索引index返回dataset[index]
        image_index = self.images[index]#根据索引index获取该图片
        img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
        img = io.imread(img_path)# 读取该图片
        label = img_path.split('\\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label
        sample = {'image':img,'label':label}#根据图片和标签创建字典
        
        if self.transform:
            sample = self.transform(sample)#对样本进行变换
        return sample #返回该样本

之后使用torch.utils.data.DataLoader加载数据

data = MyData('path',transform=None)#初始化类,设置数据集所在路径以及变换
dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据

加载时不要涉及预处理,把该预处理的都提前做完。比如resize事先处理完,crop,flip和normalize在加载时候处理。

posted @ 2020-07-23 15:41  木叶流云  阅读(221)  评论(0编辑  收藏  举报