cutlass & FA3

GPU mode: Cutlass and FA 3

本次talk的大纲:

  • 复习attention和FA
  • 从高层次理解FA3算法
  • 将算法翻译成cutlass搭建的code

attention机制介绍

$$O=Softmax(QK^T)V$$
attention随着序列长度的变化是二次的scale。

naive的实现受限于内存的读写,会先将$QK^T$的结果存入HBM,之后进行softmax操作时再将其从HBM读取。

GPU的存储模式和内存层级

输入首先存储在HBM中,之后数据会移入计算单元和SRAM来进行计算,最后将结果写入HBM。

如何减少HBM的读和写

使用tiling和recomputation

挑战:

  1. 不将score matrix materialize(完整计算出来)到HBM,只计算部分输入的softmax。
  2. backward的时候没有了score matrix

方法:

  1. 使用tiling和online softmax来进行kernel fusion: 对于Q的一个给定的tile,逐块加载KV的block
  2. 重算: 不去将attn matrix 存储,而是在backward时重算。

FA2在H100上只能有35-40%的利用率

FA3

  1. H100的新指令:

    • wgmma: 更高throughput的mma原语, async, 可以由一个warpgroup(=4个连续的warps)一起执行
    • TMA: 提供gmem和smem之间更快速的loading, async, saves registers
  2. Async

    • builds on async wgmma, tma, transaction barrier
    • inter-warpgroup overlapping: warp-specialization, pingpong
    • intra-warpgroup overlapping: softmax, async matmul
  3. low-precision: fp8, in-kernel V transpose

最基本的思想是利用生产者和消费者模型,生产者主要提供对应的KV相应的分块;消费者就执行对应的计算。

我们再研究下async: overlap gemm和softmax

我们主要想让tensor cores忙的时候也做一些ex2的计算。

inter-warpgroup
img

intra-warpgroup
img

FP8

f8可以达到两倍的wgmma的吞吐,但是tradeoff了精度

fp8 wgmma: 要求操作数 smem tensors是在k-major的

标准的attention通常是bshd;对于gemm0 $QK^T$这是好的;都在d连续;但对于gemm1$PV$,我们需要转置V。解法是在producer warpgroup中transpose V。使用ldsm/stsm等指令,输入自定义的layout以及byte permute

我们也需要reshape scores accumulator的layout;fp32的accumulator不同于fp8 操作数A的layout

注意:我们能使用in-kernel transpose 来进行V的row permutation,避免使用shuffle 指令。

persistent kernels in FA

想法:launch固定数量的CTA,能够overlap 当前work tile的epilogue以及下一个work tile 的 prologue loads.

FA3 for decoding inference

对于decoding来说,q的length为1,但是context length很长。

优化:

  1. 对 KV 的序列长度进行划分;得到并行
  2. GQA pack: 将多个query heads pack为单个的query tile

flash decoding和flash attention的区别是,原始的flash attention直接对Q的长度进行分块,每个CTA计算Q块的结果;而在decoding阶段,Q的长度很小,那么我们对KV的长度分块,最后将结果累加即可。

可能想问原始的flash attention为什么不继续对KV做并行?

​Flash Attention的并行维度已饱和
训练时通常通过数据并行(Batch) + 张量并行(Heads)​ 实现多级并行,占用大量SM资源。对Q分块并行已能有效利用计算单元,无需额外引入KV分块。

​Flash Decoding需应对长KV缓存
生成式推理的KV_len可能远超GPU核心数(如数万token),必须分块并行处理才能避免单CTA过载,同时隐藏内存访问延迟。

posted @ 2025-03-28 16:43  xwher  阅读(89)  评论(0)    收藏  举报