【pytorch】土堆pytorch教程学习(五)torchvision 中的数据集的使用

torchvision 中的数据集使用

torchvision.datasets模块中提供了许多内置的数据集。

内置的数据集有 CIFAR10、MNIST、COCO等,更多可进入 pytorch 官网查看。

所有内置的数据集都继承了 torch.utils.data.Dataset 类,并且实现了 __getitem____len__

所有的数据集几乎都有相似的API。下面以 CIFAR10 数据集的使用为例来认识下内置数据集的用法。

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

'''
dataset = torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
Args:
root(string):数据集存放的根目录。
train(bool):如果True则从训练集创建数据集,False则从测试集创建数据集。
transform(callable):需要对图像进行的转换操作
target_transforms(callable):需要对 target 进行的转换操作
download(bool):True则从网络下载数据集到根目录。如果数据集已经存在,则不再下载。
'''
# 创建训练集
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True) 
# 创建测试集
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)

img, target = test_set[0]  # 取出图像和target
print(img, test_set.classes[target])

# 在 tensorboard 里打开十张图像
writer = SummaryWriter('logs')
for i in range(10):
    img, target = test_set[i]
    writer.add_image('test_set', img, i)
writer.close()

内置数据集很方便地供我们下载使用。根据源码或者官方文档可以了解到创建数据集所需传入的参数,然后需要关注__getitem__ 方法返回的结果是什么

自定义数据集

自己定义的数据集可以参照内置数据集,即继承 torch.utils.data.Dataset 类,并且重写 __getitem____len__
__getitem__ 方法的作用是接收一个索引,返回索引对应的样本和标签,这部分是我们自己需要实现的逻辑。
__len__ 方法是返回所有样本的数量。
代码示例:
数据集分为训练集和测试集,存放在 dataset 目录,数据集又分为 ants 和 bees 两个目录,分别存放蚂蚁图片和蜜蜂图像,这两个目录名也分别作为标签。如下图所示:

import os
import torch
from PIL import Image

ants_bees_label = {'ants': 0, 'bees': 1}

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        """
        初始化函数中调用 get_img_info() 方法获取图像信息
        :param data_dir: str, 数据集所在路径
        :param transform: torchvision.transforms 数据预处理
        """
        # data_info 存储所有的图像路径和标签
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, idx):
        """
        根据 idx 读取 self.data_info 中路径对应的数据,并进行 transform 操作
        :param idx: 下标
        :return: 样本和标签组成的元组
        """
        # 通过 idx 读取样本
        path_img, label = self.data_info[idx]
        # 一般将图像转为 ”RGB“ 方便归一化处理 0~255
        img = Image.open(path_img).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为Tensor等等
        # 返回样本和标签组成的元组
        return img, label

    def __len__(self):
        # 返回所有样本的数量
        return len(self.data_info)
    @staticmethod
    def get_img_info(data_dir):
        """
        读取每一个图像的路径和对应的标签,组成一个元组,再把所有的元组作为 list返回。
        需要注意的是,模仿内置数据集,标签需要映射到 0 开始的整数:labels = {'ants':0, 'bees':1}。
        :param data_dir: 训练集或测试集的路径
        :return: 返回list
        """
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            # dirs ['ants', 'bees']
            for sub_dir in dirs:
                # 文件列表
                img_names = os.listdir(os.path.join(root, sub_dir))
                # 取出jpg结尾的文件
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                # 遍历图像,和标签组成元组,再存入list
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    # 图像的绝对路径
                    path_img = os.path.join(root, sub_dir, img_name)
                    # 映射标签 {'ants':0, 'bees':1}
                    label = ants_bees_label[sub_dir]
                    # 保存到 data_info 中
                    data_info.append((path_img, int(label)))

        if len(data_info) == 0:
            raise Exception('\ndata_dir:{} is an empty dir! please check your image paths!'.format(data_dir))

        return data_info
posted @ 2023-05-03 01:30  hzyuan  阅读(249)  评论(0)    收藏  举报