Triton SPIR-V 后端开发:PyBind绑定

本博客原文地址:https://www.cnblogs.com/BobHuang/p/18916737,原文体验更佳

项目地址:OpenMLIR/triton-spirv

上一篇Triton SPIR-V 后端开发:新增Pass 我们新增了TritonToLinalg的Pass,TritonPythonC++融合的项目,我们还需要将C++的Pass通过PyBind绑定到Python中,本次实验对应commit为457f0aa,包含了PyBind接口暴露和编译stage的添加。除此之外,本篇文章顺便写了下spirv-opt的构建,我们可以仅编译C++部分来方便调试。

一、PyBind 接口暴露

third_party/spirv/triton_spirv.ccpassespy::modulesubmodule,通过auto passes = m.def_submodule("passes"); 来定义,我们可以看到third_party/nvidia/backend/compiler.py里有nvidia.passes,我们的就是spirv.passes,我们还可以根据不同的stage再继续划分submodule。方法是通过m.def的lambda 表达式来定义的,暴露"triton_to_linalg"方法给Python用,实际上调用的mlir::triton::spirv::createTritonToLinalgPass());,也就是将Pass将入到PassManager的操作。具体代码如下所示

#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/TargetSelect.h"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

#include "spirv/include/Conversion/TritonToLinalg/Passes.h"

namespace py = pybind11;

void init_triton_spirv_passes_lair(py::module &&m) {
  m.def("triton_to_linalg", [](mlir::PassManager &pm) {
    pm.addPass(
        mlir::triton::spirv::createTritonToLinalgPass());
  });

}

void init_triton_spirv(py::module &&m) {
  auto passes = m.def_submodule("passes");
  init_triton_spirv_passes_lair(passes.def_submodule("lair"));
  // load dialects
  m.def("load_dialects", [](mlir::MLIRContext &context) {});
}

二、编译stage的添加

third_party/spirv/backend/compiler.py。 在PyBind的C++接口暴露完Python这边就可以使用了,也就是spirv.passes.lair.triton_to_linalg(pm)。Triton是通过add_stages执行Pass pipeline的,具体可以参考浅析 Triton 执行流程。我们也需要新增一个stage,可以参照make_ttir,Pass执行在pm.run(mod),具体代码如下所示。

    @staticmethod
    def make_lair(mod, metadata, opt):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()Add commentMore actions
        spirv.passes.lair.triton_to_linalg(pm)
        pm.run(mod)
        return mod


    def add_stages(self, stages, options):
        stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
        stages["lair"] = lambda src, metadata: self.make_lair(src, metadata, options)

三、spirv-opt工具

在开发的过程中,如果遵循Triton的逻辑从Python开始,会对我们的开发进度造成一定困扰。毕竟pip install安装包是非常耗时的,而且对于backend的开发大多数时间都在在写Pass,我们我们搞一个类似spirv-optmlir-opt的执行Pass工具还是会非常方便的。

1、利用MlirOptMain的工具实现

MLIR本身提供了MlirOptMain工具函数,我们仅需要将我们的Pass注册进来即可,即mlir::triton::spirv::registerTritonToLinalgPass();,然后利用DialectRegistry导入你需要的dialect即可。代码如下所示,如果想要使用mlir里更多的Pass或者要注册--one-shot-bufferize可以参考 最新的[spirv/tool/spirv-opt/spirv-opt.cpp]
(https://github.com/OpenMLIR/triton-spirv/blob/develop/third_party/spirv/tool/spirv-opt/spirv-opt.cpp)

#include "spirv/include/Conversion/TritonToLinalg/Passes.h"
Add commentMore actions
#include "mlir/Tools/mlir-opt/MlirOptMain.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

int main(int argc, char **argv) {
  mlir::triton::spirv::registerTritonToLinalgPass();

  mlir::DialectRegistry registry;
  registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
                  mlir::math::MathDialect, mlir::arith::ArithDialect,
                  mlir::scf::SCFDialect, mlir::linalg::LinalgDialect,
                  mlir::func::FuncDialect, mlir::tensor::TensorDialect,
                  mlir::memref::MemRefDialect>();

  return mlir::asMainReturnCode(
      mlir::MlirOptMain(argc, argv, "spirv-opt test driver\n", registry));
}

2、CMake编译命令

third_party/spirv/tool/CMakeLists.txt。CMake里编译这个文件并链接上相关的lib就可以了,具体代码如下所示。

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)Add commentMore actions

add_llvm_executable(spirv-opt spirv-opt.cpp PARTIAL_SOURCES_INTENDED)

llvm_update_compile_flags(spirv-opt)
target_link_libraries(spirv-opt PRIVATE
  TritonTransforms
  TritonToLinalg
  ${dialect_libs}
  ${conversion_libs}
  # MLIR core
  MLIROptLib
  MLIRPass
  MLIRTransforms
)

mlir_check_all_link_libraries(spirv-opt)

3、根目录CMake 控制

third_party/spirv/CMakeLists.txt。这个工具默认情况下不想编译,用户需要使用BUILD_SPIRV_OPT来控制,具体代码如下所示。

option(BUILD_SPIRV_OPT "build spirv-opt to debug" OFF)
if(BUILD_SPIRV_OPT)Add commentMore actions
  add_subdirectory(tool)
endif()

4、编译命令

仅编译C++的命令我们需要在setup.py:513subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)中去拿cmake_args,另外加入-DBUILD_SPIRV_OPT=ON即可,以下是当前我使用的编译命令,在项目的中文README也会对此及时更新。

mkdir build-opt; cd build-opt
cmake -G Ninja .. -DLLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include  -DLLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib -DTRITON_CODEGEN_BACKENDS="nvidia;amd;spirv" -DCMAKE_BUILD_TYPE=Debug -DBUILD_SPIRV_OPT=ON
ninja

5、运行示例

我们就可以直接使用spirv-opt来执行我们的Pass了,如下所示。若嫌弃路径深可以自己把build-opt/third_party/spirv/tool/加入到Path里。

build-opt/third_party/spirv/tool/spirv-opt third_party/spirv/test/add_kernel.ttir  --triton-to-linalg -o linalg.mlir
posted @ 2025-06-07 15:37  暴力都不会的蒟蒻  阅读(25)  评论(0)    收藏  举报