Pytorch中'*'(打包/解包)与nn.Sequential()的结合运用
在看神经网络相关代码时,看见了这种写法
class Generator:
def __init__(self, image_size):
super(Generator, self).__init__()
self.gen = nn.Sequential(
*self._block(12, 128, 0.2),
*self._block(128, 256, 0.2),
*self._block(256, 512, 0.2),
*self._block(512, 1024, 0.2),
nn.Linear(1024, image_size),
nn.Tanh()
)
def _block(self, in_feature, out_feature, slop, normalize=True):
layers = [nn.Linear(in_feature, out_feature)]
if normalize:
layers.append(nn.BatchNorm1d(out_feature, 0.8))
layers.append(nn.LeakyReLU(slop, inplace=True))
return layers
_block
很常见,但*_block
还第一次见,仔细理解了后发现这种写法很巧妙,于是解析并记录下来。顺带将*
的用法一并记录。
在python中,我们经常见到'*'这个符号,我们知道'*'和'**'可以在函数中分别用作接收任意位置的参数和接收任意关键字的参数,具体可见知乎的这篇回答:Python:函数定义和调用时都加*,有什么作用? - 石溪的回答 - 知乎。同时我们也知道'*'和'**'可以分别用作元组的打包/解包和字典的打包/解包。那么这个打包/解包在nn.Sequential()中有什么妙用呢,我们先来回顾下它是怎么打包/解包的。
打包
首先我们创建一个函数func()
,在3个位置接收参数
def func(arg1, *args2):
print('This is arg1: ', arg1)
print('----------')
print('These are args2: ', args2)
print('args2\'s type: ', type(args2))
func(1, 2, 3)
# >>>
# This is arg1: 1
# ----------
# These are args2: (2, 3)
# args2's type: <class 'tuple'>
我们可以看到,一号位形参arg1
接收了一号位实参1
,而二号位形参*args2
接受了剩下所有位置的实参并将其打包为了一个元组。
我们再用'**'看看有什么做用。
def func2(arg3, **args4):
print('This is arg3: ', arg3)
print('----------')
print('These are args4: ', args4)
print('args4\'s type: ', type(args4))
func2(1, a=2, b=3)
# >>>
# This is arg3: 1
# ----------
# These are args4: {'a': 2, 'b': 3}
# args4's type: <class 'dict'>
这里我们可以看到,一号位形参arg3
接受了一号位实参1
,二号位形参接受了剩下所有关键字的实参并将其打包为了一个字典。
接下来我们看看'*'和'**'组合在一起的情况。
def func3(arg5, *args6, **args7):
print('This is arg5: ', arg5)
print('----------')
print('These are args6: ', args6)
print('args6\'s type: ', type(args7))
print('----------')
print('These are args7: ', args7)
print('args7\'s type: ', type(args7))
func3(1, 2, 3, a=4, b=5)
# >>>
# This is arg5: 1
# ----------
# These are args6: (2, 3)
# args6's type: <class 'dict'>
# ----------
# These are args7: {'a': 4, 'b': 5}
# args7's type: <class 'dict'>
一般书写时,我们习惯上把带有两个星号的,接收关键字的形参命名为kwargs
,即**kwargs
。
接下来看看'*'与'**'的解包用法。
解包
在上述例子中,我们使用将任意数目的实参传递给带星号的形参,在完成实参接收的同时又做了打包的工作。现在我们看看下面的代码发生了什么。
def func4(arg, *args):
print('This is arg: ', arg)
print('----------')
print('These are *args: ', *args)
func4(1, *(2, 3))
# >>>
# This is arg: 1
# ----------
# These are *args: 2 3
可以看到,*args
由两个数构成,2和3。这两个数又来自于*(2, 3)
,也就是说'*'将(2, 3)
这个元组拆成了2和3两个数。注意:这里是两个数!而不是什么包含两个值的列表或者元组!什么?你问我为什么是'2 3'而不是func1
所展示的(2, 3)?因为我在第4行打印的是*args
而不是args
,也就是说我打印的是将args
解包过后了的值,我们看看打印args
会发生什么。
def func4(arg, *args):
print('This is arg: ', arg)
print('----------')
print('These are args: ', args) # 注意这里更改了参数:*args -> args
func4(1, *(2, 3))
# >>>
# This is arg: 1
# ----------
# These are args: (2, 3)
这不解了包后经过func4
又打包回去了嘛(恼
同理,我们还可以对字典解包。
def func5(arg2, *args2):
print('This is arg2: ', arg2)
print('----------')
print('These are *args2: ', *args2)
func5(1, *{'a': 2, 'b': 3})
# >>>
# This is arg2: 1
# ----------
# These are *args2: a b
欸乍一看怎么只有字典的键没有值呢。那就对了,'**'可将任意数目的关键字参数打包为字典,但是用'*'解包时只能解出字典的键。那用'**'来解包会怎么样呢?
def func6(arg3, **args3):
print('This is arg3: ', arg3)
print('----------')
print('These are **args3: ', **args3)
func6(1, **{'a': 2, 'b': 3})
# >>>
# This is arg3: 1
# ----------
# TypeError: 'a' is an invalid keyword argument for print()
第四行运行报错了(悲。'a'打印的是一个无效的关键字参数。啥意思?我们来写一段会报同样错误的代码看看
print(a = 3 + 4)
# >>> TypeError: 'a' is an invalid keyword argument for print()
精通Python的你会想到,python3.8版本新加入了一个海象运算符:=
(挺形象),于是乎这段语句可以这么改
print(a := 3 + 4)
# >>> 7
了解海象运算符就很容易理解这是怎么回事了。不理解没关系,简单来说,就是print()
函数只会简单打印出括号里会返回的内容,例如3+4会返回7,但a=3+4则是个赋值语句,并不会返回什么东西,而且也不认识'a',所以报错了。我们知道了报错原因后再回到func6
,解释器告诉我们:我不认识'a'这个玩意儿。好嘛,让你认识认识,我们给它俩个专属的形参位置,让它俩坐进去
def func6(arg3, a, b):
print('This is arg3: ', arg3)
print('----------')
print('These are a b: ', a, b)
func6(1, **{'a': 2, 'b': 3})
# >>>
# This is arg3: 1
# ----------
# These are a b: 2 3
与nn.Sequential()结合
绕了一大圈终于回来了,在对解包有了深刻理解后,我们现在来解释下文章开头抛出的代码。在神经网络中,同一个模型中的相同结构经常会重复使用,如果一行行把这些重复的网络结构写出来,代码将会变得特别繁琐并且可读性差。这时候,我们可以将重复的代码独立写为一个_block
,然后重复调用这个_block
便可以提升我们代码的可读性。但是,这样做有一个缺点,如果说我们在重复调用第二次或第三次的过程中,网络结构不一致该怎么办呢?就像下面这段代码
def __init__(self, image_size):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# Part1
nn.Linear(12, 128),
nn.LeakyReLU(0.2, inplace=True),
# Part2
nn.Linear(128, 256),
nn.BatchNorm1d(256, 0.8),
nn.LeakyReLU(0.2, inplace=True),
# Part3
nn.Linear(256, 512),
nn.BatchNorm1d(512, 0.8),
nn.LeakyReLU(0.2, inplace=True),
# Part4
nn.Linear(512, 1024),
nn.BatchNorm1d(1024, 0.8),
nn.LeakyReLU(0.2, inplace=True),
# Part5
nn.Linear(1024, image_size),
nn.Tanh()
)
在gen这个模型中,网络层次可以划分为5个部分,我已经标注出来了。我们非常容易看出,除去第五部分,每个部分都有相同的结构,即Linear层和LeakyReLU层。除开第1部分,其余的部分又都包含有批次划归一层。我们固然可以把第一部分和第五部分刨除,将2、3、4部分的结构写为一个block然后重复调用,不过我们还可以用一种更优雅的方法来简化这段代码。
首先,我们先回顾下nn.Sequential()
,这个类可以直接接收多个网络层或者一个包含多个网络层及其名字的OrderedDict(这里不展开介绍了)。我们这里使用它的第一个方法,即接收多个网络层,我们突然想到,前面介绍过的解包操作出来的不就是多个值吗,很有道理,但是转念一想,“我都有一个单独的block函数了我还去打包解包做什么?这不多此一举”也很有道理。但是,解包的对象是个列表(前文只提到了元组,其实列表也可以),列表就有个有点就是往里面增删元素非常方便!上例所展示的Part1,他没有BN(批次划归一)层,2、3、4有,那我是不是可以写一个列表,满足一定条件那就往里面塞一个BN层?如下
if condition:
layers.append(nn.BatchNorm1d(num_feature, eps))
那这下好了,我们得到了一个基于if-else变化的网络层列表了。我们可以把它写为一个block方便我们调用。
def _block(self, in_feature, out_feature, slop, normalize=True):
layers = [nn.Linear(in_feature, out_feature)]
if normalize:
layers.append(nn.BatchNorm1d(out_feature, 0.8))
layers.append(nn.LeakyReLU(slop, inplace=True))
return layers
但是nn.Sequential()
接收的是多个值或者OrderedDict怎么办?解包!代码如下
class Generator:
def __init__(self, image_size):
super(Generator, self).__init__()
self.gen = nn.Sequential(
*self._block(12, 128, 0.2, normalize=False),
*self._block(128, 256, 0.2),
*self._block(256, 512, 0.2),
*self._block(512, 1024, 0.2),
nn.Linear(1024, image_size),
nn.Tanh()
)
def _block(self, in_feature, out_feature, slop, normalize=True):
layers = [nn.Linear(in_feature, out_feature)]
if normalize:
layers.append(nn.BatchNorm1d(out_feature, 0.8))
layers.append(nn.LeakyReLU(slop, inplace=True))
return layers
这样我们就得到了一个非常优雅的神经网络模型了。