Triton SPIR-V 后端开发:矩阵乘实现验证

项目地址:OpenMLIR/triton-spirv

本项目于2025.7.6在NVIDIA显卡上使用OpenCL跑通了矩阵乘matrix-multiplication.py,并对上了答案。

一、前言

1、项目进展

本项目于上个月(2025.6.7)在NVIDIA显卡上使用OpenCL跑通了向量加,具体见Triton SPIR-V 后端开发:向量加实现验证。最近一个月修复了上文中Launch kernel 参数固定、memcpy的完善的问题,还从MLIR加入了BufferLoopHoistingPass等开发后,完成了矩阵乘的验证。

2、生成的OpenCL kernel

没做任何优化,完全符合Triton的语义。memcpy 这次我用 async_work_group_copy来组合了。

//===------------------------------------------------------------*- C++ -*-===//
//
// Automatically generated file for OpenCL
//
//===----------------------------------------------------------------------===//

__kernel void matmul_kernel(
  __global float* var_0,
  __global float* var_1,
  __global float* var_2,
  int var_3,
  int var_4,
  int var_5,
  int var_6,
  int var_7,
  int var_8
) { // L5
  event_t ev = 0;
  __local float var_9[16][16];  // L32
  __local float var_10[16][16]; //
  for (int var_11 = 0; var_11 < 16; var_11 += 1) {  //
    for (int var_12 = 0; var_12 < 16; var_12 += 1) {    //
      var_10[var_11][var_12] = 0.000000;    //
    }
  }
  int var_13 = get_global_id(0);
  int var_15 = get_global_id(1);
  int var_17 = var_15 * 16; // L16
  int var_18 = var_13 * 16; // L17
  int var_21 = var_17 * var_6;  // L20
  __local float var_23[16][1];  // L30
  __local float var_24[1][16];  // L31
  __local float var_25[16][16]; // L32
  for (int var_26 = 0; var_26 < var_5; var_26 += 1) {   // L26
    int var_28 = var_21 + var_26;   // L27
    int var_29 = var_26 * var_7;    // L28
    int var_31 = var_29 + var_18;   // L28
    // load Block A in col
    for (int i = 0; i < 16; i += 1) {
      ev = async_work_group_copy(var_23 + i, var_0 + var_28 + i * var_6, 1, 0);
    }   // L30
    wait_group_events(1, &ev);
    // load Block B in row
    ev = async_work_group_copy(var_24, var_1 + var_31, 16, 0);  // L31
    wait_group_events(1, &ev);
    // broadcast Block A
    for (int var_32 = 0; var_32 < 16; var_32 += 1) {    // L32
      for (int var_33 = 0; var_33 < 16; var_33 += 1) {  // L32
        float var_34 = var_23[var_32][0];   // L32
        var_25[var_32][var_33] = var_34;    // L32
      }
    }
    // broadcast Block B
    for (int var_35 = 0; var_35 < 16; var_35 += 1) {    // L32
      for (int var_36 = 0; var_36 < 16; var_36 += 1) {  // L32
        float var_37 = var_24[0][var_36];   // L32
        var_9[var_35][var_36] = var_37; // L32
      }
    }
    // Block A mul Block B
    for (int var_38 = 0; var_38 < 16; var_38 += 1) {    // L32
      for (int var_39 = 0; var_39 < 16; var_39 += 1) {  // L32
        float var_40 = var_25[var_38][var_39];  // L32
        float var_41 = var_9[var_38][var_39];   // L32
        float var_42 = var_40 * var_41; // L32
        var_25[var_38][var_39] = var_42;    // L32
      }
    }
    // result(Block A mul Block B) + acc
    for (int var_43 = 0; var_43 < 16; var_43 += 1) {    // L32
      for (int var_44 = 0; var_44 < 16; var_44 += 1) {  // L32
        float var_45 = var_10[var_43][var_44];  // L32
        float var_46 = var_25[var_43][var_44];  // L32
        float var_47 = var_45 + var_46; // L32
        var_10[var_43][var_44] = var_47;    // L32
      }
    }
  }
  int var_49 = var_17 * var_8;  // L35
  int var_50 = var_49 + var_18; // L35
  int var_51 = var_17 + 16; // L39
  int var_53 = min(var_51, var_3);  // L39
  int var_54 = max(var_53, var_17); // L39
  int var_55 = var_54 - var_17; // L39
  int var_56 = var_18 + 16; // L39
  int var_58 = min(var_56, var_4);  // L39
  int var_59 = max(var_58, var_18); // L39
  int var_60 = var_59 - var_18; // L39
  int var_61 = min(var_55, 16); // L39
  int var_62 = min(var_60, 16); // L39
  // store on row
  for (int i = 0; i < var_61; i += 1) {
    ev = async_work_group_copy(var_2 + var_50 + i * var_8, var_10 + i, var_62, 0);
  } // L39
  wait_group_events(1, &ev);
}

二、开发要点

本文章所用commit为35fe15b,当前编译stage依旧为Python code->Triton IR->Linalg IR->memref IR->OpenCL file。和向量加验证通过代码diff

1、Launch kernel 完善

根据Triton的参数信息进行clSetKernelArg,Triton kernel的signature会保留着原始的参数信息,比如类型以及名字。我们可以根据类型信息去遍历bound_args, 我们可以先处理掉指针和i32。指针*是tensor,可以用data_ptr()获取到指针,然后设置 OpenCL 缓冲区对象(global buffer),因为我们还是在cpu上的tensor,然后设置cl_mem_flagsCL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,可读可写且复制 host_ptr 的数据到cl_mem里。创建好cl_mem后我们需要将其保存起来用于拷回,这是OpenCL没有Pytorch接入带来的问题,不然就0拷贝了。i32转换为整数就可以对应了。这里的话stride最后一维为1,被优化掉了,类型为constexpr,所以参数少了3个,但是不影响我们的处理。

    param_tys = [ty for (name, ty) in tt_kernel.src.signature.items()]
    buffers = {}
    kernel_idx = 0
    for (idx, arg) in enumerate(bound_args):
        ty = param_tys[idx]
        if ty[0] == '*':
            nbytes = arg.element_size() * arg.numel()
            cl_buf = ctypes.c_void_p(cl.clCreateBuffer(context, CL_MEM_KERNEL, nbytes, ctypes.c_void_p(arg.data_ptr()), ctypes.byref(err)))
            cl.clSetKernelArg(kernel, kernel_idx, ctypes.sizeof(ctypes.c_void_p), ctypes.byref(cl_buf))
            kernel_idx += 1
            buffers[idx] = cl_buf
        elif ty == "constexpr":
            pass
        elif ty == 'i32':
            int_value = ctypes.c_int(arg)
            cl.clSetKernelArg(kernel, kernel_idx, ctypes.sizeof(int_value), ctypes.byref(int_value))
            kernel_idx += 1
        else:
            raise RuntimeError("Unsuport this kernel type, please add!")

kernel 执行完后进行数据拷回

    for (idx, arg) in enumerate(bound_args):
        if param_tys[idx][0] == '*':
            nbytes = arg.element_size() * arg.numel()
            cl.clEnqueueReadBuffer(queue, buffers[idx], False, 0, nbytes, ctypes.c_void_p(arg.data_ptr()), 0, None, None)
    cl.clFinish(queue)

2、memcpy的完善

之前是使用for循环+赋值语句的,很傻,这里使用了async_work_group_copy的API,可以在__global__local互相拷贝。

event_t async_work_group_copy(
  __local gentype *dst, const __global gentype *src,
  size_t num_gentypes, event_t event);

event_t async_work_group_copy(
  __global gentype *dst, const __local gentype *src,
  size_t num_gentypes, event_t event);

但是如果有跨步呢,这里还是for循环套了一下。当然这里还要处理好stride,具体可以看源码。

    if (is1D(targetMemref)) {
      emitAsyncCopyWithOpFoldResult(targetSubView.getSource(),
                                    sourceSubView.getSource(),
                                    targetSubView.getMixedSizes()[0]);
    } else if (is2D(targetMemref)) {
      indent() << "for (int i = 0; i < ";
      emitOpFoldResult(targetSubView.getMixedSizes()[0]);
      os << "; i += 1) {\n";
      addIndent();
      emitAsyncCopyWithOpFoldResult(targetSubView.getSource(),
                                    sourceSubView.getSource(),
                                    targetSubView.getMixedSizes()[1]);
      os << "\n";
      reduceIndent();
      indent() << "}";
    }

3、Triton Dialect 转换

多了些Op需要Convert,这里还是借助microsoft/triton-shared的源码。

  patterns.add<YieldConverter>(patterns.getContext());
...
  patterns.add<LoopConverter>(patterns.getContext());
  patterns.add<BroadcastConverter>(patterns.getContext());
...
  patterns.add<DenseConstantConverter>(patterns.getContext());

4、Linalg Dialect 转换

在linalg这边矩阵乘还多了linalg::FillOp,全部丢进SmallVector再做处理就好了。

    SmallVector<linalg::LinalgOp> linalgOps;
    moduleOp.walk(
        [&](linalg::GenericOp genericOp) { linalgOps.push_back(genericOp); });
    moduleOp.walk(
        [&](linalg::FillOp fillOp) { linalgOps.push_back(fillOp); });
    PatternRewriter rewriter(&getContext());
    for (auto linalgOp : linalgOps) {
      rewriter.setInsertionPoint(linalgOp);
      if (failed(linalg::linalgOpToAffineLoops(rewriter, linalgOp))) {
        llvm::errs() << "Failed to lower to affine loops.\n";
        return;
      }
      rewriter.eraseOp(linalgOp);
    }

5、BufferLoopHoistingPass

__local 必须定义必须在 kernel 的最外层作用域声明,不能在条件、循环或其他块中引用。所以可以用BufferLoopHoistingPass把它提出来,他是作用于func::FuncOp的,和Pybind绑定时包一下OpPassManager

  m.def("buffer_loop_hoisting", [](mlir::PassManager &pm) {
    mlir::OpPassManager &funcPM = pm.nest<mlir::func::FuncOp>();
    funcPM.addPass(mlir::bufferization::createBufferLoopHoistingPass());
  });

附录1、项目文档

Triton SPIR-V 后端开发:向量加实现验证

Triton SPIR-V 后端开发:PyBind绑定

Triton SPIR-V 后端开发:新增Pass

Triton SPIR-V 后端开发:backend 初始化

附录2、作者相关技术文章

深度剖析 Triton编译器 MatMul优化(三)—— TMA

深度剖析 Triton编译器 MatMul优化(二)—— MMA

深度剖析 Triton编译器 MatMul优化(一)—— FMA

浅析 Triton 执行流程

从零开始教你写一个MLIR Pass

LeetGPU入门教程 (CUDA guide最佳实践)

posted @ 2025-07-06 19:54  暴力都不会的蒟蒻  阅读(45)  评论(0)    收藏  举报