Pytorch——Dataset&Dataloader

Dataset&Dataloader

在利用 Pytorch 进行深度学习的训练时需要将数据进行打包,这就是 Dataset 与 Dataloader 的作用。 

Dataset 将数据进行包装,Dataloader 迭代包装好的数据并输出每次训练所需要的矩阵。 

官网教程: Datasets & DataLoaders — PyTorch Tutorials 1.12.1+cu102 documentation

Dataset

若要自定义自己的 Dataset,可直接继承 Dataset 类。

该类中必须有 3 个固定的函数:__init__,__len__,__getitem__    

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

 

__init__ 函数

annotations_file: 一个 csv 文件,形式是第 $i$ 行:文件名,标签

例如第 i 行:image1.png, 0

img_dir: 输入文件所在路径 

注意:访问数据时会访问 annotations_file 相对于 img_dir 的路径。 

即 img_dir/annotations_file 这种形式  

transform: 对输入数据所进行的变换 

target_transform:  对label 进行的变换    

__len__ 函数 

返回输入文件个数  

__getitem__ 函数 

返回下标为 idx 的输入文件以及输出标签 

 

Dataloader 

可以选择 batch_size,是否随机打乱。

在 windows 环境下,num_workes 要等于 0   

annotations_file = 'annotations'  
img_dir = ''     
dataloader = DataLoader(dataset = Datas(annotations_file, img_dir, transforms), batch_size = 3)   
 
for i, batch in enumerate(dataloader): 
    input_x = batch['image']      # 输入数据 
    labels = batch['label']       # 标签   
    

  

Dataloader 与 Dataset 打包数据实例  

import glob 
import random  
from torch.utils.data import Dataset   
from torch.utils.data import DataLoader
import sys 
import os  
import sys  
import numpy as np  

def transforms(input_x):  
    for i in range(input_x.shape[0]):  
        input_x[i][i] *= 0.1    
    return input_x  

class Datas(Dataset):
    def __init__(self, data_dir, transform = None, mode = 'train', unaligned=False):  
        self.file_A = sorted(glob.glob(os.path.join(data_dir, '%s/A' % mode, '*.npy')))
        self.file_B = sorted(glob.glob(os.path.join(data_dir, '%s/B' % mode, '*.npy')))  
        self.transform = transform    
        self.unaligned = unaligned   

    def __len__(self):
        return min(len(self.file_A), len(self.file_B))                 
 
    def __getitem__(self, idx): 
        data_A = np.load(self.file_A[idx % len(self.file_A)])      

        # unaligned == True: 需要乱序,即对于每一个 A 要随机配对一个 B 的情况。        
        if self.unaligned:  
            data_B = np.load(self.file_B[random.randint(0, len(self.file_B) - 1)])   
        else:  
            data_B = np.load(self.file_B[idx % len(self.file_B)])    

        if self.transform:
            data_A = self.transform(data_A)  
            data_B = self.transform(data_B)     
        return {'A': data_A, 'B': data_B}   

root = 'data'                 
dataloader = DataLoader(dataset = Datas(root, transforms, 'train', unaligned=True), batch_size = 3, shuffle=True)   

for i, batch in enumerate(dataloader): 
    input_A = batch['A']  
    input_B = batch['B']  
    # input_A, input_B 即为可以输入的训练数据      
    

  

posted @ 2022-08-28 23:14  guangheli  阅读(79)  评论(0)    收藏  举报