FlashAttention简介

前置知识

在GPU进行矩阵运算的时候,内部的运算单元具有和CPU类似的存储金字塔。

img

如果采用经典的Attention的计算方式,需要保存中间变量S和注意力矩阵O,这样子会产生很大的现存占用,并且这些数据的传输也会占用很多带宽和内存。

img

FlashAttention采用分块的方式来进行计算,这样子就可以减少中间变量的存储,同时也可以减少数据的传输。

img

具体的思想是改变Attention的运算顺序,标准是先计算 \(S=QK,O=Softmax(S),R=OV\).
FlashAttention的计算顺序是先计算 \(R=OV,S=QK,O=Softmax(S)\).在这个过程中需要保存一些变量用于最终计算Softmax,并且在计算过程中进行分块,利用SRAM的带块,减少HBM的使用。
具体算法如下(ForWard):

img

img

posted @ 2023-12-16 15:47  chenfengshijie  阅读(236)  评论(0)    收藏  举报