LeetGPU入门教程 (CUDA guide最佳实践)

5.20更 终于搞定了Tritontma kernel,当时还怀疑是卡坏了,自己没好好读文档的问题。建立了LeetGPU Github仓库,看代码更方便。

本博客原文地址:https://www.cnblogs.com/BobHuang/p/18755313,原文体验更佳

工作中写过一点点算子,现在来刷刷OJ https://leetgpu.com/challenges,其实就是leetcode for GPU

目前平台有28道题,支持CUDAPytorchTriton共3种写算子的框架。订阅是5美刀/月,不订阅看不到运行时间和百分比排位,目前还看不到Nsight Compute的性能分析,但是我已充值。评测卡有Hopper架构的H100Ampere架构的A100Turing架构的Tesla T4

我这里较为深入得探讨4题,剩下的你可以尝试下,诚然Triton是最好写且性能过得去的。

一、Vector Addition

OJ入门均是a+bGPU 中就是向量C=向量A + 向量B,以下是为我们提供的示例代码。

#include "solve.h"
#include <cuda_runtime.h>

__global__ void vector_add(const float* A, const float* B, float* C, int N) {

}

// A, B, C are device pointers (i.e. pointers to memory on the GPU)
void solve(const float* A, const float* B, float* C, int N) {
    int threadsPerBlock = 256;
    int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;

    vector_add<<<blocksPerGrid, threadsPerBlock>>>(A, B, C, N);
    cudaDeviceSynchronize();
}

2025.3项目刚上线是有cudaMalloccudaMemcpycudaFreeAPI的,目前已简化成申请好的指针。示例代码里计算了threadsPerBlock表示每个block分配的线程数量,blocksPerGrid计算了需要多少个block,由于是除法进行了向上取整,有一些线程是空转的。当然还可以指定shared memory大小和stream,完整的<<< ... >>>即启动(launch)kernel 函数为<<<gridDim, blockDim, sharedMemBytes, stream>>>CUDA guide 也对此进行了细致的讲解。

我们只需要补充cuda kernel函数,计算出全局索引,全局索引有block维度的编号,也有thread维度的编号,乘加得到当前线程的ID,不超过N的数组大小不越界后相加即可。

__global__ void vector_add(const float* A, const float* B, float* C, int N) {
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    if (i < N) {
        C[i] = A[i] + B[i];
    }
}

那我们能不能修改下threadsPerBlock的大小呢,比如1024呢,当然可以。当然也可以改为1286432等这样的数字,但是改为2048就会Failed,因为对blcok内线程的限制为1024。以下我搞的目前评测用的Tesla T4./deviceQuery信息。

点击展开代码
Device 0: "Tesla T4"
  CUDA Driver Version / Runtime Version          12.8 / 12.4
  CUDA Capability Major/Minor version number:    7.5
  Total amount of global memory:                 14914 MBytes (15638134784 bytes)
  (040) Multiprocessors, (064) CUDA Cores/MP:    2560 CUDA Cores
  GPU Max Clock rate:                            1590 MHz (1.59 GHz)
  Memory Clock rate:                             5001 Mhz
  Memory Bus Width:                              256-bit
  L2 Cache Size:                                 4194304 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total shared memory per multiprocessor:        65536 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  1024
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 3 copy engine(s)
  Run time limit on kernels:                     Yes
  Integrated GPU sharing Host Memory:            No
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Enabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Managed Memory:                Yes
  Device supports Compute Preemption:            Yes
  Supports Cooperative Kernel Launch:            Yes
  Supports MultiDevice Co-op Kernel Launch:      Yes

关于performance的信息,使用信用卡 5美刀/月 订阅后就可以看到。,默认的Block内线程数256,可以获得71.2th percentile,也就是超过了71.2%的人水平,我们尝试用满即修改为1024达到了16.5th percentile,修改为512可以得到46.8th percentile,修改为128可以获得83.5th percentile,这个数据范围内128确实比较优,所以有些时候会对threadsPerBlock进行autotune自动调优。

Failed的时候可以看到输出的数据,所以我们可以把N给hack出来,是10000=625*16,设置62516就可以省掉条件判断,但是我们会发现其并不会变快,因为我们不能精细得控制线程调度到指定物理线程,具体见CUDA 如何调度 kernel 到指定的 SM?2025.4 更 数据加强了,现在会直接Failed

Pytorch算子写起来就很简单了,就是a+b,题目描述中有dentical lengths,但是Pytorch帮我们处理好了一切,一行秒杀。

import torch

# A, B, C are tensors on the GPU
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, N: int):
    torch.add(A, B, out=C)

Triton的话和CUDA类似,不需要考虑block内,Triton按照block整块处理。参考代码给出了Grid信息,grid = (triton.cdiv(N, BLOCK_SIZE),)CUDA编程中blocksPerGridBLOCK_SIZE = 1024CUDA编程中的threadsPerBlock

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, N: int):
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(N, BLOCK_SIZE),)
    vector_add_kernel[grid](a_ptr, b_ptr, c_ptr, N, BLOCK_SIZE)

作者以前写过一篇浅析 Triton 执行流程,有兴趣可以看看。Triton kernel基本也呼之欲出了,按照整块处理并mask即可。

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    # 有多个“程序”在处理不同的数据。我们在这里确定我们是哪个程序:
    pid = tl.program_id(axis=0)  # 我们以1D网格启动 所以 axis 是 0.
    # 该程序将处理从初始数据偏移的输入。
    # 例如,如果你有一个长度为4096且块大小为1024的向量,程序 将分别访问元素 [0:1024), [1024:2048), [2048:3072), [3072:4096)。
    block_start = pid * BLOCK_SIZE
    # 注意,offsets 是一个指针列表
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 创建一个掩码以防止内存越界访问
    mask = offsets < n_elements
    # 从 DRAM 加载 a 和 b,mask用来解决输入不是块大小的倍数而多余的元素
    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    c = a + b
    # 将 a + b 写回 DRAM
    tl.store(c_ptr + offsets, c, mask=mask)

Triton同样BLOCK_SIZE = 128最优,可以超过98.2%的人,CUDA不使用技巧的最快时间为1.14817 msPytorch2.0441 ms,Triton1.19467 ms。人生苦短,我选Triton, all in Triton。

二、matrix-multiplication

这个是C=A*B,矩阵A是M*N,矩阵B是N*K, 得到的矩阵C是M*KC[m][k]的值将由A的一行A[m][i]和对应的B的一列B[i][k]相乘得到,在KMN的循环顺序下我们串行的代码是这样的。

void matmul(const float* A, const float* B, float* C, int M, int N, int K) {
    memset(C, 0, M * K * sizeof(float));
    for (int col = 0; col < K; col++) {
        for (int row = 0; row < M; row++) {
            for (int i = 0; i < N; i++) {
                C[row * K + col] += A[row * N + i] * B[i * K + col];
            }
        }
    }
}

矩阵乘launch也麻烦点,是二维的。以下代码中dim3 threadsPerBlock (16, 16)声明每个线程块包含 16x16=256 个线程,这个二维block有利于全局内存合并访问,规整的线程组可以减少Warp Divergence。blocksPerGrid则分别进行了X方向和Y方向的网格维度,进行了向上取整。

    // Calculate grid and block dimensions
    dim3 threadsPerBlock(16, 16);
    dim3 blocksPerGrid((K + threadsPerBlock.x - 1) / threadsPerBlock.x,
                       (M + threadsPerBlock.y - 1) / threadsPerBlock.y);

如果是(3,2)的Grid,(4,3)的Block可以用下图表示

以上用的是KM,也就是我们用X维度可以求出col,在Y维度可以求出row,是这样的代码

    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;

在数据范围内我们做累加就好了,我们可以用中间变量去保存这个结果,最终代码为

__global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K) {
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;

    if (row < M && col < K) {
        float sum = 0.0f;
        for (int i = 0; i < N; i++) {
            sum += A[row * N + i] * B[i * K + col];
        }
        C[row * K + col] = sum;
    }
}

row和col能不能反一下呢,逻辑上当然是可以的,代码如下所示

#include "solve.h"
#include <cuda_runtime.h>

__global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int N, int K) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    int col = blockIdx.y * blockDim.y + threadIdx.y;

    if (row < M && col < K) {
        float sum = 0.0f;
        for (int i = 0; i < N; i++) {
            sum += A[row * N + i] * B[i * K + col];
        }
        C[row * K + col] = sum;
    }
}

// A, B, C are device pointers (i.e. pointers to memory on the GPU)
void solve(const float* A, const float* B, float* C, int M, int N, int K) {
    dim3 threadsPerBlock(16, 16);
    dim3 blocksPerGrid((M + threadsPerBlock.x - 1) / threadsPerBlock.x,
                       (K + threadsPerBlock.y - 1) / threadsPerBlock.y);

    matrix_multiplication_kernel<<<blocksPerGrid, threadsPerBlock>>>(A, B, C, M, N, K);
    cudaDeviceSynchronize();
}

以上代码提交我们将会获得Speed Test timed out after 25 seconds,矩阵 B 的访问 B[i * K + col] 中,col 由 blockIdx.y 控制,相邻线程(threadIdx.x 连续)访问 B 的不同行,导致非合并访问,因此速度恨慢,若想使用此种方法需要使用转置后的矩阵进行相乘。

Pytorch依旧是直接秒杀,图一乐。

import torch

# A, B, C are tensors on the GPU
def solve(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, M: int, N: int, K: int):
    torch.matmul(A, B , out = C)

Triton我们可以按照CUDA逻辑先写下,2D的Grid启动,累加到N的循环,通过了。交换M、K还可以更快。限制少了很多,Triton 帮你做了。

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matrix_multiplication_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_an,
    stride_bn, stride_bk,
    stride_cm, stride_ck,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    pid_k = tl.program_id(axis=0)
    pid_m = tl.program_id(axis=1)

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)

    # 初始化A和B的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am
    b_ptrs = b_ptr + offs_k[None, :] * stride_bk

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

    # 沿N维度依次累加
    for n in range(N):
        # 加载A和B的当前块,并应用边界检查
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        accumulator += a * b
        a_ptrs += stride_an
        b_ptrs += stride_bn

    # 将结果写回C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K)
    tl.store(c_ptrs, accumulator, mask=c_mask)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1

    grid = lambda META: (triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(M, META['BLOCK_SIZE_M']), )
    matrix_multiplication_kernel[grid](
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_an,
        stride_bn, stride_bk,
        stride_cm, stride_ck,
        BLOCK_SIZE_M=16,
        BLOCK_SIZE_K=16,
    )

tl.dot是向量乘,可以转换为mmaMatrix Multiply Accumulate指令来调用Tensor Core,Tesla T4由于架构太旧只能生成fma乘加指令。tl.dot做向量乘是有限制的,需要M >= 16, N >= 16 and K >= 16,也就是最小的块必须都大于16。所以我们把N也切掉,再加上BLOCK_SIZE_N,变成分块的矩阵乘法,代码如下所示。

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matrix_multiplication_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_an,
    stride_bn, stride_bk,
    stride_cm, stride_ck,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr
):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    pid_m = tl.program_id(axis=0)
    pid_k = tl.program_id(axis=1)

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_n = tl.arange(0, BLOCK_SIZE_N)

    # 初始化A和B的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
    b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

    # 沿N维度依次累加
    for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
        # 加载A和B的当前块,并应用边界检查
        a = tl.load(a_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N, other=0.0)
        b = tl.load(b_ptrs, mask=offs_n[:, None] < N - n * BLOCK_SIZE_N, other=0.0)
        # 计算矩阵乘法并累加
        accumulator = tl.dot(a, b, accumulator, input_precision="ieee")
        a_ptrs += BLOCK_SIZE_N * stride_an
        b_ptrs += BLOCK_SIZE_N * stride_bn

    # 将结果写回C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K)
    tl.store(c_ptrs, accumulator, mask=c_mask)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1

    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(K, META['BLOCK_SIZE_K']))
    matrix_multiplication_kernel[grid](
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_an,
        stride_bn, stride_bk,
        stride_cm, stride_ck,
        BLOCK_SIZE_M=16,
        BLOCK_SIZE_K=16,
        BLOCK_SIZE_N=16
    )

切的块变大的话提升很多,比如切分均为64,直接到了前33%。官方tutorials是利用L2 cache的,其实就是做了一下Group,性能又提升不少

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matrix_multiplication_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_an,
    stride_bn, stride_bk,
    stride_cm, stride_ck,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = GROUP_SIZE_M * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_k = (pid % num_pid_in_group) // group_size_m

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_n = tl.arange(0, BLOCK_SIZE_N)

    # 初始化A和B的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
    b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

    # 沿N维度依次累加
    for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
        # 加载A和B的当前块,并应用边界检查
        a = tl.load(a_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N, other=0.0)
        b = tl.load(b_ptrs, mask=offs_n[:, None] < N - n * BLOCK_SIZE_N, other=0.0)
        # 计算矩阵乘法并累加
        accumulator = tl.dot(a, b, accumulator, input_precision="ieee")
        a_ptrs += BLOCK_SIZE_N * stride_an
        b_ptrs += BLOCK_SIZE_N * stride_bn

    # 将结果写回C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K)
    tl.store(c_ptrs, accumulator, mask=c_mask)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1

    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(K, META['BLOCK_SIZE_K']), )
    matrix_multiplication_kernel[grid](
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_an,
        stride_bn, stride_bk,
        stride_cm, stride_ck,
        BLOCK_SIZE_M=64,
        BLOCK_SIZE_K=64,
        BLOCK_SIZE_N=64,
        GROUP_SIZE_M=8
    )

CUDA里使用Tensor core需要使用WMMA API,会比较繁琐,Triton里就会简单些。当然CUTLASS就要友好地多,大力支持CUTLASS,未来LeetGPU也将支持CUTLASS。目前Triton的版本是3.2.0,在release/3.2.x分支,当前commit id是9641643da6c52000c807b5eeed05edaec4402a67。以下代码是根据SM(stream multi-processor)数做的持久化线程块,SMTensor Core的最小单元。

import triton
import triton.language as tl

@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
    group_id = tile_id // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (tile_id % group_size_m)
    pid_k = (tile_id % num_pid_in_group) // group_size_m
    return pid_m, pid_k

@triton.jit
def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr,  #
                             M, N, K,  #
                             stride_am, stride_an,  #
                             stride_bn, stride_bk,  #
                             stride_cm, stride_ck,  #
                             BLOCK_SIZE_M: tl.constexpr,  #
                             BLOCK_SIZE_K: tl.constexpr,  #
                             BLOCK_SIZE_N: tl.constexpr,  #
                             GROUP_SIZE_M: tl.constexpr,  #
                             NUM_SMS: tl.constexpr,  #
                             ):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    start_pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
    num_tiles = num_pid_m * num_pid_k

    # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being
    # used in both the prologue and epilogue, so we duplicate the counters as a work-around.
    tile_id_c = start_pid - NUM_SMS

    offs_n_for_mask = tl.arange(0, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_k

    for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
        pid_m, pid_k = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
        start_m = pid_m * BLOCK_SIZE_M
        start_k = pid_k * BLOCK_SIZE_K
        offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
        offs_bk = start_k + tl.arange(0, BLOCK_SIZE_K)
        offs_am = tl.where(offs_am < M, offs_am, 0)
        offs_bk = tl.where(offs_bk < K, offs_bk, 0)
        offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
        offs_bk = tl.max_contiguous(tl.multiple_of(offs_bk, BLOCK_SIZE_K), BLOCK_SIZE_K)

        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
        for ni in range(n_tiles):
            offs_k = ni * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
            a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_an)
            b_ptrs = b_ptr + (offs_k[:, None] * stride_bn + offs_bk[None, :] * stride_bk)

            a = tl.load(a_ptrs, mask=offs_n_for_mask[None, :] < N - ni * BLOCK_SIZE_N, other=0.0)
            b = tl.load(b_ptrs, mask=offs_n_for_mask[:, None] < N - ni * BLOCK_SIZE_N, other=0.0)
            accumulator = tl.dot(a, b, accumulator, input_precision="ieee")

        tile_id_c += NUM_SMS
        pid_m, pid_k = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
        c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_ck * offs_ck[None, :]
        c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K)
        c = accumulator.to(tl.float32)
        tl.store(c_ptrs, c, mask=c_mask)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1
    # 查询资料可得,torch API 不让用,正常使用torch.cuda.get_device_properties("cuda").multi_processor_count
    NUM_SMS = 40
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(K, META["BLOCK_SIZE_K"])), )
    kernel = matmul_kernel_persistent[grid](
        a_ptr, b_ptr, c_ptr,  #
        M, N, K,  #
        stride_am, stride_an,
        stride_bn, stride_bk,
        stride_cm, stride_ck,
        BLOCK_SIZE_M=64,
        BLOCK_SIZE_K=64,
        BLOCK_SIZE_N=64,
        GROUP_SIZE_M=8,
        NUM_SMS=NUM_SMS,  #
    )

Hopper 架构已经引入TMA Tensor,你还可以试试H100的Kernel编写,这个我在学习。目前平台Triton的版本是3.2.03.3.0变化还是挺大的,感恩,已经更新
5.20更 5.8尝试了一早上,都没看文档的Leading dimensions must be multiples of 16-byte stridespython/triton/language/core.py:2042拷打下自己。我们的数据类型是float32,所以在M,N,K均是4的倍数就可以使用tma,仅提供个参考

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matrix_multiplication_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_an,
    stride_bn, stride_bk,
    stride_cm, stride_ck,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
    num_pid_in_group = GROUP_SIZE_M * num_pid_k
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_k = (pid % num_pid_in_group) // group_size_m

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_n = tl.arange(0, BLOCK_SIZE_N)

    # 初始化A和B的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
    b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

    # 沿N维度依次累加
    for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
        # 加载A和B的当前块,并应用边界检查
        a = tl.load(a_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N, other=0.0)
        b = tl.load(b_ptrs, mask=offs_n[:, None] < N - n * BLOCK_SIZE_N, other=0.0)
        # 计算矩阵乘法并累加
        accumulator = tl.dot(a, b, accumulator, input_precision="ieee")
        a_ptrs += BLOCK_SIZE_N * stride_an
        b_ptrs += BLOCK_SIZE_N * stride_bn

    # 将结果写回C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K)
    tl.store(c_ptrs, accumulator, mask=c_mask)

@triton.jit
def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr,  #
                                        M, N, K,  #
                                        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
                                        BLOCK_SIZE_K: tl.constexpr,  #
                                        ):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m
    offs_am = pid_m * BLOCK_SIZE_M
    offs_bn = pid_n * BLOCK_SIZE_N
    offs_k = 0

    a_desc = tl._experimental_make_tensor_descriptor(
        a_ptr,
        shape=[M, K],
        strides=[K, 1],
        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
    )
    b_desc = tl._experimental_make_tensor_descriptor(
        b_ptr,
        shape=[K, N],
        strides=[N, 1],
        block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N],
    )
    c_desc = tl._experimental_make_tensor_descriptor(
        c_ptr,
        shape=[M, N],
        strides=[N, 1],
        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
    )
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = a_desc.load([offs_am, offs_k])
        # tl.device_print("a: ", a)
        b = b_desc.load([offs_k, offs_bn])
        accumulator = tl.dot(a, b, acc=accumulator, input_precision="ieee")
        offs_k += BLOCK_SIZE_K
    accumulator = accumulator.to(a_desc.dtype)
    c_desc.store([offs_am, offs_bn], accumulator)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(K, META['BLOCK_SIZE_K']), )
    # Leading dimensions must be multiples of 16-byte strides
    if M % 4 == 0 and N % 4 == 0 and K % 4 == 0:
        import ctypes
        cudart = ctypes.CDLL("libcudart.so")
        cudart.cudaMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]
        cudart.cudaMalloc.restype = ctypes.c_int
        from typing import Optional
        # TMA descriptors require a global memory allocation
        def alloc_fn(size: int, alignment: int, stream: Optional[int]):
            ptr = ctypes.c_void_p()
            err = cudart.cudaMalloc(ctypes.byref(ptr), size)
            if err != 0:
                raise RuntimeError(f"cudaMalloc failed, code {err}")
            return ptr.value
        triton.set_allocator(alloc_fn)
        matmul_kernel_make_tensor_desciptor[grid](
            a_ptr, b_ptr, c_ptr,
            M, K, N,
            BLOCK_SIZE_M=32,
            BLOCK_SIZE_K=32,
            BLOCK_SIZE_N=32,
        )
    else:
        matrix_multiplication_kernel[grid](
            a_ptr, b_ptr, c_ptr,
            M, N, K,
            N, 1,
            K, 1,
            K, 1,
            BLOCK_SIZE_M=64,
            BLOCK_SIZE_K=64,
            BLOCK_SIZE_N=64,
            GROUP_SIZE_M=8
        )

三、Matrix Transpose

矩阵转置和矩阵乘法这个很像,naive的方法就是转换索引

#include <cuda_runtime.h>

__global__ void matrix_transpose_kernel(const float* input, float* output, int rows, int cols) {
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    if (col < cols && row < rows) {
        output[col * rows + row] = input[row * cols + col];
    }
}

// input, output are device pointers (i.e. pointers to memory on the GPU)
extern "C" void solve(const float* input, float* output, int rows, int cols) {
    dim3 threadsPerBlock(16, 16);
    dim3 blocksPerGrid((cols + threadsPerBlock.x - 1) / threadsPerBlock.x,
                       (rows + threadsPerBlock.y - 1) / threadsPerBlock.y);

    matrix_transpose_kernel<<<blocksPerGrid, threadsPerBlock>>>(input, output, rows, cols);
    cudaDeviceSynchronize();
}

我们可以利用共享内存,padding 解决 bank conflict等等优化得到个性能好的,具体可以参考 PeakCrosser [CUDA 学习笔记] 矩阵转置算子优化

#include "solve.h"
#include <cuda_runtime.h>

template <int BLOCK_SZ, int NUM_PER_THREAD>
__global__ void mat_transpose_kernel_v3(const float* idata, float* odata, int M, int N) {
    const int bx = blockIdx.x, by = blockIdx.y;
    const int tx = threadIdx.x, ty = threadIdx.y;

    __shared__ float sdata[BLOCK_SZ][BLOCK_SZ+1];

    int x = bx * BLOCK_SZ + tx;
    int y = by * BLOCK_SZ + ty;

    constexpr int ROW_STRIDE = BLOCK_SZ / NUM_PER_THREAD;

    if (x < N) {
        #pragma unroll
        for (int y_off = 0; y_off < BLOCK_SZ; y_off += ROW_STRIDE) {
            if (y + y_off < M) {
                sdata[ty + y_off][tx] = idata[(y + y_off) * N + x];
            }
        }
    }
    __syncthreads();
    x = by * BLOCK_SZ + tx;
    y = bx * BLOCK_SZ + ty;
    if (x < M) {
        for (int y_off = 0; y_off < BLOCK_SZ; y_off += ROW_STRIDE) {
            if (y + y_off < N) {
                odata[(y + y_off) * M + x] = sdata[tx][ty + y_off];
            }
        }
    }
}

// input, output are device pointers (i.e. pointers to memory on the GPU)
void solve(const float* input, float* output, int rows, int cols) {
    constexpr int BLOCK_SZ = 32;
    constexpr int NUM_PER_THREAD = 4;
    dim3 block(BLOCK_SZ, BLOCK_SZ/NUM_PER_THREAD);
    dim3 grid((cols+ BLOCK_SZ-1)/BLOCK_SZ, (rows+BLOCK_SZ-1)/BLOCK_SZ);
    mat_transpose_kernel_v3<BLOCK_SZ, NUM_PER_THREAD><<<grid, block>>>(input, output, rows, cols);
    cudaDeviceSynchronize();
}

Pytorch 要做的就是一行秒杀

import torch

# input, output are tensors on the GPU
def solve(input: torch.Tensor, output: torch.Tensor, rows: int, cols: int):
    output.copy_(torch.transpose(input, 0, 1))

Triton 考虑的就不多了,tl.trans会帮我们做掉

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matrix_transpose_kernel(
    input_ptr, output_ptr,
    M, N,
    stride_ir, stride_ic,
    stride_or, stride_oc,
    BLOCK_SIZE : tl.constexpr
):
    input_ptr = input_ptr.to(tl.pointer_type(tl.float32))
    output_ptr = output_ptr.to(tl.pointer_type(tl.float32))
    # -----------------------------------------------------------
    # 1. 确定当前线程块负责的输入分块坐标
    # -----------------------------------------------------------
    pid_m = tl.program_id(0)  # 分块在 M 方向的索引
    pid_n = tl.program_id(1)  # 分块在 N 方向的索引

    # -----------------------------------------------------------
    # 2. 计算分块内每个元素的位置偏移
    # -----------------------------------------------------------
    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    # -----------------------------------------------------------
    # 3. 定义输入分块的全局内存指针(行优先)
    # -----------------------------------------------------------
    input_ptrs = input_ptr + offs_m[:, None] * stride_ir + offs_n[None, :] * stride_ic

    # -----------------------------------------------------------
    # 4. 从全局内存加载输入分块(带边界检查)
    # -----------------------------------------------------------
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    block = tl.load(input_ptrs, mask=mask, other=0)

    # -----------------------------------------------------------
    # 5. 转置分块(交换行列)
    # -----------------------------------------------------------
    transposed_block = tl.trans(block)  # Triton 内置转置函数

    # -----------------------------------------------------------
    # 6. 计算输出分块的全局内存指针(列优先)
    # -----------------------------------------------------------
    output_ptrs = output_ptr + offs_n[:, None] * M + offs_m[None, :]  # 注意 M 是转置后的行步长

    # -----------------------------------------------------------
    # 7. 将转置后的分块写入全局内存
    # -----------------------------------------------------------
    tl.store(output_ptrs, transposed_block, mask=mask.T)  # mask 也需要转置


# input_ptr, output_ptr are raw device pointers
def solve(input_ptr: int, output_ptr: int, rows: int, cols: int):
    stride_ir, stride_ic = cols, 1
    stride_or, stride_oc = rows, 1

    grid = lambda META: (triton.cdiv(rows, META['BLOCK_SIZE']), triton.cdiv(cols, META['BLOCK_SIZE']))
    matrix_transpose_kernel[grid](
        input_ptr, output_ptr,
        rows, cols,
        stride_ir, stride_ic,
        stride_or, stride_oc,
        BLOCK_SIZE=32
    )

四、softmax

softmax就比较麻烦了,所用公式为\(\sigma(\mathbf{z})_i \;=\; \frac{\exp\bigl(z_i - \max_k z_k\bigr)} {\displaystyle\sum_{j=1}^{K} \exp\bigl(z_j - \max_k z_k\bigr)}\),想要性能过得去得用原子指令求最大值,写个块内reduce 规约操作。

__global__ void softmax_kernel(const float* input, float* output, int N) {
    extern __shared__ float shared_mem[]; // 动态共享内存用于归约
    float* max_shared = shared_mem;
    float* sum_shared = &shared_mem[blockDim.x];

    int tid = threadIdx.x;

    // 阶段1: 计算最大值
    float local_max = -FLT_MAX;
    for (int i = tid; i < N; i += blockDim.x) {
        local_max = fmaxf(local_max, input[i]);
    }
    max_shared[tid] = local_max;
    __syncthreads();

    // 块内归约求最大值
    for (int s = blockDim.x/2; s > 0; s >>= 1) {
        if (tid < s) {
            max_shared[tid] = fmaxf(max_shared[tid], max_shared[tid + s]);
        }
        __syncthreads();
    }
    float max_val = max_shared[0];
    __syncthreads();

    // 阶段2: 计算指数和
    float local_sum = 0.0f;
    for (int i = tid; i < N; i += blockDim.x) {
        local_sum += expf(input[i] - max_val);
    }
    sum_shared[tid] = local_sum;
    __syncthreads();

    // 块内归约求和
    for (int s = blockDim.x/2; s > 0; s >>= 1) {
        if (tid < s) {
            sum_shared[tid] += sum_shared[tid + s];
        }
        __syncthreads();
    }
    float sum_exp = sum_shared[0];
    __syncthreads();

    // 阶段3: 计算最终结果
    for (int i = tid; i < N; i += blockDim.x) {
        output[i] = expf(input[i] - max_val) / sum_exp;
    }
}

// input, output are device pointers (i.e. pointers to memory on the GPU)
void solve(const float* input, float* output, int N) {
    int threadsPerBlock = 256;
    int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
    size_t sharedSize = 2 * threadsPerBlock * sizeof(float);
    softmax_kernel<<<blocksPerGrid, threadsPerBlock, sharedSize>>>(input, output, N);
    cudaDeviceSynchronize();
}

Pytorch 要做的就是一行秒杀

import torch

# input, output are tensors on the GPU
def solve(input: torch.Tensor, output: torch.Tensor, N: int):
    torch.softmax(input, 0, out=output)

这是一个比较重要的算子,Triton tutorials也有示例,这个块内规约起来就简单了,但是跨线程块(grid-wide)的规约怎么办。Triton关注块内,让我们看看tutorials/02-fused-softmax.py 是怎么做的,首先他是一个在线的方式,先编译了kernel拿到了kernel.metadata.shared, 那这个shared memory会不会不够用呢。计算出了一个更小的tl.num_programs(0) 作为 循环的 step。BLOCK_SIZE是大于等于n_cols的最小2的幂,那我们N这么大。

可以写3个kernel,先做max的局部规约和全局归约,再做sum的局部规约和全局规约,最后再归一化 softmax 输出。但是不能通过题目,因为用torch申请新的空间了,triton mode下禁止使用torch。
5.7更 使用ctypes.CDLL("libcudart.so")cudaMalloc,无敌了

import triton
import triton.language as tl

@triton.jit
def partial_max_value_kernel(X, partial_max, N, BLOCK_SIZE: tl.constexpr):
    X = X.to(tl.pointer_type(tl.float32))
    partial_max = partial_max.to(tl.pointer_type(tl.float32))
    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offset < N
    x = tl.load(X + offset, mask=mask, other=-float("inf"))
    local_max = tl.max(x, axis=0)
    tl.store(partial_max + pid, local_max)

@triton.jit
def partial_exp_sum_value_kernel(X, partial_sum, global_max, N, BLOCK_SIZE: tl.constexpr):
    X = X.to(tl.pointer_type(tl.float32))
    partial_sum = partial_sum.to(tl.pointer_type(tl.float32))
    global_max = global_max.to(tl.pointer_type(tl.float32))
    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offset < N
    x = tl.load(X + offset, mask=mask, other=-float("inf"))
    gmax = tl.load(global_max)
    local_sum = tl.sum(tl.exp(x - gmax), axis=0)
    tl.store(partial_sum + pid, local_sum)

@triton.jit
def normalize_kernel(X, Y, N, global_max, global_sum, BLOCK_SIZE: tl.constexpr):
    X = X.to(tl.pointer_type(tl.float32))
    Y = Y.to(tl.pointer_type(tl.float32))
    global_max = global_max.to(tl.pointer_type(tl.float32))
    global_sum = global_sum.to(tl.pointer_type(tl.float32))
    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offset < N
    x = tl.load(X + offset, mask=mask)
    gmax = tl.load(global_max)
    gsum = tl.load(global_sum)
    y = tl.exp(x - gmax) / gsum
    tl.store(Y + offset, y, mask=mask)

@triton.jit
def get_max_value(partial_max, global_max, BLOCK_SIZE: tl.constexpr):
    partial_max = partial_max.to(tl.pointer_type(tl.float32))
    global_max = global_max.to(tl.pointer_type(tl.float32))
    offset =  tl.arange(0, BLOCK_SIZE)
    x = tl.load(partial_max + offset)
    local_max = tl.max(x, axis=0)
    tl.store(global_max, local_max)

@triton.jit
def get_sum_value(partial_sum, global_sum, BLOCK_SIZE: tl.constexpr):
    partial_sum = partial_sum.to(tl.pointer_type(tl.float32))
    global_sum = global_sum.to(tl.pointer_type(tl.float32))
    offset =  tl.arange(0, BLOCK_SIZE)
    x = tl.load(partial_sum + offset)
    local_sum = tl.sum(x, axis=0)
    tl.store(global_sum, local_sum)

def cudaEmpty(num_elements:int):
    import ctypes
    cudart = ctypes.CDLL("libcudart.so")
    cudart.cudaMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]
    cudart.cudaMalloc.restype = ctypes.c_int
    ptr = ctypes.c_void_p()
    err = cudart.cudaMalloc(ctypes.byref(ptr), num_elements*4)
    if err != 0:
        raise RuntimeError(f"cudaMalloc failed, code {err}")
    return ptr.value


# input_ptr, output_ptr are raw device pointers
def solve(input_ptr: int, output_ptr: int, N: int):
    BLOCK_SIZE = 32768
    num_blocks = triton.cdiv(N, BLOCK_SIZE)
    grid = (num_blocks,)
    partial_max = cudaEmpty(BLOCK_SIZE)
    partial_max_value_kernel[grid](
        input_ptr, partial_max, N,
        BLOCK_SIZE=BLOCK_SIZE
    )
    global_max = cudaEmpty(1)
    get_max_value[1,](partial_max, global_max, BLOCK_SIZE=num_blocks)
    partial_sum = cudaEmpty(num_blocks)
    partial_exp_sum_value_kernel[grid](
        input_ptr, partial_sum, global_max, N,
        BLOCK_SIZE=BLOCK_SIZE
    )
    global_sum = cudaEmpty(1)
    get_sum_value[1,](partial_sum, global_sum, BLOCK_SIZE=num_blocks)
    normalize_kernel[grid](
        input_ptr, output_ptr, N,
        global_max, global_sum,
        BLOCK_SIZE=BLOCK_SIZE
    )

暂时领先cuda 心满意足.jpg 但感觉不太合理,怀疑评测程序有点问题。

4.30更那其实把这些规约的东西放到shared memory就好了,一个block干完所有任务算了。具体见以下代码

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def softmax_kernel(
    input_ptr, output_ptr,
    N,
    BLOCK_SIZE: tl.constexpr
):
    input_ptr = input_ptr.to(tl.pointer_type(tl.float32))
    output_ptr = output_ptr.to(tl.pointer_type(tl.float32))
    _max = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - float("inf")
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(input_ptr + cols, mask=cols < N, other=-float("inf"))
        _max = tl.maximum(a, _max)
    max = tl.max(_max, axis=0)
    _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(input_ptr + cols, mask=cols < N, other=-float("inf"))
        _sum += tl.exp(a - max)
    sum = tl.sum(_sum, axis=0)
    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offset < N
    x = tl.load(input_ptr + offset, mask=mask)
    y = tl.exp(x - max) / sum
    tl.store(output_ptr + offset, y, mask=mask)

# input_ptr, output_ptr are raw device pointers
def solve(input_ptr: int, output_ptr: int, N: int):
    BLOCK_SIZE = 32768
    grid = (triton.cdiv(N, BLOCK_SIZE),)
    softmax_kernel[grid](
        input_ptr, output_ptr, N,
        BLOCK_SIZE=BLOCK_SIZE
    )

成功拿到number 1,值得纪念

五、提交新题

对比在传统OJ出题少了制造数据,我之前写过TZOJ出题帮助,这个数据看起来是平台这边帮忙给造的,提供cpu标程就可以了。

六、参考

1.CUDA guide

2.LeetCUDA

3.NVIDIA cuda-samples

4.Triton docs

5.Triton tutorials

6.手抓饼熊 Triton学习仓库推荐

7.紫气东来 CUDA(三):通用矩阵乘法:从入门到熟练

8.PeakCrosser [CUDA 学习笔记] 矩阵转置算子优化

9.离心 CUDA编程指北:从入门到实践

10.cloudcore CUDA微架构与指令集(4)-指令发射与warp调度

posted @ 2025-03-06 15:38  暴力都不会的蒟蒻  阅读(1562)  评论(0)    收藏  举报