闲谈之Module.state_dict()和Module.parameter()
如是我闻,姑妄听之!
—— ——题记
Module.state_dict()和Module.parameter()
Module.state_dict() 和Module.parameter()中都存储了网络的参数,本文使用实例来探究这两个方法的不同之处。例如下面一个简单的网络:
class Simple_Net(nn.Module):
def __init__(self):
super(Simple_Net,self).__init__()
self.Con = nn.Sequential(
nn.Conv2d(1,2,kernel_size=(3,3)),
nn.MaxPool2d(kernel_size=(1,2)),
nn.Conv2d(2,2,kernel_size=(3,3),padding=(3//2,3//2))
)
self.fc = nn.Linear(324,2)
def forward(self,x):
x = self.Con(x)
x = x.flatten(1)
x = self.fc(x)
return x
将网络实例化后观察其state_dict():
Net = Simple_Net()
print(Net.steta_dict().keys())
print(Net.state_dict()['Con.0.weight'])
输出:
odict_keys(['Con.0.weight', 'Con.0.bias', 'Con.2.weight', 'Con.2.bias', 'fc.weight', 'fc.bias'])
tensor([[[[-0.0936, -0.2798, -0.0771],
[ 0.0018, 0.1036, 0.1325],
[ 0.0621, -0.1227, -0.2912]]],
[[[ 0.2041, -0.2608, -0.0597],
[-0.1355, -0.0319, -0.0327],
[-0.1809, 0.0811, -0.0524]]]])
从上面输出的结果可以看出,Module.state_dict()返回了一个包含网络参数的列表。需要注意的是参数的Tensor中没有计算梯度的标志。
下面在看一看Module.parameter()到底输出了什么:
print(type(Net.parameter()))
print(next(Net.parameter()))
输出:
<class 'generator'>
Parameter containing:
tensor([[[[-0.0936, -0.2798, -0.0771],
[ 0.0018, 0.1036, 0.1325],
[ 0.0621, -0.1227, -0.2912]]],
[[[ 0.2041, -0.2608, -0.0597],
[-0.1355, -0.0319, -0.0327],
[-0.1809, 0.0811, -0.0524]]]], requires_grad=True)
从代码的输出结果可以看出,当调用Module.parameter()后返回了一个生成器,可以发现Module.parameter()和Module.state_dict()存储了一样的数据,不过有区别的是,Module.parameter()中存放的Tensor的requires_grad设置为了True。

浙公网安备 33010602011771号