【Pytorch】ResNet50中替换BN层为IN层实现

替换BN层为IN层

最近在做实验时,考虑将官方torchvision包中的Resnet模型进行一些更改,ResNet类中有个可选参数_norm_layer可以直接传入nn.InstanceNorm2d,默认为nn.BatchNorm,但是这样更改后,在使用官方的预训练权重时,会发生一些报错,BN层里的一些权重会导致报错,因此用另一种方式实现替换BN层的需求的同时,尽可能使用预训练权重

实现

  1. 定义一个函数来替换 BN 层为 IN 层
import torch.nn as nn

def replace_bn_with_in(module):
    """
    遍历网络模块,将 BatchNorm 替换为 InstanceNorm
    """
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, nn.InstanceNorm2d(child.num_features, affine=True))
        else:
            replace_bn_with_in(child)

  1. 加载预训练的 ResNet 模型
import torchvision.models as models

# 加载预训练的 ResNet 模型(这里以 resnet50 为例)
model = models.resnet50(pretrained=True)

  1. 替换BN为IN
# 将模型中的 BatchNorm 层替换为 InstanceNorm 层
replace_bn_with_in(model)

补充

setattr 函数是 Python 的内置函数,用于设置对象的属性。如果属性不存在,它会创建一个新属性。setattr 函数的使用格式如下:

setattr(object, name, value)
  • 参数
    • object:要设置属性的对象。
    • name:属性的名称,一个字符串。
    • value:要设置的属性值
posted @ 2024-04-11 16:25  chendsome  阅读(91)  评论(0)    收藏  举报  来源