Pytorch框架学习---(2)输入数据操作

本节讲述Data如何利用Pytorch提供的DataLoader进行读取,以及Transforms的图片处理方式。 【文中思维导图采用MindMaster软件】

注意:笼统总结Transforms,目前仅具体介绍裁剪、翻转、标准化,后续随着代码需要,再逐步更新。

一. 数据读取(DataLoader和Dataset)

1.DataLoader

  我们采用Pytorch提供的DataLoader进行数据Batch封装,其中需要定义dataset类。

自定义的dataset类需要复写def getitem(self, index):函数!!!

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=Batch_Size,
                          shuffle=True,
                          num_workers=4,  # num_worker = 4 * GPU个数   为了数据进来更快一些
                          pin_memory=True,  # 也是为了数据输入更快,但是会对增加显存负担 !!!
                          drop_last=True)  # droplast:最后一个批次不满足设定数目Batch_Size,则舍弃
for epoch in range(Max_Epoch):
    for i, (inputs, labels) in enumerate(train_loader):  # 每次调用一个batch,后台索引
# 也可以采用next(iter(train_loader)), 读取一个批次

  在网络运行时,我们采用enumerate函数,进行迭代,这里会:

  • 进入DataLoader数据装载器;

  • 判断参数,是否采用多进程处理;

  • 调用Sampler函数,根据输入数据个数(由Dataset类中def len()函数得到),随机获取index索引值;

  • 进入我们定义的Dataset类,调用def getitem(),根据index获取数据,返回;

  • 调用collate_fn()函数整理数据,最终得到Batch。

2.代码(如何将电脑中的数据送入网络?)

注意:这里数据集已经分类好,文件夹已经各自建立,不包含划分数据的函数!!

import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import torchvision.transforms as transforms

category = {"0": 0, "1": 1, "1_enhanced": 2, "1_enhanced_2": 3, "0_enhanced_1":4}  # 定义标签,"文件夹名":标签

class my_dataset(Dataset):
    '''根据自己的数据,进行读取,Dataset类创建Pytorch数据集类型'''
    '''
    Args:
        data_dir: 数据地址(训练集、验证集、测试集)
        transform: torchvision.transforms(各种变换、以及Totensor)      
    Return:
        read_data  根据dataloader的索引获取数据
        len(self.data_info)  数据个数
    '''

    def __init__(self, data_dir, transform=None):
        self.transforms = transform
        self.data_info = self.get_dataset_info(data_dir)  # 获取所有数据路径和对应的标签,方便dataloader 用index批量处理

    def __getitem__(self, index):  # 当dataloader sampler得到index,根据该index索引dataset中数据
        path_data, label = self.data_info[index]
        read_data = Image.open(path_data).convert("RGB")  # PIL-->RGB(0-256)

        if self.transforms is not None:
            read_data = self.transforms(read_data)

        return read_data, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod  # 定义该函数为静态类型,不用实例化类也可调用
    def get_dataset_info(data_dir):
        data_info = list()  # 最终包含所有图片、标签(每一行)
        for root, dirs, files in os.walk(data_dir):  # 获取当前文件夹的父目录、当前文件夹下所有文件名、所有内部文件
            for sub_dir in dirs:  # 遍历所有类别
                each_cate = os.listdir(os.path.join(root, sub_dir))  # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。

                for i in range(len(each_cate)):  # 遍历每一个类别下的图片数据,将标签一同嵌入
                    each_data_name = each_cate[i]
                    each_data_path = os.path.join(root, sub_dir, each_data_name)
                    each_label = category[sub_dir]

                    data_info.append((each_data_path, int(each_label)))

        return data_info

二.数据预处理(torchvision.transforms)

1.torchvision

2.transforms.Compose([......])组合

  计算机将按照Compose中定义的transforms操作,依次进行数据处理。

train_transforms = transforms.Compose([
    transforms.Resize((75, 75)),
    transforms.ToTensor(),  # (H x W x C) [0, 255] to a torch.FloatTensor (C x H x W) [0.0, 1.0]
    transforms.Normalize(mean=norm_mean,std=norm_std)  # 逐通道归一化,注意通道数
])

3.各种transforms处理方式

  本节目前仅介绍:标准化Normalize、图像裁剪Crop、旋转翻转。

(1)数据标准化

transforms.Normalize(mean, std, inplace=False)  #逐通道对图像进行标准化,mean:(M1,...,Mn) and std: (S1,..,Sn) for n channels
# input[channel] = (input[channel] - mean[channel]) / std[channel]

(2)裁剪

a)从中心进行裁剪

transforms.CenterCrop(size=32)  # 由图像中心进行裁剪,size=32*32

b)随机裁剪

transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
# 先填充再随机裁剪
# padding:设置填充大小,数值a --> 上下左右填充a个像素,(a,b)--> 左右a上下b, (a,b,c,d) --> 左a上b右c下d
# padding_mode:填充模式:
      # constant:像素值由fill参数设定;
      # edge:由图像边缘像素决定;
      # reflect:镜像填充,最后一个像素不镜像;
      # symmetric:镜像填充,最后一个像素镜像。

c)随机面积、随机长宽比裁剪图片

transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR)
# 先选择scale,再ratio,再判断size,是否需要interpolation进行resized
# scale=(0.08, 1.0):随机裁剪面积比例,范围内随机选
# ratio=(3. / 4., 4. / 3.):随机长宽比
# interpolation:插值方法

d)上下左右中心随机裁剪5张图片

transforms.FiveCrop(size)  # 从上下左右中心各裁剪出五张图片
transforms.TenCrop(size, vertical_flip=False)  # 先进行FiveCrop(),再对五张图片进行水平/垂直镜像,获得10张图片

注意:这里返回的是tuple()类型,需要按行拼接起来,送入下游transforms处理。

>>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops

有问题:当采用数据增强时,一方面采用TenCrop形式,另一方面采用其他数据变换,一同送入Dataloader时会产生错误,因为维度不一致,其他数据变换在dataset中为三维(channel,H,W),而TenCrop却是四维(ncrops,channel,H,W),于是当迭代获取Batch时会由于维度不匹配程序报错。
解决方法:【等后续找到再来写,手动狗头微笑】

(3)翻转、旋转

transforms.RandomHorizontalFlip(p=0.5)  # 依概率进行水平(左右)翻转
transforms.RandomVerticalFlip(p=0.5)  # 依概率进行垂直(上下)翻转
transforms.RandomRotation(degrees, resample=False, expand=False, center=None)  # 随机旋转图片
      # degrees:旋转角度,若为a,则在(-a,a)之间二选一,若为(a, b),则二选一
      # expand:是否扩大图片(因为旋转过后可能会丢失图片某一块),仅针对中心点旋转
      # center:旋转点设置,默认中心点

(4)对各种变换的组合--》选择操作(如RandomChoice)

transforms.RandomChoice([transforms1, transforms2, ......])  # 随机挑选一个
transforms.RandomApply([transforms1, transforms2, ......], p=0.5)  # 依概率执行整个一组(要么执行,要么不执行)
transforms.RandomOrder([transforms1, transforms2, ......])  # 对一组操作进行打乱顺序,再去执行这一组

4.自定义Transforms方法

class YourTransforms(object):
    def __init__(self,Arg1,Arg2):
        '''传参数'''
    def __call__(self, x):
        '''定义该Transforms方法'''
        return x
posted @ 2020-06-17 21:52  steven_zhao1001  阅读(2063)  评论(2编辑  收藏  举报