Pytorch模型中的parameter与buffer(torch.nn.Module的成员)

前言:

 1 class DaGMM(nn.Module):  # 自定义的模型需要继承nn.Module(固定写法)
 2     """Residual Block(残块)."""
 3 
 4     def __init__(self, n_gmm=2, latent_dim=3):
 5         super(DaGMM, self).__init__()  # (固定写法)
 6 
 7         layers = []
 8         layers += [nn.Linear(118, 60)]
 9         layers += [nn.Tanh()]  # 激活函数
10         layers += [nn.Linear(60, 30)]
11         layers += [nn.Tanh()]
12         layers += [nn.Linear(30, 10)]
13         layers += [nn.Tanh()]
14         layers += [nn.Linear(10, 1)]
15 
16         self.encoder = nn.Sequential(*layers)
17 
18         layers = []
19         layers += [nn.Linear(1, 10)]
20         layers += [nn.Tanh()]
21         layers += [nn.Linear(10, 30)]
22         layers += [nn.Tanh()]
23         layers += [nn.Linear(30, 60)]
24         layers += [nn.Tanh()]
25         layers += [nn.Linear(60, 118)]
26 
27         self.decoder = nn.Sequential(*layers)
28 
29         layers = []
30         layers += [nn.Linear(latent_dim, 10)]
31         layers += [nn.Tanh()]
32         layers += [nn.Dropout(p=0.5)]
33         layers += [nn.Linear(10, n_gmm)]
34         layers += [nn.Softmax(dim=1)]
35 
36         self.estimation = nn.Sequential(*layers)
37 
38         self.register_buffer("phi", torch.zeros(n_gmm))
39         self.register_buffer("mu", torch.zeros(n_gmm, latent_dim))
40         self.register_buffer("cov", torch.zeros(n_gmm, latent_dim, latent_dim))

  我们知道,pytorch一般情况下,是将网络中的参数保存成OrderedDict(见附1)形式的。这里的参数其实包括2种:一种是模型中的各种module含的参数,即nn.Parameter,我们当然可以在网络中定义其他的nn.Parameter参数。另外一种是buffer。前者每次optim.step会得到更新,而不会更新后者。

1、模型保存

 在Pytorch中一种模型保存和加载的方式如下:

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

  可以看到,模型保存的是model.state.dict()的返回对象。model.state_dict的返回对象是一个OrderedDict,它以键值对的形式保存模型中需要保存下来的参数,例如:

 1 import torch
 2 import torch.nn as nn
 3 
 4 class MyModule(nn.Module):
 5     def __init__(self, input_size, output_size):
 6         super(MyModule, self).__init__()
 7         self.lin = nn.Linear(input_size, output_size)
 8 
 9     def forward(self, x):
10         return self.lin(x)
11 
12 module = MyModule(4, 2)
13 print(module.state_dict())
14 输出:
15 OrderedDict([('lin.weight', tensor([[-0.3636, -0.4864, -0.2716, -0.4416],
16         [ 0.0119,  0.4462, -0.4558,  0.0188]])), ('lin.bias', tensor([ 0.1056, -0.1058]))])

  模型中的参数就是线性层的weight和bias。

2、Parameter & buffer

 torch.nn.register_parameter()用于注册Parameter实例到当前Module中(一般可以用torch.nn.Parameter()代替);torch.nn.register_buffer()用于注册Buffer实例到当前Module中。此外,Module中的parameters()函数会返回当前Module中所注册的所有Parameter的迭代器;而_all_buffers()函数会返回当前Module中所注册的所有Buffer的迭代器,(所以优化器不会计算Buffer的梯度,自然不会对其更新)。此外,Module中的state_dict()会返回包含当前Module中所注册的所有Parameter和Buffer(所以模型中未注册成Parameter或Buffer的参数无法被保存)。

 模型中需要保存下来的参数包括两种

  • 一种是反向传播需要被optimizer更新的,称之为parameter
  • 一种是反向传播不需要被optimizer更新的,称之为buffer

 第一种参数我们可以通过model.parameter()返回;第二种参数我们可以通过model.buffers()返回。因为我们模型保存的是state_dict返回的OrderedDict,所以这两种参数不仅要满足是/否需要更新的要求,还需要被保存到OrderedDict

 那么现在的问题是这两种参数如何创建呢?创建好之后如何保存到OrderedDict呢?

 2.1、parameter参数有两种创建方式: 

   1、我们可以直接将模型的成员变量(self.xxx)通过nn.Parameter()创建,会自动注册到parameters中,可以通过model.parameters()返回,并且这样创建的参数会自动保存到OrderedDict中去。

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.my_param = nn.Parameter(torch.randn(3, 3))  # 模型的成员变量
    def forward(self, x):
        # 可以通过 self.my_param 和 self.my_buffer 访问
        pass

model = MyModel()
for param in model.parameters():
    print(param)
print("----------------")
print(model.state_dict())
输出:
Parameter containing:
tensor([[-0.5421,  2.9562,  0.3447],
        [ 0.0869, -0.3464,  1.1299],
        [ 0.8644, -0.1384, -0.6338]])
----------------
OrderedDict([('param', tensor([[-0.5421,  2.9562,  0.3447],
        [ 0.0869, -0.3464,  1.1299],
        [ 0.8644, -0.1384, -0.6338]]))])

  2、通过nn.Parameter()创建普通的Parameter对象,不作为模型的成员变量,然后将Parameter对象通过register_parameter()进行注册,可以通过model.parameters()返回,注册后的参数也是会自动保存到OrderedDict中去。

import torch
import torch.nn as nn
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        param = nn.Parameter(torch.randn(3, 3))  # 普通 Parameter 对象
        self.register_parameter("my_param", param)

    def forward(self, x):
        # 可以通过 self.my_param 和 self.my_buffer 访问
        pass
model = MyModel()
for param in model.parameters():
    print(param)
print("----------------")
print(model.state_dict())
输出:
Parameter containing:
tensor([[-0.2313, -0.1490, -1.3148],
        [-1.2862, -2.2740,  1.0558],
        [-0.6559,  0.4552,  0.5993]])
----------------
OrderedDict([('my_param', tensor([[-0.2313, -0.1490, -1.3148],
        [-1.2862, -2.2740,  1.0558],
        [-0.6559,  0.4552,  0.5993]]))])

  2.2、buffer参数的创建方式:

  这种参数的创建需要先创建tensor,然后将tensor通过register_buffer()进行注册,可以通过model._all_buffers()返回,注册完成后参数也会自动保存到OrderedDict中去。

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        buffer = torch.randn(2, 3)  # tensor
        self.register_buffer('my_buffer', buffer)

    def forward(self, x):
        # 可以通过 self.param 和 self.my_buffer 访问
        pass
model = MyModel()
for buffer in model._all_buffers():
    print(buffer)
print("----------------")
print(model.state_dict())
输出:
tensor([[-0.2191,  0.1378, -1.5544],
        [-0.4343,  0.1329, -0.3834]])
----------------
OrderedDict([('my_buffer', tensor([[-0.2191,  0.1378, -1.5544],
        [-0.4343,  0.1329, -0.3834]]))])

总结:

 I、模型中需要进行更新的参数注册为Parameter,不需要进行更新的参数注册为buffer

 II、模型保存的参数是Model.state_dict()返回的OrderedDict

 III、模型进行设备移动时(CPU--->GPU),模型中注册的参数(Parameter和buffer)会同时进行移动。

附录:

1、很多人认为python中的字典是无序的,因为它是按照hash来存储的,但是python中有个模块collection,里面自带了一个子类OrderedDict,实现了对字典对象中元素的排序。

import collections

print("Regular dictionary")
d = {}
d['a'] = 'A'
d['b'] = 'B'
d['c'] = 'C'
for k, v in d.items():
    print(k, v)

print("\nOrder dictionary")
d1 = collections.OrderedDict()
d1['a'] = 'A'
d1['b'] = 'B'
d1['c'] = 'C'
d1['1'] = '1'
d1['2'] = '2'
for k, v in d1.items():
    print(k, v)
输出:
Regular dictionary
a A
c C
b B

Order dictionary
a A
b B
c C
1 1
2 2

可以看到,同样是保存了ABC等几个元素,但是使用OrderedDict会根据放入元素的先后顺序进行排序。所以输出的值是拍好序的,OrderedDict对象的字典对象,如果其顺序不同,那么python会把他们当做两个不同的对象。

参考:

参考1:【Pytorch】模型中buffer的使用

参考2:python中OrdredDict用法

参考3:Pytorch模型中的parameter和buffer

参考4:Pytorch采坑记录

 

posted @ 2020-07-10 17:40  小吴的日常  阅读(2016)  评论(0编辑  收藏  举报