Triton SPIR-V 后端开发:backend 初始化

5.18更 项目改为基于upstreammain分支开发。原来为了稳定在release/3.3.x分支开发,但是fork项目主页还是会和main对比的,目前改动还不多,直接切过来了。

本博客原文地址:https://www.cnblogs.com/BobHuang/p/18881029,原文体验更佳

本项目地址OpenMLIR/triton-spirv,目标是提供一个接入Triton的样本,完成tutorials的适配,大概率是个Toy。最终跑在Nvidia的卡上,目标达到仅使用CUDA COREKernel性能的50%及以上。

本文章对应commit完成了一个新后端的添加,并能完成Triton kernelttir(Triton IR)的转换。

本次改动有9个文件,可以分类为setup.pySPIRV注册,third_partySPIRV注册、多driver问题解决、SPIRV backend暂时不完整时的提前return。另外Triton的编译安装可以参考项目中文文档,安装好后使用TRITON_SPIRV_BACKEND=1做为环境变量即可在~/.triton/cache/XXX中看到对应的ttir(Triton IR)

一、setup.py 的SPIRV注册

python/setup.py仅需要添加我们在third_party的backend文件夹名,会进行Python代码的集成和CMake参数的修改,代码如下所示。

backends = [*BackendInstaller.copy(["nvidia", "amd", "spirv"]), *BackendInstaller.copy_externals()]

二、third_party 的SPIRV注册

这里分Python部分和C++部分,他们之间通过Pybind绑定。

1、Python初始文件

需要 third_party/spirv/backend/compiler.pythird_party/spirv/backend/driver.py和空的third_party/spirv/backend/init.py

third_party/spirv/backend/compiler.py 定义了SPIRVOptions即metadata等相关选项,还有SPIRVBackend,初始化的时候会检查,里面还定义了现在的编译stage,即只完成到ttir(Triton IR)的转换。

from triton.backends.compiler import BaseBackend, GPUTarget
import functools
import hashlib
import os
import tempfile
from pathlib import Path

from dataclasses import dataclass
from types import ModuleType
from typing import Any, Dict, Optional, Tuple

from triton._C.libtriton import spirv, ir, llvm, passes

from triton.runtime.build import _build


@dataclass(frozen=True)
class SPIRVOptions:
    backend_name: str = "spirv"
    num_warps: int = 0
    num_stages: int = 0
    num_ctas: int = 0
    num_threads: int = 0
    cluster_dims: tuple = (1, 1, 1)
    extern_libs: dict = None
    debug: bool = False
    launch_cooperative_grid: bool = False
    max_num_imprecise_acc_default: int = 0
    sanitize_overflow: bool = False

    def __post_init__(self):
        pass

    def hash(self):
        hash_dict = dict(self.__dict__)
        key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
        return hashlib.sha256(key.encode("utf-8")).hexdigest()


class SPIRVBackend(BaseBackend):

    @staticmethod
    def supports_target(target: GPUTarget):
        return target.backend == "spirv"

    def __init__(self, target: tuple) -> None:
        super().__init__(target)
        self.binary_ext = "so"

    def parse_options(self, opts) -> Any:
        args = {k: opts[k] for k in SPIRVOptions.__dataclass_fields__.keys() if k in opts}
        return SPIRVOptions(**args)

    def pack_metadata(self, metadata):
        return metadata

    def get_codegen_implementation(self, options):
        pass

    def get_module_map(self) -> Dict[str, ModuleType]:
        from triton.language.extra.cuda import libdevice
        return {"triton.language.extra.libdevice": libdevice}

    def load_dialects(self, ctx):
        spirv.load_dialects(ctx)

    @staticmethod
    def make_ttir(mod, metadata, opt):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        passes.common.add_inliner(pm)
        passes.ttir.add_rewrite_tensor_pointer(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_combine(pm)
        passes.ttir.add_reorder_broadcast(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        passes.ttir.add_loop_unroll(pm)
        pm.run(mod)
        return mod


    def add_stages(self, stages, options):
        stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)


    @functools.lru_cache()
    def hash(self):
        import platform
        return f"{platform.machine()}"

third_party/spirv/backend/driver.py 定义了SPIRVLauncherSPIRVDriver, get_active_torch_device用的还是cpu,都是空实现or直接复制粘贴,后期再完善相关逻辑。


import os
import hashlib
import importlib
import importlib.resources
import tempfile
import time

import triton
import triton._C
from triton.runtime.build import _build
from triton.runtime.cache import get_cache_manager
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget

from pathlib import Path


# ------------------------
# Utils
# ------------------------


class SPIRVUtils(object):

    def __new__(cls):
        if not hasattr(cls, "instance"):
            cls.instance = super(SPIRVUtils, cls).__new__(cls)
        return cls.instance

    def __init__(self):
        pass

    def load_binary(self, name, kernel, shared_mem, device):
        pass

    def get_device_properties(self, *args):
        return {"max_shared_mem": 0}


# ------------------------
# Launcher
# ------------------------

def make_launcher(constants, signature, ids):
    # Record the end of regular arguments;
    # subsequent arguments are architecture-specific descriptors.

    # generate glue code
    src = f""""""
    return src


class SPIRVLauncher(object):

    def __init__(self, src, metadata):
        pass

    def __call__(self, *args, **kwargs):
        self.launch(*args, **kwargs)


class SPIRVDriver(DriverBase):

    def __init__(self):
        self.utils = SPIRVUtils()
        self.launcher_cls = SPIRVLauncher
        super().__init__()

    def get_current_device(self):
        return 0

    def get_active_torch_device(self):
        import torch
        return torch.device("cpu", self.get_current_device())

    def get_current_stream(self, device):
        return 0

    def get_current_target(self):
        # Capability and warp size are zeros for SPIRV.
        return GPUTarget("spirv", "test", 0)

    def get_device_interface(self):
        import torch
        return torch.cuda

    @staticmethod
    def is_active():
        return True

    def get_benchmarker(self):
        from triton.testing import do_bench
        return do_bench

    def get_empty_cache_for_benchmark(self):
        pass

    def clear_cache(self, cache):
        cache.zero_()

2、C++初始文件

定义了third_party/spirv/triton_spirv.cc 将Python中的初始化以及load_dialects 空实现。

#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/TargetSelect.h"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

namespace py = pybind11;

void init_triton_spirv(py::module &&m) {
  auto passes = m.def_submodule("passes");
  // load dialects
  m.def("load_dialects", [](mlir::MLIRContext &context) {});
}

还需要third_party/spirv/CMakeLists.txt将上面这个文件编译,这里还需要链接Python和pybind

if(TRITON_BUILD_PYTHON_MODULE)
  add_triton_plugin(TritonSPIRV ${CMAKE_CURRENT_SOURCE_DIR}/triton_spirv.cc)
  target_link_libraries(TritonSPIRV PUBLIC PRIVATE Python3::Module pybind11::headers)
endif()

三、多driver问题解决

直接参考的triton-lang/triton-cpu,如果定义了环境变量直接使用SPIRV的driver,否则使用GPU的driver。python/triton/runtime/driver.py_create_driver将由下所示

def _create_driver() -> DriverBase:
    import os
    if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
        if "spirv" not in backends:
            raise RuntimeError("TRITON_SPIRV_BACKEND is set, but SPIRV backend is unavailable.")
        return backends["spirv"].driver()
    active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
    if len(active_drivers) >= 2 and backends["spirv"].driver.is_active():
        print("Both SPIRV and GPU backends are available. Using the GPU backend.")
        active_drivers.remove(backends["spirv"].driver)
    if len(active_drivers) != 1:
        raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
    return active_drivers[0]()

四、项目不完整时的提前return

由于未完成整个Kernel的编译,有些metadata还没有,所以python/triton/compiler/compiler.pyCompiledKernel我们需要让其先结束,如下所示。

class CompiledKernel:

    def __init__(self, src, metadata_group, hash):
        if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
            return
        from collections import namedtuple

编译未运行完,当然python/triton/runtime/jit.py内也不能launch,代码如下所示。

            grid_0 = grid[0]
            grid_1 = grid[1] if grid_size > 1 else 1
            grid_2 = grid[2] if grid_size > 2 else 1
            # launch kernel
            import os
            if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
                return
            launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
            kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
                       knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())

五、结语

先致谢一下,本项目参考和拷贝了triton-lang/triton-cpumicrosoft/triton-shared 的大部份代码,向他们致以崇高敬意。并感谢ByteDance-Seed/Triton-distributedCambricon/triton-linalgintel/intel-xpu-backend-for-tritontriton-lang/triton生态的努力。

再碎碎念一下,这个项目的源起在2023年5月,当时我在将我们自有的gpgpu芯片接入Pytorch,希望完成对大模型的训练,此时我了解到了Triton。在自有AI芯片接入AI框架Pytorch的方案中,我记录了我对Triton的看好。2024年6月随着智源研究院的推动,Triton也在国内火了起来。我也心痒痒,然后写了浅析 Triton 执行流程,并在文章中提到了想做一个SPIR-V后端的玩具。今年我正式投入到了Triton的开发中,空闲之余我也进行了算子开发的学习,并写了LeetGPU入门教程 (CUDA guide最佳实践)。我慢慢意识到了Triton的伟大,而不仅是初次了解时的扩充生态以及快速交差的期望。Pythonic is future!,世界上会写Python的人越来越多,AI模型的发展可能会产生更多的算子,让算法的人能直接写算子带来的收益很可能抹平算子性能的差距。我用业余时间维护triton-spirv一方面是为了完成当年想做一个toy的心愿,另外也希望降低Triton的接入成本,为大家提供一个简明教程。我也衷心地希望Triton可以越来越强,既好写性能又过得去,愿Python kernel mode生态越来越强大,或许可以打破一部分CUDA的护城河。

posted @ 2025-05-17 10:34  暴力都不会的蒟蒻  阅读(114)  评论(0)    收藏  举报