FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness


FlashAttention:从 GPU IO 视角重构自注意力计算

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
NeurIPS 2022


Pay Attention
tips:文章基于我自己参考原文做的PPT用AI生成的,建议直接阅读文章末尾的PPT链接,以及原文链接,如果有什么问题请留言

1.1 标准 Attention 与 GPU 上的性能瓶颈

自注意力的标准形式为:

[
\text{Attn}(Q,K,V) = \text{softmax}(QK^T)V
]

其时间和空间复杂度均为 (O(N^2))。在实际 GPU 实现中,这一复杂度并不仅仅体现在 FLOPs 上,更体现在 中间张量的内存访问模式

  • (QK^T) 生成的 (N \times N) 矩阵需要写回 HBM
  • softmax、mask、dropout 等逐元素操作会反复从 HBM 读取与写回
  • 反向传播阶段还需要再次读取完整注意力矩阵

因此,标准 attention 在 GPU 上往往是内存受限(memory-bound)算子


1.2 近似 Attention 方法的局限

已有大量工作尝试通过近似方式降低 (O(N^2)) 复杂度,例如:

  • 局部注意力(Local Attention)
  • 低秩分解(Linformer)
  • 稀疏注意力(Sparse / Longformer / BigBird)

这些方法的共同特点是:
减少数学上的计算量(FLOPs)或矩阵规模

然而在实际 GPU 运行中,这类方法常常面临两个问题:

  1. 模型质量下降(丢失全局依赖)
  2. 端到端速度提升有限
    原因在于:即使 FLOPs 降低,内存访问仍然占主导

FlashAttention 的作者明确指出:

Attention 的性能瓶颈并不在于“算得太多”,而在于“搬得太多”。


1.3 IO-Aware 算法的缺失

在 GPU 上:

  • HBM 带宽有限(~1.5–2 TB/s)
  • SRAM / registers 带宽极高(~10× HBM)
  • 但深度学习框架通常无法对内存访问进行细粒度控制

以往工作多集中在算子层面的数学结构,而很少从 HBM–SRAM 数据流 的角度系统性地重构 attention 算法。

FlashAttention 的核心贡献,正是在这一层面补上了空白。


2. Method:FlashAttention 的 IO-Aware 设计

2.1 GPU 执行模型回顾(系统视角)

在 GPU 上,一个 kernel 的执行流程可简化为:

  1. 从 HBM 加载数据到 registers / shared memory(SRAM)
  2. 在片上完成计算
  3. 将结果写回 HBM

性能的关键不在于 kernel 内部计算是否足够快,而在于:

  • HBM 被访问了多少次
  • 每次访问搬运了多少数据

FlashAttention 的目标非常明确:

在保证精确 attention 结果的前提下,使每个元素尽可能少地进出 HBM。


2.2 核心思想:分块 + 重计算

FlashAttention 使用了两种经典但此前未被系统性应用于 attention 的技术:

  1. Tiling(分块)
  2. Recomputation(重计算而非存储)

这两点直接决定了其 CUDA kernel 的组织方式。


2.3 分块 Attention 的数据流设计

设序列长度为 (N),隐藏维度为 (d)。

FlashAttention 不再显式构造完整的 (QK^T),而是:

  • 沿序列维度将 (Q, K, V) 划分为多个 block
  • 外层循环:逐块加载 (K, V) 到 SRAM
  • 内层循环:加载对应的 (Q) block,并立刻计算 partial attention

关键点在于:

  • 每个 attention block 的中间结果只存在于 SRAM / registers
  • 不再生成或存储 (N \times N) 的注意力矩阵

从 CUDA 角度看,这种设计:

  • 极大降低了 global memory traffic
  • 将 attention 计算转化为一个高算术强度(arithmetic intensity) kernel

2.4 分块 Softmax 的数值正确性

Softmax 的困难在于其分母是对整行的全局归约:

[
\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}
]

FlashAttention 使用了一种可分解的数值稳定 softmax

  • 每个 block 计算局部最大值与指数和
  • 通过维护全局最大值与缩放因子,将各 block 的结果正确拼接
  • 最终结果与一次性 softmax 严格一致

这一设计使得:

  • softmax 不再依赖完整行数据
  • 可以自然嵌入 block-wise CUDA kernel 中

2.5 重计算替代中间存储

在反向传播阶段:

  • 标准 attention 需要从 HBM 读取前向保存的注意力矩阵
  • FlashAttention 不保存注意力矩阵

而是:

  • 在前向阶段仅保存 softmax 的归一化统计量(如 max、sum)
  • 在反向传播中,在 SRAM 中重新计算 attention block

虽然这会增加 FLOPs,但从 GPU 角度看:

额外计算的代价远小于一次 HBM 访问。

因此整体运行时间反而更短。


2.6 单 Kernel 融合执行

得益于分块结构,FlashAttention 可以将以下操作融合进单个 CUDA kernel

  • (QK^T) block GEMM
  • softmax(含 mask、dropout)
  • 与 (V) 相乘并累积输出

这意味着:

  • 输入只从 HBM 读取一次
  • 输出只写回一次
  • 中间张量完全不落地

从系统角度看,这几乎是 attention 在单卡上的最优执行形态。


3. IO 复杂度分析(系统层结论)

作者给出了严格的 IO 复杂度分析:

  • 标准 attention
    HBM 访问量为 (O(N^2))

  • FlashAttention
    HBM 访问量降为 (O(N^2 / M)),其中 (M) 为可用 SRAM 大小

并证明:在给定 SRAM 约束下,不存在 IO 复杂度更低的精确 attention 算法

实验也验证了一个关键系统结论:

Attention 的运行时间主要由 HBM 访问次数决定,而非 FLOPs。


4. Block-Sparse FlashAttention(简述)

在 FlashAttention 框架内,作者进一步引入 block-level 稀疏掩码:

  • 仅对非零 block 执行 attention
  • IO 复杂度与稀疏度成正比降低
  • 可直接复用 FlashAttention 的 kernel 结构

在超长序列(64K)任务中,Block-Sparse FlashAttention 显著优于所有精确与近似 attention 实现。


5. 总结:FlashAttention 的系统意义

FlashAttention 并未改变 attention 的数学定义,而是:

  • 从 GPU IO 视角重新设计了 attention 的执行方式
  • 将 attention 从 memory-bound 算子转化为更接近 compute-bound 的形式
  • 证明 IO-aware 算法设计在现代 GPU 上具有决定性意义

这一工作也为后续算子优化提供了重要启示:

在高带宽层级差异巨大的硬件上,
“如何搬数据”往往比“算什么公式”更重要。



posted @ 2026-01-04 19:32  AzathothLXL  阅读(12)  评论(0)    收藏  举报