FlashAttention简介
前置知识
在GPU进行矩阵运算的时候,内部的运算单元具有和CPU类似的存储金字塔。

如果采用经典的Attention的计算方式,需要保存中间变量S和注意力矩阵O,这样子会产生很大的现存占用,并且这些数据的传输也会占用很多带宽和内存。

FlashAttention采用分块的方式来进行计算,这样子就可以减少中间变量的存储,同时也可以减少数据的传输。

具体的思想是改变Attention的运算顺序,标准是先计算 \(S=QK,O=Softmax(S),R=OV\).
FlashAttention的计算顺序是先计算 \(R=OV,S=QK,O=Softmax(S)\).在这个过程中需要保存一些变量用于最终计算Softmax,并且在计算过程中进行分块,利用SRAM的带块,减少HBM的使用。
具体算法如下(ForWard):


浙公网安备 33010602011771号