Triton黑魔法:cubin Runner

摘要:绕过Triton Pass pipeline 来运行 cubin 二进制程序

项目地址:OpenMLIR/Triton-ML-Runner

你是否也在为Triton只能从Python到cubin再执行而苦恼,你是否修改了其中的IR但需要修改Triton的源码而丧气,你是否拿到了高版本Triton的高性能kernel但因为集成而痛苦。我将为你带来cubin的Runner,我们可以通过剖析Triton的运行逻辑,将我们的流程嵌入到Triton中去,并作为一个简易工具来使用。

一、使用方法

使用起来so easy,你需要准备你.cache 中的.json文件和.cubin文件,然后提供给utils.py 即可。就像cubin_runner.py diff

其实就是用以下代码,代替了原来的jit流程。

    from utils import get_cufunction, cubin_launch
    kernel_name = "matmul_kernel"
    function = get_cufunction(f"{kernel_name}.json", f"{kernel_name}.cubin", f"{kernel_name}")
    bound_args = (a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1),
                  c.stride(0), c.stride(1), 16, 16)
    signature_str = "* * * i32 i32 i32 i32 constexpr i32 constexpr i32 constexpr constexpr constexpr"
    grid = (triton.cdiv(N, 16), triton.cdiv(M, 16), )
    cubin_launch(function, signature_str, bound_args, grid)

signature_str 表示传进去的参数列表并用空格隔开,和bound_args是一一对应的。*表示输入是tensor,i32就是32位有符号整数,constexpr表示常量。16, 16BLOCK_SIZE_MBLOCK_SIZE_N,a.stride(1)b.stride(1)c.stride(1)因为均是1,被JIT优化成了constexprsignature_str之后我会写个脚本自动生成

二、源码分析

目前仅支持Triton v3.3.1,所以我们以commit d654e0f (May 30, 2025)的代码来分析。

1、最终执行入口

最终执行入口在third_party/nvidia/backend/driver.py:522,调用的是CudaLauncherlaunch

    def __call__(self, gridX, gridY, gridZ, stream, function, *args):
        if self.global_scratch_size > 0:
            grid_size = gridX * gridY * gridZ
            alloc_size = grid_size * self.global_scratch_size
            global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
        else:
            global_scratch = None
        self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)

launch 是在初始化里定义的,继续向上找。

2、CudaLauncher初始化

third_party/nvidia/backend/driver.py:510CudaLauncher__init__如下所示。

    def __init__(self, src, metadata):
        constants = src.constants if hasattr(src, "constants") else dict()
        arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
        constants = {arg_idx(idx): value for idx, value in constants.items()}
        signature = {idx: value for idx, value in src.signature.items()}
        src = make_launcher(constants, signature)
        mod = compile_module_from_src(src, "__triton_launcher")
        self.launch = mod.launch
        self.global_scratch_size = metadata.global_scratch_size
        self.global_scratch_align = metadata.global_scratch_align
        self.launch_cooperative_grid = metadata.launch_cooperative_grid

我们需要的是make_launcher的参数signature来生成host codeconstants没在用了。

3、Jit调入CudaLauncher

我们是从python/triton/runtime/jit.py:591 进去的

            assert grid is not None
            if callable(grid):
                grid = grid(bound_args)
            grid_size = len(grid)
            grid_0 = grid[0]
            grid_1 = grid[1] if grid_size > 1 else 1
            grid_2 = grid[2] if grid_size > 2 else 1
            # launch kernel
            launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
            kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
                       launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook,
                       *bound_args.values())

4、function 的 load

function 其实是个指针,在编译时load的。python/triton/compiler/compiler.py:408,代码如下所示。

        self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
            self.name, self.kernel, self.metadata.shared, device)

三、代码编写

经过以上的分析,我们需要准备好相应的参数即可。

需要2步,第一步将cubin load,然后拿到function。第二步准备launch参数,执行即可。代码变动

1、CUfunction 获取

需要从metadata里拿到shared memory大小,也就是之前的json文件。

def get_cufunction(json_path, cubin_path, kernel_name):
    global metadata
    metadata = json.loads(open(json_path, "r").read())
    kernel = open(cubin_path, "rb").read()
    module, function, n_regs, n_spills = triton.runtime.driver.active.utils.load_binary(
        kernel_name, kernel, metadata['shared'], device)
    return function

2、cubin_launch

这里根据lauch的参数怎么来一一定制即可,部分参数还不完整。


def get_grid_xyz(grid):
    assert grid is not None
    grid_size = len(grid)
    grid_0 = grid[0]
    grid_1 = grid[1] if grid_size > 1 else 1
    grid_2 = grid[2] if grid_size > 2 else 1
    return grid_0, grid_1, grid_2

def get_packed_metadata():
    metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
    # JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
    target = metadata['target']
    metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
    KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
    compile_metadata = KernelMetadata(**metadata)
    backend = make_backend(compile_metadata.target)
    return backend.pack_metadata(compile_metadata)

def cubin_launch(function, signature_str, bound_args, grid):
    signature = dict(enumerate(signature_str.split()))
    src = make_launcher(None, signature)
    mod = compile_module_from_src(src, "__triton_launcher")
    global_scratch = None
    packed_metadata = get_packed_metadata()
    launch_metadata, launch_enter_hook, launch_exit_hook = None, None, None
    mod.launch(*get_grid_xyz(grid), stream, function, metadata['launch_cooperative_grid'],
               global_scratch, packed_metadata, launch_metadata, launch_enter_hook, launch_exit_hook,
               *bound_args)

四、相关文章

深度剖析 Triton编译器 MatMul优化(三)—— TMA

深度剖析 Triton编译器 MatMul优化(二)—— MMA

深度剖析 Triton编译器 MatMul优化(一)—— FMA

浅析 Triton 执行流程

posted @ 2025-07-08 07:01  暴力都不会的蒟蒻  阅读(57)  评论(0)    收藏  举报