Triton SPIR-V 后端开发:矩阵乘实现验证
本项目于2025.7.6在NVIDIA显卡上使用OpenCL跑通了矩阵乘matrix-multiplication.py,并对上了答案。

一、前言
1、项目进展
本项目于上个月(2025.6.7)在NVIDIA显卡上使用OpenCL跑通了向量加,具体见Triton SPIR-V 后端开发:向量加实现验证。最近一个月修复了上文中Launch kernel 参数固定、memcpy的完善的问题,还从MLIR加入了BufferLoopHoistingPass等开发后,完成了矩阵乘的验证。
2、生成的OpenCL kernel
没做任何优化,完全符合Triton的语义。memcpy 这次我用 async_work_group_copy来组合了。
//===------------------------------------------------------------*- C++ -*-===//
//
// Automatically generated file for OpenCL
//
//===----------------------------------------------------------------------===//
__kernel void matmul_kernel(
__global float* var_0,
__global float* var_1,
__global float* var_2,
int var_3,
int var_4,
int var_5,
int var_6,
int var_7,
int var_8
) { // L5
event_t ev = 0;
__local float var_9[16][16]; // L32
__local float var_10[16][16]; //
for (int var_11 = 0; var_11 < 16; var_11 += 1) { //
for (int var_12 = 0; var_12 < 16; var_12 += 1) { //
var_10[var_11][var_12] = 0.000000; //
}
}
int var_13 = get_global_id(0);
int var_15 = get_global_id(1);
int var_17 = var_15 * 16; // L16
int var_18 = var_13 * 16; // L17
int var_21 = var_17 * var_6; // L20
__local float var_23[16][1]; // L30
__local float var_24[1][16]; // L31
__local float var_25[16][16]; // L32
for (int var_26 = 0; var_26 < var_5; var_26 += 1) { // L26
int var_28 = var_21 + var_26; // L27
int var_29 = var_26 * var_7; // L28
int var_31 = var_29 + var_18; // L28
// load Block A in col
for (int i = 0; i < 16; i += 1) {
ev = async_work_group_copy(var_23 + i, var_0 + var_28 + i * var_6, 1, 0);
} // L30
wait_group_events(1, &ev);
// load Block B in row
ev = async_work_group_copy(var_24, var_1 + var_31, 16, 0); // L31
wait_group_events(1, &ev);
// broadcast Block A
for (int var_32 = 0; var_32 < 16; var_32 += 1) { // L32
for (int var_33 = 0; var_33 < 16; var_33 += 1) { // L32
float var_34 = var_23[var_32][0]; // L32
var_25[var_32][var_33] = var_34; // L32
}
}
// broadcast Block B
for (int var_35 = 0; var_35 < 16; var_35 += 1) { // L32
for (int var_36 = 0; var_36 < 16; var_36 += 1) { // L32
float var_37 = var_24[0][var_36]; // L32
var_9[var_35][var_36] = var_37; // L32
}
}
// Block A mul Block B
for (int var_38 = 0; var_38 < 16; var_38 += 1) { // L32
for (int var_39 = 0; var_39 < 16; var_39 += 1) { // L32
float var_40 = var_25[var_38][var_39]; // L32
float var_41 = var_9[var_38][var_39]; // L32
float var_42 = var_40 * var_41; // L32
var_25[var_38][var_39] = var_42; // L32
}
}
// result(Block A mul Block B) + acc
for (int var_43 = 0; var_43 < 16; var_43 += 1) { // L32
for (int var_44 = 0; var_44 < 16; var_44 += 1) { // L32
float var_45 = var_10[var_43][var_44]; // L32
float var_46 = var_25[var_43][var_44]; // L32
float var_47 = var_45 + var_46; // L32
var_10[var_43][var_44] = var_47; // L32
}
}
}
int var_49 = var_17 * var_8; // L35
int var_50 = var_49 + var_18; // L35
int var_51 = var_17 + 16; // L39
int var_53 = min(var_51, var_3); // L39
int var_54 = max(var_53, var_17); // L39
int var_55 = var_54 - var_17; // L39
int var_56 = var_18 + 16; // L39
int var_58 = min(var_56, var_4); // L39
int var_59 = max(var_58, var_18); // L39
int var_60 = var_59 - var_18; // L39
int var_61 = min(var_55, 16); // L39
int var_62 = min(var_60, 16); // L39
// store on row
for (int i = 0; i < var_61; i += 1) {
ev = async_work_group_copy(var_2 + var_50 + i * var_8, var_10 + i, var_62, 0);
} // L39
wait_group_events(1, &ev);
}
二、开发要点
本文章所用commit为35fe15b,当前编译stage依旧为Python code->Triton IR->Linalg IR->memref IR->OpenCL file。和向量加验证通过代码diff
1、Launch kernel 完善
根据Triton的参数信息进行clSetKernelArg,Triton kernel的signature会保留着原始的参数信息,比如类型以及名字。我们可以根据类型信息去遍历bound_args, 我们可以先处理掉指针和i32。指针*是tensor,可以用data_ptr()获取到指针,然后设置 OpenCL 缓冲区对象(global buffer),因为我们还是在cpu上的tensor,然后设置cl_mem_flags为CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,可读可写且复制 host_ptr 的数据到cl_mem里。创建好cl_mem后我们需要将其保存起来用于拷回,这是OpenCL没有Pytorch接入带来的问题,不然就0拷贝了。i32转换为整数就可以对应了。这里的话stride最后一维为1,被优化掉了,类型为constexpr,所以参数少了3个,但是不影响我们的处理。
param_tys = [ty for (name, ty) in tt_kernel.src.signature.items()]
buffers = {}
kernel_idx = 0
for (idx, arg) in enumerate(bound_args):
ty = param_tys[idx]
if ty[0] == '*':
nbytes = arg.element_size() * arg.numel()
cl_buf = ctypes.c_void_p(cl.clCreateBuffer(context, CL_MEM_KERNEL, nbytes, ctypes.c_void_p(arg.data_ptr()), ctypes.byref(err)))
cl.clSetKernelArg(kernel, kernel_idx, ctypes.sizeof(ctypes.c_void_p), ctypes.byref(cl_buf))
kernel_idx += 1
buffers[idx] = cl_buf
elif ty == "constexpr":
pass
elif ty == 'i32':
int_value = ctypes.c_int(arg)
cl.clSetKernelArg(kernel, kernel_idx, ctypes.sizeof(int_value), ctypes.byref(int_value))
kernel_idx += 1
else:
raise RuntimeError("Unsuport this kernel type, please add!")
kernel 执行完后进行数据拷回
for (idx, arg) in enumerate(bound_args):
if param_tys[idx][0] == '*':
nbytes = arg.element_size() * arg.numel()
cl.clEnqueueReadBuffer(queue, buffers[idx], False, 0, nbytes, ctypes.c_void_p(arg.data_ptr()), 0, None, None)
cl.clFinish(queue)
2、memcpy的完善
之前是使用for循环+赋值语句的,很傻,这里使用了async_work_group_copy的API,可以在__global和__local互相拷贝。
event_t async_work_group_copy(
__local gentype *dst, const __global gentype *src,
size_t num_gentypes, event_t event);
event_t async_work_group_copy(
__global gentype *dst, const __local gentype *src,
size_t num_gentypes, event_t event);
但是如果有跨步呢,这里还是for循环套了一下。当然这里还要处理好stride,具体可以看源码。
if (is1D(targetMemref)) {
emitAsyncCopyWithOpFoldResult(targetSubView.getSource(),
sourceSubView.getSource(),
targetSubView.getMixedSizes()[0]);
} else if (is2D(targetMemref)) {
indent() << "for (int i = 0; i < ";
emitOpFoldResult(targetSubView.getMixedSizes()[0]);
os << "; i += 1) {\n";
addIndent();
emitAsyncCopyWithOpFoldResult(targetSubView.getSource(),
sourceSubView.getSource(),
targetSubView.getMixedSizes()[1]);
os << "\n";
reduceIndent();
indent() << "}";
}
3、Triton Dialect 转换
多了些Op需要Convert,这里还是借助microsoft/triton-shared的源码。
patterns.add<YieldConverter>(patterns.getContext());
...
patterns.add<LoopConverter>(patterns.getContext());
patterns.add<BroadcastConverter>(patterns.getContext());
...
patterns.add<DenseConstantConverter>(patterns.getContext());
4、Linalg Dialect 转换
在linalg这边矩阵乘还多了linalg::FillOp,全部丢进SmallVector再做处理就好了。
SmallVector<linalg::LinalgOp> linalgOps;
moduleOp.walk(
[&](linalg::GenericOp genericOp) { linalgOps.push_back(genericOp); });
moduleOp.walk(
[&](linalg::FillOp fillOp) { linalgOps.push_back(fillOp); });
PatternRewriter rewriter(&getContext());
for (auto linalgOp : linalgOps) {
rewriter.setInsertionPoint(linalgOp);
if (failed(linalg::linalgOpToAffineLoops(rewriter, linalgOp))) {
llvm::errs() << "Failed to lower to affine loops.\n";
return;
}
rewriter.eraseOp(linalgOp);
}
5、BufferLoopHoistingPass
__local 必须定义必须在 kernel 的最外层作用域声明,不能在条件、循环或其他块中引用。所以可以用BufferLoopHoistingPass把它提出来,他是作用于func::FuncOp的,和Pybind绑定时包一下OpPassManager。
m.def("buffer_loop_hoisting", [](mlir::PassManager &pm) {
mlir::OpPassManager &funcPM = pm.nest<mlir::func::FuncOp>();
funcPM.addPass(mlir::bufferization::createBufferLoopHoistingPass());
});
附录1、项目文档
Triton SPIR-V 后端开发:backend 初始化
附录2、作者相关技术文章
深度剖析 Triton编译器 MatMul优化(三)—— TMA
深度剖析 Triton编译器 MatMul优化(二)—— MMA
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18969274

浙公网安备 33010602011771号