cute 教程 04 算法

cute 教程 04 Algorithms

在这篇博客中,我们总结对Tensor操作的interface和实现.

copy

cute的copy算法主要是将src的元素拷贝到dst的元素

接口和特化

一个Tensor封装了数据类型,数据位置以及编译期已知的tensor的shape和stride;因此copy能够基于参数类型,进行dispatch来使用任何不同的异步或同步的硬件拷贝指令。copy算法有两个主要的重载, 第一个重载的参数仅仅是src和dst的tensor; 第二个重载还有Copy_Atom的参数,这个参数可以使得调用者override默认实现

依赖于参数类型的并行和同步

copy算法可能在每个线程中是顺序的,或者可能在一些线程集合上是并行的(cluster, block)

如果copy是并行的, 那么参与这个copy的线程集合可能需要同步来使得所有线程的copy操作是完成的;例如,如果参与的线程组成一个线程块,那么用户应该调用__syncthreads()或者cooperative group.

copy算法也可能使用异步的拷贝,例如cp.async;在这种情况下,用户需要执行额外的同步。更加优化的gemm实现会使用pipeline技术会使用其他操作来overlap async的copy操作。

最通用的copy实现

void copy(Tensor const& src,
    Tensor & dst) {
        for(int i=0; i < size(dst); ++i) {
            dst(i) = src(i);
        }
    }

一些合理的架构无关的优化如下:

  1. 如果两个tensor有已知的memory spaces,并且这个space有优化的访问指令(如 cp.async),那么会dispatch到自定义指令上
  2. 两个Tensor有静态的layout,并且element vectorization时有效的,那么将会向量化。
  3. 如果可能的话,验证使用的copy指令对src和dst tensor是合适的

copy_if

除了copy的参数,还有一个prediction Tensor,这个tensor有与dst和src相同的shape;只有prediction tensor中的元素非0时才会拷贝.

gemm

gemm 主要有三个参数: A B C

  1. (V) x (V) => (V): 逐元素的乘法, dispatch fma or mma
  2. (M) x (N) => (M,N): vector的outer product, dispatch 4 with V = 1
  3. (M,K) x (N,K) => (M,N): 对每个k,dispatch 2
  4. (V,M) x (V,N) => (V,M,N): batched outer product;对每个m n,dispatch 1
  5. (V,M,K) x (V,N,K) => (V,M,N): 对每个K dispatch 4.

注意这个算法也有MMA_ATOM,可以override默认的FMA指令

axpby

alpha times X Plus b

fill

fill a given scalar value.

clear

fill 0

posted @ 2025-03-28 21:54  xwher  阅读(60)  评论(0)    收藏  举报