Dataset与Dataloader
在pytorch中,Dataset和Dataloader是用来加载数据的两个类
数据集
采用蚂蚁蜜蜂数据集,数据集的目录结构如下
hymenoptera_data/
|-- train/
| |-- ants/
| | |-- 0013035.jpg
| | |-- 24335309_c5ea483bb8.jpg
| | |-- ... ...
| |-- bees/
| | |-- 16838648_415acd9e3f.jpg
| | |-- ... ...
|-- val/
| |-- ants/
| | |-- 8124241_36b290d372.jpg
| | |-- ... ...
| |-- bees/
| | |-- 26589803_5ba7000313.jpg
| | |-- ... ...
Dataset
首先来看一下Dataset的函数文档
from torch.utils.data import Dataset
help(Dataset)
---------------------------------
class Dataset(typing.Generic)
| Dataset(*args, **kwds)
|
| An abstract class representing a :class:`Dataset`.
|
| All datasets that represent a map from keys to data samples should subclass
| it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
| data sample for a given key. Subclasses could also optionally overwrite
| :meth:`__len__`, which is expected to return the size of the dataset by many
| :class:`~torch.utils.data.Sampler` implementations and the default options
| of :class:`~torch.utils.data.DataLoader`.
可以看到,Dataset是一个抽象类(接口),实现该接口的子类必须实现__getitem__和__len__方法
接下来通过Dataset实现对蚂蚁蜜蜂数据集的加载
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root, is_ants=True, is_train=True):
if is_ants:
self.label = 'ants'
else:
self.label = 'bees'
if is_train:
tv = 'train'
else:
tv = 'val'
self.path = os.path.join(root, tv, self.label)
self.img_list = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_list[idx]
img_path = os.path.join(self.path, img_name)
img = Image.open(img_path)
return img, self.label # 返回第idx张图片和该图片的标签
def __len__(self):
return len(self.img_list)
使用MyData对数据集进行加载和读取
val_ants = MyData(r'D:\dataset\hymenoptera_data', True, False) # 加载验证集中的蚂蚁数据集
print("共%d个样本" % len(val_ants))
print(val_ants[0]) # 读取第10张图片及其标签
val_bees = MyData(r'D:\dataset\hymenoptera_data', False, False) # 加载验证集中的蜜蜂数据集
print("共%d个样本" % len(val_bees))
print(val_bees[0]) # 读取第10张图片及其标签
---------------------------------
共70个样本
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x375 at 0x1D58C7C0550>, 'ants')
共83个样本
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=366x500 at 0x1D58C7C08E0>, 'bees')
由于Dataset中已经定义了__add__方法,因此我们可以直接通过加法将两个数据集合并
val_data = val_ants + val_bees
print("共%d个样本" % len(val_data))
print(val_data[len(val_ants)-1][1], val_data[len(val_ants)][1])
---------------------------------
共153个样本
ants bees

浙公网安备 33010602011771号