pytorch确保所有参数正常初始化并参与梯度传播

问题

如何确保搭建的网络中所有定义的结构中的可训练参数正常初始化,并且参与梯度传播过程?

分析

参数的初始化首先依赖于模块被pytorch识别到,因为pytorch和tensorflow不一样,torch中动态图的灵活性也带来了稍许不便(坑):
自定义的一个继承于nn.Module的神经网络class,其__init__(self,...)函数中只能定义要是用到的模块组件,并不能直接定义完整的图结构,这是因为动态图的优势发挥在forward(self, x, ...)函数中,这个前向计算函数才是真正构建图的地方,这一点和tf1.0的静态图差异巨大。因为可以在forward中任意改动定义在__init__中的组件的连接关系,甚至临时添加新的组件,同时可以直接逐步查看每个节点的输出状态,这种灵活性是静态图不能提供的。但正是基于此,优化器optimizer很难根据用户想要的方式捕捉要训练的变量(参数),而且即便定义在__init__函数中的组件也有时无法被捕获到。导致训练出现问题。为此,Pytorch提供了补救(填坑)措施:nn.Module.add_module(nn.Module)可以注册子模块(用于连接构造图的组件),将指定的模块暴露给torch执行环境,并且将模块对应的参数暴露给优化器optimizer,让这些临时构建的组件的参数也参与训练。

解决

如下代码所示,通过add_module()实现动态图的组件注册,避免漏参。

# Stereo Siamese Net
import sys
sys.path.append('..')
from utils import AdaIN, adain

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T


class StereoSiamNet(nn.Module):
    '''
    Params:
    @input_channel  1(for gray image) or 3(for RGB image)
    @texture_encoder_config   [[4, 3, 2], [4, 3, 1]]
    this texture encoder config shows a 2 group of 2-layer conv net which firstly applys 4 filters with 3x3 kernel
    to the input gray image, and then use 1x1 kernel filters to reduce channels to 2, and then a similar
    group of ops are applied after this, while they finally reduce channels to 1.
    @style_encoder_config   [[8, 3], [16, 3], [32, 3], [64, 3]]
    this style encoder config shows a 4-group of 2-layers conv net followed by 1 mean-pooling and 1 fully connected
    layer to reshape 4D feature tensor into 2D(1 for batch, 1 for channel), the final channel dimension is determined
    by the decoder config.
    @decoder_config [[4, 3, 2], [4, 3, 1]]
    '''
    def __init__(
        self
        , input_channel = 1
        , texture_encoder_config = [[4, 3, 2], [4, 3, 1]]
        , style_encoder_config = [[8, 3], [16, 3], [32, 3], [64, 3]]
        , decoder_config = [[4, 3, 2], [4, 3, 1]]
        ):
        super().__init__()
        # A gray image is sampled and separated into 2 parts
        # The first part is the detail part containing all the details of textures and objects.
        # The second part is the style part containing the neccessary style information of input image.
        # The overall 2 parts together reconstruct the input image.
        # *One is able to modify this into version of RGB/YUV input.

        # As the detail part should not loss details, we apply all strides to be 1;
        # To ensure it can run on low-computation-power devices, we use bottleneck convolution style.
        # For example, if the input is single channel(Gray image), we then apply a channel multiplication
        # K to get K feature maps, and then shrink down channel number by a pointwise convolution(1x1).
        # This is much different from Depthwise Separable Convolution. Almost all the computation is 
        # spent on NxN(N>=3) filtering, which matches import patterns, and avoids parameter growing and
        # memory expension.

        # the first part of encoder: texture encoder
        layers = [] # [[4, 3, 2], [4, 3, 1]]
        assert texture_encoder_config is not None
        assert isinstance(texture_encoder_config, tuple) or isinstance(texture_encoder_config, list)
        for i in range(len(texture_encoder_config)):
            chn_last = 0
            if i==0:
                chn_last = input_channel
            else:
                chn_last = texture_encoder_config[i-1][1]
            head_ = nn.Conv2d(chn_last, texture_encoder_config[i][0], texture_encoder_config[i][1], 1, 'same') # should be NxN big kernel ops
            layers.append(head_)
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            tail_ = nn.Conv2d(texture_encoder_config[i][0], texture_encoder_config[0][2], 1, 1, 'same') # should be 1x1 small kernel ops
            layers.append(tail_)
            layers.append(nn.LeakyReLU(0.01, inplace=True))
        self.texture_encoder = nn.Sequential(*layers)

        # the second part of encoder: style encoder
        layers = [] # [[8, 3], [16, 3], [32, 3], [64, 3]]
        assert style_encoder_config is not None
        assert isinstance(style_encoder_config, tuple) or isinstance(style_encoder_config, list)
        for i in range(len(style_encoder_config)):
            chn_last = 0
            if i==0:
                chn_last = input_channel
            else:
                chn_last = style_encoder_config[i-1][0]
            head_ = nn.Conv2d(chn_last, style_encoder_config[i][0], style_encoder_config[i][1], 2, 'same')
            layers.append(head_)
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            tail_ = nn.Conv2d(style_encoder_config[i][0], style_encoder_config[i][0], 1, 1, 'same')
            layers.append(tail_)
            layers.append(nn.LeakyReLU(0.01, inplace=True))
        # append the final mean pooling layer to make it 2D tensor
        layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
        self.style_encoder = nn.Sequential(*layers)

        # the final part is the decoding network
        # restore images from given texture tensors and style tensors
        assert decoder_config is not None
        assert isinstance(decoder_config, tuple) or isinstance(decoder_config, list)
        self.decoder_layers = [] # [[4, 3, 2], [4, 3, 1]]
        self.decoder_miu = [] # AdaIN layers applied to each conv layer
        self.decoder_sigma = [] # AdaIN layers applied to each conv layer
        for i in range(len(decoder_config)):
            chn_last = 0
            if i==0:
                chn_last = texture_encoder_config[-1][-1]
            else:
                chn_last = decoder_config[i-1][-1]
            # generate AdaIN parameters
            fc_miu = nn.Sequential(
                nn.Linear(style_encoder_config[-1][0], 64, True),
                nn.Tanh(),
                nn.Linear(64, 32, True),
                nn.Tanh(),
                nn.Linear(32, chn_last, True),
                nn.Tanh()
            )
            self.decoder_miu.append(fc_miu)

            fc_sigma = nn.Sequential(
                nn.Linear(style_encoder_config[-1][0], 64, True),
                nn.Tanh(),
                nn.Linear(64, 32, True),
                nn.Tanh(),
                nn.Linear(32, chn_last, True),
                nn.Tanh()
            )
            self.decoder_sigma.append(fc_sigma)

            conv_ = nn.Conv2d(chn_last, decoder_config[i][0], decoder_config[i][1], 1, 'same')
            relu_ = nn.LeakyReLU(0.01, inplace=True)
            self.decoder_layers.append(nn.Sequential(conv_, relu_))

            # generate AdaIN parameters
            fc_miu = nn.Sequential(
                nn.Linear(style_encoder_config[-1][0], 64, True),
                nn.Tanh(),
                nn.Linear(64, 32, True),
                nn.Tanh(),
                nn.Linear(32, decoder_config[i][0], True),
                nn.Tanh()
            )
            self.decoder_miu.append(fc_miu)

            fc_sigma = nn.Sequential(
                nn.Linear(style_encoder_config[-1][0], 64, True),
                nn.Tanh(),
                nn.Linear(64, 32, True),
                nn.Tanh(),
                nn.Linear(32, decoder_config[i][0], True),
                nn.Tanh()
            )
            self.decoder_sigma.append(fc_sigma)

            conv_ = nn.Conv2d(decoder_config[i][0], decoder_config[i][2], 1, 1, 'same')
            relu_ = nn.LeakyReLU(0.01, inplace=True)
            self.decoder_layers.append(nn.Sequential(conv_, relu_))
        
        # register those modules
        assert len(self.decoder_layers) == len(self.decoder_miu)
        for i in range(len(self.decoder_layers)):
            self.add_module('decoder_layers_%d' % i, self.decoder_layers[i])
            self.add_module('decoder_miu_%d' % i, self.decoder_miu[i])
            self.add_module('decoder_sigma_%d' % i, self.decoder_sigma[i])

        self.mix_style = AdaIN()

    def forward(self, x1, x2):
        # texture part
        t1 = self.texture_encoder(x1)
        t2 = self.texture_encoder(x2)
        # style part
        # resize the input to 256x192
        x1s = F.interpolate(x1, (256, 192))
        x2s = F.interpolate(x2, (256, 192))
        s1 = self.style_encoder(x1s)
        s2 = self.style_encoder(x2s)
        # cross decoding
        assert len(self.decoder_miu) == len(self.decoder_layers)
        assert len(self.decoder_sigma) == len(self.decoder_layers)
        x1r = None # cross-decoded image
        x2r = None# cross-decoded image
        for i in range(len(self.decoder_layers)):
            # cross decode the left image: mix t2 with s1
            miu1 = self.decoder_miu[i](s1)
            sigma1 = self.decoder_sigma[i](s1)
            if i==0:
                x1r = self.mix_style(t2, miu1, sigma1)
            else:
                x1r = self.mix_style(x1r, miu1, sigma1)
            x1r = self.decoder_layers[i](x1r)
            # cross decode the right image: mix t1 with s2
            miu2 = self.decoder_miu[i](s2)
            sigma2 = self.decoder_sigma[i](s2)
            if i==0:
                x2r = self.mix_style(t1, miu2, sigma2)
            else:
                x2r = self.mix_style(x2r, miu2, sigma2)
            x2r = self.decoder_layers[i](x2r)
        return x1r, x2r

    def dump_info(self):
        return {
            'texture_encoder': self.texture_encoder,
            'style_encoder': self.style_encoder,
            'decoder_layers': self.decoder_layers,
            'decoder_miu': self.decoder_miu,
            'decoder_sigma': self.decoder_sigma,
            'mix_style': self.mix_style
            }

如上代码实现的是一个孪生网络(Siamese Network),一般用于对比学习或自监督学习。
在后面训练的时候,要确保所有的参数都被优化器发现,需要对比参看dump_infoStereoSiamNet().parameters

posted @ 2021-07-23 11:11  xchk138  阅读(131)  评论(0)    收藏  举报