注意力机制到kv_cache
Decoder-Only 大模型推理机制详解:从 KV Cache 原理到 HuggingFace 实现
引言
在当前主流的大语言模型(LLM)架构中,Decoder-Only 模型(如 GPT 系列、LLaMA、Mistral 等)已成为自然语言生成任务的基石。这类模型的核心能力是自回归生成文本——即根据已生成的 token 序列,逐个预测下一个 token。
然而,若每次生成新 token 都重新计算整个历史序列的注意力,其计算复杂度将随序列长度线性增长(甚至更高),导致推理效率急剧下降。为此,KV Cache(Key-Value Cache) 成为高效推理的关键技术。
本文将结合 D2L 教材与 HuggingFace Transformers 库的实际实现,深入剖析:
- Decoder-Only 模型如何预测下一个 token;
- QKV 的计算过程及多头注意力机制;
- 为何缓存 K 和 V 而非 Q 或 attention 结果;
- D2L 与 HuggingFace 在 KV Cache 实现上的异同;
- KV Cache 的有效性原理及其对推理性能的影响。
我们还将通过代码对比,揭示教学实现与工业级实现之间的设计哲学差异,并最终阐明 KV Cache 如何成为大模型高效推理的基石。
一、Decoder-Only 模型的基本结构与推理流程
1.1 模型结构概览
Decoder-Only 模型由多个 Decoder Block 组成,每个 block 包含:
- 自注意力层(Self-Attention)
- 前馈网络(FFN)
- 残差连接 + 层归一化(Add & Norm)
值得注意的是,在纯 Decoder-Only 架构中(如 GPT 系列),没有编码器,也不包含“编码器-解码器注意力”模块。所有输入和输出都在同一个流中处理,完全依赖于自回归机制进行文本生成。
1.2 推理流程:逐 token 自回归生成
推理过程本质上是自回归的,即每一步都基于已生成的历史 token 来预测下一个 token。具体流程如下:
初始输入: <bos>
生成: token_1
输入: <bos>, token_1 → 生成: token_2
输入: <bos>, token_1, token_2 → 生成: token_3
...
每一步都需要考虑完整的上下文历史,以确保语义连贯性。随着序列变长,如果不做优化,重复计算的成本将迅速上升。
二、QKV 的计算过程:从 Token 到注意力输入
2.1 每个 token 的隐藏状态
设模型隐藏维度为 $ d $,当前输入序列长度为 $ L $,则经过嵌入层后得到隐藏状态矩阵:
其中每个 $ h_i $ 是第 $ i $ 个 token 的上下文感知表示,包含了位置信息和语义信息。
2.2 QKV 的线性变换
通过三个可学习的权重矩阵对每个 $ h_i $ 进行线性变换:
其中:
- $ W_Q \in \mathbb{R}^{d \times d} $
- $ W_K \in \mathbb{R}^{d \times d} $
- $ W_V \in \mathbb{R}^{d \times d} $
这些变换是参数共享但独立应用于每个 token的,也就是说,每个 $ h_i $ 都会独立地生成对应的 $ q_i, k_i, v_i $。
2.3 注意力机制:缩放点积注意力
标准的缩放点积注意力公式为:
在自回归推理中,必须使用 因果掩码(Causal Mask),确保每个 token 只能关注其自身及之前的 token,防止未来信息泄露。
该操作的时间复杂度为 $ O(L^2 d) $,当 $ L $ 很大时开销显著。因此,减少重复计算成为提升推理效率的关键。
三、KV Cache 的核心思想:避免重复计算
3.1 问题:不缓存时的计算冗余
假设当前已生成 5 个 token,现在要生成第 6 个:
- 传统做法:将全部 6 个 token 输入模型,重新计算所有 $ h_i, q_i, k_i, v_i $。
- 但前 5 个 token 的 $ k_i, v_i $ 不会改变,因为模型参数固定、输入不变。
→ 重复计算前 5 步的 K 和 V 是巨大的浪费。
这种重复计算使得推理时间随输出长度呈平方级增长,严重限制了实际应用中的响应速度。
3.2 解法:KV Cache
KV Cache 的核心思想是:
缓存每个已生成 token 的 $ k_i $ 和 $ v_i $,在后续推理中复用,避免重复计算。
这样,第 6 步只需:
- 计算第 6 个 token 的 $ h_6, q_6, k_6, v_6 $
- 从缓存中读取前 5 个 token 的 $ k_1..k_5, v_1..v_5 $
- 拼接 K 和 V:$ K_{\text{full}} = [k_1..k_6], V_{\text{full}} = [v_1..v_6] $
- 计算注意力输出
这将每步新增的计算量从 $ O(L^2 d) $ 降为 $ O(dL) $,实现了线性推理复杂度。
四、D2L 中的 KV Cache 实现分析
我们来看 D2L 教材中 DecoderBlock 的 forward 方法:
def forward(self, X, state):
...
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
...
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
...
4.1 缓存了什么?
state[2][self.i]:表示第i个 Decoder Block 的历史缓存。- 初始为空,每次将当前输入
X与历史缓存拼接,形成key_values。 - 注意:这里缓存的是 原始隐藏状态 $ h_i $,而不是 $ k_i $ 或 $ v_i $。
⚠️ 关键点:这段代码并未直接缓存 K 和 V,而是缓存了 $ h_i $,然后每次用 W_k 和 W_v 重新计算 $ k_i $ 和 $ v_i $。
这对应于我们之前讨论的 “缓存 $ h_i $” 策略。
4.2 为何这样设计?优缺点分析
| 维度 | 分析 |
|---|---|
| 显存占用 | ✅ 较低:只缓存 $ h_i $(d 维),比缓存 $ k_i + v_i $(2d 维)省一半 |
| 计算量 | ❌ 较高:每步需重新计算 $ W_k \cdot h_i $ 和 $ W_v \cdot h_i $,增加 $ 2d^2 $ 计算量 |
| 适用场景 | 更适合显存受限环境(如移动端),但牺牲了计算效率 |
结论:这是一种“以计算换显存”的策略,不是典型的 KV Cache 实现。它简化了教学逻辑,便于理解缓存机制的本质,但在实际部署中并不高效。
五、HuggingFace Transformers 中的 KV Cache 实现
我们以 HuggingFace 的 transformers 库为例,分析其工业级 KV Cache 实现。
5.1 缓存对象:past_key_values
在 HuggingFace 中,KV Cache 存储在 past_key_values 中,其结构为:
past_key_values: tuple[tuple[torch.Tensor, torch.Tensor], ...]
# 外层 tuple:每个 decoder layer
# 内层 tuple:(past_key, past_value),每个 shape: [batch_size, num_heads, seq_len, head_dim]
即:直接缓存每个 layer 的 K 和 V,且是经过 W_k 和 W_v 投影后的结果。
5.2 实际代码片段(简化版)
class GPT2Attention(nn.Module):
def forward(
self,
hidden_states,
layer_past=None, # 即 past_key_values
use_cache=False,
):
# 1. 计算当前输入的 Q, K, V
query = self.q_proj(hidden_states) # [b, q_len, d_model]
key = self.k_proj(hidden_states) # [b, q_len, d_model]
value = self.v_proj(hidden_states) # [b, q_len, d_model]
# 2. 形状变换:合并投影与多头拆分
bsz, q_len, _ = hidden_states.size()
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# 3. 如果有历史缓存,拼接 K 和 V
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat([past_key, key], dim=-2) # 沿着 sequence length 维度拼接
value = torch.cat([past_value, value], dim=-2)
# 4. 计算注意力
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=None)
# 5. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
# 6. 返回输出和新的缓存
if use_cache:
new_past = (key, value)
return attn_output, new_past
return attn_output
5.3 与 D2L 代码的对比
| 特性 | D2L 实现 | HuggingFace Transformers |
|---|---|---|
| 缓存对象 | $ h_i $(隐藏状态) | $ k_i, v_i $(投影后) |
| 显存占用 | d | 2d |
| 计算量 | 每步重算 $ W_k h_i, W_v h_i $ | 直接复用,无额外计算 |
| 优化目标 | 省显存 | 省计算(GPU 上更高效) |
| 典型性 | 非主流,教学简化 | 工业级标准实现 |
✅ HuggingFace 的实现才是现代 LLM 推理的标准范式:最大化利用 GPU 并行能力,最小化每步计算延迟。
六、为什么缓存 K 和 V?为什么不缓存 Q 或 attention 结果?
6.1 为什么不缓存 Q?
- Q 是“查询”,代表当前 token 对历史的“提问”。
- 每一步的 Q 都不同:第 1 步 Q 是
<bos>的查询,第 2 步是token_1的查询。 - Q 不能复用,必须每步重新计算。
✅ Q 必须实时计算,无法缓存。
6.2 为什么不缓存 attention 结果?
- Attention 输出是 $ \text{softmax}(QK^T)V $,依赖于当前 Q。
- 每次 Q 不同,attention 结果也不同。
- 即使缓存了上一步的 attention,也无法用于下一步。
✅ Attention 结果无法复用,缓存无意义。
6.3 为什么缓存 K 和 V?
- K 和 V 是“记忆”:每个 token 的“知识表示”。
- 一旦生成,其 $ k_i, v_i $ 就固定不变。
- 后续所有步骤都可以复用这些“记忆”来回答新的“提问”(Q)。
✅ K 和 V 是静态的、可复用的,是缓存的理想对象。
🌟 类比人类记忆:你记住了一段话的内容(K/V),之后无论怎么提问(Q),都可以快速回忆并作答,而无需重新“阅读”整段文字。
七、KV Cache 的有效性:计算复杂度分析
| 策略 | 每步计算量 | 显存占用 | 适用场景 |
|---|---|---|---|
| 无缓存 | $ O(L^2 d) $ | $ O(1) $ | 不实用 |
| 缓存 $ h_i $ | $ O(d^2 + dL) $ | $ O(Ld) $ | 显存敏感场景 |
| 缓存 $ k_i, v_i $ | $ O(dL) $ | $ O(2Ld) $ | GPU/TPU 推理首选 |
💡 在典型设置下(如 $ d = 4096, L = 8192 $):
- $ d^2 \approx 16.7M $ 运算(每步都要重复)
- $ dL \approx 33.5M $ 是注意力本身成本
缓存 $ k_i, v_i $ 可消除 $ d^2 $ 的重复项,仅保留必要的 $ dL $ 成本。
结论:缓存 $ k_i, v_i $ 能将每步的额外计算从 $ O(d^2) $ 降到 $ O(1) $,仅保留注意力本身的 $ O(dL) $,实现高效推理。
八、总结:KV Cache 的本质与价值
| 项目 | 说明 |
|---|---|
| 缓存对象 | 每个 token 的 $ k_i $ 和 $ v_i $(投影后) |
| 核心思想 | 避免重复计算历史 token 的 K 和 V |
| 优势 | 将推理复杂度从 $ O(L^2 d) $ 降为 $ O(L d) $ |
| 工业实现 | HuggingFace 缓存 $ k_i, v_i $,最大化计算效率 |
| 教学实现 | D2L 缓存 $ h_i $,简化理解但牺牲性能 |
KV Cache 不是“优化技巧”,而是 Decoder-Only 模型能高效生成长文本的基石。
它体现了深度学习系统设计中的一个基本原则:空间换时间。通过合理使用显存存储中间结果,大幅降低在线推理的计算负担,从而使大模型能够在真实场景中流畅运行。
附录 A:注意力机制实现对比(D2L vs HuggingFace)
A.1 D2L 实现:DotProductAttention
class DotProductAttention(nn.Module):
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
特点:
- 输入形状:
queries: ([b, n_q, d_k])keys: ([b, n_k, d_k])values: ([b, n_k, d_v])
- 计算方式:使用
torch.bmm(batch matrix multiplication)进行批量矩阵乘。 - mask 处理:通过
valid_lens参数支持序列长度 mask。 - 输出:返回注意力加权结果,并保存
attention_weights便于可视化。
✅ 优点:代码简洁,逻辑清晰,易于理解。
A.2 Hugging Face 实现:_scaled_dot_product_attention
Hugging Face 使用 PyTorch 内置的 F.scaled_dot_product_attention 或自定义高效实现。
# 简化版核心逻辑(以 GPT-2 为例)
def _attn(self, query, key, value, attention_mask=None):
# query: [b, h, n_q, d_k]
# key: [b, h, n_k, d_k]
# value: [b, h, n_k, d_v]
# 转置 key 用于点积
key_t = key.transpose(-1, -2) # [b, h, d_k, n_k]
# 缩放点积
scale = 1 / math.sqrt(query.size(-1))
scores = torch.matmul(query, key_t) * scale # [b, h, n_q, n_k]
# 应用 attention mask
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
# softmax + dropout
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
# 加权求和
attn_output = torch.matmul(attn_weights, value) # [b, h, n_q, d_v]
return attn_output, attn_weights
特点:
- 输入形状:
query: ([b, h, n_q, d_k])key: ([b, h, n_k, d_k])value: ([b, h, n_k, d_v])
- 计算方式:使用
torch.matmul,天然支持多维张量。 - mask 处理:直接使用
attention_mask张量进行填充。 - 性能优化:可自动调用 Flash Attention、Memory-Efficient Attention 等内核。
✅ 优点:支持多头维度前置,便于并行计算;集成多种优化后端。
附录 B:多头注意力实现对比
B.1 D2L 实现:MultiHeadAttention
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout):
super().__init__()
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_v = nn.Linear(value_size, num_hiddens, bias=False)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=False)
def forward(self, queries, keys, values, valid_lens):
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
QKV 计算流程:
- 线性投影:
W_q(queries)→ ([b, n_q, num_hiddens])W_k(keys)→ ([b, n_k, num_hiddens])W_v(values)→ ([b, n_k, num_hiddens])
- 多头拆分(
transpose_qkv):- 将
num_hiddens拆分为(num_heads, d_k), 其中d_k = num_hiddens // num_heads - 调整形状为 ([b * num_heads, seq_len , d_k])
- 将
- 批量计算:调用
DotProductAttention,输入形状为 ([b * num_heads, n_q, d_k]) - 合并多头(
transpose_output):恢复为 ([b, n_q, num_hiddens]) - 输出投影:
W_o(output_concat)
✅ 优点:模块化设计,transpose_qkv 和 transpose_output 清晰展示了形状变换。
B.2 Hugging Face 实现:GPT2Attention
class GPT2Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
def forward(self, hidden_states, attention_mask=None, layer_past=None):
bsz, q_len, _ = hidden_states.size()
# 1. QKV 投影
query = self.q_proj(hidden_states) # [b, q_len, d_model]
key = self.k_proj(hidden_states) # [b, q_len, d_model]
value = self.v_proj(hidden_states) # [b, q_len, d_model]
# 2. 形状变换:合并投影与多头拆分
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# 3. 注意力计算
attn_output, _ = self._attn(query, key, value, attention_mask)
# 4. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)
# 5. 输出投影
output = self.out_proj(attn_output)
return output
QKV 计算流程:
- 线性投影:
q_proj(hidden_states)→[b, q_len, d_model]k_proj,v_proj类似。
- 形状变换:
- 使用
view+transpose将[b, q_len, d_model]→[b, h, q_len, d_k] - 与 D2L 不同:D2L 先
view再permute,Hugging Face 直接view后transpose。
- 使用
- 注意力计算:调用
_attn,输入为多头形状。 - 合并多头:
transpose(1,2)+view→[b, q_len, d_model] - 输出投影:
out_proj(attn_output)
✅ 优点:代码紧凑,view 和 transpose 高效完成多头拆分。
附录 C:关键差异总结
| 维度 | D2L 实现 | Hugging Face 实现 |
|---|---|---|
| 注意力类结构 | DotProductAttention + MultiHeadAttention |
单一 Attention 类包含所有逻辑 |
| QKV 投影层 | W_q, W_k, W_v |
q_proj, k_proj, v_proj |
| 多头拆分方式 | transpose_qkv 函数(reshape → permute → reshape) |
view → transpose |
| 输入张量形状 | [b, seq_len, d](多头前) |
[b, seq_len, d](多头前) |
| 多头张量形状 | [b * h, seq_len, d_k] |
[b, h, seq_len, d_k] |
| 矩阵乘法 | torch.bmm |
torch.matmul |
| mask 处理 | valid_lens(长度) |
attention_mask(布尔张量) |
| 代码风格 | 教学导向,模块解耦 | 工程导向,高度集成 |
结论
- 原理一致:D2L 与 HuggingFace 均实现标准的缩放点积注意力与多头机制,数学本质相同。
- 实现哲学不同:
- D2L:教学优先,将多头拆分/合并抽象为独立函数(
transpose_qkv),便于学生理解张量变换。 - Hugging Face:性能与灵活性优先,在单个类中完成所有操作,使用
view和transpose高效处理形状,支持多种 attention 实现(如 FlashAttention)。
- D2L:教学优先,将多头拆分/合并抽象为独立函数(
- 多头处理差异:
- D2L 使用
b * h批量维度,依赖repeat_interleave扩展valid_lens。 - Hugging Face 保留
b和h维度分离,使用attention_mask更灵活。
- D2L 使用
- 工程优化:Hugging Face 实现更容易集成 CUDA 内核优化(如 FlashAttention),而 D2L 实现更适合作为学习起点。
参考资料:
- Vaswani et al., "Attention is All You Need", 2017
- HuggingFace Transformers 源码:
modeling_gpt2.py,modeling_llama.py - Dive into Deep Learning (d2l.ai): Sequence-to-Sequence Models

浙公网安备 33010602011771号