模型算法-MHA-MQA-GQA(1)
1. 介绍:
基于最近对大模型 KV_cache,及 Attention 变种学习中遇到的问题和理解记录下来,帮助大家解决一点疑惑。
2. kv_cache 显存对比:
参数说明
- batch_size:B
 - seq_len:L
 - head_num:H
 - head_dim:D
 - layer_num:N
 - group_size:G,每组 Q_head 数量
 - embedding_dim:D_em = H * D
 
MHA : 2 * BLHDN * sizeof(DataType)
MQA:2 * BLDN * sizeof(DataType)
GQA:2 * BLDN * (H/G) * sizeof(DataType)
3. MQA和GQA计算量没有减少,为什么能够加速?
- 因为头的数量减少,WK WV矩阵参数量减少,带来前置计算量减少。
 
4. MQA 多头Q与单头 KV 计算如何组织数据?
MQA:
- Q_mul_heads 从 (B, S, H, D) reshape 为 (B, H, S, D);
 - K_head 从 (B, S, 1, D) reshape 为 (B, 1, D, S);
 
matmul(Q_mul_heads, K_head) = (B, H, S, S) ,matmul 将 K_head 复制 H 份与 Q_head 计算。
GQA:
- Q_mul_heads 从 (B, S, H, D) reshape 为 (B, H, S, D);
 - K_head 从 (B, S, H/G, D) -> (B, S, H/G, 1, D) ,再 expand 复制最后一个维度为 (B, S, H/G, G, D), reshape 为 (B, S, H, D) 与 Q_mul_heads 大小一致, 再 reshape 为 (B, H, D, S) 可以进行 malmul 计算。
 
                    
                
                
            
        
浙公网安备 33010602011771号