深度剖析 Triton编译器 MatMul优化(三)—— TMA
深度剖析 Triton编译器 MatMul优化(二)—— MMA 我们介绍了很容易就拿到性能的tl.dot操作,生成的是tcgen05.mma.cta_group::1.kind::tf32和cp.async.cg.shared.global指令,这次我们来看TMA,生成的是cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes和ldmatrix.sync.aligned.m8n8.x4.shared.b16指令。
本文所用用Triton的 commit为bc75dd0(Jun 27, 2025) 的版本。所有IR和kernel文件均已上传至Github。sBobHuang/Triton-blog-file。本系列相关文章
深度剖析 Triton编译器 MatMul优化(二)—— MMA
深度剖析 Triton编译器 MatMul优化(一)—— FMA
一、matmul Triton kernel
1、kernel书写
Triton kernel如下所示,矩阵a大小为M*N,矩阵b大小为N*K,结果矩阵c为M*K。完整可运行代码在matmul-with-tma-v2.py
# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl
@triton.jit
def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr, #
M, N, K, #
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, #
):
a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
pid_m = tl.program_id(axis=0)
pid_k = tl.program_id(axis=1)
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[M, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
c_desc = tl.make_tensor_descriptor(
c_ptr,
shape=[M, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for n in range(tl.cdiv(N, BLOCK_SIZE_N)):
a = a_desc.load([pid_m * BLOCK_SIZE_M, n * BLOCK_SIZE_N])
b = b_desc.load([n * BLOCK_SIZE_N, pid_k * BLOCK_SIZE_K])
accumulator = tl.dot(a, b, acc=accumulator)
accumulator = accumulator.to(a_desc.dtype)
c_desc.store([pid_m * BLOCK_SIZE_M, pid_k * BLOCK_SIZE_K], accumulator)
# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(K, META['BLOCK_SIZE_K']), )
# Leading dimensions must be multiples of 16-byte strides
# if M % 4 == 0 and N % 4 == 0 and K % 4 == 0:
import torch
# TMA descriptors require a global memory allocation
def alloc_fn(size, alignment, stream):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
matmul_kernel_make_tensor_desciptor[grid](
a_ptr, b_ptr, c_ptr,
M, N, K,
BLOCK_SIZE_M=128,
BLOCK_SIZE_K=64,
BLOCK_SIZE_N=64,
)
tl.make_tensor_descriptor是3.4.0的API,3.3.1请使用tl._experimental_make_tensor_descriptor。
2、kernel简析
这个kernel算是相当直观了,a确定好 [BLOCK_SIZE_M, BLOCK_SIZE_N]的块,然后每次从中拿第[pid_m * BLOCK_SIZE_M, n * BLOCK_SIZE_N]的小块,b同理从第[n * BLOCK_SIZE_N, pid_k * BLOCK_SIZE_K]的块dot即可,最后直接存,我们可以看到没有mask。所以tl.make_tensor_descriptor是有限制的,需要手动padding,然后注意Leading dimensions must be multiples of 16-byte strides 16B对齐(最小搬运大小)。以前有一个tl.make_block_ptr,能用但是store还是得按行mask,matmul-with-block.py。
二、ast_to_ttir
使用JIT装饰器遍历Python AST,最后调用MLIR的self.create<。
1、循环IR
得到的ttir比较冗余,全部IR在01-source.mlir,我们挑其中的循环看一下,如下所示。
%15 = scf.for %arg6 = %11 to %12 step %13 iter_args(%arg7 = %8) -> (tensor<128x64xf32>) : i32 {
...
%36 = arith.muli %0, %c128_i32_7 : i32 loc(#loc9)
%c64_i32_10 = arith.constant 64 : i32 loc(#loc10)
%c64_i32_11 = arith.constant 64 : i32 loc(#loc10)
...
%43 = arith.muli %arg6, %c64_i32_11 : i32 loc(#loc10)
%44 = tt.descriptor_load %3[%36, %43] : !tt.tensordesc<tensor<128x64xf32>> -> tensor<128x64xf32> loc(#loc11)
%c64_i32_14 = arith.constant 64 : i32 loc(#loc12)
%c64_i32_15 = arith.constant 64 : i32 loc(#loc12)
...
%51 = arith.muli %arg6, %c64_i32_15 : i32 loc(#loc12)
%c64_i32_18 = arith.constant 64 : i32 loc(#loc13)
%c64_i32_19 = arith.constant 64 : i32 loc(#loc13)
...
%58 = arith.muli %1, %c64_i32_19 : i32 loc(#loc13)
%59 = tt.descriptor_load %5[%51, %58] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32> loc(#loc14)
%cst = arith.constant 0.000000e+00 : f32 loc(#loc15)
%60 = tt.dot %44, %59, %arg7, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc15)
scf.yield %60 : tensor<128x64xf32> loc(#loc16)
} loc(#loc8)
2、load a
这个load非常清爽简单,因为就是从大的里面取了一块。
%51 = arith.muli %arg6, %c64_i32_15 : i32 loc(#loc12)
...
%58 = arith.muli %1, %c64_i32_19 : i32 loc(#loc13)
%59 = tt.descriptor_load %5[%51, %58] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32> loc(#loc14)
%5 是通过make_tensor 来表示的
%4 = arith.extsi %arg5 : i32 to i64 loc(#loc4)
%c1_i64_0 = arith.constant 1 : i64 loc(#loc4)
%5 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%4, %c1_i64_0] : <f32>, <tensor<64x64xf32>> loc(#loc4)
其Op定义如下所示,抽象了block tensor的访问,其实是subview
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
Pure,
SameVariadicOperandSize,
]> {
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";
let description = [{
`tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size,
and returns a descriptor object which can be used to load/store from the tensor in global memory.
}];
let arguments = (ins
TT_Ptr:$base,
Variadic<I32>:$shape,
Variadic<I64>:$strides
);
let results = (outs TT_TensorDescType:$result);
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
let builders = [
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger)>
];
let extraClassDeclaration = [{
ArrayRef<int64_t> getTensorShape() {
return getType().getBlockType().getShape();
}
}];
}
3、其他部分
store也用了对应的tt.descriptor_store,都是Python ast自己处理的
%22 = arith.muli %0, %c128_i32_2 : i32 loc(#loc17)
...
%29 = arith.muli %1, %c64_i32_3 : i32 loc(#loc18)
tt.descriptor_store %7[%22, %29], %15 : !tt.tensordesc<tensor<128x64xf32>>, tensor<128x64xf32> loc(#loc19)
tt.return loc(#loc20)
你还可以丢给chatgpt解读01-source.mlir。此时是完全符合Python DSL语义的,无任何优化。
三、 make_ttir
这是将Python ast得到的MLIR简化的阶段,我们将执行如下流程。
@staticmethod
def make_ttir(mod, metadata, opt, capability):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
if capability // 10 < 9:
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
pm.run(mod)
return mod
Pass执行后IR从02-Inliner.mlir到13-TritonLoopUnroll.mlir。
1、Canonicalizer
规范化是一个通用Pass,你会看到它经常出现,它可能存在的操作包含消除冗余操作、简化匹配模式、折叠常量计算、应用定义的Op的canonicalization。
02-Inliner.mlir vs 03-Canonicalizer.mlir,对 triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了化简。
03-Canonicalizer.mlir vs 04-Canonicalizer.mlir,对 triton.language.standard.cdiv__i32__ 做了进一步化简。
04-Canonicalizer.mlir vs 05-Canonicalizer.mlir,变化特别大,循环中没使用的Op都去掉了,并且对triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了inline。
05-Canonicalizer.mlir vs 06-Canonicalizer.mlir,进一步化简。
2、CSE
10-TritonReorderBroadcast.mlir vs 11-CSE.mlir,Common Subexpression Elimination,公共子表达式消除,比如arith.extsi就有很多重复的。
3、本阶段IR产物
本阶段IR产物为matmul_kernel_make_tensor_desciptor.ttir,也是13-TritonLoopUnroll.mlir。
四、 make_ttgir
这个阶段Pass比较多,但是对我们源码产生变化的也能接受。
@staticmethod
def make_ttgir(mod, metadata, opt, capability):
# Set maxnreg on all kernels, if it was provided.
if opt.maxnreg is not None:
mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
cluster_info = nvidia.ClusterInfo()
if opt.cluster_dims is not None:
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
pm = ir.pass_manager(mod.context)
dump_enabled = pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
# optimize TTGIR
passes.ttgpuir.add_coalesce(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_f32_dot_tc(pm)
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
passes.ttir.add_loop_aware_cse(pm)
if capability // 10 in [8, 9]:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
elif capability // 10 >= 10:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.ttgpuir.add_optimize_accumulator_init(pm)
passes.ttgpuir.add_hoist_tmem_alloc(pm)
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
else:
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_coalesce_async_copy(pm)
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.common.add_symbol_dce(pm)
if capability // 10 >= 9:
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
passes.common.add_sccp(pm)
passes.common.add_canonicalizer(pm)
pm.run(mod)
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
tensordesc_meta = mod.get_tensordesc_metadata()
metadata["tensordesc_meta"] = tensordesc_meta
return mod
Pass执行后IR从14-ConvertTritonToTritonGPU.mlir到62-Canonicalizer.mlir。
1、ConvertTritonToTritonGPU
13-TritonLoopUnroll.mlir vs 14-ConvertTritonToTritonGPU.mlir,这里主要是加上了一些layout。
2、TritonGPURemoveLayoutConversions
17-TritonGPUPlanCTAPass.mlir vs 18-TritonGPURemoveLayoutConversions.mlir,去除多余的convert_layout,对iter_args的layout做了交换,这样循环内少了2次ttg.convert_layout,循环结束后多1次ttg.convert_layout。
// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
...
%9 = scf.for %arg6 = %c0_i32 to %8 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x64xf32, #blocked>) : i32 {}
// new
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
...
%9 = scf.for %arg6 = %c0_i32 to %8 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x64xf32, #blocked>) : i32 {}
3、TritonGPUAccelerateMatmul
19-TritonGPUOptimizeThreadLocality.mlir vs 20-TritonGPUAccelerateMatmul.mlir,这里对tt.dot做了处理。
// old
%16 = arith.muli %1, %c64_i32 : i32 loc(#loc14)
%17 = tt.descriptor_load %5[%14, %16] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32, #blocked1> loc(#loc15)
%18 = ttg.convert_layout %15 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> loc(#loc13)
%19 = ttg.convert_layout %17 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> loc(#loc15)
%20 = tt.dot %18, %19, %arg7, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> loc(#loc16)
scf.yield %20 : tensor<128x64xf32, #blocked> loc(#loc17)
// new
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#loc = loc("/home/ubuntu/triton/matmul.py":6:0)
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}>
#smem = #ttg.shared_memory
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
...
%16 = ttg.local_alloc %15 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc13)
%17 = arith.muli %1, %c64_i32 : i32 loc(#loc14)
%18 = tt.descriptor_load %5[%14, %17] : !tt.tensordesc<tensor<64x64xf32>> -> tensor<64x64xf32, #blocked1> loc(#loc15)
%19 = ttg.local_alloc %18 : (tensor<64x64xf32, #blocked1>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc15)
%20 = ttg.convert_layout %arg7 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> loc(#loc16)
%result, %token = ttng.tmem_alloc %20 : (tensor<128x64xf32, #blocked2>) -> (!ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) loc(#loc16)
%21 = ttng.tc_gen5_mma %16, %19, %result[%token], %true, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc16)
%result_0, %token_1 = ttng.tmem_load %result[%21] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked2> loc(#loc16)
%22 = ttg.convert_layout %result_0 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> loc(#loc16)
scf.yield %22 : tensor<128x64xf32, #blocked> loc(#loc17)
4、TritonGPURemoveLayoutConversions
20-TritonGPUAccelerateMatmul.mlir vs 21-TritonGPURemoveLayoutConversions.mlir,进一步合并layout。
// old
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// new
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
5、TritonNvidiaGPUOptimizeDescriptorEncodingPass
23-Canonicalizer.mlir vs 24-TritonNvidiaGPUOptimizeDescriptorEncodingPass.mlir,为tensordesc加上了#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> 标签。
6、TritonLoopInvariantCodeMotion
27-Canonicalizer.mlir vs 28-TritonLoopInvariantCodeMotion.mlir,循环不变量外提,即那两个offset的计算。
7、TritonGPUOptimizeAccumulatorInit
28-TritonLoopInvariantCodeMotion.mlir vs 29-TritonGPUOptimizeAccumulatorInit.mlir,变换修改了ttng.tc_gen5_mma的operand。
8、TritonGPUHoistTMEMAlloc
29-TritonGPUOptimizeAccumulatorInit.mlir vs 30-TritonGPUHoistTMEMAlloc.mlir,将tensor memory的alloc挪到循环外面,并初始化,避免不必要的开销。
9、TritonGPUAssignLatencies
31-TritonNvidiaGPUPromoteLHSToTMemPass.mlir vs 32-TritonGPUAssignLatencies.mlir,为相应的Op打上tt.latency的Attribute。
10、TritonGPUScheduleLoops
32-TritonGPUAssignLatencies.mlir vs 33-TritonGPUScheduleLoops.mlir,software pipeline loop scheduling,我最爱的软流水是这个Pass,涉及到GPU核心了,隐藏延迟。这里把tt.latency转换为loop.cluster和loop.stage的Attribute。
11、SCCP
37-TritonGPURewritePartitionDependencies.mlir vs 38-SCCP.mlir,Sparse Conditional Constant Propagation,CFG中传播常量,依旧是调整顺序。
12、CSE
38-SCCP.mlir vs 39-CSE.mlir,Common Subexpression Elimination,公共子表达式消除,
13、TritonGPUPipeline
42-TritonGPUScheduleLoops.mlir vs 43-TritonGPUPipeline.mlir,GPU流水,软件流水生效了,做了硬件层面的调度,甚至把第一次循环给提出来了,其实就是prologue。
这里numStages默认值为3,即流水得做 3 个,这样可以 overlap prefetching / computing / waiting 三阶段。
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">,
Option<"dumpIntermediateSteps", "dump-intermediate-steps",
"bool", /*default*/"false",
"Dump intermediate steps">
];
14、TritonNvidiaGPURemoveTMEMTokensPass
44-TritonGPUCombineTensorSelectAndIf.mlir vs 45-TritonNvidiaGPURemoveTMEMTokensPass.mlir,会多了个ub.poison,下一个CanonicalizerPass就把这个和之前的4个ub.poison全都干掉了。
15、TritonLoopAwareCSE
46-Canonicalizer.mlir vs 47-TritonLoopAwareCSE.mlir,去掉了部分ttg.memdesc_subview.
16、TritonGPUReorderInstructions
55-TritonGPUReduceDataDuplication.mlir vs 56-TritonGPUReorderInstructions.mlir,将ttg.local_alloc放在一起了。
17、TritonNvidiaGPUTMALoweringPass
58-SymbolDCE.mlir vs 59-TritonNvidiaGPUTMALoweringPass.mlir,将tt.make_tensor_descriptor、tt.descriptor_store lower。
18、TritonGPUFenceInsertion
59-TritonNvidiaGPUTMALoweringPass.mlir vs 60-TritonGPUFenceInsertion.mlir,在做ttng.tc_gen5_mma前插入了ttng.fence_async_shared来控制异步 shared-memory 操作都完成。
19、SCCP
60-TritonGPUFenceInsertion.mlir vs 61-SCCP.mlir,又调整了次顺序。
20、Canonicalizer
61-SCCP.mlir vs 62-Canonicalizer.mlir,上面调整顺序后优化掉了一行ttg.convert_layout。
21、本阶段IR
本阶段IR产物为matmul_kernel_make_tensor_desciptor.ttgir,也是62-Canonicalizer.mlir。
五、make_llir
要从ttgir到llvm ir。
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
if knobs.compilation.enable_experimental_consan:
# Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
passes.ttgpuir.add_concurrency_sanitizer(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if not knobs.compilation.disable_line_info:
passes.llvmir.add_di_scope(pm)
pm.run(mod)
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
if knobs.compilation.enable_asan:
raise RuntimeError(
"Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
llvm_mod = llvm.to_module(mod, context)
proc = sm_arch_from_capability(capability)
features = get_features(options, self.target.arch)
triple = 'nvptx64-nvidia-cuda'
nvidia.set_short_ptr()
llvm.attach_datalayout(llvm_mod, triple, proc, features)
nvidia.set_nvvm_reflect_ftz(llvm_mod)
if options.extern_libs:
paths = [path for (name, path) in options.extern_libs]
llvm.link_extern_libs(llvm_mod, paths)
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
# Get some metadata
# warp-specialization mutates num_warps
total_num_warps = src.get_int_attr("ttg.total-num-warps")
if total_num_warps is not None:
metadata["num_warps"] = total_num_warps
metadata["shared"] = src.get_int_attr("ttg.shared")
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
ret = str(llvm_mod)
del llvm_mod
del context
return ret
Pass执行后的IR从63-TritonNvidiaGPUMMALoweringPass.mlir到80-LLVMDIScope.mlir。
1、TritonGPUAllocateWarpGroups
64-TritonGPUCombineTensorSelectAndIf.mlir vs 65-TritonGPUAllocateWarpGroups.mlir,module上添加了ttg.total-num-warps。
2、SCFToControlFlowPass
65-TritonGPUAllocateWarpGroups.mlir vs 66-SCFToControlFlowPass.mlir,将scf.for和scf.iflower到cf.br及cf.cond_br。
3、AllocateSharedMemory
66-SCFToControlFlowPass.mlir vs 67-AllocateSharedMemory.mlir,module上也添加了ttg.shared大小描述, ttg.shared = 180272 : i32,Op的offset是分配后对应的起始位置。
ttng.tensormap_create %3, %arg0, [%c32_i32, %c128_i32], [%arg4, %arg3], [%4], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc4)
...
ttng.tensormap_create %7, %arg1, [%c32_i32, %c64_i32], [%arg5, %arg4], [%8], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc5)
...
ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)
...
%17 = ttg.local_alloc {allocation.offset = 180256 : i32} : () -> !ttg.memdesc<2xi64, #shared1, #smem, mutable> loc(#loc13)
...
%59 = ttg.local_alloc %58 {allocation.offset = 163840 : i32} : (tensor<64x64xf32, #blocked1>) -> !ttg.memdesc<64x64xf32, #shared2, #smem> loc(#loc15)
4、TritonTensorMemoryAllocationPass
67-AllocateSharedMemory.mlir vs 68-TritonTensorMemoryAllocationPass.mlir,module上添加了ttg.tensor_memory_size = 64。
5、TritonGPUGlobalScratchAllocationPass
68-TritonTensorMemoryAllocationPass.mlir vs 69-TritonGPUGlobalScratchAllocationPass.mlir,加上了ttg.global_scratch_memory_alignment = 128 : i32, ttg.global_scratch_memory_size = 384 : i32。
这里是有用Scratch的,给ttng.tensormap_create做为operand,来描述描述这块 smem 上的数据结构。
// old
%10 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr<i8> loc(#loc6)
%11 = arith.muli %6, %c4_i64 : i64 loc(#loc6)
ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)
// new
%10 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32, ttg.global_scratch_memory_offset = 256 : i32} : !tt.ptr<i8> loc(#loc6)
%11 = arith.muli %6, %c4_i64 : i64 loc(#loc6)
ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)
6、TritonGPUProxyFenceInsertion
69-TritonGPUGlobalScratchAllocationPass.mlir vs 70-TritonGPUProxyFenceInsertion.mlir,在%27后插入了 ttng.fence_async_shared {bCluster = false} loc(#loc14)。
7、ConvertTritonGPUToLLVM
70-TritonGPUProxyFenceInsertion.mlir vs 71-ConvertTritonGPUToLLVM.mlir,这个Pass非常复杂,需要看每个Op的RewritePattern,这一步转换到线程级别了。
// old
%10 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32, ttg.global_scratch_memory_offset = 256 : i32} : !tt.ptr<i8> loc(#loc6)
%11 = arith.muli %6, %c4_i64 : i64 loc(#loc6)
ttng.tensormap_create %10, %arg2, [%c32_i32, %c128_i32], [%arg5, %arg3], [%11], [%c1_i32, %c1_i32] {allocation.offset = 0 : i32, elem_type = 7 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<f32>, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc6)
ttng.tensormap_fenceproxy_acquire %10 : !tt.ptr<i8> loc(#loc6)
%12 = ttng.reinterpret_tensor_descriptor %10 : !tt.ptr<i8> to !tt.tensordesc<tensor<128x64xf32, #shared>> loc(#loc6)
...
#loc6 = loc("/home/ubuntu/triton/matmul.py":30:8)
// new
%326 = llvm.mlir.constant(256 : i32) : i32 loc(#loc6)
%327 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32 loc(#loc6)
%328 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.y"() : () -> i32 loc(#loc6)
%329 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.z"() : () -> i32 loc(#loc6)
%330 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.nctaid.x"() : () -> i32 loc(#loc6)
%331 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.nctaid.y"() : () -> i32 loc(#loc6)
%332 = llvm.mul %329, %331 : i32 loc(#loc6)
%333 = llvm.add %328, %332 : i32 loc(#loc6)
%334 = llvm.mul %333, %330 : i32 loc(#loc6)
%335 = llvm.add %327, %334 : i32 loc(#loc6)
%336 = llvm.mlir.constant(384 : i32) : i32 loc(#loc6)
%337 = llvm.mul %335, %336 : i32 loc(#loc6)
%338 = llvm.add %337, %326 : i32 loc(#loc6)
%339 = llvm.getelementptr %arg6[%338] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8 loc(#loc6)
%340 = llvm.mul %203, %77 : i64 loc(#loc6)
nvvm.barrier0 loc(#loc6)
%341 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
%342 = llvm.mlir.addressof @global_smem : !llvm.ptr<3> loc(#loc)
%343 = llvm.getelementptr %342[%341] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 loc(#loc6)
%344 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
%345 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
%346 = llvm.and %344, %345 : i32 loc(#loc6)
%347 = llvm.mlir.constant(32 : i32) : i32 loc(#loc6)
%348 = llvm.icmp "slt" %346, %347 : i32 loc(#loc6)
%349 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
%350 = llvm.getelementptr %343[%346] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i32 loc(#loc6)
%351 = llvm.mlir.undef : vector<1xi32> loc(#loc6)
%352 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
%353 = llvm.insertelement %349, %351[%352 : i32] : vector<1xi32> loc(#loc6)
%354 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b" %350, %353, %348 : (!llvm.ptr<3>, vector<1xi32>, i1) -> !llvm.void loc(#loc6)
%355 = llvm.mlir.constant(-1 : i32) : i32 loc(#loc6)
%356 = llvm.call_intrinsic "llvm.nvvm.bar.warp.sync"(%355) : (i32) -> !llvm.void loc(#loc6)
%357 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
%358 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
%359 = llvm.and %357, %358 : i32 loc(#loc6)
%360 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
%361 = llvm.icmp "eq" %359, %360 : i32 loc(#loc6)
%362 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 tensormap.replace.tile.global_address.shared::cta.b1024.b64 [ $0 + 0 ], $1;", "l,l,b" %343, %arg2, %361 : (!llvm.ptr<3>, !llvm.ptr<1>, i1) -> !llvm.void loc(#loc6)
%363 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
%364 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
%365 = llvm.and %363, %364 : i32 loc(#loc6)
%366 = llvm.mlir.constant(0 : i32) : i32 loc(#loc6)
%367 = llvm.icmp "eq" %365, %366 : i32 loc(#loc6)
%368 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 tensormap.replace.tile.rank.shared::cta.b1024.b32 [ $0 + 0 ], 0x1;", "l,b" %343, %367 : (!llvm.ptr<3>, i1) -> !llvm.void loc(#loc6)
...
%435 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
%436 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
%437 = llvm.and %435, %436 : i32 loc(#loc6)
%438 = llvm.mlir.constant(32 : i32) : i32 loc(#loc6)
%439 = llvm.icmp "slt" %437, %438 : i32 loc(#loc6)
%440 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [ $0 + 0 ], [ $1 + 0 ], 0x80;", "l,l,b" %339, %343, %439 : (!llvm.ptr<1>, !llvm.ptr<3>, i1) -> !llvm.void loc(#loc6)
%441 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc6)
%442 = llvm.mlir.constant(127 : i32) : i32 loc(#loc6)
%443 = llvm.and %441, %442 : i32 loc(#loc6)
%444 = llvm.mlir.constant(32 : i32) : i32 loc(#loc6)
%445 = llvm.icmp "slt" %443, %444 : i32 loc(#loc6)
%446 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$1 fence.proxy.tensormap::generic.acquire.gpu [ $0 + 0 ], 0x80;\0A\09@$2 cp.async.bulk.commit_group ;\0A\09@$3 cp.async.bulk.wait_group.read 0 ;", "l,b,b,b" %339, %445, %445, %445 : (!llvm.ptr<1>, i1, i1, i1) -> !llvm.void loc(#loc6)
nvvm.barrier0 loc(#loc6)
%447 = llvm.addrspacecast %339 : !llvm.ptr<1> to !llvm.ptr loc(#loc6)
...
#loc6 = loc("/home/ubuntu/triton/matmul.py":30:8)
8、Canonicalizer
71-ConvertTritonGPUToLLVM.mlir vs 72-Canonicalizer.mlir,规范化Pass,优化生猛,IR从4629行降到了2555行。
9、CSE
72-Canonicalizer.mlir vs 73-CSE.mlir,公共子表达式消除,IR可以从2555行降到1945行。
10、ConvertNVGPUToLLVM
73-CSE.mlir vs 74-ConvertNVGPUToLLVM.mlir,这个Pass会把nvgpu dialect的 Op全部降级,比如nvgpu.tensor_memory_base,nvgpu.warp_id,nvgpu.fence_async_shared,nvgpu.ldmatrix等。
11、Canonicalizer
76-ReconcileUnrealizedCastsPass.mlir vs 77-Canonicalizer.mlir,优化了2条IR。
12、CSE
77-Canonicalizer.mlir vs 78-CSE.mlir,优化了2条IR。
13、LLVMDIScope
79-SymbolDCE vs 80-LLVMDIScope,在 LLVM IR 中附加调试信息作用域(Debug Info Scope) 的 pass,生成调试信息(如 DWARF)以支持源级调试。
14、本阶段IR产物
本阶段IR产物为 matmul_kernel_make_tensor_desciptor.llir
六、make_ptx
这里实际上调用的是LLVM。
def make_ptx(self, src, metadata, opt, capability):
ptx_version = get_ptx_version_from_options(opt, self.target.arch)
triple = 'nvptx64-nvidia-cuda'
proc = sm_arch_from_capability(capability)
features = get_features(opt, self.target.arch)
ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
# Find kernel names (there should only be one)
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
assert len(names) == 1
metadata["name"] = names[0]
# post-process
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
# Remove the debug flag that prevents ptxas from optimizing the code
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
if knobs.nvidia.dump_nvptx:
print("// -----// NVPTX Dump //----- //")
print(ret)
return ret
产物为matmul_kernel_make_tensor_desciptor.ptx。输入文件是1072行,输出文件是1174行,基本一一对应。
1、dot对照
// old
tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #5, !dbg !21
tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !21
%260 = icmp eq i32 %44, 0, !dbg !21
%261 = and i1 %51, %260, !dbg !21
br i1 %261, label %262, label %329, !dbg !21
262: ; preds = %7
%263 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !21
%264 = extractvalue { i32, i1 } %263, 1, !dbg !21
%265 = lshr exact i32 ptrtoint (ptr addrspace(3) @global_smem to i32), 4, !dbg !21
%266 = and i32 %265, 16383, !dbg !21
%267 = zext nneg i32 %266 to i64, !dbg !21
%268 = or disjoint i64 %267, 4611686293372403712, !dbg !21
%269 = lshr exact i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), 4, !dbg !21
%270 = and i32 %269, 16383, !dbg !21
%271 = zext nneg i32 %270 to i64, !dbg !21
%272 = or disjoint i64 %271, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %268, i64 %272, i32 135268624, i1 false, i1 %264) #5, !dbg !21
%273 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 32), 4, !dbg !21
%274 = and i32 %273, 16383, !dbg !21
%275 = zext nneg i32 %274 to i64, !dbg !21
%276 = or disjoint i64 %275, 4611686293372403712, !dbg !21
%277 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 32), 4, !dbg !21
%278 = and i32 %277, 16383, !dbg !21
%279 = zext nneg i32 %278 to i64, !dbg !21
%280 = or disjoint i64 %279, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %276, i64 %280, i32 135268624, i1 true, i1 %264) #5, !dbg !21
%281 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 64), 4, !dbg !21
%282 = and i32 %281, 16383, !dbg !21
%283 = zext nneg i32 %282 to i64, !dbg !21
%284 = or disjoint i64 %283, 4611686293372403712, !dbg !21
%285 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 64), 4, !dbg !21
%286 = and i32 %285, 16383, !dbg !21
%287 = zext nneg i32 %286 to i64, !dbg !21
%288 = or disjoint i64 %287, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %284, i64 %288, i32 135268624, i1 true, i1 %264) #5, !dbg !21
%289 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 96), 4, !dbg !21
%290 = and i32 %289, 16383, !dbg !21
%291 = zext nneg i32 %290 to i64, !dbg !21
%292 = or disjoint i64 %291, 4611686293372403712, !dbg !21
%293 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 96), 4, !dbg !21
%294 = and i32 %293, 16383, !dbg !21
%295 = zext nneg i32 %294 to i64, !dbg !21
%296 = or disjoint i64 %295, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %292, i64 %296, i32 135268624, i1 true, i1 %264) #5, !dbg !21
%297 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16384), 4, !dbg !21
%298 = and i32 %297, 16383, !dbg !21
%299 = zext nneg i32 %298 to i64, !dbg !21
%300 = or disjoint i64 %299, 4611686293372403712, !dbg !21
%301 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8192), 4, !dbg !21
%302 = and i32 %301, 16383, !dbg !21
%303 = zext nneg i32 %302 to i64, !dbg !21
%304 = or disjoint i64 %303, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %300, i64 %304, i32 135268624, i1 true, i1 %264) #5, !dbg !21
%305 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16416), 4, !dbg !21
%306 = and i32 %305, 16383, !dbg !21
%307 = zext nneg i32 %306 to i64, !dbg !21
%308 = or disjoint i64 %307, 4611686293372403712, !dbg !21
%309 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8224), 4, !dbg !21
%310 = and i32 %309, 16383, !dbg !21
%311 = zext nneg i32 %310 to i64, !dbg !21
%312 = or disjoint i64 %311, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %308, i64 %312, i32 135268624, i1 true, i1 %264) #5, !dbg !21
%313 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16448), 4, !dbg !21
%314 = and i32 %313, 16383, !dbg !21
%315 = zext nneg i32 %314 to i64, !dbg !21
%316 = or disjoint i64 %315, 4611686293372403712, !dbg !21
%317 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8256), 4, !dbg !21
%318 = and i32 %317, 16383, !dbg !21
%319 = zext nneg i32 %318 to i64, !dbg !21
%320 = or disjoint i64 %319, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %316, i64 %320, i32 135268624, i1 true, i1 %264) #5, !dbg !21
%321 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) @global_smem to i32), i32 16480), 4, !dbg !21
%322 = and i32 %321, 16383, !dbg !21
%323 = zext nneg i32 %322 to i64, !dbg !21
%324 = or disjoint i64 %323, 4611686293372403712, !dbg !21
%325 = lshr exact i32 add (i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 147456) to i32), i32 8288), 4, !dbg !21
%326 = and i32 %325, 16383, !dbg !21
%327 = zext nneg i32 %326 to i64, !dbg !21
%328 = or disjoint i64 %327, 4611686293338849280, !dbg !21
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %10, i64 %324, i64 %328, i32 135268624, i1 true, i1 %264) #5, !dbg !21
tail call void asm sideeffect "@$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l"(i1 %264, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 180256)) #5, !dbg !21
br label %329, !dbg !21
...
!21 = !DILocation(line: 40, column: 32, scope: !5)
// new
.loc 1 40 32 // matmul.py:40:32
// begin inline asm
fence.proxy.async.shared::cta;
// end inline asm
bar.sync 0;
@%p81 bra $L__BB0_6;
// %bb.5: // in Loop: Header=BB0_4 Depth=1
.loc 1 38 24 // matmul.py:38:24
shl.b32 %r446, %r585, 15;
add.s32 %r448, %r62, %r446;
.loc 1 40 32 // matmul.py:40:32
elect.sync %r449|%p115, -1;
bfe.u32 %r450, %r448, 4, 14;
cvt.u64.u32 %rd134, %r450;
or.b64 %rd117, %rd134, 4611686293372403712;
mov.b32 %r431, 135268624;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd117, %rd118, %r431, %p114;
// end inline asm
add.s32 %r451, %r448, 32;
bfe.u32 %r452, %r451, 4, 14;
cvt.u64.u32 %rd135, %r452;
or.b64 %rd119, %rd135, 4611686293372403712;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd119, %rd120, %r431, %p114;
// end inline asm
add.s32 %r453, %r448, 64;
bfe.u32 %r454, %r453, 4, 14;
cvt.u64.u32 %rd136, %r454;
or.b64 %rd121, %rd136, 4611686293372403712;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd121, %rd122, %r431, %p114;
// end inline asm
add.s32 %r455, %r448, 96;
bfe.u32 %r456, %r455, 4, 14;
cvt.u64.u32 %rd137, %r456;
or.b64 %rd123, %rd137, 4611686293372403712;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd123, %rd124, %r431, %p114;
// end inline asm
add.s32 %r457, %r448, 16384;
bfe.u32 %r458, %r457, 4, 14;
cvt.u64.u32 %rd138, %r458;
or.b64 %rd125, %rd138, 4611686293372403712;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd125, %rd126, %r431, %p114;
// end inline asm
add.s32 %r459, %r448, 16416;
bfe.u32 %r460, %r459, 4, 14;
cvt.u64.u32 %rd139, %r460;
or.b64 %rd127, %rd139, 4611686293372403712;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd127, %rd128, %r431, %p114;
// end inline asm
add.s32 %r461, %r448, 16448;
bfe.u32 %r462, %r461, 4, 14;
cvt.u64.u32 %rd140, %r462;
or.b64 %rd129, %rd140, 4611686293372403712;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd129, %rd130, %r431, %p114;
// end inline asm
add.s32 %r463, %r448, 16480;
bfe.u32 %r464, %r463, 4, 14;
cvt.u64.u32 %rd141, %r464;
or.b64 %rd131, %rd141, 4611686293372403712;
// begin inline asm
@%p115 tcgen05.mma.cta_group::1.kind::tf32 [ %r561 + 0 ], %rd131, %rd132, %r431, %p114;
// end inline asm
cvt.u64.u32 %rd133, %r590;
// begin inline asm
@%p115 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [%rd133];
// end inline asm
bra.uni $L__BB0_6;
2、store对照
// old
%651 = and i32 %160, 16256, !dbg !26
%652 = or disjoint i32 %651, %162, !dbg !26
%653 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %652, !dbg !26
%654 = insertelement <4 x i32> poison, i32 %587, i64 0, !dbg !26
%655 = insertelement <4 x i32> %654, i32 %588, i64 1, !dbg !26
%656 = insertelement <4 x i32> %655, i32 %589, i64 2, !dbg !26
%657 = insertelement <4 x i32> %656, i32 %590, i64 3, !dbg !26
store <4 x i32> %657, ptr addrspace(3) %653, align 16, !dbg !26
...
%729 = xor i32 %652, 112, !dbg !26
%730 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %729, !dbg !26
%731 = insertelement <4 x i32> poison, i32 %615, i64 0, !dbg !26
%732 = insertelement <4 x i32> %731, i32 %616, i64 1, !dbg !26
%733 = insertelement <4 x i32> %732, i32 %617, i64 2, !dbg !26
%734 = insertelement <4 x i32> %733, i32 %618, i64 3, !dbg !26
store <4 x i32> %734, ptr addrspace(3) %730, align 16, !dbg !26
%735 = getelementptr inbounds nuw i8, ptr addrspace(3) %730, i32 16384, !dbg !26
%736 = insertelement <4 x i32> poison, i32 %647, i64 0, !dbg !26
%737 = insertelement <4 x i32> %736, i32 %648, i64 1, !dbg !26
%738 = insertelement <4 x i32> %737, i32 %649, i64 2, !dbg !26
%739 = insertelement <4 x i32> %738, i32 %650, i64 3, !dbg !26
store <4 x i32> %739, ptr addrspace(3) %735, align 16, !dbg !26
tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #5, !dbg !26
tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !26
%740 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !26
%741 = extractvalue { i32, i1 } %740, 1, !dbg !26
%742 = and i1 %56, %741, !dbg !26
tail call void asm sideeffect "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r"(i1 %742, ptr %584, i32 %68, i32 %41, ptr addrspace(3) %60) #5, !dbg !26
tail call void @llvm.nvvm.cp.async.bulk.commit.group(), !dbg !26
tail call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0), !dbg !26
tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !26
...
!26 = !DILocation(line: 43, column: 63, scope: !5)
// new
.loc 1 43 63 // matmul.py:43:63
and.b32 %r563, %r23, 16256;
or.b32 %r564, %r563, %r24;
add.s32 %r565, %r62, %r564;
st.shared.v4.b32 [%r565], {%r493, %r494, %r495, %r496};
st.shared.v4.b32 [%r565+16384], {%r525, %r526, %r527, %r528};
xor.b32 %r566, %r564, 16;
add.s32 %r567, %r62, %r566;
st.shared.v4.b32 [%r567], {%r497, %r498, %r499, %r500};
st.shared.v4.b32 [%r567+16384], {%r529, %r530, %r531, %r532};
xor.b32 %r568, %r564, 32;
add.s32 %r569, %r62, %r568;
st.shared.v4.b32 [%r569], {%r501, %r502, %r503, %r504};
st.shared.v4.b32 [%r569+16384], {%r533, %r534, %r535, %r536};
xor.b32 %r570, %r564, 48;
add.s32 %r571, %r62, %r570;
st.shared.v4.b32 [%r571], {%r505, %r506, %r507, %r508};
st.shared.v4.b32 [%r571+16384], {%r537, %r538, %r539, %r540};
xor.b32 %r572, %r564, 64;
add.s32 %r573, %r62, %r572;
st.shared.v4.b32 [%r573], {%r509, %r510, %r511, %r512};
st.shared.v4.b32 [%r573+16384], {%r541, %r542, %r543, %r544};
xor.b32 %r574, %r564, 80;
add.s32 %r575, %r62, %r574;
st.shared.v4.b32 [%r575], {%r513, %r514, %r515, %r516};
st.shared.v4.b32 [%r575+16384], {%r545, %r546, %r547, %r548};
xor.b32 %r576, %r564, 96;
add.s32 %r577, %r62, %r576;
st.shared.v4.b32 [%r577], {%r517, %r518, %r519, %r520};
st.shared.v4.b32 [%r577+16384], {%r549, %r550, %r551, %r552};
xor.b32 %r578, %r564, 112;
add.s32 %r579, %r62, %r578;
st.shared.v4.b32 [%r579], {%r521, %r522, %r523, %r524};
st.shared.v4.b32 [%r579+16384], {%r553, %r554, %r555, %r556};
// begin inline asm
fence.proxy.async.shared::cta;
// end inline asm
bar.sync 0;
elect.sync %r580|%p153, -1;
and.pred %p150, %p73, %p153;
// begin inline asm
@%p150 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd144, {%r558, %r559}], [%r560];
// end inline asm
cp.async.bulk.commit_group;
cp.async.bulk.wait_group.read 0;
bar.sync 0;
3、load对照
// old
tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !23
%331 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !23
%332 = extractvalue { i32, i1 } %331, 1, !dbg !23
%333 = and i1 %82, %332, !dbg !23
%334 = and i1 %56, %333, !dbg !23
%335 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %59, !dbg !23
%336 = or disjoint i32 %61, 128, !dbg !23
tail call void asm sideeffect "@$0 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r"(i1 %334, ptr addrspace(3) %335, ptr %29, i32 %336, i32 %41, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 180240)) #5, !dbg !23
...
!23 = !DILocation(line: 38, column: 24, scope: !5)
// new
.loc 1 38 24 // matmul.py:38:24
bar.sync 0;
elect.sync %r341|%p107, -1;
and.pred %p108, %p103, %p107;
and.pred %p101, %p73, %p108;
shl.b32 %r342, %r9, 2;
add.s32 %r343, %r62, %r342;
add.s32 %r330, %r343, 65536;
or.b32 %r331, %r159, 128;
// begin inline asm
@%p101 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [%r330], [%rd66, {%r331, %r559}], [%r329];
// end inline asm
七、系列文章
深度剖析 Triton编译器 MatMul优化(二)—— MMA
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18958924

浙公网安备 33010602011771号