【CVPR2026】Multi-Modal Image Fusion via Intervention-Stable Feature Learning
论文标题:Multi-Modal Image Fusion via Intervention-Stable Feature Learning
作者:Xue Wang, Zheng Guan, Wenhua Qian, Chengchao Wang, Runzhuo Ma
发表时间:2026 年 3 月(arXiv 预印本)
会议:CVPR 2026
code:github
一、论文核心贡献
1. 研究问题
这篇论文关注的不是“怎么把红外和可见光简单拼起来”,而是:
现有多模态融合方法大多基于统计相关性学习,容易把数据集里的伪相关当成真实互补关系,导致分布变化时性能下降。
论文明确指出,当前 MMIF 方法主要停留在 association 层面,缺少对跨模态依赖关系的主动检验。

2. 创新方法
论文的核心创新有两层:
第一,提出三种干预策略,用来主动测试模态关系是否真实稳定:
- Complementary masking:对两个模态做不重叠遮挡,检验跨模态补偿能力
- Random masking:对相同区域同时遮挡,检验局部充分性
- Modality dropout:直接去掉一个模态,检验模态必要性
第二,提出 CFI(Causal Feature Integrator):
通过双向跨模态 attention + gate,优先保留在干预下仍稳定的重要特征,而不是只保留高相关但不稳定的特征。
3. 模型结构
4. 关键结论
干预式训练能让模型学到更稳定的跨模态依赖,在公开基准、目标检测、语义分割以及跨域医学融合上都表现更强。
摘要和实验部分都强调了其在 benchmark 和 downstream tasks 上的 SOTA 或接近 SOTA 表现。
二、关键代码
当前仓库代码中,核心是
test.py的推理实现;README 说明了运行方式,训练细节主要来自论文正文。
1. 核心模块:CFI
这是整套方法最关键的实现。
class CFI(nn.Module):
def __init__(self, dim: int, reduce: int = 8, q_chunk: int = 0):
super().__init__()
self.dim = dim
self.reduce = max(1, int(reduce))
self.q_chunk = int(q_chunk)
# 1x1投影生成Q,K,V
self.q_v = nn.Conv2d(dim, dim, 1)
self.k_v = nn.Conv2d(dim, dim, 1)
self.v_v = nn.Conv2d(dim, dim, 1)
self.q_i = nn.Conv2d(dim, dim, 1)
self.k_i = nn.Conv2d(dim, dim, 1)
self.v_i = nn.Conv2d(dim, dim, 1)
self.gate = nn.Sequential(
nn.Conv2d(dim, dim // 2, 3, padding=1),
nn.BatchNorm2d(dim//2),
nn.LeakyReLU(0.01, inplace=True),
nn.Conv2d(dim // 2, 1, 1),
nn.Sigmoid()
)
# 融合后再refine
self.refine = nn.Sequential(
nn.Conv2d(dim, dim, 3, padding=1),
nn.BatchNorm2d(dim),
nn.LeakyReLU(0.01, inplace=True),
)
def forward(self, f_vi, f_ir):
# Q,K,V
qv, kv, vv = self.q_v(f_vi), self.k_v(f_vi), self.v_v(f_vi)
qi, ki, vi = self.q_i(f_ir), self.k_i(f_ir), self.v_i(f_ir)
vi2ir = self._pooled_xattn(qv, ki, vi)
ir2vi = self._pooled_xattn(qi, kv, vv)
cross = (vi2ir + ir2vi)
local = (f_vi + f_ir)
g = (self.gate(cross))
fused = g * cross + (1.0 - g) * local
fused = self.refine(fused)
return fused, g
这段代码对应论文里的核心思想:
先做双向跨模态交互,再用 gate 判断哪些区域更“稳定”,最后把 cross feature 和 local feature 自适应融合。
2. 关键逻辑流程:多尺度融合
网络整体是双编码器 + 多尺度 CFI + 解码输出。
class Network(nn.Module):
def __init__(self, dim: int = 32):
super().__init__()
self.encoder_ir = Encoder(dim=dim)
self.encoder_vis = Encoder(dim=dim)
self.cfi3 = CFI(dim)
self.cfi4 = CFI(dim)
self.cfi5 = CFI(dim)
self.final_decoder = nn.Sequential(
nn.Conv2d(dim, 1, 3, padding=1),
nn.Sigmoid()
)
def forward(self, vi, ir):
v1, v2, v3 = self.encoder_vis(vi)
i1, i2, i3 = self.encoder_ir(ir)
# ===== level 3 ( H/4) =====
f3, g3 = self.cfi3(v3, i3)
x3 = f3
# ===== level 4 ( H/2) =====
x4_up = F.interpolate(x3, size=v2.shape[-2:], mode='bilinear', align_corners=False)
f4, g4 = self.cfi4(v2, i2)
x4 = x4_up + f4
# ===== level 5 ( H) =====
x5_up = F.interpolate(x4, size=v1.shape[-2:], mode='bilinear', align_corners=False)
f5, g5 = self.cfi5(v1, i1)
x5 = x5_up + f5
out = self.final_decoder(x5)
return out, (x3, x4, x5), {'g3': g3, 'g4': g4, 'g5': g5}
一句话概括流程:
红外/可见光分别编码 → 在 3 个尺度上用 CFI 做稳定性融合 → 逐层上采样 → 输出融合图。
3. 调用示例:推理入口
README 说明直接运行 test.py 即可,融合结果保存在 ./Fused/。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Network().to(device)
model.load_state_dict(torch.load(r"./best.pth", map_location=device))
model.eval()
test_root = "./test symlink"
ir = prepare_data(os.path.join(test_root, "ir"))
vi = prepare_data1(os.path.join(test_root, "vi"))
save_path = os.path.join("./Fused", "test symlink")
with torch.no_grad():
output, _, _ = model(Y, img_ir)
代码里的一个关键点是:
可见光图像先转成 YCrCb,只把 Y 通道送进网络做融合,最后再和原 Cr/Cb 拼回彩色结果。
这说明模型主要在做“亮度/结构融合”,颜色信息沿用可见光原图。
三、总结
这篇论文最值得记的就两点:
论文层面:
它把多模态融合从“学相关性”推进到“学干预下仍稳定的特征”。
代码层面:
最关键的不是 backbone,而是 CFI = 双向跨模态 attention + 稳定性 gate + 多尺度融合。

浙公网安备 33010602011771号