注意力机制 MHA、MQA和GQA

本文结合chatgpt生成

1. MHA:Multi-Head Attention,多头注意力

标准多头注意力里,每个 head 都有自己独立的 Q、K、V 投影。

假设:

  • hidden size = d_model
  • head 数 = h
  • 每个 head 维度 = d_head
  • 通常 d_model = h * d_head

那么 MHA 中,X维度为(batch, seq_len, d_model),W_QKV维度为(batch, d_model, d_model):

Q = X W_Q
K = X W_K
V = X W_V

然后切成 h 个 head:

Q: [batch, seq_len, h, d_head]
K: [batch, seq_len, h, d_head]
V: [batch, seq_len, h, d_head]

虽然有h个注意力头,但第 i 个 query head 只和第 i 个 key/value head 做 attention:

head_i = softmax(Q_i K_i^T / sqrt(d_head)) V_i

所以 MHA 的结构,各个头的数量是:

Q heads: h
K heads: h
V heads: h

也就是每个 attention head 都有独立的 K/V 表示。

优点

表达能力强。每个 head 可以学习不同的匹配模式、上下文选择方式和信息子空间。

缺点

推理时 KV cache 很大。

自回归生成时,每生成一个 token,都要缓存所有层的 K/V。MHA 的 KV cache 大小大致正比于:

num_layers * seq_len * h * d_head * 2

这里的 2 是 K 和 V。

长上下文和大 batch 推理时,KV cache 会成为显存瓶颈。


2. MQA:Multi-Query Attention,多查询注意力

MQA 的思路是:Query 仍然有多个 head,但所有 Query heads 共享同一组 K/V head。

结构变成:

Q heads: h
K heads: 1
V heads: 1

shape 大致是:

Q: [batch, seq_len, h, d_head]
K: [batch, seq_len, 1, d_head]
V: [batch, seq_len, 1, d_head]

每个 query head 都拿自己的 Q 去和同一份 K/V 做 attention:

head_i = softmax(Q_i K_shared^T / sqrt(d_head)) V_shared

直观上,MHA 是:

Q1 -> K1,V1
Q2 -> K2,V2
Q3 -> K3,V3
...
Qh -> Kh,Vh

MQA 是:

Q1 -> K,V
Q2 -> K,V
Q3 -> K,V
...
Qh -> K,V

优点

KV cache 大幅下降。

原来 MHA 要缓存 h 组 K/V,现在只缓存 1 组。理论上 KV cache 减少约 h 倍。

比如 h = 32,那么 MQA 的 KV cache 约为 MHA 的 1/32

这对大模型长文本推理非常重要,因为生成阶段很多时候不是算力不够,而是显存带宽和 KV cache 读写成为瓶颈。

缺点

表达能力下降。

所有 query heads 共享同一套 K/V,意味着不同 head 虽然可以用不同 Query 去“问问题”,但它们看到的 key/value 表示空间是一样的。这样会损失一部分多头注意力的多样性。

所以 MQA 通常推理快、省显存,但模型质量可能比 MHA 略差,尤其是在大模型或复杂任务上。


3. GQA:Grouped-Query Attention,分组查询注意力

GQA 是 MHA 和 MQA 的折中。

它把 Query heads 分成若干组,每一组 Query heads 共享一组 K/V。

假设:

Q heads = h
KV heads = g

其中:

1 < g < h

那么结构是:

Q: [batch, seq_len, h, d_head]
K: [batch, seq_len, g, d_head]
V: [batch, seq_len, g, d_head]

h / g 个 Query heads 共享同一个 KV head。

例如:

h = 32
g = 8

那么每 4 个 Q heads 共享一组 K/V:

Q1, Q2, Q3, Q4       -> K1, V1
Q5, Q6, Q7, Q8       -> K2, V2
...
Q29, Q30, Q31, Q32   -> K8, V8

和 MHA/MQA 的关系

GQA 可以看作一个连续谱:

MHA: g = h
GQA: 1 < g < h
MQA: g = 1

也就是说:

MHA = 每个 Q head 独享 K/V
MQA = 所有 Q head 共享一组 K/V
GQA = 一组 Q heads 共享一组 K/V

4. 三者的核心区别

方法 Q head 数 KV head 数 KV cache 表达能力 推理速度
MHA h h 最大 最强 较慢
GQA h g 中等 接近 MHA 较快
MQA h 1 最小 相对较弱 最快

其中:

MHA: num_kv_heads = num_attention_heads
GQA: 1 < num_kv_heads < num_attention_heads
MQA: num_kv_heads = 1

在很多模型配置里,你会看到类似:

"num_attention_heads": 32,
"num_key_value_heads": 8

这就是 GQA。

如果是:

"num_attention_heads": 32,
"num_key_value_heads": 32

就是 MHA。

如果是:

"num_attention_heads": 32,
"num_key_value_heads": 1

就是 MQA。


5. 为什么主要影响推理,而不是训练?

严格说训练也会受影响,但 MQA/GQA 最重要的收益在 自回归推理

生成第 t 个 token 时,当前 token 的 Q 只来自当前 token,但 K/V 来自历史所有 token,并且历史 K/V 会被缓存:

当前步:
Q_current: 只算当前 token
K_cache, V_cache: 读历史所有 token

MHA 的历史 K/V 是:

[batch, past_seq_len, h, d_head]

GQA 是:

[batch, past_seq_len, g, d_head]

MQA 是:

[batch, past_seq_len, 1, d_head]

所以长上下文生成时,MQA/GQA 可以显著减少 KV cache 占用和显存带宽压力。

训练阶段通常是全序列并行计算 attention,训练时整段文本一次性算完,不需要为每个生成步反复读历史 K/V,不像推理那样一 token 一 token 读取 KV cache,因此收益没有推理阶段那么突出。


6. 一个直观类比

可以把 attention 想成很多个“提问者”在查资料。

MHA

每个提问者都有自己的索引系统和资料库:

Q1 查 K1/V1
Q2 查 K2/V2
Q3 查 K3/V3

信息组织最丰富,但资料库最多,存储开销大。

MQA

所有提问者共享同一个索引系统和资料库:

Q1 查 K/V
Q2 查 K/V
Q3 查 K/V

最省空间,但所有人查的是同一套资料结构。

GQA

几个提问者共享一套资料库:

Q1, Q2, Q3, Q4 查 K1/V1
Q5, Q6, Q7, Q8 查 K2/V2

比 MQA 更有表达力,比 MHA 更省显存。


7. 实际模型为什么喜欢 GQA?

因为 GQA 的性价比通常最好。

MHA 质量好,但 KV cache 太大。
MQA 速度快、省显存,但质量可能掉得比较明显。
GQA 在两者之间,能显著减少 KV cache,同时保留相当多的多头多样性。

所以现代大语言模型经常采用 GQA,尤其是长上下文模型和需要高吞吐推理的模型。

一句话总结:

MHA 是每个 query head 配一套 K/V;MQA 是所有 query heads 共用一套 K/V;GQA 是若干个 query heads 共用一套 K/V。GQA 是 MHA 和 MQA 之间的工程折中。

posted @ 2026-05-31 14:52  wljss  阅读(16)  评论(0)    收藏  举报