[pytorch] 训练时冻结一部分模型的参数 —— module.requires_grad_(False)

prologue

title: [pytorch] 训练时冻结一部分模型的参数 —— module.requires_grad_(False)
代码用到一个解码器\(dec\),希望用它预测生成结果\(g\)的counting encode并用以计算损失,以此约束生成器生成合理的结果(能解码出正确的counting encode)
但考虑到\(g\)并不准确,如果不冻结\(dec\)的参数,就会被\(g\)带偏

idea

实际上可以用 nn.Module.requires_grad_(False)

dec.requires_grad_(False)
logits = dec.classify(g)
dec.requires_grad_(True)
loss(logits, label)

但是dec还需要处理其他特征(比如输入x,用以训练dec本身),需要更新参数,不确定上面那样在梯度更新之前就重新设为True是否可行

validate

下面就用一个两个MLP来模拟\(dec\)

import random
import torch
import torch.nn as nn
from torch import optim
import numpy as np


def setup_seed(seed, strict=True):
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)
  random.seed(seed)
  if strict:
    torch.backends.cudnn.deterministic = True


class MLP(nn.Module):

  def __init__(self):
    super().__init__()
    self.c_fc = nn.Linear(6, 4 * 6, bias=False)
    self.gelu = nn.GELU()
    self.c_proj = nn.Linear(4 * 6, 6, bias=False)
    self.dropout = nn.Dropout(0.1)

  def forward(self, x):
    x = self.c_fc(x)
    x = self.gelu(x)
    x = self.c_proj(x)
    x = self.dropout(x)
    return x

接着先固定seed,创建两个mlp并查看初始参数(以其中一个线性层为例即可),输入x随机初始化,再用SGD更新参数

setup_seed(233)
m1 = MLP()
m2 = MLP()
print('m1-weight\n', m1.c_fc.weight[:3, :7])
print('m2-weight\n', m2.c_fc.weight[:3, :7])
x = torch.randn((3, 6), requires_grad=True)
# print(x.requires_grad, x)
opimizer1 = optim.SGD(m1.parameters(), lr=0.1)
opimizer2 = optim.SGD(m2.parameters(), lr=0.1)

上述输出:

m1-weight
 tensor([[-0.1926, -0.3506, -0.0758,  0.2318, -0.1711,  0.3389],
        [-0.1360, -0.3253, -0.1471,  0.3951,  0.4054, -0.0247],
        [-0.1603,  0.1496, -0.2566,  0.3095,  0.2247, -0.3124]],
       grad_fn=<SliceBackward0>)
m2-weight
 tensor([[ 0.3739,  0.3307, -0.1514,  0.0787,  0.3436, -0.1428],
        [-0.1487,  0.1236,  0.4002, -0.2563, -0.0266, -0.2860],
        [ 0.3197, -0.1728, -0.1770, -0.2492,  0.2864, -0.3191]],
       grad_fn=<SliceBackward0>)

将x扔进模型,然后得到最终输出x3,retain_grad用以维持这些中间变量的梯度,方便输出查看,m2调用时冻结参数,与m1形成对照:

x1 = m1(x)
m2.requires_grad_(False)
x2 = m2(x1)
m2.requires_grad_(True)
x3 = x2 * 0.3 - 3
x1.retain_grad()
x2.retain_grad()

下面的t不太清楚,似乎是给backward的起始梯度,总之能算就行。

t = torch.randn((3, 6)) * 10
# t = torch.ones_like(x)
opimizer1.zero_grad()
opimizer2.zero_grad()
x3.backward(t)
opimizer1.step()
opimizer2.step()
# print('x\n', x.grad)
# print('x1\n', x1.grad)
# print('x2\n', x2.grad)
print('m1-weight\n', m1.c_fc.weight[:3, :7])
print('m2-weight\n', m2.c_fc.weight[:3, :7])

上述输出:

m1-weight
 tensor([[-0.1897, -0.3464, -0.0873,  0.2439, -0.1725,  0.3384],
        [-0.1663, -0.3467, -0.2410,  0.4209,  0.4023, -0.0558],
        [-0.1909,  0.1313, -0.3833,  0.3582,  0.2189, -0.3493]],
       grad_fn=<SliceBackward0>)
m2-weight
 tensor([[ 0.3739,  0.3307, -0.1514,  0.0787,  0.3436, -0.1428],
        [-0.1487,  0.1236,  0.4002, -0.2563, -0.0266, -0.2860],
        [ 0.3197, -0.1728, -0.1770, -0.2492,  0.2864, -0.3191]],
       grad_fn=<SliceBackward0>)

发现m2参数果然没变,而m1已经更新,因此证明方案可行。

posted @ 2023-10-17 19:59  NoNoe  阅读(176)  评论(0编辑  收藏  举报