pytorch冻结参数
这个主要用在微调的时候
# 冻结参数,这个就是将resnet34中原本的层的param参数不传播,新加的传播
for param in finetune_net.features.parameters():
param.requires_grad = False
比如,这个resnet34,它的输出是1000分类,我们下载了这个resnet34已经它训练好的参数,并且我们再其后面加入了两个线性层,但是我们在训练的时候不想让其下载的参数改变,只让那两层线性层进行训练,然后我们可以使用冻结参数,让其参数不参与训练,比如:
# 微调预训练模型
def get_net():
finetune_net = nn.Sequential()
finetune_net.features = torchvision.models.resnet34(pretrained=True)
# 定义一个新的输出网络,共有120个输出类别
finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),
nn.ReLU(),
nn.Linear(256, 120))
# 冻结参数,这个就是将resnet34中原本的层的param参数不传播,新加的传播
for param in finetune_net.features.parameters():
param.requires_grad = False
return finetune_net