MLA学习

🔍 MLA vs 传统 Attention:详细计算对比解析


🎯 背景说明

Transformer 中最核心的模块之一是 Multi-Head Self-Attention (MHSA),其计算复杂度随上下文长度增长迅速。而 MLA(Multi-head Latent Attention)通过引入固定数量的 latent vector,大幅减少推理中的计算量和 KV 缓存量。


⚙️ 基本符号定义

符号 含义
N 序列长度(token 个数)
d 表征维度(如 768)
h 注意力头数(如 12)
L latent 向量数量(如 8)
Q, K, V Query、Key、Value 表征
C 输入 token 表示(shape: N × d)

🧠 一、传统 Multi-Head Self-Attention(MHSA)

计算流程

  1. 输入 token 表征 \(\mathbf{h}_i \in \mathbb{R}^d\)

  2. 通过线性层获得:

    \[\mathbf{q}_i = W^Q \mathbf{h}_i,\quad \mathbf{k}_i = W^K \mathbf{h}_i,\quad \mathbf{v}_i = W^V \mathbf{h}_i \]

  3. Attention Score 计算(所有 token 两两交互):

    \[\text{score}_{ij} = \frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d}} \]

  4. Softmax 得到 attention 权重:

    \[\alpha_{ij} = \text{softmax}_j(\text{score}_{ij}) \]

  5. 聚合:

    \[\text{output}_i = \sum_{j=1}^{N} \alpha_{ij} \cdot \mathbf{v}_j \]

时间与内存复杂度

  • 计算复杂度:\(\mathcal{O}(N^2 \cdot d)\)
  • KV 缓存复杂度(推理):\(\mathcal{O}(N \cdot d)\)

🧬 二、MLA(Multi-Head Latent Attention)

关键创新

  • token-to-token attention 替换为 latent-to-token attention
  • 避免 \(N \times N\) 注意力矩阵

计算流程

  1. 输入 token \(\mathbf{h}_i\) 映射为:

    \[\mathbf{q}_i^c,\; \mathbf{q}_i^p,\; \mathbf{k}_i^c,\; \mathbf{k}_i^p,\; \mathbf{v}_i = \text{LinearProj}(\mathbf{h}_i) \]

  2. 拼接内容和位置(RoPE 编码):

    \[\mathbf{q}_i = [\mathbf{q}_i^c;\, \text{RoPE}(\mathbf{q}_i^p)],\quad \mathbf{k}_i = [\mathbf{k}_i^c;\, \text{RoPE}(\mathbf{k}_i^p)] \]

  3. 定义 \(L\) 个 learnable latent vector \(\{\mathbf{c}_t^q, \mathbf{c}_t^k\}_{t=1}^L\)

  4. 每个 latent \(t\) 的 attention 计算:

    • Score 计算:

      \[\text{score}_{t,i} = \frac{(\mathbf{Q}_t)(\mathbf{k}_i)^\top}{\sqrt{d}} \]

    • 聚合:

      \[\mathbf{u}_t = \sum_{i=1}^{N} \text{softmax}_i(\text{score}_{t,i}) \cdot \mathbf{v}_i \]

  5. 多个 latent 头并行 → 拼接 → 输出:

    \[\text{MLA}(H) = \text{Concat}(\mathbf{u}_1, ..., \mathbf{u}_L)W^O \]

时间与内存复杂度

  • 计算复杂度:\(\mathcal{O}(L \cdot N \cdot d)\) (通常 \(L \ll N\)
  • KV 缓存复杂度(推理):\(\mathcal{O}(L \cdot d)\) (固定大小)

🔬 三、核心对比总结

项目 MHSA MLA
交互对象 Token-to-token Latent-to-token
注意力矩阵尺寸 \(N \times N\) \(L \times N\)
推理缓存 (KV Cache) \(N \cdot d\) \(L \cdot d\)
计算复杂度 \(\mathcal{O}(N^2 \cdot d)\) \(\mathcal{O}(L \cdot N \cdot d)\)
可并行性
长文本效率

✅ MLA 的优势

  • 更低计算成本:避免 \(N^2\) 计算
  • 极低 KV 缓存成本:推理速度显著提升
  • 适用于长文本:在长上下文场景下表现尤为出色
  • 抽象信息提取能力强:Latent vector 具备聚合全局 token 信息能力

📚 参考资料

  • DeepSeek-VL 技术报告(2024)
  • Attention is All You Need(Vaswani et al., 2017)
  • Rotary Positional Embedding(RoFormer)
posted @ 2025-06-05 21:57  咖啡加油条  阅读(69)  评论(0)    收藏  举报