cute 教程 04 算法
cute 教程 04 Algorithms
在这篇博客中,我们总结对Tensor操作的interface和实现.
copy
cute的copy算法主要是将src的元素拷贝到dst的元素
接口和特化
一个Tensor封装了数据类型,数据位置以及编译期已知的tensor的shape和stride;因此copy
能够基于参数类型,进行dispatch来使用任何不同的异步或同步的硬件拷贝指令。copy算法有两个主要的重载, 第一个重载的参数仅仅是src和dst的tensor; 第二个重载还有Copy_Atom
的参数,这个参数可以使得调用者override默认实现
依赖于参数类型的并行和同步
copy算法可能在每个线程中是顺序的,或者可能在一些线程集合上是并行的(cluster, block)
如果copy是并行的, 那么参与这个copy的线程集合可能需要同步来使得所有线程的copy操作是完成的;例如,如果参与的线程组成一个线程块,那么用户应该调用__syncthreads()
或者cooperative group.
copy算法也可能使用异步的拷贝,例如cp.async
;在这种情况下,用户需要执行额外的同步。更加优化的gemm实现会使用pipeline技术会使用其他操作来overlap async的copy操作。
最通用的copy实现
void copy(Tensor const& src,
Tensor & dst) {
for(int i=0; i < size(dst); ++i) {
dst(i) = src(i);
}
}
一些合理的架构无关的优化如下:
- 如果两个tensor有已知的memory spaces,并且这个space有优化的访问指令(如
cp.async
),那么会dispatch到自定义指令上 - 两个
Tensor
有静态的layout,并且element vectorization时有效的,那么将会向量化。 - 如果可能的话,验证使用的copy指令对src和dst tensor是合适的
copy_if
除了copy的参数,还有一个prediction Tensor,这个tensor有与dst和src相同的shape;只有prediction tensor中的元素非0时才会拷贝.
gemm
gemm 主要有三个参数: A
B
C
(V) x (V) => (V)
: 逐元素的乘法, dispatch fma or mma(M) x (N) => (M,N)
: vector的outer product, dispatch 4 with V = 1(M,K) x (N,K) => (M,N)
: 对每个k,dispatch 2(V,M) x (V,N) => (V,M,N)
: batched outer product;对每个m n,dispatch 1(V,M,K) x (V,N,K) => (V,M,N)
: 对每个K dispatch 4.
注意这个算法也有MMA_ATOM
,可以override默认的FMA
指令
axpby
alpha times X Plus b
fill
fill a given scalar value.
clear
fill 0