import triton
import triton.language as tl
@triton.jit
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
m = tl.program_id(0)
n = tl.arange(tl.constexpr(0), tl.constexpr(1024))
X = X + m * stride_xm + n * stride_xn
x = tl.load(X, mask=n < N, other=-float('inf'))
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
Y = Y + m * stride_ym + n * stride_yn
tl.store(Y, y, mask=n < N)
import torch
# 生成测试数据
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)
# 执行Triton内核
grid = (X.shape[0],)
softmax[grid](Y, Y.stride(0), Y.stride(1),
X, X.stride(0), X.stride(1),
X.shape[0], X.shape[1])
# 计算参考结果并验证
ref = torch.softmax(X, dim=1)
print("Triton计算结果示例(前5行):\n", Y[:5, :5])
print("PyTorch参考结果示例(前5行):\n", ref[:5, :5])
print("最大差异:", torch.max(torch.abs(Y - ref)).item())
print("结果是否一致:", torch.allclose(Y, ref, rtol=1e-4))