Meanteacher

💻 PyTorch 语义分割模型

I. 导入与基础工具

import torch # 导入 PyTorch 核心库
import torch.nn as nn # 导入 PyTorch 神经网络模块,用于定义层
import resnet, resnext # 导入 ResNet 和 ResNeXt 骨干网络模块,假定这些模块在当前项目路径下已定义
from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d # 从 PyTorch 中导入标准的 BatchNorm2d,并将其重命名为 SynchronizedBatchNorm2d。在实际多GPU场景中,这通常会替换为自定义的同步BN实现。
#from lib.nn import SynchronizedBatchNorm2d # 示例注释:如果需要使用自定义的同步BN,则取消注释此行,并注释上一行。

# 计算像素准确率的基类
class SegmentationModuleBase(nn.Module):
    def __init__(self):
        super(SegmentationModuleBase, self).__init__() # 初始化父类 nn.Module

    def pixel_acc(self, pred, label):
        # pred: [B, C, H, W] 模型输出的 Logits 或 LogSoftmax 结果
        # label: [B, H, W] 真实标签,其中 C 为类别数
        _, preds = torch.max(pred, dim=1) # 沿着类别维度 (dim=1) 找到概率最大的索引,即预测的类别
        valid = (label >= 0).long()       # 创建一个掩码,标记标签值 >= 0 的像素为有效像素(用于忽略填充或特殊背景值)
        acc_sum = torch.sum(valid * (preds == label).long()) # 计算正确预测的有效像素总数
        pixel_sum = torch.sum(valid)      # 计算总的有效像素数
        acc = acc_sum.float() / (pixel_sum.float() + 1e-10) # 计算像素准确率,加上 1e-10 防止除零错误
        return acc

# 3x3 卷积函数
def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
    "3x3 convolution with padding"
    # 定义一个标准的 3x3 卷积层
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=has_bias) # padding=1 使得在 stride=1 时输出尺寸不变

# 3x3 卷积 + BN + ReLU 序列
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
    # 定义一个包含三个操作的序列
    return nn.Sequential(
            conv3x3(in_planes, out_planes, stride), # 3x3 卷积
            SynchronizedBatchNorm2d(out_planes), # 批量归一化层
            nn.ReLU(inplace=True), # ReLU 激活函数,inplace=True 节省内存
            )

II. 核心分割模块 (SegmentationModule)

该模块是训练和推理的入口,将编码器和解码器连接起来,并处理损失和评估。

class SegmentationModule(SegmentationModuleBase):
    def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
        super(SegmentationModule, self).__init__()
        self.encoder = net_enc          # 编码器实例(如 ResnetDilated)
        self.decoder = net_dec          # 解码器实例(如 PPMBilinear)
        self.crit = crit                # 损失函数实例 (Criterion)
        self.deep_sup_scale = deep_sup_scale # 深度监督损失的加权系数 (alpha)

    def forward(self, feed_dict, *, segSize=None):
        # feed_dict: 输入数据字典,至少包含 'img_data' 和 'seg_label'
        # segSize: 目标分割尺寸,仅在推理时使用(训练时为 None)
        
        if segSize is None: # 训练模式
            # 编码器前向传播,要求返回所有阶段的特征图列表
            enc_features = self.encoder(feed_dict['img_data'], return_feature_maps=True)
            
            if self.deep_sup_scale is not None: # 如果启用了深度监督
                # 解码器返回主预测 (pred) 和辅助预测 (pred_deepsup)
                (pred, pred_deepsup) = self.decoder(enc_features) 
            else:
                pred = self.decoder(enc_features) # 解码器只返回主预测

            loss = self.crit(pred, feed_dict['seg_label']) # 计算主预测的损失 (L_main)
            
            if self.deep_sup_scale is not None:
                loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) # 计算辅助预测的损失 (L_deepsup)
                loss = loss + loss_deepsup * self.deep_sup_scale # 总损失 = L_main + alpha * L_deepsup

            acc = self.pixel_acc(pred, feed_dict['seg_label']) # 计算主预测的像素准确率
            return loss, acc
        
        else: # 推理模式
            # 编码器提取特征,解码器进行预测,并根据 segSize 上采样到目标尺寸
            pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
            return pred

III. 模型构建器 (ModelBuilder)

负责实例化、初始化和加载模型的权重。

class ModelBuilder():
    # 自定义权重初始化函数
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1: # 如果模块名包含 'Conv' (如 Conv2d)
            nn.init.kaiming_normal_(m.weight.data) # 使用 Kaiming (He) 正态初始化
        elif classname.find('BatchNorm') != -1: # 如果模块名包含 'BatchNorm'
            m.weight.data.fill_(1.) # weight (尺度参数) 初始化为 1
            m.bias.data.fill_(1e-4) # bias (偏移参数) 初始化为接近 0 的小值
        #elif classname.find('Linear') != -1: # 线性层(全连接层)初始化(已注释)
        #    m.weight.data.normal_(0.0, 0.0001)

    def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
        pretrained = True if len(weights) == 0 else False # 检查是否需要加载官方预训练权重
        # fc_dim 在这里是编码器输出特征的通道数,用于解码器输入,通常由架构决定

        # --- 架构选择与实例化 (仅列出 ResNet-50 膨胀版本作为代表) ---
        # ... (省略 ResNet-18, 34 的逻辑,它们与 50 类似) ...
        elif arch == 'resnet50_dilated8':
            orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) # 加载 PyTorch 官方 ResNet-50 实例
            net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) # 封装为 8x 膨胀 ResNet 编码器
        # ... (省略其他 ResNet 和 ResNeXt 架构的逻辑) ...
        else:
            raise Exception('Architecture undefined!') # 如果 arch 不在列表中,则抛出异常

        # net_encoder.apply(self.weights_init) # 编码器通常依赖预训练权重,故自定义初始化被注释
        if len(weights) > 0:
            print('Loading weights for net_encoder')
            # 加载外部权重文件
            net_encoder.load_state_dict(
                torch.load(weights, map_location=lambda storage, loc: storage), strict=False) # strict=False 允许忽略不匹配的层
        return net_encoder

    def build_decoder(self, arch='ppm_bilinear_deepsup',
                      fc_dim=512, num_class=150,
                      weights='', use_softmax=False):
        # fc_dim: 编码器输出特征的通道数(解码器输入通道)
        # num_class: 最终分割的类别数
        
        # --- 架构选择与实例化 (仅列出 PPMBilinearDeepsup 和 UPerNet 作为代表) ---
        if arch == 'ppm_bilinear_deepsup':
            net_decoder = PPMBilinearDeepsup(num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax)
        elif arch == 'upernet':
            net_decoder = UPerNet(num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax, fpn_dim=512)
        # ... (省略其他解码器架构的逻辑) ...
        else:
            raise Exception('Architecture undefined!')

        net_decoder.apply(self.weights_init) # 对解码器应用自定义权重初始化
        if len(weights) > 0:
            print('Loading weights for net_decoder')
            net_decoder.load_state_dict(
                torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
        return net_decoder

IV. 编码器实现细节 (Resnet & ResnetDilated)

1. Resnet:标准 ResNet 特征提取器

class Resnet(nn.Module):
    def __init__(self, orig_resnet):
        super(Resnet, self).__init__()

        # 接管原始 ResNet 的所有层(除了最后的 AvgPool 和 FC 层)
        self.conv1 = orig_resnet.conv1
        self.bn1 = orig_resnet.bn1
        self.relu1 = orig_resnet.relu1
        self.conv2 = orig_resnet.conv2
        self.bn2 = orig_resnet.bn2
        self.relu2 = orig_resnet.relu2
        self.conv3 = orig_resnet.conv3
        self.bn3 = orig_resnet.bn3
        self.relu3 = orig_resnet.relu3
        self.maxpool = orig_resnet.maxpool
        self.layer1 = orig_resnet.layer1
        self.layer2 = orig_resnet.layer2
        self.layer3 = orig_resnet.layer3
        self.layer4 = orig_resnet.layer4

    def forward(self, x, return_feature_maps=False):
        conv_out = [] # 用于存储每个阶段的输出特征图

        # 初始阶段(在 maxpool 之前)的前向传播
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        # 四个主要残差阶段(Layer 1 到 Layer 4)的前向传播
        x = self.layer1(x); conv_out.append(x); # Layer 1 特征
        x = self.layer2(x); conv_out.append(x); # Layer 2 特征
        x = self.layer3(x); conv_out.append(x); # Layer 3 特征
        x = self.layer4(x); conv_out.append(x); # Layer 4 特征(最高层语义特征)

        if return_feature_maps:
            return conv_out # 返回包含所有阶段特征的列表
        return [x] # 如果不需要所有特征图,只返回最后一个特征图的列表

2. ResnetDilated:膨胀 ResNet 封装(核心)

通过修改底层卷积参数实现膨胀卷积。

class ResnetDilated(nn.Module):
    def __init__(self, orig_resnet, dilate_scale=8):
        super(ResnetDilated, self).__init__()
        from functools import partial

        if dilate_scale == 8:
            # 修改 ResNet conv1 的参数,使用膨胀率为 2
            orig_resnet.conv1.apply(partial(self._nostride_dilate, dilate=2))
            # ... (源码中省略了对 layer3 和 layer4 的修改,但通常这些层也需要修改以实现 8x 下采样)
            
        elif dilate_scale == 16:
            # 仅修改 layer4 中的卷积参数
            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))

        # 拷贝原始 ResNet 的所有层(与 Resnet 相同)
        self.conv1 = orig_resnet.conv1
        self.bn1 = orig_resnet.bn1
        # ... (省略其他层的拷贝) ...
        self.layer4 = orig_resnet.layer4

    def _nostride_dilate(self, m, dilate):
        # 递归遍历模块 m 中的所有卷积层并修改参数
        classname = m.__class__.__name__
        if classname.find('Conv') != -1: # 仅对 Conv2d 层操作
            # 1. 如果卷积层带有步长 (stride=2),通常用于下采样
            if m.stride == (2, 2):
                m.stride = (1, 1) # 取消下采样,将步长设为 1
                if m.kernel_size == (3, 3):
                    # 增加膨胀率和填充来保持感受野和尺寸
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            # 2. 如果是普通卷积 (stride=1)
            else:
                if m.kernel_size == (3, 3):
                    # 增加膨胀率和填充来扩大感受野和保持尺寸
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x, return_feature_maps=False):
        # 前向传播逻辑与 Resnet 相同
        conv_out = []
        # ... (逻辑相同) ...
        if return_feature_maps:
            return conv_out
        return [x]

V. 解码器实现细节

1. PPMBilinearDeepsup:带深度监督的金字塔池化

# pyramid pooling, bilinear upsample (带深度监督)
class PPMBilinearDeepsup(nn.Module):
    def __init__(self, num_class=150, fc_dim=4096,
                 use_softmax=False, pool_scales=(1, 2, 3, 6)):
        super(PPMBilinearDeepsup, self).__init__()
        self.use_softmax = use_softmax

        self.ppm = [] # 用于存储金字塔池化分支的列表
        for scale in pool_scales:
            self.ppm.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(scale), # 自适应平均池化到指定尺度
                nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), # 1x1 卷积调整通道
                SynchronizedBatchNorm2d(512),
                nn.ReLU(inplace=True)
            ))
        self.ppm = nn.ModuleList(self.ppm) # 转换为 nn.ModuleList

        self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) # 深度监督分支的卷积块 (输入通道为 fc_dim/2,即 P4 特征的通道)

        self.conv_last = nn.Sequential( # 主分支的最终分类卷积序列
            nn.Conv2d(fc_dim+len(pool_scales)*512, 512, # 输入通道:P5 原特征通道 + 所有 PPM 输出通道
                      kernel_size=3, padding=1, bias=False),
            SynchronizedBatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1), # Dropout 用于正则化
            nn.Conv2d(512, num_class, kernel_size=1) # 1x1 卷积输出类别分数
        )
        self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) # 辅助分支的最终分类层
        self.dropout_deepsup = nn.Dropout2d(0.1) # 辅助分支的 Dropout

    def forward(self, conv_out, segSize=None):
        conv5 = conv_out[-1] # 提取编码器最高层特征(P5)

        input_size = conv5.size()
        ppm_out = [conv5] # 列表初始包含 P5 原特征

        # 1. PPM 模块
        for pool_scale in self.ppm:
            # PyTorch 2.x 应该使用 nn.functional.interpolate,这里使用兼容写法
            ppm_out.append(nn.functional.upsample( 
                pool_scale(conv5),
                (input_size[2], input_size[3]),
                mode='bilinear', align_corners=False)) # 双线性上采样到 P5 尺寸
        ppm_out = torch.cat(ppm_out, 1) # 拼接所有特征
        x = self.conv_last(ppm_out) # 主分支输出

        if self.use_softmax:  # 推理模式
            x = nn.functional.upsample(x, size=segSize, mode='bilinear', align_corners=False) # 上采样到目标尺寸
            x = nn.functional.softmax(x, dim=1) # Softmax 转换为概率
            return x

        # 2. 深度监督分支
        conv4 = conv_out[-2] # 提取编码器倒数第二层特征(P4)
        _ = self.cbr_deepsup(conv4)
        _ = self.dropout_deepsup(_)
        _ = self.conv_last_deepsup(_) # 辅助预测 (pred_deepsup)

        x = nn.functional.log_softmax(x, dim=1) # 主预测转换为 log-概率
        _ = nn.functional.log_softmax(_, dim=1) # 辅助预测转换为 log-概率

        return (x, _) # 返回 (主预测, 辅助预测)

2. UPerNet:统一感知解析网络 (PPM + FPN)

class UPerNet(nn.Module):
    def __init__(self, num_class=150, fc_dim=4096,
                 use_softmax=False, pool_scales=(1, 2, 3, 6),
                 fpn_inplanes=(256,512,1024,2048), fpn_dim=256):
        super(UPerNet, self).__init__()
        self.use_softmax = use_softmax
        # fpn_inplanes: 编码器 L1, L2, L3, L4 (P2, P3, P4, P5) 的通道数

        # 1. PPM Module 初始化
        self.ppm_pooling = []
        self.ppm_conv = []
        for scale in pool_scales:
            self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
            self.ppm_conv.append(nn.Sequential(
                nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
                SynchronizedBatchNorm2d(512),
                nn.ReLU(inplace=True)
            ))
        self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
        self.ppm_conv = nn.ModuleList(self.ppm_conv)
        self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) # PPM 融合后,通道数统一为 fpn_dim

        # 2. FPN Module 初始化
        self.fpn_in = [] # 侧向连接(Lateral Branch)的 1x1 卷积列表
        for fpn_inplane in fpn_inplanes[:-1]: # P2, P3, P4 (跳过最高层 P5)
            self.fpn_in.append(nn.Sequential(
                nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), # 将通道数统一为 fpn_dim
                SynchronizedBatchNorm2d(fpn_dim),
                nn.ReLU(inplace=True)
            ))
        self.fpn_in = nn.ModuleList(self.fpn_in)

        self.fpn_out = [] # 顶部-到底部(Top-Down Branch)的 3x3 卷积列表
        for i in range(len(fpn_inplanes) - 1): 
            self.fpn_out.append(nn.Sequential(conv3x3_bn_relu(fpn_dim, fpn_dim, 1))) # 3x3 卷积进一步处理
        self.fpn_out = nn.ModuleList(self.fpn_out)

        # 3. 最终分类层
        self.conv_last = nn.Sequential(
            conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), # 拼接后的特征图进行融合(通道数是 4 * fpn_dim)
            nn.Conv2d(fpn_dim, num_class, kernel_size=1) # 最终 1x1 卷积输出类别分数
        )

    def forward(self, conv_out, segSize=None):
        conv5 = conv_out[-1] # P5 特征

        # 1. PPM 模块 (处理最高层特征)
        input_size = conv5.size()
        ppm_out = [conv5]
        for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
            ppm_out.append(pool_conv(nn.functional.upsample(
                pool_scale(conv5),
                (input_size[2], input_size[3]),
                mode='bilinear', align_corners=False)))
        ppm_out = torch.cat(ppm_out, 1) # PPM 拼接
        f = self.ppm_last_conv(ppm_out) # 得到 PPM 融合后的 P5 特征 f (通道数为 fpn_dim)

        # 2. FPN Module (自上而下融合)
        fpn_feature_list = [f] # 列表初始存放 P5 特征
        for i in reversed(range(len(conv_out) - 1)): # 从 P4 (i=2) 迭代到 P2 (i=0)
            conv_x = conv_out[i] # 编码器当前级别特征 (P4, P3, P2)
            conv_x = self.fpn_in[i](conv_x) # 侧向连接:1x1 卷积调整通道
            
            f = nn.functional.upsample(f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # Top-Down:上层特征上采样
            f = conv_x + f # 相加融合

            fpn_feature_list.append(self.fpn_out[i](f)) # 存储融合后的特征

        fpn_feature_list.reverse() # 调整顺序为 [P2_fused, P3_fused, P4_fused, P5_fused]

        # 3. FPN 最终融合 (将所有特征图上采样到 P2 尺寸并拼接)
        output_size = fpn_feature_list[0].size()[2:] # 以 P2 特征图的尺寸为基准
        fusion_list = [fpn_feature_list[0]]
        for i in range(1, len(fpn_feature_list)):
            fusion_list.append(nn.functional.upsample( # 将 P3, P4, P5 特征上采样到 P2 尺寸
                fpn_feature_list[i], output_size, mode='bilinear', align_corners=False))
        fusion_out = torch.cat(fusion_list, 1) # 拼接所有特征 (通道数是 4 * fpn_dim)
        x = self.conv_last(fusion_out) # 最终分类

        if self.use_softmax:  # 推理模式
            x = nn.functional.upsample(x, size=segSize, mode='bilinear', align_corners=False)
            x = nn.functional.softmax(x, dim=1)
            return x

        x = nn.functional.log_softmax(x, dim=1) # 训练模式
        return x

最终总结流程

该框架是用于语义分割的 模块化 实现,它将分割任务拆解为三个主要阶段:

  1. 特征提取 (Encoder): 使用 膨胀 ResNet (ResnetDilated)。通过修改步长和应用膨胀卷积,它能够在保持高分辨率(通常为输入尺寸的 $1/8$ 或 $1/16$)的同时,捕获大感受野,输出多级特征图。

  2. 特征融合与上下文提取 (Decoder):

    • PPM 解码器: 通过金字塔池化对最高级特征进行多尺度采样和融合,以捕获图像的全局上下文信息
    • UPerNet 解码器: 结合 PPM(用于最高层)和 FPN(用于多级特征融合),通过自上而下的路径将高级语义信息和低级空间细节有效结合,生成高准确度的像素级预测。
  3. 预测与监督 (Supervision):

    • 深度监督 (Deep Supervision): 在解码器的中间层(如 P4 特征)额外进行一次辅助预测,计算损失,并与主预测损失按权重相加,以增强中间层特征的学习能力。
    • 输出: 最终预测通过双线性插值上采样到输入图像的原始尺寸,并使用 log_softmax(用于训练)或 softmax(用于推理)输出。
posted @ 2025-12-01 21:08  学java的阿驴  阅读(6)  评论(0)    收藏  举报