Pytorch——Dataset&Dataloader
Dataset&Dataloader
在利用 Pytorch 进行深度学习的训练时需要将数据进行打包,这就是 Dataset 与 Dataloader 的作用。
Dataset 将数据进行包装,Dataloader 迭代包装好的数据并输出每次训练所需要的矩阵。
官网教程: Datasets & DataLoaders — PyTorch Tutorials 1.12.1+cu102 documentation
Dataset
若要自定义自己的 Dataset,可直接继承 Dataset 类。
该类中必须有 3 个固定的函数:__init__,__len__,__getitem__
import os import pandas as pd from torchvision.io import read_image class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label
__init__ 函数
annotations_file: 一个 csv 文件,形式是第 $i$ 行:文件名,标签
例如第 i 行:image1.png, 0
img_dir: 输入文件所在路径
注意:访问数据时会访问 annotations_file 相对于 img_dir 的路径。
即 img_dir/annotations_file 这种形式
transform: 对输入数据所进行的变换
target_transform: 对label 进行的变换
__len__ 函数
返回输入文件个数
__getitem__ 函数
返回下标为 idx 的输入文件以及输出标签
Dataloader
可以选择 batch_size,是否随机打乱。
在 windows 环境下,num_workes 要等于 0
annotations_file = 'annotations' img_dir = '' dataloader = DataLoader(dataset = Datas(annotations_file, img_dir, transforms), batch_size = 3) for i, batch in enumerate(dataloader): input_x = batch['image'] # 输入数据 labels = batch['label'] # 标签
Dataloader 与 Dataset 打包数据实例
import glob import random from torch.utils.data import Dataset from torch.utils.data import DataLoader import sys import os import sys import numpy as np def transforms(input_x): for i in range(input_x.shape[0]): input_x[i][i] *= 0.1 return input_x class Datas(Dataset): def __init__(self, data_dir, transform = None, mode = 'train', unaligned=False): self.file_A = sorted(glob.glob(os.path.join(data_dir, '%s/A' % mode, '*.npy'))) self.file_B = sorted(glob.glob(os.path.join(data_dir, '%s/B' % mode, '*.npy'))) self.transform = transform self.unaligned = unaligned def __len__(self): return min(len(self.file_A), len(self.file_B)) def __getitem__(self, idx): data_A = np.load(self.file_A[idx % len(self.file_A)]) # unaligned == True: 需要乱序,即对于每一个 A 要随机配对一个 B 的情况。 if self.unaligned: data_B = np.load(self.file_B[random.randint(0, len(self.file_B) - 1)]) else: data_B = np.load(self.file_B[idx % len(self.file_B)]) if self.transform: data_A = self.transform(data_A) data_B = self.transform(data_B) return {'A': data_A, 'B': data_B} root = 'data' dataloader = DataLoader(dataset = Datas(root, transforms, 'train', unaligned=True), batch_size = 3, shuffle=True) for i, batch in enumerate(dataloader): input_A = batch['A'] input_B = batch['B'] # input_A, input_B 即为可以输入的训练数据