pytorch冻结参数

这个主要用在微调的时候

# 冻结参数,这个就是将resnet34中原本的层的param参数不传播,新加的传播
    for param in finetune_net.features.parameters():
        param.requires_grad = False

比如,这个resnet34,它的输出是1000分类,我们下载了这个resnet34已经它训练好的参数,并且我们再其后面加入了两个线性层,但是我们在训练的时候不想让其下载的参数改变,只让那两层线性层进行训练,然后我们可以使用冻结参数,让其参数不参与训练,比如:

# 微调预训练模型
def get_net():
    finetune_net = nn.Sequential()
    finetune_net.features = torchvision.models.resnet34(pretrained=True)
    # 定义一个新的输出网络,共有120个输出类别
    finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),
                                            nn.ReLU(),
                                            nn.Linear(256, 120))
    # 冻结参数,这个就是将resnet34中原本的层的param参数不传播,新加的传播
    for param in finetune_net.features.parameters():
        param.requires_grad = False
    return finetune_net
posted @ 2023-10-28 22:38  lipu123  阅读(129)  评论(0)    收藏  举报