深度剖析 Triton编译器 MatMul优化(二)—— MMA

深度剖析 Triton编译器 MatMul优化(一) 我们介绍了native矩阵乘的优化,本章来分析很容易就拿到性能的tl.dot操作。

上来首先性能对比,Triton native kernel vs Triton with dot kernel。这个加速比是3.68x,相较native的CUDA加速比是5.16x。

Triton native kernel vs Triton with dot kernel

这个加速比看起来已经很高了,但是和我之前tma的kernel耗时差不多我产生了怀疑,看了PTX发现不对劲。input_precision="ieee"导致这些kernel都走了fma指令,所以这个kernel的优化实际是循环展开带来的。用tf32(default)因为精度损失通过不了这道题目。LeetGPU还有一道GEMM (FP16),但是数据量太小,做不了bench。所以之后的文章就不做性能对比了,被卡脖子了。

本文所用用Triton的 commit为bc75dd0(Jun 27, 2025) 的版本。所有IR和kernel文件均已上传至Github。sBobHuang/Triton-blog-file。本系列相关文章 深度剖析 Triton编译器 MatMul优化(一)

一、matmul Triton kernel

1、kernel书写

Triton kernel如下所示,矩阵a大小为M*N,矩阵b大小为N*K,结果矩阵c为M*K。完整可运行在matmul-with-dot-v2.py

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matrix_multiplication_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_an,
    stride_bn, stride_bk,
    stride_cm, stride_ck,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr
):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    pid_k = tl.program_id(axis=0)
    pid_m = tl.program_id(axis=1)

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    offs_n = tl.arange(0, BLOCK_SIZE_N)

    # 初始化A和B的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
    b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

    # 沿N维度按块累加
    for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
        offset_n = n * BLOCK_SIZE_N
        max_idx = N - offset_n
        # 加载A和B的块
        a = tl.load(a_ptrs + offset_n * stride_an, mask=offs_n[None, :] < max_idx, other=0.0)
        b = tl.load(b_ptrs + offset_n * stride_bn, mask=offs_n[:, None] < max_idx, other=0.0)
        # 计算a @ b,累加到 accumulator
        accumulator = tl.dot(a, b, acc=accumulator)

    # 将结果写回C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
    c_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
    tl.store(c_ptrs, accumulator, mask=c_mask)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1

    grid = lambda META: (triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(M, META['BLOCK_SIZE_M']), )
    matrix_multiplication_kernel[grid](
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_an,
        stride_bn, stride_bk,
        stride_cm, stride_ck,
        BLOCK_SIZE_M=128,
        BLOCK_SIZE_K=64,
        BLOCK_SIZE_N=64,
    )

2、kernel简析

上文相比,本文kernel多了个BLOCK_SIZE_N,即a在循环中需要取[BLOCK_SIZE_M, BLOCK_SIZE_N]的块,b在循环中需要取[BLOCK_SIZE_N, BLOCK_SIZE_K],然后直接这两个块进行矩阵乘法。我们需要对a_ptrs = a_ptr + offs_m[:, None] * stride_am加上offs_n[None, :] * stride_anb_ptrs = b_ptr + offs_k[None, :] * stride_bk加上offs_k[None, :] * stride_bk,然后取的时候按块加上offset、mask即可。

二、ast_to_ttir

使用JIT装饰器遍历Python AST,最后调用MLIR的self.create<

1、循环IR

得到的ttir比较冗余,全部IR在01-source.mlir,我们挑其中的循环看一下,如下所示。

    %85 = scf.for %arg9 = %81 to %82 step %83 iter_args(%arg10 = %79) -> (tensor<128x64xf32>)  : i32 {
      %155 = arith.muli %arg9, %c64_i32_61 : i32 loc(#loc25)
...省略
      %162 = arith.subi %arg4, %155 : i32 loc(#loc26)
      %163 = tt.expand_dims %34 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc27)
      %164 = tt.splat %162 : i32 -> tensor<1x64xi32> loc(#loc28)
      %165 = arith.cmpi slt, %163, %164 : tensor<1x64xi32> loc(#loc28)
...省略
      %172 = arith.muli %arg9, %c64_i32_67 : i32 loc(#loc29)
...省略
      %179 = arith.muli %172, %c1_i32_71 : i32 loc(#loc30)
      %180 = tt.splat %179 : i32 -> tensor<128x64xi32> loc(#loc31)
      %181 = tt.addptr %56, %180 : tensor<128x64x!tt.ptr<f32>>, tensor<128x64xi32> loc(#loc31)
      %cst_74 = arith.constant 0.000000e+00 : f32 loc(#loc32)
      %182 = tt.broadcast %165 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc32)
      %cst_75 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc32)
      %183 = tt.load %181, %182, %cst_75 : tensor<128x64x!tt.ptr<f32>> loc(#loc32)
      %184 = tt.expand_dims %34 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc33)
      %185 = tt.splat %162 : i32 -> tensor<64x1xi32> loc(#loc34)
      %186 = arith.cmpi slt, %184, %185 : tensor<64x1xi32> loc(#loc34)
...省略
      %193 = arith.muli %arg9, %c64_i32_77 : i32 loc(#loc35)
...省略
      %200 = arith.muli %193, %arg7 : i32 loc(#loc36)
      %201 = tt.splat %200 : i32 -> tensor<64x64xi32> loc(#loc37)
      %202 = tt.addptr %78, %201 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> loc(#loc37)
      %cst_82 = arith.constant 0.000000e+00 : f32 loc(#loc38)
      %203 = tt.broadcast %186 : tensor<64x1xi1> -> tensor<64x64xi1> loc(#loc38)
      %cst_83 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> loc(#loc38)
      %204 = tt.load %202, %203, %cst_83 : tensor<64x64x!tt.ptr<f32>> loc(#loc38)
      %cst_84 = arith.constant 0.000000e+00 : f32 loc(#loc39)
      %cst_85 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc39)
      %205 = tt.dot %183, %204, %cst_85, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc39)
      %206 = arith.addf %arg10, %205 : tensor<128x64xf32> loc(#loc40)
      scf.yield %206 : tensor<128x64xf32> loc(#loc41)
    } loc(#loc24)

2、load a

现在这个load比较复杂,我们可以依次跟踪过去,首先是%183 = tt.load %181, %182, %cst_75 : tensor<128x64x!tt.ptr<f32>> loc(#loc32)%181是地址偏移,%182是mask,%cst_75是others,也就是mask为false时的填充值。我们还可以看LoadOp定义中的assemblyFormat来对应。

def TT_LoadOp : TT_Op<"load", [
  SameLoadStoreOperandsAndResultShape,
  SameLoadStoreOperandsAndResultEncoding,
  AttrSizedOperandSegments,
  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
  DeclareOpInterfaceMethods<InferTypeOpInterface>,
  TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">,
  TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))",
                 "($_op.getOperands().size() <= 1) || std::equal_to<>()">,
  TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)",
                 "($_op.getOperands().size() <= 2) || std::equal_to<>()">
]> {
    let summary = "Load from a tensor of pointers or from a tensor pointer";

    let arguments = (
      ins
      AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
      Optional<TT_BoolLike>:$mask,
      Optional<TT_Type>:$other,

      DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
      OptionalAttr<TT_PaddingOptionAttr>:$padding,
      DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
      DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
      DefaultValuedAttr<BoolAttr, "false">:$isVolatile
    );

    let results = (outs TT_Type:$result);
...
    let assemblyFormat = [{
      $ptr (`,` $mask^)? (`,` $other^)?
      oilist(
        `cacheModifier` `=` $cache |
        `evictionPolicy` `=` $evict
      )
      attr-dict `:` type($ptr)
    }];

    let hasCanonicalizer = 1;
}

Load a的完整MLIR为

      // 构造mask
      %163 = tt.expand_dims %34 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc27)
      %164 = tt.splat %162 : i32 -> tensor<1x64xi32> loc(#loc28)
      %165 = arith.cmpi slt, %163, %164 : tensor<1x64xi32> loc(#loc28)

      // 计算偏移
      %c64_i32_67 = arith.constant 64 : i32 loc(#loc29)
      %172 = arith.muli %arg9, %c64_i32_67 : i32 loc(#loc29)
      %c1_i32_71 = arith.constant 1 : i32 loc(#loc30)
      %179 = arith.muli %172, %c1_i32_71 : i32 loc(#loc30)
      %180 = tt.splat %179 : i32 -> tensor<128x64xi32> loc(#loc31)
      %181 = tt.addptr %56, %180 : tensor<128x64x!tt.ptr<f32>>, tensor<128x64xi32> loc(#loc31)
      // 准备 广播 mask 和 fallback 数据
      %182 = tt.broadcast %165 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc32)
      %cst_75 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc32)
      // 安全 load
      %183 = tt.load %181, %182, %cst_75 : tensor<128x64x!tt.ptr<f32>> loc(#loc32)

3、其他部分

将两次load的值块做dot,也就是矩阵乘,再加0

      %cst_85 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc39)
      %205 = tt.dot %183, %204, %cst_85, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc39)

accumulator是被放进iter_args的,所以还需要yield。这个和之前一样,这个kernel变化的仅是计算部分按块来了。

    %85 = scf.for %arg9 = %81 to %82 step %83 iter_args(%arg10 = %79) -> (tensor<128x64xf32>)  : i32 {
...
        scf.yield %206 : tensor<128x64xf32> loc(#loc41)
    } loc(#loc24)

其他IR的理解类似,你还可以丢给chatgpt解读01-source.mlir。此时是完全符合Python DSL语义的,无任何优化。

三、 make_ttir

这是将Python ast得到的MLIR简化的阶段,我们将执行如下流程。

    @staticmethod
    def make_ttir(mod, metadata, opt, capability):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        passes.common.add_inliner(pm)
        passes.ttir.add_rewrite_tensor_pointer(pm)
        if capability // 10 < 9:
            passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_combine(pm)
        passes.ttir.add_reorder_broadcast(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        passes.ttir.add_loop_unroll(pm)
        pm.run(mod)
        return mod

Pass执行后IR从02-Inliner.mlir13-TritonLoopUnroll.mlir

1、Canonicalizer

规范化是一个通用Pass,你会看到它经常出现,它可能存在的操作包含消除冗余操作、简化匹配模式、折叠常量计算、应用定义的Opcanonicalization

02-Inliner.mlir vs 03-Canonicalizer.mlir,对 triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了化简。

03-Canonicalizer.mlir vs 04-Canonicalizer.mlir,对 triton.language.standard.cdiv__i32__ 做了进一步化简。

04-Canonicalizer.mlir vs 05-Canonicalizer.mlir,变化特别大,循环中没使用的Op都去掉了,并且对triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了inline。

05-Canonicalizer.mlir vs 06-Canonicalizer.mlir,进一步化简。

2、TritonCombineOps

08-Canonicalizer.mlir vs 09-TritonCombineOps.mlir,将tt.dot+arith.addf合并到了一起,所以变成直接在accumulator上通过iter_args加了。

// old
      %77 = tt.dot %67, %76, %cst_0, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc38)
      %78 = arith.addf %arg10, %77 : tensor<128x64xf32> loc(#loc39)
      scf.yield %78 : tensor<128x64xf32> loc(#loc40)

// new
      %77 = tt.dot %67, %76, %arg10, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc38)
      scf.yield %77 : tensor<128x64xf32> loc(#loc39)

3、CSE

10-TritonReorderBroadcast.mlir vs 11-CSE.mlir,Common Subexpression Elimination,公共子表达式消除,比如tt.make_range就有很多重复的。

4、最终结果

最终产物为matrix_multiplication_kernel.ttir,也是13-TritonLoopUnroll.mlir

四、 make_ttgir

这个阶段Pass比较多,但是对我们源码产生变化的也能接受。

    @staticmethod
    def make_ttgir(mod, metadata, opt, capability):
        # Set maxnreg on all kernels, if it was provided.
        if opt.maxnreg is not None:
            mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))

        cluster_info = nvidia.ClusterInfo()
        if opt.cluster_dims is not None:
            cluster_info.clusterDimX = opt.cluster_dims[0]
            cluster_info.clusterDimY = opt.cluster_dims[1]
            cluster_info.clusterDimZ = opt.cluster_dims[2]
        pm = ir.pass_manager(mod.context)
        dump_enabled = pm.enable_debug()
        passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
        # optimize TTGIR
        passes.ttgpuir.add_coalesce(pm)
        if capability // 10 >= 8:
            passes.ttgpuir.add_f32_dot_tc(pm)
        # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
        nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_thread_locality(pm)
        passes.ttgpuir.add_accelerate_matmul(pm)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
        passes.ttir.add_loop_aware_cse(pm)
        if capability // 10 in [8, 9]:
            passes.ttgpuir.add_fuse_nested_loops(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttir.add_triton_licm(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttgpuir.add_combine_tensor_select_and_if(pm)
            nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
            passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
            passes.ttgpuir.add_schedule_loops(pm)
            passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
        elif capability // 10 >= 10:
            passes.ttgpuir.add_fuse_nested_loops(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttir.add_triton_licm(pm)
            passes.ttgpuir.add_optimize_accumulator_init(pm)
            passes.ttgpuir.add_hoist_tmem_alloc(pm)
            nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
            passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
            passes.ttgpuir.add_schedule_loops(pm)
            passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
            passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
            passes.ttgpuir.add_combine_tensor_select_and_if(pm)
            nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
        else:
            passes.ttir.add_triton_licm(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_loop_aware_cse(pm)
        passes.ttgpuir.add_prefetch(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        passes.ttgpuir.add_coalesce_async_copy(pm)
        nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
        passes.ttgpuir.add_reduce_data_duplication(pm)
        passes.ttgpuir.add_reorder_instructions(pm)
        passes.ttir.add_loop_aware_cse(pm)
        passes.common.add_symbol_dce(pm)
        if capability // 10 >= 9:
            nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
        nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
        passes.common.add_sccp(pm)
        passes.common.add_canonicalizer(pm)
        pm.run(mod)
        metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
        tensordesc_meta = mod.get_tensordesc_metadata()
        metadata["tensordesc_meta"] = tensordesc_meta
        return mod

Pass执行后IR从14-ConvertTritonToTritonGPU.mlir62-Canonicalizer.mlir

1、ConvertTritonToTritonGPU

13-TritonLoopUnroll.mlir vs 14-ConvertTritonToTritonGPU.mlir,这里主要是加上了一些layout。

2、TritonGPUCoalesce

14-ConvertTritonToTritonGPU.mlir vs 15-TritonGPUCoalesce.mlir,访存合并。

3、TritonGPURemoveLayoutConversions

17-TritonGPUPlanCTAPass.mlir vs 18-TritonGPURemoveLayoutConversions.mlir,去除多余的convert_layout,cost-model选择了#blocked5

// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>

// new
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>

4、TritonGPUAccelerateMatmul

19-TritonGPUOptimizeThreadLocality.mlir vs 20-TritonGPUAccelerateMatmul.mlir,这里对tt.dot做了处理。

// old
      %57 = tt.load %55, %56, %cst_0 : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc28)
...
      %64 = tt.load %62, %63, %cst : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc32)
      %65 = ttg.convert_layout %57 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> loc(#loc28)
      %66 = ttg.convert_layout %64 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> loc(#loc32)
      %67 = tt.dot %65, %66, %arg10, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf32, #blocked1> loc(#loc33)
      scf.yield %67 : tensor<128x64xf32, #blocked1> loc(#loc34)

// new
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
...
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}>
#smem = #ttg.shared_memory
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>

...
      %57 = tt.load %55, %56, %cst_0 : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc28)
      %58 = ttg.local_alloc %57 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc28)
...
      %65 = tt.load %63, %64, %cst : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc32)
      %66 = ttg.local_alloc %65 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc32)
      %67 = ttg.convert_layout %arg10 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked2> loc(#loc33)
      %result, %token = ttng.tmem_alloc %67 : (tensor<128x64xf32, #blocked2>) -> (!ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) loc(#loc33)
      %68 = ttng.tc_gen5_mma %58, %66, %result[%token], %true, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc33)
      %result_2, %token_3 = ttng.tmem_load %result[%68] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked2> loc(#loc33)
      %69 = ttg.convert_layout %result_2 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked1> loc(#loc33)
      scf.yield %69 : tensor<128x64xf32, #blocked1> loc(#loc34)

ttng即TritonNvidiaGPU,转换后包含分配 shared memory(alloc) 和 tensor memory(tmem),执行ttng.tc_gen5_mma操作,最后再从tensor memory取回来。

5、TritonGPURemoveLayoutConversions

20-TritonGPUAccelerateMatmul.mlir vs 21-TritonGPURemoveLayoutConversions.mlir,进一步合并layout。

// old
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>

// new
#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>

6、TritonLoopAwareCSE

24-TritonNvidiaGPUOptimizeDescriptorEncodingPass.mlir vs 25-TritonLoopAwareCSE.mlir,把有些make_range之类的都合并了

// old
    %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc8)
...
    %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc13)
    %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> loc(#loc13)
// new
    %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc8)
...
    %15 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> loc(#loc13)

7、TritonGPUOptimizeAccumulatorInit

28-TritonLoopInvariantCodeMotion.mlir vs 29-TritonGPUOptimizeAccumulatorInit.mlir,变换修改了ttng.tc_gen5_mma的operand

// old
    %31 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %cst_1) -> (tensor<128x64xf32, #blocked1>)  : i32 {
...
      %64 = ttng.tc_gen5_mma %55, %63, %result[%token], %true, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc33)
      %result_2, %token_3 = ttng.tmem_load %result[%64] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> loc(#loc33)
      scf.yield %result_2 : tensor<128x64xf32, #blocked1> loc(#loc34)
    } loc(#loc23)
...
    %46 = ttg.convert_layout %31 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked> loc(#loc41)

// new
    %31:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %cst_1, %arg11 = %false) -> (tensor<128x64xf32, #blocked1>, i1)  : i32 {
...
      %64 = ttng.tc_gen5_mma %55, %63, %result[%token], %arg11, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc33)
      %result_2, %token_3 = ttng.tmem_load %result[%64] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> loc(#loc33)
      scf.yield %result_2, %true : tensor<128x64xf32, #blocked1>, i1 loc(#loc34)
    } loc(#loc23)
    %46 = ttg.convert_layout %31#0 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked> loc(#loc41)

8、TritonGPUHoistTMEMAlloc

29-TritonGPUOptimizeAccumulatorInit.mlir vs 30-TritonGPUHoistTMEMAlloc.mlir,将tensor memory的alloc挪到循环外面,避免不必要的开销。

// new
    %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) loc(#loc23)
    %31 = ttng.tmem_store %cst_1, %result[%token], %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
    %32:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %false, %arg11 = %31) -> (i1, !ttg.async.token)  : i32 {
...
      %65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
      scf.yield %true, %65 : i1, !ttg.async.token loc(#loc34)
    } loc(#loc24)
    %result_2, %token_3 = ttng.tmem_load %result[%32#1] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> loc(#loc23)

9、TritonGPUAssignLatencies

31-TritonNvidiaGPUPromoteLHSToTMemPass.mlir vs 32-TritonGPUAssignLatencies.mlir,为相应的Op打上latency标签。

// new
      %55 = tt.load %53, %54, %cst_0 {tt.latency = 2 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc29)
...
      %63 = tt.load %61, %62, %cst {tt.latency = 2 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc33)
      %64 = ttg.local_alloc %63 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc33)
      %65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)

10、TritonGPUScheduleLoops

32-TritonGPUAssignLatencies.mlir vs 33-TritonGPUScheduleLoops.mlirsoftware pipeline loop scheduling,我最爱的软流水是这个Pass,涉及到GPU核心了,隐藏延迟

// old
    %32:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %false, %arg11 = %31) -> (i1, !ttg.async.token)  : i32 {
      %48 = arith.muli %arg9, %c64_i32 : i32 loc(#loc25)
      %49 = arith.subi %arg4, %48 : i32 loc(#loc26)
      %50 = tt.splat %49 : i32 -> tensor<1x64xi32, #blocked> loc(#loc27)
      %51 = arith.cmpi slt, %15, %50 : tensor<1x64xi32, #blocked> loc(#loc27)
      %52 = tt.splat %48 : i32 -> tensor<128x64xi32, #blocked> loc(#loc28)
      %53 = tt.addptr %18, %52 : tensor<128x64x!tt.ptr<f32>, #blocked>, tensor<128x64xi32, #blocked> loc(#loc28)
      %54 = tt.broadcast %51 : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> loc(#loc29)
      %55 = tt.load %53, %54, %cst_0 {tt.latency = 2 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc29)
      %56 = ttg.local_alloc %55 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc29)
      %57 = tt.splat %49 : i32 -> tensor<64x1xi32, #blocked> loc(#loc30)
      %58 = arith.cmpi slt, %20, %57 : tensor<64x1xi32, #blocked> loc(#loc30)
      %59 = arith.muli %48, %arg7 : i32 loc(#loc31)
      %60 = tt.splat %59 : i32 -> tensor<64x64xi32, #blocked> loc(#loc32)
      %61 = tt.addptr %28, %60 : tensor<64x64x!tt.ptr<f32>, #blocked>, tensor<64x64xi32, #blocked> loc(#loc32)
      %62 = tt.broadcast %58 : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked> loc(#loc33)
      %63 = tt.load %61, %62, %cst {tt.latency = 2 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc33)
      %64 = ttg.local_alloc %63 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc33)
      %65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
      scf.yield %true, %65 : i1, !ttg.async.token loc(#loc34)
    } loc(#loc24)

// new
    %32:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %false, %arg11 = %31) -> (i1, !ttg.async.token)  : i32 {
      %48 = arith.muli %arg9, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 loc(#loc25)
      %49 = arith.subi %arg4, %48 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 loc(#loc26)
      %50 = tt.splat %49 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<1x64xi32, #blocked> loc(#loc27)
      %51 = arith.cmpi slt, %15, %50 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x64xi32, #blocked> loc(#loc27)
      %52 = tt.splat %48 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<128x64xi32, #blocked> loc(#loc28)
      %53 = tt.addptr %18, %52 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked>, tensor<128x64xi32, #blocked> loc(#loc28)
      %54 = tt.broadcast %51 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> loc(#loc29)
      %55 = tt.load %53, %54, %cst_0 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc29)
      %56 = ttg.local_alloc %55 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc29)
      %57 = tt.splat %49 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<64x1xi32, #blocked> loc(#loc30)
      %58 = arith.cmpi slt, %20, %57 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x1xi32, #blocked> loc(#loc30)
      %59 = arith.muli %48, %arg7 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 loc(#loc31)
      %60 = tt.splat %59 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<64x64xi32, #blocked> loc(#loc32)
      %61 = tt.addptr %28, %60 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked>, tensor<64x64xi32, #blocked> loc(#loc32)
      %62 = tt.broadcast %58 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked> loc(#loc33)
      %63 = tt.load %61, %62, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc33)
      %64 = ttg.local_alloc %63 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc33)
      %65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
      scf.yield %true, %65 : i1, !ttg.async.token loc(#loc34)
    } {tt.scheduled_max_stage = 2 : i32} loc(#loc24)

11、SCCP

37-TritonGPURewritePartitionDependencies.mlir vs 38-SCCP.mlir,Sparse Conditional Constant Propagation,CFG中传播常量,依旧是调整顺序。

12、TritonGPUPipeline

42-TritonGPUScheduleLoops.mlir vs 43-TritonGPUPipeline.mlir,GPU流水,软件流水生效了,我们真正到了硬件层面的调度。以下是Pass的description

    Applies software pipelining to loops in the module based on number of stages.
    This may convert some load into asynchronous loads, and multi-buffer the data.

第一个shared改变了

//old
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>

//new
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>

还会增加local_dealloc,减少shared memory的使用。还有async_wait控制异步操作的同步点

// old
    %result_2, %token_3 = ttng.tmem_load %result[%32#1] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked> loc(#loc23)

// new
    %111 = arith.cmpi sgt, %35, %c0_i32 : i32 loc(#loc24)
    %112:16 = scf.if %111 -> (i1, !ttg.async.token, i32, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, i32, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64>) {
      ttng.wait_barrier %110#12, %110#13 deps %110#14, %110#15 : !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64> loc(#loc23)
      scf.yield %true, %110#1, %4, %4, %4, %4, %110#7, %c3_i32, %110#9, %3, %110#11, %3, %0, %110#2, %2, %1 : i1, !ttg.async.token, i32, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, i32, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64> loc(#loc24)
    } else {
      scf.yield %110#0, %110#1, %110#2, %110#3, %110#4, %110#5, %110#6, %110#7, %110#8, %110#9, %110#10, %110#11, %110#12, %110#13, %110#14, %110#15 : i1, !ttg.async.token, i32, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, i32, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64> loc(#loc24)
    } loc(#loc24)
    %113 = ttg.async_wait  {num = 0 : i32} loc(#loc24)
    ttg.local_dealloc %41 : !ttg.memdesc<3x64x64xf32, #shared1, #smem, mutable> loc(#loc24)
    ttg.local_dealloc %40 : !ttg.memdesc<3x128x64xf32, #shared2, #smem, mutable> loc(#loc24)
    %114 = ttg.memdesc_subview %37[%c0_i32] : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
    ttng.inval_barrier %114 : !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
    %115 = ttg.memdesc_subview %37[%c1_i32] : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
    ttng.inval_barrier %115 : !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
    ttg.local_dealloc %37 : !ttg.memdesc<2xi64, #shared, #smem, mutable> loc(#loc24)

循环和循环前也发生了非常大的变化,感兴趣可以自己去看。

13、TritonNvidiaGPURemoveTMEMTokensPass

44-TritonGPUCombineTensorSelectAndIf.mlir vs 45-TritonNvidiaGPURemoveTMEMTokensPass.mlir,会多了个ub.poison,下一个CanonicalizerPass就把这个和之前的5个ub.poison全都干掉了。

14、TritonLoopAwareCSE

46-Canonicalizer.mlir vs 47-TritonLoopAwareCSE.mlir,去掉了ttg.memdesc_subview.

15、TritonNvidiaGPUInterleaveTMemPass

53-TritonGPURemoveLayoutConversions.mlir vs 54-TritonNvidiaGPUInterleaveTMemPass.mlir,调整了下ttng.tmem_load的位置。

16、SCCP

60-TritonGPUFenceInsertion.mlir vs 61-SCCP.mlir,又调整了次顺序。

17、最终产物

最终产物为matrix_multiplication_kernel.ttgir,也是62-Canonicalizer.mlir

五、make_llir

要从ttgir到llvm ir。

    def make_llir(self, src, metadata, options, capability):
        ptx_version = get_ptx_version_from_options(options, self.target.arch)

        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()

        nvidia.passes.ttnvgpuir.add_lower_mma(pm)
        passes.ttgpuir.add_combine_tensor_select_and_if(pm)
        passes.ttgpuir.add_allocate_warp_groups(pm)
        passes.convert.add_scf_to_cf(pm)
        passes.ttgpuir.add_allocate_shared_memory(pm)
        nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
        if knobs.compilation.enable_experimental_consan:
            # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
            passes.ttgpuir.add_concurrency_sanitizer(pm)
        passes.ttgpuir.add_allocate_global_scratch_memory(pm)
        nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
        nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
        nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        if not knobs.compilation.disable_line_info:
            passes.llvmir.add_di_scope(pm)
        pm.run(mod)
        # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
        llvm.init_targets()
        context = llvm.context()
        if knobs.compilation.enable_asan:
            raise RuntimeError(
                "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
        llvm_mod = llvm.to_module(mod, context)
        proc = sm_arch_from_capability(capability)
        features = get_features(options, self.target.arch)
        triple = 'nvptx64-nvidia-cuda'
        nvidia.set_short_ptr()
        llvm.attach_datalayout(llvm_mod, triple, proc, features)
        nvidia.set_nvvm_reflect_ftz(llvm_mod)

        if options.extern_libs:
            paths = [path for (name, path) in options.extern_libs]
            llvm.link_extern_libs(llvm_mod, paths)

        llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)

        # Get some metadata
        # warp-specialization mutates num_warps
        total_num_warps = src.get_int_attr("ttg.total-num-warps")
        if total_num_warps is not None:
            metadata["num_warps"] = total_num_warps
        metadata["shared"] = src.get_int_attr("ttg.shared")
        metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
        metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
        metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
        ret = str(llvm_mod)
        del llvm_mod
        del context
        return ret

Pass执行后的IR从63-TritonNvidiaGPUMMALoweringPass.mlir80-LLVMDIScope.mlir

1、TritonGPUAllocateWarpGroups

64-TritonGPUCombineTensorSelectAndIf.mlir vs 65-TritonGPUAllocateWarpGroups.mlir,module上添加了ttg.total-num-warps

2、SCFToControlFlowPass

65-TritonGPUAllocateWarpGroups.mlir vs 66-SCFToControlFlowPass.mlir,将scf.forlower到cf.br

3、AllocateSharedMemory

66-SCFToControlFlowPass.mlir vs 67-AllocateSharedMemory.mlirmodule上也添加了ttg.shared大小描述, ttg.shared = 147472 : i32

4、TritonTensorMemoryAllocationPass

67-AllocateSharedMemory.mlir vs 68-TritonTensorMemoryAllocationPass.mlir,在module上添加了ttg.tensor_memory_size = 64 : i32

5、TritonGPUGlobalScratchAllocationPass

68-TritonTensorMemoryAllocationPass.mlir vs 69-TritonGPUGlobalScratchAllocationPass.mlir,加上了global_scratch的相关信息。

6、ConvertTritonGPUToLLVM

70-TritonGPUProxyFenceInsertion.mlir vs 71-ConvertTritonGPUToLLVM.mlir,这个Pass非常复杂,需要看每个Op的RewritePattern,这一步转换到线程级别了。

7、Canonicalizer

71-ConvertTritonGPUToLLVM.mlir vs 72-Canonicalizer.mlir,规范化Pass,优化生猛,IR从24154行降到了8648行。

8、CSE

72-Canonicalizer.mlir vs 73-CSE.mlir,公共子表达式消除,IR可以从8648行降到2956行。

9、ConvertNVGPUToLLVM

73-CSE.mlir vs 74-ConvertNVGPUToLLVM.mlir,代码中多了有nvvm.read.ptx.sreg.tid.x这些代码。

10、Canonicalizer

76-ReconcileUnrealizedCastsPass.mlir vs 77-Canonicalizer.mlir,调整了下顺序,减少了几行。

11、CSE

77-Canonicalizer.mlir vs 78-CSE.mlir,公共子表达式消除。

12、LLVMDIScope

79-SymbolDCE vs 80-LLVMDIScope,在 LLVM IR 中附加调试信息作用域(Debug Info Scope) 的 pass,生成调试信息(如 DWARF)以支持源级调试。

13、最终产物

最终产物为 matrix_multiplication_kernel.llir

六、make_ptx

这里实际上调用的是LLVM。

    def make_ptx(self, src, metadata, opt, capability):
        ptx_version = get_ptx_version_from_options(opt, self.target.arch)

        triple = 'nvptx64-nvidia-cuda'
        proc = sm_arch_from_capability(capability)
        features = get_features(opt, self.target.arch)
        ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
        # Find kernel names (there should only be one)
        names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
        assert len(names) == 1
        metadata["name"] = names[0]
        # post-process
        ptx_version = f'{ptx_version//10}.{ptx_version%10}'
        ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
        ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
        # Remove the debug flag that prevents ptxas from optimizing the code
        ret = re.sub(r",\s*debug|debug,\s*", "", ret)
        if knobs.nvidia.dump_nvptx:
            print("// -----// NVPTX Dump //----- //")
            print(ret)
        return ret

产物为matrix_multiplication_kernel.ptx。输入文件是1683行,输出文件是2166行,基本一一对应。

1、dot对照

// old
  %494 = icmp eq i32 %178, 0, !dbg !28
  %495 = and i1 %186, %494, !dbg !28
  br i1 %495, label %496, label %563, !dbg !28

496:                                              ; preds = %10
  %497 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !28
  %498 = extractvalue { i32, i1 } %497, 1, !dbg !28
  %499 = lshr exact i32 ptrtoint (ptr addrspace(3) @global_smem to i32), 4, !dbg !28
  %500 = and i32 %499, 16383, !dbg !28
  %501 = zext nneg i32 %500 to i64, !dbg !28
  %502 = or disjoint i64 %501, 4611686293372403712, !dbg !28
  %503 = lshr exact i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), 4, !dbg !28
  %504 = and i32 %503, 16383, !dbg !28
  %505 = zext nneg i32 %504 to i64, !dbg !28
  %506 = or disjoint i64 %505, 4611686293338849280, !dbg !28
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %13, i64 %502, i64 %506, i32 135268624, i1 false, i1 %498) #4, !dbg !28
...
!28 = !DILocation(line: 38, column: 33, scope: !5)

// new
	.loc	1 38 33                         // matmul.py:38:33
	setp.ne.s32 	%p27, %r32, 0;
	or.pred 	%p28, %p6, %p27;
	@%p28 bra 	$L__BB0_2;
// %bb.1:
	elect.sync 	%r530|%p30, -1;
	bfe.u32 	%r532, %r103, 4, 14;
	cvt.u64.u32 	%rd197, %r532;
	or.b64 	%rd180, %rd197, 4611686293372403712;
	bfe.u32 	%r534, %r464, 4, 14;
	cvt.u64.u32 	%rd198, %r534;
	or.b64 	%rd181, %rd198, 4611686293338849280;
	mov.b32 	%r515, 135268624;
	mov.pred 	%p29, 0;
	// begin inline asm
	@%p30 tcgen05.mma.cta_group::1.kind::tf32 [ %r1077 + 0 ], %rd180, %rd181, %r515, %p29;
	// end inline asm

2、store对照

// old
...
  %1222 = getelementptr inbounds nuw i8, ptr addrspace(3) %1192, i32 1920, !dbg !44
  %1223 = load <4 x i32>, ptr addrspace(3) %1222, align 16, !dbg !44
  %.extract = extractelement <4 x i32> %1193, i64 0, !dbg !44
  %.extract64 = extractelement <4 x i32> %1193, i64 1, !dbg !44
  %.extract65 = extractelement <4 x i32> %1193, i64 2, !dbg !44
  %.extract66 = extractelement <4 x i32> %1193, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract, i32 %.extract64, i32 %.extract65, i32 %.extract66, ptr addrspace(1) %981, i1 %1014) #4, !dbg !44
  %.extract67 = extractelement <4 x i32> %1195, i64 0, !dbg !44
  %.extract68 = extractelement <4 x i32> %1195, i64 1, !dbg !44
  %.extract69 = extractelement <4 x i32> %1195, i64 2, !dbg !44
  %.extract70 = extractelement <4 x i32> %1195, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract67, i32 %.extract68, i32 %.extract69, i32 %.extract70, ptr addrspace(1) %982, i1 %1015) #4, !dbg !44
  %.extract71 = extractelement <4 x i32> %1197, i64 0, !dbg !44
  %.extract72 = extractelement <4 x i32> %1197, i64 1, !dbg !44
  %.extract73 = extractelement <4 x i32> %1197, i64 2, !dbg !44
  %.extract74 = extractelement <4 x i32> %1197, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract71, i32 %.extract72, i32 %.extract73, i32 %.extract74, ptr addrspace(1) %983, i1 %1016) #4, !dbg !44
  %.extract75 = extractelement <4 x i32> %1199, i64 0, !dbg !44
  %.extract76 = extractelement <4 x i32> %1199, i64 1, !dbg !44
  %.extract77 = extractelement <4 x i32> %1199, i64 2, !dbg !44
  %.extract78 = extractelement <4 x i32> %1199, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract75, i32 %.extract76, i32 %.extract77, i32 %.extract78, ptr addrspace(1) %984, i1 %1017) #4, !dbg !44
  %.extract79 = extractelement <4 x i32> %1201, i64 0, !dbg !44
  %.extract80 = extractelement <4 x i32> %1201, i64 1, !dbg !44
  %.extract81 = extractelement <4 x i32> %1201, i64 2, !dbg !44
  %.extract82 = extractelement <4 x i32> %1201, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract79, i32 %.extract80, i32 %.extract81, i32 %.extract82, ptr addrspace(1) %985, i1 %1018) #4, !dbg !44
  %.extract83 = extractelement <4 x i32> %1203, i64 0, !dbg !44
  %.extract84 = extractelement <4 x i32> %1203, i64 1, !dbg !44
  %.extract85 = extractelement <4 x i32> %1203, i64 2, !dbg !44
  %.extract86 = extractelement <4 x i32> %1203, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract83, i32 %.extract84, i32 %.extract85, i32 %.extract86, ptr addrspace(1) %986, i1 %1019) #4, !dbg !44
  %.extract87 = extractelement <4 x i32> %1205, i64 0, !dbg !44
  %.extract88 = extractelement <4 x i32> %1205, i64 1, !dbg !44
  %.extract89 = extractelement <4 x i32> %1205, i64 2, !dbg !44
  %.extract90 = extractelement <4 x i32> %1205, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract87, i32 %.extract88, i32 %.extract89, i32 %.extract90, ptr addrspace(1) %987, i1 %1020) #4, !dbg !44
  %.extract91 = extractelement <4 x i32> %1207, i64 0, !dbg !44
  %.extract92 = extractelement <4 x i32> %1207, i64 1, !dbg !44
  %.extract93 = extractelement <4 x i32> %1207, i64 2, !dbg !44
  %.extract94 = extractelement <4 x i32> %1207, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract91, i32 %.extract92, i32 %.extract93, i32 %.extract94, ptr addrspace(1) %988, i1 %1021) #4, !dbg !44
  %.extract95 = extractelement <4 x i32> %1209, i64 0, !dbg !44
  %.extract96 = extractelement <4 x i32> %1209, i64 1, !dbg !44
  %.extract97 = extractelement <4 x i32> %1209, i64 2, !dbg !44
  %.extract98 = extractelement <4 x i32> %1209, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract95, i32 %.extract96, i32 %.extract97, i32 %.extract98, ptr addrspace(1) %989, i1 %1022) #4, !dbg !44
  %.extract99 = extractelement <4 x i32> %1211, i64 0, !dbg !44
  %.extract100 = extractelement <4 x i32> %1211, i64 1, !dbg !44
  %.extract101 = extractelement <4 x i32> %1211, i64 2, !dbg !44
  %.extract102 = extractelement <4 x i32> %1211, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract99, i32 %.extract100, i32 %.extract101, i32 %.extract102, ptr addrspace(1) %990, i1 %1023) #4, !dbg !44
  %.extract103 = extractelement <4 x i32> %1213, i64 0, !dbg !44
  %.extract104 = extractelement <4 x i32> %1213, i64 1, !dbg !44
  %.extract105 = extractelement <4 x i32> %1213, i64 2, !dbg !44
  %.extract106 = extractelement <4 x i32> %1213, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract103, i32 %.extract104, i32 %.extract105, i32 %.extract106, ptr addrspace(1) %991, i1 %1024) #4, !dbg !44
  %.extract107 = extractelement <4 x i32> %1215, i64 0, !dbg !44
  %.extract108 = extractelement <4 x i32> %1215, i64 1, !dbg !44
  %.extract109 = extractelement <4 x i32> %1215, i64 2, !dbg !44
  %.extract110 = extractelement <4 x i32> %1215, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract107, i32 %.extract108, i32 %.extract109, i32 %.extract110, ptr addrspace(1) %992, i1 %1025) #4, !dbg !44
  %.extract111 = extractelement <4 x i32> %1217, i64 0, !dbg !44
  %.extract112 = extractelement <4 x i32> %1217, i64 1, !dbg !44
  %.extract113 = extractelement <4 x i32> %1217, i64 2, !dbg !44
  %.extract114 = extractelement <4 x i32> %1217, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract111, i32 %.extract112, i32 %.extract113, i32 %.extract114, ptr addrspace(1) %993, i1 %1026) #4, !dbg !44
  %.extract115 = extractelement <4 x i32> %1219, i64 0, !dbg !44
  %.extract116 = extractelement <4 x i32> %1219, i64 1, !dbg !44
  %.extract117 = extractelement <4 x i32> %1219, i64 2, !dbg !44
  %.extract118 = extractelement <4 x i32> %1219, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract115, i32 %.extract116, i32 %.extract117, i32 %.extract118, ptr addrspace(1) %994, i1 %1027) #4, !dbg !44
  %.extract119 = extractelement <4 x i32> %1221, i64 0, !dbg !44
  %.extract120 = extractelement <4 x i32> %1221, i64 1, !dbg !44
  %.extract121 = extractelement <4 x i32> %1221, i64 2, !dbg !44
  %.extract122 = extractelement <4 x i32> %1221, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract119, i32 %.extract120, i32 %.extract121, i32 %.extract122, ptr addrspace(1) %995, i1 %1028) #4, !dbg !44
  %.extract123 = extractelement <4 x i32> %1223, i64 0, !dbg !44
  %.extract124 = extractelement <4 x i32> %1223, i64 1, !dbg !44
  %.extract125 = extractelement <4 x i32> %1223, i64 2, !dbg !44
  %.extract126 = extractelement <4 x i32> %1223, i64 3, !dbg !44
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract123, i32 %.extract124, i32 %.extract125, i32 %.extract126, ptr addrspace(1) %996, i1 %1029) #4, !dbg !44
...
!44 = !DILocation(line: 45, column: 21, scope: !5)

// new
	.loc	1 45 21                         // matmul.py:45:21
...
	ld.shared.v4.b32 	{%r1073, %r1074, %r1075, %r1076}, [%r1123+1920];
	// begin inline asm
	@%p92 st.global.v4.b32 [ %rd346 + 0 ], { %r1013, %r1014, %r1015, %r1016 };
	// end inline asm
	// begin inline asm
	@%p93 st.global.v4.b32 [ %rd347 + 0 ], { %r1017, %r1018, %r1019, %r1020 };
	// end inline asm
	// begin inline asm
	@%p94 st.global.v4.b32 [ %rd348 + 0 ], { %r1021, %r1022, %r1023, %r1024 };
	// end inline asm
	// begin inline asm
	@%p95 st.global.v4.b32 [ %rd349 + 0 ], { %r1025, %r1026, %r1027, %r1028 };
	// end inline asm
	// begin inline asm
	@%p96 st.global.v4.b32 [ %rd350 + 0 ], { %r1029, %r1030, %r1031, %r1032 };
	// end inline asm
	// begin inline asm
	@%p97 st.global.v4.b32 [ %rd351 + 0 ], { %r1033, %r1034, %r1035, %r1036 };
	// end inline asm
	// begin inline asm
	@%p98 st.global.v4.b32 [ %rd352 + 0 ], { %r1037, %r1038, %r1039, %r1040 };
	// end inline asm
	// begin inline asm
	@%p99 st.global.v4.b32 [ %rd353 + 0 ], { %r1041, %r1042, %r1043, %r1044 };
	// end inline asm
	// begin inline asm
	@%p100 st.global.v4.b32 [ %rd354 + 0 ], { %r1045, %r1046, %r1047, %r1048 };
	// end inline asm
	// begin inline asm
	@%p101 st.global.v4.b32 [ %rd355 + 0 ], { %r1049, %r1050, %r1051, %r1052 };
	// end inline asm
	// begin inline asm
	@%p102 st.global.v4.b32 [ %rd356 + 0 ], { %r1053, %r1054, %r1055, %r1056 };
	// end inline asm
	// begin inline asm
	@%p103 st.global.v4.b32 [ %rd357 + 0 ], { %r1057, %r1058, %r1059, %r1060 };
	// end inline asm
	// begin inline asm
	@%p104 st.global.v4.b32 [ %rd358 + 0 ], { %r1061, %r1062, %r1063, %r1064 };
	// end inline asm
	// begin inline asm
	@%p105 st.global.v4.b32 [ %rd359 + 0 ], { %r1065, %r1066, %r1067, %r1068 };
	// end inline asm
	// begin inline asm
	@%p106 st.global.v4.b32 [ %rd360 + 0 ], { %r1069, %r1070, %r1071, %r1072 };
	// end inline asm
	// begin inline asm
	@%p107 st.global.v4.b32 [ %rd361 + 0 ], { %r1073, %r1074, %r1075, %r1076 };
	// end inline asm

3、load对照

// old
  %189 = and i32 %11, 1, !dbg !31
  %190 = icmp eq i32 %189, 0, !dbg !31
  %191 = and i32 %44, 28, !dbg !31
  %192 = shl nuw nsw i32 %11, 9, !dbg !31
  %193 = and i32 %192, 4096, !dbg !31
  %194 = or disjoint i32 %191, %193, !dbg !31
  %195 = and i32 %11, 16, !dbg !31
  %.not = icmp eq i32 %195, 0, !dbg !31
  %196 = select i1 %.not, i32 0, i32 36, !dbg !31
  %197 = and i32 %11, 32, !dbg !31
  %198 = icmp eq i32 %197, 0, !dbg !31
  %199 = select i1 %198, i32 0, i32 72, !dbg !31
  %200 = and i32 %11, 64, !dbg !31
  %201 = icmp eq i32 %200, 0, !dbg !31
  %202 = select i1 %201, i32 0, i32 144, !dbg !31
  %203 = or disjoint i32 %199, %196, !dbg !31
  %204 = xor i32 %203, %194, !dbg !31
  %205 = xor i32 %204, %202, !dbg !31
  %206 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %205, !dbg !31
  %207 = or disjoint i32 %194, 256, !dbg !31
  %208 = or disjoint i32 %203, %202, !dbg !31
  %209 = xor i32 %208, %207, !dbg !31
  %210 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %209, !dbg !31
  %211 = or disjoint i32 %194, 512, !dbg !31
  %212 = xor i32 %208, %211, !dbg !31
  %213 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %212, !dbg !31
  %214 = or disjoint i32 %194, 768, !dbg !31
  %215 = xor i32 %208, %214, !dbg !31
  %216 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %215, !dbg !31
  %217 = or disjoint i32 %194, 1024, !dbg !31
  %218 = xor i32 %208, %217, !dbg !31
  %219 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %218, !dbg !31
  %220 = or disjoint i32 %194, 1280, !dbg !31
  %221 = xor i32 %208, %220, !dbg !31
  %222 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %221, !dbg !31
  %223 = or disjoint i32 %194, 1536, !dbg !31
  %224 = xor i32 %208, %223, !dbg !31
  %225 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %224, !dbg !31
  %226 = or disjoint i32 %194, 1792, !dbg !31
  %227 = xor i32 %208, %226, !dbg !31
  %228 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %227, !dbg !31
  %229 = or disjoint i32 %194, 2048, !dbg !31
  %230 = xor i32 %208, %229, !dbg !31
  %231 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %230, !dbg !31
  %232 = or disjoint i32 %194, 2304, !dbg !31
  %233 = xor i32 %208, %232, !dbg !31
  %234 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %233, !dbg !31
  %235 = or disjoint i32 %194, 2560, !dbg !31
  %236 = xor i32 %208, %235, !dbg !31
  %237 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %236, !dbg !31
  %238 = or disjoint i32 %194, 2816, !dbg !31
  %239 = xor i32 %208, %238, !dbg !31
  %240 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %239, !dbg !31
  %241 = or disjoint i32 %194, 3072, !dbg !31
  %242 = xor i32 %208, %241, !dbg !31
  %243 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %242, !dbg !31
  %244 = or disjoint i32 %194, 3328, !dbg !31
  %245 = xor i32 %208, %244, !dbg !31
  %246 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %245, !dbg !31
  %247 = or disjoint i32 %194, 3584, !dbg !31
  %248 = xor i32 %208, %247, !dbg !31
  %249 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %248, !dbg !31
  %250 = or disjoint i32 %194, 3840, !dbg !31
  %251 = xor i32 %208, %250, !dbg !31
  %252 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %251, !dbg !31
  %253 = select i1 %188, i32 16, i32 0, !dbg !31
...
  %838 = select i1 %821, i32 16, i32 0, !dbg !31
  tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %822, ptr addrspace(1) %804, i32 %838) #4, !dbg !31
  tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %823, ptr addrspace(1) %805, i32 %838) #4, !dbg !31
  tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %824, ptr addrspace(1) %806, i32 %838) #4, !dbg !31
  tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %825, ptr addrspace(1) %807, i32 %838) #4, !dbg !31
...
!31 = !DILocation(line: 35, column: 20, scope: !5)

// new
	.loc	1 35 20                         // matmul.py:35:20
	and.b32 	%r406, %r1, 1;
	neg.s32 	%r407, %r406;
	shl.b32 	%r408, %r1, 9;
	or.b32 	%r409, %r368, %r408;
	and.b32 	%r410, %r409, 4124;
	bfe.s32 	%r411, %r1, 4, 1;
	and.b32 	%r412, %r411, 36;
	and.b32 	%r413, %r1, 32;
	bfe.s32 	%r414, %r1, 5, 1;
	and.b32 	%r415, %r414, 72;
	and.b32 	%r416, %r1, 64;
	bfe.s32 	%r417, %r1, 6, 1;
	and.b32 	%r418, %r417, 144;
	or.b32 	%r419, %r415, %r412;
	xor.b32 	%r420, %r419, %r410;
	xor.b32 	%r34, %r420, %r418;
	shl.b32 	%r421, %r34, 2;
	add.s32 	%r171, %r103, %r421;
	or.b32 	%r422, %r410, 256;
	or.b32 	%r423, %r419, %r418;
	xor.b32 	%r35, %r423, %r422;
	shl.b32 	%r424, %r35, 2;
	add.s32 	%r173, %r103, %r424;
	or.b32 	%r425, %r410, 512;
	xor.b32 	%r36, %r423, %r425;
	shl.b32 	%r426, %r36, 2;
	add.s32 	%r175, %r103, %r426;
	or.b32 	%r427, %r410, 768;
	xor.b32 	%r37, %r423, %r427;
	shl.b32 	%r428, %r37, 2;
	add.s32 	%r177, %r103, %r428;
	or.b32 	%r429, %r410, 1024;
	xor.b32 	%r38, %r423, %r429;
	shl.b32 	%r430, %r38, 2;
	add.s32 	%r179, %r103, %r430;
	or.b32 	%r431, %r410, 1280;
	xor.b32 	%r39, %r423, %r431;
	shl.b32 	%r432, %r39, 2;
	add.s32 	%r181, %r103, %r432;
	or.b32 	%r433, %r410, 1536;
	xor.b32 	%r40, %r423, %r433;
	shl.b32 	%r434, %r40, 2;
	add.s32 	%r183, %r103, %r434;
	or.b32 	%r435, %r410, 1792;
	xor.b32 	%r41, %r423, %r435;
	shl.b32 	%r436, %r41, 2;
	add.s32 	%r185, %r103, %r436;
	or.b32 	%r437, %r410, 2048;
	xor.b32 	%r42, %r423, %r437;
	shl.b32 	%r438, %r42, 2;
	add.s32 	%r187, %r103, %r438;
	or.b32 	%r439, %r410, 2304;
	xor.b32 	%r43, %r423, %r439;
	shl.b32 	%r440, %r43, 2;
	add.s32 	%r189, %r103, %r440;
	or.b32 	%r441, %r410, 2560;
	xor.b32 	%r44, %r423, %r441;
	shl.b32 	%r442, %r44, 2;
	add.s32 	%r191, %r103, %r442;
	or.b32 	%r443, %r410, 2816;
	xor.b32 	%r45, %r423, %r443;
	shl.b32 	%r444, %r45, 2;
	add.s32 	%r193, %r103, %r444;
	or.b32 	%r445, %r410, 3072;
	xor.b32 	%r46, %r423, %r445;
	shl.b32 	%r446, %r46, 2;
	add.s32 	%r195, %r103, %r446;
	or.b32 	%r447, %r410, 3328;
	xor.b32 	%r47, %r423, %r447;
	shl.b32 	%r448, %r47, 2;
	add.s32 	%r197, %r103, %r448;
	or.b32 	%r449, %r410, 3584;
	xor.b32 	%r48, %r423, %r449;
	shl.b32 	%r450, %r48, 2;
	add.s32 	%r199, %r103, %r450;
	or.b32 	%r451, %r410, 3840;
	xor.b32 	%r49, %r423, %r451;
	shl.b32 	%r452, %r49, 2;
	add.s32 	%r201, %r103, %r452;
	selp.b32 	%r453, 16, 0, %p7;
	selp.b32 	%r174, %r453, 0, %p8;
...
	selp.b32 	%r453, 16, 0, %p7;
	selp.b32 	%r174, %r453, 0, %p8;
	// begin inline asm
	cp.async.cg.shared.global [ %r171 + 0 ], [ %rd31 + 0 ], 0x10, %r174;
	// end inline asm
	// begin inline asm
	cp.async.cg.shared.global [ %r173 + 0 ], [ %rd32 + 0 ], 0x10, %r174;
	// end inline asm
	// begin inline asm
	cp.async.cg.shared.global [ %r175 + 0 ], [ %rd33 + 0 ], 0x10, %r174;
	// end inline asm
	// begin inline asm
	cp.async.cg.shared.global [ %r177 + 0 ], [ %rd34 + 0 ], 0x10, %r174;
	// end inline asm
...

七、系列文章

深度剖析 Triton编译器 MatMul优化(一)

浅析 Triton 执行流程

posted @ 2025-06-29 13:57  暴力都不会的蒟蒻  阅读(161)  评论(0)    收藏  举报