模型算法-MHA-MQA-GQA(1)

1. 介绍:

基于最近对大模型 KV_cache,及 Attention 变种学习中遇到的问题和理解记录下来,帮助大家解决一点疑惑。

2. kv_cache 显存对比:

参数说明

  • batch_size:B
  • seq_len:L
  • head_num:H
  • head_dim:D
  • layer_num:N
  • group_size:G,每组 Q_head 数量
  • embedding_dim:D_em = H * D

MHA : 2 * BLHDN * sizeof(DataType)
MQA:2 * BLDN * sizeof(DataType)
GQA:2 * BLDN * (H/G) * sizeof(DataType)

3. MQA和GQA计算量没有减少,为什么能够加速?

  • 因为头的数量减少,WK WV矩阵参数量减少,带来前置计算量减少。

4. MQA 多头Q与单头 KV 计算如何组织数据?

MQA:

  • Q_mul_heads 从 (B, S, H, D) reshape 为 (B, H, S, D);
  • K_head 从 (B, S, 1, D) reshape 为 (B, 1, D, S);
    matmul(Q_mul_heads, K_head) = (B, H, S, S) ,matmul 将 K_head 复制 H 份与 Q_head 计算。

GQA:

  • Q_mul_heads 从 (B, S, H, D) reshape 为 (B, H, S, D);
  • K_head 从 (B, S, H/G, D) -> (B, S, H/G, 1, D) ,再 expand 复制最后一个维度为 (B, S, H/G, G, D), reshape 为 (B, S, H, D) 与 Q_mul_heads 大小一致, 再 reshape 为 (B, H, D, S) 可以进行 malmul 计算。
posted @ 2025-07-11 16:33  安洛8  阅读(35)  评论(0)    收藏  举报