深度剖析 Triton编译器 MatMul优化(三)—— TMA

深度剖析 Triton编译器 MatMul优化(二)—— MMA 我们介绍了很容易就拿到性能的tl.dot操作,生成的是tcgen05.mma.cta_group::1.kind::tf32cp.async.cg.shared.global指令,这次我们来看TMA,生成的是cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytesldmatrix.sync.aligned.m8n8.x4.shared.b16指令。

本文所用用Triton的 commit为bc75dd0(Jun 27, 2025) 的版本。所有IR和kernel文件均已上传至Github。sBobHuang/Triton-blog-file。本系列相关文章

深度剖析 Triton编译器 MatMul优化(二)—— MMA

深度剖析 Triton编译器 MatMul优化(一)—— FMA

一、matmul Triton kernel

1、kernel书写

Triton kernel如下所示,矩阵a大小为M*N,矩阵b大小为N*K,结果矩阵c为M*K。完整可运行代码在matmul-with-tma-v2.py

# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl

@triton.jit
def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr,  #
                                        M, N, K,  #
                                        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
                                        BLOCK_SIZE_K: tl.constexpr,  #
                                        ):
    a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
    b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
    c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
    pid_m = tl.program_id(axis=0)
    pid_k = tl.program_id(axis=1)

    a_desc = tl.make_tensor_descriptor(
        a_ptr,
        shape=[M, N],
        strides=[N, 1],
        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
    )
    b_desc = tl.make_tensor_descriptor(
        b_ptr,
        shape=[N, K],
        strides=[K, 1],
        block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
    )
    c_desc = tl.make_tensor_descriptor(
        c_ptr,
        shape=[M, K],
        strides=[K, 1],
        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
    )
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)

    for n in range(tl.cdiv(N, BLOCK_SIZE_N)):
        a = a_desc.load([pid_m * BLOCK_SIZE_M, n * BLOCK_SIZE_N])
        b = b_desc.load([n * BLOCK_SIZE_N, pid_k * BLOCK_SIZE_K])
        accumulator = tl.dot(a, b, acc=accumulator)

    accumulator = accumulator.to(a_desc.dtype)
    c_desc.store([pid_m * BLOCK_SIZE_M, pid_k * BLOCK_SIZE_K], accumulator)

# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(K, META['BLOCK_SIZE_K']), )
    # Leading dimensions must be multiples of 16-byte strides
    # if M % 4 == 0 and N % 4 == 0 and K % 4 == 0:

    import torch
    # TMA descriptors require a global memory allocation
    def alloc_fn(size, alignment, stream):
        return torch.empty(size, device="cuda", dtype=torch.int8)

    triton.set_allocator(alloc_fn)
    matmul_kernel_make_tensor_desciptor[grid](
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        BLOCK_SIZE_M=128,
        BLOCK_SIZE_K=64,
        BLOCK_SIZE_N=64,
    )

tl.make_tensor_descriptor3.4.0的API,3.3.1请使用tl._experimental_make_tensor_descriptor

2、kernel简析

这个kernel算是相当直观了,a确定好 [BLOCK_SIZE_M, BLOCK_SIZE_N]的块,然后每次从中拿第[pid_m * BLOCK_SIZE_M, n * BLOCK_SIZE_N]的小块,b同理从第[n * BLOCK_SIZE_N, pid_k * BLOCK_SIZE_K]的块dot即可,最后直接存,我们可以看到没有mask。所以tl.make_tensor_descriptor是有限制的,需要手动padding,然后注意Leading dimensions must be multiples of 16-byte strides 16B对齐(最小搬运大小)。以前有一个tl.make_block_ptr,能用但是store还是得按行mask,matmul-with-block.py

二、ast_to_ttir

使用JIT装饰器遍历Python AST,最后调用MLIR的self.create<

1、循环IR

得到的ttir比较冗余,全部IR在01-source.mlir,我们挑其中的循环看一下,如下所示。

    %15 = scf.for %arg6 = %11 to %12 step %13 iter_args(%arg7 = %8) -> (tensor<128x64xf32>)  : i32 {
...
      %36 = arith.muli %0, %c128_i32_7 : i32 loc(#loc9)
      %c64_i32_10 = arith.constant 64 : i32 loc(#loc10)
      %c64_i32_11 = arith.constant 64 : i32 loc(#loc10)
...
      %43 = arith.muli %arg6, %c64_i32_11 : i32 loc(#loc10)
      %44 = tt.descriptor_load %3[%36, %43] : !tt.tensordesc<tensor<128x64xf32>> -> tensor<128x64xf32> loc(#loc11)
      %c64_i32_14 = arith.constant 64 : i32 loc(#loc12)
      %c64_i32_15 = arith.constant 64 : i32 loc(#loc12)
...
      %51 = arith.muli %arg6, %c64_i32_15 : i32 loc(#loc12)
      %c64_i32_18 = arith.constant 64 : i32 loc(#loc13)
      %c64_i32_19 = arith.constant 64 : i32 loc(#loc13)
...
      %58 = arith.muli %1, %c64_i32_19 : i32 loc(#loc13)
      %59 = tt.descriptor_load %5[%51, %58] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32> loc(#loc14)
      %cst = arith.constant 0.000000e+00 : f32 loc(#loc15)
      %60 = tt.dot %44, %59, %arg7, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc15)
      scf.yield %60 : tensor<128x64xf32> loc(#loc16)
    } loc(#loc8)

2、load a

这个load非常清爽简单,因为就是从大的里面取了一块。

      %51 = arith.muli %arg6, %c64_i32_15 : i32 loc(#loc12)
...
      %58 = arith.muli %1, %c64_i32_19 : i32 loc(#loc13)
      %59 = tt.descriptor_load %5[%51, %58] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32> loc(#loc14)

%5 是通过make_tensor 来表示的

    %4 = arith.extsi %arg5 : i32 to i64 loc(#loc4)
    %c1_i64_0 = arith.constant 1 : i64 loc(#loc4)
    %5 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%4, %c1_i64_0] : <f32>, <tensor<64x64xf32>> loc(#loc4)

其Op定义如下所示,抽象了block tensor的访问,其实是subview

def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
    Pure,
    SameVariadicOperandSize,
]> {
  let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";

  let description = [{
      `tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size,
      and returns a descriptor object which can be used to load/store from the tensor in global memory.
  }];

  let arguments = (ins
    TT_Ptr:$base,
    Variadic<I32>:$shape,
    Variadic<I64>:$strides
  );

  let results = (outs TT_TensorDescType:$result);

  let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";

  let builders = [
    OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger)>
  ];

  let extraClassDeclaration = [{
    ArrayRef<int64_t> getTensorShape() {
      return getType().getBlockType().getShape();
    }
  }];
}

3、其他部分

store也用了对应的tt.descriptor_store,都是Python ast自己处理的

    %22 = arith.muli %0, %c128_i32_2 : i32 loc(#loc17)
...
    %29 = arith.muli %1, %c64_i32_3 : i32 loc(#loc18)
    tt.descriptor_store %7[%22, %29], %15 : !tt.tensordesc<tensor<128x64xf32>>, tensor<128x64xf32> loc(#loc19)
    tt.return loc(#loc20)

你还可以丢给chatgpt解读01-source.mlir。此时是完全符合Python DSL语义的,无任何优化。

三、 make_ttir

这是将Python ast得到的MLIR简化的阶段,我们将执行如下流程。

    @staticmethod
    def make_ttir(mod, metadata, opt, capability):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        passes.common.add_inliner(pm)
        passes.ttir.add_rewrite_tensor_pointer(pm)
        if capability // 10 < 9:
            passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_combine(pm)
        passes.ttir.add_reorder_broadcast(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        passes.ttir.add_loop_unroll(pm)
        pm.run(mod)
        return mod

Pass执行后IR从02-Inliner.mlir13-TritonLoopUnroll.mlir

1、Canonicalizer

规范化是一个通用Pass,你会看到它经常出现,它可能存在的操作包含消除冗余操作、简化匹配模式、折叠常量计算、应用定义的Opcanonicalization

02-Inliner.mlir vs 03-Canonicalizer.mlir,对 triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了化简。

03-Canonicalizer.mlir vs 04-Canonicalizer.mlir,对 triton.language.standard.cdiv__i32__ 做了进一步化简。

04-Canonicalizer.mlir vs 05-Canonicalizer.mlir,变化特别大,循环中没使用的Op都去掉了,并且对triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了inline。

05-Canonicalizer.mlir vs 06-Canonicalizer.mlir,进一步化简。

2、CSE

10-TritonReorderBroadcast.mlir vs 11-CSE.mlir,Common Subexpression Elimination,公共子表达式消除,比如arith.extsi就有很多重复的。

3、本阶段IR产物

本阶段IR产物为matmul_kernel_make_tensor_desciptor.ttir,也是13-TritonLoopUnroll.mlir

四、 make_ttgir

这个阶段Pass比较多,但是对我们源码产生变化的也能接受。

    @staticmethod
    def make_ttgir(mod, metadata, opt, capability):
        # Set maxnreg on all kernels, if it was provided.
        if opt.maxnreg is not None:
            mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))

        cluster_info = nvidia.ClusterInfo()
        if opt.cluster_dims is not None:
            cluster_info.clusterDimX = opt.cluster_dims[0]
            cluster_info.clusterDimY = opt.cluster_dims[1]
            cluster_info.clusterDimZ = opt.cluster_dims[2]
        pm = ir.pass_manager(mod.context)
        dump_enabled = pm.enable_debug()
        passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
        # optimize TTGIR
        passes.ttgpuir.add_coalesce(pm)
        if capability // 10 >= 8:
            passes.ttgpuir.add_f32_dot_tc(pm)
        # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
        nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_thread_locality(pm)
        passes.ttgpuir.add_accelerate_matmul(pm)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
        passes.ttir.add_loop_aware_cse(pm)
        if capability // 10 in [8, 9]:
            passes.ttgpuir.add_fuse_nested_loops(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttir.add_triton_licm(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttgpuir.add_combine_tensor_select_and_if(pm)
            nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
            passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
            passes.ttgpuir.add_schedule_loops(pm)
            passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
        elif capability // 10 >= 10:
            passes.ttgpuir.add_fuse_nested_loops(pm)
            passes.common.add_canonicalizer(pm)
            passes.ttir.add_triton_licm(pm)
            passes.ttgpuir.add_optimize_accumulator_init(pm)
            passes.ttgpuir.add_hoist_tmem_alloc(pm)
            nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
            passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
            passes.ttgpuir.add_schedule_loops(pm)
            passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
            passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
            passes.ttgpuir.add_combine_tensor_select_and_if(pm)
            nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
        else:
            passes.ttir.add_triton_licm(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_loop_aware_cse(pm)
        passes.ttgpuir.add_prefetch(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        passes.ttgpuir.add_coalesce_async_copy(pm)
        nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
        passes.ttgpuir.add_reduce_data_duplication(pm)
        passes.ttgpuir.add_reorder_instructions(pm)
        passes.ttir.add_loop_aware_cse(pm)
        passes.common.add_symbol_dce(pm)
        if capability // 10 >= 9:
            nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
        nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
        passes.common.add_sccp(pm)
        passes.common.add_canonicalizer(pm)
        pm.run(mod)
        metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
        tensordesc_meta = mod.get_tensordesc_metadata()
        metadata["tensordesc_meta"] = tensordesc_meta
        return mod

Pass执行后IR从14-ConvertTritonToTritonGPU.mlir62-Canonicalizer.mlir

1、ConvertTritonToTritonGPU

13-TritonLoopUnroll.mlir vs 14-ConvertTritonToTritonGPU.mlir,这里主要是加上了一些layout。

2、TritonGPURemoveLayoutConversions

17-TritonGPUPlanCTAPass.mlir vs 18-TritonGPURemoveLayoutConversions.mlir,去除多余的convert_layout,对iter_args的layout做了交换,这样循环内少了2次ttg.convert_layout,循环结束后多1次ttg.convert_layout

// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
...
%9 = scf.for %arg6 = %c0_i32 to %8 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x64xf32, #blocked>)  : i32 {}

// new
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
...
    %9 = scf.for %arg6 = %c0_i32 to %8 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x64xf32, #blocked>)  : i32 {}

3、TritonGPUAccelerateMatmul

19-TritonGPUOptimizeThreadLocality.mlir vs 20-TritonGPUAccelerateMatmul.mlir,这里对tt.dot做了处理。

// old
      %16 = arith.muli %1, %c64_i32 : i32 loc(#loc14)
      %17 = tt.descriptor_load %5[%14, %16] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32, #blocked1> loc(#loc15)
      %18 = ttg.convert_layout %15 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> loc(#loc13)
      %19 = ttg.convert_layout %17 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> loc(#loc15)
      %20 = tt.dot %18, %19, %arg7, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> loc(#loc16)
      scf.yield %20 : tensor<128x64xf32, #blocked> loc(#loc17)

// new
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#loc = loc("/home/ubuntu/triton/matmul.py":6:0)
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}>
#smem = #ttg.shared_memory
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
...
      %16 = ttg.local_alloc %15 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc13)
      %17 = arith.muli %1, %c64_i32 : i32 loc(#loc14)
      %18 = tt.descriptor_load %5[%14, %17] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32, #blocked1> loc(#loc15)
      %19 = ttg.local_alloc %18 : (tensor<64x64xf32, #blocked1>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc15)
      %20 = ttg.convert_layout %arg7 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> loc(#loc16)
      %result, %token = ttng.tmem_alloc %20 : (tensor<128x64xf32, #blocked2>) -> (!ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) loc(#loc16)
      %21 = ttng.tc_gen5_mma %16, %19, %result[%token], %true, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc16)
      %result_0, %token_1 = ttng.tmem_load %result[%21] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked2> loc(#loc16)
      %22 = ttg.convert_layout %result_0 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> loc(#loc16)
      scf.yield %22 : tensor<128x64xf32, #blocked> loc(#loc17)

4、TritonGPURemoveLayoutConversions

20-TritonGPUAccelerateMatmul.mlir vs 21-TritonGPURemoveLayoutConversions.mlir,进一步合并layout。

// old
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>

// new
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>

5、TritonNvidiaGPUOptimizeDescriptorEncodingPass

23-Canonicalizer.mlir vs 24-TritonNvidiaGPUOptimizeDescriptorEncodingPass.mlir,为tensordesc加上了#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> 标签。

6、TritonLoopInvariantCodeMotion

27-Canonicalizer.mlir vs 28-TritonLoopInvariantCodeMotion.mlir,循环不变量外提,即那两个offset的计算。

7、TritonGPUOptimizeAccumulatorInit

28-TritonLoopInvariantCodeMotion.mlir vs 29-TritonGPUOptimizeAccumulatorInit.mlir,变换修改了ttng.tc_gen5_mma的operand

8、TritonGPUHoistTMEMAlloc

29-TritonGPUOptimizeAccumulatorInit.mlir vs 30-TritonGPUHoistTMEMAlloc.mlir,将tensor memory的alloc挪到循环外面,并初始化,避免不必要的开销。

9、TritonGPUAssignLatencies

31-TritonNvidiaGPUPromoteLHSToTMemPass.mlir vs 32-TritonGPUAssignLatencies.mlir,为相应的Op打上tt.latency的Attribute。

10、TritonGPUScheduleLoops

32-TritonGPUAssignLatencies.mlir vs 33-TritonGPUScheduleLoops.mlirsoftware pipeline loop scheduling,我最爱的软流水是这个Pass,涉及到GPU核心了,隐藏延迟。这里把tt.latency转换为loop.clusterloop.stage的Attribute。

11、SCCP

37-TritonGPURewritePartitionDependencies.mlir vs 38-SCCP.mlir,Sparse Conditional Constant Propagation,CFG中传播常量,依旧是调整顺序。

12、CSE

38-SCCP.mlir vs 39-CSE.mlir,Common Subexpression Elimination,公共子表达式消除,

13、TritonGPUPipeline

42-TritonGPUScheduleLoops.mlir vs 43-TritonGPUPipeline.mlir,GPU流水,软件流水生效了,做了硬件层面的调度,甚至把第一次循环给提出来了,其实就是prologue

这里numStages默认值为3,即流水得做 3 个,这样可以 overlap prefetching / computing / waiting 三阶段。

  let options = [
    Option<"numStages", "num-stages",
           "int32_t", /*default*/"3",
           "number of pipeline stages">,
    Option<"dumpIntermediateSteps", "dump-intermediate-steps",
           "bool", /*default*/"false",
           "Dump intermediate steps">
  ];

14、TritonNvidiaGPURemoveTMEMTokensPass

44-TritonGPUCombineTensorSelectAndIf.mlir vs 45-TritonNvidiaGPURemoveTMEMTokensPass.mlir,会多了个ub.poison,下一个CanonicalizerPass就把这个和之前的4个ub.poison全都干掉了。

15、TritonLoopAwareCSE

46-Canonicalizer.mlir vs 47-TritonLoopAwareCSE.mlir,去掉了部分ttg.memdesc_subview.

16、TritonGPUReorderInstructions

55-TritonGPUReduceDataDuplication.mlir vs 56-TritonGPUReorderInstructions.mlir,将ttg.local_alloc放在一起了。

17、TritonNvidiaGPUTMALoweringPass

58-SymbolDCE.mlir vs 59-TritonNvidiaGPUTMALoweringPass.mlir,将tt.make_tensor_descriptortt.descriptor_store lower。

18、TritonGPUFenceInsertion

59-TritonNvidiaGPUTMALoweringPass.mlir vs 60-TritonGPUFenceInsertion.mlir,在做ttng.tc_gen5_mma前插入了ttng.fence_async_shared来控制异步 shared-memory 操作都完成。

19、SCCP

60-TritonGPUFenceInsertion.mlir vs 61-SCCP.mlir,又调整了次顺序。

20、Canonicalizer

61-SCCP.mlir vs 62-Canonicalizer.mlir,上面调整顺序后优化掉了一行ttg.convert_layout

21、本阶段IR

本阶段IR产物为matmul_kernel_make_tensor_desciptor.ttgir,也是62-Canonicalizer.mlir

五、make_llir

要从ttgir到llvm ir。

    def make_llir(self, src, metadata, options, capability):
        ptx_version = get_ptx_version_from_options(options, self.target.arch)

        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()

        nvidia.passes.ttnvgpuir.add_lower_mma(pm)
        passes.ttgpuir.add_combine_tensor_select_and_if(pm)
        passes.ttgpuir.add_allocate_warp_groups(pm)
        passes.convert.add_scf_to_cf(pm)
        passes.ttgpuir.add_allocate_shared_memory(pm)
        nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
        if knobs.compilation.enable_experimental_consan:
            # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
            passes.ttgpuir.add_concurrency_sanitizer(pm)
        passes.ttgpuir.add_allocate_global_scratch_memory(pm)
        nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
        nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
        nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        if not knobs.compilation.disable_line_info:
            passes.llvmir.add_di_scope(pm)
        pm.run(mod)
        # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
        llvm.init_targets()
        context = llvm.context()
        if knobs.compilation.enable_asan:
            raise RuntimeError(
                "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
        llvm_mod = llvm.to_module(mod, context)
        proc = sm_arch_from_capability(capability)
        features = get_features(options, self.target.arch)
        triple = 'nvptx64-nvidia-cuda'
        nvidia.set_short_ptr()
        llvm.attach_datalayout(llvm_mod, triple, proc, features)
        nvidia.set_nvvm_reflect_ftz(llvm_mod)

        if options.extern_libs:
            paths = [path for (name, path) in options.extern_libs]
            llvm.link_extern_libs(llvm_mod, paths)

        llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)

        # Get some metadata
        # warp-specialization mutates num_warps
        total_num_warps = src.get_int_attr("ttg.total-num-warps")
        if total_num_warps is not None:
            metadata["num_warps"] = total_num_warps
        metadata["shared"] = src.get_int_attr("ttg.shared")
        metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
        metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
        metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
        ret = str(llvm_mod)
        del llvm_mod
        del context
        return ret

Pass执行后的IR从63-TritonNvidiaGPUMMALoweringPass.mlir80-LLVMDIScope.mlir

1、TritonGPUAllocateWarpGroups

64-TritonGPUCombineTensorSelectAndIf.mlir vs 65-TritonGPUAllocateWarpGroups.mlir,module上添加了ttg.total-num-warps

2、SCFToControlFlowPass

65-TritonGPUAllocateWarpGroups.mlir vs 66-SCFToControlFlowPass.mlir,将scf.forscf.iflower到cf.brcf.cond_br

3、AllocateSharedMemory

66-SCFToControlFlowPass.mlir vs 67-AllocateSharedMemory.mlirmodule上也添加了ttg.shared大小描述, ttg.shared = 180272 : i32,Op的offset是分配后对应的起始位置。

    ttng.tensormap_create %3, %arg0, [%c32_i32, %c128_i32], [%arg4, %arg3], [%4], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc4)
...
    ttng.tensormap_create %7, %arg1, [%c32_i32, %c64_i32], [%arg5, %arg4], [%8], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc5)
...
    ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)
...
    %17 = ttg.local_alloc {allocation.offset = 180256 : i32} : () -> !ttg.memdesc<2xi64, #shared1, #smem, mutable> loc(#loc13)
...
    %59 = ttg.local_alloc %58 {allocation.offset = 163840 : i32} : (tensor<64x64xf32, #blocked1>) -> !ttg.memdesc<64x64xf32, #shared2, #smem> loc(#loc15)

4、TritonTensorMemoryAllocationPass

67-AllocateSharedMemory.mlir vs 68-TritonTensorMemoryAllocationPass.mlirmodule上添加了ttg.tensor_memory_size = 64

5、TritonGPUGlobalScratchAllocationPass

68-TritonTensorMemoryAllocationPass.mlir vs 69-TritonGPUGlobalScratchAllocationPass.mlir,加上了ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 384 : i32

这里是有用Scratch的,给ttng.tensormap_create做为operand,来描述描述这块 smem 上的数据结构

// old
    %10 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr<i8> loc(#loc6)
    %11 = arith.muli %6, %c4_i64 : i64 loc(#loc6)
    ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)

// new
    %10 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32, ttg.global_scratch_memory_offset = 256 : i32} : !tt.ptr<i8> loc(#loc6)
    %11 = arith.muli %6, %c4_i64 : i64 loc(#loc6)
    ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)

6、TritonGPUProxyFenceInsertion

69-TritonGPUGlobalScratchAllocationPass.mlir vs 70-TritonGPUProxyFenceInsertion.mlir,在%27后插入了 ttng.fence_async_shared {bCluster = false} loc(#loc14)

7、ConvertTritonGPUToLLVM

70-TritonGPUProxyFenceInsertion.mlir vs 71-ConvertTritonGPUToLLVM.mlir,这个Pass非常复杂,需要看每个Op的RewritePattern,这一步转换到线程级别了。

// old
    %10 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32, ttg.global_scratch_memory_offset = 256 : i32} : !tt.ptr<i8> loc(#loc6)
    %11 = arith.muli %6, %c4_i64 : i64 loc(#loc6)
    ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)
    ttng.tensormap_fenceproxy_acquire %10 : !tt.ptr<i8> loc(#loc6)
    %12 = ttng.reinterpret_tensor_descriptor %10 : !tt.ptr<i8> to !tt.tensordesc<tensor<128x64xf32, #shared>> loc(#loc6)
...
#loc6 = loc("/home/ubuntu/triton/matmul.py":30:8)

// new
    %326 = llvm.mlir.constant(256 : i32) : i32 loc(#loc6)
    %327 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32 loc(#loc6)
    %328 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.y"() : () -> i32 loc(#loc6)
    %329 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.z"() : () -> i32 loc(#loc6)
    %330 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.nctaid.x"() : () -> i32 loc(#loc6)
    %331 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.nctaid.y"() : () -> i32 loc(#loc6)
    %332 = llvm.mul %329, %331 : i32 loc(#loc6)
    %333 = llvm.add %328, %332 : i32 loc(#loc6)
    %334 = llvm.mul %333, %330 : i32 loc(#loc6)
    %335 = llvm.add %327, %334 : i32 loc(#loc6)
    %336 = llvm.mlir.constant(384 : i32) : i32 loc(#loc6)
    %337 = llvm.mul %335, %336 : i32 loc(#loc6)
    %338 = llvm.add %337, %326 : i32 loc(#loc6)
    %339 = llvm.getelementptr %arg6[%338] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8 loc(#loc6)
    %340 = llvm.mul %203, %77 : i64 loc(#loc6)
    nvvm.barrier0 loc(#loc6)
    %341 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
    %342 = llvm.mlir.addressof @global_smem : !llvm.ptr<3> loc(#loc)
    %343 = llvm.getelementptr %342[%341] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 loc(#loc6)
    %344 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
    %345 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
    %346 = llvm.and %344, %345 : i32 loc(#loc6)
    %347 = llvm.mlir.constant(32 : i32) : i32 loc(#loc6)
    %348 = llvm.icmp "slt" %346, %347 : i32 loc(#loc6)
    %349 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
    %350 = llvm.getelementptr %343[%346] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i32 loc(#loc6)
    %351 = llvm.mlir.undef : vector<1xi32> loc(#loc6)
    %352 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
    %353 = llvm.insertelement %349, %351[%352 : i32] : vector<1xi32> loc(#loc6)
    %354 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b" %350, %353, %348 : (!llvm.ptr<3>, vector<1xi32>, i1) -> !llvm.void loc(#loc6)
    %355 = llvm.mlir.constant(-1 : i32) : i32 loc(#loc6)
    %356 = llvm.call_intrinsic "llvm.nvvm.bar.warp.sync"(%355) : (i32) -> !llvm.void loc(#loc6)
    %357 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
    %358 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
    %359 = llvm.and %357, %358 : i32 loc(#loc6)
    %360 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
    %361 = llvm.icmp "eq" %359, %360 : i32 loc(#loc6)
    %362 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 tensormap.replace.tile.global_address.shared::cta.b1024.b64 [ $0 + 0 ], $1;", "l,l,b" %343, %arg2, %361 : (!llvm.ptr<3>, !llvm.ptr<1>, i1) -> !llvm.void loc(#loc6)
    %363 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
    %364 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
    %365 = llvm.and %363, %364 : i32 loc(#loc6)
    %366 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
    %367 = llvm.icmp "eq" %365, %366 : i32 loc(#loc6)
    %368 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 tensormap.replace.tile.rank.shared::cta.b1024.b32 [ $0 + 0 ], 0x1;", "l,b" %343, %367 : (!llvm.ptr<3>, i1) -> !llvm.void loc(#loc6)
...
    %435 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
    %436 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
    %437 = llvm.and %435, %436 : i32 loc(#loc6)
    %438 = llvm.mlir.constant(32 : i32) : i32 loc(#loc6)
    %439 = llvm.icmp "slt" %437, %438 : i32 loc(#loc6)
    %440 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [ $0 + 0 ], [ $1 + 0 ], 0x80;", "l,l,b" %339, %343, %439 : (!llvm.ptr<1>, !llvm.ptr<3>, i1) -> !llvm.void loc(#loc6)
    %441 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
    %442 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
    %443 = llvm.and %441, %442 : i32 loc(#loc6)
    %444 = llvm.mlir.constant(32 : i32) : i32 loc(#loc6)
    %445 = llvm.icmp "slt" %443, %444 : i32 loc(#loc6)
    %446 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 fence.proxy.tensormap::generic.acquire.gpu [ $0 + 0 ], 0x80;\0A\09@$2 cp.async.bulk.commit_group ;\0A\09@$3 cp.async.bulk.wait_group.read 0 ;", "l,b,b,b" %339, %445, %445, %445 : (!llvm.ptr<1>, i1, i1, i1) -> !llvm.void loc(#loc6)
    nvvm.barrier0 loc(#loc6)
    %447 = llvm.addrspacecast %339 : !llvm.ptr<1> to !llvm.ptr loc(#loc6)
...
#loc6 = loc("/home/ubuntu/triton/matmul.py":30:8)

8、Canonicalizer

71-ConvertTritonGPUToLLVM.mlir vs 72-Canonicalizer.mlir,规范化Pass,优化生猛,IR从4629行降到了2555行。

9、CSE

72-Canonicalizer.mlir vs 73-CSE.mlir,公共子表达式消除,IR可以从2555行降到1945行。

10、ConvertNVGPUToLLVM

73-CSE.mlir vs 74-ConvertNVGPUToLLVM.mlir,这个Pass会把nvgpu dialect的 Op全部降级,比如nvgpu.tensor_memory_basenvgpu.warp_idnvgpu.fence_async_sharednvgpu.ldmatrix等。

11、Canonicalizer

76-ReconcileUnrealizedCastsPass.mlir vs 77-Canonicalizer.mlir,优化了2条IR。

12、CSE

77-Canonicalizer.mlir vs 78-CSE.mlir,优化了2条IR。

13、LLVMDIScope

79-SymbolDCE vs 80-LLVMDIScope,在 LLVM IR 中附加调试信息作用域(Debug Info Scope) 的 pass,生成调试信息(如 DWARF)以支持源级调试。

14、本阶段IR产物

本阶段IR产物为 matmul_kernel_make_tensor_desciptor.llir

六、make_ptx

这里实际上调用的是LLVM。

    def make_ptx(self, src, metadata, opt, capability):
        ptx_version = get_ptx_version_from_options(opt, self.target.arch)

        triple = 'nvptx64-nvidia-cuda'
        proc = sm_arch_from_capability(capability)
        features = get_features(opt, self.target.arch)
        ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
        # Find kernel names (there should only be one)
        names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
        assert len(names) == 1
        metadata["name"] = names[0]
        # post-process
        ptx_version = f'{ptx_version//10}.{ptx_version%10}'
        ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
        ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
        # Remove the debug flag that prevents ptxas from optimizing the code
        ret = re.sub(r",\s*debug|debug,\s*", "", ret)
        if knobs.nvidia.dump_nvptx:
            print("// -----// NVPTX Dump //----- //")
            print(ret)
        return ret

产物为matmul_kernel_make_tensor_desciptor.ptx。输入文件是1072行,输出文件是1174行,基本一一对应。

1、dot对照

// old
  tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #5, !dbg !21
  tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !21
  %260 = icmp eq i32 %44, 0, !dbg !21
  %261 = and i1 %51, %260, !dbg !21
  br i1 %261, label %262, label %329, !dbg !21

262:                                              ; preds = %7
  %263 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !21
  %264 = extractvalue { i32, i1 } %263, 1, !dbg !21
  %265 = lshr exact i32 ptrtoint (ptr addrspace(3) @global_smem to i32), 4, !dbg !21
  %266 = and i32 %265, 16383, !dbg !21
  %267 = zext nneg i32 %266 to i64, !dbg !21
  %268 = or disjoint i64 %267, 4611686293372403712, !dbg !21
  %269 = lshr exact i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), 4, !dbg !21
  %270 = and i32 %269, 16383, !dbg !21
  %271 = zext nneg i32 %270 to i64, !dbg !21
  %272 = or disjoint i64 %271, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %268, i64 %272, i32 135268624, i1 false, i1 %264) #5, !dbg !21
  %273 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 32), 4, !dbg !21
  %274 = and i32 %273, 16383, !dbg !21
  %275 = zext nneg i32 %274 to i64, !dbg !21
  %276 = or disjoint i64 %275, 4611686293372403712, !dbg !21
  %277 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 32), 4, !dbg !21
  %278 = and i32 %277, 16383, !dbg !21
  %279 = zext nneg i32 %278 to i64, !dbg !21
  %280 = or disjoint i64 %279, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %276, i64 %280, i32 135268624, i1 true, i1 %264) #5, !dbg !21
  %281 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 64), 4, !dbg !21
  %282 = and i32 %281, 16383, !dbg !21
  %283 = zext nneg i32 %282 to i64, !dbg !21
  %284 = or disjoint i64 %283, 4611686293372403712, !dbg !21
  %285 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 64), 4, !dbg !21
  %286 = and i32 %285, 16383, !dbg !21
  %287 = zext nneg i32 %286 to i64, !dbg !21
  %288 = or disjoint i64 %287, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %284, i64 %288, i32 135268624, i1 true, i1 %264) #5, !dbg !21
  %289 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 96), 4, !dbg !21
  %290 = and i32 %289, 16383, !dbg !21
  %291 = zext nneg i32 %290 to i64, !dbg !21
  %292 = or disjoint i64 %291, 4611686293372403712, !dbg !21
  %293 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 96), 4, !dbg !21
  %294 = and i32 %293, 16383, !dbg !21
  %295 = zext nneg i32 %294 to i64, !dbg !21
  %296 = or disjoint i64 %295, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %292, i64 %296, i32 135268624, i1 true, i1 %264) #5, !dbg !21
  %297 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16384), 4, !dbg !21
  %298 = and i32 %297, 16383, !dbg !21
  %299 = zext nneg i32 %298 to i64, !dbg !21
  %300 = or disjoint i64 %299, 4611686293372403712, !dbg !21
  %301 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8192), 4, !dbg !21
  %302 = and i32 %301, 16383, !dbg !21
  %303 = zext nneg i32 %302 to i64, !dbg !21
  %304 = or disjoint i64 %303, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %300, i64 %304, i32 135268624, i1 true, i1 %264) #5, !dbg !21
  %305 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16416), 4, !dbg !21
  %306 = and i32 %305, 16383, !dbg !21
  %307 = zext nneg i32 %306 to i64, !dbg !21
  %308 = or disjoint i64 %307, 4611686293372403712, !dbg !21
  %309 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8224), 4, !dbg !21
  %310 = and i32 %309, 16383, !dbg !21
  %311 = zext nneg i32 %310 to i64, !dbg !21
  %312 = or disjoint i64 %311, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %308, i64 %312, i32 135268624, i1 true, i1 %264) #5, !dbg !21
  %313 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16448), 4, !dbg !21
  %314 = and i32 %313, 16383, !dbg !21
  %315 = zext nneg i32 %314 to i64, !dbg !21
  %316 = or disjoint i64 %315, 4611686293372403712, !dbg !21
  %317 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8256), 4, !dbg !21
  %318 = and i32 %317, 16383, !dbg !21
  %319 = zext nneg i32 %318 to i64, !dbg !21
  %320 = or disjoint i64 %319, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %316, i64 %320, i32 135268624, i1 true, i1 %264) #5, !dbg !21
  %321 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16480), 4, !dbg !21
  %322 = and i32 %321, 16383, !dbg !21
  %323 = zext nneg i32 %322 to i64, !dbg !21
  %324 = or disjoint i64 %323, 4611686293372403712, !dbg !21
  %325 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8288), 4, !dbg !21
  %326 = and i32 %325, 16383, !dbg !21
  %327 = zext nneg i32 %326 to i64, !dbg !21
  %328 = or disjoint i64 %327, 4611686293338849280, !dbg !21
  tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %324, i64 %328, i32 135268624, i1 true, i1 %264) #5, !dbg !21
  tail call void asm sideeffect "@$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l"(i1 %264, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 180256)) #5, !dbg !21
  br label %329, !dbg !21
...
!21 = !DILocation(line: 40, column: 32, scope: !5)

// new
	.loc	1 40 32                         // matmul.py:40:32
	// begin inline asm
	fence.proxy.async.shared::cta;
	// end inline asm
	bar.sync 	0;
	@%p81 bra 	$L__BB0_6;
// %bb.5:                               //   in Loop: Header=BB0_4 Depth=1
	.loc	1 38 24                         // matmul.py:38:24
	shl.b32 	%r446, %r585, 15;
	add.s32 	%r448, %r62, %r446;
	.loc	1 40 32                         // matmul.py:40:32
	elect.sync 	%r449|%p115, -1;
	bfe.u32 	%r450, %r448, 4, 14;
	cvt.u64.u32 	%rd134, %r450;
	or.b64 	%rd117, %rd134, 4611686293372403712;
	mov.b32 	%r431, 135268624;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd117, %rd118, %r431, %p114;
	// end inline asm
	add.s32 	%r451, %r448, 32;
	bfe.u32 	%r452, %r451, 4, 14;
	cvt.u64.u32 	%rd135, %r452;
	or.b64 	%rd119, %rd135, 4611686293372403712;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd119, %rd120, %r431, %p114;
	// end inline asm
	add.s32 	%r453, %r448, 64;
	bfe.u32 	%r454, %r453, 4, 14;
	cvt.u64.u32 	%rd136, %r454;
	or.b64 	%rd121, %rd136, 4611686293372403712;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd121, %rd122, %r431, %p114;
	// end inline asm
	add.s32 	%r455, %r448, 96;
	bfe.u32 	%r456, %r455, 4, 14;
	cvt.u64.u32 	%rd137, %r456;
	or.b64 	%rd123, %rd137, 4611686293372403712;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd123, %rd124, %r431, %p114;
	// end inline asm
	add.s32 	%r457, %r448, 16384;
	bfe.u32 	%r458, %r457, 4, 14;
	cvt.u64.u32 	%rd138, %r458;
	or.b64 	%rd125, %rd138, 4611686293372403712;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd125, %rd126, %r431, %p114;
	// end inline asm
	add.s32 	%r459, %r448, 16416;
	bfe.u32 	%r460, %r459, 4, 14;
	cvt.u64.u32 	%rd139, %r460;
	or.b64 	%rd127, %rd139, 4611686293372403712;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd127, %rd128, %r431, %p114;
	// end inline asm
	add.s32 	%r461, %r448, 16448;
	bfe.u32 	%r462, %r461, 4, 14;
	cvt.u64.u32 	%rd140, %r462;
	or.b64 	%rd129, %rd140, 4611686293372403712;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd129, %rd130, %r431, %p114;
	// end inline asm
	add.s32 	%r463, %r448, 16480;
	bfe.u32 	%r464, %r463, 4, 14;
	cvt.u64.u32 	%rd141, %r464;
	or.b64 	%rd131, %rd141, 4611686293372403712;
	// begin inline asm
	@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd131, %rd132, %r431, %p114;
	// end inline asm
	cvt.u64.u32 	%rd133, %r590;
	// begin inline asm
	@%p115 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [%rd133];
	// end inline asm
	bra.uni 	$L__BB0_6;

2、store对照

// old
  %651 = and i32 %160, 16256, !dbg !26
  %652 = or disjoint i32 %651, %162, !dbg !26
  %653 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %652, !dbg !26
  %654 = insertelement <4 x i32> poison, i32 %587, i64 0, !dbg !26
  %655 = insertelement <4 x i32> %654, i32 %588, i64 1, !dbg !26
  %656 = insertelement <4 x i32> %655, i32 %589, i64 2, !dbg !26
  %657 = insertelement <4 x i32> %656, i32 %590, i64 3, !dbg !26
  store <4 x i32> %657, ptr addrspace(3) %653, align 16, !dbg !26
...
  %729 = xor i32 %652, 112, !dbg !26
  %730 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %729, !dbg !26
  %731 = insertelement <4 x i32> poison, i32 %615, i64 0, !dbg !26
  %732 = insertelement <4 x i32> %731, i32 %616, i64 1, !dbg !26
  %733 = insertelement <4 x i32> %732, i32 %617, i64 2, !dbg !26
  %734 = insertelement <4 x i32> %733, i32 %618, i64 3, !dbg !26
  store <4 x i32> %734, ptr addrspace(3) %730, align 16, !dbg !26
  %735 = getelementptr inbounds nuw i8, ptr addrspace(3) %730, i32 16384, !dbg !26
  %736 = insertelement <4 x i32> poison, i32 %647, i64 0, !dbg !26
  %737 = insertelement <4 x i32> %736, i32 %648, i64 1, !dbg !26
  %738 = insertelement <4 x i32> %737, i32 %649, i64 2, !dbg !26
  %739 = insertelement <4 x i32> %738, i32 %650, i64 3, !dbg !26
  store <4 x i32> %739, ptr addrspace(3) %735, align 16, !dbg !26
  tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #5, !dbg !26
  tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !26
  %740 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !26
  %741 = extractvalue { i32, i1 } %740, 1, !dbg !26
  %742 = and i1 %56, %741, !dbg !26
  tail call void asm sideeffect "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r"(i1 %742, ptr %584, i32 %68, i32 %41, ptr addrspace(3) %60) #5, !dbg !26
  tail call void @llvm.nvvm.cp.async.bulk.commit.group(), !dbg !26
  tail call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0), !dbg !26
  tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !26
...
!26 = !DILocation(line: 43, column: 63, scope: !5)

// new
	.loc	1 43 63                         // matmul.py:43:63
	and.b32 	%r563, %r23, 16256;
	or.b32 	%r564, %r563, %r24;
	add.s32 	%r565, %r62, %r564;
	st.shared.v4.b32 	[%r565], {%r493, %r494, %r495, %r496};
	st.shared.v4.b32 	[%r565+16384], {%r525, %r526, %r527, %r528};
	xor.b32 	%r566, %r564, 16;
	add.s32 	%r567, %r62, %r566;
	st.shared.v4.b32 	[%r567], {%r497, %r498, %r499, %r500};
	st.shared.v4.b32 	[%r567+16384], {%r529, %r530, %r531, %r532};
	xor.b32 	%r568, %r564, 32;
	add.s32 	%r569, %r62, %r568;
	st.shared.v4.b32 	[%r569], {%r501, %r502, %r503, %r504};
	st.shared.v4.b32 	[%r569+16384], {%r533, %r534, %r535, %r536};
	xor.b32 	%r570, %r564, 48;
	add.s32 	%r571, %r62, %r570;
	st.shared.v4.b32 	[%r571], {%r505, %r506, %r507, %r508};
	st.shared.v4.b32 	[%r571+16384], {%r537, %r538, %r539, %r540};
	xor.b32 	%r572, %r564, 64;
	add.s32 	%r573, %r62, %r572;
	st.shared.v4.b32 	[%r573], {%r509, %r510, %r511, %r512};
	st.shared.v4.b32 	[%r573+16384], {%r541, %r542, %r543, %r544};
	xor.b32 	%r574, %r564, 80;
	add.s32 	%r575, %r62, %r574;
	st.shared.v4.b32 	[%r575], {%r513, %r514, %r515, %r516};
	st.shared.v4.b32 	[%r575+16384], {%r545, %r546, %r547, %r548};
	xor.b32 	%r576, %r564, 96;
	add.s32 	%r577, %r62, %r576;
	st.shared.v4.b32 	[%r577], {%r517, %r518, %r519, %r520};
	st.shared.v4.b32 	[%r577+16384], {%r549, %r550, %r551, %r552};
	xor.b32 	%r578, %r564, 112;
	add.s32 	%r579, %r62, %r578;
	st.shared.v4.b32 	[%r579], {%r521, %r522, %r523, %r524};
	st.shared.v4.b32 	[%r579+16384], {%r553, %r554, %r555, %r556};
	// begin inline asm
	fence.proxy.async.shared::cta;
	// end inline asm
	bar.sync 	0;
	elect.sync 	%r580|%p153, -1;
	and.pred 	%p150, %p73, %p153;
	// begin inline asm
	@%p150 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd144, {%r558, %r559}], [%r560];
	// end inline asm
	cp.async.bulk.commit_group;
	cp.async.bulk.wait_group.read 	0;
	bar.sync 	0;

3、load对照

// old
  tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !23
  %331 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !23
  %332 = extractvalue { i32, i1 } %331, 1, !dbg !23
  %333 = and i1 %82, %332, !dbg !23
  %334 = and i1 %56, %333, !dbg !23
  %335 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %59, !dbg !23
  %336 = or disjoint i32 %61, 128, !dbg !23
  tail call void asm sideeffect "@$0 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r"(i1 %334, ptr addrspace(3) %335, ptr %29, i32 %336, i32 %41, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 180240)) #5, !dbg !23
...
!23 = !DILocation(line: 38, column: 24, scope: !5)

// new
	.loc	1 38 24                         // matmul.py:38:24
	bar.sync 	0;
	elect.sync 	%r341|%p107, -1;
	and.pred 	%p108, %p103, %p107;
	and.pred 	%p101, %p73, %p108;
	shl.b32 	%r342, %r9, 2;
	add.s32 	%r343, %r62, %r342;
	add.s32 	%r330, %r343, 65536;
	or.b32 	%r331, %r159, 128;
	// begin inline asm
	@%p101 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r330], [%rd66, {%r331, %r559}], [%r329];
	// end inline asm

七、系列文章

深度剖析 Triton编译器 MatMul优化(二)—— MMA

深度剖析 Triton编译器 MatMul优化(一)—— FMA

浅析 Triton 执行流程

posted @ 2025-07-01 07:07  暴力都不会的蒟蒻  阅读(212)  评论(0)    收藏  举报