闲谈之Module.state_dict()和Module.parameter()

如是我闻,姑妄听之!

—— ——题记

Module.state_dict()和Module.parameter()

  Module.state_dict() 和Module.parameter()中都存储了网络的参数,本文使用实例来探究这两个方法的不同之处。例如下面一个简单的网络:

class Simple_Net(nn.Module):
    def __init__(self):
        super(Simple_Net,self).__init__()
        self.Con = nn.Sequential(
            nn.Conv2d(1,2,kernel_size=(3,3)),
            nn.MaxPool2d(kernel_size=(1,2)),
            nn.Conv2d(2,2,kernel_size=(3,3),padding=(3//2,3//2))
        )
        self.fc = nn.Linear(324,2)
        
    def forward(self,x):
        x = self.Con(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x  

  将网络实例化后观察其state_dict():

Net = Simple_Net()
print(Net.steta_dict().keys())
print(Net.state_dict()['Con.0.weight'])

  输出:

odict_keys(['Con.0.weight', 'Con.0.bias', 'Con.2.weight', 'Con.2.bias', 'fc.weight', 'fc.bias'])
tensor([[[[-0.0936, -0.2798, -0.0771],
        [ 0.0018,  0.1036,  0.1325],
        [ 0.0621, -0.1227, -0.2912]]],


        [[[ 0.2041, -0.2608, -0.0597],
        [-0.1355, -0.0319, -0.0327],
        [-0.1809,  0.0811, -0.0524]]]])

  从上面输出的结果可以看出,Module.state_dict()返回了一个包含网络参数的列表。需要注意的是参数的Tensor中没有计算梯度的标志。

  下面在看一看Module.parameter()到底输出了什么:

print(type(Net.parameter()))
print(next(Net.parameter()))

  输出:

<class 'generator'>
Parameter containing:
tensor([[[[-0.0936, -0.2798, -0.0771],
        [ 0.0018,  0.1036,  0.1325],
        [ 0.0621, -0.1227, -0.2912]]],


        [[[ 0.2041, -0.2608, -0.0597],
        [-0.1355, -0.0319, -0.0327],
        [-0.1809,  0.0811, -0.0524]]]], requires_grad=True)

  从代码的输出结果可以看出,当调用Module.parameter()后返回了一个生成器,可以发现Module.parameter()和Module.state_dict()存储了一样的数据,不过有区别的是,Module.parameter()中存放的Tensor的requires_grad设置为了True。

posted @ 2021-11-24 19:11  流纹抄  阅读(560)  评论(0)    收藏  举报