Triton黑魔法:cubin Runner
摘要:绕过Triton Pass pipeline 来运行 cubin 二进制程序
项目地址:OpenMLIR/Triton-ML-Runner
你是否也在为Triton只能从Python到cubin再执行而苦恼,你是否修改了其中的IR但需要修改Triton的源码而丧气,你是否拿到了高版本Triton的高性能kernel但因为集成而痛苦。我将为你带来cubin的Runner,我们可以通过剖析Triton的运行逻辑,将我们的流程嵌入到Triton中去,并作为一个简易工具来使用。
一、使用方法
使用起来so easy,你需要准备你.cache 中的.json文件和.cubin文件,然后提供给utils.py 即可。就像cubin_runner.py diff

其实就是用以下代码,代替了原来的jit流程。
from utils import get_cufunction, cubin_launch
kernel_name = "matmul_kernel"
function = get_cufunction(f"{kernel_name}.json", f"{kernel_name}.cubin", f"{kernel_name}")
bound_args = (a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), 16, 16)
signature_str = "* * * i32 i32 i32 i32 constexpr i32 constexpr i32 constexpr constexpr constexpr"
grid = (triton.cdiv(N, 16), triton.cdiv(M, 16), )
cubin_launch(function, signature_str, bound_args, grid)
signature_str 表示传进去的参数列表并用空格隔开,和bound_args是一一对应的。*表示输入是tensor,i32就是32位有符号整数,constexpr表示常量。16, 16为BLOCK_SIZE_M和BLOCK_SIZE_N,a.stride(1)、b.stride(1)和c.stride(1)因为均是1,被JIT优化成了constexpr。signature_str之后我会写个脚本自动生成。
二、源码分析
目前仅支持Triton v3.3.1,所以我们以commit d654e0f (May 30, 2025)的代码来分析。
1、最终执行入口
最终执行入口在third_party/nvidia/backend/driver.py:522,调用的是CudaLauncher的launch。
def __call__(self, gridX, gridY, gridZ, stream, function, *args):
if self.global_scratch_size > 0:
grid_size = gridX * gridY * gridZ
alloc_size = grid_size * self.global_scratch_size
global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
else:
global_scratch = None
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
launch 是在初始化里定义的,继续向上找。
2、CudaLauncher初始化
third_party/nvidia/backend/driver.py:510。CudaLauncher的__init__如下所示。
def __init__(self, src, metadata):
constants = src.constants if hasattr(src, "constants") else dict()
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
constants = {arg_idx(idx): value for idx, value in constants.items()}
signature = {idx: value for idx, value in src.signature.items()}
src = make_launcher(constants, signature)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch
self.global_scratch_size = metadata.global_scratch_size
self.global_scratch_align = metadata.global_scratch_align
self.launch_cooperative_grid = metadata.launch_cooperative_grid
我们需要的是make_launcher的参数signature来生成host code,constants没在用了。
3、Jit调入CudaLauncher
我们是从python/triton/runtime/jit.py:591 进去的
assert grid is not None
if callable(grid):
grid = grid(bound_args)
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
# launch kernel
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook,
*bound_args.values())
4、function 的 load
function 其实是个指针,在编译时load的。python/triton/compiler/compiler.py:408,代码如下所示。
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
self.name, self.kernel, self.metadata.shared, device)
三、代码编写
经过以上的分析,我们需要准备好相应的参数即可。
需要2步,第一步将cubin load,然后拿到function。第二步准备launch参数,执行即可。代码变动
1、CUfunction 获取
需要从metadata里拿到shared memory大小,也就是之前的json文件。
def get_cufunction(json_path, cubin_path, kernel_name):
global metadata
metadata = json.loads(open(json_path, "r").read())
kernel = open(cubin_path, "rb").read()
module, function, n_regs, n_spills = triton.runtime.driver.active.utils.load_binary(
kernel_name, kernel, metadata['shared'], device)
return function
2、cubin_launch
这里根据lauch的参数怎么来一一定制即可,部分参数还不完整。
def get_grid_xyz(grid):
assert grid is not None
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
return grid_0, grid_1, grid_2
def get_packed_metadata():
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
# JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
target = metadata['target']
metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
compile_metadata = KernelMetadata(**metadata)
backend = make_backend(compile_metadata.target)
return backend.pack_metadata(compile_metadata)
def cubin_launch(function, signature_str, bound_args, grid):
signature = dict(enumerate(signature_str.split()))
src = make_launcher(None, signature)
mod = compile_module_from_src(src, "__triton_launcher")
global_scratch = None
packed_metadata = get_packed_metadata()
launch_metadata, launch_enter_hook, launch_exit_hook = None, None, None
mod.launch(*get_grid_xyz(grid), stream, function, metadata['launch_cooperative_grid'],
global_scratch, packed_metadata, launch_metadata, launch_enter_hook, launch_exit_hook,
*bound_args)
四、相关文章
深度剖析 Triton编译器 MatMul优化(三)—— TMA
深度剖析 Triton编译器 MatMul优化(二)—— MMA
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18972092

浙公网安备 33010602011771号