Flash Attention & Paged Attention
Flash Attention & Paged Attention
内容
FlashAttention 和 PagedAttention 是两种针对 Transformer 注意力机制的显存优化技术,分别解决不同维度的性能瓶颈。这里用技术对比的方式帮你快速理解:
1. FlashAttention(计算优化)
核心目标:减少注意力计算时的显存访问次数
提出方:Stanford DAWN Lab (2022)
适用场景:训练/推理中的注意力计算优化
关键技术:
# 传统注意力计算 (伪代码)
QK = Q @ K.T # O(N²) 显存占用
softmax_QK = softmax(QK)
Attention = softmax_QK @ V # 两次显存密集型矩阵乘法
# FlashAttention 改进
将计算拆分为分块(tiling)处理:
1. 分块加载Q/K/V到SRAM(片上高速缓存)
2. 局部计算QK^T + softmax + 与V相乘
3. 通过重计算(recompute)避免存储中间矩阵
优化效果:
- 训练速度提升 15-30% (A100实测)
- 显存占用降低 5-20 倍
典型应用:LLaMA、GPT-3 等大模型的训练加速
2. PagedAttention(显存管理优化)
核心目标:高效管理 KV Cache 显存碎片
提出方:UC Berkeley (vLLM框架)
适用场景:大模型推理时的显存管理
关键技术:
# 传统KV Cache问题:
每个请求的KV缓存分配连续显存 -> 导致内存碎片化
# PagedAttention 改进:
1. 将KV缓存划分为固定大小的"页"(类似OS内存分页)
2. 非连续物理显存 + 逻辑页表映射
3. 支持动态分配/释放显存页
优化效果:
- 吞吐量提升 2-4 倍 (实测)
- 支持比原有显存大 5 倍的上下文长度
典型应用:vLLM、ChatGPT 等推理服务
对比表格
| 维度 | FlashAttention | PagedAttention |
|---|---|---|
| 优化方向 | 计算模式优化 | 显存分配优化 |
| 主要阶段 | 训练/推理中的注意力计算 | 推理时的KV缓存管理 |
| 关键技术 | 分块计算 + 重计算 | 显存分页 + 虚拟映射 |
| 显存节省原理 | 减少中间矩阵存储 | 消除内存碎片 |
| 典型受益场景 | 长文本训练 | 高并发推理服务 |
组合使用示例
当部署一个大语言模型时:
- 训练阶段:使用 FlashAttention 加速注意力计算
- 推理阶段:使用 PagedAttention 管理 KV Cache
# 实际技术栈示例
训练框架:PyTorch + FlashAttention-2
推理框架:vLLM (内置PagedAttention)
两者的优化维度相互正交,实际应用中常配合使用以获得最大性能收益。如果想更深入理解,推荐通过具体代码案例实践(如用 FlashAttention 实现一个注意力层,或用 vLLM 部署模型观察显存变化)。

浙公网安备 33010602011771号