FlashAttention 全系列深度解析--IO 感知注意力计算如何重塑 LLM 训练与推理

技术日报 2026-03-25


一、技术背景与动机

1.1 标准注意力的根本瓶颈

Transformer 架构的注意力机制(Self-Attention)自 2017 年提出以来,已成为大语言模型(LLM)、视觉模型、多模态模型的基础组件。然而,随着序列长度 $N$ 的增大,标准注意力的时间与空间复杂度均为 $O(N^2)$,在处理长上下文时面临两大核心瓶颈:

1. 内存墙(Memory Wall)

以序列长度 $N = 2048$、头维度 $d = 64$ 的注意力为例,中间矩阵 $S = QK^\top$(形状 $2048 \times 2048$,FP16)约占 16 MB 的 GPU HBM(高带宽内存)。当 $N = 32768$,这一数字膨胀到 4 GB,远超单 GPU 可承受范围。

2. IO 带宽瓶颈(IO Bottleneck)

GPU 的核心矛盾在于:算力增长远快于内存带宽增长。以 A100 GPU 为例,其浮点算力为 312 TFLOPS(BF16),而 HBM 带宽仅 2 TB/s。标准注意力需要在 HBM 与 SRAM(片上高速缓存,容量约 20-40 MB)之间反复传输数据,绝大多数时间 GPU 的 Tensor Core 在等待数据而非计算。

大多数所谓的"高效注意力"(Efficient Attention)工作,如 Linformer、Performer、Longformer,都以引入近似误差为代价来降低计算量,但忽略了 GPU IO 带宽这一真正瓶颈。

1.2 FlashAttention 的诞生

2022 年,斯坦福大学 Tri Dao、Dan Fu 等人在 NeurIPS 2022 发表了 FlashAttention,首次将 IO 感知(IO-Awareness) 引入注意力计算。其关键洞察是:

注意力计算的瓶颈不在于 FLOPs(浮点运算量),而在于 HBM 的读写次数(IO 次数)。

FlashAttention 通过分块(Tiling)+ 在线 Softmax(Online Softmax)技术,将注意力计算融合进单个 CUDA Kernel,在 不引入任何近似误差 的前提下,将内存复杂度从 $O(N^2)$ 降至 $O(N)$,并大幅减少 HBM 访问次数。


二、核心概念与原理

2.1 GPU 内存层次结构

理解 FlashAttention 必须先理解 GPU 的内存层次:

内存类型 容量(A100) 带宽 延迟
寄存器(Register) ~几十 KB/SM 极高 最低
SRAM(共享内存) 192 KB/SM,约 20-40 MB 全局 ~19 TB/s 极低
HBM(显存) 40/80 GB ~2 TB/s 较高

标准注意力每次计算都要将 $Q, K, V, S, P, O$ 在 HBM 和 SRAM 之间来回搬运,IO 次数为 $\Theta(Nd + N^2)$。FlashAttention 通过分块让所有中间结果留在 SRAM 中计算,将 IO 次数降至 $O(N^2 d^2 M^{-1})$(其中 $M$ 是 SRAM 容量)。

2.2 从标准 Softmax 到 Online Softmax

标准 Softmax(3-Pass)

计算 $\text{softmax}(\mathbf{x})$,传统方式需要三次遍历:

Pass 1: 求最大值 m = max(x_1, ..., x_N) (防数值溢出)
Pass 2: 求指数和 ℓ = Σ exp(x_i - m)
Pass 3: 归一化 softmax(x_i) = exp(x_i - m) / ℓ

这要求必须看完全部输入才能开始输出,无法分块处理。

Online Softmax(单次遍历)

核心思想:通过维护两个递推统计量(局部最大值 $m_j$ 与修正后的指数和 $d_j$),逐块增量更新,无需看完全部数据:

$$m_{j+1} = \max(m_j,\ x_{j+1})$$

$$d_{j+1} = d_j \cdot e^{m_j - m_{j+1}} + e^{x_{j+1} - m_{j+1}}$$

当新元素 $x_{j+1}$ 到来时,若新的最大值比旧的大,对旧的指数和乘以修正因子 $e^{m_j - m_{j+1}}$ 进行缩放,再加上新元素的贡献。这样,单次遍历就能得到正确的 Softmax 分母。

Python 代码示例:

import math

def online_softmax(x):
"""单次遍历计算 Softmax,数值稳定"""
m = float('-inf') # 当前最大值
d = 0.0 # 修正后的指数和

for xi in x:
m_new = max(m, xi)
d = d * math.exp(m - m_new) + math.exp(xi - m_new)
m = m_new

return [math.exp(xi - m) / d for xi in x]

# 验证
x = [1.0, 2.0, 3.0, 4.0]
result = online_softmax(x)
print(result)
# [0.0321, 0.0871, 0.2369, 0.6439] 与 scipy.special.softmax 一致

三、关键算法:FlashAttention 前向传播

3.1 算法框架

FlashAttention 将 Online Softmax 推广到注意力计算,实现对完整注意力矩阵的分块计算。

输入:矩阵 $Q, K, V \in \mathbb{R}^{N \times d}$(序列长度 $N$,头维度 $d$)

块大小设置

  • Query 块大小 $B_r = \lceil M / (4d) \rceil$
  • Key/Value 块大小 $B_c = \min(\lceil M / (4d) \rceil, d)$
  • 其中 $M$ 为 SRAM 容量

核心循环(外循环遍历 K/V 块,内循环遍历 Q 块)

for j = 1 to ⌈N/Bc⌉: # 外循环:遍历 K/V 的分块
从 HBM 加载 K_j, V_j 到 SRAM

for i = 1 to ⌈N/Br⌉: # 内循环:遍历 Q 的分块
从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM

# 计算当前块的注意力分数
S_ij = Q_i · K_j^T * scale # 形状 [Br, Bc]

# 更新局部最大值(Online Softmax 第一步)
m_ij = rowmax(S_ij) # [Br,]
P_ij = exp(S_ij - m_ij) # 减去局部最大值,防溢出
l_ij = rowsum(P_ij) # [Br,]

# 融合旧统计量,更新全局最大值和指数和(Online Softmax 第二步)
m_i_new = max(m_i, m_ij)
l_i_new = exp(m_i - m_i_new) * l_i + exp(m_ij - m_i_new) * l_ij

# 更新输出累加器
O_i = (l_i * exp(m_i - m_i_new) * O_i + exp(m_ij - m_i_new) * P_ij * V_j) / l_i_new

# 更新统计量
m_i, l_i = m_i_new, l_i_new

# 将更新后的 O_i, l_i, m_i 写回 HBM

最终输出:O(与标准注意力数学等价)

关键洞察:整个计算过程中,$N \times N$ 的注意力矩阵 $S$ 从未被完整地实例化到 HBM,只存在于 SRAM 的短暂中间计算中。

3.2 反向传播的重计算技巧

反向传播也需要注意力矩阵 $P$,但 FlashAttention 不存储它,而是在反向传播时重新计算(Recomputation):通过保存的 $O, l, m$ 统计量,在 SRAM 中重新算出 $P$,然后立即计算梯度。这将显存占用从 $O(N^2)$ 进一步降低,代价是额外的浮点运算(约增加 30% FLOPs),但由于算力远快于 IO,整体仍然更快。

复杂度对比

指标 标准注意力 FlashAttention
HBM 读写次数(IO) $\Theta(Nd + N^2)$ $\Theta(N^2 d^2 M^{-1})$
额外显存(前向) $O(N^2)$ $O(N)$
FLOPs $O(N^2 d)$ $O(N^2 d)$(约多 30%)

四、版本演进:FlashAttention-1 → 3

4.1 FlashAttention-2(2023):工程精细化

论文FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning(Tri Dao,普林斯顿大学)

FA-2 的主要改进聚焦于减少非矩阵乘法运算(non-matmul FLOPs)改进 GPU 并行化策略

① 减少非矩阵乘法开销

现代 GPU(如 A100)的 Tensor Core 专门加速矩阵乘法,吞吐量约 312 TFLOPS(BF16);而非矩阵乘法(如指数运算 exp、标量乘除)的吞吐量仅约 20 TFLOPS,相差 15 倍。FA-2 重排了在线 Softmax 的缩放时机,将不必要的重缩放操作延迟到循环结束后统一执行,大幅减少非矩阵乘法的次数。

② 更好的并行化策略:从 Split-KV 到 Split-Q

FA-1 外循环遍历 Q,内循环遍历 K/V("sliced-K" 方案),导致同一线程块内的 4 个 Warp 需要共享内存同步,引入额外通信开销。FA-2 互换循环顺序,外循环遍历 K/V,内循环遍历 Q,让每个 Warp 独立处理 Q 的一部分,输出各自无需通信,彻底消除 Warp 间共享内存同步

③ 序列维度并行

在 Batch 和 Head 维度之外,增加对序列长度维度的并行,使得长序列(小 Batch)场景下 GPU SM 利用率大幅提升。

性能数据(A100 80GB)

  • FP16/BF16:达到 230 TFLOPS/s(FA-1 约 124 TFLOPS/s,提升约 2 倍
  • 对比 PyTorch 标准注意力:最高 9 倍加速
  • GPT-3 2.7B(8K 上下文)训练:从 80 TFLOPS/s 提升至 225 TFLOPS/s

4.2 FlashAttention-3(2024):Hopper 硬件专属优化

论文FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision(Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao;NeurIPS 2024 Spotlight)

FA-2 在 H100 GPU 上的利用率仅 35%,远低于硬件理论峰值。FA-3 与 NVIDIA 官方合作,深度利用 Hopper 架构的三大新特性:

特性一:异步流水线(Asynchronous Pipelining)

Hopper H100 的 SM(流式多处理器)引入了 Warp 专业化(Warp Specialization) 设计:不同 Warp Group 可以同时执行不同类型的操作,计算 Warp 做矩阵乘,数据搬运 Warp 用 TMA(Tensor Memory Accelerator,张量内存加速器) 异步传输数据。

FA-3 利用"乒乓调度"(Ping-Pong Scheduling):

时间线:
Warp Group 0: [GEMM QK^T] → [等待] → [GEMM PV] → [等待] → ...
Warp Group 1: [等待] → [Softmax] → [等待] → [Softmax] → ...
TMA Unit: [加载 K0,V0] → [加载 K1,V1] → ...

→ 矩阵乘(GEMM)与 Softmax、数据加载 三者完全重叠,消除流水线气泡

这使 FP16 前向性能从 FA-2 的约 570 TFLOPS 提升至 640-660 TFLOPS

特性二:WGMMA 指令

FA-3 使用 Hopper 专属的 WGMMA(Warpgroup Matrix Multiply-Accumulate) 指令代替 Ampere 的 mma.sync,单条指令的矩阵乘吞吐量更高,并通过 NVIDIA CUTLASS 库 的高级抽象进行管理,降低编程复杂度。

特性三:FP8 低精度支持与非相干处理

H100 的 Tensor Core 支持 FP8(8 位浮点)精度,理论峰值高达 1.978 PFLOPS,是 FP16 的两倍。然而 FP8 精度低,面临两大挑战:

  • 动态范围窄:FP8 仅 8 bit,极易溢出或下溢
  • 异常值(Outliers)问题:注意力矩阵中少量极大值会导致严重量化误差

FA-3 引入非相干处理(Incoherent Processing)技术:在量化前,对 Query 和 Key 矩阵乘以随机正交矩阵(实践中用 Hadamard 变换 近似实现),将激活值中的异常值"打散",均匀分布到整个向量中,从而显著降低 FP8 量化误差。

Hadamard 变换计算复杂度 $O(d \log d)$,且可与 RoPE 等之前操作融合执行,几乎无额外开销。

性能数据(H100 SXM5 80GB)

精度 TFLOPS H100 理论峰值利用率
FA-2(FP16) ~370 TFLOPS ~35%
FA-3(FP16) 740 TFLOPS 75%
FA-3(FP8) ~1.2 PFLOPS ~60%
  • 相比 FA-2,FA-3 实现 1.5–2.0 倍加速
  • FP8 版本数值误差比基线 FP8 注意力低 2.6 倍
  • 对于中长序列(1K+),FA-3 甚至超越了针对 H100 深度优化的 cuDNN 实现

五、代码示例

5.1 在 PyTorch 中使用 FlashAttention

方式一:通过 torch.nn.functional.scaled_dot_product_attention(PyTorch 2.2+,自动选择后端)

import torch
import torch.nn.functional as F

# 模拟多头注意力输入
batch_size, seq_len, num_heads, head_dim = 2, 2048, 16, 64
device = 'cuda'
dtype = torch.bfloat16

q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)

# PyTorch 会自动选择 FlashAttention 后端(需要 CUDA + 满足条件)
# 转为 [batch, heads, seq, dim] 格式
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

print(output.shape) # [2, 16, 2048, 64]

方式二:直接使用 flash-attn 库

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func, flash_attn_with_kvcache
import torch

device, dtype = 'cuda', torch.bfloat16
batch, seqlen, heads, dim = 2, 4096, 16, 64

q = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
k = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
v = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)

# 因果注意力(自回归 LLM 训练标配)
output = flash_attn_func(q, k, v, causal=True)
print(output.shape) # [2, 4096, 16, 64]

方式三:推理时使用 KV 缓存(FlashDecoding)

from flash_attn import flash_attn_with_kvcache

# 增量解码场景:q 只有 1 个 token
q_new = torch.randn(batch, 1, heads, dim, device=device, dtype=dtype)
k_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
v_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
cache_seqlens = torch.tensor([100, 200], device=device, dtype=torch.int32)

output = flash_attn_with_kvcache(
q_new, k_cache, v_cache,
cache_seqlens=cache_seqlens,
causal=True
)
# 直接返回对完整 KV 缓存的注意力结果

六、技术优势与创新点

6.1 精确性(零近似误差)

FlashAttention 是精确算法(Exact Attention),与标准注意力数学上完全等价,不引入任何近似,模型训练的收敛行为完全不变。这与 Linformer、Performer、BigBird 等近似方法有本质区别。

6.2 内存效率

将额外显存占用从 $O(N^2)$ 降至 $O(N)$。以 $N = 32K$ 的长序列为例:

  • 标准注意力:约 4 GB 额外显存
  • FlashAttention:约 256 MB(减少约 16 倍

不需要激活检查点(Activation Checkpointing),即可训练 4K-8K 序列长度的大模型,显存大幅节省。

6.3 速度提升(实测数据)

GPU 对比对象 加速比
A100 (FP16) PyTorch 标准注意力 最高 9x
A100 (BF16) FA-1 2x
H100 (FP16) FA-2 1.5-2x
H100 (FP8) 基线 FP8 注意力 数值误差低 2.6x

6.4 可扩展性

支持超长序列训练(目前结合 Ring Attention 等技术,可支持 1M+ token 长度),为长文档理解、长视频处理、代码补全等应用提供基础支撑。


七、适用场景与实际案例

7.1 大模型训练加速

FlashAttention 已成为当代 LLM 训练的标准配置,以下知名模型均使用了 FlashAttention:

  • OpenAI:GPT-4(训练时采用)
  • Meta:Llama、Llama-2、Llama-3 系列
  • TII(阿联酋):Falcon、Falcon-2
  • 斯坦福:Alpaca 系列
  • 主流训练框架:Megatron-LM、DeepSpeed、NeMo、Axolotl

训练效率数据

  • GPT-3 1.3B(8K 上下文,8×A100):模型 FLOPs 利用率从 ~72 TFLOPS 提升至 220 TFLOPS/s
  • GPT-3 1.3B:训练 26B Token(Chinchilla 最优计算量)仅需约 43 小时

7.2 推理框架集成

框架 集成方式
PyTorch 2.2+ scaled_dot_product_attention 自动调用 FA 后端
vLLM 默认使用 FA 进行推理加速
HuggingFace Transformers 通过 attn_implementation="flash_attention_2" 启用
TensorRT-LLM 内置 FA 优化
xFormers memory_efficient_attention
llama.cpp --flash-attn 参数开启

7.3 长上下文应用

FlashAttention 将以往因显存限制无法处理的长序列变为可能:

  • 长文档问答:法律合同分析、医疗文献综述(64K-128K token)
  • 代码补全:分析整个代码库上下文
  • 视频理解:处理高帧率视频帧序列
  • 生物信息学:超长蛋白质序列建模

八、展望:FlashAttention-4 与未来趋势(2025-2026)

2026 年 3 月,Tri Dao 团队已释出 FlashAttention-4(FA-4)的预览信息,主要方向包括:

  • 算法层面重构:使用多项式逼近替代精确指数运算,在可接受误差范围内进一步减少 Softmax 计算的 FLOPs
  • 异步流水线深化:反向传播也实现完整的异步流水线,在 B200 GPU 上 FP16 利用率预计达到理论峰值的 71%
  • 确定性训练模式(Deterministic Backward Pass):在保证性能的同时支持可复现训练
  • NPU 适配挑战:由于 FlashAttention 要求 Softmax 与矩阵乘的算力配比灵活,华为昇腾等 NPU 架构(固定脉动阵列)对其支持存在系统性挑战,相关研究正在推进中

九、总结

FlashAttention 通过一个精妙的视角转换——将注意力计算视为 IO 问题而非算力问题——彻底改变了 LLM 训练和推理的基础设施格局。其核心创新链条清晰:

  1. 分块(Tiling):避免实例化 $N \times N$ 注意力矩阵
  2. 在线 Softmax(Online Softmax):使分块处理数学上可行
  3. 反向传播重计算(Recomputation):以算力换显存
  4. 硬件协同优化(FA-2/FA-3):深度利用 Tensor Core、TMA、WGMMA 等硬件特性

从 2022 年的 FA-1(3x GPT-2 训练加速)到 2024 年的 FA-3(H100 75% 利用率),FlashAttention 家族在短短两年内将 GPU 注意力计算效率提升了一个数量级,成为现代 AI 基础设施中最重要的系统优化之一。


参考资料

  1. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (NeurIPS 2022)
  2. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (ICLR 2024)
  3. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (NeurIPS 2024 Spotlight)
  4. Tri Dao 官方博客:FlashAttention-3
  5. PyTorch 官方博客:FlashAttention-3
  6. Hazy Research 博客:FlashAttention-2
  7. NVIDIA 技术博客:新一代 FlashAttention
  8. GitHub: Dao-AILab/flash-attention
  9. GitHub: togethercomputer/flash-attention-3
  10. 从 Online Softmax 到 FlashAttention(腾讯云开发者社区)
  11. FlashAttention 深度解析:从数学原理到工程实现

---

<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>

# 技术日报 2026-03-25

---

## 一、技术背景与动机

### 1.1 标准注意力的根本瓶颈

Transformer 架构的注意力机制(Self-Attention)自 2017 年提出以来,已成为大语言模型(LLM)、视觉模型、多模态模型的基础组件。然而,随着序列长度 $N$ 的增大,标准注意力的**时间与空间复杂度均为 $O(N^2)$**,在处理长上下文时面临两大核心瓶颈:

**1. 内存墙(Memory Wall)**

以序列长度 $N = 2048$、头维度 $d = 64$ 的注意力为例,中间矩阵 $S = QK^\top$(形状 $2048 \times 2048$,FP16)约占 **16 MB** 的 GPU HBM(高带宽内存)。当 $N = 32768$,这一数字膨胀到 **4 GB**,远超单 GPU 可承受范围。

**2. IO 带宽瓶颈(IO Bottleneck)**

GPU 的核心矛盾在于:**算力增长远快于内存带宽增长**。以 A100 GPU 为例,其浮点算力为 312 TFLOPS(BF16),而 HBM 带宽仅 2 TB/s。标准注意力需要在 HBM 与 SRAM(片上高速缓存,容量约 20-40 MB)之间反复传输数据,绝大多数时间 GPU 的 Tensor Core 在**等待数据**而非计算。

大多数所谓的"高效注意力"(Efficient Attention)工作,如 Linformer、Performer、Longformer,都以引入**近似误差**为代价来降低计算量,但忽略了 GPU IO 带宽这一真正瓶颈。

### 1.2 FlashAttention 的诞生

2022 年,斯坦福大学 Tri Dao、Dan Fu 等人在 NeurIPS 2022 发表了 [FlashAttention](https://arxiv.org/abs/2205.14135),首次将 **IO 感知(IO-Awareness)** 引入注意力计算。其关键洞察是:

> **注意力计算的瓶颈不在于 FLOPs(浮点运算量),而在于 HBM 的读写次数(IO 次数)。**

FlashAttention 通过**分块(Tiling)+ 在线 Softmax(Online Softmax)**技术,将注意力计算融合进单个 CUDA Kernel,在 **不引入任何近似误差** 的前提下,将内存复杂度从 $O(N^2)$ 降至 $O(N)$,并大幅减少 HBM 访问次数。

---

## 二、核心概念与原理

### 2.1 GPU 内存层次结构

理解 FlashAttention 必须先理解 GPU 的内存层次:

| 内存类型 | 容量(A100) | 带宽 | 延迟 |
|---------|------------|------|------|
| 寄存器(Register) | ~几十 KB/SM | 极高 | 最低 |
| SRAM(共享内存) | 192 KB/SM,约 20-40 MB 全局 | ~19 TB/s | 极低 |
| HBM(显存) | 40/80 GB | ~2 TB/s | 较高 |

标准注意力每次计算都要将 $Q, K, V, S, P, O$ 在 HBM 和 SRAM 之间来回搬运,IO 次数为 $\Theta(Nd + N^2)$。FlashAttention 通过分块让所有中间结果**留在 SRAM** 中计算,将 IO 次数降至 $O(N^2 d^2 M^{-1})$(其中 $M$ 是 SRAM 容量)。

### 2.2 从标准 Softmax 到 Online Softmax

**标准 Softmax(3-Pass)**

计算 $\text{softmax}(\mathbf{x})$,传统方式需要三次遍历:

```
Pass 1: 求最大值 m = max(x_1, ..., x_N) (防数值溢出)
Pass 2: 求指数和 ℓ = Σ exp(x_i - m)
Pass 3: 归一化 softmax(x_i) = exp(x_i - m) / ℓ
```

这要求必须看完全部输入才能开始输出,无法分块处理。

**Online Softmax(单次遍历)**

核心思想:通过维护两个**递推统计量**(局部最大值 $m_j$ 与修正后的指数和 $d_j$),逐块增量更新,无需看完全部数据:

$$m_{j+1} = \max(m_j,\ x_{j+1})$$

$$d_{j+1} = d_j \cdot e^{m_j - m_{j+1}} + e^{x_{j+1} - m_{j+1}}$$

当新元素 $x_{j+1}$ 到来时,若新的最大值比旧的大,对旧的指数和乘以修正因子 $e^{m_j - m_{j+1}}$ 进行缩放,再加上新元素的贡献。这样,**单次遍历**就能得到正确的 Softmax 分母。

**Python 代码示例:**

```python
import math

def online_softmax(x):
"""单次遍历计算 Softmax,数值稳定"""
m = float('-inf') # 当前最大值
d = 0.0 # 修正后的指数和

for xi in x:
m_new = max(m, xi)
d = d * math.exp(m - m_new) + math.exp(xi - m_new)
m = m_new

return [math.exp(xi - m) / d for xi in x]

# 验证
x = [1.0, 2.0, 3.0, 4.0]
result = online_softmax(x)
print(result)
# [0.0321, 0.0871, 0.2369, 0.6439] 与 scipy.special.softmax 一致
```

---

## 三、关键算法:FlashAttention 前向传播

### 3.1 算法框架

FlashAttention 将 Online Softmax 推广到注意力计算,实现对完整注意力矩阵的分块计算。

**输入**:矩阵 $Q, K, V \in \mathbb{R}^{N \times d}$(序列长度 $N$,头维度 $d$)

**块大小设置**:
- Query 块大小 $B_r = \lceil M / (4d) \rceil$
- Key/Value 块大小 $B_c = \min(\lceil M / (4d) \rceil, d)$
- 其中 $M$ 为 SRAM 容量

**核心循环(外循环遍历 K/V 块,内循环遍历 Q 块)**:

```
for j = 1 to ⌈N/Bc⌉: # 外循环:遍历 K/V 的分块
从 HBM 加载 K_j, V_j 到 SRAM

for i = 1 to ⌈N/Br⌉: # 内循环:遍历 Q 的分块
从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM

# 计算当前块的注意力分数
S_ij = Q_i · K_j^T * scale # 形状 [Br, Bc]

# 更新局部最大值(Online Softmax 第一步)
m_ij = rowmax(S_ij) # [Br,]
P_ij = exp(S_ij - m_ij) # 减去局部最大值,防溢出
l_ij = rowsum(P_ij) # [Br,]

# 融合旧统计量,更新全局最大值和指数和(Online Softmax 第二步)
m_i_new = max(m_i, m_ij)
l_i_new = exp(m_i - m_i_new) * l_i + exp(m_ij - m_i_new) * l_ij

# 更新输出累加器
O_i = (l_i * exp(m_i - m_i_new) * O_i + exp(m_ij - m_i_new) * P_ij * V_j) / l_i_new

# 更新统计量
m_i, l_i = m_i_new, l_i_new

# 将更新后的 O_i, l_i, m_i 写回 HBM

最终输出:O(与标准注意力数学等价)
```

**关键洞察**:整个计算过程中,$N \times N$ 的注意力矩阵 $S$ **从未被完整地实例化到 HBM**,只存在于 SRAM 的短暂中间计算中。

### 3.2 反向传播的重计算技巧

反向传播也需要注意力矩阵 $P$,但 FlashAttention 不存储它,而是在反向传播时**重新计算**(Recomputation):通过保存的 $O, l, m$ 统计量,在 SRAM 中重新算出 $P$,然后立即计算梯度。这将显存占用从 $O(N^2)$ 进一步降低,代价是额外的浮点运算(约增加 30% FLOPs),但由于算力远快于 IO,整体仍然更快。

**复杂度对比**:

| 指标 | 标准注意力 | FlashAttention |
|------|----------|----------------|
| HBM 读写次数(IO) | $\Theta(Nd + N^2)$ | $\Theta(N^2 d^2 M^{-1})$ |
| 额外显存(前向) | $O(N^2)$ | $O(N)$ |
| FLOPs | $O(N^2 d)$ | $O(N^2 d)$(约多 30%) |

---

## 四、版本演进:FlashAttention-1 → 3

### 4.1 FlashAttention-2(2023):工程精细化

**论文**:[FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691)(Tri Dao,普林斯顿大学)

FA-2 的主要改进聚焦于**减少非矩阵乘法运算(non-matmul FLOPs)** 与**改进 GPU 并行化策略**:

**① 减少非矩阵乘法开销**

现代 GPU(如 A100)的 Tensor Core 专门加速矩阵乘法,吞吐量约 312 TFLOPS(BF16);而非矩阵乘法(如指数运算 `exp`、标量乘除)的吞吐量仅约 **20 TFLOPS**,相差 15 倍。FA-2 重排了在线 Softmax 的缩放时机,将不必要的重缩放操作延迟到循环结束后统一执行,大幅减少非矩阵乘法的次数。

**② 更好的并行化策略:从 Split-KV 到 Split-Q**

FA-1 外循环遍历 Q,内循环遍历 K/V("sliced-K" 方案),导致同一线程块内的 4 个 Warp 需要共享内存同步,引入额外通信开销。FA-2 **互换循环顺序**,外循环遍历 K/V,内循环遍历 Q,让每个 Warp 独立处理 Q 的一部分,输出各自无需通信,**彻底消除 Warp 间共享内存同步**。

**③ 序列维度并行**

在 Batch 和 Head 维度之外,增加对序列长度维度的并行,使得长序列(小 Batch)场景下 GPU SM 利用率大幅提升。

**性能数据(A100 80GB)**:
- FP16/BF16:达到 **230 TFLOPS/s**(FA-1 约 124 TFLOPS/s,提升约 **2 倍**)
- 对比 PyTorch 标准注意力:最高 **9 倍**加速
- GPT-3 2.7B(8K 上下文)训练:从 80 TFLOPS/s 提升至 **225 TFLOPS/s**

---

### 4.2 FlashAttention-3(2024):Hopper 硬件专属优化

**论文**:[FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision](https://arxiv.org/abs/2407.08608)(Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao;NeurIPS 2024 Spotlight)

FA-2 在 H100 GPU 上的利用率仅 **35%**,远低于硬件理论峰值。FA-3 与 NVIDIA 官方合作,深度利用 Hopper 架构的三大新特性:

#### 特性一:异步流水线(Asynchronous Pipelining)

Hopper H100 的 SM(流式多处理器)引入了 **Warp 专业化(Warp Specialization)** 设计:不同 Warp Group 可以同时执行不同类型的操作,计算 Warp 做矩阵乘,数据搬运 Warp 用 **TMA(Tensor Memory Accelerator,张量内存加速器)** 异步传输数据。

FA-3 利用"乒乓调度"(Ping-Pong Scheduling):

```
时间线:
Warp Group 0: [GEMM QK^T] → [等待] → [GEMM PV] → [等待] → ...
Warp Group 1: [等待] → [Softmax] → [等待] → [Softmax] → ...
TMA Unit: [加载 K0,V0] → [加载 K1,V1] → ...

→ 矩阵乘(GEMM)与 Softmax、数据加载 三者完全重叠,消除流水线气泡
```

这使 FP16 前向性能从 FA-2 的约 570 TFLOPS 提升至 **640-660 TFLOPS**。

#### 特性二:WGMMA 指令

FA-3 使用 Hopper 专属的 **WGMMA(Warpgroup Matrix Multiply-Accumulate)** 指令代替 Ampere 的 `mma.sync`,单条指令的矩阵乘吞吐量更高,并通过 **NVIDIA CUTLASS 库** 的高级抽象进行管理,降低编程复杂度。

#### 特性三:FP8 低精度支持与非相干处理

H100 的 Tensor Core 支持 FP8(8 位浮点)精度,理论峰值高达 **1.978 PFLOPS**,是 FP16 的两倍。然而 FP8 精度低,面临两大挑战:

- **动态范围窄**:FP8 仅 8 bit,极易溢出或下溢
- **异常值(Outliers)问题**:注意力矩阵中少量极大值会导致严重量化误差

FA-3 引入**非相干处理(Incoherent Processing)**技术:在量化前,对 Query 和 Key 矩阵乘以随机正交矩阵(实践中用 **Hadamard 变换** 近似实现),将激活值中的异常值"打散",均匀分布到整个向量中,从而显著降低 FP8 量化误差。

Hadamard 变换计算复杂度 $O(d \log d)$,且可与 RoPE 等之前操作融合执行,几乎无额外开销。

**性能数据(H100 SXM5 80GB)**:

| 精度 | TFLOPS | H100 理论峰值利用率 |
|------|--------|------------------|
| FA-2(FP16) | ~370 TFLOPS | ~35% |
| FA-3(FP16) | **740 TFLOPS** | **75%** |
| FA-3(FP8) | **~1.2 PFLOPS** | ~60% |

- 相比 FA-2,FA-3 实现 **1.5–2.0 倍**加速
- FP8 版本数值误差比基线 FP8 注意力低 **2.6 倍**
- 对于中长序列(1K+),FA-3 甚至超越了针对 H100 深度优化的 **cuDNN** 实现

---

## 五、代码示例

### 5.1 在 PyTorch 中使用 FlashAttention

**方式一:通过 `torch.nn.functional.scaled_dot_product_attention`(PyTorch 2.2+,自动选择后端)**

```python
import torch
import torch.nn.functional as F

# 模拟多头注意力输入
batch_size, seq_len, num_heads, head_dim = 2, 2048, 16, 64
device = 'cuda'
dtype = torch.bfloat16

q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)

# PyTorch 会自动选择 FlashAttention 后端(需要 CUDA + 满足条件)
# 转为 [batch, heads, seq, dim] 格式
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

print(output.shape) # [2, 16, 2048, 64]
```

**方式二:直接使用 flash-attn 库**

```bash
pip install flash-attn --no-build-isolation
```

```python
from flash_attn import flash_attn_func, flash_attn_with_kvcache
import torch

device, dtype = 'cuda', torch.bfloat16
batch, seqlen, heads, dim = 2, 4096, 16, 64

q = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
k = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
v = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)

# 因果注意力(自回归 LLM 训练标配)
output = flash_attn_func(q, k, v, causal=True)
print(output.shape) # [2, 4096, 16, 64]
```

**方式三:推理时使用 KV 缓存(FlashDecoding)**

```python
from flash_attn import flash_attn_with_kvcache

# 增量解码场景:q 只有 1 个 token
q_new = torch.randn(batch, 1, heads, dim, device=device, dtype=dtype)
k_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
v_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
cache_seqlens = torch.tensor([100, 200], device=device, dtype=torch.int32)

output = flash_attn_with_kvcache(
q_new, k_cache, v_cache,
cache_seqlens=cache_seqlens,
causal=True
)
# 直接返回对完整 KV 缓存的注意力结果
```

---

## 六、技术优势与创新点

### 6.1 精确性(零近似误差)

FlashAttention 是**精确算法(Exact Attention)**,与标准注意力数学上完全等价,不引入任何近似,模型训练的收敛行为完全不变。这与 Linformer、Performer、BigBird 等**近似方法**有本质区别。

### 6.2 内存效率

将额外显存占用从 $O(N^2)$ 降至 $O(N)$。以 $N = 32K$ 的长序列为例:
- 标准注意力:约 **4 GB** 额外显存
- FlashAttention:约 **256 MB**(减少约 **16 倍**)

不需要激活检查点(Activation Checkpointing),即可训练 4K-8K 序列长度的大模型,显存大幅节省。

### 6.3 速度提升(实测数据)

| GPU | 对比对象 | 加速比 |
|-----|---------|--------|
| A100 (FP16) | PyTorch 标准注意力 | 最高 **9x** |
| A100 (BF16) | FA-1 | **2x** |
| H100 (FP16) | FA-2 | **1.5-2x** |
| H100 (FP8) | 基线 FP8 注意力 | 数值误差低 **2.6x** |

### 6.4 可扩展性

支持超长序列训练(目前结合 Ring Attention 等技术,可支持 **1M+ token** 长度),为长文档理解、长视频处理、代码补全等应用提供基础支撑。

---

## 七、适用场景与实际案例

### 7.1 大模型训练加速

FlashAttention 已成为当代 LLM 训练的**标准配置**,以下知名模型均使用了 FlashAttention:

- **OpenAI**:GPT-4(训练时采用)
- **Meta**:Llama、Llama-2、Llama-3 系列
- **TII(阿联酋)**:Falcon、Falcon-2
- **斯坦福**:Alpaca 系列
- **主流训练框架**:Megatron-LM、DeepSpeed、NeMo、Axolotl

**训练效率数据**:
- GPT-3 1.3B(8K 上下文,8×A100):模型 FLOPs 利用率从 ~72 TFLOPS 提升至 **220 TFLOPS/s**
- GPT-3 1.3B:训练 26B Token(Chinchilla 最优计算量)仅需约 **43 小时**

### 7.2 推理框架集成

| 框架 | 集成方式 |
|------|---------|
| PyTorch 2.2+ | `scaled_dot_product_attention` 自动调用 FA 后端 |
| vLLM | 默认使用 FA 进行推理加速 |
| HuggingFace Transformers | 通过 `attn_implementation="flash_attention_2"` 启用 |
| TensorRT-LLM | 内置 FA 优化 |
| xFormers | `memory_efficient_attention` |
| llama.cpp | `--flash-attn` 参数开启 |

### 7.3 长上下文应用

FlashAttention 将以往因显存限制无法处理的长序列变为可能:

- **长文档问答**:法律合同分析、医疗文献综述(64K-128K token)
- **代码补全**:分析整个代码库上下文
- **视频理解**:处理高帧率视频帧序列
- **生物信息学**:超长蛋白质序列建模

---

## 八、展望:FlashAttention-4 与未来趋势(2025-2026)

2026 年 3 月,Tri Dao 团队已释出 FlashAttention-4(FA-4)的预览信息,主要方向包括:

- **算法层面重构**:使用多项式逼近替代精确指数运算,在可接受误差范围内进一步减少 Softmax 计算的 FLOPs
- **异步流水线深化**:反向传播也实现完整的异步流水线,在 B200 GPU 上 FP16 利用率预计达到理论峰值的 **71%**
- **确定性训练模式(Deterministic Backward Pass)**:在保证性能的同时支持可复现训练
- **NPU 适配挑战**:由于 FlashAttention 要求 Softmax 与矩阵乘的算力配比灵活,华为昇腾等 NPU 架构(固定脉动阵列)对其支持存在系统性挑战,相关研究正在推进中

---

## 九、总结

FlashAttention 通过一个精妙的视角转换——**将注意力计算视为 IO 问题而非算力问题**——彻底改变了 LLM 训练和推理的基础设施格局。其核心创新链条清晰:

1. **分块(Tiling)**:避免实例化 $N \times N$ 注意力矩阵
2. **在线 Softmax(Online Softmax)**:使分块处理数学上可行
3. **反向传播重计算(Recomputation)**:以算力换显存
4. **硬件协同优化(FA-2/FA-3)**:深度利用 Tensor Core、TMA、WGMMA 等硬件特性

从 2022 年的 FA-1(3x GPT-2 训练加速)到 2024 年的 FA-3(H100 75% 利用率),FlashAttention 家族在短短两年内将 GPU 注意力计算效率提升了一个数量级,成为现代 AI 基础设施中最重要的系统优化之一。

---

## 参考资料

1. [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (NeurIPS 2022)](https://arxiv.org/abs/2205.14135)
2. [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (ICLR 2024)](https://arxiv.org/abs/2307.08691)
3. [FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (NeurIPS 2024 Spotlight)](https://arxiv.org/abs/2407.08608)
4. [Tri Dao 官方博客:FlashAttention-3](https://tridao.me/blog/2024/flash3/)
5. [PyTorch 官方博客:FlashAttention-3](https://pytorch.org/blog/flashattention-3/)
6. [Hazy Research 博客:FlashAttention-2](https://hazyresearch.stanford.edu/blog/2023-07-17-flash2)
7. [NVIDIA 技术博客:新一代 FlashAttention](https://developer.nvidia.cn/blog/next-generation-of-flashattention/)
8. [GitHub: Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)
9. [GitHub: togethercomputer/flash-attention-3](https://github.com/togethercomputer/flash-attention-3)
10. [从 Online Softmax 到 FlashAttention(腾讯云开发者社区)](https://cloud.tencent.com/developer/article/2616547)
11. [FlashAttention 深度解析:从数学原理到工程实现](https://pillumina.github.io/posts/aiinfra/11-flashattention/)

---

posted @ 2026-04-09 01:18  SHICENT  阅读(29)  评论(0)    收藏  举报