大模型-Qwen3 attention层-98
Qwen3(或相似架构)中的 Attention 层实现,它结合了 Triton 自定义 kernel(KV cache 存储) 和 FlashAttention 库 来实现高效推理
Triton kernel — 存 KV cache
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
idx = tl.program_id(0) # 当前线程处理的 token index
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets) # 加载 key[idx]
value = tl.load(value_ptr + value_offsets) # 加载 value[idx]
slot = tl.load(slot_mapping_ptr + idx) # 当前 token 对应的 cache 槽位
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key) # 写入 KV 缓存
tl.store(v_cache_ptr + cache_offsets, value)
Triton kernel 的作用 把 新生成的 key / value 向量写入全局 KV 缓存
slot_mapping 用来决定新 token 的存放位置(可能涉及到批次拼接、动态序列等)
D = num_heads * head_dim 是每个 token 展平后的 KV 长度。
这个 kernel 保证了:在 prefill 阶段(一次性输入上下文),或者 decode 阶段(逐步生成 token),KV 都能写入缓存,避免重复计算。
Python 封装 — store_kvcache
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
这里对 key/value 的 stride 做断言,确保内存布局连续,利于 Triton 访问。
启动 kernel 时 (N,) 表示一维 grid,每个线程块处理一个 token。
attention
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
o: torch.Tensor
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
context = get_context()
k_cache = self.k_cache
v_cache = self.v_cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
o = o.view(-1, self.num_heads * self.head_dim)
return o
num_heads: query 的头数
head_dim: 每个头的维度
num_kv_heads: KV 头数(可能 < num_heads,比如 grouped-query attention)
scale: softmax 缩放因子(通常是 1/sqrt(head_dim))
k_cache, v_cache: KV 缓存,后面会被外部初始化
forward:
def forward(self, q, k, v):
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
context = get_context()
k_cache = self.k_cache
v_cache = self.v_cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
q, k, v reshape 成 [tokens, num_heads, head_dim]
调用 store_kvcache 将新 KV 存入缓存。
context 是外部上下文对象,包含 slot_mapping、是否 prefill、block tables、序列长度信息。
if context.is_prefill:
if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
使用 FlashAttention 的 varlen 版本(支持变长序列)
输入 cu_seqlens_* 表示不同 batch 内 token 的起始位置,用于处理变长批次。
block_table 用于 prefix cache(共享 prefix 的场景,如多请求并行)。
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
使用 FlashAttention 专门的 KV cache 版本
q.unsqueeze(1) 变成 [tokens, 1, head_dim],因为 decode 时只处理一个 step。
直接在 KV cache 上做 attention,避免重复计算历史序列。
o = o.view(-1, self.num_heads * self.head_dim)
return o
输出 [tokens, hidden_size],其中 hidden_size = num_heads * head_dim
总结:
这个 Attention 模块做的事可以分为三步:
KV 缓存更新
新计算的 k,v 存到全局 KV cache(Triton kernel)。
选择执行路径
Prefill 阶段:用 flash_attn_varlen_func,高效处理长上下文批次。
Decode 阶段:用 flash_attn_with_kvcache,只跟 KV cache 做 attention,效率高。
输出投影结果
把多头拼回 [tokens, hidden_size],交给后续层。
亮点:
Triton 自定义 kernel:解决 KV cache 写入的高效问题。
FlashAttention 集成:高效处理大序列(prefill)和逐步解码(decode)。
支持 prefix cache:方便多请求共享上下文。
支持 GQA (num_kv_heads < num_heads)。

浙公网安备 33010602011771号