Fork me on GitHub

状态空间模型(State Space Model, SSM)

要把 S4(Structured State Space Sequence model) 放进深度神经网络里,本质上是把一个经典的 状态空间模型(State Space Model, SSM) 变成一个“可训练的序列层(layer)”。但很多人一开始容易误解:一层 ≠ 一个时间步,而是一层实现整个时间序列的状态演化

我们从数学形式 → 神经网络映射,一步一步拆。


一、S4 的核心状态空间方程

标准连续时间 SSM:

\[\begin{aligned} \frac{dx(t)}{dt} &= Ax(t) + Bu(t) \\ y(t) &= Cx(t) + Du(t) \end{aligned} \]

离散化后(真正用于神经网络):

\[x_{t+1} = \bar{A}x_t + \bar{B}u_t \]

\[y_t = Cx_t + Du_t \]

其中:

  • (x_t):隐藏状态(memory)
  • (u_t):输入(token embedding / feature)
  • (y_t):输出
  • (A,B,C,D):可学习参数

二、S4 在神经网络中的结构映射

1️⃣ 一层 S4 ≈ 一个“序列算子”

在 Transformer 里:

  • 一层 = attention + MLP

在 S4 里:

  • 一层 = SSM卷积 + 非线性

更具体:

\[y = \text{SSM}(u) = K * u \]

其中卷积核 (K) 来自状态空间:

\[K_t = C \bar{A}^t \bar{B} \]

关键点:
S4层 = 用状态空间生成的长卷积核


2️⃣ 参数对应关系(重点)

你问的核心:神经网络 vs 状态方程参数

状态空间 神经网络中的角色
(A) 状态转移矩阵(决定“记忆如何衰减/传播”)
(B) 输入投影(input → state)
(C) 状态读出(state → output)
(D) skip connection(类似 residual)

对应关系可以理解为:

input u_t
   ↓ (B)
state x_t
   ↓ (A dynamics across time)
state sequence
   ↓ (C)
output y_t
   + D u_t(残差)

三、时间维 vs 网络层:很多人搞反的地方

❗关键结论:

  • 时间步 t → 状态方程递推
  • 网络层 l → 参数不同的一组 (A, B, C, D)

类比 Transformer:

Transformer S4
token 维度 时间 t
layer 维度 不同 SSM 参数

👉 所以:

  • 不是一层对应一个时间步

  • 而是:

    • 一层处理整个序列
    • 多层堆叠 = 多尺度动态建模

四、前向传播(Forward Pass)

方式一:递归(理论形式)

\[x_{t+1} = \bar{A}x_t + \bar{B}u_t \]

逐步计算(类似 RNN):

for t:
    x[t+1] = A x[t] + B u[t]
    y[t] = C x[t] + D u[t]

方式二:S4 实际实现(关键创新)

用卷积替代递归

\[y = K * u \]

优点:

  • 并行计算(不像 RNN)
  • 长序列高效

✔ 前向过程对应关系

操作 状态空间解释
输入 embedding (u_t)
S4 kernel convolution 展开后的 (A,B,C) 动态
residual (D u_t)
activation 非线性(SSM外加)

五、反向传播(Backward Pass)

反向传播本质就是:

  • 对 (A,B,C,D) 求梯度

但关键难点是:

A 是高维结构化矩阵(HiPPO + diagonal + low-rank)

S4 的 trick:

  1. (A) 被结构化(HiPPO 矩阵)
  2. 用频域(FFT)计算卷积
  3. 梯度通过卷积核 (K) 反传

✔ 反向传播对应关系

梯度路径 实际含义
∂L/∂y → ∂L/∂K 卷积核梯度
∂L/∂K → ∂L/∂(A,B,C) 状态空间参数更新
∂L/∂u 传回上一层

六、一句话总结映射关系

如果压缩成最本质的理解:

S4 = 用状态空间系统参数化的长卷积层

  • 层 = 一组 (A,B,C,D)
  • 时间 = 状态递推展开
  • 前向 = 卷积(隐式状态传播)
  • 反向 = 对动态系统参数求导

七、再给你一个工程视角

实际代码中一层 S4 大概是:

def S4Layer(u):
    K = compute_kernel(A, B, C)   # 频域 or recurrence
    y = conv(u, K)                # 核心计算
    return y + D * u              # residual

然后堆叠:

x -> S4 -> GELU -> S4 -> GELU -> ...

八、一个容易踩的坑

如果你把 S4 理解成:

“类似 RNN 的 step-by-step 状态更新”

那你会错过它最重要的点:

它本质是“可并行的全局卷积模型”,不是RNN


Mamba

先把一个常见误解掐掉:Mamba 不是简单“更快的 S4”。它在“是否线性时不变(LTI)”这个根上动了刀,这直接改变了表达能力、并行方式和工程实现。

下面按“本质差异”来拆,而不是表面指标。


一、最核心差异:LTI vs 输入依赖(Selective)

S4:线性时不变系统(LTI)

S4 的状态方程(离散):

\[x_{t+1} = \bar{A}x_t + \bar{B}u_t,\quad y_t = Cx_t + Du_t \]

👉 关键点:

  • (A,B,C,D) 对所有时间步固定
  • 系统是 线性 + 时不变(LTI)

因此可以展开为卷积:

\[y = K * u \]


Mamba:输入驱动的动态系统(Selective SSM)

Mamba 做了一个关键改动:

\[x_{t+1} = A(u_t),x_t + B(u_t)u_t \]

参数变成:

  • 依赖输入 (u_t)(通过 gating / projection 生成)

✔ 本质区别一句话

  • S4:同一个滤波器处理所有 token
  • Mamba:每个 token 都有自己的滤波器

二、表达能力差异(为什么 Mamba 更强)

S4 的限制

由于 LTI:

  • 对所有输入使用同一个 kernel
  • 无法根据上下文动态调整

类似:

一个固定 attention pattern


Mamba 的能力

Mamba 引入:

  • selective scan(选择性状态更新)
  • 类似 attention 的“条件计算”

可以实现:

  • 内容依赖(content-based reasoning)
  • 动态记忆控制(类似 LSTM gate)

类比

模型 类比
S4 固定卷积核
Mamba 动态卷积核(input-conditioned)
Transformer fully dynamic attention

三、计算方式差异(决定工程性能)

S4:卷积(FFT)

  • 通过 kernel (K)
  • 用 FFT 加速

复杂度:

  • 训练:(O(n \log n))
  • 推理:不适合 streaming(要卷积)

Mamba:scan(并行前缀)

Mamba不能用卷积,因为:

系统不再是 LTI

所以它用:

  • parallel scan(类似 prefix sum)
  • GPU 友好的 recurrence

复杂度:

  • 训练:(O(n))
  • 推理:天然 streaming

✔ 工程本质差异

维度 S4 Mamba
核心算子 卷积 scan
并行方式 FFT prefix-scan
streaming
kernel 是否固定

四、信息选择机制(记忆 vs 忘记)

这是理解 Mamba 的关键。

S4

  • 所有信息都进入状态
  • 没有“选择性遗忘”

问题:

  • 容易被噪声污染
  • 难以做 reasoning

Mamba

引入 gating:

  • 控制:

    • 写入多少
    • 保留多少

类似:

  • LSTM 的 forget gate
  • Transformer 的 attention mask

✔ 本质

Mamba = SSM + attention-like selection


五、训练稳定性与长序列

S4

优点:

  • 数学上稳定(HiPPO)
  • 长序列建模强

缺点:

  • 不够灵活

Mamba

优点:

  • 同样长序列能力
  • 更强表达能力

代价:

  • 训练更复杂
  • 更依赖实现(kernel fusion)

六、结构层面对比

S4 block

x → SSM(conv) → FFN → residual

Mamba block

x 
 → input projection
 → selective SSM (scan)
 → output projection
 → residual

多了:

  • input-dependent 参数生成
  • gating

七、你可以这样理解演化路径

这是最有价值的视角

RNN  →  S4(线性SSM,可并行)
     →  Mamba(非线性SSM,带选择)
     →  Transformer(完全动态)

八、一个更“底层”的本质总结

如果你只记一个结论:

S4 的本质是“结构化长卷积”
Mamba 的本质是“可学习的动态状态机”


九、为什么现在趋势是 Mamba

不是因为 S4 不好,而是:

  • S4:擅长记忆(memory)
  • Mamba:既能记忆,又能选择(reasoning)

而大模型真正需要的是:

memory + selectivity + scalability


十、给你一个工程判断建议

如果你在做系统设计:

  • 长序列建模(信号 / 时间序列)
    S4 仍然很强

  • LLM / agent / reasoning
    优先 Mamba

posted @ 2026-04-20 16:10  stardsd  阅读(41)  评论(0)    收藏  举报