使用triton 简单实现softmax 算子

 

 

import triton
import triton.language as tl

@triton.jit
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
    m = tl.program_id(0)
    n = tl.arange(tl.constexpr(0), tl.constexpr(1024))
    X = X + m * stride_xm + n * stride_xn
    x = tl.load(X, mask=n < N, other=-float('inf'))
    z = x - tl.max(x, axis=0)
    num = tl.exp(z)
    denom = tl.sum(num, axis=0)
    y = num / denom
    Y = Y + m * stride_ym + n * stride_yn
    tl.store(Y, y, mask=n < N)

import torch

# 生成测试数据
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)

# 执行Triton内核
grid = (X.shape[0],)
softmax[grid](Y, Y.stride(0), Y.stride(1),
             X, X.stride(0), X.stride(1),
             X.shape[0], X.shape[1])

# 计算参考结果并验证
ref = torch.softmax(X, dim=1)
print("Triton计算结果示例(前5行):\n", Y[:5, :5])
print("PyTorch参考结果示例(前5行):\n", ref[:5, :5])
print("最大差异:", torch.max(torch.abs(Y - ref)).item())
print("结果是否一致:", torch.allclose(Y, ref, rtol=1e-4))

  

posted @ 2025-06-19 16:01  小kk_p  阅读(48)  评论(0)    收藏  举报