Triton SPIR-V 后端开发:向量加实现验证

本博客原文地址:https://www.cnblogs.com/BobHuang/p/18917000,原文体验更佳

项目地址:OpenMLIR/triton-spirv

本项目于2025.6.7在NVIDIA显卡上使用OpenCL跑通了向量加vector-add.py,并对上了答案。

一、前言

1、项目初始

项目的开始在2025.5.6,我在rtfffMLIR Triton躺平开发群宣布了想做这个项目,项目的源起可追溯到2023.5我在自有GPGPU对Pytorch的适配。

2、项目进展

我初步完成了方案设计,并计划以由简入繁的方式逐步实现功能。在随后的一个月开发过程中,我依次解决了NVIDIA显卡不支持clCreateProgramWithILOpenCL ABI,以及Pytorch不支持OpenCL后端等问题后,最终成功跑通了向量加示例程序vector-add.py,并验证其计算结果正确。

项目暂时存在一些不完善的地方,但跑通demo后我认为其他开发者也可以参与了,可以参考我的项目文档(在此知乎专栏)以及Triton源码等各种资料来尝试独立完成某一部分的开发工作。

3、更多的后端参考

在项目的开发中,FlagTree国产芯片Triton编译器开源了。所以除了nvidiaamdintelcpumicrosoft(部分)meta tlx扩展(Low-level)这些后端外,我们还有昆仑芯摩尔线程沐曦ARM china华为昇腾清微智能天数智芯寒武纪(部分)Seed 分布式扩展Op 的后端实现可以参考。我的开发体验是Triton的开发需要会写算子的编译器哥,懂体系结构的算法哥过来会降维打击,我的系列文章可能教会一部分算子哥学会编译器?peformance,想要吃饭还得是performance,要把算子调优的经验固化到Triton编译器中,Triton才会越来越好。

二、当前项目设计

本文章所用commit为f66f77c,当前编译stage为Python code->Triton IR->Linalg IR->memref IR->OpenCL file,代码在third_party/spirv/backend/compiler.py:135,产物在third_party/spirv/test

Launch kernel 利用ctypes调用的OpenCL shared library,由于Pytorch没有OpenCL接入,所以申请的tensor在cpu上,需要搬运,代码在third_party/spirv/backend/cl_utils.py:34

1、AST 获取并转换为ttir

ttir即Triton IR。这里复用了Triton的源码,因为不是NPU/DSA我们并不需要有多余的Op支持,还可以完全兼容已经写好的Triton算子。具体调用流程可以参考浅析 Triton 执行流程

    @staticmethod
    def make_ttir(mod, metadata, opt):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        passes.common.add_inliner(pm)
        passes.ttir.add_rewrite_tensor_pointer(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_combine(pm)
        passes.ttir.add_reorder_broadcast(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        passes.ttir.add_loop_unroll(pm)
        pm.run(mod)
        return mod

vector-add.py 的产物在third_party/spirv/test/add_kernel.ttir,如下所示

#loc = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0)
module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0)) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)
    %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)
    %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6)
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6)
    %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc7)
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc7)
    %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc8)
    %10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc9)
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc9)
    %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc10)
    %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11)
    %14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc12)
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc12)
    tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc13)
    tt.return loc(#loc14)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":38:24)
#loc3 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":43:24)
#loc4 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":44:41)
#loc5 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":44:28)
#loc6 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":46:21)
#loc7 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:24)
#loc8 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:16)
#loc9 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":50:24)
#loc10 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":50:16)
#loc11 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":51:17)
#loc12 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":53:26)
#loc13 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":53:35)
#loc14 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":53:4)

2、ttir转换为lair

lair即Linalg IR,这里复用了microsoft/triton-shared的源码,会对ttir的指针进行分析,然后转换为memref和bufferization。具体代码如下所示。

    @staticmethod
    def make_lair(mod, metadata, opt):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        spirv.passes.lair.triton_to_linalg(pm)
        pm.run(mod)
        return mod

Pass中主要是runUseAnalysis和几个Convert,runUseAnalysis会对指针分析,以供convert使用。我这里添加了GetProgramIDConverter,将triton::GetProgramIdOp转换为gpu::GlobalIdOp,其余均为拷贝过来的,我是根据vector-add.py 的ttir来添加的,原项目更多。

  populateFunctionOpInterfaceTypeConversionPattern<triton::FuncOp>(
      patterns, typeConverter);
  patterns.add<MetaOpConverter>(patterns.getContext());
  patterns.add<StoreConverter>(patterns.getContext());
  patterns.add<LegacyAddPtrConverter>(patterns.getContext());
  patterns.add<GetProgramIDConverter>(patterns.getContext());
  patterns.add<LoadConverter>(patterns.getContext());
  patterns.add<SplatConverter>(patterns.getContext());
  linalg::populateElementwiseToLinalgConversionPatterns(patterns);

vector-add.py 的产物在third_party/spirv/test/add_kernel.lair,如下所示

#loc = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0)
#loc5 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:16)
#loc7 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":50:16)
#map = affine_map<(d0) -> (d0)>
module {
  func.func @add_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg1: memref<*xf32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg2: memref<*xf32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0)) {
    %c1024 = arith.constant 1024 : index loc(#loc1)
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %global_id_x = gpu.global_id  x loc(#loc2)
    %0 = arith.index_cast %global_id_x : index to i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = arith.index_cast %1 : i32 to index loc(#loc4)
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%2], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> loc(#loc4)
    %alloc = memref.alloc() : memref<1024xf32> loc(#loc5)
    %3 = arith.addi %2, %c1024 : index loc(#loc5)
    %4 = arith.index_cast %arg3 : i32 to index loc(#loc5)
    %5 = arith.minsi %3, %4 : index loc(#loc5)
    %6 = arith.maxsi %5, %2 : index loc(#loc5)
    %7 = arith.subi %6, %2 : index loc(#loc5)
    %subview = memref.subview %reinterpret_cast[0] [%7] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> loc(#loc5)
    %subview_0 = memref.subview %alloc[0] [%7] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>> loc(#loc5)
    memref.copy %subview, %subview_0 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>> loc(#loc5)
    %8 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32> to tensor<1024xf32> loc(#loc5)
    %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%2], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> loc(#loc6)
    %alloc_2 = memref.alloc() : memref<1024xf32> loc(#loc7)
    %subview_3 = memref.subview %reinterpret_cast_1[0] [%7] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> loc(#loc7)
    %subview_4 = memref.subview %alloc_2[0] [%7] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>> loc(#loc7)
    memref.copy %subview_3, %subview_4 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>> loc(#loc7)
    %9 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32> to tensor<1024xf32> loc(#loc7)
    %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%8, %9 : tensor<1024xf32>, tensor<1024xf32>) outs(%8 : tensor<1024xf32>) {
    ^bb0(%in: f32 loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:16), %in_7: f32 loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":50:16), %out: f32 loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:16)):
      %11 = arith.addf %in, %in_7 : f32 loc(#loc8)
      linalg.yield %11 : f32 loc(#loc8)
    } -> tensor<1024xf32> loc(#loc8)
    %reinterpret_cast_5 = memref.reinterpret_cast %arg2 to offset: [%2], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> loc(#loc9)
    %extracted_slice = tensor.extract_slice %10[0] [%7] [1] : tensor<1024xf32> to tensor<?xf32> loc(#loc10)
    %subview_6 = memref.subview %reinterpret_cast_5[0] [%7] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> loc(#loc10)
    bufferization.materialize_in_destination %extracted_slice in writable %subview_6 : (tensor<?xf32>, memref<?xf32, strided<[1], offset: ?>>) -> () loc(#loc10)
    return loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":38:24)
#loc3 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":43:24)
#loc4 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:24)
#loc6 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":50:24)
#loc8 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":51:17)
#loc9 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":53:26)
#loc10 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":53:35)

3、lair转换为memir

memir即memref IR,转换为linalg IR后我们的代码还存在tensor以及bufferization。我们需要将其进一步干掉,也就是--one-shot-bufferize。另外linalg::GenericOp在做完tiling以及优化后,我们需要转换为affine方便后面优化的可能。--canonicalize可以对冗余操作进行消除。代码如下所示。

    @staticmethod
    def make_memir(mod, metadata, opt):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        spirv.passes.memir.one_shot_bufferize(pm)
        spirv.passes.memir.linalg_to_affine_loops(pm)
        passes.common.add_canonicalizer(pm)
        pm.run(mod)
        return mod

这里需要改写MLIR的--convert-linalg-to-affine-loops,因为其不是运行在mlir::ModuleOp上的Pass。--one-shot-bufferize需要注册对应的interface,具体代码如下所示。

  m.def("load_dialects", [](mlir::MLIRContext &context) {
    mlir::DialectRegistry registry;
    mlir::bufferization::func_ext::
        registerBufferizableOpInterfaceExternalModels(registry);
    mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
    mlir::linalg::registerBufferizableOpInterfaceExternalModels(registry);
    mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
    context.appendDialectRegistry(registry);
    context.loadAllAvailableDialects();
  });

vector-add.py 的产物在third_party/spirv/test/add_kernel.memir,如下所示

#loc = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0)
#loc8 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":51:17)
module {
  func.func @add_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg1: memref<*xf32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg2: memref<*xf32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":29:0)) {
    %c1024 = arith.constant 1024 : index loc(#loc1)
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %global_id_x = gpu.global_id  x loc(#loc2)
    %0 = arith.index_cast %global_id_x : index to i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = arith.index_cast %1 : i32 to index loc(#loc4)
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%2], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> loc(#loc4)
    %alloc = memref.alloc() : memref<1024xf32> loc(#loc5)
    %3 = arith.addi %2, %c1024 : index loc(#loc5)
    %4 = arith.index_cast %arg3 : i32 to index loc(#loc5)
    %5 = arith.minsi %3, %4 : index loc(#loc5)
    %6 = arith.maxsi %5, %2 : index loc(#loc5)
    %7 = arith.subi %6, %2 : index loc(#loc5)
    %subview = memref.subview %reinterpret_cast[0] [%7] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> loc(#loc5)
    %subview_0 = memref.subview %alloc[0] [%7] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>> loc(#loc5)
    memref.copy %subview, %subview_0 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>> loc(#loc5)
    %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%2], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> loc(#loc6)
    %alloc_2 = memref.alloc() : memref<1024xf32> loc(#loc7)
    %subview_3 = memref.subview %reinterpret_cast_1[0] [%7] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> loc(#loc7)
    %subview_4 = memref.subview %alloc_2[0] [%7] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>> loc(#loc7)
    memref.copy %subview_3, %subview_4 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>> loc(#loc7)
    affine.for %arg4 loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":51:17) = 0 to 1024 {
      %8 = affine.load %alloc[%arg4] : memref<1024xf32> loc(#loc8)
      %9 = affine.load %alloc_2[%arg4] : memref<1024xf32> loc(#loc8)
      %10 = arith.addf %8, %9 : f32 loc(#loc8)
      affine.store %10, %alloc[%arg4] : memref<1024xf32> loc(#loc8)
    } loc(#loc8)
    %reinterpret_cast_5 = memref.reinterpret_cast %arg2 to offset: [%2], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> loc(#loc9)
    %subview_6 = memref.subview %alloc[0] [%7] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>> loc(#loc10)
    %subview_7 = memref.subview %reinterpret_cast_5[0] [%7] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> loc(#loc10)
    memref.copy %subview_6, %subview_7 : memref<?xf32, strided<[1]>> to memref<?xf32, strided<[1], offset: ?>> loc(#loc10)
    return loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":38:24)
#loc3 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":43:24)
#loc4 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:24)
#loc5 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":49:16)
#loc6 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":50:24)
#loc7 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":50:16)
#loc9 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":53:26)
#loc10 = loc("/home/ubuntu/triton-spirv/python/tutorials/spirv_demo/01-vector-add.py":53:35)

4、emit到OpenCL

NVIDIA的OpenCL是不支持clCreateProgramWithIL的,也就是我们不能把LLVM SPIRV IR做为他的输入,曲线救国我们需要emit出OpenCL源码出来,所以需要二进制工具triton-spirv-translate将MLIR源码输出到OpenCL文件中,具体代码如下所示。

    @staticmethod
    def emit_opencl(src, metadata, opt):
        import re
        names = re.findall(r"func\.func @(\w+)\(", str(src))
        assert len(names) == 1
        metadata["name"] = names[0]
        import triton._C as tc
        spirv_translate = os.path.join(tc.__path__[0], 'triton-spirv-translate')
        with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.memir') as fsrc:
            fsrc.write(str(src))
            fsrc.flush()
            opencl_file = fsrc.name + '.cl'
            emit_opencl_cmd = [
                spirv_translate,
                fsrc.name,
                '-triton-spirv-emit-opencl',
                '-o',
                opencl_file
            ]
            subprocess.run(emit_opencl_cmd, check=True, close_fds=False)
            with open(opencl_file, 'rb') as f:
                opencl_src = f.read()
            if os.path.exists(opencl_file):
                os.remove(opencl_file)
        return opencl_src

我在做HLS时有了解到ScaleHLS-HIDA项目中的lib/Translation/EmitHLSCpp.cpp可以将MLIR输出为CPP源码,我对其去掉相关依赖并对LLVM做版本升级后得到了triton-spirv-translate。这里我主要写了memref::CopyOp,我希望在此时memref已经Flatten拍成一维,暂时对stride暂时限制为1,并对offsetsize做了具体的emit,部分代码如下所示,完整代码在third_party/spirv/tool/triton-spirv-translate/EmitOpenCL.cpp

void ModuleEmitter::emitMemCpy(memref::CopyOp op) {
  auto sourceSubView = getSubviewOp(op.getSource());
  auto targetSubView = getSubviewOp(op.getTarget());
  assert(sourceSubView && targetSubView && "need copy subview");
  assert(checkOneDimMemref(sourceSubView) && checkOneDimMemref(targetSubView) &&
         "memcpy not support over 1D");
  assert(checkSubViewOffsetAndStride(sourceSubView) &&
         "source subview not support");
  assert(checkSubViewOffsetAndStride(targetSubView) &&
         "target subview not support");

  indent() << "for (";
  // Emit lower bound.
  os << "int i = 0; ";
  // Emit upper bound.
  os << "i < ";
  OpFoldResult upperBound = targetSubView.getMixedSizes()[0];
  if (auto intAttr = getConstantIntValue(upperBound)) {
    os << intAttr;
  } else {
    emitValue(mlir::dyn_cast<Value>(upperBound));
  }
  os << "; i += 1) {\n";
  addIndent();
  indent();
  emitMemCpyValue(targetSubView.getSource());
  os << " = ";
  emitMemCpyValue(sourceSubView.getSource());
  os << ";\n";
  reduceIndent();
  indent() << "}";
  emitInfoAndNewLine(op);
  if (mlir::isa<memref::AllocOp>(targetSubView.getSource().getDefiningOp())) {
    indent() << "barrier(CLK_LOCAL_MEM_FENCE);\n";
  }
}

vector-add.py 的产物在third_party/spirv/test/add_kernel.cl,如下所示

//===------------------------------------------------------------*- C++ -*-===//
//
// Automatically generated file for OpenCL
//
//===----------------------------------------------------------------------===//

__kernel void add_kernel(
  __global float* var_0,
  __global float* var_1,
  __global float* var_2,
  int var_3
) {	// L29
  int var_4 = get_global_id(0);
  int var_6 = var_4 * 1024;	// L43
  __local float var_8[1024];	// L49
  int var_9 = var_6 + 1024;	// L49
  int var_11 = min(var_9, var_3);	// L49
  int var_12 = max(var_11, var_6);	// L49
  int var_13 = var_12 - var_6;	// L49
  for (int i = 0; i < var_13; i += 1) {
    var_8[i] = var_0[i + var_6];
  }	// L49
  barrier(CLK_LOCAL_MEM_FENCE);
  __local float var_14[1024];	// L50
  for (int i = 0; i < var_13; i += 1) {
    var_14[i] = var_1[i + var_6];
  }	// L50
  barrier(CLK_LOCAL_MEM_FENCE);
  for (int var_15 = 0; var_15 < 1024; var_15 += 1) {	// L51
    float var_16 = var_8[var_15];	// L51
    float var_17 = var_14[var_15];	// L51
    float var_18 = var_16 + var_17;	// L51
    var_8[var_15] = var_18;	// L51
  }
  for (int i = 0; i < var_13; i += 1) {
    var_2[i + var_6] = var_8[i];
  }	// L53
}

这步不是Pass,是二进制工具,所以还需要集成到项目中,我修改了setup.py将其放进了triton/_C中,具体代码如下所示。

        subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
        translate_path = "third_party/spirv/tool/triton-spirv-translate/triton-spirv-translate"
        shutil.copy(os.path.join(cmake_dir, translate_path), extdir)

4、Launch kernel

关于Launch kernel我是用12-ctypes-cl.py做完实验后直接用Python嵌入到项目中的。在Launch kernel时我为其设置了特殊路径,python/triton/runtime/jit.py:633代码如下所示。

            grid_2 = grid[2] if grid_size > 2 else 1
            # launch kernel
            import os
            if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
                kernel.run(grid_0, grid_1, grid_2, kernel.name, kernel.kernel, bound_args.values())
                return kernel
            launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())

此外由于kernel没有预编译,还不存在相关metadata,所以python/triton/compiler/compiler.py:466_init_handles内做了判断,代码如下所示

    def _init_handles(self):
        if self.module is not None:
            return
        device = driver.active.get_current_device()
        # create launcher
        self.run = driver.active.launcher_cls(self.src, self.metadata)
        import os
        if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
            return
        # not enough shared memory to run the kernel
        max_shared = max_shared_mem(device)

具体Launch kernel实现在third_party/spirv/backend/cl_utils.py:34

三、存在的问题

1、不支持clCreateProgramWithIL

某天我在测试时得到了下图,并在NVIDIA论坛和知乎上看到了相关讨论,clCreateProgramWithIL-alternatives

demo做到一半时遇到这个问题很难受,有点破防,甚至对项目的可行性产生了动摇。毫无疑问NVIDIATriton是个很好的对标项目,换对标产品对我以及想参与这个开源项目的人都会在一定程度上削弱积极性,毕竟目前AI芯片是NVIDIAothers

不过我还是硬着头皮做完了demo,意外完成了蓝色大佬在Baby Triton说的输出Kernel

  • 解决方案1: 换到intel或AMD,优点是不直接对IR进行优化的话需要组装很多的inline asm,且nvidiaOpenCL可能也支持得不好,缺点是需要搞到卡且项目价值需要重估。

  • 解决方案2:继续再完善几个kernel,试试tma等特性再回头解决方案1或因为性能太差直接放弃项目。个人暂时更倾向于此方案。

2、Launch kernel 参数固定

我暂时没有对kernel参数进行处理,仅支持vector-add.py,需要增加对类型的判断让其更通用。

3、block_size的thread数为1

NVIDIATriton是通过threadsPerWarp = [32], warpsPerCTA = [4]这样的metadata来做的,分别代表 每个warp的线程数和每个CTA(线程块)的warp数。比如向量加的BLOCK_SIZE为1024,实际上block dim1024/(4*32),也就是每个线程处理8个元素,是在ConvertTritonGPUToLLVM的Pass中通过convertTritonTensorType来转换的,核心代码为lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp:46unsigned numElementsPerThread = getTotalElemsPerThread(type);

4、Pytorch无OpenCL后端

我试了CUDA申请的tensor并不能给OpenCL用,所以最后是用cpu的tensor,再读回去的。

  • 解决方案1:clCreateBuffer全设置为可读可写,再按需copy回host比较好

  • 解决方案2:找一个有OpenCL申请tensor的Pytorch,接入过来,看起来依然没有,自有AI芯片接入AI框架Pytorch的方案

  • 解决方案3: 抛弃掉Pytorch,make Triton 成为完整的语言。可能要搞下Pytorch里那套内存池以及devie和host的copy逻辑。

5、memcpy的完善

实际上memcpy是很复杂的,而且也不一定是全copy到local上,且可以这部分复用。barrier也需要显式体现在IR中,且可以复用。

6、Launch 补全

NVIDIALaunch是使用Python的格式字符串胶水进去的,使用了PyObject,显然我们也需要用C++的host code,而不是ctypes。by the way,Triton的runtime还是很耗时的,如果小算子且执行多次,可以把这部分都用C++重写下。

7、benchmark

答案对得上了,性能如何呢,估计惨不忍睹,当然我只想比clEnqueueNDRangeKernel

8、其他未知问题

由于本项目目前仅是为了跑通demo,还存在各种各样的问题,欢迎大家来提pr。

附录1、项目文档

Triton SPIR-V 后端开发:向量加实现验证

Triton SPIR-V 后端开发:PyBind绑定

Triton SPIR-V 后端开发:新增Pass

Triton SPIR-V 后端开发:backend 初始化

附录2、作者相关技术文章

浅析 Triton 执行流程

从零开始教你写一个MLIR Pass

LeetGPU入门教程 (CUDA guide最佳实践)

posted @ 2025-06-07 19:55  暴力都不会的蒟蒻  阅读(121)  评论(1)    收藏  举报