【CVPR2026】Multi-Modal Image Fusion via Intervention-Stable Feature Learning

论文标题:Multi-Modal Image Fusion via Intervention-Stable Feature Learning
作者:Xue Wang, Zheng Guan, Wenhua Qian, Chengchao Wang, Runzhuo Ma
发表时间:2026 年 3 月(arXiv 预印本)
会议:CVPR 2026
code:github

一、论文核心贡献

1. 研究问题

这篇论文关注的不是“怎么把红外和可见光简单拼起来”,而是:
现有多模态融合方法大多基于统计相关性学习,容易把数据集里的伪相关当成真实互补关系,导致分布变化时性能下降。
论文明确指出,当前 MMIF 方法主要停留在 association 层面,缺少对跨模态依赖关系的主动检验。

2. 创新方法

论文的核心创新有两层:

第一,提出三种干预策略,用来主动测试模态关系是否真实稳定:

  • Complementary masking:对两个模态做不重叠遮挡,检验跨模态补偿能力
  • Random masking:对相同区域同时遮挡,检验局部充分性
  • Modality dropout:直接去掉一个模态,检验模态必要性

第二,提出 CFI(Causal Feature Integrator)
通过双向跨模态 attention + gate,优先保留在干预下仍稳定的重要特征,而不是只保留高相关但不稳定的特征。

3. 模型结构

4. 关键结论

干预式训练能让模型学到更稳定的跨模态依赖,在公开基准、目标检测、语义分割以及跨域医学融合上都表现更强。
摘要和实验部分都强调了其在 benchmark 和 downstream tasks 上的 SOTA 或接近 SOTA 表现。


二、关键代码

当前仓库代码中,核心是 test.py推理实现;README 说明了运行方式,训练细节主要来自论文正文。

1. 核心模块:CFI

这是整套方法最关键的实现。

class CFI(nn.Module):
    def __init__(self, dim: int, reduce: int = 8, q_chunk: int = 0):
        super().__init__()
        self.dim = dim
        self.reduce = max(1, int(reduce))
        self.q_chunk = int(q_chunk)

        # 1x1投影生成Q,K,V
        self.q_v = nn.Conv2d(dim, dim, 1)
        self.k_v = nn.Conv2d(dim, dim, 1)
        self.v_v = nn.Conv2d(dim, dim, 1)

        self.q_i = nn.Conv2d(dim, dim, 1)
        self.k_i = nn.Conv2d(dim, dim, 1)
        self.v_i = nn.Conv2d(dim, dim, 1)

        self.gate = nn.Sequential(
            nn.Conv2d(dim, dim // 2, 3, padding=1),
            nn.BatchNorm2d(dim//2),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Conv2d(dim // 2, 1, 1),
            nn.Sigmoid()
        )

        # 融合后再refine
        self.refine = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.LeakyReLU(0.01, inplace=True),
        )

    def forward(self, f_vi, f_ir):
        # Q,K,V
        qv, kv, vv = self.q_v(f_vi), self.k_v(f_vi), self.v_v(f_vi)
        qi, ki, vi = self.q_i(f_ir), self.k_i(f_ir), self.v_i(f_ir)

        vi2ir = self._pooled_xattn(qv, ki, vi)
        ir2vi = self._pooled_xattn(qi, kv, vv)

        cross = (vi2ir + ir2vi)
        local = (f_vi + f_ir)

        g = (self.gate(cross))
        fused = g * cross + (1.0 - g) * local

        fused = self.refine(fused)
        return fused, g

这段代码对应论文里的核心思想:
先做双向跨模态交互,再用 gate 判断哪些区域更“稳定”,最后把 cross feature 和 local feature 自适应融合。


2. 关键逻辑流程:多尺度融合

网络整体是双编码器 + 多尺度 CFI + 解码输出。

class Network(nn.Module):
    def __init__(self, dim: int = 32):
        super().__init__()
        self.encoder_ir = Encoder(dim=dim)
        self.encoder_vis = Encoder(dim=dim)

        self.cfi3 = CFI(dim)
        self.cfi4 = CFI(dim)
        self.cfi5 = CFI(dim)

        self.final_decoder = nn.Sequential(
            nn.Conv2d(dim, 1, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, vi, ir):
        v1, v2, v3 = self.encoder_vis(vi)
        i1, i2, i3 = self.encoder_ir(ir)

        # ===== level 3 ( H/4) =====
        f3, g3 = self.cfi3(v3, i3)
        x3 = f3

        # ===== level 4 ( H/2) =====
        x4_up = F.interpolate(x3, size=v2.shape[-2:], mode='bilinear', align_corners=False)
        f4, g4 = self.cfi4(v2, i2)
        x4 = x4_up + f4

        # ===== level 5 ( H) =====
        x5_up = F.interpolate(x4, size=v1.shape[-2:], mode='bilinear', align_corners=False)
        f5, g5 = self.cfi5(v1, i1)
        x5 = x5_up + f5

        out = self.final_decoder(x5)
        return out, (x3, x4, x5), {'g3': g3, 'g4': g4, 'g5': g5}

一句话概括流程:

红外/可见光分别编码 → 在 3 个尺度上用 CFI 做稳定性融合 → 逐层上采样 → 输出融合图。


3. 调用示例:推理入口

README 说明直接运行 test.py 即可,融合结果保存在 ./Fused/

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Network().to(device)

model.load_state_dict(torch.load(r"./best.pth", map_location=device))
model.eval()

test_root = "./test symlink"
ir = prepare_data(os.path.join(test_root, "ir"))
vi = prepare_data1(os.path.join(test_root, "vi"))
save_path = os.path.join("./Fused", "test symlink")

with torch.no_grad():
    output, _, _ = model(Y, img_ir)

代码里的一个关键点是:

可见光图像先转成 YCrCb,只把 Y 通道送进网络做融合,最后再和原 Cr/Cb 拼回彩色结果。
这说明模型主要在做“亮度/结构融合”,颜色信息沿用可见光原图。


三、总结

这篇论文最值得记的就两点:

论文层面
它把多模态融合从“学相关性”推进到“学干预下仍稳定的特征”。

代码层面
最关键的不是 backbone,而是 CFI = 双向跨模态 attention + 稳定性 gate + 多尺度融合

posted @ 2026-04-22 21:07  可以解决问题不大  阅读(17)  评论(0)    收藏  举报