了解pytorch读取图片的流程和机制。

1,Dataset类

2,制作图片数据的索引

3,构建自定义Dataset类

Dataset类

pytorch读取图片,主要通过Dataset类。Dataset类作为所有的dataset的基类存在,所有的datasets都需要继承它。

 所有的子类,都要重写getitem函数和len函数。

from torch.utils.data import Dataset
import os
from PIL import Image

class mydata(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(root_dir,label_dir)
        self.img_path=os.listdir(self.path)

    def __getitem__(self, idx):
        img_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img,label   # 返回图片和标签

    def __len__(self):
        return len(self.img_path)

root_dir='dataset/train'
label_dir='ants'
data=mydata(root_dir,label_dir)
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
"""
构建Dataset子类,
pytorch读取图片,主要是通过Dataset类,Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于c++中的虚基类。

"""

class MyDataset(Dataset): # 继承Dataset类
    def __init__(self, txt_path, transform=None, target_transform=None): # 定义txt_path参数
        fh = open(txt_path, 'r') # 读取txt文件
        imgs = []  # 定义imgs的列表
        for line in fh:
            line = line.rstrip() # 默认删除的是空白符('\n', '\r', '\t', ' ')
            words = line.split() # 默认以空格、换行(\n)、制表符(\t)进行分割,大多是"\"
            imgs.append((words[0], int(words[1]))) # 存放进imgs列表中

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index] # fn代表图片的路径,label代表标签
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1   参考:https://blog.csdn.net/icamera0/article/details/50843172

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)   # 返回图片的长度

 img = Image.open(fn).convert('RGB') 如果不使用.convert('RGB')进行转换的话,读出来的图像是RGBA四通道的,A通道为透明通道,该对深度学习模型训练来说暂时用不到,因此使用convert('RGB')进行通道转换。

 posted on 2024-03-08 20:58  会飞的金鱼  阅读(72)  评论(0)    收藏  举报