FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FlashAttention:从 GPU IO 视角重构自注意力计算
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
NeurIPS 2022
Pay Attention
tips:文章基于我自己参考原文做的PPT用AI生成的,建议直接阅读文章末尾的PPT链接,以及原文链接,如果有什么问题请留言
1. Related Work:从“减少计算”到“优化 IO”
1.1 标准 Attention 与 GPU 上的性能瓶颈
自注意力的标准形式为:
[
\text{Attn}(Q,K,V) = \text{softmax}(QK^T)V
]
其时间和空间复杂度均为 (O(N^2))。在实际 GPU 实现中,这一复杂度并不仅仅体现在 FLOPs 上,更体现在 中间张量的内存访问模式:
- (QK^T) 生成的 (N \times N) 矩阵需要写回 HBM
- softmax、mask、dropout 等逐元素操作会反复从 HBM 读取与写回
- 反向传播阶段还需要再次读取完整注意力矩阵
因此,标准 attention 在 GPU 上往往是内存受限(memory-bound)算子。
1.2 近似 Attention 方法的局限
已有大量工作尝试通过近似方式降低 (O(N^2)) 复杂度,例如:
- 局部注意力(Local Attention)
- 低秩分解(Linformer)
- 稀疏注意力(Sparse / Longformer / BigBird)
这些方法的共同特点是:
减少数学上的计算量(FLOPs)或矩阵规模。
然而在实际 GPU 运行中,这类方法常常面临两个问题:
- 模型质量下降(丢失全局依赖)
- 端到端速度提升有限
原因在于:即使 FLOPs 降低,内存访问仍然占主导
FlashAttention 的作者明确指出:
Attention 的性能瓶颈并不在于“算得太多”,而在于“搬得太多”。
1.3 IO-Aware 算法的缺失
在 GPU 上:
- HBM 带宽有限(~1.5–2 TB/s)
- SRAM / registers 带宽极高(~10× HBM)
- 但深度学习框架通常无法对内存访问进行细粒度控制
以往工作多集中在算子层面的数学结构,而很少从 HBM–SRAM 数据流 的角度系统性地重构 attention 算法。
FlashAttention 的核心贡献,正是在这一层面补上了空白。
2. Method:FlashAttention 的 IO-Aware 设计
2.1 GPU 执行模型回顾(系统视角)
在 GPU 上,一个 kernel 的执行流程可简化为:
- 从 HBM 加载数据到 registers / shared memory(SRAM)
- 在片上完成计算
- 将结果写回 HBM
性能的关键不在于 kernel 内部计算是否足够快,而在于:
- HBM 被访问了多少次
- 每次访问搬运了多少数据
FlashAttention 的目标非常明确:
在保证精确 attention 结果的前提下,使每个元素尽可能少地进出 HBM。
2.2 核心思想:分块 + 重计算
FlashAttention 使用了两种经典但此前未被系统性应用于 attention 的技术:
- Tiling(分块)
- Recomputation(重计算而非存储)
这两点直接决定了其 CUDA kernel 的组织方式。
2.3 分块 Attention 的数据流设计
设序列长度为 (N),隐藏维度为 (d)。
FlashAttention 不再显式构造完整的 (QK^T),而是:
- 沿序列维度将 (Q, K, V) 划分为多个 block
- 外层循环:逐块加载 (K, V) 到 SRAM
- 内层循环:加载对应的 (Q) block,并立刻计算 partial attention
关键点在于:
- 每个 attention block 的中间结果只存在于 SRAM / registers
- 不再生成或存储 (N \times N) 的注意力矩阵
从 CUDA 角度看,这种设计:
- 极大降低了 global memory traffic
- 将 attention 计算转化为一个高算术强度(arithmetic intensity) kernel
2.4 分块 Softmax 的数值正确性
Softmax 的困难在于其分母是对整行的全局归约:
[
\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}
]
FlashAttention 使用了一种可分解的数值稳定 softmax:
- 每个 block 计算局部最大值与指数和
- 通过维护全局最大值与缩放因子,将各 block 的结果正确拼接
- 最终结果与一次性 softmax 严格一致
这一设计使得:
- softmax 不再依赖完整行数据
- 可以自然嵌入 block-wise CUDA kernel 中
2.5 重计算替代中间存储
在反向传播阶段:
- 标准 attention 需要从 HBM 读取前向保存的注意力矩阵
- FlashAttention 不保存注意力矩阵
而是:
- 在前向阶段仅保存 softmax 的归一化统计量(如 max、sum)
- 在反向传播中,在 SRAM 中重新计算 attention block
虽然这会增加 FLOPs,但从 GPU 角度看:
额外计算的代价远小于一次 HBM 访问。
因此整体运行时间反而更短。
2.6 单 Kernel 融合执行
得益于分块结构,FlashAttention 可以将以下操作融合进单个 CUDA kernel:
- (QK^T) block GEMM
- softmax(含 mask、dropout)
- 与 (V) 相乘并累积输出
这意味着:
- 输入只从 HBM 读取一次
- 输出只写回一次
- 中间张量完全不落地
从系统角度看,这几乎是 attention 在单卡上的最优执行形态。
3. IO 复杂度分析(系统层结论)
作者给出了严格的 IO 复杂度分析:
-
标准 attention:
HBM 访问量为 (O(N^2)) -
FlashAttention:
HBM 访问量降为 (O(N^2 / M)),其中 (M) 为可用 SRAM 大小
并证明:在给定 SRAM 约束下,不存在 IO 复杂度更低的精确 attention 算法。
实验也验证了一个关键系统结论:
Attention 的运行时间主要由 HBM 访问次数决定,而非 FLOPs。
4. Block-Sparse FlashAttention(简述)
在 FlashAttention 框架内,作者进一步引入 block-level 稀疏掩码:
- 仅对非零 block 执行 attention
- IO 复杂度与稀疏度成正比降低
- 可直接复用 FlashAttention 的 kernel 结构
在超长序列(64K)任务中,Block-Sparse FlashAttention 显著优于所有精确与近似 attention 实现。
5. 总结:FlashAttention 的系统意义
FlashAttention 并未改变 attention 的数学定义,而是:
- 从 GPU IO 视角重新设计了 attention 的执行方式
- 将 attention 从 memory-bound 算子转化为更接近 compute-bound 的形式
- 证明 IO-aware 算法设计在现代 GPU 上具有决定性意义
这一工作也为后续算子优化提供了重要启示:
在高带宽层级差异巨大的硬件上,
“如何搬数据”往往比“算什么公式”更重要。
-
论文原文(NeurIPS 2022)
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
https://proceedings.neurips.cc/paper_files/paper/2022/file/67d57c32e20fd0a7a302cb81d36e40d5-Paper-Conference.pdf -
自己制作的PPT链接
https://1drv.ms/p/c/7a3fa4b8d46fdfb3/IQCXklWR0GjzSqjZ5wxyJd30ATqY25ut2Rm180vozdE0V3A?e=eVEaal

浙公网安备 33010602011771号