分组查询注意力(GQA)的Pytorch实现

自注意力层(分组查询注意力)

初始化

class SelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.n_head = config.n_head # 查询头的数量
        self.kv_head = config.kv_head # kv头的数量
        self.n_embed = config.n_embed # 嵌入维度
        self.h_dim = self.n_embed // self.n_head
        assert self.n_embed % self.n_head == 0
        assert self.kv_head < self.n_head and self.n_head % self.kv_head == 0
        self.q_linear = nn.Linear(self.n_embed, self.n_embed, bias=False)
        self.v_linear = nn.Linear(self.n_embed, self.kv_head*self.h_dim, bias=False)
        self.k_linear = nn.Linear(self.n_embed, self.kv_head*self.h_dim, bias=False)
        self.out = nn.Linear(self.n_embed, self.n_embed, bias=False)

layer_idx 的作用是作为一个索引,告诉当前这个 Attention 模块它是 Transformer 模型中的第几层。方便后续训练过程中的调试与日志记录以及kv缓存处理。
assert的作用是断定某个条件必须为真,如果该条件为假,程序就会立即崩溃并抛出一个 AssertionError 异常。
nn.Linear的两个必须参数为输入维度和输出维度
这里采用的是分组查询注意力(GQA),即多个查询头共享一个kv头。相比较与MHA,计算量和缓存压力要小很多,但理论上模型质量也会有所下降。

    def forward(self, x, cos_sin, kv_cache):
        # 修改qkv矩阵的形状,方便后续计算
        B, T, C = x.size()
        q = self.q_linear(x).view(B, T, self.n_head, self.h_dim)
        v = self.v_linear(x).view(B, T, self.kv_head, self.h_dim)
        k = self.k_linear(x).view(B, T, self.kv_head, self.h_dim)

        # 进行旋转位置编码
        cos, sin = cos_sin
        q,k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
        q,k = norm(q), norm(k)

        # 矩阵转置(B, T, H, D) -> (B, H, T, D)
        # 方便后续的注意力计算,让 PyTorch 可以将 H 个头看作一个批次维度进行高效的矩阵乘法
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        k = k.transpose(1,2)

view() 本身是一个“零拷贝”操作,它不会移动数据,只是重新解释数据的形状。它并不会真的在内存中移动数据。只是修改了张量的元数据(比如 stride,即访问下一个元素需要跳过多少个内存位置),创建了一个新的“视图”指向原来的数据。
为了提高效率,像 transpose(), permute(), narrow() 等操作都是创建一个新的视图,并没有在内存中移动数据。

# 创建一个 2x3 的连续张量 
x = torch.arange(6).view(2, 3) 
# tensor([[0, 1, 2], 
#		 [3, 4, 5]])

在内存中,x 的数据是这样存储的:[0, 1, 2, 3, 4, 5]。

# 对 x 进行转置
y = x.transpose(0, 1)
# tensor([[0, 3],
#         [1, 4],
#         [2, 5]])

虽然 y 的逻辑形状是 (3, 2),但它在内存中的数据仍然是 [0, 1, 2, 3, 4, 5]

  • 为了读取 y 的第一行 [0, 3],程序需要先读取位置0的 0,然后跳过 1 和 2,去读取位置3的 3。
  • 因为元素不再是紧挨着的,所以这个张量 y 不是连续的,对于不连续的数据不能直接进行.view()操作
    旋转位置编码这里先不做赘述

现在来解释一下为什么又要进行矩阵转置
假设我们向注意力层输入的数据x长这样,一次输入4个toekn,每个token15个维度:
tensor([[[7, 2, 0, 2, 3, 7, 1, 2, 5, 5, 6, 6, 2, 9, 6],
[0, 1, 1, 9, 3, 3, 0, 2, 6, 4, 7, 0, 3, 7, 1],
[2, 7, 7, 0, 2, 1, 1, 8, 9, 6, 6, 2, 5, 5, 5],
[3, 6, 1, 1, 0, 0, 5, 3, 2, 0, 0, 9, 6, 6, 6]]])
尺寸为(1,4,15)
以q矩阵为例,假设(B, T, self.n_head, self.h_dim)=(1,4,3,5)
将输入 x 乘以一个可学习的权重矩阵(W_q) 得到q。这个过程叫做线性投影(Linear Projection)
则q初始化后长这样:
tensor([[[[7, 2, 0, 2, 3],# T0, H0
[7, 1, 2, 5, 5],# T0, H1
[6, 6, 2, 9, 6]],# T0, H2

     `[[0, 1, 1, 9, 3],# T1, H0`
      `[3, 0, 2, 6, 4],# T1, H1`
      `[7, 0, 3, 7, 1]],# T1, H2`
      
     `[[2, 7, 7, 0, 2],# T2, H0`
      `[1, 1, 8, 9, 6],# T2, H1`
      `[6, 2, 5, 5, 5]],# T2, H2`
      
     `[[3, 6, 1, 1, 0],# T3, H0`
      `[0, 5, 3, 2, 0],# T3, H1`
      `[0, 9, 6, 6, 6]]]])# T3, H2`

相当于把每个token截成四份,我们知道MQA,包括MHA计算注意力矩阵时都是头矩阵之间进行计算,所以我们按照头来分组,更能利用GPU擅长处理并行运算的特点。
对q的第2,3个维度进行转置:
tensor([[[[7, 2, 0, 2, 3], # H0, T0
[0, 1, 1, 9, 3], # H0, T1
[2, 7, 7, 0, 2], # H0, T2
[3, 6, 1, 1, 0]], # H0, T3

     `[[7, 1, 2, 5, 5],   # H1, T0`
      `[3, 0, 2, 6, 4],   # H1, T1`
      `[1, 1, 8, 9, 6],   # H1, T2`
      `[0, 5, 3, 2, 0]],  # H1, T3`
      
     `[[6, 6, 2, 9, 6],   # H2, T0`
      `[7, 0, 3, 7, 1],   # H2, T1`
      `[6, 2, 5, 5, 5],   # H2, T2`
      `[0, 9, 6, 6, 6]]]]) # H2, T3`

这样我们就得到了每个头要处理的内容。

现在来介绍一下注意力分数的计算

if kv_cache is not None:
            k,v = kv_cache.insert_kv(self.layer_idx, k, v)
        Tq = q.size(2)
        Tk = k.size(2)

        nrep = self.n_head // self.kv_head
        k,v = repeat_kv(k, nrep), repeat_kv(v, nrep)

        if kv_cache is None or Tq == Tk:
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
            # scaled_dot_product_attention将注意力机制的多个步骤融合,同时会调用最优的实现,比如FlashAttention来实现算力优化

        elif Tq == 1:
            # Tq = 1说明是单token生成,即推理场景,所以不需要掩码
            y = F.scaled_dot_product_attention(q, k, v, is_causal=False)

        else:
            mask = torch.zeros((Tq,Tk), device=q.device, dtype=torch.bool)
            prefix_len = Tk - Tq
            if prefix_len > 0:
                mask[:, :prefix_len] = True
            mask[:, prefix_len:] = torch.tril(torch.ones((Tq,Tq), device=q.device, dtype=torch.bool))
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        y.transpose(1,2).contiguous().view(B, T, -1) # (B, H, T, D) -> (B, T, C)
        y = self.out(y)

        return y

这里先不介绍kv缓存的实现
nrep是为了计算出一个kv头要对应多少个查询头,然后对kv头进行复制

def repeat_kv(x, nrep):
    if nrep == 1:
        return x
    bs, h, slen, dim = x.shape
    return(x[:, :, None, :, :]
        .expand(bs, h, nrep, slen, dim)
        .reshape(bs, h * nrep, slen, dim)
    )

假设一个原先的k头长这样
k = tensor([[[[0,0,0,0], [0,0,0,1], [0,0,0,2]], # K_H0
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]]) # K_H1
x[:, :, None, :, :]后:
tensor([[[ [[0,0,0,0], [0,0,0,1], [0,0,0,2]] ], # K_H0 [ [[1,1,1,0], [1,1,1,1], [1,1,1,2]] ]]]) # K_H1expand后:tensor([[[ [[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (original) [[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (view 1) [[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # KV_H0 (view 2) [[0,0,0,0], [0,0,0,1], [0,0,0,2]]]], # KV_H0 (view 3)`

     `[ [[1,1,1,0], [1,1,1,1], [1,1,1,2]]],  # KV_H1 (original)`
       `[[1,1,1,0], [1,1,1,1], [1,1,1,2]]],  # KV_H1 (view 1)`
       `[[1,1,1,0], [1,1,1,1], [1,1,1,2]]],  # KV_H1 (view 2)`
       `[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]]])# KV_H1 (view 3)`

reshape后:
tensor([[[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 0 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 1 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 2 (from K_H0)
[[0,0,0,0], [0,0,0,1], [0,0,0,2]]], # Head 3 (from K_H0)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 4 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 5 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]], # Head 6 (from K_H1)
[[1,1,1,0], [1,1,1,1], [1,1,1,2]]]])# Head 7 (from K_H1)
这样在与q矩阵运算时就能一一对应了
scaled_dot_product_attention(q, k, v, is_causal=False)torch.nn.functional的一个函数,只要输入q,k,v矩阵和是否进行掩码就能注意力分数的计算。这个函数会在底层使用flashattention机制来实现算力优化。
现在来介绍一下标准注意力的计算
\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\)

attn_score = torch.matmul(q,k.transpose(-2, -1))
attn_score = attn_score / self.h_dim
if is_causal:
	mask =torch.triu(torch.ones(attn_score.size(-2),attn_score.size(-1)), diagonal=1)
attn_score = attn_score.mask_fill(mask, float('inf'))
attn_weights = F.softmax(attn_score, dim=-1)
output = torch.matmul(attn_weights, v)

matul是矩阵相乘,torch.triu用于创建一个上三角矩阵
假如说x长这样:
tensor([[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[1., 1., 1., 1.]])
torch.triu(x, diagonal=0)就长这样:
tensor([[ 1., 1., 1., 1.],
[ 0., 1., 1., 1.],
[ 0., 0., 1., 1.],
[ 0., 0., 0., 1.]])
torch.triu(x, diagonal=1)就长这样:
tensor([[ 0., 1., 1., 1.],
[ 0., 0., 1., 1.],
[ 0., 0., 0., 1.],
[ 0., 0., 0., 0.]])
.mask_fill表示应用掩码矩阵,float('inf')表示负无限,方便后续交给softmax处理
下面来介绍三个选择情况:
if kv_cache is None or Tq == Tk
表示没有kv缓存或q,k矩阵序列长度相同的情况,一般这个情况都是在进行模型训练,此时需要掩码。
elif Tq == 1:
q的序列长度为1,说明此时正在进行单token生成,即推理状态,此时不需要掩码。
 else:
这个块处理的是模型在推理时,一次性处理一批(一个 chunk)新的查询(Tq > 1),并且 KV 缓存中已经存在一部分历史信息(Tk > Tq)。在这种模式下,一个小的、快速的“草稿模型”会先生成一小段文本(比如4个 tokens),然后主模型(就是我们正在分析的这个)会一次性地验证这4个 tokens。这时,Tq 就等于 4。即投机解码算法(speculative decoding)主要用于加速模型的推理。
对于该算法的详细流程,这里先不做过多赘述。

        y.transpose(1,2).contiguous().view(B, T, -1) # (B, H, T, D) -> (B, T, C)
        y = self.out(y)

将所有头拼接起来,相当于前面转置qkv矩阵的逆操作
contiguous()是为了创建一个在内存上连续的y副本。view() 本身也是一个“零拷贝”操作,它不会移动数据,只是重新解释数据的形状。如果数据在内存中是“乱”的(非连续的),view() 就不知道该如何正确地、高效地重新解释它。这一点在前面解释view()时谈到过

posted @ 2025-10-21 09:55  Luxxx23  阅读(6)  评论(1)    收藏  举报