cutlass学习
参考
- 使用cutlass实现flash-attention https://arxiv.org/html/2312.11918v1
- TensorCoreGemm demo https://github.com/NVIDIA-developer-blog/code-samples/blob/master/posts/tensor-cores/simpleTensorCoreGEMM.cu
阅读demo
- https://github.com/NVIDIA/cutlass/blob/main/examples/12_gemm_bias_relu/gemm_bias_relu.cu
- https://github.com/NVIDIA/cutlass/blob/main/examples/35_gemm_softmax/gemm_softmax.cu
- 完成梳理清楚如何让cutlass处理flag的demo
MMA 抽象

Atom tile 和 value tile
atom tile和value tile,在之前的cutlass定义中atom tile又叫thr_layout,value tile叫val_layout。
atom tile主要是指上述的op原子计算在M和N方向上各自拓展多少次,在tensor core中,每个op都是一个warp,即有32个线程;M和N方向各拓展两次,就有4个warp,128个线程,所以也叫thr_layout;
value tile主要是指拓展后的atom,在M和N方向上继续重复多少次计算,因为是重复,内部是loop操作,所以不会占用更多的线程,只会扩大处理的矩阵大小。

MMA_Atom指硬件的原始能力,只处理上图的计算
中间就是刚才说到的拓展,实线是atom的warp级拓展,即M轴拓展2次,N轴拓展2次,一共4个,每个warp固定32个线程,一共128个;虚线是value layout级别的重复计算,不会增大线程数。在这个基础上就可以定义出TiledMMA,每个TiledMMA处理 A:(32,16),B(16, 32),C:(32,32)的大小。
demo1 解析
#include <cuda.h> #include <cublas_v2.h> #include <stdlib.h> #include <cute/tensor.hpp> template <typename T> void gen_rand_data(T *data, int n); template <typename T, int kTileM, int kTileN, int kTileK, typename TiledMMA> __global__ void gemm_simple(T *Cptr, const T *Aptr, const T *Bptr, int m, int n, int k) { using namespace cute; # 初始化Tensor Tensor A = make_tensor(make_gmem_ptr(Aptr), make_shape(m, k), make_stride(k, Int<1>{})); Tensor B = make_tensor(make_gmem_ptr(Bptr), make_shape(n, k), make_stride(k, Int<1>{})); Tensor C = make_tensor(make_gmem_ptr(Cptr), make_shape(m, n), make_stride(n, Int<1>{})); int ix = blockIdx.x; int iy = blockIdx.y; # 根据kTileM和kTileK分快处理 Tensor gA = local_tile(A, make_tile(Int<kTileM>{}, Int<kTileK>{}), make_coord(iy, _)); Tensor gB = local_tile(B, make_tile(Int<kTileN>{}, Int<kTileK>{}), make_coord(ix, _)); Tensor gC = local_tile(C, make_tile(Int<kTileM>{}, Int<kTileN>{}), make_coord(iy, ix)); // gA(kTileM, kTileK, num_tile_k) // gB(kTileN, kTileK, num_tile_k) // gC(kTileM, kTileN) // block级别的MMA结构 TiledMMA tiled_mma; // 当前线程的MMA切片 auto thr_mma = tiled_mma.get_slice(threadIdx.x); auto tAgA = thr_mma.partition_A(gA); // (MMA, MMA_M, MMA_K, num_tile_k) auto tBgB = thr_mma.partition_B(gB); // (MMA, MMA_N, MMA_K, num_tile_k) auto tCgC = thr_mma.partition_C(gC); // (MMA, MMA_M, MMA_N) // 定义线程访问全局内存的范围,用于数据加载 // auto tArA = thr_mma.partition_fragment_A(gA(_, _, 0)); // (MMA, MMA_M, MMA_K) auto tBrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // (MMA, MMA_N, MMA_K) auto tCrC = thr_mma.partition_fragment_C(gC(_, _)); // (MMA, MMA_M, MMA_N) clear(tCrC); int num_tile_k = size<2>(gA); // K纬度切分计算 #pragma unroll 1 for(int itile = 0; itile < num_tile_k; ++itile) { // 为实际计算准备寄存器空间,片段大小匹配Tensor-core要求 cute::copy(tAgA(_, _, _, itile), tArA); cute::copy(tBgB(_, _, _, itile), tBrB); cute::gemm(tiled_mma, tCrC, tArA, tBrB, tCrC); } cute::copy(tCrC, tCgC); } int main() { srand(10086); using T = cute::half_t; using namespace cute; T *Cptr; T *Aptr; T *Bptr; int m = 81920; int n = 256; int k = 256; cudaMalloc(&Cptr, sizeof(T) * m * n); cudaMalloc(&Aptr, sizeof(T) * m * k); cudaMalloc(&Bptr, sizeof(T) * k * n); T *Aptr_host; T *Bptr_host; Aptr_host = (T*)malloc(sizeof(T) * m * k); Bptr_host = (T*)malloc(sizeof(T) * n * k); gen_rand_data(Aptr_host, m * k); gen_rand_data(Bptr_host, n * k); cudaMemcpy(Aptr, Aptr_host, sizeof(T) * m * k, cudaMemcpyHostToDevice); cudaMemcpy(Bptr, Bptr_host, sizeof(T) * n * k, cudaMemcpyHostToDevice); using mma_op = SM80_16x8x16_F16F16F16F16_TN; using mma_traits = MMA_Traits<mma_op>; using mma_atom = MMA_Atom<mma_traits>; // atom_tile using MMA = decltype(make_tiled_mma(mma_atom{}, make_layout(Shape<_2, _2, _1>{}), // atom tile make_layout(Shape<_1, _2, _1>{})));// value tile constexpr int kTileM = 128; constexpr int kTileN = 128; constexpr int kTileK = 32; // sizeof(MMA{}) = 128,和kTileM相等 // 这里128怎么计算的?一个warp=32,atom_tile再M,N两个维度拓展各两倍,所以是32 * 4=128 printf("mma %d\n", (int)size(MMA{})); dim3 block(size(MMA{})); printf("block %d\n", block.x); dim3 grid(n / kTileN, m / kTileM); for (int i = 0; i < 100; ++i) { gemm_simple<T, kTileM, kTileN, kTileK, MMA><<<grid, block>>>(Cptr, Aptr, Bptr, m, n, k); } cudaDeviceSynchronize(); auto err = cudaGetLastError(); printf("err = %d, str = %s\n", err, cudaGetErrorString(err)); // cublas T *Cptr_cublas; cudaMalloc(&Cptr_cublas, sizeof(T) * m * n); cublasHandle_t handle; cublasCreate(&handle); half alpha = half(1.f); half beta = half(0.f); for (int i = 0; i < 100; ++i) { cublasStatus_t ret = cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, (half *)Bptr, k, (half *)Aptr, k, &beta, (half *)Cptr_cublas, n); if (ret != CUBLAS_STATUS_SUCCESS) { printf("blas err = %d, str = %s\n", ret, cublasGetStatusString(ret)); } } cudaDeviceSynchronize(); err = cudaGetLastError(); printf("err = %d, str = %s\n", err, cudaGetErrorString(err)); T *Cptr_host; T *Cptr_cublas_host; Cptr_host = (T*)malloc(sizeof(T) * m * n); Cptr_cublas_host = (T*)malloc(sizeof(T) * m * n); // compare cudaMemcpy(Cptr_host, Cptr, sizeof(T) * m * n, cudaMemcpyDeviceToHost); cudaMemcpy(Cptr_cublas_host, Cptr_cublas, sizeof(T) * m * n, cudaMemcpyDeviceToHost); float threshold = 0.1; for (int i = 0; i < m * n; ++i) { float v1 = Cptr_host[i]; float v2 = Cptr_cublas_host[i]; if (fabs(v2 - v1) > threshold) { printf("v1 = %f, v2 = %f\n", v1, v2); } } Tensor tensor_C = make_tensor(Cptr_host, make_shape(m, n), make_stride(n, 1)); Tensor tensor_C_cublas = make_tensor(Cptr_cublas_host, make_shape(m, n), make_stride(n, 1)); auto tile = make_tile(8, 8); auto coor = make_coord(0, 0); Tensor tc1 = local_tile(tensor_C, tile, coor); Tensor tc1_cublas = local_tile(tensor_C_cublas, tile, coor); print_tensor(tc1); print_tensor(tc1_cublas); } template <typename T> void gen_rand_data(T *data, int n) { for (int i = 0; i < n; ++i) { float v = (rand() % 200 - 100) * 0.01; data[i] = v; } }
demo2 解析
#include <cublas_v2.h> #include <cuda.h> #include <stdarg.h> #include <stdio.h> #include <cute/tensor.hpp> #include "detail/cublaslt-gemm.h" #include "detail/data.h" template <typename Config> __global__ void /* __launch_bounds__(128, 1) */ gemm_multi_stage(void *Dptr, const void *Aptr, const void *Bptr, int m, int n, int k) { using namespace cute; using X = Underscore; using T = typename Config::T; using SmemLayoutA = typename Config::SmemLayoutA; using SmemLayoutB = typename Config::SmemLayoutB; using SmemLayoutC = typename Config::SmemLayoutC; using TiledMMA = typename Config::MMA; using S2RCopyAtomA = typename Config::S2RCopyAtomA; using S2RCopyAtomB = typename Config::S2RCopyAtomB; using G2SCopyA = typename Config::G2SCopyA; using G2SCopyB = typename Config::G2SCopyB; using R2SCopyAtomC = typename Config::R2SCopyAtomC; using S2GCopyAtomC = typename Config::S2GCopyAtomC; using S2GCopyC = typename Config::S2GCopyC; constexpr int kTileM = Config::kTileM; constexpr int kTileN = Config::kTileN; constexpr int kTileK = Config::kTileK; constexpr int kStage = Config::kStage; extern __shared__ T shm_data[]; T *Ashm = shm_data; T *Bshm = shm_data + cute::cosize(SmemLayoutA{}); int idx = threadIdx.x; int ix = blockIdx.x; int iy = blockIdx.y; // use Tensor notation to represent device pointer + dimension Tensor A = make_tensor(make_gmem_ptr((T *)Aptr), make_shape(m, k), make_stride(k, Int<1>{})); // (M, K) Tensor B = make_tensor(make_gmem_ptr((T *)Bptr), make_shape(n, k), make_stride(k, Int<1>{})); // (N, K) Tensor D = make_tensor(make_gmem_ptr((T *)Dptr), make_shape(m, n), make_stride(n, Int<1>{})); // (M, N) // slice the tensor to small one which is used for current thread block. Tensor gA = local_tile(A, make_tile(Int<kTileM>{}, Int<kTileK>{}), make_coord(iy, _)); // (kTileM, kTileK, k) Tensor gB = local_tile(B, make_tile(Int<kTileN>{}, Int<kTileK>{}), make_coord(ix, _)); // (kTileN, kTileK, k) Tensor gD = local_tile(D, make_tile(Int<kTileM>{}, Int<kTileN>{}), make_coord(iy, ix)); // (kTileM, kTileN) // shared memory auto sA = make_tensor(make_smem_ptr(Ashm), SmemLayoutA{}); // (kTileM, kTileK, kStage) auto sB = make_tensor(make_smem_ptr(Bshm), SmemLayoutB{}); // (kTileN, kTileK, kStage) // dispatch TileA/TileB/TileC mma tensor into thread fragment via partition // method TiledMMA tiled_mma; auto thr_mma = tiled_mma.get_slice(idx); auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // (MMA, MMA_M, MMA_K) auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // (MMA, MMA_N, MMA_K) auto tCrD = thr_mma.partition_fragment_C(gD); // (MMA, MMA_M, MMA_N) // fill zero for accumulator clear(tCrD); // gmem -cp.async-> shm -ldmatrix-> reg auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma); auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(idx); auto tAsA = s2r_thr_copy_a.partition_S(sA); // ? (CPY, CPY_M, CPY_K, kStage) auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA); // ? (CPY, CPY_M, CPY_K) auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma); auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(idx); auto tBsB = s2r_thr_copy_b.partition_S(sB); // ? (CPY, CPY_M, CPY_K, kStage) auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB); // ? (CPY, CPY_M, CPY_K) G2SCopyA g2s_tiled_copy_a; auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(idx); auto tAgA_copy = g2s_thr_copy_a.partition_S(gA); // (CPY, CPY_M, CPY_K, k) auto tAsA_copy = g2s_thr_copy_a.partition_D(sA); // (CPY, CPY_M, CPY_K, kStage) G2SCopyB g2s_tiled_copy_b; auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(idx); auto tBgB_copy = g2s_thr_copy_b.partition_S(gB); // (CPY, CPY_N, CPY_K, k) auto tBsB_copy = g2s_thr_copy_b.partition_D(sB); // (CPY, CPY_N, CPY_K, kStage) int itile_to_read = 0; int ismem_read = 0; int ismem_write = 0; // submit kStage - 1 tile // gmem -> shm #pragma unroll for (int istage = 0; istage < kStage - 1; ++istage) { cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, istage), tAsA_copy(_, _, _, istage)); cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, istage), tBsB_copy(_, _, _, istage)); cp_async_fence(); ++itile_to_read; ++ismem_write; } // wait one submitted gmem->smem done cp_async_wait<kStage - 2>(); __syncthreads(); int ik = 0; // smem -> reg cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik, ismem_read), tCrA_view(_, _, ik)); cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik, ismem_read), tCrB_view(_, _, ik)); // loop over k: i. load tile, ii. mma int ntile = k / kTileK; #pragma unroll 1 for (int itile = 0; itile < ntile; ++itile) { int nk = size<2>(tCrA); #pragma unroll for (int ik = 0; ik < nk; ++ik) { int ik_next = (ik + 1) % nk; if (ik == nk - 1) { cp_async_wait<kStage - 2>(); __syncthreads(); ismem_read = (ismem_read + 1) % kStage; } // shm -> reg s[itile][ik + 1] -> r[ik + 1] cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik_next, ismem_read), tCrA_view(_, _, ik_next)); cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik_next, ismem_read), tCrB_view(_, _, ik_next)); if (ik == 0) { if (itile_to_read < ntile) { cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, itile_to_read), tAsA_copy(_, _, _, ismem_write)); cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, itile_to_read), tBsB_copy(_, _, _, ismem_write)); ++itile_to_read; ismem_write = (ismem_write + 1) % kStage; } cp_async_fence(); } cute::gemm(tiled_mma, tCrD, tCrA(_, _, ik), tCrB(_, _, ik), tCrD); } // for ik } // itile // use less shared memory as a scratchpad tile to use large wide instuction // Dreg -> shm -> reg -> global auto sC = make_tensor(sA(_, _, ismem_read).data(), SmemLayoutC{}); auto r2s_tiled_copy_c = make_tiled_copy_C(R2SCopyAtomC{}, tiled_mma); auto r2s_thr_copy_c = r2s_tiled_copy_c.get_slice(idx); auto tCrC_r2s = r2s_thr_copy_c.retile_S(tCrD); // (CPY, CPY_M, CPY_N) auto tCsC_r2s = r2s_thr_copy_c.partition_D(sC); // (CPY, _1, _1, pipe) S2GCopyC s2g_tiled_copy_c; auto s2g_thr_copy_c = s2g_tiled_copy_c.get_thread_slice(idx); auto tCsC_s2g = s2g_thr_copy_c.partition_S(sC); // (CPY, _1, _1, pipe) auto tCgC_s2g = s2g_thr_copy_c.partition_D(gD); // (CPY, CPY_M, CPY_N) auto tCgC_s2gx = group_modes<1, 3>(tCgC_s2g); // (CPY_, CPY_MN) auto tCrC_r2sx = group_modes<1, 3>(tCrC_r2s); // (CPY_, CPY_MN) int step = size<3>(tCsC_r2s); // pipe #pragma unroll for (int i = 0; i < size<1>(tCrC_r2sx); i += step) { // reg -> shm #pragma unroll for (int j = 0; j < step; ++j) { // we add a temp tensor to cope with accumulator and output data type // difference auto t = make_tensor_like<T>(tCrC_r2sx(_, i + j)); cute::copy(tCrC_r2sx(_, i + j), t); cute::copy(r2s_tiled_copy_c, t, tCsC_r2s(_, 0, 0, j)); } __syncthreads(); #pragma unroll // shm -> global for (int j = 0; j < step; ++j) { cute::copy(s2g_tiled_copy_c, tCsC_s2g(_, 0, 0, j), tCgC_s2gx(_, i + j)); } __syncthreads(); } } namespace config { using namespace cute; template <typename T_, int kTileM_ = 128, int kTileN_ = 128, int kTileK_ = 32, int kStage_ = 5, int kSmemLayoutCBatch_ = 2, typename ComputeType = T_> struct GemmConfig { using T = T_; // tile configuration static constexpr int kTileM = kTileM_; static constexpr int kTileN = kTileN_; static constexpr int kTileK = kTileK_; static constexpr int kStage = kStage_; static constexpr int kSmemLayoutCBatch = kSmemLayoutCBatch_; static constexpr int kShmLoadSwizzleM = 3; static constexpr int kShmLoadSwizzleS = 3; static constexpr int kShmLoadSwizzleB = 3; using SmemLayoutAtom = decltype(composition( Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{}, make_layout(make_shape(Int<8>{}, Int<kTileK>{}), make_stride(Int<kTileK>{}, Int<1>{})))); using SmemLayoutA = decltype( tile_to_shape(SmemLayoutAtom{}, make_shape(Int<kTileM>{}, Int<kTileK>{}, Int<kStage>{}))); using SmemLayoutB = decltype( tile_to_shape(SmemLayoutAtom{}, make_shape(Int<kTileN>{}, Int<kTileK>{}, Int<kStage>{}))); using mma_op = SM80_16x8x16_F16F16F16F16_TN; using mma_traits = MMA_Traits<mma_op>; using mma_atom = MMA_Atom<mma_traits>; static constexpr int kMmaEURepeatM = 2; static constexpr int kMmaEURepeatN = 2; static constexpr int kMmaEURepeatK = 1; using mma_atom_shape = mma_traits::Shape_MNK; static constexpr int kMmaPM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{}); static constexpr int kMmaPN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{}); static constexpr int kMmaPK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{}); using MMA_EU_RepeatT = decltype(make_layout(make_shape( Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{}))); using MMA_P_T = Tile<Int<kMmaPM>, Int<kMmaPN>, Int<kMmaPK>>; using MMA = decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_P_T{})); using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>; using g2s_copy_traits = Copy_Traits<g2s_copy_op>; using g2s_copy_atom = Copy_Atom<g2s_copy_traits, T>; using G2SCopyA = decltype(make_tiled_copy(g2s_copy_atom{}, make_layout(make_shape(Int<32>{}, Int<4>{}), make_stride(Int<4>{}, Int<1>{})), make_layout(make_shape(Int<1>{}, Int<8>{})))); using G2SCopyB = G2SCopyA; // shared memory to register copy using s2r_copy_op = SM75_U32x4_LDSM_N; using s2r_copy_traits = Copy_Traits<s2r_copy_op>; using s2r_copy_atom = Copy_Atom<s2r_copy_traits, T>; using S2RCopyAtomA = s2r_copy_atom; using S2RCopyAtomB = s2r_copy_atom; // epilogue: register to global via shared memory using SmemLayoutAtomC = decltype(composition( Swizzle<2, 3, 3>{}, make_layout(make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}), make_stride(Int<kMmaPN>{}, Int<1>{})))); using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}, Int<kSmemLayoutCBatch>{}))); static_assert(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) >= size(SmemLayoutC{}), "C shared memory request is large than A's one pipe"); using R2SCopyAtomC = Copy_Atom<UniversalCopy<int>, T>; using S2GCopyAtomC = Copy_Atom<UniversalCopy<cute::uint128_t>, T>; using S2GCopyC = decltype(make_tiled_copy(S2GCopyAtomC{}, make_layout(make_shape(Int<32>{}, Int<4>{}), make_stride(Int<4>{}, Int<1>{})), make_layout(make_shape(Int<1>{}, Int<8>{})))); static constexpr int kThreadNum = size(MMA{}); static constexpr int shm_size_AB = cute::cosize(SmemLayoutA{}) + cute::cosize(SmemLayoutB{}); static constexpr int shm_size_C = cute::cosize(SmemLayoutC{}); static constexpr int kShmSize = cute::max(shm_size_AB, shm_size_C) * sizeof(T); }; } // namespace config int main(int argc, char *argv[]) { using T = cute::half_t; using namespace cute; using X = Underscore; srand(10086); cublasHandle_t handle; cublasCreate(&handle); int cublas_version; cublasGetVersion_v2(handle, &cublas_version); printf("cuBLAS version: %d\n", cublas_version); // default; int M = 81920; int N = 256; int K = 256; int enable_cpu = 0; int enable_cublaslt = 1; int nt = 11; using ComputeType = T; T *Aptr; T *Bptr; T *Dptr; T *Dptr_cublas; T *Dptr_cublaslt; T *Aptr_host; T *Bptr_host; T *Dptr_host; T *Dptr_host_cpu; T *Dptr_host_blas; T *Dptr_host_cublaslt; Aptr_host = (T *)malloc(sizeof(T) * M * K); Bptr_host = (T *)malloc(sizeof(T) * N * K); Dptr_host = (T *)malloc(sizeof(T) * M * N); Dptr_host_cpu = (T *)malloc(sizeof(T) * M * N); Dptr_host_blas = (T *)malloc(sizeof(T) * M * N); Dptr_host_cublaslt = (T *)malloc(sizeof(T) * M * N); cudaMalloc(&Aptr, sizeof(T) * M * K); cudaMalloc(&Bptr, sizeof(T) * N * K); cudaMalloc(&Dptr, sizeof(T) * M * N); cudaMalloc(&Dptr_cublas, sizeof(T) * M * N); cudaMalloc(&Dptr_cublaslt, sizeof(T) * M * N); auto tA = make_tensor(Aptr_host, make_shape(M, K), make_stride(K, 1)); auto tB = make_tensor(Bptr_host, make_shape(N, K), make_stride(K, 1)); auto tD = make_tensor(Dptr_host, make_shape(M, N), make_stride(N, 1)); cpu_rand_data(&tA); cpu_rand_data(&tB); clear(tD); cudaMemcpy(Aptr, Aptr_host, sizeof(T) * M * K, cudaMemcpyHostToDevice); cudaMemcpy(Bptr, Bptr_host, sizeof(T) * N * K, cudaMemcpyHostToDevice); cudaMemcpy(Dptr, Dptr_host, sizeof(T) * M * N, cudaMemcpyHostToDevice); cudaMemset(Dptr_cublas, 0, sizeof(T) * M * N); cudaMemset(Dptr_cublaslt, 0, sizeof(T) * M * N); CublasLtGemm<T, ComputeType> cublaslt_gemm; if (enable_cublaslt) { cublaslt_gemm.init(Dptr_cublaslt, Bptr, Aptr, N, M, K); } config::GemmConfig<T, 128, 128, 32, 3> gemm_config; print(typename decltype(gemm_config)::MMA{}); dim3 block = gemm_config.kThreadNum; dim3 grid((N + gemm_config.kTileN - 1) / gemm_config.kTileN, (M + gemm_config.kTileM - 1) / gemm_config.kTileM); int shm_size = gemm_config.kShmSize; half alpha = 1.f; half beta = 0.f; for (int it = 0; it < nt; ++it) { // blas cudaMemset(Dptr_cublas, 0, sizeof(T) * M * N); cublasStatus_t ret = cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, K, &alpha, (half *)Bptr, K, (half *)Aptr, K, &beta, (half *)Dptr_cublas, N); if (ret != CUBLAS_STATUS_SUCCESS) { printf("cublas err = %d, str = %s\n", ret, cublasGetStatusString(ret)); } if (enable_cublaslt) { cudaMemset(Dptr_cublaslt, 0, sizeof(T) * M * N); cublaslt_gemm.run(); } // multi-stage cudaMemset(Dptr, 0, sizeof(T) * M * N); cudaFuncSetAttribute(gemm_multi_stage<decltype(gemm_config)>, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size); gemm_multi_stage<decltype(gemm_config)> <<<grid, block, shm_size>>>(Dptr, Aptr, Bptr, M, N, K); } cudaMemcpy(Dptr_host, Dptr, sizeof(T) * M * N, cudaMemcpyDeviceToHost); cudaMemcpy(Dptr_host_blas, Dptr_cublas, sizeof(T) * M * N, cudaMemcpyDeviceToHost); cudaMemcpy(Dptr_host_cublaslt, Dptr_cublaslt, sizeof(T) * M * N, cudaMemcpyDeviceToHost); cudaDeviceSynchronize(); auto err = cudaGetLastError(); printf("block = (%d, %d), gird = (%d, %d), shm = %d\n", block.x, block.y, grid.x, grid.y, shm_size); if (err == cudaSuccess) { printf("err = %d, str = %s\n", err, cudaGetErrorString(err)); } else { printf_fail("err = %d, str = %s\n", err, cudaGetErrorString(err)); } gpu_compare(Dptr, Dptr_cublas, M * N); if (enable_cublaslt) { gpu_compare(Dptr, Dptr_cublaslt, M * N); } auto tD_host = make_tensor(Dptr_host, make_shape(M, N), make_stride(N, 1)); auto tD_host_cpu = make_tensor(Dptr_host_cpu, make_shape(M, N), make_stride(N, 1)); auto tD_host_blas = make_tensor(Dptr_host_blas, make_shape(M, N), make_stride(N, 1)); auto tD_host_cublaslt = make_tensor(Dptr_host_cublaslt, make_shape(M, N), make_stride(N, 1)); if (enable_cpu) { cpu_gemm(&tD_host_cpu, tA, tB); cpu_compare(tD_host_cpu, tD_host, 0.1f); } auto tile = make_tile(min(8, M), min(8, N)); auto t32x32 = local_tile(tD_host, tile, make_coord(0, 0)); auto t32x32_cpu = local_tile(tD_host_cpu, tile, make_coord(0, 0)); auto t32x32_blas = local_tile(tD_host_blas, tile, make_coord(0, 0)); auto t32x32_cublaslt = local_tile(tD_host_cublaslt, tile, make_coord(0, 0)); printf("M = %d, N = %d, K = %d\n", M, N, K); printf("our-impl:\n"); print_tensor(t32x32); if (enable_cpu) { printf("cpu:\n"); print_tensor(t32x32_cpu); } printf("cublas:\n"); print_tensor(t32x32_blas); if (enable_cublaslt) { printf("cublaslt:\n"); print_tensor(t32x32_cublaslt); } }
关键点如下:
- shared memory 寄存器相关语法和layout梳理
- pipeline流水图??
- cp_async_fence(); && cp_async_wait<kStage - 2>();
-
在流水线GEMM实现中,N 的值通常设置为:
cp_async_wait<kStage - 2>();
这里的逻辑是:
- 流水线有 kStage 级缓冲区
- 任何时候最多有 kStage - 1 个未完成的异步操作(因为一个缓冲区正在使用)
- 设置 N = kStage - 2 意味着:
- 等待直到至少 1个 异步操作完成
- 因为 未完成操作 ≤ kStage-2 等价于 至少完成1个
参考
- MMA 抽象 https://zhuanlan.zhihu.com/p/663092747
- MMA 抽象https://blog.csdn.net/qq_32742009/article/details/136787966
- Copy 抽象 https://zhuanlan.zhihu.com/p/666232173
- 原文链接https://blog.csdn.net/qq_32742009/article/details/136787966
- https://zhuanlan.zhihu.com/p/667521327
浙公网安备 33010602011771号