GRDN:分组残差密集网络,用于真实图像降噪和基于GAN的真实世界噪声建模

GRDN:分组残差密集网络,用于真实图像降噪和基于GAN的真实世界噪声建模

摘要

随着深度学习体系结构(尤其是卷积神经网络)的发展,有关图像去噪的最新研究已经取得了进展。但是,现实世界中的图像去噪仍然非常具有挑战性,因为不可能获得理想的地面对图像和现实世界中的噪声图像对。由于最近发布了基准数据集,图像去噪社区的兴趣正朝着现实世界中的去噪问题发展。在本文中,我们提出了分组残差密集网络(GRDN),它是最新的残差密集网络(RDN)的扩展和通用体系结构。RDN的核心部分定义为分组残差密集块(GRDB),并用作GRDN的构建模块。我们通过实验表明,通过级联GRDB可以显着改善图像降噪性能。除了网络架构设计之外,我们还开发了一种新的基于对抗网络的真实世界的噪声建模方法。我们通过在NTIRE2019实像去噪挑战赛道2:sRGB中的峰值信噪比和结构相似性方面获得最高分,证明了所提出方法的优越性。.

一、简介

在图像去噪领域,最近的研究表明,基于学习的方法比以前的手工方法(例如块匹配3D(BM3D)[6]及其变体)更加有效。对于基于学习的方法,拥有足够数量的高质量数据集至关重要。由于可以通过在无噪声图像上添加合成噪声来轻松构建一对嘈杂且无噪声的图像,因此,大多数以前的基于学习的方法都专注于经典的高斯去噪任务,并且最关注网络的体系结构设计。尤其是卷积神经网络(CNN)。然而,由于合成噪声图像和真实噪声图像之间的差距,发现使用合成图像训练的CNN在真实噪声图像上表现不佳,有时甚至不如BM3D [22]。

这些作者做出了同样的贡献。通讯作者:S.-W.荣格*

对于现实世界的图像去噪,主要有两种方法。第一种方法是找到一种更好的真实噪声统计模型,而不是加性高斯白噪声[3,8,10,19,23]。特别是,高斯分布和泊松分布的组合显示出可以紧密地模拟依赖信号和不依赖信号的噪声。使用这些新的合成噪点图像训练的网络证明了在消除现实世界的噪点图像方面的优越性。这种方法的一个明显优势是,只需将合成噪声添加到无噪声的地面图像中,我们就可以拥有无限多的训练图像对。但是,是否可以通过统计模型来模拟现实世界的噪声仍然是有争议的。因此,第二种方法是相反的方向。从真实的嘈杂图像中,可以通过反转图像采集过程[1、4、24、22、2]获得几乎无噪声的地面真实图像。据我们所知,智能手机图像去噪数据集(SIDD)[1]是第二种方法中最大的高质量图像数据集之一。但是,提供的图像数量可能不足以训练大型网络,并且没有足够的专业知识,很难从真实的嘈杂图像中生成真实的图像。因此,我们采用第二种方法,但应用了我们自己的基于生成对抗网络(GAN)的数据增强技术来获取更大的数据集。

网络架构当然是最重要的。在基于CNN的图像恢复中,密集残差块(RDB)[33、32]受到了极大关注。在本文中,我们提出了一种称为分组残差密集网络(GRDN)的新体系结构。特别是,提出的体系结构采用了最近的残差密集网络(RDN)作为具有较小修改的组件,并将其定义为分组的残差密集块(GRDB)。通过将GRDB与关注模块进行级联,我们可以获得现实世界中图像去噪任务的最新性能[28]。在NTIRE2019实像去噪挑战-轨道2:sRGB中,我们在39.93 dB的峰值信噪比(PSNR)和0.9736的结构相似度(SSIM)方面取得了最佳性能。

img

​ 图1:提出的网络架构:GRDN

二、相关工作

2.1 影像还原

图像降噪是图像处理中研究最广泛的主题之一。由于深度学习的显着进步,基于CNN的方法现在在图像去噪中占主导地位。但是,大多数以前的基于学习的图像去噪方法都集中在经典的高斯去噪任务上。对于现实世界的图像降噪,第一种方法是通过使用不同的相机设置来捕获一对嘈杂且无噪点的图像[2,22]。在[22]中表明,较早的基于学习的方法与经典方法(如BM3D)可比甚至有时不如BM3D。我们认为这主要是由于训练数据集的质量和数量不足。因此,开发了更丰富和完善的数据集,例如Darmstadt噪声数据集(DND)和SIDD [1],并且最近的基于学习的方法[1、3、10、23]显示出它们优于经典方法在现实世界中的图像去噪。

除了努力生成高质量的数据集以外,还进行了大量研究以找到更好的网络体系结构以进行图像去噪。从CNN的角度来看,为不同的图像恢复任务(如图像去噪,图像去模糊,超分辨率和压缩伪像减少)开发的网络体系结构具有相似性。反复证明,为某种图像还原任务开发的一种体系结构在其他还原任务中也表现良好[30、32、23]。因此,我们检查了为不同图像恢复任务而开发的许多体系结构,尤其是超分辨率[7,13,16,17,14,26,33,31,11,18]。其中,RDN [33,32]和残留信道关注网络(RCAN)[31]与我们的网络体系结构关系最密切。

特别是,我们尝试利用RDN和RCAN中的新颖思想。RCAN在残差(RIR)体系结构中引入了残差,消融研究表明RIR的性能增益最为显着。因此,我们在架构设计中使用了RIR原理。另外,RDN本身是一个图像恢复网络,但是我们将它与修改一起用作我们的网络的组成部分,并构造了一个RDN的级联结构作为我们的图像去噪网络。最近的研究还表明注意模块的有效性。在许多注意力模块中,卷积块注意力模块(CBAM)[28]是一种易于植入的模块,可以顺序地估计通道的注意力和空间的注意力,在一般物体检测和图像分类中显示出了效率,因此我们将CBAM纳入了我们的网络。

2.2. GAN

诸如SIDD和DND之类的可公开获得的真实世界图像降噪数据集中的训练图像数量可能不足以训练深度和广泛的神经网络。扩充这些数据集的一种可行方法是利用GAN的功能[9]。第一种基于GAN的真实世界的噪声建模方法[5]仅使用真实世界的噪声图像训练噪声生成器,其中鉴别器被训练为区分真实和模拟噪声信号。然后,使用噪声发生器将合成的但逼真的噪声添加到无噪声的地面图像中,并使用生成的成对的地面图像和高噪声图像最终训练去噪网络。通过使用GAN生成的数据集,现实世界中的图像降噪性能得到了显着改善。

通过将诸如无噪声图像补丁,ISO和快门速度之类的调节信号作为生成器的附加输入,我们改进了以前基于GAN的实际噪声仿真技术[5]。对无噪声图像斑块进行调节可以帮助生成更逼真的与信号相关的噪声,而其他相机参数可以提高可控性和各种模拟噪声信号。我们还通过使用最新的相对论GAN [12]来更改先前体系结构的鉴别符[5]。与常规GAN不同,相对论GAN的判别器学会了确定真实数据与伪数据之间哪个更为现实。我们的方法与传统相对论GAN的不同之处在于,真实数据和伪数据都被用作输入,以使鉴别器更明确地比较这两个数据。

img

​ 图2:GRDN的组件:(a)RDB和(b)GRDB

三、提出的方法

3.1 图像去噪网络

我们的称为GRDN的图像去噪网络架构如图1所示。我们的设计原则是分配每一层的负担,以便可以更好地训练更深更广的网络。为此,将残余连接应用于四个不同级别。下采样层和上采样层被包括在内,以实现更深,更宽的架构,并且还应用了CBAM [28]。

受RDN [33]的启发,我们使用如图2(a)所示的RDB作为构建模块。在RDN中,来自级联RDB的要素被串联在一起,然后是1×1卷积层。如图2(b)所示,我们将RDN的功能串联部分定义为GRDB,并将其用作GRDN的构建模块。请注意,原始RDN [33]在GRDB之前和之后应用卷积层,并使用全局残差学习进行图像去噪。但是,我们认为RDN给GRDB的最后1×1卷积层带来了沉重的负担。因此,我们改为级联GRDB,以便可以将RDB中的功能分为多个阶段。受包括RDN [33]在内的许多最新图像恢复网络的推动,我们还包括了全局残差连接,因此该网络可以专注于学习噪声图像和真实图像之间的差异。最后,我们将CBAM作为构建模块来进一步提高去噪性能。CBAM块的位置根据经验选择在上卷积层和最后一个卷积层之间。

img

​ 图3:cERGAN生成器

尽管GRDN在结构上比RDN更深[33,32],但我们使用了相同数量的RDB。具体来说,在原始RDN中使用了16个RDB进行图像降噪。我们使用4个GRDB堆栈,每个GRDB包含4个RDB,因此GRDN中有16个RDB。

3.2 基于GAN的真实世界噪声建模

受最新技术[5]的启发,我们开发了自己的发生器和鉴别器用于实际噪声建模。与先前的技术[21]类似,我们使用残差块(ResBlocks)作为生成器的构建模块。但是,我们进行了一些修改以提高实际噪声建模的性能。图3显示了生成器架构。首先,我们包含调节信号:无噪声的图像补丁,ISO,快门速度和智能手机型号,作为发生器的附加输入。对无噪声图像斑块进行调节可以帮助生成更逼真的与信号相关的噪声,而其他与相机相关的参数可以提高可控性和各种模拟噪声信号。为了用这些调节信号训练生成器,我们使用了SIDD [1]的元数据。第二,频谱归一化(SN)[20]在像[29]中所使用的基本卷积单元中在批量归一化之前应用。第三,我们的ResBlock包含剩余缩放比例[25、18、27]。从经验上发现,SN和残留水垢对训练我们的发电机很有用。

img

​ 图4:CERGAN鉴别器

如图4所示,我们的鉴别器架构也不同于以前的基于GAN的噪声仿真技术[5]。增强的超分辨率GAN(ESGAN)[27]表明相对论GAN [12]可有效地生成逼真的图像纹理。与原始GAN [9]不同,相对论GAN的判别者学会了确定真实数据与伪数据之间哪个更为现实。令img表示输入图像x的未变换鉴别器输出。然后可以将标准鉴别符表示为img,σ是S型函数。ESGAN中采用的相对论平均GAN(RaGAN)的定义为:

img

其中,imgimg分别表示真实数据和伪数据,而img表示期望运算符,该期望运算符应用于迷你批处理中的所有数据[27]。定义为条件显式相对论GAN(cERGAN)的拟议网络的鉴别符为

img

其中img表示调节信号。具体来说,我们通过复制值使每个条件数据的大小与训练补丁的大小相同,因此我们的img由4个补丁组成:来自智能手机代码的3个常量补丁(例如Google Pixel = 0,iPhone 7 = 1等),ISO级别,快门速度和一个无噪点的图像补丁。除了img之外,我们还同时使用imgimg作为鉴别符的输入。请注意,ESGAN使用imgimg作为鉴别符的输入。

生成器和鉴别器的损失函数分别表示为imgimg,其定义如下:

img

换句话说,如果第二个输入是img,而第三个输入是img,则辨别器将经过训练以预测接近1的值,即imgimg更现实。如果切换了两个输入,则鉴别器将被训练为预测接近0的值,即img的真实性不如img。训练了生成器以欺骗鉴别器。通过要求网络在真实数据和假数据之间进行显式比较,我们可以模拟更真实的真实噪声。

四、实验

我们使用PyTorch库,Intel i7-8700 @ 3.20GHz,32GB RAM和NVIDIA Titan XP来实现所有模型。

img

​ 表1:图像去噪模型的比较。

4.1 数据集

我们使用了NTIRE 2019实像去噪挑战的训练和验证图像,它是SIDD数据集的子集[1]。让ChDB表示我们用于实验的数据集。具体来说,分别使用320个高分辨率图像和1280个尺寸为256×256的裁剪图像块进行训练和验证。提供的图像是由五个智能手机相机拍摄的-Apple iPhone 7,Google Pixel,三星Galaxy S6 Edge,摩托罗拉Nexus 6和LG G4。由于测试数据集的真实图像不公开,因此我们在本节中使用验证数据集报告图像去噪模型的性能。由于我们注意到地面真实图像中图像边界周围的非边际劣化,因此在生成训练补丁时,我们排除了第一行和最后8行/列。没有应用诸如缩放,翻转和旋转之类的常规数据增强技术。

4.2 图像去噪

4.2.1 实施细节

我们通过两种方式扩充了提供的训练数据集。首先,我们使用作者提供的源代码[10]将合成噪声添加到真实图像中。我们还应用了第3.2节中介绍的基于GAN的噪声模拟器,以生成其他合成噪声图像。

在每个训练批次中,我们随机提取16对真实的图像和嘈杂的图像块。我们使用Adam [15]进行了训练,其img = 0.9,img = 0.999。初始学习率设置为img,然后在每次img迭代时降低到一半。我们使用img损失训练了网络。我们训练了大约5天的模型。

对于上/下卷积层,我们使用了4×4过滤器,对来自RDB的级联特征使用了1×1过滤器。否则,我们使用3×3滤波器。使用零填充,并且未对所有卷积层使用膨胀。每个RDB具有8对卷积层和ReLU激活层。

4.2.2与RDN的比较

首先,我们将GRDN模型与RDN [32]进行了比较。实验结果如表1所示。我们使用ChDB重新训练了RDN。表1中的第一和第二列对应于RDN和建议的GRDN。可以看出,我们模型的PSNR比RDN高0.04 dB。请注意,RDN和GRDN具有相同数量的RDB,因此参数的数量相似。具体来说,我们的基本GRDN模型具有22M参数,而RDN具有21.9M参数。

4.2.3 补丁大小实验

由于原始图像分辨率很高(超过1200万像素),因此需要使用最大可能的色块大小来包含足够的图像内容。因此,我们将补丁大小增加到96×96,这是我们实验环境中最大的大小。通过比较表1的第2列和第5列,我们可以看到,通过增大贴片尺寸可以获得0.22dB的显着性能增益。

4.2.4 CBAM模块的实验

CBAM [28]是一个简单但有效的CNN模块。因为它是轻量级的通用模块,所以可以轻松地将其植入任何CNN架构中,而无需大量增加参数数量。特别是,CBAM可以置于网络的瓶颈处。由于我们具有下采样和上采样层,因此我们检查了CBAM的不同位置和组合。我们得出结论,对于我们的模型,CBAM的最佳位置是在上采样层之后。我们认为,这表明CBAM增强了上采样数据的重要功能。它还有助于为之后的最后一个卷积层构造最终的去噪图像。发现CBAM的有效性取决于网络的复杂性。比较表1的第二列和第三列,CBAM将PSNR提高了0.05 dB。但是,在增加了补丁大小之后,CBAM的增益被稀释了。比较表1的第5列和第6列,CBAM甚至将PSNR降低了0.01 dB。

img

图5:真实世界噪声建模的实验结果:

(a)ChDB中的真实噪声图像块

(b)真实图像块

(c)噪声图像和真实图像块之间的差异

( d)cERGAN产生的噪声补丁

4.2.5超参数调整

我们比较了具有不同数量的过滤器和GRDB的网络。比较表1的第6列和第7列,较浅但较宽的网络性能要好0.02 dB。因此,在我们的硬件限制下,第七列的模型是性能最佳的模型。

4.3 真实世界噪声建模

为了训练cERGAN的生成器和判别器,我们从真实的噪声图像中裁剪了尺寸为48×48的图像块,并从ChDB中裁剪了其真实的图像。我们将批处理大小为32,将Adam优化器与imgimg一起使用。生成器和鉴别器经过了340k次迭代训练。区分器和生成器的初始学习率均设置为0.0002,并且我们在320k次迭代后线性降低了学习率,以使最后一次迭代后学习率变为0。图5示出了由所提出的cERGAN产生的一些噪声图像补丁。从图5和6中可以看出。如图5(c)和(d)所示,建议的cERGAN可以生成接近真实噪声的噪声补丁。

通过比较使用或不使用模拟数据训练的拟议图像降噪网络,可以评估模拟噪声图像的有效性。在这里,测试的网络对应于表1的第4列。我们首先尝试仅使用cERGAN获得的合成的现实噪声图像训练图像去噪网络。在ChDB验证集中获得的平均PSNR为38.63 dB,不如我们仅使用提供的ChDB数据集获得的PSNR(表1中的39.62 dB)。

img

​ 图6:具有不同数据集的图像去噪网络的收敛性分析。

将统计建模的真实世界的噪声添加到ChDB的真实图像中。我们使用这些数据集训练的图像去噪网络仅产生36.17 dB,这表明所提出的基于GAN的噪声建模至少比统计噪声建模方法[10]表现更好。

最后,我们将原始的ChDB数据集与通过建议的cERGAN和常规方法生成的合成数据集[10]相结合。在这里,我们只能测试一种配置:来自ChDB的90%,使用[10]的模拟ChDB的5%和使用cERGAN的模拟ChDB的5%。图6显示使用增强数据集获得的PSNR更加稳定地增加。所得的PSNR为39.64 dB,略高于使用原始数据集获得的PSNR(39.62 dB)。

五. NTIRE2019图像降噪挑战

这项工作被提议参加NTIRE2019实像去噪挑战-Track 2:sRGB。挑战在于开发一种具有最高PSNR和SSIM的图像去噪系统。提交的图像去噪网络对应于表1的第七列。提交的模型中的一个小更改是,我们每2个GRDB都包含了跳过连接。对于训练,(35.49 / 0.9812)(29.86 / 0.9314)(26.32 / 0.7576)(19.05 / 0.3623)(39.11 / 0.9899)(39.59 / 0.9902)(37.05 / 0.9749)(37.13 / 0.9748)我们使用增强的ChDB本节中提到的技术。 4.3。我们的模型在PSNR和SSIM方面均在真实图像降噪方面排名第一。如表2所示,我们的模型优于第二等级方法0.05 dB。

六、结论

在本文中,我们提出了一种用于现实图像降噪的改进网络架构。通过广泛和分层地使用剩余连接,我们的模型获得了最先进的性能。此外,我们开发了一种改进的基于GAN的实际噪声建模方法。

尽管我们只能将拟议的网络评估为现实世界中的图像降噪,但我们认为拟议的网络普遍适用。因此,我们计划将提出的图像去噪网络应用于其他图像恢复任务。我们也不能完全和定量地证明所提出的实际噪声建模方法的有效性。为了更好地进行真实的噪声建模,显然必须进行更精细的设计。我们相信,我们的真实世界噪声建模方法可以扩展到其他真实世界的退化,例如模糊,混叠和雾度,这将在我们的未来工作中得到证明。

七、参考文献

[1] A. Abdelhamed, S. Lin, and M. Brown. A high-quality denoising dataset for smartphone cameras. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 1692–1700, 2018.

[2] J. Anaya and A. Barbu. RENOIR - A benchmark dataset for real noise reduction evaluation. CoRR, abs/1409.8230, 2014.

[3] T. Brooks, B. Mildenhall, T. Xue, J. Chen, D. Sharlet, and J. T. Barron. Unprocessing images for learned raw denoising. CoRR, abs/1811.11127, 2018.

[4] C. Chen, Q. Chen, J. Xu, and V. Koltun. Learning to see in the dark. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 3291–3300, 2018.

[5] J. Chen, J. Chen, H. Chao, and M. Yang. Image blind denoising with generative adversarial network based noise modeling. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 3155–3164, 2018.

[6] K. Dabov, A. Foi, V. Katkovnik, and K. Egiazarian. Image denoising by sparse 3-d transform-domain collaborative filtering. IEEE Trans. Image Process., 16(8):2080–2095, Aug. 2007.

[7] C. Dong, C. C. Loy, K. He, and X. Tang. Learning a deep convolutional network for image super-resolution. In Proceedings of the European Conference on Computer Vision, pages 184–199. Springer, 2014.

[8] A. Foi, M. Trimeche, V. Katkovnik, and K. Egiazarian. Practical Poissonian-Gaussian noise modeling and fitting for single-image raw-data. IEEE Trans. Image Process., 17(10):1737–1754, 2008.

GRDN网络结构代码实现

SubNets.py

import torch
import torch.nn as nn
import torch.nn.functional as F


def weights_init(m):
    """
    custom weights initialization called on netG and netD
    https://github.com/pytorch/examples/blob/master/dcgan/main.py
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

####################################################################################################################


class make_dense(nn.Module):
    def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
        super(make_dense, self).__init__()
        self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                              bias=False)
        self.nChannels = nChannels

    def forward(self, x):
        out = F.relu(self.conv(x))
        out = torch.cat((x, out), 1)
        return out

class make_residual_dense_ver1(nn.Module):
    def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
        super(make_residual_dense_ver1, self).__init__()
        self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                              bias=False)
        self.nChannels_ = nChannels_
        self.nChannels = nChannels
        self.growthrate = growthRate

    def forward(self, x):
        # print('1', x.shape, self.nChannels, self.nChannels_, self.growthrate)
        # print('2', outoflayer.shape)
        # print('3', out.shape, outoflayer.shape)
        # print('4', out.shape)

        outoflayer = F.relu(self.conv(x))
        out = torch.cat((x[:, :self.nChannels, :, :] + outoflayer, x[:, self.nChannels:, :, :]), 1)
        out = torch.cat((out, outoflayer), 1)
        return out

class make_residual_dense_ver2(nn.Module):
    def __init__(self, nChannels, nChannels_, growthRate, kernel_size=3):
        super(make_residual_dense_ver2, self).__init__()
        if nChannels == nChannels_ :
            self.conv = nn.Conv2d(nChannels_, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                                  bias=False)
        else:
            self.conv = nn.Conv2d(nChannels_ + growthRate, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                                  bias=False)

        self.nChannels_ = nChannels_
        self.nChannels = nChannels
        self.growthrate = growthRate

    def forward(self, x):
        # print('1', x.shape, self.nChannels, self.nChannels_, self.growthrate)
        # print('2', outoflayer.shape)
        # print('3', out.shape, outoflayer.shape)
        # print('4', out.shape)

        outoflayer = F.relu(self.conv(x))
        if x.shape[1] == self.nChannels:
            out = torch.cat((x, x + outoflayer), 1)
        else:
            out = torch.cat((x[:, :self.nChannels, :, :], x[:, self.nChannels:self.nChannels + self.growthrate, :, :] + outoflayer, x[:, self.nChannels + self.growthrate:, :, :]), 1)
        out = torch.cat((out, outoflayer), 1)
        return out

class make_dense_LReLU(nn.Module):
    def __init__(self, nChannels, growthRate, kernel_size=3):
        super(make_dense_LReLU, self).__init__()
        self.conv = nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
                              bias=False)

    def forward(self, x):
        out = F.leaky_relu(self.conv(x))
        out = torch.cat((x, out), 1)
        return out


# Residual dense block (RDB) architecture
class RDB(nn.Module):
    """
    https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
    """

    def __init__(self, nChannels, nDenselayer, growthRate):
        """
        :param nChannels: input feature 의 channel 수
        :param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
        :param growthRate: Conv 의 output layer 의 수
        """
        super(RDB, self).__init__()
        nChannels_ = nChannels
        modules = []
        for i in range(nDenselayer):
            modules.append(make_dense(nChannels, nChannels_, growthRate))
            nChannels_ += growthRate
        self.dense_layers = nn.Sequential(*modules)

        ###################kingrdb ver2##############################################
        # self.conv_1x1 = nn.Conv2d(nChannels_ + growthRate, nChannels, kernel_size=1, padding=0, bias=False)
        ###################else######################################################
        self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        out = self.dense_layers(x)
        out = self.conv_1x1(out)
        # local residual 구조
        out = out + x
        return out

def RDB_Blocks(channels, size):
    bundle = []
    for i in range(size):
        bundle.append(RDB(channels, nDenselayer=8, growthRate=64))  # RDB(input channels,
    return nn.Sequential(*bundle)

####################################################################################################################
# Group of Residual dense block (GRDB) architecture
class GRDB(nn.Module):
    """
    https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
    """

    def __init__(self, numofkernels, nDenselayer, growthRate, numforrg):
        """
        :param nChannels: input feature 의 channel 수
        :param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
        :param growthRate: Conv 의 output layer 의 수
        """
        super(GRDB, self).__init__()

        modules = []
        for i in range(numforrg):
            modules.append(RDB(numofkernels, nDenselayer=nDenselayer, growthRate=growthRate))
        self.rdbs = nn.Sequential(*modules)
        self.conv_1x1 = nn.Conv2d(numofkernels * numforrg, numofkernels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        out = x
        outputlist = []
        for rdb in self.rdbs:
            output = rdb(out)
            outputlist.append(output)
            out = output
        concat = torch.cat(outputlist, 1)
        out = x + self.conv_1x1(concat)
        return out

# Group of group of Residual dense block (GRDB) architecture
class GGRDB(nn.Module):
    """
    https://github.com/lizhengwei1992/ResidualDenseNetwork-Pytorch
    """

    def __init__(self, numofmodules, numofkernels, nDenselayer, growthRate, numforrg):
        """
        :param nChannels: input feature 의 channel 수
        :param nDenselayer: RDB(residual dense block) 에서 Conv 의 개수
        :param growthRate: Conv 의 output layer 의 수
        """
        super(GGRDB, self).__init__()

        modules = []
        for i in range(numofmodules):
            modules.append(GRDB(numofkernels, nDenselayer=nDenselayer, growthRate=growthRate, numforrg=numforrg))
        self.grdbs = nn.Sequential(*modules)

    def forward(self, x):
        output = x
        for grdb in self.grdbs:
            output = grdb(output)

        return x + output

####################################################################################################################


class ResidualBlock(nn.Module):
    """
    one_to_many 논문에서 제시된 resunit 구조
    """
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        residual = self.bn1(x)
        residual = self.relu1(residual)
        residual = self.conv1(residual)
        residual = self.bn2(residual)
        residual = self.relu2(residual)
        residual = self.conv2(residual)
        return x + residual


def ResidualBlocks(channels, size):
    bundle = []
    for i in range(size):
        bundle.append(ResidualBlock(channels))
    return nn.Sequential(*bundle)

DenoisingMoels.py

from models.subNets import *
from models.cbam import *


class ntire_rdb_gd_rir_ver1(nn.Module):
    def __init__(self, input_channel, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64, t=1):
        super(ntire_rdb_gd_rir_ver1, self).__init__()

        self.numforrg = numforrg  # num of rdb units in one residual group
        self.numofrdb = numofrdb  # num of all rdb units
        self.nDenselayer = numofconv
        self.numofkernels = numoffilters
        self.t = t

        self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
        # self.layer2 = nn.ReLU()
        self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        modules = []
        for i in range(self.numofrdb // self.numforrg):
            modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        self.rglayer = nn.Sequential(*modules)

        self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        # self.layer8 = nn.ReLU()
        self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
        self.cbam = CBAM(self.numofkernels, 16)

    def forward(self, x):
        out = self.layer1(x)
        # out = self.layer2(out)
        out = self.layer3(out)

        # out = self.rglayer(out)
        for grdb in self.rglayer:
            for i in range(self.t):
                out = grdb(out)

        out = self.layer7(out)
        out = self.cbam(out)

        # out = self.layer8(out)
        out = self.layer9(out)

        # global residual 구조
        return out + x

class ntire_rdb_gd_rir_ver2(nn.Module):
    def __init__(self, input_channel, numofmodules=2, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64, t=1):
        super(ntire_rdb_gd_rir_ver2, self).__init__()

        self.numofmodules = numofmodules # num of modules to make residual
        self.numforrg = numforrg  # num of rdb units in one residual group
        self.numofrdb = numofrdb  # num of all rdb units
        self.nDenselayer = numofconv
        self.numofkernels = numoffilters
        self.t = t

        self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
        # self.layer2 = nn.ReLU()
        self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        modules = []
        for i in range(self.numofrdb // (self.numofmodules * self.numforrg)):
            modules.append(GGRDB(self.numofmodules, self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        for i in range((self.numofrdb % (self.numofmodules * self.numforrg)) // self.numforrg):
            modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        self.rglayer = nn.Sequential(*modules)

        self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        # self.layer8 = nn.ReLU()
        self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
        self.cbam = CBAM(numoffilters, 16)

    def forward(self, x):
        out = self.layer1(x)
        # out = self.layer2(out)
        out = self.layer3(out)

        for grdb in self.rglayer:
            for i in range(self.t):
                out = grdb(out)

        out = self.layer7(out)
        out = self.cbam(out)

        # out = self.layer8(out)
        out = self.layer9(out)

        # global residual 구조
        return out + x



class Generator_one2many_gd_rir_old(nn.Module):
    def __init__(self, input_channel, numforrg=4, numofrdb=16, numofconv=8, numoffilters=64):
        super(Generator_one2many_gd_rir_old, self).__init__()

        self.numforrg = numforrg  # num of rdb units in one residual group
        self.numofrdb = numofrdb  # num of all rdb units
        self.nDenselayer = numofconv
        self.numofkernels = numoffilters

        self.layer1 = nn.Conv2d(input_channel, self.numofkernels, kernel_size=3, stride=1, padding=1)
        self.layer2 = nn.ReLU()
        self.layer3 = nn.Conv2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)

        modules = []
        for i in range(self.numofrdb // self.numforrg):
            modules.append(GRDB(self.numofkernels, self.nDenselayer, self.numofkernels, self.numforrg))
        self.rglayer = nn.Sequential(*modules)

        self.layer7 = nn.ConvTranspose2d(self.numofkernels, self.numofkernels, kernel_size=4, stride=2, padding=1)
        self.layer8 = nn.ReLU()
        self.layer9 = nn.Conv2d(self.numofkernels, input_channel, kernel_size=3, stride=1, padding=1)
        self.cbam = CBAM(self.numofkernels, 16)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.rglayer(out)

        out = self.layer7(out)
        out = self.cbam(out)
        out = self.layer8(out)
        out = self.layer9(out)

        # global residual 구조
        return out + x

cbma.py

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=False, bn=False, bias=True):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

def weights_init_rcan(m):
    """
    custom weights initialization called on netG and netD
    https://github.com/pytorch/examples/blob/master/dcgan/main.py
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        if classname.find('BasicConv') != -1:
            m.conv.weight.data.normal_(0.0, 0.02)
            if m.bn != None:
                m.bn.bias.data.fill_(0)
        else:
            m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

DGU-3DMlab1_track1.py

import numpy as np
import cv2
import torch
from models.DenoisingModels import *
from utils.utils import *
from utils.transforms import *
import scipy.io as sio
import time
import tqdm

if __name__ == '__main__':

    print('********************Test code for NTIRE challenge******************')

    # path of input .mat file
    mat_dir = 'mats/BenchmarkNoisyBlocksRaw.mat'

    # Read .mat file
    mat_file = sio.loadmat(mat_dir)

    # get input numpy
    noisyblock = mat_file['BenchmarkNoisyBlocksRaw']
    
    print('input shape', noisyblock.shape)

    # path of saved pkl file of model
    modelpath = 'checkpoints/DGU-3DMlab1_track1.pkl'
    expname = 'DGU-3DMlab1_track1'

    # set gpu
    device = torch.device('cuda:0')

    # make network object
    model = Generator_one2many_gd_rir_old(input_channel=1, numforrg=4, numofrdb=16, numofconv=8, numoffilters=67).to(device)

    # make numpy of output with same shape of input
    resultNP = np.ones(noisyblock.shape)
    print('resultNP.shape', resultNP.shape)

    submitpath = f'results_folder/{expname}'
    make_dirs(submitpath)

    # load checkpoint of the model
    checkpoint = torch.load(modelpath)
    model.load_state_dict(checkpoint['state_dict'])

    transform = ToTensor()
    revtransform = ToImage()

    # pass inputs through model and get outputs
    with torch.no_grad():
        model.eval()
        starttime = time.time()     # check when model starts to process
        for imgidx in tqdm.tqdm(range(noisyblock.shape[0])):
            for patchidx in range(noisyblock.shape[1]):
                img = noisyblock[imgidx][patchidx]   # img shape (256, 256, 3)

                input = transform(img).float()
                input = input.view(1, -1, input.shape[1], input.shape[2]).to(device)

                output = model(input)       # pass input through model

                outimg = revtransform(output)   # transform output tensor to numpy

                # put output patch into result numpy
                resultNP[imgidx][patchidx] = outimg

    # check time after finishing task for all input patches
    endtime = time.time()
    elapsedTime = endtime - starttime   # calculate elapsed time
    print('ended', elapsedTime)
    num_of_pixels = noisyblock.shape[0] * noisyblock.shape[1] * noisyblock.shape[2] * noisyblock.shape[3]
    print('number of pixels', num_of_pixels)
    runtime_per_mega_pixels = (num_of_pixels / 1000000) / elapsedTime
    print('Runtime per mega pixel', runtime_per_mega_pixels)

    # save result numpy as .mat file
    sio.savemat(f'{submitpath}/{expname}', dict([('results', resultNP)]))
posted @ 2021-06-09 10:30  梁君牧  阅读(1447)  评论(0编辑  收藏  举报