Pytorch——Dataset&Dataloader
Dataset&Dataloader
在利用 Pytorch 进行深度学习的训练时需要将数据进行打包,这就是 Dataset 与 Dataloader 的作用。
Dataset 将数据进行包装,Dataloader 迭代包装好的数据并输出每次训练所需要的矩阵。
官网教程: Datasets & DataLoaders — PyTorch Tutorials 1.12.1+cu102 documentation
Dataset
若要自定义自己的 Dataset,可直接继承 Dataset 类。
该类中必须有 3 个固定的函数:__init__,__len__,__getitem__
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
__init__ 函数
annotations_file: 一个 csv 文件,形式是第 $i$ 行:文件名,标签
例如第 i 行:image1.png, 0
img_dir: 输入文件所在路径
注意:访问数据时会访问 annotations_file 相对于 img_dir 的路径。
即 img_dir/annotations_file 这种形式
transform: 对输入数据所进行的变换
target_transform: 对label 进行的变换
__len__ 函数
返回输入文件个数
__getitem__ 函数
返回下标为 idx 的输入文件以及输出标签
Dataloader
可以选择 batch_size,是否随机打乱。
在 windows 环境下,num_workes 要等于 0
annotations_file = 'annotations'
img_dir = ''
dataloader = DataLoader(dataset = Datas(annotations_file, img_dir, transforms), batch_size = 3)
for i, batch in enumerate(dataloader):
input_x = batch['image'] # 输入数据
labels = batch['label'] # 标签
Dataloader 与 Dataset 打包数据实例
import glob
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import sys
import os
import sys
import numpy as np
def transforms(input_x):
for i in range(input_x.shape[0]):
input_x[i][i] *= 0.1
return input_x
class Datas(Dataset):
def __init__(self, data_dir, transform = None, mode = 'train', unaligned=False):
self.file_A = sorted(glob.glob(os.path.join(data_dir, '%s/A' % mode, '*.npy')))
self.file_B = sorted(glob.glob(os.path.join(data_dir, '%s/B' % mode, '*.npy')))
self.transform = transform
self.unaligned = unaligned
def __len__(self):
return min(len(self.file_A), len(self.file_B))
def __getitem__(self, idx):
data_A = np.load(self.file_A[idx % len(self.file_A)])
# unaligned == True: 需要乱序,即对于每一个 A 要随机配对一个 B 的情况。
if self.unaligned:
data_B = np.load(self.file_B[random.randint(0, len(self.file_B) - 1)])
else:
data_B = np.load(self.file_B[idx % len(self.file_B)])
if self.transform:
data_A = self.transform(data_A)
data_B = self.transform(data_B)
return {'A': data_A, 'B': data_B}
root = 'data'
dataloader = DataLoader(dataset = Datas(root, transforms, 'train', unaligned=True), batch_size = 3, shuffle=True)
for i, batch in enumerate(dataloader):
input_A = batch['A']
input_B = batch['B']
# input_A, input_B 即为可以输入的训练数据

浙公网安备 33010602011771号