Flash Attention & Paged Attention

Flash Attention & Paged Attention

内容

FlashAttention 和 PagedAttention 是两种针对 Transformer 注意力机制的显存优化技术,分别解决不同维度的性能瓶颈。这里用技术对比的方式帮你快速理解:


1. FlashAttention(计算优化)

核心目标减少注意力计算时的显存访问次数
提出方:Stanford DAWN Lab (2022)
适用场景:训练/推理中的注意力计算优化
关键技术

# 传统注意力计算 (伪代码)
QK = Q @ K.T            # O(N²) 显存占用
softmax_QK = softmax(QK) 
Attention = softmax_QK @ V  # 两次显存密集型矩阵乘法

# FlashAttention 改进
将计算拆分为分块(tiling)处理:
1. 分块加载Q/K/V到SRAM(片上高速缓存)
2. 局部计算QK^T + softmax + 与V相乘
3. 通过重计算(recompute)避免存储中间矩阵

优化效果

  • 训练速度提升 15-30% (A100实测)
  • 显存占用降低 5-20 倍
    典型应用:LLaMA、GPT-3 等大模型的训练加速

2. PagedAttention(显存管理优化)

核心目标高效管理 KV Cache 显存碎片
提出方:UC Berkeley (vLLM框架)
适用场景:大模型推理时的显存管理
关键技术

# 传统KV Cache问题:
每个请求的KV缓存分配连续显存 -> 导致内存碎片化

# PagedAttention 改进:
1. 将KV缓存划分为固定大小的"页"(类似OS内存分页)
2. 非连续物理显存 + 逻辑页表映射
3. 支持动态分配/释放显存页

优化效果

  • 吞吐量提升 2-4 倍 (实测)
  • 支持比原有显存大 5 倍的上下文长度
    典型应用:vLLM、ChatGPT 等推理服务

对比表格

维度 FlashAttention PagedAttention
优化方向 计算模式优化 显存分配优化
主要阶段 训练/推理中的注意力计算 推理时的KV缓存管理
关键技术 分块计算 + 重计算 显存分页 + 虚拟映射
显存节省原理 减少中间矩阵存储 消除内存碎片
典型受益场景 长文本训练 高并发推理服务

组合使用示例

当部署一个大语言模型时:

  1. 训练阶段:使用 FlashAttention 加速注意力计算
  2. 推理阶段:使用 PagedAttention 管理 KV Cache
# 实际技术栈示例
训练框架:PyTorch + FlashAttention-2
推理框架:vLLM (内置PagedAttention)

两者的优化维度相互正交,实际应用中常配合使用以获得最大性能收益。如果想更深入理解,推荐通过具体代码案例实践(如用 FlashAttention 实现一个注意力层,或用 vLLM 部署模型观察显存变化)。

posted @ 2025-03-25 15:01  Gold_stein  阅读(782)  评论(0)    收藏  举报