MLA学习
🔍 MLA vs 传统 Attention:详细计算对比解析
🎯 背景说明
Transformer 中最核心的模块之一是 Multi-Head Self-Attention (MHSA),其计算复杂度随上下文长度增长迅速。而 MLA(Multi-head Latent Attention)通过引入固定数量的 latent vector,大幅减少推理中的计算量和 KV 缓存量。
⚙️ 基本符号定义
| 符号 | 含义 | 
|---|---|
| N | 序列长度(token 个数) | 
| d | 表征维度(如 768) | 
| h | 注意力头数(如 12) | 
| L | latent 向量数量(如 8) | 
| Q, K, V | Query、Key、Value 表征 | 
| C | 输入 token 表示(shape: N × d) | 
🧠 一、传统 Multi-Head Self-Attention(MHSA)
计算流程
- 
输入 token 表征 \(\mathbf{h}_i \in \mathbb{R}^d\)
 - 
通过线性层获得:
\[\mathbf{q}_i = W^Q \mathbf{h}_i,\quad \mathbf{k}_i = W^K \mathbf{h}_i,\quad \mathbf{v}_i = W^V \mathbf{h}_i \] - 
Attention Score 计算(所有 token 两两交互):
\[\text{score}_{ij} = \frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d}} \] - 
Softmax 得到 attention 权重:
\[\alpha_{ij} = \text{softmax}_j(\text{score}_{ij}) \] - 
聚合:
\[\text{output}_i = \sum_{j=1}^{N} \alpha_{ij} \cdot \mathbf{v}_j \] 
时间与内存复杂度
- 计算复杂度:\(\mathcal{O}(N^2 \cdot d)\)
 - KV 缓存复杂度(推理):\(\mathcal{O}(N \cdot d)\)
 
🧬 二、MLA(Multi-Head Latent Attention)
关键创新
- 将 token-to-token attention 替换为 latent-to-token attention
 - 避免 \(N \times N\) 注意力矩阵
 
计算流程
- 
输入 token \(\mathbf{h}_i\) 映射为:
\[\mathbf{q}_i^c,\; \mathbf{q}_i^p,\; \mathbf{k}_i^c,\; \mathbf{k}_i^p,\; \mathbf{v}_i = \text{LinearProj}(\mathbf{h}_i) \] - 
拼接内容和位置(RoPE 编码):
\[\mathbf{q}_i = [\mathbf{q}_i^c;\, \text{RoPE}(\mathbf{q}_i^p)],\quad \mathbf{k}_i = [\mathbf{k}_i^c;\, \text{RoPE}(\mathbf{k}_i^p)] \] - 
定义 \(L\) 个 learnable latent vector \(\{\mathbf{c}_t^q, \mathbf{c}_t^k\}_{t=1}^L\)
 - 
每个 latent \(t\) 的 attention 计算:
- 
Score 计算:
\[\text{score}_{t,i} = \frac{(\mathbf{Q}_t)(\mathbf{k}_i)^\top}{\sqrt{d}} \] - 
聚合:
\[\mathbf{u}_t = \sum_{i=1}^{N} \text{softmax}_i(\text{score}_{t,i}) \cdot \mathbf{v}_i \] 
 - 
 - 
多个 latent 头并行 → 拼接 → 输出:
\[\text{MLA}(H) = \text{Concat}(\mathbf{u}_1, ..., \mathbf{u}_L)W^O \] 
时间与内存复杂度
- 计算复杂度:\(\mathcal{O}(L \cdot N \cdot d)\) (通常 \(L \ll N\))
 - KV 缓存复杂度(推理):\(\mathcal{O}(L \cdot d)\) (固定大小)
 
🔬 三、核心对比总结
| 项目 | MHSA | MLA | 
|---|---|---|
| 交互对象 | Token-to-token | Latent-to-token | 
| 注意力矩阵尺寸 | \(N \times N\) | \(L \times N\) | 
| 推理缓存 (KV Cache) | \(N \cdot d\) | \(L \cdot d\) | 
| 计算复杂度 | \(\mathcal{O}(N^2 \cdot d)\) | \(\mathcal{O}(L \cdot N \cdot d)\) | 
| 可并行性 | 高 | 高 | 
| 长文本效率 | 低 | 高 | 
✅ MLA 的优势
- 更低计算成本:避免 \(N^2\) 计算
 - 极低 KV 缓存成本:推理速度显著提升
 - 适用于长文本:在长上下文场景下表现尤为出色
 - 抽象信息提取能力强:Latent vector 具备聚合全局 token 信息能力
 
📚 参考资料
- DeepSeek-VL 技术报告(2024)
 - Attention is All You Need(Vaswani et al., 2017)
 - Rotary Positional Embedding(RoFormer)
 
                    
                
                
            
        
浙公网安备 33010602011771号