理解-Flash-Attention-在-Triton-中从头编写算法

理解 Flash Attention:在 Triton 中从头编写算法

原文:towardsdatascience.com/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton-5609f0b143ea/

免费阅读alexdremov.me

Flash Attention 是一种革命性的技术,它极大地加速了基于 Transformer 模型的注意力机制,其处理速度比原始方法快得多倍。通过巧妙地划分数据和最小化内存传输,它解决了大型语言模型经常遇到的著名的 GPU 内存瓶颈问题。

在这篇文章中,我们将深入探讨 Flash Attention 如何利用高效的I/O 感知来减少开销,然后通过在 Triton 中构建一个块稀疏注意力内核来更进一步。

💥 我将提供一个关于 Flash 注意力如何工作的简单解释。然后,我们将将在 Triton 中实现所解释的算法!

什么是注意力?

注意力机制(或缩放点积注意力)是 Transformer 模型的核心元素,这是一种解决语言建模问题的领先架构。所有流行的模型,如 GPT、LLaMA 和 BERT,都依赖于注意力。

公式相当简单:

其余的都是历史。

尽管公式看起来很简单,但其计算涉及大张量的乘法和大量数据移动。考虑到这是 Transformer 架构的核心部分,优化算法可以极大地提高模型的整体性能。

在原始实现中,注意力需要O(n²)额外的内存和O(n²)的计算时间复杂度,其中n是序列长度。这太多了!

Flash Attention

核心思想

Flash 注意力(Flash Attention)的主要思想可以用原始论文中的一句话来概括:

我们认为,一个缺失的原则是使注意力算法具有 I/O 感知性——考虑到 GPU 内存层级之间的读写操作。

也就是说,现代 GPU 有几种类型的内存:

  • SRAM – 快速,片上,小

  • HBM – 比 SRAM 慢,体积大。这就是我们通常所说的 GPU 内存。

查看下面的图像以了解不同内存类型的带宽和大小差异。

Image from FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness by Tri Dao et al.

Image from FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness by Tri Dao et al.

💡 为了进行计算,数据必须从 HBM 传输到 SRAM,而这种传输并不是没有开销的!

Flash Attention 算法提出了一种在瓦片中计算注意力的方法,而不需要显式地实现注意力分数张量:

💥 不实现矩阵意味着在任何给定时间,矩阵都没有在内存中以完整形状存在。

很容易看出,这个矩阵需要 O(n²) 的内存来存储。对于长的序列,那将是非常多的数据!所以,如果我们能够避免显式地实现这个矩阵,我们可以节省大量的内存。

然而,这个矩阵对于 transformer 训练是必要的,因为它是反向传播和梯度计算的一部分。作者提出,在反向传播期间重新计算这个矩阵会更好(再次不显式实现)。这不仅节省了大量内存,而且提供了巨大的速度提升,因为我们不需要在不同 GPU 内存类型之间传输这个巨大的矩阵。

总体而言,这种方法不仅通过考虑 GPU I/O 特性来加速计算,而且还允许处理巨大的序列长度,因为内存复杂度降低到 O(n)

瓦片注意力计算

最后要理解的是如何在瓦片中计算注意力。基本上,这意味着我们将通过处理传入的标记的小部分来计算整个序列的注意力。

好吧,计算 QK^T 在瓦片中是很简单的。考虑到注意力维度不高,我们可以加载完整的矩阵行和列,并在瓦片中进行乘法运算。

😡 是的,如果我们想要一个巨大的注意力维度,没有算法修改,Flash Attention 将无法工作。

由于维度通常很小,即使是对于巨大的模型,这种限制也是合理的。

瓦片 QK^T | 图片由作者提供

瓦片 QK^T | 图片由作者提供

因此,我们在 SRAM 中计算了 QK^T。剩下的只是应用 softmax,乘以 V,就完成了!

那就是技巧所在。

问题在于 softmax 分母需要聚合整个序列长度以归一化分数,而我们无法访问整个长度,因为我们以瓦片的形式加载数据。

为了解决这个问题,我们可以实现一个连接 softmax 算法。使用它,我们可以以“批量”模式计算 softmax:通过调整计算值与新的传入数据。

从原始文章中提取算法,我们可以定义规则来计算数据连接的 softmax。有两个向量 x1x2,我们需要计算这些向量连接 [x1, x2] 上的 softmax 分母 l(x)。如果向量的最大值是 m(x),我们可以很容易地推导出连接的 softmax 分母:

最后的等价性可以很容易地验证为

因此,现在我们得到了我们想要的东西——我们可以按块计算 softmax,然后通过执行上述公式的重新归一化来计算全局 softmax。最后要做的就是将 V 张量的块纳入其中,并继续进行相同的重新归一化(因为矩阵乘法是一个线性操作)。

而这一切都不需要将整个序列加载到内存中或具体化 QK^T

💥 注意,我们仅在块中计算 Softmax(QK^T),无需在任何时刻拥有整个矩阵。

此外,在实际算法中,为了数值稳定性,我们将计算的不是 Softmax(x),而是 Softmax(x – max(x))。我们可以这样做,因为 softmax 对常数平移是不变的。

Triton 实现

现在,我们可以在 Triton 中轻松实现概述的算法,Triton 是一个允许我们用 Python 的便捷性编写高效 GPU 内核的工具。

💡 要了解更多关于 Triton 的信息,请查看他们的官方指南。

教程 – Triton 文档

算法概述

第一步是决定我们将如何分配作业以及每个作业将加载什么数据。根据分块 softmax 算法,每个作业必须能够访问整个序列长度的 K, V。因此,每个作业将按块遍历 K, V。我们没有对 Q 块处理数量的算法限制。因此,每个作业将只加载一个 Q 块并仅与它一起工作——这样我们就可以最大化作业并行性。

内核作业数据管理 | 图片由作者提供

内核作业数据管理 | 图片由作者提供

总结来说,每个作业将加载一个单独的 Q 块,遍历 KV 中的所有块,并存储一个与 Q 块对应的输出块。

内核

剩下的就是编写实际的代码了。让我们先关注核心部分,然后再添加 Triton 特定的模板。

下面是一段带有每行解释的 Triton 伪代码。

看见了吗?很简单!

重要的是,您可以看到一旦我们理解了分块 softmax 的概念,编写这样的事情是多么简单。除此之外,从算法角度来看,没有什么是复杂的。

💥 通过实现 triton 优化,这个核函数可以更快。然而,这超出了本文的范围。

这段伪代码与实际代码非常接近。您可以通过以下链接在我的 GitHub 上找到它。我添加的只是数据管理和 PyTorch 包装器。

kernels/src/self_attention/kernel.py at main · alexdremov/kernels

❗ 如果有什么不清楚的地方,请随时提问。我在评论里 😁。

上面的代码经过广泛测试以匹配 PyTorch 的 scaled_dot_product_attention。您也可以查看测试以了解如何使用编写的内核。

基准测试

虽然我们编写了 Triton 中的内核来提高算法理解,但将性能与原始实现和 PyTorch 的 scaled_dot_product_attention 进行比较很有趣。

不同序列长度的基准实现 | 图片由作者提供

不同序列长度的基准实现 | 图片由作者提供

如预期的那样,Flash Attention 算法在性能上完全优于原始实现。同时,我用虚线标记了导致原始实现出现 CUDA 内存不足错误的长度范围。

我们看到,我们的 Triton 实现略逊于 PyTorch SDPA。但差距并不大。考虑到 PyTorch SDPA 是一个经过良好优化的 CUDA 内核,这是一个不错的结果。

基准测试代码也存放在仓库中。

kernels/benchmark/benchmark_self_attention.py 在 main · alexdremov/kernels

这个故事最初发表在 alexdremov.me 上。去看看吧!(至少,那里的 TEX 看起来更好)

结论

在这篇文章中,我介绍了 Flash Attention 算法的动机以及其算法细节。最后,我们成功地在 Triton 中从头开始实现了它,重现了论文中的速度提升。

我希望这篇文章提高了您对 Flash Attention 的理解。如果您有任何问题,请随时在下面留言。

参考文献

FlashAttention:具有 I/O 感知的快速且内存高效的精确注意力

教程 – Triton 文档

GitHub – alexdremov/kernels:有用的内核集合

posted @ 2026-03-27 10:43  布客飞龙IV  阅读(1)  评论(0)    收藏  举报