FlashAttention 全系列深度解析--IO 感知注意力计算如何重塑 LLM 训练与推理
技术日报 2026-03-25
一、技术背景与动机
1.1 标准注意力的根本瓶颈
Transformer 架构的注意力机制(Self-Attention)自 2017 年提出以来,已成为大语言模型(LLM)、视觉模型、多模态模型的基础组件。然而,随着序列长度 $N$ 的增大,标准注意力的时间与空间复杂度均为 $O(N^2)$,在处理长上下文时面临两大核心瓶颈:
1. 内存墙(Memory Wall)
以序列长度 $N = 2048$、头维度 $d = 64$ 的注意力为例,中间矩阵 $S = QK^\top$(形状 $2048 \times 2048$,FP16)约占 16 MB 的 GPU HBM(高带宽内存)。当 $N = 32768$,这一数字膨胀到 4 GB,远超单 GPU 可承受范围。
2. IO 带宽瓶颈(IO Bottleneck)
GPU 的核心矛盾在于:算力增长远快于内存带宽增长。以 A100 GPU 为例,其浮点算力为 312 TFLOPS(BF16),而 HBM 带宽仅 2 TB/s。标准注意力需要在 HBM 与 SRAM(片上高速缓存,容量约 20-40 MB)之间反复传输数据,绝大多数时间 GPU 的 Tensor Core 在等待数据而非计算。
大多数所谓的"高效注意力"(Efficient Attention)工作,如 Linformer、Performer、Longformer,都以引入近似误差为代价来降低计算量,但忽略了 GPU IO 带宽这一真正瓶颈。
1.2 FlashAttention 的诞生
2022 年,斯坦福大学 Tri Dao、Dan Fu 等人在 NeurIPS 2022 发表了 FlashAttention,首次将 IO 感知(IO-Awareness) 引入注意力计算。其关键洞察是:
注意力计算的瓶颈不在于 FLOPs(浮点运算量),而在于 HBM 的读写次数(IO 次数)。
FlashAttention 通过分块(Tiling)+ 在线 Softmax(Online Softmax)技术,将注意力计算融合进单个 CUDA Kernel,在 不引入任何近似误差 的前提下,将内存复杂度从 $O(N^2)$ 降至 $O(N)$,并大幅减少 HBM 访问次数。
二、核心概念与原理
2.1 GPU 内存层次结构
理解 FlashAttention 必须先理解 GPU 的内存层次:
| 内存类型 | 容量(A100) | 带宽 | 延迟 |
|---|---|---|---|
| 寄存器(Register) | ~几十 KB/SM | 极高 | 最低 |
| SRAM(共享内存) | 192 KB/SM,约 20-40 MB 全局 | ~19 TB/s | 极低 |
| HBM(显存) | 40/80 GB | ~2 TB/s | 较高 |
标准注意力每次计算都要将 $Q, K, V, S, P, O$ 在 HBM 和 SRAM 之间来回搬运,IO 次数为 $\Theta(Nd + N^2)$。FlashAttention 通过分块让所有中间结果留在 SRAM 中计算,将 IO 次数降至 $O(N^2 d^2 M^{-1})$(其中 $M$ 是 SRAM 容量)。
2.2 从标准 Softmax 到 Online Softmax
标准 Softmax(3-Pass)
计算 $\text{softmax}(\mathbf{x})$,传统方式需要三次遍历:
Pass 1: 求最大值 m = max(x_1, ..., x_N) (防数值溢出)
Pass 2: 求指数和 ℓ = Σ exp(x_i - m)
Pass 3: 归一化 softmax(x_i) = exp(x_i - m) / ℓ
这要求必须看完全部输入才能开始输出,无法分块处理。
Online Softmax(单次遍历)
核心思想:通过维护两个递推统计量(局部最大值 $m_j$ 与修正后的指数和 $d_j$),逐块增量更新,无需看完全部数据:
$$m_{j+1} = \max(m_j,\ x_{j+1})$$
$$d_{j+1} = d_j \cdot e^{m_j - m_{j+1}} + e^{x_{j+1} - m_{j+1}}$$
当新元素 $x_{j+1}$ 到来时,若新的最大值比旧的大,对旧的指数和乘以修正因子 $e^{m_j - m_{j+1}}$ 进行缩放,再加上新元素的贡献。这样,单次遍历就能得到正确的 Softmax 分母。
Python 代码示例:
import math
def online_softmax(x):
"""单次遍历计算 Softmax,数值稳定"""
m = float('-inf') # 当前最大值
d = 0.0 # 修正后的指数和
for xi in x:
m_new = max(m, xi)
d = d * math.exp(m - m_new) + math.exp(xi - m_new)
m = m_new
return [math.exp(xi - m) / d for xi in x]
# 验证
x = [1.0, 2.0, 3.0, 4.0]
result = online_softmax(x)
print(result)
# [0.0321, 0.0871, 0.2369, 0.6439] 与 scipy.special.softmax 一致
三、关键算法:FlashAttention 前向传播
3.1 算法框架
FlashAttention 将 Online Softmax 推广到注意力计算,实现对完整注意力矩阵的分块计算。
输入:矩阵 $Q, K, V \in \mathbb{R}^{N \times d}$(序列长度 $N$,头维度 $d$)
块大小设置:
- Query 块大小 $B_r = \lceil M / (4d) \rceil$
- Key/Value 块大小 $B_c = \min(\lceil M / (4d) \rceil, d)$
- 其中 $M$ 为 SRAM 容量
核心循环(外循环遍历 K/V 块,内循环遍历 Q 块):
for j = 1 to ⌈N/Bc⌉: # 外循环:遍历 K/V 的分块
从 HBM 加载 K_j, V_j 到 SRAM
for i = 1 to ⌈N/Br⌉: # 内循环:遍历 Q 的分块
从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM
# 计算当前块的注意力分数
S_ij = Q_i · K_j^T * scale # 形状 [Br, Bc]
# 更新局部最大值(Online Softmax 第一步)
m_ij = rowmax(S_ij) # [Br,]
P_ij = exp(S_ij - m_ij) # 减去局部最大值,防溢出
l_ij = rowsum(P_ij) # [Br,]
# 融合旧统计量,更新全局最大值和指数和(Online Softmax 第二步)
m_i_new = max(m_i, m_ij)
l_i_new = exp(m_i - m_i_new) * l_i + exp(m_ij - m_i_new) * l_ij
# 更新输出累加器
O_i = (l_i * exp(m_i - m_i_new) * O_i + exp(m_ij - m_i_new) * P_ij * V_j) / l_i_new
# 更新统计量
m_i, l_i = m_i_new, l_i_new
# 将更新后的 O_i, l_i, m_i 写回 HBM
最终输出:O(与标准注意力数学等价)
关键洞察:整个计算过程中,$N \times N$ 的注意力矩阵 $S$ 从未被完整地实例化到 HBM,只存在于 SRAM 的短暂中间计算中。
3.2 反向传播的重计算技巧
反向传播也需要注意力矩阵 $P$,但 FlashAttention 不存储它,而是在反向传播时重新计算(Recomputation):通过保存的 $O, l, m$ 统计量,在 SRAM 中重新算出 $P$,然后立即计算梯度。这将显存占用从 $O(N^2)$ 进一步降低,代价是额外的浮点运算(约增加 30% FLOPs),但由于算力远快于 IO,整体仍然更快。
复杂度对比:
| 指标 | 标准注意力 | FlashAttention |
|---|---|---|
| HBM 读写次数(IO) | $\Theta(Nd + N^2)$ | $\Theta(N^2 d^2 M^{-1})$ |
| 额外显存(前向) | $O(N^2)$ | $O(N)$ |
| FLOPs | $O(N^2 d)$ | $O(N^2 d)$(约多 30%) |
四、版本演进:FlashAttention-1 → 3
4.1 FlashAttention-2(2023):工程精细化
论文:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning(Tri Dao,普林斯顿大学)
FA-2 的主要改进聚焦于减少非矩阵乘法运算(non-matmul FLOPs) 与改进 GPU 并行化策略:
① 减少非矩阵乘法开销
现代 GPU(如 A100)的 Tensor Core 专门加速矩阵乘法,吞吐量约 312 TFLOPS(BF16);而非矩阵乘法(如指数运算 exp、标量乘除)的吞吐量仅约 20 TFLOPS,相差 15 倍。FA-2 重排了在线 Softmax 的缩放时机,将不必要的重缩放操作延迟到循环结束后统一执行,大幅减少非矩阵乘法的次数。
② 更好的并行化策略:从 Split-KV 到 Split-Q
FA-1 外循环遍历 Q,内循环遍历 K/V("sliced-K" 方案),导致同一线程块内的 4 个 Warp 需要共享内存同步,引入额外通信开销。FA-2 互换循环顺序,外循环遍历 K/V,内循环遍历 Q,让每个 Warp 独立处理 Q 的一部分,输出各自无需通信,彻底消除 Warp 间共享内存同步。
③ 序列维度并行
在 Batch 和 Head 维度之外,增加对序列长度维度的并行,使得长序列(小 Batch)场景下 GPU SM 利用率大幅提升。
性能数据(A100 80GB):
- FP16/BF16:达到 230 TFLOPS/s(FA-1 约 124 TFLOPS/s,提升约 2 倍)
- 对比 PyTorch 标准注意力:最高 9 倍加速
- GPT-3 2.7B(8K 上下文)训练:从 80 TFLOPS/s 提升至 225 TFLOPS/s
4.2 FlashAttention-3(2024):Hopper 硬件专属优化
论文:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision(Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao;NeurIPS 2024 Spotlight)
FA-2 在 H100 GPU 上的利用率仅 35%,远低于硬件理论峰值。FA-3 与 NVIDIA 官方合作,深度利用 Hopper 架构的三大新特性:
特性一:异步流水线(Asynchronous Pipelining)
Hopper H100 的 SM(流式多处理器)引入了 Warp 专业化(Warp Specialization) 设计:不同 Warp Group 可以同时执行不同类型的操作,计算 Warp 做矩阵乘,数据搬运 Warp 用 TMA(Tensor Memory Accelerator,张量内存加速器) 异步传输数据。
FA-3 利用"乒乓调度"(Ping-Pong Scheduling):
时间线:
Warp Group 0: [GEMM QK^T] → [等待] → [GEMM PV] → [等待] → ...
Warp Group 1: [等待] → [Softmax] → [等待] → [Softmax] → ...
TMA Unit: [加载 K0,V0] → [加载 K1,V1] → ...
→ 矩阵乘(GEMM)与 Softmax、数据加载 三者完全重叠,消除流水线气泡
这使 FP16 前向性能从 FA-2 的约 570 TFLOPS 提升至 640-660 TFLOPS。
特性二:WGMMA 指令
FA-3 使用 Hopper 专属的 WGMMA(Warpgroup Matrix Multiply-Accumulate) 指令代替 Ampere 的 mma.sync,单条指令的矩阵乘吞吐量更高,并通过 NVIDIA CUTLASS 库 的高级抽象进行管理,降低编程复杂度。
特性三:FP8 低精度支持与非相干处理
H100 的 Tensor Core 支持 FP8(8 位浮点)精度,理论峰值高达 1.978 PFLOPS,是 FP16 的两倍。然而 FP8 精度低,面临两大挑战:
- 动态范围窄:FP8 仅 8 bit,极易溢出或下溢
- 异常值(Outliers)问题:注意力矩阵中少量极大值会导致严重量化误差
FA-3 引入非相干处理(Incoherent Processing)技术:在量化前,对 Query 和 Key 矩阵乘以随机正交矩阵(实践中用 Hadamard 变换 近似实现),将激活值中的异常值"打散",均匀分布到整个向量中,从而显著降低 FP8 量化误差。
Hadamard 变换计算复杂度 $O(d \log d)$,且可与 RoPE 等之前操作融合执行,几乎无额外开销。
性能数据(H100 SXM5 80GB):
| 精度 | TFLOPS | H100 理论峰值利用率 |
|---|---|---|
| FA-2(FP16) | ~370 TFLOPS | ~35% |
| FA-3(FP16) | 740 TFLOPS | 75% |
| FA-3(FP8) | ~1.2 PFLOPS | ~60% |
- 相比 FA-2,FA-3 实现 1.5–2.0 倍加速
- FP8 版本数值误差比基线 FP8 注意力低 2.6 倍
- 对于中长序列(1K+),FA-3 甚至超越了针对 H100 深度优化的 cuDNN 实现
五、代码示例
5.1 在 PyTorch 中使用 FlashAttention
方式一:通过 torch.nn.functional.scaled_dot_product_attention(PyTorch 2.2+,自动选择后端)
import torch
import torch.nn.functional as F
# 模拟多头注意力输入
batch_size, seq_len, num_heads, head_dim = 2, 2048, 16, 64
device = 'cuda'
dtype = torch.bfloat16
q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# PyTorch 会自动选择 FlashAttention 后端(需要 CUDA + 满足条件)
# 转为 [batch, heads, seq, dim] 格式
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(output.shape) # [2, 16, 2048, 64]
方式二:直接使用 flash-attn 库
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func, flash_attn_with_kvcache
import torch
device, dtype = 'cuda', torch.bfloat16
batch, seqlen, heads, dim = 2, 4096, 16, 64
q = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
k = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
v = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
# 因果注意力(自回归 LLM 训练标配)
output = flash_attn_func(q, k, v, causal=True)
print(output.shape) # [2, 4096, 16, 64]
方式三:推理时使用 KV 缓存(FlashDecoding)
from flash_attn import flash_attn_with_kvcache
# 增量解码场景:q 只有 1 个 token
q_new = torch.randn(batch, 1, heads, dim, device=device, dtype=dtype)
k_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
v_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
cache_seqlens = torch.tensor([100, 200], device=device, dtype=torch.int32)
output = flash_attn_with_kvcache(
q_new, k_cache, v_cache,
cache_seqlens=cache_seqlens,
causal=True
)
# 直接返回对完整 KV 缓存的注意力结果
六、技术优势与创新点
6.1 精确性(零近似误差)
FlashAttention 是精确算法(Exact Attention),与标准注意力数学上完全等价,不引入任何近似,模型训练的收敛行为完全不变。这与 Linformer、Performer、BigBird 等近似方法有本质区别。
6.2 内存效率
将额外显存占用从 $O(N^2)$ 降至 $O(N)$。以 $N = 32K$ 的长序列为例:
- 标准注意力:约 4 GB 额外显存
- FlashAttention:约 256 MB(减少约 16 倍)
不需要激活检查点(Activation Checkpointing),即可训练 4K-8K 序列长度的大模型,显存大幅节省。
6.3 速度提升(实测数据)
| GPU | 对比对象 | 加速比 |
|---|---|---|
| A100 (FP16) | PyTorch 标准注意力 | 最高 9x |
| A100 (BF16) | FA-1 | 2x |
| H100 (FP16) | FA-2 | 1.5-2x |
| H100 (FP8) | 基线 FP8 注意力 | 数值误差低 2.6x |
6.4 可扩展性
支持超长序列训练(目前结合 Ring Attention 等技术,可支持 1M+ token 长度),为长文档理解、长视频处理、代码补全等应用提供基础支撑。
七、适用场景与实际案例
7.1 大模型训练加速
FlashAttention 已成为当代 LLM 训练的标准配置,以下知名模型均使用了 FlashAttention:
- OpenAI:GPT-4(训练时采用)
- Meta:Llama、Llama-2、Llama-3 系列
- TII(阿联酋):Falcon、Falcon-2
- 斯坦福:Alpaca 系列
- 主流训练框架:Megatron-LM、DeepSpeed、NeMo、Axolotl
训练效率数据:
- GPT-3 1.3B(8K 上下文,8×A100):模型 FLOPs 利用率从 ~72 TFLOPS 提升至 220 TFLOPS/s
- GPT-3 1.3B:训练 26B Token(Chinchilla 最优计算量)仅需约 43 小时
7.2 推理框架集成
| 框架 | 集成方式 |
|---|---|
| PyTorch 2.2+ | scaled_dot_product_attention 自动调用 FA 后端 |
| vLLM | 默认使用 FA 进行推理加速 |
| HuggingFace Transformers | 通过 attn_implementation="flash_attention_2" 启用 |
| TensorRT-LLM | 内置 FA 优化 |
| xFormers | memory_efficient_attention |
| llama.cpp | --flash-attn 参数开启 |
7.3 长上下文应用
FlashAttention 将以往因显存限制无法处理的长序列变为可能:
- 长文档问答:法律合同分析、医疗文献综述(64K-128K token)
- 代码补全:分析整个代码库上下文
- 视频理解:处理高帧率视频帧序列
- 生物信息学:超长蛋白质序列建模
八、展望:FlashAttention-4 与未来趋势(2025-2026)
2026 年 3 月,Tri Dao 团队已释出 FlashAttention-4(FA-4)的预览信息,主要方向包括:
- 算法层面重构:使用多项式逼近替代精确指数运算,在可接受误差范围内进一步减少 Softmax 计算的 FLOPs
- 异步流水线深化:反向传播也实现完整的异步流水线,在 B200 GPU 上 FP16 利用率预计达到理论峰值的 71%
- 确定性训练模式(Deterministic Backward Pass):在保证性能的同时支持可复现训练
- NPU 适配挑战:由于 FlashAttention 要求 Softmax 与矩阵乘的算力配比灵活,华为昇腾等 NPU 架构(固定脉动阵列)对其支持存在系统性挑战,相关研究正在推进中
九、总结
FlashAttention 通过一个精妙的视角转换——将注意力计算视为 IO 问题而非算力问题——彻底改变了 LLM 训练和推理的基础设施格局。其核心创新链条清晰:
- 分块(Tiling):避免实例化 $N \times N$ 注意力矩阵
- 在线 Softmax(Online Softmax):使分块处理数学上可行
- 反向传播重计算(Recomputation):以算力换显存
- 硬件协同优化(FA-2/FA-3):深度利用 Tensor Core、TMA、WGMMA 等硬件特性
从 2022 年的 FA-1(3x GPT-2 训练加速)到 2024 年的 FA-3(H100 75% 利用率),FlashAttention 家族在短短两年内将 GPU 注意力计算效率提升了一个数量级,成为现代 AI 基础设施中最重要的系统优化之一。
参考资料
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (NeurIPS 2022)
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (ICLR 2024)
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (NeurIPS 2024 Spotlight)
- Tri Dao 官方博客:FlashAttention-3
- PyTorch 官方博客:FlashAttention-3
- Hazy Research 博客:FlashAttention-2
- NVIDIA 技术博客:新一代 FlashAttention
- GitHub: Dao-AILab/flash-attention
- GitHub: togethercomputer/flash-attention-3
- 从 Online Softmax 到 FlashAttention(腾讯云开发者社区)
- FlashAttention 深度解析:从数学原理到工程实现
---
<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
# 技术日报 2026-03-25
---
## 一、技术背景与动机
### 1.1 标准注意力的根本瓶颈
Transformer 架构的注意力机制(Self-Attention)自 2017 年提出以来,已成为大语言模型(LLM)、视觉模型、多模态模型的基础组件。然而,随着序列长度 $N$ 的增大,标准注意力的**时间与空间复杂度均为 $O(N^2)$**,在处理长上下文时面临两大核心瓶颈:
**1. 内存墙(Memory Wall)**
以序列长度 $N = 2048$、头维度 $d = 64$ 的注意力为例,中间矩阵 $S = QK^\top$(形状 $2048 \times 2048$,FP16)约占 **16 MB** 的 GPU HBM(高带宽内存)。当 $N = 32768$,这一数字膨胀到 **4 GB**,远超单 GPU 可承受范围。
**2. IO 带宽瓶颈(IO Bottleneck)**
GPU 的核心矛盾在于:**算力增长远快于内存带宽增长**。以 A100 GPU 为例,其浮点算力为 312 TFLOPS(BF16),而 HBM 带宽仅 2 TB/s。标准注意力需要在 HBM 与 SRAM(片上高速缓存,容量约 20-40 MB)之间反复传输数据,绝大多数时间 GPU 的 Tensor Core 在**等待数据**而非计算。
大多数所谓的"高效注意力"(Efficient Attention)工作,如 Linformer、Performer、Longformer,都以引入**近似误差**为代价来降低计算量,但忽略了 GPU IO 带宽这一真正瓶颈。
### 1.2 FlashAttention 的诞生
2022 年,斯坦福大学 Tri Dao、Dan Fu 等人在 NeurIPS 2022 发表了 [FlashAttention](https://arxiv.org/abs/2205.14135),首次将 **IO 感知(IO-Awareness)** 引入注意力计算。其关键洞察是:
> **注意力计算的瓶颈不在于 FLOPs(浮点运算量),而在于 HBM 的读写次数(IO 次数)。**
FlashAttention 通过**分块(Tiling)+ 在线 Softmax(Online Softmax)**技术,将注意力计算融合进单个 CUDA Kernel,在 **不引入任何近似误差** 的前提下,将内存复杂度从 $O(N^2)$ 降至 $O(N)$,并大幅减少 HBM 访问次数。
---
## 二、核心概念与原理
### 2.1 GPU 内存层次结构
理解 FlashAttention 必须先理解 GPU 的内存层次:
| 内存类型 | 容量(A100) | 带宽 | 延迟 |
|---------|------------|------|------|
| 寄存器(Register) | ~几十 KB/SM | 极高 | 最低 |
| SRAM(共享内存) | 192 KB/SM,约 20-40 MB 全局 | ~19 TB/s | 极低 |
| HBM(显存) | 40/80 GB | ~2 TB/s | 较高 |
标准注意力每次计算都要将 $Q, K, V, S, P, O$ 在 HBM 和 SRAM 之间来回搬运,IO 次数为 $\Theta(Nd + N^2)$。FlashAttention 通过分块让所有中间结果**留在 SRAM** 中计算,将 IO 次数降至 $O(N^2 d^2 M^{-1})$(其中 $M$ 是 SRAM 容量)。
### 2.2 从标准 Softmax 到 Online Softmax
**标准 Softmax(3-Pass)**
计算 $\text{softmax}(\mathbf{x})$,传统方式需要三次遍历:
```
Pass 1: 求最大值 m = max(x_1, ..., x_N) (防数值溢出)
Pass 2: 求指数和 ℓ = Σ exp(x_i - m)
Pass 3: 归一化 softmax(x_i) = exp(x_i - m) / ℓ
```
这要求必须看完全部输入才能开始输出,无法分块处理。
**Online Softmax(单次遍历)**
核心思想:通过维护两个**递推统计量**(局部最大值 $m_j$ 与修正后的指数和 $d_j$),逐块增量更新,无需看完全部数据:
$$m_{j+1} = \max(m_j,\ x_{j+1})$$
$$d_{j+1} = d_j \cdot e^{m_j - m_{j+1}} + e^{x_{j+1} - m_{j+1}}$$
当新元素 $x_{j+1}$ 到来时,若新的最大值比旧的大,对旧的指数和乘以修正因子 $e^{m_j - m_{j+1}}$ 进行缩放,再加上新元素的贡献。这样,**单次遍历**就能得到正确的 Softmax 分母。
**Python 代码示例:**
```python
import math
def online_softmax(x):
"""单次遍历计算 Softmax,数值稳定"""
m = float('-inf') # 当前最大值
d = 0.0 # 修正后的指数和
for xi in x:
m_new = max(m, xi)
d = d * math.exp(m - m_new) + math.exp(xi - m_new)
m = m_new
return [math.exp(xi - m) / d for xi in x]
# 验证
x = [1.0, 2.0, 3.0, 4.0]
result = online_softmax(x)
print(result)
# [0.0321, 0.0871, 0.2369, 0.6439] 与 scipy.special.softmax 一致
```
---
## 三、关键算法:FlashAttention 前向传播
### 3.1 算法框架
FlashAttention 将 Online Softmax 推广到注意力计算,实现对完整注意力矩阵的分块计算。
**输入**:矩阵 $Q, K, V \in \mathbb{R}^{N \times d}$(序列长度 $N$,头维度 $d$)
**块大小设置**:
- Query 块大小 $B_r = \lceil M / (4d) \rceil$
- Key/Value 块大小 $B_c = \min(\lceil M / (4d) \rceil, d)$
- 其中 $M$ 为 SRAM 容量
**核心循环(外循环遍历 K/V 块,内循环遍历 Q 块)**:
```
for j = 1 to ⌈N/Bc⌉: # 外循环:遍历 K/V 的分块
从 HBM 加载 K_j, V_j 到 SRAM
for i = 1 to ⌈N/Br⌉: # 内循环:遍历 Q 的分块
从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM
# 计算当前块的注意力分数
S_ij = Q_i · K_j^T * scale # 形状 [Br, Bc]
# 更新局部最大值(Online Softmax 第一步)
m_ij = rowmax(S_ij) # [Br,]
P_ij = exp(S_ij - m_ij) # 减去局部最大值,防溢出
l_ij = rowsum(P_ij) # [Br,]
# 融合旧统计量,更新全局最大值和指数和(Online Softmax 第二步)
m_i_new = max(m_i, m_ij)
l_i_new = exp(m_i - m_i_new) * l_i + exp(m_ij - m_i_new) * l_ij
# 更新输出累加器
O_i = (l_i * exp(m_i - m_i_new) * O_i + exp(m_ij - m_i_new) * P_ij * V_j) / l_i_new
# 更新统计量
m_i, l_i = m_i_new, l_i_new
# 将更新后的 O_i, l_i, m_i 写回 HBM
最终输出:O(与标准注意力数学等价)
```
**关键洞察**:整个计算过程中,$N \times N$ 的注意力矩阵 $S$ **从未被完整地实例化到 HBM**,只存在于 SRAM 的短暂中间计算中。
### 3.2 反向传播的重计算技巧
反向传播也需要注意力矩阵 $P$,但 FlashAttention 不存储它,而是在反向传播时**重新计算**(Recomputation):通过保存的 $O, l, m$ 统计量,在 SRAM 中重新算出 $P$,然后立即计算梯度。这将显存占用从 $O(N^2)$ 进一步降低,代价是额外的浮点运算(约增加 30% FLOPs),但由于算力远快于 IO,整体仍然更快。
**复杂度对比**:
| 指标 | 标准注意力 | FlashAttention |
|------|----------|----------------|
| HBM 读写次数(IO) | $\Theta(Nd + N^2)$ | $\Theta(N^2 d^2 M^{-1})$ |
| 额外显存(前向) | $O(N^2)$ | $O(N)$ |
| FLOPs | $O(N^2 d)$ | $O(N^2 d)$(约多 30%) |
---
## 四、版本演进:FlashAttention-1 → 3
### 4.1 FlashAttention-2(2023):工程精细化
**论文**:[FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691)(Tri Dao,普林斯顿大学)
FA-2 的主要改进聚焦于**减少非矩阵乘法运算(non-matmul FLOPs)** 与**改进 GPU 并行化策略**:
**① 减少非矩阵乘法开销**
现代 GPU(如 A100)的 Tensor Core 专门加速矩阵乘法,吞吐量约 312 TFLOPS(BF16);而非矩阵乘法(如指数运算 `exp`、标量乘除)的吞吐量仅约 **20 TFLOPS**,相差 15 倍。FA-2 重排了在线 Softmax 的缩放时机,将不必要的重缩放操作延迟到循环结束后统一执行,大幅减少非矩阵乘法的次数。
**② 更好的并行化策略:从 Split-KV 到 Split-Q**
FA-1 外循环遍历 Q,内循环遍历 K/V("sliced-K" 方案),导致同一线程块内的 4 个 Warp 需要共享内存同步,引入额外通信开销。FA-2 **互换循环顺序**,外循环遍历 K/V,内循环遍历 Q,让每个 Warp 独立处理 Q 的一部分,输出各自无需通信,**彻底消除 Warp 间共享内存同步**。
**③ 序列维度并行**
在 Batch 和 Head 维度之外,增加对序列长度维度的并行,使得长序列(小 Batch)场景下 GPU SM 利用率大幅提升。
**性能数据(A100 80GB)**:
- FP16/BF16:达到 **230 TFLOPS/s**(FA-1 约 124 TFLOPS/s,提升约 **2 倍**)
- 对比 PyTorch 标准注意力:最高 **9 倍**加速
- GPT-3 2.7B(8K 上下文)训练:从 80 TFLOPS/s 提升至 **225 TFLOPS/s**
---
### 4.2 FlashAttention-3(2024):Hopper 硬件专属优化
**论文**:[FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision](https://arxiv.org/abs/2407.08608)(Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao;NeurIPS 2024 Spotlight)
FA-2 在 H100 GPU 上的利用率仅 **35%**,远低于硬件理论峰值。FA-3 与 NVIDIA 官方合作,深度利用 Hopper 架构的三大新特性:
#### 特性一:异步流水线(Asynchronous Pipelining)
Hopper H100 的 SM(流式多处理器)引入了 **Warp 专业化(Warp Specialization)** 设计:不同 Warp Group 可以同时执行不同类型的操作,计算 Warp 做矩阵乘,数据搬运 Warp 用 **TMA(Tensor Memory Accelerator,张量内存加速器)** 异步传输数据。
FA-3 利用"乒乓调度"(Ping-Pong Scheduling):
```
时间线:
Warp Group 0: [GEMM QK^T] → [等待] → [GEMM PV] → [等待] → ...
Warp Group 1: [等待] → [Softmax] → [等待] → [Softmax] → ...
TMA Unit: [加载 K0,V0] → [加载 K1,V1] → ...
→ 矩阵乘(GEMM)与 Softmax、数据加载 三者完全重叠,消除流水线气泡
```
这使 FP16 前向性能从 FA-2 的约 570 TFLOPS 提升至 **640-660 TFLOPS**。
#### 特性二:WGMMA 指令
FA-3 使用 Hopper 专属的 **WGMMA(Warpgroup Matrix Multiply-Accumulate)** 指令代替 Ampere 的 `mma.sync`,单条指令的矩阵乘吞吐量更高,并通过 **NVIDIA CUTLASS 库** 的高级抽象进行管理,降低编程复杂度。
#### 特性三:FP8 低精度支持与非相干处理
H100 的 Tensor Core 支持 FP8(8 位浮点)精度,理论峰值高达 **1.978 PFLOPS**,是 FP16 的两倍。然而 FP8 精度低,面临两大挑战:
- **动态范围窄**:FP8 仅 8 bit,极易溢出或下溢
- **异常值(Outliers)问题**:注意力矩阵中少量极大值会导致严重量化误差
FA-3 引入**非相干处理(Incoherent Processing)**技术:在量化前,对 Query 和 Key 矩阵乘以随机正交矩阵(实践中用 **Hadamard 变换** 近似实现),将激活值中的异常值"打散",均匀分布到整个向量中,从而显著降低 FP8 量化误差。
Hadamard 变换计算复杂度 $O(d \log d)$,且可与 RoPE 等之前操作融合执行,几乎无额外开销。
**性能数据(H100 SXM5 80GB)**:
| 精度 | TFLOPS | H100 理论峰值利用率 |
|------|--------|------------------|
| FA-2(FP16) | ~370 TFLOPS | ~35% |
| FA-3(FP16) | **740 TFLOPS** | **75%** |
| FA-3(FP8) | **~1.2 PFLOPS** | ~60% |
- 相比 FA-2,FA-3 实现 **1.5–2.0 倍**加速
- FP8 版本数值误差比基线 FP8 注意力低 **2.6 倍**
- 对于中长序列(1K+),FA-3 甚至超越了针对 H100 深度优化的 **cuDNN** 实现
---
## 五、代码示例
### 5.1 在 PyTorch 中使用 FlashAttention
**方式一:通过 `torch.nn.functional.scaled_dot_product_attention`(PyTorch 2.2+,自动选择后端)**
```python
import torch
import torch.nn.functional as F
# 模拟多头注意力输入
batch_size, seq_len, num_heads, head_dim = 2, 2048, 16, 64
device = 'cuda'
dtype = torch.bfloat16
q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# PyTorch 会自动选择 FlashAttention 后端(需要 CUDA + 满足条件)
# 转为 [batch, heads, seq, dim] 格式
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(output.shape) # [2, 16, 2048, 64]
```
**方式二:直接使用 flash-attn 库**
```bash
pip install flash-attn --no-build-isolation
```
```python
from flash_attn import flash_attn_func, flash_attn_with_kvcache
import torch
device, dtype = 'cuda', torch.bfloat16
batch, seqlen, heads, dim = 2, 4096, 16, 64
q = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
k = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
v = torch.randn(batch, seqlen, heads, dim, device=device, dtype=dtype)
# 因果注意力(自回归 LLM 训练标配)
output = flash_attn_func(q, k, v, causal=True)
print(output.shape) # [2, 4096, 16, 64]
```
**方式三:推理时使用 KV 缓存(FlashDecoding)**
```python
from flash_attn import flash_attn_with_kvcache
# 增量解码场景:q 只有 1 个 token
q_new = torch.randn(batch, 1, heads, dim, device=device, dtype=dtype)
k_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
v_cache = torch.zeros(batch, 4096, heads, dim, device=device, dtype=dtype)
cache_seqlens = torch.tensor([100, 200], device=device, dtype=torch.int32)
output = flash_attn_with_kvcache(
q_new, k_cache, v_cache,
cache_seqlens=cache_seqlens,
causal=True
)
# 直接返回对完整 KV 缓存的注意力结果
```
---
## 六、技术优势与创新点
### 6.1 精确性(零近似误差)
FlashAttention 是**精确算法(Exact Attention)**,与标准注意力数学上完全等价,不引入任何近似,模型训练的收敛行为完全不变。这与 Linformer、Performer、BigBird 等**近似方法**有本质区别。
### 6.2 内存效率
将额外显存占用从 $O(N^2)$ 降至 $O(N)$。以 $N = 32K$ 的长序列为例:
- 标准注意力:约 **4 GB** 额外显存
- FlashAttention:约 **256 MB**(减少约 **16 倍**)
不需要激活检查点(Activation Checkpointing),即可训练 4K-8K 序列长度的大模型,显存大幅节省。
### 6.3 速度提升(实测数据)
| GPU | 对比对象 | 加速比 |
|-----|---------|--------|
| A100 (FP16) | PyTorch 标准注意力 | 最高 **9x** |
| A100 (BF16) | FA-1 | **2x** |
| H100 (FP16) | FA-2 | **1.5-2x** |
| H100 (FP8) | 基线 FP8 注意力 | 数值误差低 **2.6x** |
### 6.4 可扩展性
支持超长序列训练(目前结合 Ring Attention 等技术,可支持 **1M+ token** 长度),为长文档理解、长视频处理、代码补全等应用提供基础支撑。
---
## 七、适用场景与实际案例
### 7.1 大模型训练加速
FlashAttention 已成为当代 LLM 训练的**标准配置**,以下知名模型均使用了 FlashAttention:
- **OpenAI**:GPT-4(训练时采用)
- **Meta**:Llama、Llama-2、Llama-3 系列
- **TII(阿联酋)**:Falcon、Falcon-2
- **斯坦福**:Alpaca 系列
- **主流训练框架**:Megatron-LM、DeepSpeed、NeMo、Axolotl
**训练效率数据**:
- GPT-3 1.3B(8K 上下文,8×A100):模型 FLOPs 利用率从 ~72 TFLOPS 提升至 **220 TFLOPS/s**
- GPT-3 1.3B:训练 26B Token(Chinchilla 最优计算量)仅需约 **43 小时**
### 7.2 推理框架集成
| 框架 | 集成方式 |
|------|---------|
| PyTorch 2.2+ | `scaled_dot_product_attention` 自动调用 FA 后端 |
| vLLM | 默认使用 FA 进行推理加速 |
| HuggingFace Transformers | 通过 `attn_implementation="flash_attention_2"` 启用 |
| TensorRT-LLM | 内置 FA 优化 |
| xFormers | `memory_efficient_attention` |
| llama.cpp | `--flash-attn` 参数开启 |
### 7.3 长上下文应用
FlashAttention 将以往因显存限制无法处理的长序列变为可能:
- **长文档问答**:法律合同分析、医疗文献综述(64K-128K token)
- **代码补全**:分析整个代码库上下文
- **视频理解**:处理高帧率视频帧序列
- **生物信息学**:超长蛋白质序列建模
---
## 八、展望:FlashAttention-4 与未来趋势(2025-2026)
2026 年 3 月,Tri Dao 团队已释出 FlashAttention-4(FA-4)的预览信息,主要方向包括:
- **算法层面重构**:使用多项式逼近替代精确指数运算,在可接受误差范围内进一步减少 Softmax 计算的 FLOPs
- **异步流水线深化**:反向传播也实现完整的异步流水线,在 B200 GPU 上 FP16 利用率预计达到理论峰值的 **71%**
- **确定性训练模式(Deterministic Backward Pass)**:在保证性能的同时支持可复现训练
- **NPU 适配挑战**:由于 FlashAttention 要求 Softmax 与矩阵乘的算力配比灵活,华为昇腾等 NPU 架构(固定脉动阵列)对其支持存在系统性挑战,相关研究正在推进中
---
## 九、总结
FlashAttention 通过一个精妙的视角转换——**将注意力计算视为 IO 问题而非算力问题**——彻底改变了 LLM 训练和推理的基础设施格局。其核心创新链条清晰:
1. **分块(Tiling)**:避免实例化 $N \times N$ 注意力矩阵
2. **在线 Softmax(Online Softmax)**:使分块处理数学上可行
3. **反向传播重计算(Recomputation)**:以算力换显存
4. **硬件协同优化(FA-2/FA-3)**:深度利用 Tensor Core、TMA、WGMMA 等硬件特性
从 2022 年的 FA-1(3x GPT-2 训练加速)到 2024 年的 FA-3(H100 75% 利用率),FlashAttention 家族在短短两年内将 GPU 注意力计算效率提升了一个数量级,成为现代 AI 基础设施中最重要的系统优化之一。
---
## 参考资料
1. [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (NeurIPS 2022)](https://arxiv.org/abs/2205.14135)
2. [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (ICLR 2024)](https://arxiv.org/abs/2307.08691)
3. [FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (NeurIPS 2024 Spotlight)](https://arxiv.org/abs/2407.08608)
4. [Tri Dao 官方博客:FlashAttention-3](https://tridao.me/blog/2024/flash3/)
5. [PyTorch 官方博客:FlashAttention-3](https://pytorch.org/blog/flashattention-3/)
6. [Hazy Research 博客:FlashAttention-2](https://hazyresearch.stanford.edu/blog/2023-07-17-flash2)
7. [NVIDIA 技术博客:新一代 FlashAttention](https://developer.nvidia.cn/blog/next-generation-of-flashattention/)
8. [GitHub: Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)
9. [GitHub: togethercomputer/flash-attention-3](https://github.com/togethercomputer/flash-attention-3)
10. [从 Online Softmax 到 FlashAttention(腾讯云开发者社区)](https://cloud.tencent.com/developer/article/2616547)
11. [FlashAttention 深度解析:从数学原理到工程实现](https://pillumina.github.io/posts/aiinfra/11-flashattention/)
---

浙公网安备 33010602011771号