1. 简介
- 用 mma PTX 指令实现 M16N16K16 矩阵乘法
2. 代码
- 调用1:wmma + sharedM
- 调用2:wmma + sharedM + padding 避免 bankcoflict
- 调用3:mma + sharedM + swizzle 避免 bankcoflict
//A 16*16; B 16*16
//wmma 处理 half 使用 16*16*16 size 的 matrix,并使用padding 优化
//mma swizzle
#include <iostream>
#include <cuda_runtime.h>
#include "common/tester.h"
#include "common/common.h"
using namespace nvcuda;
__device__ __forceinline__ void ld_st_128bit(void *dst, void *src)
{
*reinterpret_cast<float4 *>(dst) = *reinterpret_cast<float4 *>(src);
}
//wmma + sharedM
template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_wmma_kernel(half *A, half *B, half *C)
{
__shared__ half smem_a[M * K];
__shared__ half smem_b[K * N];
__shared__ half smem_c[M * N];
int tx = threadIdx.x;
uint32_t nPerThreadLoad = M*K/32; //8
ld_st_128bit(smem_a + nPerThreadLoad * tx, A + nPerThreadLoad * tx);
ld_st_128bit(smem_b + nPerThreadLoad * tx, B + nPerThreadLoad * tx);
__syncthreads();
wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, M, N, K, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
//load_matrix_sync 底层使用 ldmatrix ptx指令的加载分块矩阵的方式,所以bankConflict 分析可以参考 mma ldmatrix 的加载数据方式
wmma::load_matrix_sync(a_frag, smem_a, K); //ldm
wmma::load_matrix_sync(b_frag, smem_b, N);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::store_matrix_sync(smem_c, c_frag, N, wmma::mem_row_major);
__syncthreads();
ld_st_128bit(C + nPerThreadLoad * tx, smem_c + nPerThreadLoad * tx);
}
//优化1:wmma + sharedM + padding 避免 bankcoflict
template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_padding_wmma_kernel(half *A,half *B,half *C)
{
// +8 OFFSET 也可以是别的参数
const uint32_t OFFSET = 8;
__shared__ half smem_a[M][K+OFFSET];
__shared__ half smem_b[K][N+OFFSET];
__shared__ half smem_c[M*N];
uint32_t tx = threadIdx.x;
uint32_t nPerThreadLoad = M*K/32; //8 128bit
ld_st_128bit(&smem_a[tx/2][tx%2*nPerThreadLoad],A+tx*nPerThreadLoad);
ld_st_128bit(&smem_b[tx/2][tx%2*nPerThreadLoad],B+tx*nPerThreadLoad);
__syncthreads();
wmma::fragment<wmma::matrix_a, M, N, K, half,wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, M, N, K, half,wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator,M,N,K,half> c_frag;
wmma::fill_fragment(c_frag,0.0f);
wmma::load_matrix_sync(a_frag,(half*)smem_a,K+OFFSET);
wmma::load_matrix_sync(b_frag,(half*)smem_b,N+OFFSET);
wmma::mma_sync(c_frag,a_frag,b_frag,c_frag);
wmma::store_matrix_sync(smem_c,c_frag,N,wmma::mem_row_major);
__syncthreads();
ld_st_128bit(C + tx*nPerThreadLoad,smem_c + nPerThreadLoad*tx);
}
// mma
#define REG(val) (*reinterpret_cast<uint32_t*>(&(val)))
#define HALF2(val) (*reinterpret_cast<half2*>(&(val)))
//ptx
__device__ __forceinline__ void ldmatrix_sync(half *dst,half *src)
{
//sm_90之后支持
// "=r"约束符用于输出操作数,= 符号表示这是一个只写操作数(输出操作数),r 表示操作数应该放在通用寄存器中
// "l"约束符用于输入操作数,表示该操作数用于提供地址信息
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(REG(dst[0])),
"=r"(REG(dst[2])),
"=r"(REG(dst[4])),
"=r"(REG(dst[6]))
: "l"(__cvta_generic_to_shared(src)));
}
__device__ __forceinline__ void ldmatrix_trans_sync(half *dst, void *src)
{
//LD.trans trans frament 块内 N格式,寄存器按列存储读取。不改变原矩阵的数据排布,只改变寄存器的读写方向。
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.trans.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(REG(dst[0])),
"=r"(REG(dst[2])),
"=r"(REG(dst[4])),
"=r"(REG(dst[6]))
: "l"(__cvta_generic_to_shared(src)));
}
__device__ __forceinline__ void mma_sync_m16n8k16(half *c, half *a, half *b)
{
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1}, "
"{%2, %3, %4, %5}, "
"{%6, %7}, "
"{%8, %9};"
: "=r"(REG(c[0])), "=r"(REG(c[2]))
: "r"(REG(a[0])),
"r"(REG(a[2])),
"r"(REG(a[4])),
"r"(REG(a[6])),
"r"(REG(b[0])),
"r"(REG(b[2])),
"r"(0),
"r"(0));
}
__device__ __forceinline__ void stmatrix_sync(half *dst, half *src)
{
//sm_100 later
asm volatile(
"stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};"
: // 无输出操作数,它从C/C++代码的角度看,是消耗了输入操作数(即寄存器%1和%2中的数据),整个操作过程并没有产生一个可供C/C++代码使用的“返回值”或“输出变量”,因此被声明为“无输出操作数”
: "l"(__cvta_generic_to_shared(dst)),
"r"(REG(src[0])),
"r"(REG(src[2])),
"r"(REG(src[4])),
"r"(REG(src[6]))
: "memory" //“memory”字段属于clobber列表的一部分,它告知编译器该汇编指令可能会修改内存内容,使用"memory"可以确保编译器不会缓存内存中的旧值,从而保证后续操作能读取到最新的数据
);
}
__device__ __forceinline__ void stmatrix_sync_(half *dst, half *src)
{
//for sm_100 之前的版本
// ! Ampere doesn't have stmatrix.sync, we should simulate it
uint64_t private_addr = (uint64_t)dst;
uint64_t shared_addr[4];
#pragma unroll
for (int i = 0; i < 4; i++)
{
//广播 i * 8 + threadIdx.x / 4 通道的 private_addr值
shared_addr[i] = __shfl_sync(0xFFFFFFFF, private_addr, i * 8 + threadIdx.x / 4);
}
#pragma unroll
for (int i = 0; i < 4; i++)
{
*(reinterpret_cast<half2 *>(shared_addr[i]) + threadIdx.x % 4) = HALF2(src[2 * i]);
}
}
// mma + sharedM + fragment的寄存器存储
template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_mma_kernel(half *A, half *B, half *C)
{
__shared__ half smem_a[M * K];
__shared__ half smem_b[K * N];
__shared__ half smem_c[M * N];
int tx = threadIdx.x;
uint32_t nPerThreadLoad = M*K/32; //8
//共享内存数据与原数据排布一致,当做矩阵的排布
ld_st_128bit(smem_a + nPerThreadLoad * tx, A + nPerThreadLoad * tx);
ld_st_128bit(smem_b + nPerThreadLoad * tx, B + nPerThreadLoad * tx);
__syncthreads();
wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, M, N, K, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
//共享内存数据转存到 fragment 寄存器中,排布不变
// mma 要求 A按行存储;B(因为是16*8需要trans)输入应按列存储
//此处 row col 的计算方法是将原共享内存矩阵数组转化为 fragment 存储格式的算法,与行列互换无关
uint32_t row = tx%K;
uint32_t col = tx/K;
//共享内存地址固定写法
// fragment.x 指向寄存器数组,第一个寄存器为 fragment.x[0];
// 这里各个线程统一传入 fragment.x,底层可能另有处理
ldmatrix_sync(a_frag.x, smem_a + row*K + col*8);
ldmatrix_trans_sync(b_frag.x, smem_b + row*K + col*8 );
#if 0 //fragment_.num_storage_elements:32
if(tx ==0)
printf("size: %ld, addr: %p\n",sizeof(a_frag.x),®(a_frag.x[0]));
if(tx ==1)
printf("size: %ld, addr: %p\n",sizeof(a_frag.x),®(a_frag.x[1]));
#endif
mma_sync_m16n8k16(c_frag.x, a_frag.x, b_frag.x);
// 偏移量4 = 4个寄存器 × 2元素/寄存器 = 8个FP16元素
mma_sync_m16n8k16(c_frag.x+4, a_frag.x, b_frag.x+4);
stmatrix_sync(smem_c + row*K + col*nPerThreadLoad,c_frag.x);
__syncthreads();
ld_st_128bit(C + nPerThreadLoad * tx, smem_c + nPerThreadLoad * tx);
}
// mma + sharedM + swizzle 避免 bankcoflict
/*
原数据地址与共享内存地址转换
对于gaddr 在共享内存的原地址addr: B 表示列需要的二进制位数,M 表示一个块内的元素索引需要的二进制位数,S 表示 addr 地址按块划分的行列坐标需要位移的二进制位数=M
*/
template<uint32_t B,uint32_t M,uint32_t S>
__device__ __forceinline__ uint32_t swizzle(uint32_t srcAddr)
{
//行列坐标值取后三位进行异或运算
//掩码用来获取行坐标
uint32_t mask = (1 << B - 1) << M;
uint32_t addr = ((srcAddr >> S) & mask) ^ srcAddr;
return addr;
}
template<uint32_t M,uint32_t N,uint32_t K>
__global__ void sharedM_mma_swizzle_kernel(half *A, half *B, half *C)
{
__shared__ half smem_a[M * K];
__shared__ half smem_b[K * N];
__shared__ half smem_c[M * N];
int tx = threadIdx.x;
uint32_t nPerThreadLoad = M*K/32; //8
uint32_t offset = tx * nPerThreadLoad;
//根据线程在全局内存中取数据逻辑地址计算共享内存物理地址存放数据
//16行 3位表示;2列 1位表示;一块8个元素 3位表示
uint32_t g2sAddr = swizzle<3,1,3>(offset);
ld_st_128bit(smem_a + g2sAddr, A+offset);
ld_st_128bit(smem_b + g2sAddr, B+offset);
__syncthreads();
wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, M, N, K, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
//从共享内存取出数据,计算线程原本要加载数据逻辑坐标,通过 swizzle 计算物理坐标
//保持原矩阵布局不变的 row col 计算
uint32_t row = tx%16;
uint32_t col = tx/16;
//根据当前线程提供数据地址逻辑计算在共享内存物理地址
uint32_t s2rAddr = swizzle<3,1,3>(row*M+col*nPerThreadLoad);
ldmatrix_sync(a_frag.x,smem_a+s2rAddr);
ldmatrix_trans_sync(b_frag.x,smem_b+s2rAddr);
#if 1 //使用 wmma api计算
wmma::mma_sync(c_frag,a_frag,b_frag,c_frag);
#else //使用mma ptx 指令计算
mma_sync_m16n8k16(c_frag.x,a_frag.x,b_frag.x);
mma_sync_m16n8k16(c_frag.x+4,a_frag.x,b_frag.x+4);
#endif
stmatrix_sync(smem_c+ s2rAddr,c_frag.x);
//从共享内存取数据,逻辑地址为 tx*nPerThreadLoad,物理地址为 g2sAddr
ld_st_128bit(C+tx*nPerThreadLoad,smem_c+g2sAddr);
}
//以下为调用
void sharedM_wmma(half *A, half *B, half *C, int M, int N, int K)
{
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
dim3 block(32);
dim3 grid(1);
sharedM_wmma_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}
void sharedM_padding_wmma(half *A, half *B, half *C, int M, int N, int K)
{
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
dim3 block(32);
dim3 grid(1);
sharedM_padding_wmma_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}
void sharedM_mma(half *A, half *B, half *C, int M, int N, int K)
{
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
dim3 block(32);
dim3 grid(1);
sharedM_mma_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}
void sharedM_mma_swizzle(half *A, half *B, half *C, int M, int N, int K)
{
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
dim3 block(32);
dim3 grid(1);
sharedM_mma_swizzle_kernel<WMMA_M,WMMA_N,WMMA_K><<<grid, block>>>(A, B, C);
}
int main(int argc, char *argv[])
{
// {
// Tester tester(16, 16, 16, 1, 1, 100, true);
// tester.evaluate(shareM_wmma, "sharedM_wmma");
// }
// {
// Tester tester(16, 16, 16, 1, 1, 100, true);
// tester.evaluate(sharedM_padding_wmma, "sharedM_padding_wmma");
// }
// {
// Tester tester(16, 16, 16, 1, 1, 100, true);
// tester.evaluate(sharedM_mma, "sharedM_mma");
// }
{
Tester tester(16, 16, 16, 1, 1, 100, true);
tester.evaluate(sharedM_mma_swizzle, "sharedM_mma_swizzle");
}
return 0;
}