【pytorch】土堆pytorch教程学习(五)torchvision 中的数据集的使用
torchvision 中的数据集使用
在torchvision.datasets模块中提供了许多内置的数据集。
内置的数据集有 CIFAR10、MNIST、COCO等,更多可进入 pytorch 官网查看。
所有内置的数据集都继承了 torch.utils.data.Dataset 类,并且实现了 __getitem__ 和 __len__。
所有的数据集几乎都有相似的API。下面以 CIFAR10 数据集的使用为例来认识下内置数据集的用法。
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
'''
dataset = torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
Args:
root(string):数据集存放的根目录。
train(bool):如果True则从训练集创建数据集,False则从测试集创建数据集。
transform(callable):需要对图像进行的转换操作
target_transforms(callable):需要对 target 进行的转换操作
download(bool):True则从网络下载数据集到根目录。如果数据集已经存在,则不再下载。
'''
# 创建训练集
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True)
# 创建测试集
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)
img, target = test_set[0] # 取出图像和target
print(img, test_set.classes[target])
# 在 tensorboard 里打开十张图像
writer = SummaryWriter('logs')
for i in range(10):
img, target = test_set[i]
writer.add_image('test_set', img, i)
writer.close()
内置数据集很方便地供我们下载使用。根据源码或者官方文档可以了解到创建数据集所需传入的参数,然后需要关注__getitem__ 方法返回的结果是什么。


自定义数据集
自己定义的数据集可以参照内置数据集,即继承 torch.utils.data.Dataset 类,并且重写 __getitem__ 和 __len__。
__getitem__ 方法的作用是接收一个索引,返回索引对应的样本和标签,这部分是我们自己需要实现的逻辑。
__len__ 方法是返回所有样本的数量。
代码示例:
数据集分为训练集和测试集,存放在 dataset 目录,数据集又分为 ants 和 bees 两个目录,分别存放蚂蚁图片和蜜蜂图像,这两个目录名也分别作为标签。如下图所示:

import os
import torch
from PIL import Image
ants_bees_label = {'ants': 0, 'bees': 1}
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, transform=None):
"""
初始化函数中调用 get_img_info() 方法获取图像信息
:param data_dir: str, 数据集所在路径
:param transform: torchvision.transforms 数据预处理
"""
# data_info 存储所有的图像路径和标签
self.data_info = self.get_img_info(data_dir)
self.transform = transform
def __getitem__(self, idx):
"""
根据 idx 读取 self.data_info 中路径对应的数据,并进行 transform 操作
:param idx: 下标
:return: 样本和标签组成的元组
"""
# 通过 idx 读取样本
path_img, label = self.data_info[idx]
# 一般将图像转为 ”RGB“ 方便归一化处理 0~255
img = Image.open(path_img).convert('RGB')
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为Tensor等等
# 返回样本和标签组成的元组
return img, label
def __len__(self):
# 返回所有样本的数量
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
"""
读取每一个图像的路径和对应的标签,组成一个元组,再把所有的元组作为 list返回。
需要注意的是,模仿内置数据集,标签需要映射到 0 开始的整数:labels = {'ants':0, 'bees':1}。
:param data_dir: 训练集或测试集的路径
:return: 返回list
"""
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
# dirs ['ants', 'bees']
for sub_dir in dirs:
# 文件列表
img_names = os.listdir(os.path.join(root, sub_dir))
# 取出jpg结尾的文件
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图像,和标签组成元组,再存入list
for i in range(len(img_names)):
img_name = img_names[i]
# 图像的绝对路径
path_img = os.path.join(root, sub_dir, img_name)
# 映射标签 {'ants':0, 'bees':1}
label = ants_bees_label[sub_dir]
# 保存到 data_info 中
data_info.append((path_img, int(label)))
if len(data_info) == 0:
raise Exception('\ndata_dir:{} is an empty dir! please check your image paths!'.format(data_dir))
return data_info
本文来自博客园,作者:hzyuan,转载请注明原文链接:https://www.cnblogs.com/hzyuan/p/17368576.html

浙公网安备 33010602011771号