2023arXiv_A Fast and Lightweight Network for Low-Light Image Enhancement(FLW)

一. Motivtaion

低光图像增强目前的方法不能同时处理低亮度、低对比度、颜色失真

二.Contribution

1. 提出一个全局组件可以有效提高处理速度

2. 设计了几个新奇的损失函数

三. Network

 对应代码有点出入:

 网络结构解读:

GFE:

获取图像的HSV(Hue:色调,S饱和度,V亮度)中的V通道 :即RGB值的最大值

low_im_filter_max = np.max(low, axis=2, keepdims=True)  # positive (H, W, 1)

对于输入低光照的图像PIL读入图像H W C获取V Channel(HSV)

进入直方图统计,统计在V通道中最大值到最小值的14个区间的值,对于hist[1,1,15]  第三维中位置前12设置为统计的直方图并归一化,位置13设置为低光图像V通道的最小值,位置14设置为低光图像V通道的最大值,位置15设置为正常光图像V通道的平均值,放入MLP中MLP是五个卷积操作,之后放入高阶曲线亮度调整函数(Zero_DCE提出)

LEN网络结构比较简单: 七个卷积加跳连接

四. Loss:

(1)L1损失:均方误差

L1 = torch.abs(R_low - high).mean()

 (2)Lssim损失:使用的是pytorch_ssim中自带的SSIM损失

 (3)作者认为对图像低光增强之后颜色信息应该是能尽量保留的,亦即增强后图像与参考图像(监督图像)的色调和饱和度应该是正比例关系。使用余弦相似度度量增强前后图片的HSV,实际上就是在计算色调、饱和度的相关性。但是在代码实现中1-余弦相似度,但是由于两张图片很接近时,余弦相似度为1从而容易使得Lcolor变为0,容易使得梯度消失,所以在代码实现中最后的损失又加了一个反余弦的计算,避免梯度消失

i,j 是像素坐标, E(I(i,j))是增强后的图像, Y(i,j)是GT

class L_color_zy(nn.Module):

    def __init__(self):
        super(L_color_zy, self).__init__()

    def forward(self, x, y):
        product_separte_color = (x * y).mean(1, keepdim=True)
        x_abs = (x ** 2).mean(1, keepdim=True) ** 0.5
        y_abs = (y ** 2).mean(1, keepdim=True) ** 0.5
        # torch.acos(x,out = None)计算按元素的反余弦值
        loss1 = (1 - product_separte_color / (x_abs * y_abs + 0.00001)).mean() + \
            torch.mean(torch.acos(product_separte_color / (x_abs * y_abs + 0.00001)))

        return loss1

(4)亮度损失:                 c指的是是r,g,b三通道。b(·)代表像素周围的块 。

按照作者原文的解释,也就是参考图像亮的部分增强图像也应该亮。这里对图像明暗程度的衡量并不是一个pixel-wise的概念,而是对每个像素周边的区域的余弦相似度计算。因为线性关系的两者之间可能存在一个常数的差值,因此在计算余弦相似度时需要减去各自区域的最小值(这也是为什么对明暗程度的衡量并不是一个pixel-wise的)。

class L_bright_cosist(nn.Module):

    def __init__(self):
        super(L_bright_cosist, self).__init__()

    def gradient_Consistency_loss_patch(self, x, y):
        # B*C*H*W
        #  torch.abs取绝对值; x.min()返回指定维度的最小数和对应下标,取[0]只取最小值
        # 该行代码的核心功能是计算一个张量沿着第 2 和第 3 维度的最小值,并对其进行取绝对值和截断操作(分离出来)。
        # 使用 detach() 方法将张量从计算图中分离出来,避免梯度回传对该张量产生影响。这种分离通常用于不需要梯度的张量,以节省内存并加速模型训练
        min_x = torch.abs(x.min(2, keepdim=True)[0].min(3, keepdim=True)[0]).detach()  # [1,3,1,1]
        min_y = torch.abs(y.min(2, keepdim=True)[0].min(3, keepdim=True)[0]).detach()
        x = x - min_x
        y = y - min_y
        # B*1*1,3
        # .mean([2, 3])按照高宽取平均值
        product_separte_color = (x * y).mean([2, 3], keepdim=True)
        x_abs = (x ** 2).mean([2, 3], keepdim=True) ** 0.5
        y_abs = (y ** 2).mean([2, 3], keepdim=True) ** 0.5
        loss1 = (1 - product_separte_color / (x_abs * y_abs + 0.00001)).mean() + torch.mean(
            torch.acos(product_separte_color / (x_abs * y_abs + 0.00001)))

        # 按照通道取平均值?
        product_combine_color = torch.mean(product_separte_color, 1, keepdim=True)
        # 通过沿着第 2 和第 3 维度求平均值,可以将 x 和 y 中每个通道的特征图压缩成一个标量,以减少特征图的维度并提取出有意义的统计信息。这种操作通常用于卷积神经网络(CNN)中的池化层或特征融合操作。
        x_abs2 = torch.mean(x_abs ** 2, 1, keepdim=True) ** 0.5
        y_abs2 = torch.mean(y_abs ** 2, 1, keepdim=True) ** 0.5
        loss2 = torch.mean(1 - product_combine_color / (x_abs2 * y_abs2 + 0.00001)) + torch.mean(
            torch.acos(product_combine_color / (x_abs2 * y_abs2 + 0.00001)))
        return loss1 + loss2

    def forward(self, x, y):
        B, C, H, W = x.shape
        loss = self.gradient_Consistency_loss_patch(x, y)
        # loss1 = 0
        # # 0:H // 2 高取0到H/2   H // 2: 高取H/2到H
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, 0:H // 2, 0:W // 2], y[:, :, 0:H // 2, 0:W // 2])
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, H // 2:, 0:W // 2], y[:, :, H // 2:, 0:W // 2])
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, 0:H // 2, W // 2:], y[:, :, 0:H // 2, W // 2:])
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, H // 2:, W // 2:], y[:, :, H // 2:, W // 2:])

        return loss  # +loss1#+torch.mean(torch.abs(x-y))#+loss1

(5)结构损失:                     倒三角是一个梯度计算

图像的结构在增强前后不应该有改变,也就是内容应当尽量相似。在我们之前已经阅读过的文章中,结构损失也是可以用相关系数来描述的,但是此处作者使用每个像素处的梯度来衡量图像结构之间的差异。与明度相同,结构差异也应当是符合一个线性关系的,同样使用一次函数的公式来描述,也同样需要减去一个误差。

# 这段代码定义了一个名为"L_grad_cosist"的类,该类继承自PyTorch中的nn.Module,用于计算两张图像之间的梯度一致性损失。
# 在初始化函数__init__()中,该程序首先定义了两个卷积核kernel_right和kernel_down,这两个卷积核分别对应于求解水平方向和垂直方向的图像梯度。
# 然后利用这两个卷积核,分别构建了权重矩阵self.weight_right和self.weight_down,并将它们转换成nn.Parameter类型。
# 在梯度计算函数gradient_of_one_channel()中,程序利用了pytorch中的F.conv2d函数对x和y进行卷积操作,并分别计算了原始图像x和增强后的图像y在水平和垂直方向上的梯度。
# 最后,函数返回四个梯度张量值,分别对应于原始图像x在水平和垂直方向上的梯度,以及增强后的图像y在水平和垂直方向上的梯度。
class L_grad_cosist(nn.Module):

    def __init__(self):
        super(L_grad_cosist, self).__init__()
        kernel_right = torch.FloatTensor([[0, 0, 0], [0, 1, -1], [0, 0, 0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor([[0, 0, 0], [0, 1, 0], [0, -1, 0]]).cuda().unsqueeze(0).unsqueeze(0)
        self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
        self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)

    def gradient_of_one_channel(self, x, y): # [1, H, W] R/G/B
        # conv2d 小写 函数卷积操作
        D_org_right = F.conv2d(x, self.weight_right, padding="same")
        D_org_down = F.conv2d(x, self.weight_down, padding="same")
        D_enhance_right = F.conv2d(y, self.weight_right, padding="same")
        D_enhance_down = F.conv2d(y, self.weight_down, padding="same")
        return torch.abs(D_org_right), torch.abs(D_enhance_right), torch.abs(D_org_down), torch.abs(D_enhance_down)

    def gradient_Consistency_loss_patch(self, x, y):
        # B*C*H*W
        min_x = torch.abs(x.min(2, keepdim=True)[0].min(3, keepdim=True)[0]).detach()
        min_y = torch.abs(y.min(2, keepdim=True)[0].min(3, keepdim=True)[0]).detach()
        x = x - min_x
        y = y - min_y
        # B*1*1,3
        product_separte_color = (x * y).mean([2, 3], keepdim=True)
        x_abs = (x ** 2).mean([2, 3], keepdim=True) ** 0.5
        y_abs = (y ** 2).mean([2, 3], keepdim=True) ** 0.5
        loss1 = (1 - product_separte_color / (x_abs * y_abs + 0.00001)).mean() + torch.mean(
            torch.acos(product_separte_color / (x_abs * y_abs + 0.00001)))

        product_combine_color = torch.mean(product_separte_color, 1, keepdim=True)
        x_abs2 = torch.mean(x_abs ** 2, 1, keepdim=True) ** 0.5
        y_abs2 = torch.mean(y_abs ** 2, 1, keepdim=True) ** 0.5
        loss2 = torch.mean(1 - product_combine_color / (x_abs2 * y_abs2 + 0.00001)) + torch.mean(
            torch.acos(product_combine_color / (x_abs2 * y_abs2 + 0.00001)))
        return loss1 + loss2

    def forward(self, x, y):
        # x[:, 0:1, :, :] 取第一个通道
        x_R1, y_R1, x_R2, y_R2 = self.gradient_of_one_channel(x[:, 0:1, :, :], y[:, 0:1, :, :])
        x_G1, y_G1, x_G2, y_G2 = self.gradient_of_one_channel(x[:, 1:2, :, :], y[:, 1:2, :, :])
        x_B1, y_B1, x_B2, y_B2 = self.gradient_of_one_channel(x[:, 2:3, :, :], y[:, 2:3, :, :])
        x = torch.cat([x_R1, x_G1, x_B1, x_R2, x_G2, x_B2], 1)
        y = torch.cat([y_R1, y_G1, y_B1, y_R2, y_G2, y_B2], 1)

        B, C, H, W = x.shape
        loss = self.gradient_Consistency_loss_patch(x, y)
        # loss1 = 0
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, 0:H // 2, 0:W // 2], y[:, :, 0:H // 2, 0:W // 2])
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, H // 2:, 0:W // 2], y[:, :, H // 2:, 0:W // 2])
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, 0:H // 2, W // 2:], y[:, :, 0:H // 2, W // 2:])
        # loss1 += self.gradient_Consistency_loss_patch(x[:, :, H // 2:, W // 2:], y[:, :, H // 2:, W // 2:])

        return loss  # +loss1#+torch.mean(torch.abs(x-y))#+loss1

 

x_abs = (x ** 2).mean([2, 3], keepdim=True) ** 0.5和

 mean的解释代码:

x = torch.arange(18,dtype=float).reshape(1, 2, 3, 3)
print(x)
y = x.mean([2],keepdim=True)
print(y)
x_abs = x.mean([2,3],keepdim=True)
x_abs2 = torch.mean(x ,1, keepdim=True)
print(x_abs)
print(x_abs2)

 

posted @ 2023-05-04 19:43  helloWorldhelloWorld  阅读(262)  评论(0)    收藏  举报