深度剖析 Triton编译器 MatMul优化(一)—— FMA
本文分析了native
(不做分块)的Triton Matmul
矩阵乘在 NVIDIA B200的编译流程,从Python
->TTIR
->TTGIR
->LLVM IR
->PTX
。最近会出一个系列分析Triton对于矩阵乘的优化以及Blackwell
新特性的支持情况。首先先看性能,用上autotune相比CUDA加速比为1.40×。
CUDA native kernel vs Triton native kernel
我关于Triton分析的上一篇文章介绍了Triton,分析了vectorAdd算子的编译Pass Pipeline,简单介绍了jit流程。得到了不少朋友的关注和鼓励,在此感谢大家。本文针对的是Pass导致的IR变化,构建项目和全部流程请参考上一篇文章,相关内容本文不再赘述。浅析 Triton 执行流程
本文所用用Triton的 commit为bc75dd0(Jun 27, 2025) 的版本。所有IR和kernel文件均已上传至Github。sBobHuang/Triton-blog-file
一、matmul Triton kernel
1、kernel书写
Triton kernel如下所示,矩阵a大小为M*N
,矩阵b大小为N*K
,结果矩阵c为M*K
# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl
@triton.jit
def matrix_multiplication_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_an,
stride_bn, stride_bk,
stride_cm, stride_ck,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
):
a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
pid_k = tl.program_id(axis=0)
pid_m = tl.program_id(axis=1)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
# 初始化A和B的指针
a_ptrs = a_ptr + offs_m[:, None] * stride_am
b_ptrs = b_ptr + offs_k[None, :] * stride_bk
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
# 沿N维度依次累加
for n in range(N):
# 加载A和B的当前值
a = tl.load(a_ptrs + n * stride_an)
b = tl.load(b_ptrs + n * stride_bn)
accumulator += a * b
# 将结果写回C
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_ck = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
c_mask = (offs_cm[:, None] < M) & (offs_ck[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
stride_am, stride_an = N, 1
stride_bn, stride_bk = K, 1
stride_cm, stride_ck = K, 1
grid = lambda META: (triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(M, META['BLOCK_SIZE_M']), )
matrix_multiplication_kernel[grid](
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_an,
stride_bn, stride_bk,
stride_cm, stride_ck,
BLOCK_SIZE_M=128,
BLOCK_SIZE_K=64,
)
2、kernel简析
针对结果c[i][j]
,就是a
的第i行乘上b
的第j列,即c[i][j] = sum(a[i][n] * b[n][j] for k in 0..N-1)
。
以上公式中,a
的起始位置为a[i][0]
,a[i][0]
的指针为a_ptr + i * stride_am
。这个i
怎么来的呢,我们将program
在第1维上分了triton.cdiv(M, META['BLOCK_SIZE_M']
块,pid_m
为拿到的块编号,块编号乘上块大小就是起始位置了。另外我们的program
是要处理BLOCK_SIZE_M
个元素的,即i实际为一个集合,即pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
,然后使用[:, None]
将其broadcast为[BLOCK_SIZE_M, 1]
。b
同理得到了[1, BLOCK_SIZE_K]
的块。
每个program
要处理 [BLOCK_SIZE_M, BLOCK_SIZE_K]
的数据,存在循环,所以我们需要创建tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
的累加变量。然后每次循环取出a[i][n]
和b[n][j]
的值相乘并存进累加变量。
结果c
的指针计算同理,其是需要mask的,判断(offs_cm[:, None] < M) & (offs_ck[None, :] < K)
即可。
二、ast_to_ttir
使用JIT装饰器遍历Python AST,最后调用MLIR的self.create<
方法,具体可以参考上一篇文章。
1、循环分析
得到的ttir比较冗余,全部IR在01-source.mlir,我们挑其中的循环看一下,如下所示。
%60 = scf.for %arg9 = %56 to %57 step %58 iter_args(%arg10 = %55) -> (tensor<128x64xf32>) : i32 {
%c1_i32_49 = arith.constant 1 : i32 loc(#loc17)
%c1_i32_50 = arith.constant 1 : i32 loc(#loc17)
%124 = arith.extsi %arg9 : i32 to i64 loc(#loc17)
%125 = arith.extsi %c1_i32_50 : i32 to i64 loc(#loc17)
%126 = arith.muli %124, %125 : i64 loc(#loc17)
%c2147483647_i64_51 = arith.constant 2147483647 : i64 loc(#loc17)
%c-2147483648_i64_52 = arith.constant -2147483648 : i64 loc(#loc17)
%127 = arith.cmpi sle, %126, %c2147483647_i64_51 : i64 loc(#loc17)
%128 = arith.cmpi sge, %126, %c-2147483648_i64_52 : i64 loc(#loc17)
%129 = arith.andi %127, %128 : i1 loc(#loc17)
%130 = arith.muli %arg9, %c1_i32_50 : i32 loc(#loc17)
%131 = tt.splat %130 : i32 -> tensor<128x1xi32> loc(#loc18)
%132 = tt.addptr %44, %131 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> loc(#loc18)
%133 = tt.load %132 : tensor<128x1x!tt.ptr<f32>> loc(#loc19)
%134 = arith.extsi %arg9 : i32 to i64 loc(#loc20)
%135 = arith.extsi %arg7 : i32 to i64 loc(#loc20)
%136 = arith.muli %134, %135 : i64 loc(#loc20)
%c2147483647_i64_53 = arith.constant 2147483647 : i64 loc(#loc20)
%c-2147483648_i64_54 = arith.constant -2147483648 : i64 loc(#loc20)
%137 = arith.cmpi sle, %136, %c2147483647_i64_53 : i64 loc(#loc20)
%138 = arith.cmpi sge, %136, %c-2147483648_i64_54 : i64 loc(#loc20)
%139 = arith.andi %137, %138 : i1 loc(#loc20)
%140 = arith.muli %arg9, %arg7 : i32 loc(#loc20)
%141 = tt.splat %140 : i32 -> tensor<1x64xi32> loc(#loc21)
%142 = tt.addptr %54, %141 : tensor<1x64x!tt.ptr<f32>>, tensor<1x64xi32> loc(#loc21)
%143 = tt.load %142 : tensor<1x64x!tt.ptr<f32>> loc(#loc22)
%144 = tt.broadcast %133 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc23)
%145 = tt.broadcast %143 : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc23)
%146 = arith.mulf %144, %145 : tensor<128x64xf32> loc(#loc23)
%147 = arith.addf %arg10, %146 : tensor<128x64xf32> loc(#loc24)
scf.yield %147 : tensor<128x64xf32> loc(#loc25)
} loc(#loc16)
可以对应到Triton kernel中的循环部分,如下所示
for n in range(N):
a = tl.load(a_ptrs + n * stride_an)
b = tl.load(b_ptrs + n * stride_bn)
accumulator += a * b
2、sanitize_overflow
有一段IR是这样的,把循环变量*1,再和 2147483647、-2147483648比较,得到结果%129
。没有在后续IR中使用,这段会被优化掉。这段代码是python/triton/language/semantic.py:206的binary_op_sanitize_overflow_impl
生成的。
%c1_i32_50 = arith.constant 1 : i32 loc(#loc17)
%124 = arith.extsi %arg9 : i32 to i64 loc(#loc17)
%125 = arith.extsi %c1_i32_50 : i32 to i64 loc(#loc17)
%126 = arith.muli %124, %125 : i64 loc(#loc17)
%c2147483647_i64_51 = arith.constant 2147483647 : i64 loc(#loc17)
%c-2147483648_i64_52 = arith.constant -2147483648 : i64 loc(#loc17)
%127 = arith.cmpi sle, %126, %c2147483647_i64_51 : i64 loc(#loc17)
%128 = arith.cmpi sge, %126, %c-2147483648_i64_52 : i64 loc(#loc17)
%129 = arith.andi %127, %128 : i1 loc(#loc17)
3、其他部分
然后计算了offset,并把a
load出来。stride_an
是1,没有通过参数传递。b
的load同理。
%130 = arith.muli %arg9, %c1_i32_50 : i32 loc(#loc17)
%131 = tt.splat %130 : i32 -> tensor<128x1xi32> loc(#loc18)
%132 = tt.addptr %44, %131 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> loc(#loc18)
%133 = tt.load %132 : tensor<128x1x!tt.ptr<f32>> loc(#loc19)
a和b的shape不同,需要broadcast
,之后再乘加即可。
%144 = tt.broadcast %133 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc23)
%145 = tt.broadcast %143 : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc23)
%146 = arith.mulf %144, %145 : tensor<128x64xf32> loc(#loc23)
%147 = arith.addf %arg10, %146 : tensor<128x64xf32> loc(#loc24)
accumulator
是被放进iter_args
的,所以还需要yield
。
%60 = scf.for %arg9 = %56 to %57 step %58 iter_args(%arg10 = %55) -> (tensor<128x64xf32>) : i32 {
...
scf.yield %147 : tensor<128x64xf32> loc(#loc25)
} loc(#loc16)
其他IR的理解类似,你还可以丢给chatgpt解读01-source.mlir。此时是完全符合Python DSL语义的,无任何优化。
三、 make_ttir
这个阶段,我们将执行如下流程。
@staticmethod
def make_ttir(mod, metadata, opt, capability):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
if capability // 10 < 9:
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
pm.run(mod)
return mod
B200是SM_100,不会走add_rewrite_tensor_descriptor_to_pointer
这个Pass。Pass执行后IR从02-Inliner.mlir到12-TritonLoopUnroll.mlir。
1、Canonicalizer
规范化是一个通用Pass,你会看到它经常出现,它可能存在的操作包含消除冗余操作、简化匹配模式、折叠常量计算、应用定义的Op
的canonicalization
。
02-Inliner.mlir vs 03-Canonicalizer.mlir
// old
tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32_"() -> tensor<128x64xf32> attributes {noinline = false} {
%cst = arith.constant 0.000000e+00 : f32 loc(#loc46)
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc46)
tt.return %cst_0 : tensor<128x64xf32> loc(#loc47)
^bb1: // no predecessors
%0 = ub.poison : tensor<128x64xf32> loc(#loc48)
tt.return %0 : tensor<128x64xf32> loc(#loc48)
} loc(#loc45)
// new
tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32_"() -> tensor<128x64xf32> attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc46)
tt.return %cst : tensor<128x64xf32> loc(#loc47)
} loc(#loc45)
03-Canonicalizer.mlir vs 04-Canonicalizer.mlir 的变化特别大,我们之前分析的循环中没使用的Op都去掉了
%18 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst) -> (tensor<128x64xf32>) : i32 {
%45 = tt.splat %arg9 : i32 -> tensor<128x1xi32> loc(#loc18)
%46 = tt.addptr %14, %45 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32> loc(#loc18)
%47 = tt.load %46 : tensor<128x1x!tt.ptr<f32>> loc(#loc19)
%48 = arith.muli %arg9, %arg7 : i32 loc(#loc20)
%49 = tt.splat %48 : i32 -> tensor<1x64xi32> loc(#loc21)
%50 = tt.addptr %17, %49 : tensor<1x64x!tt.ptr<f32>>, tensor<1x64xi32> loc(#loc21)
%51 = tt.load %50 : tensor<1x64x!tt.ptr<f32>> loc(#loc22)
%52 = tt.broadcast %47 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc23)
%53 = tt.broadcast %51 : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc23)
%54 = arith.mulf %52, %53 : tensor<128x64xf32> loc(#loc23)
%55 = arith.addf %arg10, %54 : tensor<128x64xf32> loc(#loc24)
scf.yield %55 : tensor<128x64xf32> loc(#loc25)
} loc(#loc17)
04-Canonicalizer.mlir vs 05-Canonicalizer.mlir 变化不大,是loc以及去掉了@"triton.language.standard.zeros
这个func。
2、CSE
Common Subexpression Elimination,公共子表达式消除,查找等价操作,然后复用。
09-TritonReorderBroadcast.mlir vs 10-CSE.mlir比如tt.expand_dims %5 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc24)
有%15和%19,最后%19就没有了,arith.muli
直接用的%19,其他类似。
3、最终结果
最终产物为matrix_multiplication_kernel.ttir,也是12-TritonLoopUnroll.mlir。
四、 make_ttgir
优化越来越多,这个模块也越来越大了,如下所示。
@staticmethod
def make_ttgir(mod, metadata, opt, capability):
# Set maxnreg on all kernels, if it was provided.
if opt.maxnreg is not None:
mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
cluster_info = nvidia.ClusterInfo()
if opt.cluster_dims is not None:
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
pm = ir.pass_manager(mod.context)
dump_enabled = pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
# optimize TTGIR
passes.ttgpuir.add_coalesce(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_f32_dot_tc(pm)
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
passes.ttir.add_loop_aware_cse(pm)
if capability // 10 in [8, 9]:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
elif capability // 10 >= 10:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.ttgpuir.add_optimize_accumulator_init(pm)
passes.ttgpuir.add_hoist_tmem_alloc(pm)
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
else:
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_coalesce_async_copy(pm)
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.common.add_symbol_dce(pm)
if capability // 10 >= 9:
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
passes.common.add_sccp(pm)
passes.common.add_canonicalizer(pm)
pm.run(mod)
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
tensordesc_meta = mod.get_tensordesc_metadata()
metadata["tensordesc_meta"] = tensordesc_meta
return mod
这里还对capability
做了一个更详细的分支,我们依旧只分析变化部分。Pass执行后IR从13-ConvertTritonToTritonGPU.mlir到61-Canonicalizer.mlir。
1、ConvertTritonToTritonGPU
将Triton IR转换为 TritonGPU IR,以下是这个Pass的description
。
This pass converts the Triton Dialect into the TritonGPU Dialect.
This is a partial conversion that also affects other dialects
(namely `Arith`, `Math`, `SCF` and `CF`).
For these dialects, and many Triton dialect operations the conversions
mainly consists of enhancing the tensor type and the `tt.ptr<tensor<>>`
type with an appropriate layout encoding (these encodings generally
include information on `numWarps`, `threadsPerWarp` and `numCTAs`).
12-TritonLoopUnroll.mlir vs 13-ConvertTritonToTritonGPU.mlir,这里主要是加上了一些layout,以下是被加进来的内存访问布局标签。
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
我们大多数Op都被加上了这个标签。比如#blocked3
,也就是a
的访问,sizePerThread = [1, 1]
表示每个线程处理 1×1 个元素;threadsPerWarp = [32, 1]
表示每个 warp 由 32 个线程组成,沿着第 0 维行方向排布;warpsPerCTA = [4, 1]
表示一个 Cooperative Thread Array(block)
有 4 个 warp 沿 第 0 维 排列,整个 CTA 会覆盖 4 × 32 = 128 行;order = [1, 0]
表示内存访问优先级是列优先,即 column-major 顺序。#blocked3
在同一个warp内(32个线程,排列为1行32列),这32个线程将访问同一列的连续 32 行的元素。
我们的module
上也带上了一些attributes
,这会给后边用。ttg.target
指CUDA Compute capability,B200是SM_100,也就是"cuda:100"。
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32}
2、TritonGPUCoalesce
访存合并,合并线程连续访问内存中的连续地址,目标是提升 访存效率和 GPU 带宽利用率。上面加的#ttg.blocked
标记就是给这个Pass使用的。
13-ConvertTritonToTritonGPU.mlir vs 14-TritonGPUCoalesce.mlir,我们先分析下最后store的转换吧。
// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
...
%22 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst) -> (tensor<128x64xf32, #blocked>) : i32
...
%30 = tt.addptr %28, %29 : tensor<128x64x!tt.ptr<f32>, #blocked>, tensor<128x64xi32, #blocked> loc(#loc26)
...
%38 = arith.andi %36, %37 : tensor<128x64xi1, #blocked> loc(#loc29)
tt.store %30, %22, %38 : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc30)
// new
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
...
%39 = ttg.convert_layout %30 : tensor<128x64x!tt.ptr<f32>, #blocked> -> tensor<128x64x!tt.ptr<f32>, #blocked5> loc(#loc30)
%40 = ttg.convert_layout %22 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked5> loc(#loc30)
%41 = ttg.convert_layout %38 : tensor<128x64xi1, #blocked> -> tensor<128x64xi1, #blocked5> loc(#loc30)
tt.store %39, %40, %41 : tensor<128x64x!tt.ptr<f32>, #blocked5> loc(#loc30)
新layout的sizePerThread = [1, 4]
代表每个线程访问一个小 vector。threadsPerWarp = [2, 16]
是 2 行线程排布实现了 memory interleaving,可以提升带宽利用率,减少 bank conflicts,[1, 32]
容易跨 cacheline。
我们再看下之前的load
,ttg.convert_layout
是多余的。
// old
%40 = tt.addptr %16, %39 : tensor<128x1x!tt.ptr<f32>, #blocked3>, tensor<128x1xi32, #blocked3> loc(#loc16)
%41 = tt.load %40 : tensor<128x1x!tt.ptr<f32>, #blocked3> loc(#loc17)
// new
%43 = tt.addptr %16, %42 : tensor<128x1x!tt.ptr<f32>, #blocked3>, tensor<128x1xi32, #blocked3> loc(#loc16)
%44 = ttg.convert_layout %43 : tensor<128x1x!tt.ptr<f32>, #blocked3> -> tensor<128x1x!tt.ptr<f32>, #blocked3> loc(#loc17)
%45 = tt.load %44 : tensor<128x1x!tt.ptr<f32>, #blocked3> loc(#loc17)
3、TritonGPURemoveLayoutConversions
去除多余的convert_layout
,刚才我们也发现了load就不需要转换。
16-TritonGPUPlanCTAPass.mlir vs 17-TritonGPURemoveLayoutConversions.mlir,我们可以发现layout减少了。
// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
// new
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
其实就是前5个被合并到#blocked3
了。我们可以看下lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp的内容。注释里写了这个Pass的运行方式,分析只做一轮(one-shot),然后整段 IR 依照支配(dominance)顺序重写一次。
// The current algorithm works by analyzing the IR and doing a one-shot rewrite
// based on the analysis. The algorithm is as follows.
//
// 1. Find all the anchor ops. These are ops that have a layout we want to
// preserve.
//
// 2. For each anchor, propagate its layout to all its descendants.
// An op can have multiple ancestors that are anchors, so at this stage an op
// may have multiple layouts associated with it.
//
// 3. Resolve conflicts by deciding which of the multiple layouts the op should
// keep, inserting convert-layout ops to resolve conflicts. After this
// stage, each value has only one layout associated with it.
//
// 4. Rewrite the IR by walking the function in dominance order. Since we
// assume the IR is structured we just need to process the regions in the
// correct order. For each op, rewrite it using the layout decided by the
// analysis phase.
17-TritonGPURemoveLayoutConversions-debug.mlir 是对这个Pass Debug的输出。我们可以大概跟一下代码逻辑,插入layout的部分在 lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp :203
void LayoutPropagation::initAnchorLayout() {
auto addAnchor = [&](Value v) {
if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
layouts.insert({v, LayoutInfo(tensorType.getEncoding())});
}
};
// Consider function args as anchors. This makes it easier to write tests --
// you can pass a tensor with an encoding as an arg, instead of explicitly
// calling tt.load.
for (auto arg : funcOp.getArguments()) {
addAnchor(arg);
}
funcOp.walk([&](Operation *op) {
if (isLayoutAnchor(op)) {
for (auto result : op->getResults()) {
addAnchor(result);
}
}
});
}
其中是通过isLayoutAnchor函数去判断的,这其实是一个cost model
,可以认为是 布局敏感操作的 cost-aware 筛选器。
// Return true if the op is an op with a layout we don't want to change. We will
// propagate the layout starting from anchor ops.
bool isLayoutAnchor(Operation *op) {
if (isa<LoadOp, StoreOp>(op))
return isExpensiveLoadOrStore(op);
if (isa<DotOp, DotScaledOp, nvidia_gpu::WarpGroupDotOp, AtomicRMWOp,
AtomicCASOp, triton::nvidia_gpu::TMEMLoadOp>(op))
return true;
if (auto gatherOp = dyn_cast<GatherOp>(op))
return gatherOp.getEfficientLayout();
// Heuristic: Mark permuting reshape as a layout anchor. Its dst can be
// anything, so it stops forward-propagation of layouts. We rely on the
// backwards pass to fix it up if necessary. (If we didn't do this, then
// anything following the reshape won't be covered by the forward pass at
// all.)
if (auto reshape = dyn_cast<ReshapeOp>(op))
return reshape.getAllowReorder();
return false;
}
layouts插入的Value是%45 = tt.load %44 : tensor<128x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>>
,也就是%45 = tt.load %44 : tensor<128x1x!tt.ptr
4、SCCP
Sparse Conditional Constant Propagation,CFG中传播常量,这里只是调整了arith.constant
的顺序。
36-TritonGPURewritePartitionDependencies.mlir vs 37-SCCP.mlir
5、TritonNvidiaGPURemoveTMEMTokensPass
清除TMEM token,这是表示 共享内存(shared memory)或其他 memory barrier 相关的依赖关系的抽象。
43-TritonGPUCombineTensorSelectAndIf.mlir vs 44-TritonNvidiaGPURemoveTMEMTokensPass.mlir,我们这个多了ub.poison,下一个Canonicalizer
Pass就被干掉了,代码注释里也写了。
%0 = ub.poison : !ttg.async.token loc(#loc)
6、SCCP
CFG中传播常量,59-TritonGPUFenceInsertion.mlir vs 60-SCCP.mlir 把arith.constant
又调整了次顺序。
7、最终产物
最终产物为matrix_multiplication_kernel.ttgir,也是61-Canonicalizer.mlir。
五、make_llir
要从ttgir到llvm ir是比较陡峭的,不过SIMT
的gpgpu就是可以这么做,编程模型相对比较简单。
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
if knobs.compilation.enable_experimental_consan:
# Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
passes.ttgpuir.add_concurrency_sanitizer(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if not knobs.compilation.disable_line_info:
passes.llvmir.add_di_scope(pm)
pm.run(mod)
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
if knobs.compilation.enable_asan:
raise RuntimeError(
"Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
llvm_mod = llvm.to_module(mod, context)
proc = sm_arch_from_capability(capability)
features = get_features(options, self.target.arch)
triple = 'nvptx64-nvidia-cuda'
nvidia.set_short_ptr()
llvm.attach_datalayout(llvm_mod, triple, proc, features)
nvidia.set_nvvm_reflect_ftz(llvm_mod)
if options.extern_libs:
paths = [path for (name, path) in options.extern_libs]
llvm.link_extern_libs(llvm_mod, paths)
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
# Get some metadata
# warp-specialization mutates num_warps
total_num_warps = src.get_int_attr("ttg.total-num-warps")
if total_num_warps is not None:
metadata["num_warps"] = total_num_warps
metadata["shared"] = src.get_int_attr("ttg.shared")
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
ret = str(llvm_mod)
del llvm_mod
del context
return ret
Pass执行后的IR从62-TritonNvidiaGPUMMALoweringPass.mlir到79-LLVMDIScope.mlir。
1、SCFToControlFlowPass
64-TritonGPUAllocateWarpGroups vs 65-SCFToControlFlowPass,将scf.for
lower到cf.br
// old
%26 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst) -> (tensor<128x64xf32, #blocked>) : i32 {
...
%52 = arith.addf %arg10, %51 : tensor<128x64xf32, #blocked> loc(#loc22)
scf.yield %52 : tensor<128x64xf32, #blocked> loc(#loc23)
} loc(#loc15)
// new
cf.br ^bb1(%c0_i32, %cst : i32, tensor<128x64xf32, #blocked>) loc(#loc15)
^bb1(%26: i32 loc("/home/ubuntu/triton/matmul.py":30:19), %27: tensor<128x64xf32, #blocked> loc(unknown)): // 2 preds: ^bb0, ^bb2
%28 = arith.cmpi slt, %26, %arg4 : i32 loc(#loc15)
cf.cond_br %28, ^bb2, ^bb3 loc(#loc15)
^bb2: // pred: ^bb1
...
%39 = arith.addf %27, %38 : tensor<128x64xf32, #blocked> loc(#loc22)
%40 = arith.addi %26, %c1_i32 : i32 loc(#loc15)
cf.br ^bb1(%40, %39 : i32, tensor<128x64xf32, #blocked>) loc(#loc15)
^bb3: // pred: ^bb1
2、AllocateSharedMemory
65-SCFToControlFlowPass vs 66-AllocateSharedMemory,ttg.convert_layout
使用了SharedMemory
来转置,可以看到其offset, module
上也添加了ttg.shared
大小描述,ttg.shared
不等于tensor<128x64xf32, #blocked>
的大小,即128x64*4
不等于34816
。因为lib/Analysis/Allocation.cpp:207的elems
取得是getNumScratchElemsPaddedCvt(srcTy, dstTy)
。
// old
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32}
...
%55 = ttg.convert_layout %27 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> loc(#loc29)
// new
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 34816 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32}
%55 = ttg.convert_layout %27 {allocation.offset = 0 : i32} : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> loc(#loc29)
3、TritonTensorMemoryAllocationPass
66-AllocateSharedMemory vs 67-TritonTensorMemoryAllocationPass,仅在module
上添加了ttg.tensor_memory_size = 0 : i32
。
4、TritonGPUGlobalScratchAllocationPass
67-TritonTensorMemoryAllocationPass vs 68-TritonGPUGlobalScratchAllocationPass,加上了global_scratch
的相关信息。
// old
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 34816 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32} {
tt.func public @matrix_multiplication_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0)) attributes {noinline = false} }
// new
module attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 34816 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 4 : i32} {
tt.func public @matrix_multiplication_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/ubuntu/triton/matmul.py":6:0)) attributes {noinline = false, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} }
5、ConvertTritonGPUToLLVM
69-TritonGPUProxyFenceInsertion vs 70-ConvertTritonGPUToLLVM,这个是转换中非常重要的一个Pass,我们的IR会膨胀起来,生成的结果还会出现llvm.inline_asm
这种嵌汇编ptx指令。
除了去debugRewritePattern
外,还有loc信息的,我们可以先跟踪下arith.addf的浮点数加法运算怎么处理的。
// old
%39 = arith.addf %27, %38 : tensor<128x64xf32, #blocked> loc(#loc22)
...
#loc22 = loc("/home/ubuntu/triton/matmul.py":34:23)
// new
%1527 = builtin.unrealized_conversion_cast %1526 : tensor<128x64xf32, #blocked> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
...
%2721 = llvm.extractvalue %1527[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2722 = llvm.extractvalue %1527[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2723 = llvm.extractvalue %1527[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
... 省略58行
%2782 = llvm.extractvalue %1527[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2783 = llvm.extractvalue %1527[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2784 = llvm.extractvalue %1527[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2785 = llvm.extractvalue %2720[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2786 = llvm.extractvalue %2720[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2787 = llvm.extractvalue %2720[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
... 省略58行
%2846 = llvm.extractvalue %2720[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2847 = llvm.extractvalue %2720[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2848 = llvm.extractvalue %2720[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc16)
%2849 = llvm.fadd %2721, %2785 : f32 loc(#loc16)
%2850 = llvm.fadd %2722, %2786 : f32 loc(#loc16)
%2851 = llvm.fadd %2723, %2787 : f32 loc(#loc16)
... 省略58行
%2910 = llvm.fadd %2782, %2846 : f32 loc(#loc16)
%2911 = llvm.fadd %2783, %2847 : f32 loc(#loc16)
%2912 = llvm.fadd %2784, %2848 : f32 loc(#loc16)
...
#loc16 = loc("/home/ubuntu/triton/matmul.py":34:23)
从%1527
里取数据和%2720
相乘,我们可以看下%2720
是怎么来的,如下所示。
%2592 = llvm.fmul %2464, %2528 : f32 loc(#loc22)
%2593 = llvm.fmul %2465, %2529 : f32 loc(#loc22)
%2594 = llvm.fmul %2466, %2530 : f32 loc(#loc22)
... 省略58行
%2653 = llvm.fmul %2525, %2589 : f32 loc(#loc22)
%2654 = llvm.fmul %2526, %2590 : f32 loc(#loc22)
%2655 = llvm.fmul %2527, %2591 : f32 loc(#loc22)
%2656 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
%2657 = llvm.insertvalue %2592, %2656[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
%2658 = llvm.insertvalue %2593, %2657[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
%2659 = llvm.insertvalue %2594, %2658[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
... 省略58行
%2718 = llvm.insertvalue %2653, %2717[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
%2719 = llvm.insertvalue %2654, %2718[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
%2720 = llvm.insertvalue %2655, %2719[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc22)
就是乘法后的值,然后定义了一个64*f32的struct,然后和%1527
相乘,%1527
又可以追溯到^bb1
标签,^bb1
又可以追溯到最前面的%70
%4 = llvm.mlir.constant(0.000000e+00 : f32) : f32 loc(#loc1)
%5 = llvm.bitcast %4 : f32 to f32 loc(#loc1)
%6 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
%7 = llvm.insertvalue %5, %6[0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
%8 = llvm.insertvalue %5, %7[1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
%9 = llvm.insertvalue %5, %8[2] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
... 省略58行
%68 = llvm.insertvalue %5, %67[61] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
%69 = llvm.insertvalue %5, %68[62] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
%70 = llvm.insertvalue %5, %69[63] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> loc(#loc1)
%70
也是一个64*f32的struct的,且最初全为为0.0
。这就是accumulator
。我们的原始数据是tensor<128x64xf32, #blocked>
,逻辑对上了,shape是怎么回事呢,这一步其实就是转换到线程了。module
中的"ttg.num-warps" = 4 : i32
和"ttg.threads-per-warp" = 32
来计算的,具体在lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp:46的unsigned numElementsPerThread = getTotalElemsPerThread(type);
来算出来的,这个getTotalElemsPerThread
继续怎么算的你可以跟踪lib/Dialect/TritonGPU/IR/Dialect.cpp:79,所以其实和layout都有关。
unsigned getTotalElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || isa<triton::PointerType>(type))
return 1;
auto tensorType = cast<RankedTensorType>(type);
return getTotalElemsPerThread(tensorType.getEncoding(),
tensorType.getShape());
}
我们再分析一下怎么存进去的,新的只拿了非常小的一块。
// old
%55 = ttg.convert_layout %27 {allocation.offset = 0 : i32} : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1> loc(#loc29)
tt.store %47, %55, %54 : tensor<128x64x!tt.ptr<f32>, #blocked1> loc(#loc29)
...
#loc29 = loc("/home/ubuntu/triton/matmul.py":41:21)
// new
...
%4979 = llvm.getelementptr inbounds %4883[%4978] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 loc(#loc29)
%4980 = llvm.mlir.undef : vector<4xf32> loc(#loc29)
%4981 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4982 = llvm.insertelement %4884, %4980[%4981 : i32] : vector<4xf32> loc(#loc29)
...省略4行
%4987 = llvm.mlir.constant(3 : i32) : i32 loc(#loc29)
%4988 = llvm.insertelement %4887, %4986[%4987 : i32] : vector<4xf32> loc(#loc29)
%4989 = llvm.mlir.constant(true) : i1 loc(#loc29)
%4990 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4991 = llvm.extractelement %4988[%4990 : i32] : vector<4xf32> loc(#loc29)
...省略4行
%4996 = llvm.mlir.constant(3 : i32) : i32 loc(#loc29)
%4997 = llvm.extractelement %4988[%4996 : i32] : vector<4xf32> loc(#loc29)
%4998 = llvm.bitcast %4991 : f32 to i32 loc(#loc29)
...省略2行
%5001 = llvm.bitcast %4997 : f32 to i32 loc(#loc29)
%5002 = llvm.mlir.undef : vector<4xi32> loc(#loc29)
%5003 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%5004 = llvm.insertelement %4998, %5002[%5003 : i32] : vector<4xi32> loc(#loc29)
...省略4行
%5009 = llvm.mlir.constant(3 : i32) : i32 loc(#loc29)
%5010 = llvm.insertelement %5001, %5008[%5009 : i32] : vector<4xi32> loc(#loc29)
llvm.store %5010, %4979 {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<3> loc(#loc29)
...
#loc29 = loc("/home/ubuntu/triton/matmul.py":41:21)
逻辑是构造了一个 vector<4xi32>,把 4 个 f32 数据 bitcast 成 i32,打包成一个向量,用 16 字节对齐的 store 把这个 vector 存到共享内存中对应偏移的位置。关于内存偏移的计算如下所示
%4948 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc29)
%4949 = llvm.mlir.constant(127 : i32) : i32 loc(#loc29)
%4950 = llvm.and %4948, %4949 : i32 loc(#loc29)
%4951 = llvm.mlir.constant(32 : i32) : i32 loc(#loc29)
%4952 = llvm.urem %4950, %4951 : i32 loc(#loc29)
%4953 = llvm.udiv %4950, %4951 : i32 loc(#loc29)
%4954 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4955 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4956 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4957 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4958 = llvm.shl %4952, %4957 : i32 loc(#loc29)
%4959 = llvm.or %4956, %4958 : i32 loc(#loc29)
%4960 = llvm.mlir.constant(5 : i32) : i32 loc(#loc29)
%4961 = llvm.shl %4953, %4960 : i32 loc(#loc29)
%4962 = llvm.or %4959, %4961 : i32 loc(#loc29)
%4963 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4964 = llvm.mlir.constant(7 : i32) : i32 loc(#loc29)
%4965 = llvm.and %4962, %4964 : i32 loc(#loc29)
%4966 = llvm.mlir.constant(12 : i32) : i32 loc(#loc29)
%4967 = llvm.shl %4965, %4966 : i32 loc(#loc29)
%4968 = llvm.xor %4963, %4967 : i32 loc(#loc29)
%4969 = llvm.mlir.constant(127 : i32) : i32 loc(#loc29)
%4970 = llvm.and %4962, %4969 : i32 loc(#loc29)
%4971 = llvm.mlir.constant(4 : i32) : i32 loc(#loc29)
%4972 = llvm.shl %4970, %4971 : i32 loc(#loc29)
%4973 = llvm.xor %4968, %4972 : i32 loc(#loc29)
%4974 = llvm.xor %4955, %4973 : i32 loc(#loc29)
%4975 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4976 = llvm.xor %4974, %4975 : i32 loc(#loc29)
%4977 = llvm.mlir.constant(0 : i32) : i32 loc(#loc29)
%4978 = llvm.add %4976, %4977 : i32 loc(#loc29)
那么inline asm是怎么产生的呢,我们可以找下出现llvm.inline_asm
的位置,这里以loc("/home/ubuntu/triton/matmul.py":32:20)
为例
// old
%31 = tt.load %30 : tensor<128x1x!tt.ptr<f32>, #blocked> loc(#loc17)
...
#loc17 = loc("/home/ubuntu/triton/matmul.py":32:20)
// new
%1537 = llvm.extractvalue %1536[0] : !llvm.struct<(ptr<1>)> loc(#loc18)
%1538 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l" %1537 : (!llvm.ptr<1>) -> i32 loc(#loc18)
%1539 = llvm.bitcast %1538 : i32 to vector<1xf32> loc(#loc18)
%1540 = llvm.mlir.constant(0 : index) : i32 loc(#loc18)
%1541 = llvm.extractelement %1539[%1540 : i32] : vector<1xf32> loc(#loc18)
%1542 = llvm.mlir.undef : !llvm.struct<(f32)> loc(#loc18)
%1543 = llvm.insertvalue %1541, %1542[0] : !llvm.struct<(f32)> loc(#loc18)
...
#loc18 = loc("/home/ubuntu/triton/matmul.py":32:20)
描述的是global memory 加载 float 值,再装入 LLVM struct的过程。其实我们需要的是%1541
,也就是float值,在下一个Pass无关的代码就会被优化掉。这个实际上是PTXBuilder 来拼接出来的。
6、Canonicalizer
70-ConvertTritonGPUToLLVM vs 71-Canonicalizer,规范化Pass,上文提到过,IR可以从7192行降到1882行。
浮点数乘法被优化后的IR,从259行降到了193行,你可以看到直接复用了%2720
也就是乘法得到的值,没再打包解包。
store的IR量更是从从2265行降到了988行,这里也有数据打包解包的消除,还有llvm.mlir.constant(0 : i32)
被合并成了一条及其他优化。
7、CSE
71-Canonicalizer vs 72-CSE,公共子表达式消除,上文提到过,IR可以从1882行降到1847行。
8、LLVMDIScope
78-SymbolDCE vs 79-LLVMDIScope,在 LLVM IR 中附加调试信息作用域(Debug Info Scope) 的 pass,生成调试信息(如 DWARF)以支持源级调试。
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"
@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16
define ptx_kernel void @matrix_multiplication_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, ptr addrspace(1) readnone captures(none) %9) local_unnamed_addr #0 !dbg !5 {
%11 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !8
%12 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !dbg !9
%13 = shl nuw nsw i32 %12, 7, !dbg !10
%14 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !11
%15 = and i32 %14, 127, !dbg !11
%16 = or disjoint i32 %13, %15, !dbg !12 }
...
!5 = distinct !DISubprogram(name: "matrix_multiplication_kernel", linkageName: "matrix_multiplication_kernel", scope: !1, file: !1, line: 6, type: !6, scopeLine: 6, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
!6 = !DISubroutineType(cc: DW_CC_normal, types: !7)
!7 = !{}
!8 = !DILocation(line: 17, column: 26, scope: !5)
!9 = !DILocation(line: 18, column: 26, scope: !5)
!10 = !DILocation(line: 20, column: 21, scope: !5)
!11 = !DILocation(line: 20, column: 49, scope: !5)
!12 = !DILocation(line: 20, column: 36, scope: !5)
9、最终产物
最终产物为 matrix_multiplication_kernel.llir
六、make_ptx
这里实际上调用的是LLVM,我们只简单分析下最后的ptx结果。
def make_ptx(self, src, metadata, opt, capability):
ptx_version = get_ptx_version_from_options(opt, self.target.arch)
triple = 'nvptx64-nvidia-cuda'
proc = sm_arch_from_capability(capability)
features = get_features(opt, self.target.arch)
ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
# Find kernel names (there should only be one)
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
assert len(names) == 1
metadata["name"] = names[0]
# post-process
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
# Remove the debug flag that prevents ptxas from optimizing the code
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
if knobs.nvidia.dump_nvptx:
print("// -----// NVPTX Dump //----- //")
print(ret)
return ret
产物为matrix_multiplication_kernel.ptx。输入文件是579行,输出文件是729行,基本一一对应。
1、加法对照
将向量的fmul + fadd生成了具体的fma指令。
// old
%209 = fmul <64 x float> %208, %205, !dbg !23
%210 = fadd <64 x float> %39, %209, !dbg !24
...
!24 = !DILocation(line: 34, column: 23, scope: !5)
// new
.loc 1 34 23 // matmul.py:34:23
fma.rn.f32 %r599, %r338, %r342, %r599;
fma.rn.f32 %r598, %r338, %r341, %r598;
fma.rn.f32 %r597, %r338, %r340, %r597;
...省略58行
fma.rn.f32 %r657, %r338, %r400, %r657;
fma.rn.f32 %r658, %r338, %r401, %r658;
fma.rn.f32 %r659, %r338, %r402, %r659;
2、store对照
其中调用了llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
会转换为bar.sync 0;
这个barrier操作。其余基本都是对LLVM IR的翻译,这部分是这个kernel的核心部分,可以学习下他对share memory的padding。
// old
%331 = and i32 %14, 7, !dbg !33
%332 = shl nuw nsw i32 %331, 12, !dbg !33
%333 = shl nuw nsw i32 %15, 4, !dbg !33
%334 = or disjoint i32 %332, %333, !dbg !33
%335 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %334, !dbg !33
%336 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>, !dbg !33
store <4 x float> %336, ptr addrspace(3) %335, align 16, !dbg !33
%337 = getelementptr inbounds nuw i8, ptr addrspace(3) %335, i32 2048, !dbg !33
%338 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, !dbg !33
store <4 x float> %338, ptr addrspace(3) %337, align 16, !dbg !33
%339 = xor i32 %334, 16, !dbg !33
...省略42行
%370 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %369, !dbg !33
%371 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 56, i32 57, i32 58, i32 59>, !dbg !33
store <4 x float> %371, ptr addrspace(3) %370, align 16, !dbg !33
%372 = getelementptr inbounds nuw i8, ptr addrspace(3) %370, i32 2048, !dbg !33
%373 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 60, i32 61, i32 62, i32 63>, !dbg !33
store <4 x float> %373, ptr addrspace(3) %372, align 16, !dbg !33
%372 = getelementptr inbounds nuw i8, ptr addrspace(3) %370, i32 2048, !dbg !33
%373 = shufflevector <64 x float> %211, <64 x float> poison, <4 x i32> <i32 60, i32 61, i32 62, i32 63>, !dbg !33
store <4 x float> %373, ptr addrspace(3) %372, align 16, !dbg !33
tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !33
%374 = shl nuw nsw i32 %14, 8, !dbg !33
%375 = and i32 %374, 30720, !dbg !33
%376 = shl nuw nsw i32 %331, 4, !dbg !33
%377 = or disjoint i32 %375, %376, !dbg !33
%378 = xor i32 %377, %215, !dbg !33
%379 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %378, !dbg !33
%380 = load <4 x i32>, ptr addrspace(3) %379, align 16, !dbg !33
%381 = getelementptr inbounds nuw i8, ptr addrspace(3) %379, i32 128, !dbg !33
...省略26行
%408 = load <4 x i32>, ptr addrspace(3) %407, align 16, !dbg !33
%409 = getelementptr inbounds nuw i8, ptr addrspace(3) %379, i32 1920, !dbg !33
%410 = load <4 x i32>, ptr addrspace(3) %409, align 16, !dbg !33
%.extract = extractelement <4 x i32> %380, i64 0, !dbg !33
%.extract64 = extractelement <4 x i32> %380, i64 1, !dbg !33
%.extract65 = extractelement <4 x i32> %380, i64 2, !dbg !33
%.extract66 = extractelement <4 x i32> %380, i64 3, !dbg !33
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract, i32 %.extract64, i32 %.extract65, i32 %.extract66, ptr addrspace(1) %282, i1 %315) #3, !dbg !33
...省略70行
%.extract123 = extractelement <4 x i32> %410, i64 0, !dbg !33
%.extract124 = extractelement <4 x i32> %410, i64 1, !dbg !33
%.extract125 = extractelement <4 x i32> %410, i64 2, !dbg !33
%.extract126 = extractelement <4 x i32> %410, i64 3, !dbg !33
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract123, i32 %.extract124, i32 %.extract125, i32 %.extract126, ptr addrspace(1) %297, i1 %330) #3, !dbg !33
// new
.loc 1 41 21 // matmul.py:41:21
and.b32 %r505, %r2, 7;
shl.b32 %r506, %r505, 12;
shl.b32 %r507, %r3, 4;
or.b32 %r508, %r506, %r507;
mov.b32 %r509, global_smem;
add.s32 %r510, %r509, %r508;
st.shared.v4.b32 [%r510], {%r596, %r597, %r598, %r599};
st.shared.v4.b32 [%r510+2048], {%r600, %r601, %r602, %r603};
xor.b32 %r511, %r508, 16;
add.s32 %r512, %r509, %r511;
st.shared.v4.b32 [%r512], {%r604, %r605, %r606, %r607};
st.shared.v4.b32 [%r512+2048], {%r608, %r609, %r610, %r611};
...省略16行
xor.b32 %r521, %r508, 96;
add.s32 %r522, %r509, %r521;
st.shared.v4.b32 [%r522], {%r644, %r645, %r646, %r647};
st.shared.v4.b32 [%r522+2048], {%r648, %r649, %r650, %r651};
xor.b32 %r523, %r508, 112;
add.s32 %r524, %r509, %r523;
st.shared.v4.b32 [%r524], {%r652, %r653, %r654, %r655};
st.shared.v4.b32 [%r524+2048], {%r656, %r657, %r658, %r659};
bar.sync 0;
shl.b32 %r525, %r2, 8;
and.b32 %r526, %r525, 30720;
shl.b32 %r527, %r505, 4;
or.b32 %r528, %r526, %r527;
xor.b32 %r529, %r528, %r470;
add.s32 %r530, %r509, %r529;
ld.shared.v4.b32 {%r403, %r404, %r405, %r406}, [%r530];
ld.shared.v4.b32 {%r407, %r408, %r409, %r410}, [%r530+128];
...省略12行
ld.shared.v4.b32 {%r459, %r460, %r461, %r462}, [%r530+1792];
ld.shared.v4.b32 {%r463, %r464, %r465, %r466}, [%r530+1920];
// begin inline asm
@%p3 st.global.v4.b32 [ %rd31 + 0 ], { %r403, %r404, %r405, %r406 };
// end inline asm
// begin inline asm
@%p4 st.global.v4.b32 [ %rd32 + 0 ], { %r407, %r408, %r409, %r410 };
// end inline asm
...省略 12*3行
// begin inline asm
@%p17 st.global.v4.b32 [ %rd45 + 0 ], { %r459, %r460, %r461, %r462 };
// end inline asm
// begin inline asm
@%p18 st.global.v4.b32 [ %rd46 + 0 ], { %r463, %r464, %r465, %r466 };
// end inline asm
3、load对照
对嵌汇编做了展开
// old
%41 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %40) #3, !dbg !19
...
!19 = !DILocation(line: 32, column: 20, scope: !5)
// new
.loc 1 32 20 // matmul.py:32:20
// begin inline asm
mov.u32 %r338, 0x0;
ld.global.b32 { %r338 }, [ %rd80 + 0 ];
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18810026