【大模型】什么是KV cache
KV-Cache(键值缓存)详解
KV-Cache 是 Transformer 模型在推理(生成)阶段 用于 加速自回归生成 的关键优化技术。让我为您详细解释:
一、核心概念
什么是 KV-Cache?
KV-Cache 是缓存 Key(键) 和 Value(值) 的机制,用于减少 Transformer 在生成文本时的重复计算。
为什么需要 KV-Cache?
在自回归生成中(如 GPT 生成文本),每次生成一个新 token 时:
- 无缓存:需要重新计算之前所有 token 的 Key 和 Value → 大量重复计算
- 有缓存:复用已计算的 Key 和 Value → 只计算当前 token 的部分
二、工作原理
Transformer 注意力机制回顾
注意力(Q, K, V) = softmax(Q·Kᵀ/√d) · V
自回归生成示例
生成序列:"The cat sat on the"
-
第一步:输入 "The"
- 计算 Q₁, K₁, V₁(第 1 个 token)
- 输出:预测 "cat"
-
第二步:输入 "The cat"
- 无缓存:重新计算 K₁,V₁ 和 K₂,V₂
- 有缓存:从缓存读取 K₁,V₁,只计算 K₂,V₂
- 输出:预测 "sat"
-
第三步:输入 "The cat sat"
- 有缓存:从缓存读取 K₁,V₁, K₂,V₂,只计算 K₃,V₃
三、KV-Cache 图示
时间步 t=1:
输入: [x₁] → 计算: Q₁, K₁, V₁
缓存: {K₁, V₁}
输出: y₁
时间步 t=2:
输入: [x₂] → 计算: Q₂, K₂, V₂
缓存: {K₁,K₂, V₁,V₂} ← 新增
↓
注意力计算时:使用全部缓存 {K₁,K₂} 和 {V₁,V₂}
时间步 t=3:
输入: [x₃] → 计算: Q₃, K₃, V₃
缓存: {K₁,K₂,K₃, V₁,V₂,V₃} ← 继续扩展
四、数学表达
设第 l 层在时间步 t 的注意力计算:
无 KV-Cache:
Q_{l,t} = X_t · W_l^Q
K_{l,t} = [X_1,...,X_t] · W_l^K ← 每次重新计算全部
V_{l,t} = [X_1,...,X_t] · W_l^V
有 KV-Cache:
Q_{l,t} = X_t · W_l^Q
K_{l,t} = concat(K_cache_l, X_t · W_l^K) ← 只计算新的,连接缓存
V_{l,t} = concat(V_cache_l, X_t · W_l^V)
五、内存占用分析
缓存大小计算
每个token的缓存大小 = 2 × 层数 × 隐藏维度 × 数据类型大小
示例:LLaMA-7B 模型
层数 L = 32
隐藏维度 d = 4096
头数 h = 32
每个头维度 d_h = 128
数据类型: float16 (2字节)
每个token每层缓存大小 = 2 × (h × d_h) × 2字节
= 2 × 4096 × 2
= 16,384字节 ≈ 16KB
每个token总缓存 = 16KB × 32层 = 512KB
生成1000个token的缓存 = 512KB × 1000 = 512MB
内存优化对比
| 方法 | 计算复杂度 | 内存占用 | 适用场景 |
|---|---|---|---|
| 无缓存 | O(n²) 每次 | 低 | 短序列 |
| 有缓存 | O(n) 每次 | 高(线性增长) | 长序列生成 |
| 窗口缓存 | O(w) | 固定 | 长序列(如 4K/8K 上下文) |
六、代码示例
简化实现
import torch
import torch.nn as nn
class AttentionWithKVCache(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def forward(self, x, past_key_value=None):
"""
x: [batch_size, seq_len, dim]
past_key_value: 元组 (past_key, past_value)
past_key: [batch, num_heads, past_len, head_dim]
"""
batch_size, seq_len, _ = x.shape
# 计算 Q, K, V
Q = self.q_proj(x) # [batch, seq_len, dim]
K = self.k_proj(x) # [batch, seq_len, dim]
V = self.v_proj(x) # [batch, seq_len, dim]
# 重塑为多头
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 如果有过去的 KV,与当前连接
if past_key_value is not None:
past_key, past_value = past_key_value
K = torch.cat([past_key, K], dim=2) # 在序列维度连接
V = torch.cat([past_value, V], dim=2)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 重塑回原始形状
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.dim)
# 输出投影
output = self.o_proj(attn_output)
# 返回当前 KV 用于下一次
current_key_value = (K, V)
return output, current_key_value
实际使用示例
def generate_with_kv_cache(model, prompt, max_length=100):
"""使用 KV-Cache 生成文本"""
input_ids = tokenizer.encode(prompt)
generated = input_ids.copy()
# 初始化 KV-Cache(每层一个)
past_key_values = None
for _ in range(max_length):
# 只传入最后一个 token(如果是第一次,传入所有)
if past_key_values is None:
inputs = torch.tensor([input_ids])
else:
inputs = torch.tensor([[generated[-1]]])
# 前向传播,传入 KV-Cache
outputs = model(
inputs,
past_key_values=past_key_values,
use_cache=True
)
# 更新 KV-Cache
past_key_values = outputs.past_key_values
# 获取下一个 token
next_token_logits = outputs.logits[0, -1, :]
next_token = torch.argmax(next_token_logits).item()
generated.append(next_token)
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(generated)
七、实际应用中的优化
1. 分页注意力(Paged Attention)
- 类似操作系统内存分页
- 解决缓存碎片问题
- vLLM 等推理框架使用
2. 多查询注意力(MQA)
- 多个查询头共享相同的 KV
- 减少 KV-Cache 大小
3. 分组查询注意力(GQA)
- MQA 和 MHA 的折中
- 几个查询头共享一组 KV
4. 滑动窗口注意力
- 只缓存最近 N 个 token
- 适用于长文本(如 128K 上下文)
八、在 Hugging Face Transformers 中的使用
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
# 启用 KV-Cache
outputs = model.generate(
input_ids,
max_length=100,
use_cache=True, # 启用 KV-Cache
past_key_values=None, # 首次为 None
)
# 手动管理 KV-Cache
past_key_values = None
for i in range(10):
outputs = model(
input_ids[:, i:i+1], # 每次一个 token
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
九、优缺点总结
优点:
- 大幅加速推理:减少 2-10 倍计算量
- 内存效率:避免重复计算
- 适合流式生成:逐个 token 生成时效果显著
缺点:
- 内存占用增长:缓存随序列长度线性增长
- 实现复杂:需要精细的内存管理
- 批处理困难:不同序列长度导致缓存不对齐
十、相关技术
- Flash Attention:优化注意力计算,与 KV-Cache 互补
- vLLM 的 PagedAttention:工业级 KV-Cache 管理
- TensorRT-LLM:NVIDIA 的优化 KV-Cache 实现
- Continuous Batching:动态批处理 + KV-Cache
总结
KV-Cache 是大语言模型推理的 核心技术,通过缓存已计算的 Key 和 Value 来避免重复计算。虽然会增加内存占用,但带来的推理速度提升是数量级的,是当前所有主流 LLM 推理框架的标配优化。
在实际应用中,您会经常在以下场景遇到:
- OpenAI API 流式响应
- 本地部署的 LLM 服务(如 llama.cpp)
- 推理框架(vLLM, TensorRT-LLM, TGI)
- 手机端部署的轻量化模型
理解 KV-Cache 对优化模型部署、减少推理延迟、降低计算成本都至关重要。

浙公网安备 33010602011771号