注意力机制 MHA、MQA和GQA
本文结合chatgpt生成
1. MHA:Multi-Head Attention,多头注意力
标准多头注意力里,每个 head 都有自己独立的 Q、K、V 投影。
假设:
- hidden size =
d_model - head 数 =
h - 每个 head 维度 =
d_head - 通常
d_model = h * d_head
那么 MHA 中,X维度为(batch, seq_len, d_model),W_QKV维度为(batch, d_model, d_model):
Q = X W_Q
K = X W_K
V = X W_V
然后切成 h 个 head:
Q: [batch, seq_len, h, d_head]
K: [batch, seq_len, h, d_head]
V: [batch, seq_len, h, d_head]
虽然有h个注意力头,但第 i 个 query head 只和第 i 个 key/value head 做 attention:
head_i = softmax(Q_i K_i^T / sqrt(d_head)) V_i
所以 MHA 的结构,各个头的数量是:
Q heads: h
K heads: h
V heads: h
也就是每个 attention head 都有独立的 K/V 表示。
优点
表达能力强。每个 head 可以学习不同的匹配模式、上下文选择方式和信息子空间。
缺点
推理时 KV cache 很大。
自回归生成时,每生成一个 token,都要缓存所有层的 K/V。MHA 的 KV cache 大小大致正比于:
num_layers * seq_len * h * d_head * 2
这里的 2 是 K 和 V。
长上下文和大 batch 推理时,KV cache 会成为显存瓶颈。
2. MQA:Multi-Query Attention,多查询注意力
MQA 的思路是:Query 仍然有多个 head,但所有 Query heads 共享同一组 K/V head。
结构变成:
Q heads: h
K heads: 1
V heads: 1
shape 大致是:
Q: [batch, seq_len, h, d_head]
K: [batch, seq_len, 1, d_head]
V: [batch, seq_len, 1, d_head]
每个 query head 都拿自己的 Q 去和同一份 K/V 做 attention:
head_i = softmax(Q_i K_shared^T / sqrt(d_head)) V_shared
直观上,MHA 是:
Q1 -> K1,V1
Q2 -> K2,V2
Q3 -> K3,V3
...
Qh -> Kh,Vh
MQA 是:
Q1 -> K,V
Q2 -> K,V
Q3 -> K,V
...
Qh -> K,V
优点
KV cache 大幅下降。
原来 MHA 要缓存 h 组 K/V,现在只缓存 1 组。理论上 KV cache 减少约 h 倍。
比如 h = 32,那么 MQA 的 KV cache 约为 MHA 的 1/32。
这对大模型长文本推理非常重要,因为生成阶段很多时候不是算力不够,而是显存带宽和 KV cache 读写成为瓶颈。
缺点
表达能力下降。
所有 query heads 共享同一套 K/V,意味着不同 head 虽然可以用不同 Query 去“问问题”,但它们看到的 key/value 表示空间是一样的。这样会损失一部分多头注意力的多样性。
所以 MQA 通常推理快、省显存,但模型质量可能比 MHA 略差,尤其是在大模型或复杂任务上。
3. GQA:Grouped-Query Attention,分组查询注意力
GQA 是 MHA 和 MQA 的折中。
它把 Query heads 分成若干组,每一组 Query heads 共享一组 K/V。
假设:
Q heads = h
KV heads = g
其中:
1 < g < h
那么结构是:
Q: [batch, seq_len, h, d_head]
K: [batch, seq_len, g, d_head]
V: [batch, seq_len, g, d_head]
每 h / g 个 Query heads 共享同一个 KV head。
例如:
h = 32
g = 8
那么每 4 个 Q heads 共享一组 K/V:
Q1, Q2, Q3, Q4 -> K1, V1
Q5, Q6, Q7, Q8 -> K2, V2
...
Q29, Q30, Q31, Q32 -> K8, V8
和 MHA/MQA 的关系
GQA 可以看作一个连续谱:
MHA: g = h
GQA: 1 < g < h
MQA: g = 1
也就是说:
MHA = 每个 Q head 独享 K/V
MQA = 所有 Q head 共享一组 K/V
GQA = 一组 Q heads 共享一组 K/V
4. 三者的核心区别
| 方法 | Q head 数 | KV head 数 | KV cache | 表达能力 | 推理速度 |
|---|---|---|---|---|---|
| MHA | h | h | 最大 | 最强 | 较慢 |
| GQA | h | g | 中等 | 接近 MHA | 较快 |
| MQA | h | 1 | 最小 | 相对较弱 | 最快 |
其中:
MHA: num_kv_heads = num_attention_heads
GQA: 1 < num_kv_heads < num_attention_heads
MQA: num_kv_heads = 1
在很多模型配置里,你会看到类似:
"num_attention_heads": 32,
"num_key_value_heads": 8
这就是 GQA。
如果是:
"num_attention_heads": 32,
"num_key_value_heads": 32
就是 MHA。
如果是:
"num_attention_heads": 32,
"num_key_value_heads": 1
就是 MQA。
5. 为什么主要影响推理,而不是训练?
严格说训练也会受影响,但 MQA/GQA 最重要的收益在 自回归推理。
生成第 t 个 token 时,当前 token 的 Q 只来自当前 token,但 K/V 来自历史所有 token,并且历史 K/V 会被缓存:
当前步:
Q_current: 只算当前 token
K_cache, V_cache: 读历史所有 token
MHA 的历史 K/V 是:
[batch, past_seq_len, h, d_head]
GQA 是:
[batch, past_seq_len, g, d_head]
MQA 是:
[batch, past_seq_len, 1, d_head]
所以长上下文生成时,MQA/GQA 可以显著减少 KV cache 占用和显存带宽压力。
训练阶段通常是全序列并行计算 attention,训练时整段文本一次性算完,不需要为每个生成步反复读历史 K/V,不像推理那样一 token 一 token 读取 KV cache,因此收益没有推理阶段那么突出。
6. 一个直观类比
可以把 attention 想成很多个“提问者”在查资料。
MHA
每个提问者都有自己的索引系统和资料库:
Q1 查 K1/V1
Q2 查 K2/V2
Q3 查 K3/V3
信息组织最丰富,但资料库最多,存储开销大。
MQA
所有提问者共享同一个索引系统和资料库:
Q1 查 K/V
Q2 查 K/V
Q3 查 K/V
最省空间,但所有人查的是同一套资料结构。
GQA
几个提问者共享一套资料库:
Q1, Q2, Q3, Q4 查 K1/V1
Q5, Q6, Q7, Q8 查 K2/V2
比 MQA 更有表达力,比 MHA 更省显存。
7. 实际模型为什么喜欢 GQA?
因为 GQA 的性价比通常最好。
MHA 质量好,但 KV cache 太大。
MQA 速度快、省显存,但质量可能掉得比较明显。
GQA 在两者之间,能显著减少 KV cache,同时保留相当多的多头多样性。
所以现代大语言模型经常采用 GQA,尤其是长上下文模型和需要高吞吐推理的模型。
一句话总结:
MHA 是每个 query head 配一套 K/V;MQA 是所有 query heads 共用一套 K/V;GQA 是若干个 query heads 共用一套 K/V。GQA 是 MHA 和 MQA 之间的工程折中。

浙公网安备 33010602011771号