FCN-ResNet18 语义分割完整实现详解

好的!我来把这段代码整理成博客园风格的笔记,一段代码一段讲解:

FCN-ResNet18 语义分割完整实现详解

1. 导入必要的库

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

代码说明

  • torch:PyTorch深度学习框架
  • torchvision:提供预训练模型和数据集
  • nn:神经网络模块
  • F:函数式接口
  • d2l:《动手学深度学习》工具库

2. VOC数据集类别和颜色定义

VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
    'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
    'horse', 'motorbike', 'person', 'potted plant', 'sheep',
    'sofa', 'train', 'tv/monitor'
]

VOC_COLORMAP = [
    [0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
    [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
    [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
    [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
    [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]
]

代码说明

  • VOC_CLASSES:21个语义类别名称
  • VOC_COLORMAP:每个类别对应的RGB颜色值
  • 背景为黑色[0,0,0],飞机为红色[128,0,0]
  • 这是PASCAL VOC数据集的官方定义

3. 创建颜色到标签的映射字典

colormap2label = torch.zeros(256 ** 3, dtype=torch.long)
for i, colormap in enumerate(VOC_COLORMAP):
    # 将RGB颜色转换为唯一索引:R*256^2 + G*256 + B
    color_index = (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]
    colormap2label[color_index] = i

代码说明

  • 创建大小为256³的查找表(覆盖所有RGB颜色)
  • 将每个VOC颜色映射到对应的类别ID
  • 例如:黑色[0,0,0] → 索引0 → 类别0(背景)

4. 高效的标签转换函数

def voc_label_indices_fast(colormap, colormap2label):
    """使用查找表快速将RGB标签图转换为类别ID"""
    # colormap: (H, W, 3) RGB图像
    # 将RGB图像转换为索引:R*256^2 + G*256 + B
    indices = (colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 + colormap[:, :, 2]
    # 使用查找表直接映射到类别ID
    return colormap2label[indices]

def preprocess_mask(mask, colormap2label):
    """预处理掩码:RGB → 类别ID"""
    if mask.dim() == 3 and mask.shape[-1] == 3:  # 如果是RGB图像
        mask = voc_label_indices_fast(mask, colormap2label)
    return mask

代码说明

  • voc_label_indices_fast:批量处理整个图像,比逐像素循环快很多
  • 利用向量化操作一次性计算所有像素的索引
  • preprocess_mask:封装函数,自动判断输入格式

5. 双线性插值卷积核

def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    center = factor - 1 if kernel_size % 2 == 1 else factor - 0.5
    og = torch.arange(kernel_size).reshape(-1, 1), torch.arange(kernel_size).reshape(1, -1)
    filt = (1 - torch.abs(og[0] - center) / factor) * (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels, kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight

代码说明

  • 创建双线性插值权重核
  • 中心权重最大,向边缘逐渐减小
  • 用于初始化转置卷积,实现平滑的上采样

6. 构建FCN-ResNet18网络

# 1) 加载预训练 ResNet18,做 encoder
pretrained_net = torchvision.models.resnet18(pretrained=True)
net = nn.Sequential(*list(pretrained_net.children())[:-2])

# 2) segmentation head
num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv',
               nn.ConvTranspose2d(num_classes, num_classes,
                                  kernel_size=64, padding=16, stride=32))

# 3) 初始化反卷积为双线性插值
net.transpose_conv.weight.data.copy_(bilinear_kernel(num_classes, num_classes, 64))

代码说明

  • 编码器:使用ResNet18(去掉最后两层)
  • 1×1卷积:将512特征通道转换为21个类别通道
  • 转置卷积:32倍上采样,恢复原始分辨率
  • 双线性初始化:避免棋盘伪影,加速收敛

7. 演示颜色映射过程

def demonstrate_colormap2label():
    """演示colormap2label的使用方法"""
    print("=== 演示colormap2label映射 ===")

    test_colors = [
        [0, 0, 0],        # 背景 - 黑色
        [128, 0, 0],      # 飞机 - 红色
        [0, 128, 0],      # 自行车 - 绿色
        [192, 128, 128],  # 人 - 灰色
        [255, 255, 255]   # 不在VOC中的颜色
    ]

    for color in test_colors:
        color_index = (color[0] * 256 + color[1]) * 256 + color[2]
        class_id = colormap2label[color_index].item()
        if class_id == 0 and color != [0, 0, 0]:
            class_name = "未知类别"
        else:
            class_name = VOC_CLASSES[class_id]
        print(f"RGB{color} -> 索引{color_index} -> 类别{class_id}: {class_name}")

# 运行演示
demonstrate_colormap2label()

输出示例

=== 演示colormap2label映射 ===
RGB[0, 0, 0] -> 索引0 -> 类别0: background
RGB[128, 0, 0] -> 索引8388608 -> 类别1: aeroplane
RGB[0, 128, 0] -> 索引32768 -> 类别2: bicycle
RGB[192, 128, 128] -> 索引12632256 -> 类别15: person
RGB[255, 255, 255] -> 索引16777215 -> 类别0: 未知类别

8. 加载VOC数据集

batch_size = 32
crop_size = (320, 480)
print("\n加载VOC数据集...")
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)

# 检查一个批次
for X, Y in train_iter:
    print(f"输入图像形状: {X.shape}")  # (batch, 3, H, W)
    print(f"标签形状: {Y.shape}")      # (batch, H, W) - 已经是类别ID
    break

代码说明

  • d2l.load_data_voc自动处理数据加载和预处理
  • 输入图像:(32, 3, 320, 480) - 批量32,3通道,320×480分辨率
  • 标签:(32, 320, 480) - 每个像素是0-20的类别ID

9. 定义损失函数

def loss(inputs, targets):
    return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)

代码说明

  • inputs: (N, 21, H, W) - 21个类别的概率图
  • targets: (N, H, W) - 每个像素的真实类别ID
  • 先计算每个像素的交叉熵,然后在空间维度取平均

10. 模型训练

num_epochs = 5
lr = 0.001
wd = 1e-3
devices = d2l.try_all_gpus()

print(f"\n开始训练,使用设备: {devices}")
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

print("训练完成!")

代码说明

  • 训练5个epoch,学习率0.001
  • 使用SGD优化器,权重衰减1e-3
  • 自动检测并使用所有可用的GPU
  • d2l.train_ch13封装了标准的训练流程

关键技术点总结

  1. 全卷积网络:去除全连接层,支持任意尺寸输入
  2. 编码器-解码器结构:ResNet18编码 + 转置卷积解码
  3. 双线性初始化:转置卷积权重初始化为双线性插值
  4. 逐像素分类:每个像素独立进行21分类
  5. 颜色映射:RGB标签图 → 类别ID图

这个实现展示了现代语义分割网络的核心思想,结合了迁移学习和端到端训练的优势。

posted @ 2025-11-07 18:46  学java的阿驴  阅读(8)  评论(0)    收藏  举报