深度剖析 Triton编译器 MatMul优化(一)—— FMA

本文分析了native(不做分块)的Triton Matmul矩阵乘在 NVIDIA B200的编译流程,从Python->TTIR->TTGIR->LLVM IR->PTX。最近会出一个系列分析Triton对于矩阵乘的优化以及Blackwell新特性的支持情况。首先先看性能,用上autotune相比CUDA加速比为1.40×。
CUDA native kernel vs Triton native kernel

我关于Triton分析的上一篇文章介绍了Triton,分析了vectorAdd算子的编译Pass Pipeline,简单介绍了jit流程。得到了不少朋友的关注和鼓励,在此感谢大家。本文针对的是Pass导致的IR变化,构建项目和全部流程请参考上一篇文章,相关内容本文不再赘述。浅析 Triton 执行流程

本文所用用Triton的 commit为bc75dd0(Jun 27, 2025) 的版本。所有IR和kernel文件均已上传至Github。sBobHuang/Triton-blog-file

一、matmul Triton kernel

1、kernel书写

Triton kernel如下所示,矩阵a大小为M*N,矩阵b大小为N*K,结果矩阵c为M*K

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matrix_multiplication_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_an,
    stride_bn, stride_bk,
    stride_cm, stride_ck,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    pid_k = tl.program_id(axis=0)
    pid_m = tl.program_id(axis=1)

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)

    # 初始化A和B的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am
    b_ptrs = b_ptr + offs_k[None, :] * stride_bk

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

    # 沿N维度依次累加
    for n in range(N):
        # 加载A和B的当前值
        a = tl.load(a_ptrs + n * stride_an)
        b = tl.load(b_ptrs + n * stride_bn)
        accumulator += a * b

    # 将结果写回C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K)
    tl.store(c_ptrs, accumulator, mask=c_mask)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1

    grid = lambda META: (triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(M, META['BLOCK_SIZE_M']), )
    matrix_multiplication_kernel[grid](
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_an,
        stride_bn, stride_bk,
        stride_cm, stride_ck,
        BLOCK_SIZE_M=128,
        BLOCK_SIZE_K=64,
    )

2、kernel简析

针对结果c[i][j],就是a的第i行乘上b的第j列,即c[i][j] = sum(a[i][n] * b[n][j] for k in 0..N-1)

以上公式中,a的起始位置为a[i][0],a[i][0]的指针为a_ptr + i * stride_am。这个i怎么来的呢,我们将program在第1维上分了triton.cdiv(M, META['BLOCK_SIZE_M']块,pid_m为拿到的块编号,块编号乘上块大小就是起始位置了。另外我们的program是要处理BLOCK_SIZE_M个元素的,即i实际为一个集合,即pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M),然后使用[:, None]将其broadcast为[BLOCK_SIZE_M, 1]b同理得到了[1, BLOCK_SIZE_K]的块。

每个program要处理 [BLOCK_SIZE_M, BLOCK_SIZE_K]的数据,存在循环,所以我们需要创建tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)的累加变量。然后每次循环取出a[i][n]b[n][j]的值相乘并存进累加变量。

结果c的指针计算同理,其是需要mask的,判断(offs_cm[:, None] < M) & (offs_ck[None, :] < K)即可。

二、ast_to_ttir

使用JIT装饰器遍历Python AST,最后调用MLIR的self.create<方法,具体可以参考上一篇文章。

1、循环分析

得到的ttir比较冗余,全部IR在01-source.mlir,我们挑其中的循环看一下,如下所示。

    %60 = scf.for %arg9 = %56 to %57 step %58 iter_args(%arg10 = %55) -> (tensor<128x64xf32>)  : i32 {
      %c1_i32_49 = arith.constant 1 : i32 loc(#loc17)
      %c1_i32_50 = arith.constant 1 : i32 loc(#loc17)
      %124 = arith.extsi %arg9 : i32 to i64 loc(#loc17)
      %125 = arith.extsi %c1_i32_50 : i32 to i64 loc(#loc17)
      %126 = arith.muli %124, %125 : i64 loc(#loc17)
      %c2147483647_i64_51 = arith.constant 2147483647 : i64 loc(#loc17)
      %c-2147483648_i64_52 = arith.constant -2147483648 : i64 loc(#loc17)
      %127 = arith.cmpi sle, %126, %c2147483647_i64_51 : i64 loc(#loc17)
      %128 = arith.cmpi sge, %126, %c-2147483648_i64_52 : i64 loc(#loc17)
      %129 = arith.andi %127, %128 : i1 loc(#loc17)
      %130 = arith.muli %arg9, %c1_i32_50 : i32 loc(#loc17)
      %131 = tt.splat %130 : i32 -> tensor<128x1xi32> loc(#loc18)
      %132 = tt.addptr %44, %131 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> loc(#loc18)
      %133 = tt.load %132 : tensor<128x1x!tt.ptr<f32>> loc(#loc19)
      %134 = arith.extsi %arg9 : i32 to i64 loc(#loc20)
      %135 = arith.extsi %arg7 : i32 to i64 loc(#loc20)
      %136 = arith.muli %134, %135 : i64 loc(#loc20)
      %c2147483647_i64_53 = arith.constant 2147483647 : i64 loc(#loc20)
      %c-2147483648_i64_54 = arith.constant -2147483648 : i64 loc(#loc20)
      %137 = arith.cmpi sle, %136, %c2147483647_i64_53 : i64 loc(#loc20)
      %138 = arith.cmpi sge, %136, %c-2147483648_i64_54 : i64 loc(#loc20)
      %139 = arith.andi %137, %138 : i1 loc(#loc20)
      %140 = arith.muli %arg9, %arg7 : i32 loc(#loc20)
      %141 = tt.splat %140 : i32 -> tensor<1x64xi32> loc(#loc21)
      %142 = tt.addptr %54, %141 : tensor<1x64x!tt.ptr<f32>>, tensor<1x64xi32> loc(#loc21)
      %143 = tt.load %142 : tensor<1x64x!tt.ptr<f32>> loc(#loc22)
      %144 = tt.broadcast %133 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc23)
      %145 = tt.broadcast %143 : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc23)
      %146 = arith.mulf %144, %145 : tensor<128x64xf32> loc(#loc23)
      %147 = arith.addf %arg10, %146 : tensor<128x64xf32> loc(#loc24)
      scf.yield %147 : tensor<128x64xf32> loc(#loc25)
    } loc(#loc16)

可以对应到Triton kernel中的循环部分,如下所示

    for n in range(N):
        a = tl.load(a_ptrs + n * stride_an)
        b = tl.load(b_ptrs + n * stride_bn)
        accumulator += a * b

2、sanitize_overflow

有一段IR是这样的,把循环变量*1,再和 2147483647、-2147483648比较,得到结果%129。没有在后续IR中使用,这段会被优化掉。这段代码是python/triton/language/semantic.py:206binary_op_sanitize_overflow_impl生成的。

      %c1_i32_50 = arith.constant 1 : i32 loc(#loc17)
      %124 = arith.extsi %arg9 : i32 to i64 loc(#loc17)
      %125 = arith.extsi %c1_i32_50 : i32 to i64 loc(#loc17)
      %126 = arith.muli %124, %125 : i64 loc(#loc17)
      %c2147483647_i64_51 = arith.constant 2147483647 : i64 loc(#loc17)
      %c-2147483648_i64_52 = arith.constant -2147483648 : i64 loc(#loc17)
      %127 = arith.cmpi sle, %126, %c2147483647_i64_51 : i64 loc(#loc17)
      %128 = arith.cmpi sge, %126, %c-2147483648_i64_52 : i64 loc(#loc17)
      %129 = arith.andi %127, %128 : i1 loc(#loc17)

3、其他部分

然后计算了offset,并把aload出来。stride_an是1,没有通过参数传递。b的load同理。

      %130 = arith.muli %arg9, %c1_i32_50 : i32 loc(#loc17)
      %131 = tt.splat %130 : i32 -> tensor<128x1xi32> loc(#loc18)
      %132 = tt.addptr %44, %131 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> loc(#loc18)
      %133 = tt.load %132 : tensor<128x1x!tt.ptr<f32>> loc(#loc19)

a和b的shape不同,需要broadcast,之后再乘加即可。

      %144 = tt.broadcast %133 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc23)
      %145 = tt.broadcast %143 : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc23)
      %146 = arith.mulf %144, %145 : tensor<128x64xf32> loc(#loc23)
      %147 = arith.addf %arg10, %146 : tensor<128x64xf32> loc(#loc24)

accumulator是被放进iter_args的,所以还需要yield

    %60 = scf.for %arg9 = %56 to %57 step %58 iter_args(%arg10 = %55) -> (tensor<128x64xf32>)  : i32 {
      ...
      scf.yield %147 : tensor<128x64xf32> loc(#loc25)
    } loc(#loc16)

其他IR的理解类似,你还可以丢给chatgpt解读01-source.mlir。此时是完全符合Python DSL语义的,无任何优化。

三、 make_ttir

这个阶段,我们将执行如下流程。

    @staticmethod
    def make_ttir(mod, metadata, opt, capability):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        passes.common.add_inliner(pm)
        passes.ttir.add_rewrite_tensor_pointer(pm)
        if capability // 10 < 9:
            passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_combine(pm)
        passes.ttir.add_reorder_broadcast(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        passes.ttir.add_loop_unroll(pm)
        pm.run(mod)
        return mod

B200是SM_100,不会走add_rewrite_tensor_descriptor_to_pointer这个Pass。Pass执行后IR从02-Inliner.mlir12-TritonLoopUnroll.mlir

1、Canonicalizer

规范化是一个通用Pass,你会看到它经常出现,它可能存在的操作包含消除冗余操作、简化匹配模式、折叠常量计算、应用定义的Opcanonicalization

02-Inliner.mlir vs 03-Canonicalizer.mlir

// old
  tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32_"() -> tensor<128x64xf32> attributes {noinline = false} {
    %cst = arith.constant 0.000000e+00 : f32 loc(#loc46)
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc46)
    tt.return %cst_0 : tensor<128x64xf32> loc(#loc47)
  ^bb1:  // no predecessors
    %0 = ub.poison : tensor<128x64xf32> loc(#loc48)
    tt.return %0 : tensor<128x64xf32> loc(#loc48)
  } loc(#loc45)

// new
  tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32_"() -> tensor<128x64xf32> attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc46)
    tt.return %cst : tensor<128x64xf32> loc(#loc47)
  } loc(#loc45)

03-Canonicalizer.mlir vs 04-Canonicalizer.mlir 的变化特别大,我们之前分析的循环中没使用的Op都去掉了

    %18 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst) -> (tensor<128x64xf32>)  : i32 {
      %45 = tt.splat %arg9 : i32 -> tensor<128x1xi32> loc(#loc18)
      %46 = tt.addptr %14, %45 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> loc(#loc18)
      %47 = tt.load %46 : tensor<128x1x!tt.ptr<f32>> loc(#loc19)
      %48 = arith.muli %arg9, %arg7 : i32 loc(#loc20)
      %49 = tt.splat %48 : i32 -> tensor<1x64xi32> loc(#loc21)
      %50 = tt.addptr %17, %49 : tensor<1x64x!tt.ptr<f32>>, tensor<1x64xi32> loc(#loc21)
      %51 = tt.load %50 : tensor<1x64x!tt.ptr<f32>> loc(#loc22)
      %52 = tt.broadcast %47 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc23)
      %53 = tt.broadcast %51 : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc23)
      %54 = arith.mulf %52, %53 : tensor<128x64xf32> loc(#loc23)
      %55 = arith.addf %arg10, %54 : tensor<128x64xf32> loc(#loc24)
      scf.yield %55 : tensor<128x64xf32> loc(#loc25)
    } loc(#loc17)

04-Canonicalizer.mlir vs 05-Canonicalizer.mlir 变化不大,是loc以及去掉了@"triton.language.standard.zeros这个func。

2、CSE

Common Subexpression Elimination,公共子表达式消除,查找等价操作,然后复用。

09-TritonReorderBroadcast.mlir vs 10-CSE.mlir比如tt.expand_dims %5 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc24)%15%19,最后%19就没有了,arith.muli直接用的%19,其他类似。

3、最终结果

最终产物为matrix_multiplication_kernel.ttir,也是12-TritonLoopUnroll.mlir

四、 make_ttgir

优化越来越多,这个模块也越来越大了,如下所示。

    @staticmethod
    def make_ttgir(mod, metadata, opt, capability):
        # Set maxnreg on all kernels, if it was provided.
        if opt.maxnreg is not None:
            mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))

        cluster_info = nvidia.ClusterInfo()
        if opt.cluster_dims is not None:
            cluster_info.clusterDimX = opt.cluster_dims[0]
            cluster_info.clusterDimY = opt.cluster_dims[1]
            cluster_info.clusterDimZ = opt.cluster_dims[2]
        pm = ir.pass_manager(mod.context)
        dump_enabled = pm.enable_debug()
        passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
        # optimize TTGIR
        passes.ttgpuir.add_coalesce(pm)
        if capability // 10 >= 8:
            passes.ttgpuir.add_f32_dot_tc(pm)
        # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
        nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_thread_locality(pm)
        passes.ttgpuir.add_accelerate_matmul(pm)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
        passes.ttir.add_loop_aware_cse(pm)
        if capability // 10 in [8, 9]:
            passes.ttgpuir.add_fuse_nested_loops(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttir.add_triton_licm(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttgpuir.add_combine_tensor_select_and_if(pm)
            nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
            passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
            passes.ttgpuir.add_schedule_loops(pm)
            passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
        elif capability // 10 >= 10:
            passes.ttgpuir.add_fuse_nested_loops(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttir.add_triton_licm(pm)
            passes.ttgpuir.add_optimize_accumulator_init(pm)
            passes.ttgpuir.add_hoist_tmem_alloc(pm)
            nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
            passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
            passes.ttgpuir.add_schedule_loops(pm)
            passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
            passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
            passes.ttgpuir.add_combine_tensor_select_and_if(pm)
            nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
        else:
            passes.ttir.add_triton_licm(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_loop_aware_cse(pm)
        passes.ttgpuir.add_prefetch(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        passes.ttgpuir.add_coalesce_async_copy(pm)
        nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
        passes.ttgpuir.add_reduce_data_duplication(pm)
        passes.ttgpuir.add_reorder_instructions(pm)
        passes.ttir.add_loop_aware_cse(pm)
        passes.common.add_symbol_dce(pm)
        if capability // 10 >= 9:
            nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
        nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
        passes.common.add_sccp(pm)
        passes.common.add_canonicalizer(pm)
        pm.run(mod)
        metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
        tensordesc_meta = mod.get_tensordesc_metadata()
        metadata["tensordesc_meta"] = tensordesc_meta
        return mod

这里还对capability做了一个更详细的分支,我们依旧只分析变化部分。Pass执行后IR从13-ConvertTritonToTritonGPU.mlir61-Canonicalizer.mlir

1、ConvertTritonToTritonGPU

将Triton IR转换为 TritonGPU IR,以下是这个Pass的description

      This pass converts the Triton Dialect into the TritonGPU Dialect.
      This is a partial conversion that also affects other dialects
      (namely `Arith`, `Math`, `SCF` and `CF`).
      For these dialects, and many Triton dialect operations the conversions
      mainly consists of enhancing the tensor type and the `tt.ptr<tensor<>>`
      type with an appropriate layout encoding (these encodings generally
      include information on `numWarps`, `threadsPerWarp` and `numCTAs`).

12-TritonLoopUnroll.mlir vs 13-ConvertTritonToTritonGPU.mlir,这里主要是加上了一些layout,以下是被加进来的内存访问布局标签。

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>

我们大多数Op都被加上了这个标签。比如#blocked3,也就是a的访问,sizePerThread = [1, 1]表示每个线程处理 1×1 个元素;threadsPerWarp = [32, 1]表示每个 warp 由 32 个线程组成,沿着第 0 维行方向排布;warpsPerCTA = [4, 1]表示一个 Cooperative Thread Array(block)有 4 个 warp 沿 第 0 维 排列,整个 CTA 会覆盖 4 × 32 = 128 行;order = [1, 0]表示内存访问优先级是列优先,即 column-major 顺序。#blocked3在同一个warp内(32个线程,排列为1行32列),这32个线程将访问同一列的连续 32 行的元素。

我们的module上也带上了一些attributes,这会给后边用。ttg.targetCUDA Compute capability,B200是SM_100,也就是"cuda:100"。

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32}

2、TritonGPUCoalesce

访存合并,合并线程连续访问内存中的连续地址,目标是提升 访存效率和 GPU 带宽利用率。上面加的#ttg.blocked标记就是给这个Pass使用的。

13-ConvertTritonToTritonGPU.mlir vs 14-TritonGPUCoalesce.mlir,我们先分析下最后store的转换吧。


// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
...
    %22 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst) -> (tensor<128x64xf32, #blocked>)  : i32
    ...
    %30 = tt.addptr %28, %29 : tensor<128x64x!tt.ptr<f32>, #blocked>, tensor<128x64xi32, #blocked> loc(#loc26)
    ...
    %38 = arith.andi %36, %37 : tensor<128x64xi1, #blocked> loc(#loc29)
    tt.store %30, %22, %38 : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc30)

// new
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
...
    %39 = ttg.convert_layout %30 : tensor<128x64x!tt.ptr<f32>, #blocked> -> tensor<128x64x!tt.ptr<f32>, #blocked5> loc(#loc30)
    %40 = ttg.convert_layout %22 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked5> loc(#loc30)
    %41 = ttg.convert_layout %38 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #blocked5> loc(#loc30)
    tt.store %39, %40, %41 : tensor<128x64x!tt.ptr<f32>, #blocked5> loc(#loc30)

新layout的sizePerThread = [1, 4]代表每个线程访问一个小 vector。threadsPerWarp = [2, 16]是 2 行线程排布实现了 memory interleaving,可以提升带宽利用率,减少 bank conflicts,[1, 32]容易跨 cacheline。

我们再看下之前的loadttg.convert_layout是多余的。

// old
      %40 = tt.addptr %16, %39 : tensor<128x1x!tt.ptr<f32>, #blocked3>, tensor<128x1xi32, #blocked3> loc(#loc16)
      %41 = tt.load %40 : tensor<128x1x!tt.ptr<f32>, #blocked3> loc(#loc17)

// new
      %43 = tt.addptr %16, %42 : tensor<128x1x!tt.ptr<f32>, #blocked3>, tensor<128x1xi32, #blocked3> loc(#loc16)
      %44 = ttg.convert_layout %43 : tensor<128x1x!tt.ptr<f32>, #blocked3> -> tensor<128x1x!tt.ptr<f32>, #blocked3> loc(#loc17)
      %45 = tt.load %44 : tensor<128x1x!tt.ptr<f32>, #blocked3> loc(#loc17)

3、TritonGPURemoveLayoutConversions

去除多余的convert_layout,刚才我们也发现了load就不需要转换。

16-TritonGPUPlanCTAPass.mlir vs 17-TritonGPURemoveLayoutConversions.mlir,我们可以发现layout减少了。

// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>

// new
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>

其实就是前5个被合并到#blocked3了。我们可以看下lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp的内容。注释里写了这个Pass的运行方式,分析只做一轮(one-shot),然后整段 IR 依照支配(dominance)顺序重写一次。

// The current algorithm works by analyzing the IR and doing a one-shot rewrite
// based on the analysis. The algorithm is as follows.
//
// 1. Find all the anchor ops. These are ops that have a layout we want to
//    preserve.
//
// 2. For each anchor, propagate its layout to all its descendants.
//    An op can have multiple ancestors that are anchors, so at this stage an op
//    may have multiple layouts associated with it.
//
// 3. Resolve conflicts by deciding which of the multiple layouts the op should
//    keep, inserting convert-layout ops to resolve conflicts.  After this
//    stage, each value has only one layout associated with it.
//
// 4. Rewrite the IR by walking the function in dominance order. Since we
//    assume the IR is structured we just need to process the regions in the
//    correct order. For each op, rewrite it using the layout decided by the
//    analysis phase.

17-TritonGPURemoveLayoutConversions-debug.mlir 是对这个Pass Debug的输出。我们可以大概跟一下代码逻辑,插入layout的部分在 lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp :203

void LayoutPropagation::initAnchorLayout() {
  auto addAnchor = [&](Value v) {
    if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
      layouts.insert({v, LayoutInfo(tensorType.getEncoding())});
    }
  };

  // Consider function args as anchors.  This makes it easier to write tests --
  // you can pass a tensor with an encoding as an arg, instead of explicitly
  // calling tt.load.
  for (auto arg : funcOp.getArguments()) {
    addAnchor(arg);
  }

  funcOp.walk([&](Operation *op) {
    if (isLayoutAnchor(op)) {
      for (auto result : op->getResults()) {
        addAnchor(result);
      }
    }
  });
}

其中是通过isLayoutAnchor函数去判断的,这其实是一个cost model,可以认为是 布局敏感操作的 cost-aware 筛选器。

// Return true if the op is an op with a layout we don't want to change. We will
// propagate the layout starting from anchor ops.
bool isLayoutAnchor(Operation *op) {
  if (isa<LoadOp, StoreOp>(op))
    return isExpensiveLoadOrStore(op);
  if (isa<DotOp, DotScaledOp, nvidia_gpu::WarpGroupDotOp, AtomicRMWOp,
          AtomicCASOp, triton::nvidia_gpu::TMEMLoadOp>(op))
    return true;
  if (auto gatherOp = dyn_cast<GatherOp>(op))
    return gatherOp.getEfficientLayout();

  // Heuristic: Mark permuting reshape as a layout anchor.  Its dst can be
  // anything, so it stops forward-propagation of layouts.  We rely on the
  // backwards pass to fix it up if necessary.  (If we didn't do this, then
  // anything following the reshape won't be covered by the forward pass at
  // all.)
  if (auto reshape = dyn_cast<ReshapeOp>(op))
    return reshape.getAllowReorder();

  return false;
}

layouts插入的Value是%45 = tt.load %44 : tensor<128x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>>,也就是%45 = tt.load %44 : tensor<128x1x!tt.ptr, #blocked3> loc(#loc17),依照支配(dominance)顺序把相关的都重写就好了。

4、SCCP

Sparse Conditional Constant Propagation,CFG中传播常量,这里只是调整了arith.constant的顺序。

36-TritonGPURewritePartitionDependencies.mlir vs 37-SCCP.mlir

5、TritonNvidiaGPURemoveTMEMTokensPass

清除TMEM token,这是表示 共享内存(shared memory)或其他 memory barrier 相关的依赖关系的抽象。

43-TritonGPUCombineTensorSelectAndIf.mlir vs 44-TritonNvidiaGPURemoveTMEMTokensPass.mlir,我们这个多了ub.poison,下一个CanonicalizerPass就被干掉了,代码注释里也写了。

%0 = ub.poison : !ttg.async.token loc(#loc)

6、SCCP

CFG中传播常量,59-TritonGPUFenceInsertion.mlir vs 60-SCCP.mlirarith.constant又调整了次顺序。

7、最终产物

最终产物为matrix_multiplication_kernel.ttgir,也是61-Canonicalizer.mlir

五、make_llir

要从ttgir到llvm ir是比较陡峭的,不过SIMT的gpgpu就是可以这么做,编程模型相对比较简单。

    def make_llir(self, src, metadata, options, capability):
        ptx_version = get_ptx_version_from_options(options, self.target.arch)

        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()

        nvidia.passes.ttnvgpuir.add_lower_mma(pm)
        passes.ttgpuir.add_combine_tensor_select_and_if(pm)
        passes.ttgpuir.add_allocate_warp_groups(pm)
        passes.convert.add_scf_to_cf(pm)
        passes.ttgpuir.add_allocate_shared_memory(pm)
        nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
        if knobs.compilation.enable_experimental_consan:
            # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
            passes.ttgpuir.add_concurrency_sanitizer(pm)
        passes.ttgpuir.add_allocate_global_scratch_memory(pm)
        nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
        nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
        nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        if not knobs.compilation.disable_line_info:
            passes.llvmir.add_di_scope(pm)
        pm.run(mod)
        # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
        llvm.init_targets()
        context = llvm.context()
        if knobs.compilation.enable_asan:
            raise RuntimeError(
                "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
        llvm_mod = llvm.to_module(mod, context)
        proc = sm_arch_from_capability(capability)
        features = get_features(options, self.target.arch)
        triple = 'nvptx64-nvidia-cuda'
        nvidia.set_short_ptr()
        llvm.attach_datalayout(llvm_mod, triple, proc, features)
        nvidia.set_nvvm_reflect_ftz(llvm_mod)

        if options.extern_libs:
            paths = [path for (name, path) in options.extern_libs]
            llvm.link_extern_libs(llvm_mod, paths)

        llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)

        # Get some metadata
        # warp-specialization mutates num_warps
        total_num_warps = src.get_int_attr("ttg.total-num-warps")
        if total_num_warps is not None:
            metadata["num_warps"] = total_num_warps
        metadata["shared"] = src.get_int_attr("ttg.shared")
        metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
        metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
        metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
        ret = str(llvm_mod)
        del llvm_mod
        del context
        return ret

Pass执行后的IR从62-TritonNvidiaGPUMMALoweringPass.mlir79-LLVMDIScope.mlir

1、SCFToControlFlowPass

64-TritonGPUAllocateWarpGroups vs 65-SCFToControlFlowPass,将scf.forlower到cf.br

// old
    %26 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst) -> (tensor<128x64xf32, #blocked>)  : i32 {
...
      %52 = arith.addf %arg10, %51 : tensor<128x64xf32, #blocked> loc(#loc22)
      scf.yield %52 : tensor<128x64xf32, #blocked> loc(#loc23)
    } loc(#loc15)

// new
    cf.br ^bb1(%c0_i32, %cst : i32, tensor<128x64xf32, #blocked>) loc(#loc15)
  ^bb1(%26: i32 loc("/home/ubuntu/triton/matmul.py":30:19), %27: tensor<128x64xf32, #blocked> loc(unknown)):  // 2 preds: ^bb0, ^bb2
    %28 = arith.cmpi slt, %26, %arg4 : i32 loc(#loc15)
    cf.cond_br %28, ^bb2, ^bb3 loc(#loc15)
  ^bb2:  // pred: ^bb1
...
    %39 = arith.addf %27, %38 : tensor<128x64xf32, #blocked> loc(#loc22)
    %40 = arith.addi %26, %c1_i32 : i32 loc(#loc15)
    cf.br ^bb1(%40, %39 : i32, tensor<128x64xf32, #blocked>) loc(#loc15)
  ^bb3:  // pred: ^bb1

2、AllocateSharedMemory

65-SCFToControlFlowPass vs 66-AllocateSharedMemoryttg.convert_layout使用了SharedMemory来转置,可以看到其offset, module上也添加了ttg.shared大小描述,ttg.shared不等于tensor<128x64xf32, #blocked>的大小,即128x64*4不等于34816。因为lib/Analysis/Allocation.cpp:207elems取得是getNumScratchElemsPaddedCvt(srcTy, dstTy)

// old
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32}
...
    %55 = ttg.convert_layout %27 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> loc(#loc29)

// new
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 34816 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32}
    %55 = ttg.convert_layout %27 {allocation.offset = 0 : i32} : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> loc(#loc29)

3、TritonTensorMemoryAllocationPass

66-AllocateSharedMemory vs 67-TritonTensorMemoryAllocationPass,仅在module上添加了ttg.tensor_memory_size = 0 : i32

4、TritonGPUGlobalScratchAllocationPass

67-TritonTensorMemoryAllocationPass vs 68-TritonGPUGlobalScratchAllocationPass,加上了global_scratch的相关信息。

// old
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 34816 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32} {
  tt.func public @matrix_multiplication_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0)) attributes {noinline = false} }

// new
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 34816 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32} {
  tt.func public @matrix_multiplication_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0)) attributes {noinline = false, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} }

5、ConvertTritonGPUToLLVM

69-TritonGPUProxyFenceInsertion vs 70-ConvertTritonGPUToLLVM,这个是转换中非常重要的一个Pass,我们的IR会膨胀起来,生成的结果还会出现llvm.inline_asm这种嵌汇编ptx指令。

除了去debugRewritePattern外,还有loc信息的,我们可以先跟踪下arith.addf的浮点数加法运算怎么处理的。

// old
    %39 = arith.addf %27, %38 : tensor<128x64xf32, #blocked> loc(#loc22)
...
#loc22 = loc("/home/ubuntu/triton/matmul.py":34:23)

// new
%1527 = builtin.unrealized_conversion_cast %1526 : tensor<128x64xf32, #blocked> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
...
    %2721 = llvm.extractvalue %1527[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2722 = llvm.extractvalue %1527[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2723 = llvm.extractvalue %1527[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
... 省略58行
    %2782 = llvm.extractvalue %1527[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2783 = llvm.extractvalue %1527[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2784 = llvm.extractvalue %1527[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2785 = llvm.extractvalue %2720[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2786 = llvm.extractvalue %2720[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2787 = llvm.extractvalue %2720[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
... 省略58行
    %2846 = llvm.extractvalue %2720[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2847 = llvm.extractvalue %2720[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2848 = llvm.extractvalue %2720[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc16)
    %2849 = llvm.fadd %2721, %2785 : f32 loc(#loc16)
    %2850 = llvm.fadd %2722, %2786 : f32 loc(#loc16)
    %2851 = llvm.fadd %2723, %2787 : f32 loc(#loc16)
... 省略58行
    %2910 = llvm.fadd %2782, %2846 : f32 loc(#loc16)
    %2911 = llvm.fadd %2783, %2847 : f32 loc(#loc16)
    %2912 = llvm.fadd %2784, %2848 : f32 loc(#loc16)
...
#loc16 = loc("/home/ubuntu/triton/matmul.py":34:23)

%1527里取数据和%2720相乘,我们可以看下%2720是怎么来的,如下所示。

    %2592 = llvm.fmul %2464, %2528 : f32 loc(#loc22)
    %2593 = llvm.fmul %2465, %2529 : f32 loc(#loc22)
    %2594 = llvm.fmul %2466, %2530 : f32 loc(#loc22)
... 省略58行
    %2653 = llvm.fmul %2525, %2589 : f32 loc(#loc22)
    %2654 = llvm.fmul %2526, %2590 : f32 loc(#loc22)
    %2655 = llvm.fmul %2527, %2591 : f32 loc(#loc22)
    %2656 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
    %2657 = llvm.insertvalue %2592, %2656[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc22)
    %2658 = llvm.insertvalue %2593, %2657[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc22)
    %2659 = llvm.insertvalue %2594, %2658[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc22)
... 省略58行
    %2718 = llvm.insertvalue %2653, %2717[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc22)
    %2719 = llvm.insertvalue %2654, %2718[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc22)
    %2720 = llvm.insertvalue %2655, %2719[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc22)

就是乘法后的值,然后定义了一个64*f32的struct,然后和%1527相乘,%1527又可以追溯到^bb1标签,^bb1又可以追溯到最前面的%70

    %4 = llvm.mlir.constant(0.000000e+00 : f32) : f32 loc(#loc1)
    %5 = llvm.bitcast %4 : f32 to f32 loc(#loc1)
    %6 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
    %7 = llvm.insertvalue %5, %6[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc1)
    %8 = llvm.insertvalue %5, %7[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc1)
    %9 = llvm.insertvalue %5, %8[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc1)
... 省略58行
    %68 = llvm.insertvalue %5, %67[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc1)
    %69 = llvm.insertvalue %5, %68[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc1)
    %70 = llvm.insertvalue %5, %69[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  loc(#loc1)

%70也是一个64*f32的struct的,且最初全为为0.0。这就是accumulator。我们的原始数据是tensor<128x64xf32, #blocked>,逻辑对上了,shape是怎么回事呢,这一步其实就是转换到线程了。module中的"ttg.num-warps" = 4 : i32"ttg.threads-per-warp" = 32来计算的,具体在lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp:46unsigned numElementsPerThread = getTotalElemsPerThread(type);来算出来的,这个getTotalElemsPerThread继续怎么算的你可以跟踪lib/Dialect/TritonGPU/IR/Dialect.cpp:79,所以其实和layout都有关。

unsigned getTotalElemsPerThread(Type type) {
  if (type.isIntOrIndexOrFloat() || isa<triton::PointerType>(type))
    return 1;
  auto tensorType = cast<RankedTensorType>(type);
  return getTotalElemsPerThread(tensorType.getEncoding(),
                                tensorType.getShape());
}

我们再分析一下怎么存进去的,新的只拿了非常小的一块。

// old
    %55 = ttg.convert_layout %27 {allocation.offset = 0 : i32} : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> loc(#loc29)
    tt.store %47, %55, %54 : tensor<128x64x!tt.ptr<f32>, #blocked1> loc(#loc29)
...
#loc29 = loc("/home/ubuntu/triton/matmul.py":41:21)

// new
...
    %4979 = llvm.getelementptr inbounds %4883[%4978] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 loc(#loc29)
    %4980 = llvm.mlir.undef : vector<4xf32> loc(#loc29)
    %4981 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4982 = llvm.insertelement %4884, %4980[%4981 : i32] : vector<4xf32> loc(#loc29)
...省略4行
    %4987 = llvm.mlir.constant(3 : i32) : i32 loc(#loc29)
    %4988 = llvm.insertelement %4887, %4986[%4987 : i32] : vector<4xf32> loc(#loc29)
    %4989 = llvm.mlir.constant(true) : i1 loc(#loc29)
    %4990 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4991 = llvm.extractelement %4988[%4990 : i32] : vector<4xf32> loc(#loc29)
...省略4行
    %4996 = llvm.mlir.constant(3 : i32) : i32 loc(#loc29)
    %4997 = llvm.extractelement %4988[%4996 : i32] : vector<4xf32> loc(#loc29)
    %4998 = llvm.bitcast %4991 : f32 to i32 loc(#loc29)
...省略2行
    %5001 = llvm.bitcast %4997 : f32 to i32 loc(#loc29)
    %5002 = llvm.mlir.undef : vector<4xi32> loc(#loc29)
    %5003 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %5004 = llvm.insertelement %4998, %5002[%5003 : i32] : vector<4xi32> loc(#loc29)
...省略4行
    %5009 = llvm.mlir.constant(3 : i32) : i32 loc(#loc29)
    %5010 = llvm.insertelement %5001, %5008[%5009 : i32] : vector<4xi32> loc(#loc29)
    llvm.store %5010, %4979 {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<3> loc(#loc29)
...
#loc29 = loc("/home/ubuntu/triton/matmul.py":41:21)

逻辑是构造了一个 vector<4xi32>,把 4 个 f32 数据 bitcast 成 i32,打包成一个向量,用 16 字节对齐的 store 把这个 vector 存到共享内存中对应偏移的位置。关于内存偏移的计算如下所示

    %4948 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc29)
    %4949 = llvm.mlir.constant(127 : i32) : i32 loc(#loc29)
    %4950 = llvm.and %4948, %4949 : i32 loc(#loc29)
    %4951 = llvm.mlir.constant(32 : i32) : i32 loc(#loc29)
    %4952 = llvm.urem %4950, %4951 : i32 loc(#loc29)
    %4953 = llvm.udiv %4950, %4951 : i32 loc(#loc29)
    %4954 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4955 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4956 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4957 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4958 = llvm.shl %4952, %4957 : i32 loc(#loc29)
    %4959 = llvm.or %4956, %4958 : i32 loc(#loc29)
    %4960 = llvm.mlir.constant(5 : i32) : i32 loc(#loc29)
    %4961 = llvm.shl %4953, %4960 : i32 loc(#loc29)
    %4962 = llvm.or %4959, %4961 : i32 loc(#loc29)
    %4963 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4964 = llvm.mlir.constant(7 : i32) : i32 loc(#loc29)
    %4965 = llvm.and %4962, %4964 : i32 loc(#loc29)
    %4966 = llvm.mlir.constant(12 : i32) : i32 loc(#loc29)
    %4967 = llvm.shl %4965, %4966 : i32 loc(#loc29)
    %4968 = llvm.xor %4963, %4967 : i32 loc(#loc29)
    %4969 = llvm.mlir.constant(127 : i32) : i32 loc(#loc29)
    %4970 = llvm.and %4962, %4969 : i32 loc(#loc29)
    %4971 = llvm.mlir.constant(4 : i32) : i32 loc(#loc29)
    %4972 = llvm.shl %4970, %4971 : i32 loc(#loc29)
    %4973 = llvm.xor %4968, %4972 : i32 loc(#loc29)
    %4974 = llvm.xor %4955, %4973 : i32 loc(#loc29)
    %4975 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4976 = llvm.xor %4974, %4975 : i32 loc(#loc29)
    %4977 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
    %4978 = llvm.add %4976, %4977 : i32 loc(#loc29)

那么inline asm是怎么产生的呢,我们可以找下出现llvm.inline_asm的位置,这里以loc("/home/ubuntu/triton/matmul.py":32:20)为例

// old
%31 = tt.load %30 : tensor<128x1x!tt.ptr<f32>, #blocked> loc(#loc17)
...
#loc17 = loc("/home/ubuntu/triton/matmul.py":32:20)

// new
    %1537 = llvm.extractvalue %1536[0] : !llvm.struct<(ptr<1>)>  loc(#loc18)
    %1538 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l" %1537 : (!llvm.ptr<1>) -> i32 loc(#loc18)
    %1539 = llvm.bitcast %1538 : i32 to vector<1xf32> loc(#loc18)
    %1540 = llvm.mlir.constant(0 : index) : i32 loc(#loc18)
    %1541 = llvm.extractelement %1539[%1540 : i32] : vector<1xf32> loc(#loc18)
    %1542 = llvm.mlir.undef : !llvm.struct<(f32)> loc(#loc18)
    %1543 = llvm.insertvalue %1541, %1542[0] : !llvm.struct<(f32)>  loc(#loc18)
...
#loc18 = loc("/home/ubuntu/triton/matmul.py":32:20)

描述的是global memory 加载 float 值,再装入 LLVM struct的过程。其实我们需要的是%1541,也就是float值,在下一个Pass无关的代码就会被优化掉。这个实际上是PTXBuilder 来拼接出来的。

6、Canonicalizer

70-ConvertTritonGPUToLLVM vs 71-Canonicalizer,规范化Pass,上文提到过,IR可以从7192行降到1882行。

浮点数乘法被优化后的IR,从259行降到了193行,你可以看到直接复用了%2720也就是乘法得到的值,没再打包解包。

store的IR量更是从从2265行降到了988行,这里也有数据打包解包的消除,还有llvm.mlir.constant(0 : i32)被合并成了一条及其他优化。

7、CSE

71-Canonicalizer vs 72-CSE,公共子表达式消除,上文提到过,IR可以从1882行降到1847行。

8、LLVMDIScope

78-SymbolDCE vs 79-LLVMDIScope,在 LLVM IR 中附加调试信息作用域(Debug Info Scope) 的 pass,生成调试信息(如 DWARF)以支持源级调试。

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"

@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16

define ptx_kernel void @matrix_multiplication_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, ptr addrspace(1) readnone captures(none) %9) local_unnamed_addr #0 !dbg !5 {
  %11 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !8
  %12 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !dbg !9
  %13 = shl nuw nsw i32 %12, 7, !dbg !10
  %14 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !11
  %15 = and i32 %14, 127, !dbg !11
  %16 = or disjoint i32 %13, %15, !dbg !12 }
...

!5 = distinct !DISubprogram(name: "matrix_multiplication_kernel", linkageName: "matrix_multiplication_kernel", scope: !1, file: !1, line: 6, type: !6, scopeLine: 6, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
!6 = !DISubroutineType(cc: DW_CC_normal, types: !7)
!7 = !{}
!8 = !DILocation(line: 17, column: 26, scope: !5)
!9 = !DILocation(line: 18, column: 26, scope: !5)
!10 = !DILocation(line: 20, column: 21, scope: !5)
!11 = !DILocation(line: 20, column: 49, scope: !5)
!12 = !DILocation(line: 20, column: 36, scope: !5)

9、最终产物

最终产物为 matrix_multiplication_kernel.llir

六、make_ptx

这里实际上调用的是LLVM,我们只简单分析下最后的ptx结果。

    def make_ptx(self, src, metadata, opt, capability):
        ptx_version = get_ptx_version_from_options(opt, self.target.arch)

        triple = 'nvptx64-nvidia-cuda'
        proc = sm_arch_from_capability(capability)
        features = get_features(opt, self.target.arch)
        ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
        # Find kernel names (there should only be one)
        names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
        assert len(names) == 1
        metadata["name"] = names[0]
        # post-process
        ptx_version = f'{ptx_version//10}.{ptx_version%10}'
        ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
        ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
        # Remove the debug flag that prevents ptxas from optimizing the code
        ret = re.sub(r",\s*debug|debug,\s*", "", ret)
        if knobs.nvidia.dump_nvptx:
            print("// -----// NVPTX Dump //----- //")
            print(ret)
        return ret

产物为matrix_multiplication_kernel.ptx。输入文件是579行,输出文件是729行,基本一一对应。

1、加法对照

将向量的fmul + fadd生成了具体的fma指令。

// old
    %209 = fmul <64 x float> %208, %205, !dbg !23
    %210 = fadd <64 x float> %39, %209, !dbg !24
...
!24 = !DILocation(line: 34, column: 23, scope: !5)

// new
	.loc	1 34 23                         // matmul.py:34:23
	fma.rn.f32 	%r599, %r338, %r342, %r599;
	fma.rn.f32 	%r598, %r338, %r341, %r598;
	fma.rn.f32 	%r597, %r338, %r340, %r597;
...省略58行
	fma.rn.f32 	%r657, %r338, %r400, %r657;
	fma.rn.f32 	%r658, %r338, %r401, %r658;
	fma.rn.f32 	%r659, %r338, %r402, %r659;

2、store对照

其中调用了llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)会转换为bar.sync 0;这个barrier操作。其余基本都是对LLVM IR的翻译,这部分是这个kernel的核心部分,可以学习下他对share memory的padding。

// old
  %331 = and i32 %14, 7, !dbg !33
  %332 = shl nuw nsw i32 %331, 12, !dbg !33
  %333 = shl nuw nsw i32 %15, 4, !dbg !33
  %334 = or disjoint i32 %332, %333, !dbg !33
  %335 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %334, !dbg !33
  %336 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>, !dbg !33
  store <4 x float> %336, ptr addrspace(3) %335, align 16, !dbg !33
  %337 = getelementptr inbounds nuw i8, ptr addrspace(3) %335, i32 2048, !dbg !33
  %338 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, !dbg !33
  store <4 x float> %338, ptr addrspace(3) %337, align 16, !dbg !33
  %339 = xor i32 %334, 16, !dbg !33
...省略42行
  %370 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %369, !dbg !33
  %371 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 56, i32 57, i32 58, i32 59>, !dbg !33
  store <4 x float> %371, ptr addrspace(3) %370, align 16, !dbg !33
  %372 = getelementptr inbounds nuw i8, ptr addrspace(3) %370, i32 2048, !dbg !33
  %373 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 60, i32 61, i32 62, i32 63>, !dbg !33
  store <4 x float> %373, ptr addrspace(3) %372, align 16, !dbg !33
  %372 = getelementptr inbounds nuw i8, ptr addrspace(3) %370, i32 2048, !dbg !33
  %373 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 60, i32 61, i32 62, i32 63>, !dbg !33
  store <4 x float> %373, ptr addrspace(3) %372, align 16, !dbg !33
  tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !33
  %374 = shl nuw nsw i32 %14, 8, !dbg !33
  %375 = and i32 %374, 30720, !dbg !33
  %376 = shl nuw nsw i32 %331, 4, !dbg !33
  %377 = or disjoint i32 %375, %376, !dbg !33
  %378 = xor i32 %377, %215, !dbg !33
  %379 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %378, !dbg !33
  %380 = load <4 x i32>, ptr addrspace(3) %379, align 16, !dbg !33
  %381 = getelementptr inbounds nuw i8, ptr addrspace(3) %379, i32 128, !dbg !33
...省略26行
  %408 = load <4 x i32>, ptr addrspace(3) %407, align 16, !dbg !33
  %409 = getelementptr inbounds nuw i8, ptr addrspace(3) %379, i32 1920, !dbg !33
  %410 = load <4 x i32>, ptr addrspace(3) %409, align 16, !dbg !33
  %.extract = extractelement <4 x i32> %380, i64 0, !dbg !33
  %.extract64 = extractelement <4 x i32> %380, i64 1, !dbg !33
  %.extract65 = extractelement <4 x i32> %380, i64 2, !dbg !33
  %.extract66 = extractelement <4 x i32> %380, i64 3, !dbg !33
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract, i32 %.extract64, i32 %.extract65, i32 %.extract66, ptr addrspace(1) %282, i1 %315) #3, !dbg !33
...省略70行
  %.extract123 = extractelement <4 x i32> %410, i64 0, !dbg !33
  %.extract124 = extractelement <4 x i32> %410, i64 1, !dbg !33
  %.extract125 = extractelement <4 x i32> %410, i64 2, !dbg !33
  %.extract126 = extractelement <4 x i32> %410, i64 3, !dbg !33
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract123, i32 %.extract124, i32 %.extract125, i32 %.extract126, ptr addrspace(1) %297, i1 %330) #3, !dbg !33

// new
	.loc	1 41 21                         // matmul.py:41:21
	and.b32 	%r505, %r2, 7;
	shl.b32 	%r506, %r505, 12;
	shl.b32 	%r507, %r3, 4;
	or.b32 	%r508, %r506, %r507;
	mov.b32 	%r509, global_smem;
	add.s32 	%r510, %r509, %r508;
	st.shared.v4.b32 	[%r510], {%r596, %r597, %r598, %r599};
	st.shared.v4.b32 	[%r510+2048], {%r600, %r601, %r602, %r603};
	xor.b32 	%r511, %r508, 16;
	add.s32 	%r512, %r509, %r511;
	st.shared.v4.b32 	[%r512], {%r604, %r605, %r606, %r607};
	st.shared.v4.b32 	[%r512+2048], {%r608, %r609, %r610, %r611};
...省略16行
	xor.b32 	%r521, %r508, 96;
	add.s32 	%r522, %r509, %r521;
	st.shared.v4.b32 	[%r522], {%r644, %r645, %r646, %r647};
	st.shared.v4.b32 	[%r522+2048], {%r648, %r649, %r650, %r651};
	xor.b32 	%r523, %r508, 112;
	add.s32 	%r524, %r509, %r523;
	st.shared.v4.b32 	[%r524], {%r652, %r653, %r654, %r655};
	st.shared.v4.b32 	[%r524+2048], {%r656, %r657, %r658, %r659};
	bar.sync 	0;
	shl.b32 	%r525, %r2, 8;
	and.b32 	%r526, %r525, 30720;
	shl.b32 	%r527, %r505, 4;
	or.b32 	%r528, %r526, %r527;
	xor.b32 	%r529, %r528, %r470;
	add.s32 	%r530, %r509, %r529;
	ld.shared.v4.b32 	{%r403, %r404, %r405, %r406}, [%r530];
	ld.shared.v4.b32 	{%r407, %r408, %r409, %r410}, [%r530+128];
...省略12行
	ld.shared.v4.b32 	{%r459, %r460, %r461, %r462}, [%r530+1792];
	ld.shared.v4.b32 	{%r463, %r464, %r465, %r466}, [%r530+1920];
	// begin inline asm
	@%p3 st.global.v4.b32 [ %rd31 + 0 ], { %r403, %r404, %r405, %r406 };
	// end inline asm
	// begin inline asm
	@%p4 st.global.v4.b32 [ %rd32 + 0 ], { %r407, %r408, %r409, %r410 };
	// end inline asm
...省略 12*3行
	// begin inline asm
	@%p17 st.global.v4.b32 [ %rd45 + 0 ], { %r459, %r460, %r461, %r462 };
	// end inline asm
	// begin inline asm
	@%p18 st.global.v4.b32 [ %rd46 + 0 ], { %r463, %r464, %r465, %r466 };
	// end inline asm

3、load对照

对嵌汇编做了展开

// old
  %41 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %40) #3, !dbg !19
...
!19 = !DILocation(line: 32, column: 20, scope: !5)

// new
	.loc	1 32 20                         // matmul.py:32:20
	// begin inline asm
	mov.u32 %r338, 0x0;
	ld.global.b32 { %r338 }, [ %rd80 + 0 ];
posted @ 2025-06-27 06:50  暴力都不会的蒟蒻  阅读(232)  评论(0)    收藏  举报