大模型-Qwen3 attention层-98

Qwen3(或相似架构)中的 Attention 层实现,它结合了 Triton 自定义 kernel(KV cache 存储) 和 FlashAttention 库 来实现高效推理

Triton kernel — 存 KV cache

@triton.jit
def store_kvcache_kernel(
    key_ptr,
    key_stride,
    value_ptr,
    value_stride,
    k_cache_ptr,
    v_cache_ptr,
    slot_mapping_ptr,
    D: tl.constexpr,
):
    idx = tl.program_id(0)                          # 当前线程处理的 token index
    key_offsets = idx * key_stride + tl.arange(0, D)
    value_offsets = idx * value_stride + tl.arange(0, D)

    key = tl.load(key_ptr + key_offsets)            # 加载 key[idx]
    value = tl.load(value_ptr + value_offsets)      # 加载 value[idx]

    slot = tl.load(slot_mapping_ptr + idx)          # 当前 token 对应的 cache 槽位
    cache_offsets = slot * D + tl.arange(0, D)

    tl.store(k_cache_ptr + cache_offsets, key)      # 写入 KV 缓存
    tl.store(v_cache_ptr + cache_offsets, value)

Triton kernel 的作用 把 新生成的 key / value 向量写入全局 KV 缓存
slot_mapping 用来决定新 token 的存放位置(可能涉及到批次拼接、动态序列等)
D = num_heads * head_dim 是每个 token 展平后的 KV 长度。
这个 kernel 保证了:在 prefill 阶段(一次性输入上下文),或者 decode 阶段(逐步生成 token),KV 都能写入缓存,避免重复计算。

Python 封装 — store_kvcache

def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
    N, num_heads, head_dim = key.shape
    D = num_heads * head_dim
    assert key.stride(-1) == 1 and value.stride(-1) == 1
    assert key.stride(1) == head_dim and value.stride(1) == head_dim
    assert k_cache.stride(1) == D and v_cache.stride(1) == D
    assert slot_mapping.numel() == N
    store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)

这里对 key/value 的 stride 做断言,确保内存布局连续,利于 Triton 访问。
启动 kernel 时 (N,) 表示一维 grid,每个线程块处理一个 token。

attention

class Attention(nn.Module):

    def __init__(
        self,
        num_heads,
        head_dim,
        scale,
        num_kv_heads,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.k_cache = self.v_cache = torch.tensor([])

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        o: torch.Tensor
        q = q.view(-1, self.num_heads, self.head_dim)
        k = k.view(-1, self.num_kv_heads, self.head_dim)
        v = v.view(-1, self.num_kv_heads, self.head_dim)
        context = get_context()
        k_cache = self.k_cache
        v_cache = self.v_cache
        store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
        if context.is_prefill:
            if context.block_tables is not None:    # prefix cache
                k, v = k_cache, v_cache
            o = flash_attn_varlen_func(q, k, v,
                                       max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
                                       max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
                                       softmax_scale=self.scale, causal=True, block_table=context.block_tables)
        else:    # decode
            o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
                                        cache_seqlens=context.context_lens, block_table=context.block_tables, 
                                        softmax_scale=self.scale, causal=True)
        o = o.view(-1, self.num_heads * self.head_dim)
        return o

num_heads: query 的头数
head_dim: 每个头的维度
num_kv_heads: KV 头数(可能 < num_heads,比如 grouped-query attention)
scale: softmax 缩放因子(通常是 1/sqrt(head_dim))
k_cache, v_cache: KV 缓存,后面会被外部初始化

forward:

def forward(self, q, k, v):
    q = q.view(-1, self.num_heads, self.head_dim)
    k = k.view(-1, self.num_kv_heads, self.head_dim)
    v = v.view(-1, self.num_kv_heads, self.head_dim)

    context = get_context()
    k_cache = self.k_cache
    v_cache = self.v_cache
    store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)

q, k, v reshape 成 [tokens, num_heads, head_dim]
调用 store_kvcache 将新 KV 存入缓存。
context 是外部上下文对象,包含 slot_mapping、是否 prefill、block tables、序列长度信息。

if context.is_prefill:
    if context.block_tables is not None:    # prefix cache
        k, v = k_cache, v_cache
    o = flash_attn_varlen_func(q, k, v,
                               max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
                               max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
                               softmax_scale=self.scale, causal=True, block_table=context.block_tables)

使用 FlashAttention 的 varlen 版本(支持变长序列)
输入 cu_seqlens_* 表示不同 batch 内 token 的起始位置,用于处理变长批次。
block_table 用于 prefix cache(共享 prefix 的场景,如多请求并行)。

else:
    o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
                                cache_seqlens=context.context_lens, block_table=context.block_tables,
                                softmax_scale=self.scale, causal=True)

使用 FlashAttention 专门的 KV cache 版本
q.unsqueeze(1) 变成 [tokens, 1, head_dim],因为 decode 时只处理一个 step。
直接在 KV cache 上做 attention,避免重复计算历史序列。

o = o.view(-1, self.num_heads * self.head_dim)
return o

输出 [tokens, hidden_size],其中 hidden_size = num_heads * head_dim

总结:
这个 Attention 模块做的事可以分为三步:

KV 缓存更新

新计算的 k,v 存到全局 KV cache(Triton kernel)。

选择执行路径
Prefill 阶段:用 flash_attn_varlen_func,高效处理长上下文批次。
Decode 阶段:用 flash_attn_with_kvcache,只跟 KV cache 做 attention,效率高。

输出投影结果
把多头拼回 [tokens, hidden_size],交给后续层。

亮点:
Triton 自定义 kernel:解决 KV cache 写入的高效问题。
FlashAttention 集成:高效处理大序列(prefill)和逐步解码(decode)。
支持 prefix cache:方便多请求共享上下文。
支持 GQA (num_kv_heads < num_heads)。

posted @ 2025-09-04 09:58  jack-chen666  阅读(46)  评论(0)    收藏  举报