Loading

【大模型】什么是KV cache

KV-Cache(键值缓存)详解

KV-Cache 是 Transformer 模型在推理(生成)阶段 用于 加速自回归生成 的关键优化技术。让我为您详细解释:

一、核心概念

什么是 KV-Cache?

KV-Cache 是缓存 Key(键)Value(值) 的机制,用于减少 Transformer 在生成文本时的重复计算。

为什么需要 KV-Cache?

在自回归生成中(如 GPT 生成文本),每次生成一个新 token 时:

  • 无缓存:需要重新计算之前所有 token 的 Key 和 Value → 大量重复计算
  • 有缓存:复用已计算的 Key 和 Value → 只计算当前 token 的部分

二、工作原理

Transformer 注意力机制回顾

注意力(Q, K, V) = softmax(Q·Kᵀ/√d) · V

自回归生成示例

生成序列:"The cat sat on the"

  1. 第一步:输入 "The"

    • 计算 Q₁, K₁, V₁(第 1 个 token)
    • 输出:预测 "cat"
  2. 第二步:输入 "The cat"

    • 无缓存:重新计算 K₁,V₁ 和 K₂,V₂
    • 有缓存:从缓存读取 K₁,V₁,只计算 K₂,V₂
    • 输出:预测 "sat"
  3. 第三步:输入 "The cat sat"

    • 有缓存:从缓存读取 K₁,V₁, K₂,V₂,只计算 K₃,V₃

三、KV-Cache 图示

时间步 t=1:
输入: [x₁] → 计算: Q₁, K₁, V₁
缓存: {K₁, V₁}
输出: y₁

时间步 t=2:
输入: [x₂] → 计算: Q₂, K₂, V₂
缓存: {K₁,K₂, V₁,V₂}  ← 新增
           ↓
注意力计算时:使用全部缓存 {K₁,K₂} 和 {V₁,V₂}

时间步 t=3:
输入: [x₃] → 计算: Q₃, K₃, V₃
缓存: {K₁,K₂,K₃, V₁,V₂,V₃} ← 继续扩展

四、数学表达

设第 l 层在时间步 t 的注意力计算:

无 KV-Cache

Q_{l,t} = X_t · W_l^Q
K_{l,t} = [X_1,...,X_t] · W_l^K  ← 每次重新计算全部
V_{l,t} = [X_1,...,X_t] · W_l^V

有 KV-Cache

Q_{l,t} = X_t · W_l^Q
K_{l,t} = concat(K_cache_l, X_t · W_l^K)  ← 只计算新的,连接缓存
V_{l,t} = concat(V_cache_l, X_t · W_l^V)

五、内存占用分析

缓存大小计算

每个token的缓存大小 = 2 × 层数 × 隐藏维度 × 数据类型大小

示例:LLaMA-7B 模型

层数 L = 32
隐藏维度 d = 4096
头数 h = 32
每个头维度 d_h = 128
数据类型: float16 (2字节)

每个token每层缓存大小 = 2 × (h × d_h) × 2字节
                         = 2 × 4096 × 2
                         = 16,384字节 ≈ 16KB

每个token总缓存 = 16KB × 32层 = 512KB

生成1000个token的缓存 = 512KB × 1000 = 512MB

内存优化对比

方法 计算复杂度 内存占用 适用场景
无缓存 O(n²) 每次 短序列
有缓存 O(n) 每次 高(线性增长) 长序列生成
窗口缓存 O(w) 固定 长序列(如 4K/8K 上下文)

六、代码示例

简化实现

import torch
import torch.nn as nn

class AttentionWithKVCache(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim)
        
    def forward(self, x, past_key_value=None):
        """
        x: [batch_size, seq_len, dim]
        past_key_value: 元组 (past_key, past_value)
                       past_key: [batch, num_heads, past_len, head_dim]
        """
        batch_size, seq_len, _ = x.shape
        
        # 计算 Q, K, V
        Q = self.q_proj(x)  # [batch, seq_len, dim]
        K = self.k_proj(x)  # [batch, seq_len, dim]
        V = self.v_proj(x)  # [batch, seq_len, dim]
        
        # 重塑为多头
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 如果有过去的 KV,与当前连接
        if past_key_value is not None:
            past_key, past_value = past_key_value
            K = torch.cat([past_key, K], dim=2)  # 在序列维度连接
            V = torch.cat([past_value, V], dim=2)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # 重塑回原始形状
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.dim)
        
        # 输出投影
        output = self.o_proj(attn_output)
        
        # 返回当前 KV 用于下一次
        current_key_value = (K, V)
        
        return output, current_key_value

实际使用示例

def generate_with_kv_cache(model, prompt, max_length=100):
    """使用 KV-Cache 生成文本"""
    input_ids = tokenizer.encode(prompt)
    generated = input_ids.copy()
    
    # 初始化 KV-Cache(每层一个)
    past_key_values = None
    
    for _ in range(max_length):
        # 只传入最后一个 token(如果是第一次,传入所有)
        if past_key_values is None:
            inputs = torch.tensor([input_ids])
        else:
            inputs = torch.tensor([[generated[-1]]])
        
        # 前向传播,传入 KV-Cache
        outputs = model(
            inputs, 
            past_key_values=past_key_values,
            use_cache=True
        )
        
        # 更新 KV-Cache
        past_key_values = outputs.past_key_values
        
        # 获取下一个 token
        next_token_logits = outputs.logits[0, -1, :]
        next_token = torch.argmax(next_token_logits).item()
        
        generated.append(next_token)
        
        if next_token == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(generated)

七、实际应用中的优化

1. 分页注意力(Paged Attention)

  • 类似操作系统内存分页
  • 解决缓存碎片问题
  • vLLM 等推理框架使用

2. 多查询注意力(MQA)

  • 多个查询头共享相同的 KV
  • 减少 KV-Cache 大小

3. 分组查询注意力(GQA)

  • MQA 和 MHA 的折中
  • 几个查询头共享一组 KV

4. 滑动窗口注意力

  • 只缓存最近 N 个 token
  • 适用于长文本(如 128K 上下文)

八、在 Hugging Face Transformers 中的使用

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")

# 启用 KV-Cache
outputs = model.generate(
    input_ids,
    max_length=100,
    use_cache=True,  # 启用 KV-Cache
    past_key_values=None,  # 首次为 None
)

# 手动管理 KV-Cache
past_key_values = None
for i in range(10):
    outputs = model(
        input_ids[:, i:i+1],  # 每次一个 token
        past_key_values=past_key_values,
        use_cache=True
    )
    past_key_values = outputs.past_key_values

九、优缺点总结

优点:

  1. 大幅加速推理:减少 2-10 倍计算量
  2. 内存效率:避免重复计算
  3. 适合流式生成:逐个 token 生成时效果显著

缺点:

  1. 内存占用增长:缓存随序列长度线性增长
  2. 实现复杂:需要精细的内存管理
  3. 批处理困难:不同序列长度导致缓存不对齐

十、相关技术

  1. Flash Attention:优化注意力计算,与 KV-Cache 互补
  2. vLLM 的 PagedAttention:工业级 KV-Cache 管理
  3. TensorRT-LLM:NVIDIA 的优化 KV-Cache 实现
  4. Continuous Batching:动态批处理 + KV-Cache

总结

KV-Cache 是大语言模型推理的 核心技术,通过缓存已计算的 Key 和 Value 来避免重复计算。虽然会增加内存占用,但带来的推理速度提升是数量级的,是当前所有主流 LLM 推理框架的标配优化。

在实际应用中,您会经常在以下场景遇到:

  • OpenAI API 流式响应
  • 本地部署的 LLM 服务(如 llama.cpp)
  • 推理框架(vLLM, TensorRT-LLM, TGI)
  • 手机端部署的轻量化模型

理解 KV-Cache 对优化模型部署、减少推理延迟、降低计算成本都至关重要。

posted @ 2026-01-06 19:55  SaTsuki26681534  阅读(197)  评论(0)    收藏  举报