pytorch确保所有参数正常初始化并参与梯度传播
问题
如何确保搭建的网络中所有定义的结构中的可训练参数正常初始化,并且参与梯度传播过程?
分析
参数的初始化首先依赖于模块被pytorch识别到,因为pytorch和tensorflow不一样,torch中动态图的灵活性也带来了稍许不便(坑):
自定义的一个继承于nn.Module
的神经网络class,其__init__(self,...)
函数中只能定义要是用到的模块组件,并不能直接定义完整的图结构,这是因为动态图的优势发挥在forward(self, x, ...)
函数中,这个前向计算函数才是真正构建图的地方,这一点和tf1.0的静态图差异巨大。因为可以在forward中任意改动定义在__init__
中的组件的连接关系,甚至临时添加新的组件,同时可以直接逐步查看每个节点的输出状态,这种灵活性是静态图不能提供的。但正是基于此,优化器optimizer很难根据用户想要的方式捕捉要训练的变量(参数),而且即便定义在__init__
函数中的组件也有时无法被捕获到。导致训练出现问题。为此,Pytorch提供了补救(填坑)措施:nn.Module.add_module(nn.Module)
可以注册子模块(用于连接构造图的组件),将指定的模块暴露给torch执行环境,并且将模块对应的参数暴露给优化器optimizer,让这些临时构建的组件的参数也参与训练。
解决
如下代码所示,通过add_module()
实现动态图的组件注册,避免漏参。
# Stereo Siamese Net
import sys
sys.path.append('..')
from utils import AdaIN, adain
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
class StereoSiamNet(nn.Module):
'''
Params:
@input_channel 1(for gray image) or 3(for RGB image)
@texture_encoder_config [[4, 3, 2], [4, 3, 1]]
this texture encoder config shows a 2 group of 2-layer conv net which firstly applys 4 filters with 3x3 kernel
to the input gray image, and then use 1x1 kernel filters to reduce channels to 2, and then a similar
group of ops are applied after this, while they finally reduce channels to 1.
@style_encoder_config [[8, 3], [16, 3], [32, 3], [64, 3]]
this style encoder config shows a 4-group of 2-layers conv net followed by 1 mean-pooling and 1 fully connected
layer to reshape 4D feature tensor into 2D(1 for batch, 1 for channel), the final channel dimension is determined
by the decoder config.
@decoder_config [[4, 3, 2], [4, 3, 1]]
'''
def __init__(
self
, input_channel = 1
, texture_encoder_config = [[4, 3, 2], [4, 3, 1]]
, style_encoder_config = [[8, 3], [16, 3], [32, 3], [64, 3]]
, decoder_config = [[4, 3, 2], [4, 3, 1]]
):
super().__init__()
# A gray image is sampled and separated into 2 parts
# The first part is the detail part containing all the details of textures and objects.
# The second part is the style part containing the neccessary style information of input image.
# The overall 2 parts together reconstruct the input image.
# *One is able to modify this into version of RGB/YUV input.
# As the detail part should not loss details, we apply all strides to be 1;
# To ensure it can run on low-computation-power devices, we use bottleneck convolution style.
# For example, if the input is single channel(Gray image), we then apply a channel multiplication
# K to get K feature maps, and then shrink down channel number by a pointwise convolution(1x1).
# This is much different from Depthwise Separable Convolution. Almost all the computation is
# spent on NxN(N>=3) filtering, which matches import patterns, and avoids parameter growing and
# memory expension.
# the first part of encoder: texture encoder
layers = [] # [[4, 3, 2], [4, 3, 1]]
assert texture_encoder_config is not None
assert isinstance(texture_encoder_config, tuple) or isinstance(texture_encoder_config, list)
for i in range(len(texture_encoder_config)):
chn_last = 0
if i==0:
chn_last = input_channel
else:
chn_last = texture_encoder_config[i-1][1]
head_ = nn.Conv2d(chn_last, texture_encoder_config[i][0], texture_encoder_config[i][1], 1, 'same') # should be NxN big kernel ops
layers.append(head_)
layers.append(nn.LeakyReLU(0.01, inplace=True))
tail_ = nn.Conv2d(texture_encoder_config[i][0], texture_encoder_config[0][2], 1, 1, 'same') # should be 1x1 small kernel ops
layers.append(tail_)
layers.append(nn.LeakyReLU(0.01, inplace=True))
self.texture_encoder = nn.Sequential(*layers)
# the second part of encoder: style encoder
layers = [] # [[8, 3], [16, 3], [32, 3], [64, 3]]
assert style_encoder_config is not None
assert isinstance(style_encoder_config, tuple) or isinstance(style_encoder_config, list)
for i in range(len(style_encoder_config)):
chn_last = 0
if i==0:
chn_last = input_channel
else:
chn_last = style_encoder_config[i-1][0]
head_ = nn.Conv2d(chn_last, style_encoder_config[i][0], style_encoder_config[i][1], 2, 'same')
layers.append(head_)
layers.append(nn.LeakyReLU(0.01, inplace=True))
tail_ = nn.Conv2d(style_encoder_config[i][0], style_encoder_config[i][0], 1, 1, 'same')
layers.append(tail_)
layers.append(nn.LeakyReLU(0.01, inplace=True))
# append the final mean pooling layer to make it 2D tensor
layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
self.style_encoder = nn.Sequential(*layers)
# the final part is the decoding network
# restore images from given texture tensors and style tensors
assert decoder_config is not None
assert isinstance(decoder_config, tuple) or isinstance(decoder_config, list)
self.decoder_layers = [] # [[4, 3, 2], [4, 3, 1]]
self.decoder_miu = [] # AdaIN layers applied to each conv layer
self.decoder_sigma = [] # AdaIN layers applied to each conv layer
for i in range(len(decoder_config)):
chn_last = 0
if i==0:
chn_last = texture_encoder_config[-1][-1]
else:
chn_last = decoder_config[i-1][-1]
# generate AdaIN parameters
fc_miu = nn.Sequential(
nn.Linear(style_encoder_config[-1][0], 64, True),
nn.Tanh(),
nn.Linear(64, 32, True),
nn.Tanh(),
nn.Linear(32, chn_last, True),
nn.Tanh()
)
self.decoder_miu.append(fc_miu)
fc_sigma = nn.Sequential(
nn.Linear(style_encoder_config[-1][0], 64, True),
nn.Tanh(),
nn.Linear(64, 32, True),
nn.Tanh(),
nn.Linear(32, chn_last, True),
nn.Tanh()
)
self.decoder_sigma.append(fc_sigma)
conv_ = nn.Conv2d(chn_last, decoder_config[i][0], decoder_config[i][1], 1, 'same')
relu_ = nn.LeakyReLU(0.01, inplace=True)
self.decoder_layers.append(nn.Sequential(conv_, relu_))
# generate AdaIN parameters
fc_miu = nn.Sequential(
nn.Linear(style_encoder_config[-1][0], 64, True),
nn.Tanh(),
nn.Linear(64, 32, True),
nn.Tanh(),
nn.Linear(32, decoder_config[i][0], True),
nn.Tanh()
)
self.decoder_miu.append(fc_miu)
fc_sigma = nn.Sequential(
nn.Linear(style_encoder_config[-1][0], 64, True),
nn.Tanh(),
nn.Linear(64, 32, True),
nn.Tanh(),
nn.Linear(32, decoder_config[i][0], True),
nn.Tanh()
)
self.decoder_sigma.append(fc_sigma)
conv_ = nn.Conv2d(decoder_config[i][0], decoder_config[i][2], 1, 1, 'same')
relu_ = nn.LeakyReLU(0.01, inplace=True)
self.decoder_layers.append(nn.Sequential(conv_, relu_))
# register those modules
assert len(self.decoder_layers) == len(self.decoder_miu)
for i in range(len(self.decoder_layers)):
self.add_module('decoder_layers_%d' % i, self.decoder_layers[i])
self.add_module('decoder_miu_%d' % i, self.decoder_miu[i])
self.add_module('decoder_sigma_%d' % i, self.decoder_sigma[i])
self.mix_style = AdaIN()
def forward(self, x1, x2):
# texture part
t1 = self.texture_encoder(x1)
t2 = self.texture_encoder(x2)
# style part
# resize the input to 256x192
x1s = F.interpolate(x1, (256, 192))
x2s = F.interpolate(x2, (256, 192))
s1 = self.style_encoder(x1s)
s2 = self.style_encoder(x2s)
# cross decoding
assert len(self.decoder_miu) == len(self.decoder_layers)
assert len(self.decoder_sigma) == len(self.decoder_layers)
x1r = None # cross-decoded image
x2r = None# cross-decoded image
for i in range(len(self.decoder_layers)):
# cross decode the left image: mix t2 with s1
miu1 = self.decoder_miu[i](s1)
sigma1 = self.decoder_sigma[i](s1)
if i==0:
x1r = self.mix_style(t2, miu1, sigma1)
else:
x1r = self.mix_style(x1r, miu1, sigma1)
x1r = self.decoder_layers[i](x1r)
# cross decode the right image: mix t1 with s2
miu2 = self.decoder_miu[i](s2)
sigma2 = self.decoder_sigma[i](s2)
if i==0:
x2r = self.mix_style(t1, miu2, sigma2)
else:
x2r = self.mix_style(x2r, miu2, sigma2)
x2r = self.decoder_layers[i](x2r)
return x1r, x2r
def dump_info(self):
return {
'texture_encoder': self.texture_encoder,
'style_encoder': self.style_encoder,
'decoder_layers': self.decoder_layers,
'decoder_miu': self.decoder_miu,
'decoder_sigma': self.decoder_sigma,
'mix_style': self.mix_style
}
如上代码实现的是一个孪生网络(Siamese Network),一般用于对比学习或自监督学习。
在后面训练的时候,要确保所有的参数都被优化器发现,需要对比参看dump_info
和StereoSiamNet().parameters
。