Triton SPIR-V 后端开发:backend 初始化
5.18更 项目改为基于upstreammain分支开发。原来为了稳定在release/3.3.x分支开发,但是fork项目主页还是会和main对比的,目前改动还不多,直接切过来了。
本博客原文地址:https://www.cnblogs.com/BobHuang/p/18881029,原文体验更佳
本项目地址OpenMLIR/triton-spirv,目标是提供一个接入Triton的样本,完成tutorials的适配,大概率是个Toy。最终跑在Nvidia的卡上,目标达到仅使用CUDA CORE的Kernel性能的50%及以上。
本文章对应commit完成了一个新后端的添加,并能完成Triton kernel 到 ttir(Triton IR)的转换。
本次改动有9个文件,可以分类为setup.py的SPIRV注册,third_party的SPIRV注册、多driver问题解决、SPIRV backend暂时不完整时的提前return。另外Triton的编译安装可以参考项目中文文档,安装好后使用TRITON_SPIRV_BACKEND=1做为环境变量即可在~/.triton/cache/XXX中看到对应的ttir(Triton IR)。
一、setup.py 的SPIRV注册
python/setup.py仅需要添加我们在third_party的backend文件夹名,会进行Python代码的集成和CMake参数的修改,代码如下所示。
backends = [*BackendInstaller.copy(["nvidia", "amd", "spirv"]), *BackendInstaller.copy_externals()]
二、third_party 的SPIRV注册
这里分Python部分和C++部分,他们之间通过Pybind绑定。
1、Python初始文件
需要 third_party/spirv/backend/compiler.py、third_party/spirv/backend/driver.py和空的third_party/spirv/backend/init.py。
third_party/spirv/backend/compiler.py 定义了SPIRVOptions即metadata等相关选项,还有SPIRVBackend,初始化的时候会检查,里面还定义了现在的编译stage,即只完成到ttir(Triton IR)的转换。
from triton.backends.compiler import BaseBackend, GPUTarget
import functools
import hashlib
import os
import tempfile
from pathlib import Path
from dataclasses import dataclass
from types import ModuleType
from typing import Any, Dict, Optional, Tuple
from triton._C.libtriton import spirv, ir, llvm, passes
from triton.runtime.build import _build
@dataclass(frozen=True)
class SPIRVOptions:
backend_name: str = "spirv"
num_warps: int = 0
num_stages: int = 0
num_ctas: int = 0
num_threads: int = 0
cluster_dims: tuple = (1, 1, 1)
extern_libs: dict = None
debug: bool = False
launch_cooperative_grid: bool = False
max_num_imprecise_acc_default: int = 0
sanitize_overflow: bool = False
def __post_init__(self):
pass
def hash(self):
hash_dict = dict(self.__dict__)
key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
return hashlib.sha256(key.encode("utf-8")).hexdigest()
class SPIRVBackend(BaseBackend):
@staticmethod
def supports_target(target: GPUTarget):
return target.backend == "spirv"
def __init__(self, target: tuple) -> None:
super().__init__(target)
self.binary_ext = "so"
def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in SPIRVOptions.__dataclass_fields__.keys() if k in opts}
return SPIRVOptions(**args)
def pack_metadata(self, metadata):
return metadata
def get_codegen_implementation(self, options):
pass
def get_module_map(self) -> Dict[str, ModuleType]:
from triton.language.extra.cuda import libdevice
return {"triton.language.extra.libdevice": libdevice}
def load_dialects(self, ctx):
spirv.load_dialects(ctx)
@staticmethod
def make_ttir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
pm.run(mod)
return mod
def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
@functools.lru_cache()
def hash(self):
import platform
return f"{platform.machine()}"
third_party/spirv/backend/driver.py 定义了SPIRVLauncher和SPIRVDriver, get_active_torch_device用的还是cpu,都是空实现or直接复制粘贴,后期再完善相关逻辑。
import os
import hashlib
import importlib
import importlib.resources
import tempfile
import time
import triton
import triton._C
from triton.runtime.build import _build
from triton.runtime.cache import get_cache_manager
from triton.backends.driver import DriverBase
from triton.backends.compiler import GPUTarget
from pathlib import Path
# ------------------------
# Utils
# ------------------------
class SPIRVUtils(object):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(SPIRVUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
pass
def load_binary(self, name, kernel, shared_mem, device):
pass
def get_device_properties(self, *args):
return {"max_shared_mem": 0}
# ------------------------
# Launcher
# ------------------------
def make_launcher(constants, signature, ids):
# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors.
# generate glue code
src = f""""""
return src
class SPIRVLauncher(object):
def __init__(self, src, metadata):
pass
def __call__(self, *args, **kwargs):
self.launch(*args, **kwargs)
class SPIRVDriver(DriverBase):
def __init__(self):
self.utils = SPIRVUtils()
self.launcher_cls = SPIRVLauncher
super().__init__()
def get_current_device(self):
return 0
def get_active_torch_device(self):
import torch
return torch.device("cpu", self.get_current_device())
def get_current_stream(self, device):
return 0
def get_current_target(self):
# Capability and warp size are zeros for SPIRV.
return GPUTarget("spirv", "test", 0)
def get_device_interface(self):
import torch
return torch.cuda
@staticmethod
def is_active():
return True
def get_benchmarker(self):
from triton.testing import do_bench
return do_bench
def get_empty_cache_for_benchmark(self):
pass
def clear_cache(self, cache):
cache.zero_()
2、C++初始文件
定义了third_party/spirv/triton_spirv.cc 将Python中的初始化以及load_dialects 空实现。
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/TargetSelect.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
namespace py = pybind11;
void init_triton_spirv(py::module &&m) {
auto passes = m.def_submodule("passes");
// load dialects
m.def("load_dialects", [](mlir::MLIRContext &context) {});
}
还需要third_party/spirv/CMakeLists.txt将上面这个文件编译,这里还需要链接Python和pybind
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonSPIRV ${CMAKE_CURRENT_SOURCE_DIR}/triton_spirv.cc)
target_link_libraries(TritonSPIRV PUBLIC PRIVATE Python3::Module pybind11::headers)
endif()
三、多driver问题解决
直接参考的triton-lang/triton-cpu,如果定义了环境变量直接使用SPIRV的driver,否则使用GPU的driver。python/triton/runtime/driver.py 的 _create_driver将由下所示
def _create_driver() -> DriverBase:
import os
if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
if "spirv" not in backends:
raise RuntimeError("TRITON_SPIRV_BACKEND is set, but SPIRV backend is unavailable.")
return backends["spirv"].driver()
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
if len(active_drivers) >= 2 and backends["spirv"].driver.is_active():
print("Both SPIRV and GPU backends are available. Using the GPU backend.")
active_drivers.remove(backends["spirv"].driver)
if len(active_drivers) != 1:
raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
return active_drivers[0]()
四、项目不完整时的提前return
由于未完成整个Kernel的编译,有些metadata还没有,所以python/triton/compiler/compiler.py的CompiledKernel我们需要让其先结束,如下所示。
class CompiledKernel:
def __init__(self, src, metadata_group, hash):
if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
return
from collections import namedtuple
编译未运行完,当然python/triton/runtime/jit.py内也不能launch,代码如下所示。
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
# launch kernel
import os
if os.getenv("TRITON_SPIRV_BACKEND", "0") == "1":
return
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
五、结语
先致谢一下,本项目参考和拷贝了triton-lang/triton-cpu 和 microsoft/triton-shared 的大部份代码,向他们致以崇高敬意。并感谢ByteDance-Seed/Triton-distributed、Cambricon/triton-linalg和intel/intel-xpu-backend-for-triton为triton-lang/triton生态的努力。
再碎碎念一下,这个项目的源起在2023年5月,当时我在将我们自有的gpgpu芯片接入Pytorch,希望完成对大模型的训练,此时我了解到了Triton。在自有AI芯片接入AI框架Pytorch的方案中,我记录了我对Triton的看好。2024年6月随着智源研究院的推动,Triton也在国内火了起来。我也心痒痒,然后写了浅析 Triton 执行流程,并在文章中提到了想做一个SPIR-V后端的玩具。今年我正式投入到了Triton的开发中,空闲之余我也进行了算子开发的学习,并写了LeetGPU入门教程 (CUDA guide最佳实践)。我慢慢意识到了Triton的伟大,而不仅是初次了解时的扩充生态以及快速交差的期望。Pythonic is future!,世界上会写Python的人越来越多,AI模型的发展可能会产生更多的算子,让算法的人能直接写算子带来的收益很可能抹平算子性能的差距。我用业余时间维护triton-spirv一方面是为了完成当年想做一个toy的心愿,另外也希望降低Triton的接入成本,为大家提供一个简明教程。我也衷心地希望Triton可以越来越强,既好写性能又过得去,愿Python kernel mode生态越来越强大,或许可以打破一部分CUDA的护城河。
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18881029

浙公网安备 33010602011771号