基于PyTorch实现图像去模糊-学习

基于PyTorch实现图像去模糊-学习

任务描述

  • 相机的抖动、快速运动的物体都会导致拍摄出模糊的图像,景深变化也会使图像进一步模糊。
  • 对于传统方法来说,要想估计出每个像素点对应的 “blur kernel” 几乎是不可行的。因此,传统方法常常需要对模糊源作出假设,将 “blur kernel” 参数化。显然,这类方法不足以解决实际中各种复杂因素引起的图像模糊。
  • 卷积神经网络能够从图像中提取出复杂的特征,从而使得模型能够适应各种场景。
  • 本教程以 CVPR2017 的 《Deep Multi-scale Convolutional Neural Network for Dynamic Scene Deblurring》 为例,来完成图像去模糊的任务。

方法概述

  • 利用pytorch深度学习工具实现一个端到端的图像去模糊模型,通过参数设置、加载数据、构建模型、训练模型和测试用例依次实现一个图像去模糊工具,在训练和预处理过程中通过可视化监督训练过程。
  • 模型采用了残差形式的CNN,输入和输出都采用高斯金字塔(Gaussian pyramid)的形式。
  • 整个网络结构由三个相似的CNN构成,分别对应输入金字塔中的每一层。网络最前面是分辨率最低的子网络(coarest level network),在这个子网络最后,是“upconvolution layer”,将重建的低分辨率图像放大为高分辨率图像,然后和高一层的子网络的输入连接在一起,作为上层网络的输入。

image-20211120221821489

%config Completer.use_jedi = False
#!pip install pytorch_msssim -i https://pypi.tuna.tsinghua.edu.cn/simple
# !jupyter nbextension enable --py widgetsnbextension
import torch

import numpy as np

import os
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from tensorboardX import SummaryWriter
from torchsummary import summary
from torch.optim import lr_scheduler
from torch.utils import data
from torchvision import transforms
from tqdm.notebook import tqdm


import pytorch_msssim # 用于计算指标 ssim 和 mssim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

参数设置

class Config():
    def __init__(self,name="Configs"):
        # train set
        self.data_dir = 'datasets/train' # 训练集目录
        self.patch_size = 256  # 输入模型的patch的尺寸
        self.batch_size= 2 #16 # 训练时每个batch中的样本个数
        self.n_threads = 1 # 用于加载数据的线程数
        
        # test set
        self.test_data_dir = 'datasets/test' # 测试集目录
        self.test_batch_size=1 # 测试时的 batch_size
        
        # model
        self.multi = True # 模型采用多尺度方法True
        self.skip = True # 模型采用滑动连接方法
        self.n_resblocks = 3 #9  # resblock的个数
        self.n_feats = 8 #64  #feature map的个数
        
        # optimization 
        self.lr = 1e-4  # 初始学习率
        self.epochs =5 #800 # 训练epoch的数目
        self.lr_step_size = 600 #采用步进学习率策略所用的 step_size
        self.lr_gamma = 0.1 #每 lr_step_size后,学习率变成 lr * lr_gamma
        
        # global
        self.name = name #配置的名称
        self.save_dir = 'temp/result'  # 保存训练过程中所产生数据的目录
        self.save_cp_dir = 'temp/models'  # 保存 checkpoint的目录
        self.imgs_dir = 'datasets/pictures'  # 此 notebook所需的图片目录
        
        
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        if not os.path.exists(self.save_cp_dir):
            os.makedirs(self.save_cp_dir)
#         if not os.path.exists(self.data_dir):
#             os.makedirs(self.data_dir)
#         if not os.path.exists(self.test_data_dir):
#             os.makedirs(self.test_data_dir)

args =  Config(name="image-deblurring")

数据准备

  • 数据集展示
  • 数据增强
  • 构造 dataset类
  • 数据加载 dataloader

数据集展示

sample_idx = 1 # 样本编号
blur_path = os.path.join(args.imgs_dir,f"blur/test{sample_idx}.png")  # 模糊图片
sharp_path = os.path.join(args.imgs_dir,f"sharp/test{sample_idx}.png") # 去模糊图片
blur_img = plt.imread(blur_path)
sharp_img = plt.imread(sharp_path)
plt.figure(figsize=(10,4))
plt.subplot(121)
plt.imshow(blur_img)
plt.subplot(122)
plt.imshow(sharp_img)
plt.show()

image-20211120221933618

数据增强

为了防止过拟合,需要对数据集进行数据增强,增强方式如下所示,对每一个输入图像,都将其进行随机角度旋转,旋转的角度在 [0, 90, 180, 270] 中随机选取。除此之外,考虑到图像质量下降,对 HSV 颜色空间的饱和度乘以 0.8 到 1.2 内的随机数

def augment(img_input, img_target):
    degree = random.choice([0,90,180,270])
    img_input = transforms.functional.rotate(img_input,degree)
    img_target = transforms.functional.rotate(img_target,degree)
    
    # color augmentation
    img_input = transforms.functional.adjust_gamma(img_input,1)
    img_target = transforms.functional.adjust_gamma(img_target,1)
    sat_factor = 1 + (0.2 - 0.4* np.random.rand())
    img_input = transforms.functional.adjust_saturation(img_input,sat_factor)
    img_target = transforms.functional.adjust_saturation(img_target,sat_factor)
    
    return img_input,img_target
img_input = Image.open(blur_path)
img_target = Image.open(sharp_path)

img_aug_input,img_aug_target = augment(img_input,img_target)
plt.figure(figsize=(10,5))
plt.subplot(121)
plt.imshow(img_aug_input)
plt.subplot(122)
plt.imshow(img_aug_target)
plt.show()

image-20211120221955392

构造 dataset类

对每一个输入图像,对齐进行随机裁剪,得到patch_size大小的输入

def getPatch(img_input,img_target,patch_size):
    w,h = img_input.size
    p = patch_size
    x = random.randrange(0,w-p +1)
    y = random.randrange(0,h -p +1)
    
    img_input = img_input.crop((x,y,x+p,y+p))
    img_target = img_target.crop((x,y,x+p,y+p))
    
    return img_input,img_target
class ImgMission(data.Dataset):
    def __init__(self,data_dir, patch_size=256, is_train= False, multi=True):
        super(ImgMission,self).__init__()
        
        self.is_train = is_train  #是否是训练集
        self.patch_size = patch_size # 训练时 patch的尺寸
        self.multi = multi  # 是否采用多尺度因子,默认采用
        
        self.sharp_file_paths = []
        sub_folders = os.listdir(data_dir)
        print(sub_folders)
        
        for folder_name in sub_folders:
            sharp_sub_folder = os.path.join(data_dir,folder_name,'sharp')
            sharp_file_names = os.listdir(sharp_sub_folder)
            # print(sharp_file_names)
            for file_name in sharp_file_names:
                sharp_file_path = os.path.join(sharp_sub_folder,file_name)
                # print(sharp_file_path)
                self.sharp_file_paths.append(sharp_file_path)
                
        self.n_samples = len(self.sharp_file_paths)
        
    def get_img_pair(self,idx):
        sharp_file_path = self.sharp_file_paths[idx]
        blur_file_path = sharp_file_path.replace("sharp","blur")
        # print(blur_file_path)
        img_input = Image.open(blur_file_path).convert('RGB')
        img_target = Image.open(sharp_file_path).convert('RGB')
        
        return img_input,img_target
    
    def __getitem__(self,idx):
        img_input,img_target = self.get_img_pair(idx)
        
        if self.is_train:
            img_input,img_target = getPatch(img_input,img_target, self.patch_size)
            img_input,img_target=  augment(img_input,img_target)
            
            
        # 转换为 tensor类型
        input_b1 = transforms.ToTensor()(img_input)
        target_s1 = transforms.ToTensor()(img_target)
        
        H = input_b1.size()[1]
        W= input_b1.size()[2]
        
        if self.multi:
            input_b1 = transforms.ToPILImage()(input_b1)
            target_s1 = transforms.ToPILImage()(target_s1)
            
            input_b2 = transforms.ToTensor()(transforms.Resize([int(H/2), int(W/2)])(input_b1))
            input_b3 = transforms.ToTensor()(transforms.Resize([int(H/4), int(W/4)])(input_b1))
            
            # 只对训练集进行数据增强
            if self.is_train:
                target_s2 = transforms.ToTensor()(transforms.Resize([int(H/2), int(W/2)])(target_s1))
                target_s3 = transforms.ToTensor()(transforms.Resize([int(H/4), int(W/4)])(target_s1))
            else:
                target_s2 = []
                target_s3 = []
                
            input_b1 = transforms.ToTensor()(input_b1)
            target_s1 = transforms.ToTensor()(target_s1)
            
            return {
                'input_b1': input_b1, # 参照下文的网络结构,输入图像的尺度 1
                'input_b2': input_b2, # 输入图像的尺度 2
                'input_b3': input_b3, # 输入图像的尺度 3
                'target_s1': target_s1, # 目标图像的尺度 1
                'target_s2': target_s2, # 目标图像的尺度 2
                'target_s3': target_s3 # 目标图像的尺度 3
            }
        else:
            return {'input_b1': input_b1, 'target_s1': target_s1}
            
        
        
    def __len__(self):
        return self.n_samples

数据加载 dataloader

def get_dataset(data_dir,patch_size=None, 
                batch_size=1, n_threads=1, 
                is_train=False,multi=False):
    # Dataset实例化
    
#     print(data_dir)
#     print(patch_size)
#     print(is_train)
#     print(multi)

    dataset = ImgMission(data_dir,patch_size=patch_size,
                    is_train=is_train,multi=multi)
    
    # print(dataset)
    # 利用封装好的 dataloader 接口定义训练过程的迭代器
    # 参数num_workers表示进程个数,在jupyter下改为0
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,
                                            drop_last=True, shuffle=is_train,
                                             num_workers = 0)
    return dataloader
  • 将训练时的dataloader实例化
data_loader = get_dataset(args.data_dir,
                          patch_size=args.patch_size,
                          batch_size= args.batch_size,
                          n_threads= args.n_threads,
                          is_train=True,
                          multi = args.multi
                         )
['GOPR0372_07_00', 'GOPR0372_07_01', 'GOPR0374_11_00', 'GOPR0374_11_01', 'GOPR0374_11_02', 'GOPR0374_11_03', 'GOPR0378_13_00', 'GOPR0379_11_00', 'GOPR0380_11_00', 'GOPR0384_11_01', 'GOPR0384_11_02', 'GOPR0384_11_03', 'GOPR0384_11_04', 'GOPR0385_11_00', 'GOPR0386_11_00', 'GOPR0477_11_00', 'GOPR0857_11_00', 'GOPR0868_11_01', 'GOPR0868_11_02', 'GOPR0871_11_01', 'GOPR0881_11_00', 'GOPR0884_11_00']

模型构建

  • 模型介绍
  • 模型定义
  • 实例化模型
  • 损失函数和优化器

image-20211120222025823

CONV 表示卷积层,
ResBlock 表示残差模块,
Upconv 表示上采样(也可以用反卷积代替)。
从图中可以看出,该模型使用了 “multi-scale” 的结构,
在输入和输出部分都都采用了高斯金字塔(Gaussian pyramid)的形式(即对原图像进行不同尺度的下采样,从而获得处于不同分辨率的图像)

image-20211120222042862

模型定义

  • default_conv 是模型采用的默认卷积层,
  • UpConv 用于上采样卷积,
  • ResidualBlock 是模型使用的残差模块,
  • SingleScaleNet 是单个尺度网络,
  • MultiScaleNet 将几个 SingleScaleNet 整合成了最终的多尺度网络模型

具体作用

  • default_conv : 网络中默认采用的卷积层,定义之后,避免重复代码
  • UpConv : 上卷积,对应上图中的 Up Conv,将图像的尺度扩大,输入到另一个单尺度网络
  • ResidualBlock : 残差模块,网络模型中采用的残差模块,之所以采用残差模块,是因为网络“只需要需要模糊图像与去模糊图像之间的差异即可”
  • SingleScaleNet : 单尺度模型,一个尺度对应一个单尺度模型实例
  • MultiScaleNet : 多尺度模型,将多个单尺度模型实例组合即可得到上图所示的多尺度去模糊网络
def default_conv(in_channels,out_channels, kernel_size, bias):
    return nn.Conv2d(in_channels,
                    out_channels,
                    kernel_size,
                    padding=(kernel_size // 2),
                    bias=bias)

class UpConv(nn.Module):
    def __init__(self):
        super(UpConv, self).__init__()
        self.body = nn.Sequential(default_conv(3,12,3,True),
                                 nn.PixelShuffle(2),
                                 nn.ReLU(inplace=True))
    def forward(self,x):
            return self.body(x)

class ResidualBlock(nn.Module):
    def __init__(self,n_feats):
        super(ResidualBlock,self).__init__()
        
        modules_body = [
            default_conv(n_feats, n_feats, 3, bias=True),
            nn.ReLU(inplace=True),
            default_conv(n_feats,n_feats,3,bias=True)
        ]
        
        self.body = nn.Sequential(*modules_body)
        
    def forward(self,x):
        res= self.body(x)
        res += x
        return res

class SingleScaleNet(nn.Module):
    def __init__(self,n_feats,n_resblocks, is_skip, n_channels=3):
        super(SingleScaleNet, self).__init__()
        self.is_skip = is_skip
        
        modules_head = [
            default_conv(n_channels,n_feats,5,bias=True),
            nn.ReLU(inplace=True)
        ]
        
        modules_body = [ResidualBlock(n_feats) for _ in range(n_resblocks)]
        modules_tail = [default_conv(n_feats, 3,5,bias=True)]
        
        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)
        
    def forward(self,x):
        x= self.head(x)
        res= self.body(x)
        if self.is_skip:
            res += x
        
        res = self.tail(res)
        return res

class MultiScaleNet(nn.Module):
    def __init__(self,n_feats, n_resblocks ,is_skip):
        super(MultiScaleNet,self).__init__()
        
        self.scale3_net = SingleScaleNet(n_feats,
                                         n_resblocks,
                                         is_skip,
                                         n_channels=3)
        self.upconv3 = UpConv()
        self.scale2_net = SingleScaleNet(n_feats,
                                         n_resblocks,
                                         is_skip,
                                         n_channels=6)
        self.upconv2 = UpConv()
        
        self.scale1_net = SingleScaleNet(n_feats,
                                        n_resblocks,
                                        is_skip,
                                        n_channels=6)
        
    def forward(self,mulscale_input):
        input_b1, input_b2,input_b3 = mulscale_input
        
        output_l3 = self.scale3_net(input_b3)
        output_l3_up = self.upconv3(output_l3)
        
        output_l2 = self.scale2_net(torch.cat((input_b2,output_l3_up),1))
        output_l2_up = self.upconv2(output_l2)
        
        output_l1 = self.scale2_net(torch.cat((input_b1,output_l2_up),1))
        
        return output_l1,output_l2,output_l3

模型实例化

if args.multi:
    my_model = MultiScaleNet(n_feats=args.n_feats,
                            n_resblocks = args.n_resblocks,
                            is_skip= args.skip)
else:
    my_model = SingleScaleNet(n_feats=args.n_feats,
                             n_resblocks=args.n_resblocks,
                             is_skip = args.skip)
if torch.cuda.is_available():
    my_model.cuda()
    loss_function = nn.MSELoss().cuda()
else:
    loss_function = nn.MSELoss()
    
optimizer = optim.Adam(my_model.parameters(),lr=args.lr)
print(my_model)
print(loss_function)
MultiScaleNet(
  (scale3_net): SingleScaleNet(
    (head): Sequential(
      (0): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): ReLU(inplace=True)
    )
    (body): Sequential(
      (0): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (tail): Sequential(
      (0): Conv2d(8, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (upconv3): UpConv(
    (body): Sequential(
      (0): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): PixelShuffle(upscale_factor=2)
      (2): ReLU(inplace=True)
    )
  )
  (scale2_net): SingleScaleNet(
    (head): Sequential(
      (0): Conv2d(6, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): ReLU(inplace=True)
    )
    (body): Sequential(
      (0): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (tail): Sequential(
      (0): Conv2d(8, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (upconv2): UpConv(
    (body): Sequential(
      (0): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): PixelShuffle(upscale_factor=2)
      (2): ReLU(inplace=True)
    )
  )
  (scale1_net): SingleScaleNet(
    (head): Sequential(
      (0): Conv2d(6, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): ReLU(inplace=True)
    )
    (body): Sequential(
      (0): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): ResidualBlock(
        (body): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
    )
    (tail): Sequential(
      (0): Conv2d(8, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
)
MSELoss()

损失函数和优化器

  • Adam 优化器,初始学习率为 lr,其相对于 SGD,更自动化,实际中需要调整的参数较少,但需要注意的是,其使用内存也比 SGD 要高。
  • 损失函数使用最常见的均方损失函数(MSELoss):
    image-20211120222105417
    其中 \(f^{\prime}(i,j)\)\(f(i,j)\) 分别为模型输出结果图和非模糊图上坐标为 \((i,j)\) 的像素,M,N分别表示图片的长与宽。
  • 具体的,本文所用的多尺度损失函数为:
    image-20211120222124809
    \(f^{\prime}_k\)\(f_k\) 分别表示第 \(k\) 个尺度上的输出结果图和非模糊图。

模型训练

  1. 训练策略
  2. 训练模型
  3. 训练过程可视化

训练策略

  • 在模型训练过程中,随着训练的进行,更新网络参数的步进(学习率)应该越来越小,整体训练过程应该满足 “先粗调后细调”,这就是常说的学习率策略。
  • 本次训练采用的学习率优化策略为 lr_scheduler.StepLR,步进为 lr_step_size,学习率每隔 lr_step_size 个 epoch 乘以 lr_gamma
scheduler = lr_scheduler.StepLR(optimizer,args.lr_step_size,args.lr_gamma)

scheduler
<torch.optim.lr_scheduler.StepLR at 0x290378a0198>

训练模型

在训练开始之前,要先创建一个 SummaryWriter,用来记录和可视化训练过程

writer = SummaryWriter(os.path.join(args.save_dir,"temp/logs/"))

writer
<tensorboardX.writer.SummaryWriter at 0x290378d6e10>
  • 在命令行运行 tensorboard --logdir=experiment/logs 来启动tensorboard。
  • 在训练模型时,每训练完一个 epoch 将模型的参数保存下来,防止训练被意外中断以及方便测试,如果需要不断更新最新的一次训练的参数,可以取消最后一行的注释。
  • 训练过程中,使用 tqdm 的进度条来观察训练过程
bar_format = '{desc}{percentage:3.0f}% | [{elapsed}<{remaining},{rate_fmt}]' # 进度条格式

for epoch in range(args.epochs):
    total_loss = 0
    batch_bar =  tqdm(data_loader, bar_format=bar_format) # 利用tqdm动态显示训练过程
    for batch,images in enumerate(batch_bar):
        my_model.train()
        curr_batch = epoch * data_loader.__len__() + batch # 当前batch在整个训练过程中的索引

        input_b1 = images['input_b1'].to(device) # 原始输入图像
        target_s1 = images['target_s1'].to(device) # 目标非模糊图片

        if args.multi:
            input_b2 = images['input_b2'].to(device)  # level-2 尺度
            target_s2 = images['target_s2'].to(device)

            input_b3 = images['input_b3'].to(device)  # level-3 尺度
            target_s3 = images['target_s3'].to(device)
            output_l1, output_l2, output_l3 = my_model((input_b1,input_b2,input_b3))

            # 损失函数
            loss = (loss_function(output_l1,target_s1) + loss_function(output_l2,target_s2) + loss_function(output_l3, target_s3)) /3 

        else:
            output_l1 = my_model(input_b1)
            loss = loss_function(output_l1,target_s1)

        my_model.zero_grad()
        loss.backward()  #反向传播
        optimizer.step() # 更新权值
        total_loss += loss.item()


        print_str = "|".join([
            "epoch:%3d/%3d" % (epoch + 1, args.epochs),
            "batch:%3d/%3d" % (batch + 1, data_loader.__len__()),
            "loss:%.5f" % (loss.item()),
        ])
        batch_bar.set_description(print_str,refresh=True)  # 更新进度条

        writer.add_scalar('train/batch_loss', loss.item(), curr_batch)

    batch_bar.close()
    scheduler.step() #调整学习率
    loss = total_loss / (batch +1)

    writer.add_scalar('train/batch_loss',loss,epoch)
    torch.save(my_model.state_dict(),os.path.join(args.save_cp_dir, f'Epoch_{epoch}.pt')) # 保存每个 epoch 的参数
#     torch.save(my_model.state_dict(),os.path.join(args.save_cp_dir, f'Epoch_lastest.pt')) # 保存最新的参数

image-20211120222241087

模型评估

  1. 指标介绍
  2. 指标实现

指标介绍

为了评估模型的效果如何,我们通过计算
峰值信噪比(Peak Signal-to-Noise Ratio, PSNR),
结构相似性(Structural Similarity, SSIM)和
多尺度的 SSIM(Multi-Scale SSIM,MSSIM)三个指标来对结果进行分析

PSNR

PSNR 的定义如下:
image-20211120222352850
其中,\(M A X_{I}\)表示图像点颜色的最大数值,如果每个采样点用 8 位表示,则最大数值为 255,\(MSE\)是两个图像之间的均方误差。
PSNR值越大代表模糊图像与参考图像越接近,即去模糊效果越好。

SSIM

SSIM也是衡量两幅图片相似性的指标,其定义如下:
image-20211120222344621
SSIM由模型输出图像 \(x\) 和参考图像 \(y\) 之间的亮度对比($ l(\mathbf{x}, \mathbf{y})\()、对比度对比(\)c(\mathbf{x}, \mathbf{y})\()和结构对比(\)s(\mathbf{x}, \mathbf{y}) \()三部分组成,\)\alpha\(,\)\beta$ 和 \(\gamma\)是各自的权重因子,一般都取为 1:
image-20211120222325752
其中,\(C_{1}\)\(C_{2}\)\(C_{3}\)为常数,是为了避免分母接近于0时造成的不稳定性。\(\mu_{x}\)\(\mu_{y}\) 分别为模型输出图像和参考图像的均值。\(\sigma_{x}\)\(\sigma_{y}\) 分别为模型输出图像和参考图像的标准差。通常取 \(C1=(K1*L)^2\)\(C2=(K2*L)^2\)\(C3=C2/2\),一般地\(K1=0.01\)\(K2=0.03\), \(L=255\)\(L\)是像素值的动态范围,一般都取为255)。
输出图片和目标图片的结构相似值越大,则表示相似性越高,图像去模糊效果越好。
SSIM是一种符合人类直觉的图像质量评价标准。从名字上我们不难发现,这种指标是在致力于向人类的真实感知看齐,详细细节可以参考原论文

MSSIM

MSSIM相当于是在多个尺度上来进行SSIM指标的测试,相对于SSIM,其能更好的衡量图像到观看者的距离、像素信息密集程度等因素对观看者给出的主观评价所产生的影响。
论文中给出的一个例子是,观看者给一个分辨率为1080p的较为模糊的画面的评分可能会比分辨率为720p的较为锐利的画面的评分高。因此在评价图像质量的时候不考虑尺度因素可能会导致得出片面的结果。
MSSIM提出在不同分辨率(尺度)下多次计算结构相似度后综合结果得到最终的评价数值。其计算过程框图如下所示
image-20211120222301708

MSSIM 的详细细节可以参考原论文

指标实现

# class PSNR(nn.Module):
class PSNR(nn.Module):
    def forward(self,img1,img2):
        mse = ((img1 - img2) ** 2).mean() # 输出图像和参考图像的 MSE
        psnr = 10 * torch.log10(1.0 * 1.0 / (mse + 10 ** (-10)))
        return psnr
# SSIM 和 MSSIM 的计算较为复杂,在这里,我们直接调用 pytorch-msssim 的接口来进行计算
ssim = pytorch_msssim.SSIM(data_range=1.0, size_average=True, channel=3)
mssim = pytorch_msssim.MS_SSIM(data_range=1.0, size_average=True, channel=3)
# 实例化
ssim = pytorch_msssim.SSIM(data_range=1.0, size_average=True, channel=3)
mssim = pytorch_msssim.MS_SSIM(data_range=1.0, size_average=True, channel=3)
psnr = PSNR()

模型预测

  1. 绘图函数定义
  2. 模型加载
  3. 数据加载
  4. 模型预测与指标分析
  5. 结果展示与保存

绘图函数定义

def plot_tensor(tensor):
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)
    ret = transforms.ToPILImage()(tensor.squeeze(0))
    plt.imshow(ret)
    return

模型加载

训练过程中我们保存了多个 checkpoint ,现在对其进行加载和测试。这里我们提供了两种选择 checkpoint 的方式,一种是选择指定 checkpoint,一种是选择最新的 checkpoint。在这里我们以最新的 checkpoint 为例进行测试

# option-A :测试指定epoch
# best_epoch = 100 
# best_cp = f"{args.save_cp_dir}/Epoch_{best_epoch}.pt"

# option-B :测试最终epoch
# best_cp = f"{args.save_cp_dir}/Epoch_lastest.pt"
best_cp = f"{args.save_cp_dir}/Epoch_3.pt"

my_model.to("cuda").load_state_dict(torch.load(best_cp))
my_model = my_model.eval()

数据加载

# 由于此模型采用的是多尺度训练,因此对于单张输入图像,需要对其进行处理,定义加载图像的函数 load_images 为
def load_images(blur_img_path,multi):
    target_s1 = None
    sharp_img_path = blur_img_path.replace("blur","sharp")
    if os.path.exists(sharp_img_path):
        img_target = Image.open(sharp_img_path).convert('RGB')
        target_s1 = transforms.ToTensor()(img_target).unsqueeze(0)
        
    img_input = Image.open(blur_img_path).convert('RGB')  # 转换为image类型 方便进行resize
    input_b1 = transforms.ToTensor()(img_input)
    
    if multi:
        H = input_b1.size()[1]
        W = input_b1.size()[2]
        
        input_b1 = transforms.ToPILImage()(input_b1)
        input_b2 = transforms.ToTensor()(transforms.Resize([int(H/2), int(W/2)])(input_b1)).unsqueeze(0)
        input_b3 = transforms.ToTensor()(transforms.Resize([int(H/4), int(W/4)])(input_b1)).unsqueeze(0)
        
        input_b1 = transforms.ToTensor()(input_b1).unsqueeze(0)
        
        return {'input_b1':input_b1, 'input_b2':input_b2, 'input_b3':input_b3, 'target_s1':target_s1}
    else:
        return {'input_b1':unsqueeze(0), 'target_s1':target_s1}

模型预测与指标分析

模型预测

#目录一
# idx = 1
# blur_img_path = f"datasets/pictures/blur/test{idx}.png"  

# 目录二
idx='000001'
blur_img_path =f"datasets/test/GOPR0384_11_00/blur/{idx}.png"
item = load_images(blur_img_path,args.multi)

input_b1 = item['input_b1'].to(device) 
input_b2 = item['input_b2'].to(device) 
input_b3 = item['input_b3'].to(device) 
target_s1 = item['target_s1'].to(device) 

output_l1,_,_ = my_model((input_b1,input_b2,input_b3))

指标分析

原始模糊图片与不模糊图片之间的指标计算

blur_psnr = psnr(input_b1,target_s1)
blur_ssim = ssim(input_b1,target_s1)
blur_mssim = mssim(input_b1,target_s1)

print(f"原始模糊图片:PSNR={blur_psnr.float()}, SSIM={blur_ssim.float()}, MSSIM={blur_mssim.float()}")
原始模糊图片:PSNR=24.050003051757812, SSIM=0.716961145401001, MSSIM=0.840461015701294

去模糊图片与不模糊的图片之间的指标计算

output_psnr = psnr(output_l1,target_s1)
output_ssim = ssim(output_l1,target_s1)
output_mssim = mssim(output_l1,target_s1)

print(f"网络输出图片:PSNR={output_psnr.float()}, SSIM={output_ssim.float()}, MSSIM={output_mssim.float()}")
网络输出图片:PSNR=24.012224197387695, SSIM=0.7089502811431885, MSSIM=0.8413411974906921

结果展示

plt.figure(figsize=(6,10))
plt.subplot(311)
plot_tensor(input_b1)
plt.subplot(312)
plot_tensor(output_l1)
plt.subplot(313)
plot_tensor(target_s1)

image-20211120222412395

# 将结果保存
save_name = blur_img_path.split("/")[-1]
save_path = os.path.join(args.save_dir,save_name)
save_img = transforms.ToPILImage()(output_l1.squeeze(0))
save_img.save(save_path)
posted @ 2021-11-20 22:30  OCEANEYES.GZY  阅读(1209)  评论(0编辑  收藏  举报