论文阅读 | MxT

Mamba × Transformer

Durham arxiv

Basis

  1. State Space Models (SSMs) - Mamba 基础理论,源于控制理论
  2. Mamba - 选择性 SSMs 之一,与 Transformer 在长距离学习上优势互补。
    • Mamba: 长距离,pixel-wise (因其来源于序列建模)
    • Transformer: 长距离,块间 (global interactions between localized patches)
  3. Overlapped Conv - 重叠卷积
  4. SRSA (Spatial Reduced Self-Att) 通过整合空间信息来减少 Transformer 中使用的自注意力机制的维度。常规 Transformer 的计算复杂度随输入序列长度二次增长。
  5. LayerNorm (层归一化)
  6. 1 × 1 Conv (到底什么用 ?)
  7. Cosine Positional Embedding

Intro

  1. Inpainting 需要平衡局部纹理复制和全局上下文理解。“This technique requires a precise balance of local texture replication and global contextual understanding to ensure the restored image integrates seamlessly with its surroundings.”
  2. CNN 感受野问题。只能捕获局部模式,无法处理更长的上下文关系。“Traditional methods using Convolutional Neural Networks (CNNs) are effective at capturing local patterns but often struggle with broader contextual relationships due to the limited receptive fields.” —— 这个问题在 GSDM 也提到过,即「Conv 无法在对角位置传递信息」他们的解决方案是空洞卷积。
  3. Transformers 改善了对全局信息的处理 “leveraging their ability to understand global interactions.” 然而,面临计算复杂性和丢失细粒度信息问题。“face computational inefficiencies and struggle to maintain fine-grained details.”
  4. 提出 M×T. 提到,Mamba 擅长以线性计算成本高效处理长序列,使其成为处理长尺度数据交互的 Transformer 的理想补充。“Mamba is adept at efficiently processing long sequences with linear computational costs, making it an ideal complement to the transformer for handling long-scale data interactions.”
  5. 实现 pixel-wise + patch-wise 两级交互学习,增强模型生成图片的质量和上下文准确性 (Context-Encoder 提过类似内容,语意连贯 + 风格合适)

Preliminary

State Space Models (SSM)

SSM 是线性时不变系统。其通过一个隐状态 \(h(t) \in \mathbb{R}^N\) 将一维输入序列 \(x(t) \in \mathbb{R}\) 映射到响应 \(y(t) \in \mathbb{R}\).

Mamba

选择性 SSMs,引入门控选择机制 (gated-selective mechanism) 基于当前状态来传播或消除选中的信息,提高内容推理 (content-reasoning) 性能。

Mamba 是时变的。

Methodology

Hybrid Block

M × T Pipeline 由 7 个混合块 (HB) 构成一个 U-Net 结构。原始的遮挡图 \(I_{masked} \in \mathbb{R}^{H \times W \times C (=3)}\) 与掩模 \(M \in \mathbb{R}^{H \times W \times 1}\) 进行 concat 操作,得到输入 \(I_{in}\),通过卷积进行嵌入,而后输入 (feed into) 7 个 HBs,共 3 次下采样和 3 次上采样。最后,由一个卷积层对输出 \(I_{out}\) 进行投影。

Hybrid Module

每个 HB 由 \(n\) 个 HMs 构成。每个 HM 包含一个 Transformer 块 (SRSA),一个 Mamba 块,用于捕获长期依赖。以及一个 上下文广播前馈网络 (CBFN) 用于增强局部上下文并控制数据流的一致性。

SRSA

兼顾全局相关性和局部上下文细节。

\[\begin{aligned} \text{Input Feature: }&F \\ \tag{1} \text{Output Local Features: }&F' \end{aligned}\]

局部特征 \(F’\) 提取过程如下。

\[F’ \leftarrow Conv_{3 \times 3} \leftarrow Conv_{1 \times 1} \leftarrow LayerNorm(F) \tag{2} \]

沿通道维度 (channel dimensions) 切分 \(F’\) 得到 \(Q,K,V\). 参考 PVTv2 将 \(K,V\) 平均池化到固定维度。降低计算量。

\[\begin{aligned} K', V' &= \text{AvgPool}(K),\text{AvgPool}(V) \\ \tag{3} \text{Att} &= \text{softmax}(K' \cdot Q) \\ \text{Initial Output: }F'' &= Att \times V \end{aligned}\]

为了增强局部上下文,进行一个 local enhance 操作 (实际上是 Conv)

\[\begin{aligned} LE(V) &= Conv_{3 \times 3}(V) \\ \tag{4} Output_{SRSA} &= LE(V) + F'' \end{aligned}\]

Mamba + PE

Mamba 用于对扁平 (flattened,多维特征展平到一维) 特征建模,捕获像素级的长期依赖。用自注意力处理则成本太高。

\[F' \leftarrow \text{Tr}\leftarrow F_{B,C,L}\leftarrow\text{Reshape} \leftarrow F_{B,C,H,W} \tag{5} \]

PE (Positional Embedding) 用于增强 Mamba 保持位置感知 (?) 的能力。

\[F'' = F' + PE(L) \tag{6} \]

Mamba 简介

Body-branch Mamba 架构的主干部分,负责处理输入序列的信息提取和编码。

Gate-branch 是 Mamba 架构的控制门部分,主要用于控制信息的流动和筛选。它通过学习适当的门控机制,来调整 body-branch 中的信息传递和处理,以提高模型的表达能力和泛化能力。

得到 \(F''\) 可计算 \(F_{body}\)\(F_{shaped}\),最终得到 \(Output_{mamba}\).

Context Broadcasting Feed-forward Network (CBFN)

基于 GDFN 改进。在 GDFN 增强局部上下文的基础上,添加一个全局处理 post-GDFN.

\[F_{CBFN} = F_{GDFN} + broadcast(\mu) \tag{7} \]

(这个动作是否相当于增强低频)

Loss

  • \(\mathcal{L}_1\)
  • \(\mathcal{L}_{style}\): style loss
  • \(\mathcal{L}_{perceptual}\): 感知损失
  • \(\mathcal{L}_{adv}\): 对抗损失 (见过)

加权环节。不知道怎么加的。

Summary

写得很清楚,难得看懂

posted @ 2025-04-18 14:29  Miya_Official  阅读(46)  评论(0)    收藏  举报