深度剖析 Triton编译器 MatMul优化(二)—— MMA
深度剖析 Triton编译器 MatMul优化(一) 我们介绍了native矩阵乘的优化,本章来分析很容易就拿到性能的tl.dot操作。
上来首先性能对比,Triton native kernel vs Triton with dot kernel。这个加速比是3.68x,相较native的CUDA加速比是5.16x。

这个加速比看起来已经很高了,但是和我之前tma的kernel耗时差不多我产生了怀疑,看了PTX发现不对劲。input_precision="ieee"导致这些kernel都走了fma指令,所以这个kernel的优化实际是循环展开带来的。用tf32(default)因为精度损失通过不了这道题目。LeetGPU还有一道GEMM (FP16),但是数据量太小,做不了bench。所以之后的文章就不做性能对比了,被卡脖子了。
本文所用用Triton的 commit为bc75dd0(Jun 27, 2025) 的版本。所有IR和kernel文件均已上传至Github。sBobHuang/Triton-blog-file。本系列相关文章 深度剖析 Triton编译器 MatMul优化(一)
一、matmul Triton kernel
1、kernel书写
Triton kernel如下所示,矩阵a大小为M*N,矩阵b大小为N*K,结果矩阵c为M*K。完整可运行在matmul-with-dot-v2.py
# The use of PyTorch in Triton programs is not allowed for the purposes of fair benchmarking.
import triton
import triton.language as tl
@triton.jit
def matrix_multiplication_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_an,
stride_bn, stride_bk,
stride_cm, stride_ck,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr
):
a_ptr = a_ptr.to(tl.pointer_type(tl.float32))
b_ptr = b_ptr.to(tl.pointer_type(tl.float32))
c_ptr = c_ptr.to(tl.pointer_type(tl.float32))
pid_k = tl.program_id(axis=0)
pid_m = tl.program_id(axis=1)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
# 初始化A和B的指针
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
b_ptrs = b_ptr + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
# 沿N维度按块累加
for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
offset_n = n * BLOCK_SIZE_N
max_idx = N - offset_n
# 加载A和B的块
a = tl.load(a_ptrs + offset_n * stride_an, mask=offs_n[None, :] < max_idx, other=0.0)
b = tl.load(b_ptrs + offset_n * stride_bn, mask=offs_n[:, None] < max_idx, other=0.0)
# 计算a @ b,累加到 accumulator
accumulator = tl.dot(a, b, acc=accumulator)
# 将结果写回C
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_k[None, :] * stride_ck
c_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
# a_ptr, b_ptr, c_ptr are raw device pointers
def solve(a_ptr: int, b_ptr: int, c_ptr: int, M: int, N: int, K: int):
stride_am, stride_an = N, 1
stride_bn, stride_bk = K, 1
stride_cm, stride_ck = K, 1
grid = lambda META: (triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(M, META['BLOCK_SIZE_M']), )
matrix_multiplication_kernel[grid](
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_an,
stride_bn, stride_bk,
stride_cm, stride_ck,
BLOCK_SIZE_M=128,
BLOCK_SIZE_K=64,
BLOCK_SIZE_N=64,
)
2、kernel简析
与上文相比,本文kernel多了个BLOCK_SIZE_N,即a在循环中需要取[BLOCK_SIZE_M, BLOCK_SIZE_N]的块,b在循环中需要取[BLOCK_SIZE_N, BLOCK_SIZE_K],然后直接这两个块进行矩阵乘法。我们需要对a_ptrs = a_ptr + offs_m[:, None] * stride_am加上offs_n[None, :] * stride_an,b_ptrs = b_ptr + offs_k[None, :] * stride_bk加上offs_k[None, :] * stride_bk,然后取的时候按块加上offset、mask即可。
二、ast_to_ttir
使用JIT装饰器遍历Python AST,最后调用MLIR的self.create<。
1、循环IR
得到的ttir比较冗余,全部IR在01-source.mlir,我们挑其中的循环看一下,如下所示。
%85 = scf.for %arg9 = %81 to %82 step %83 iter_args(%arg10 = %79) -> (tensor<128x64xf32>) : i32 {
%155 = arith.muli %arg9, %c64_i32_61 : i32 loc(#loc25)
...省略
%162 = arith.subi %arg4, %155 : i32 loc(#loc26)
%163 = tt.expand_dims %34 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc27)
%164 = tt.splat %162 : i32 -> tensor<1x64xi32> loc(#loc28)
%165 = arith.cmpi slt, %163, %164 : tensor<1x64xi32> loc(#loc28)
...省略
%172 = arith.muli %arg9, %c64_i32_67 : i32 loc(#loc29)
...省略
%179 = arith.muli %172, %c1_i32_71 : i32 loc(#loc30)
%180 = tt.splat %179 : i32 -> tensor<128x64xi32> loc(#loc31)
%181 = tt.addptr %56, %180 : tensor<128x64x!tt.ptr<f32>>, tensor<128x64xi32> loc(#loc31)
%cst_74 = arith.constant 0.000000e+00 : f32 loc(#loc32)
%182 = tt.broadcast %165 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc32)
%cst_75 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc32)
%183 = tt.load %181, %182, %cst_75 : tensor<128x64x!tt.ptr<f32>> loc(#loc32)
%184 = tt.expand_dims %34 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc33)
%185 = tt.splat %162 : i32 -> tensor<64x1xi32> loc(#loc34)
%186 = arith.cmpi slt, %184, %185 : tensor<64x1xi32> loc(#loc34)
...省略
%193 = arith.muli %arg9, %c64_i32_77 : i32 loc(#loc35)
...省略
%200 = arith.muli %193, %arg7 : i32 loc(#loc36)
%201 = tt.splat %200 : i32 -> tensor<64x64xi32> loc(#loc37)
%202 = tt.addptr %78, %201 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> loc(#loc37)
%cst_82 = arith.constant 0.000000e+00 : f32 loc(#loc38)
%203 = tt.broadcast %186 : tensor<64x1xi1> -> tensor<64x64xi1> loc(#loc38)
%cst_83 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> loc(#loc38)
%204 = tt.load %202, %203, %cst_83 : tensor<64x64x!tt.ptr<f32>> loc(#loc38)
%cst_84 = arith.constant 0.000000e+00 : f32 loc(#loc39)
%cst_85 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc39)
%205 = tt.dot %183, %204, %cst_85, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc39)
%206 = arith.addf %arg10, %205 : tensor<128x64xf32> loc(#loc40)
scf.yield %206 : tensor<128x64xf32> loc(#loc41)
} loc(#loc24)
2、load a
现在这个load比较复杂,我们可以依次跟踪过去,首先是%183 = tt.load %181, %182, %cst_75 : tensor<128x64x!tt.ptr<f32>> loc(#loc32),%181是地址偏移,%182是mask,%cst_75是others,也就是mask为false时的填充值。我们还可以看LoadOp定义中的assemblyFormat来对应。
def TT_LoadOp : TT_Op<"load", [
SameLoadStoreOperandsAndResultShape,
SameLoadStoreOperandsAndResultEncoding,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">,
TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">
]> {
let summary = "Load from a tensor of pointers or from a tensor pointer";
let arguments = (
ins
AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
Optional<TT_BoolLike>:$mask,
Optional<TT_Type>:$other,
DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
OptionalAttr<TT_PaddingOptionAttr>:$padding,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
);
let results = (outs TT_Type:$result);
...
let assemblyFormat = [{
$ptr (`,` $mask^)? (`,` $other^)?
oilist(
`cacheModifier` `=` $cache |
`evictionPolicy` `=` $evict
)
attr-dict `:` type($ptr)
}];
let hasCanonicalizer = 1;
}
Load a的完整MLIR为
// 构造mask
%163 = tt.expand_dims %34 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc27)
%164 = tt.splat %162 : i32 -> tensor<1x64xi32> loc(#loc28)
%165 = arith.cmpi slt, %163, %164 : tensor<1x64xi32> loc(#loc28)
// 计算偏移
%c64_i32_67 = arith.constant 64 : i32 loc(#loc29)
%172 = arith.muli %arg9, %c64_i32_67 : i32 loc(#loc29)
%c1_i32_71 = arith.constant 1 : i32 loc(#loc30)
%179 = arith.muli %172, %c1_i32_71 : i32 loc(#loc30)
%180 = tt.splat %179 : i32 -> tensor<128x64xi32> loc(#loc31)
%181 = tt.addptr %56, %180 : tensor<128x64x!tt.ptr<f32>>, tensor<128x64xi32> loc(#loc31)
// 准备 广播 mask 和 fallback 数据
%182 = tt.broadcast %165 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc32)
%cst_75 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc32)
// 安全 load
%183 = tt.load %181, %182, %cst_75 : tensor<128x64x!tt.ptr<f32>> loc(#loc32)
3、其他部分
将两次load的值块做dot,也就是矩阵乘,再加0
%cst_85 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc39)
%205 = tt.dot %183, %204, %cst_85, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc39)
accumulator是被放进iter_args的,所以还需要yield。这个和之前一样,这个kernel变化的仅是计算部分按块来了。
%85 = scf.for %arg9 = %81 to %82 step %83 iter_args(%arg10 = %79) -> (tensor<128x64xf32>) : i32 {
...
scf.yield %206 : tensor<128x64xf32> loc(#loc41)
} loc(#loc24)
其他IR的理解类似,你还可以丢给chatgpt解读01-source.mlir。此时是完全符合Python DSL语义的,无任何优化。
三、 make_ttir
这是将Python ast得到的MLIR简化的阶段,我们将执行如下流程。
@staticmethod
def make_ttir(mod, metadata, opt, capability):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
if capability // 10 < 9:
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_combine(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
pm.run(mod)
return mod
Pass执行后IR从02-Inliner.mlir到13-TritonLoopUnroll.mlir。
1、Canonicalizer
规范化是一个通用Pass,你会看到它经常出现,它可能存在的操作包含消除冗余操作、简化匹配模式、折叠常量计算、应用定义的Op的canonicalization。
02-Inliner.mlir vs 03-Canonicalizer.mlir,对 triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了化简。
03-Canonicalizer.mlir vs 04-Canonicalizer.mlir,对 triton.language.standard.cdiv__i32__ 做了进一步化简。
04-Canonicalizer.mlir vs 05-Canonicalizer.mlir,变化特别大,循环中没使用的Op都去掉了,并且对triton.language.standard.zeros 和triton.language.standard.cdiv__i32__ 做了inline。
05-Canonicalizer.mlir vs 06-Canonicalizer.mlir,进一步化简。
2、TritonCombineOps
08-Canonicalizer.mlir vs 09-TritonCombineOps.mlir,将tt.dot+arith.addf合并到了一起,所以变成直接在accumulator上通过iter_args加了。
// old
%77 = tt.dot %67, %76, %cst_0, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc38)
%78 = arith.addf %arg10, %77 : tensor<128x64xf32> loc(#loc39)
scf.yield %78 : tensor<128x64xf32> loc(#loc40)
// new
%77 = tt.dot %67, %76, %arg10, inputPrecision = tf32 : tensor<128x64xf32> * tensor<64x64xf32> -> tensor<128x64xf32> loc(#loc38)
scf.yield %77 : tensor<128x64xf32> loc(#loc39)
3、CSE
10-TritonReorderBroadcast.mlir vs 11-CSE.mlir,Common Subexpression Elimination,公共子表达式消除,比如tt.make_range就有很多重复的。
4、最终结果
最终产物为matrix_multiplication_kernel.ttir,也是13-TritonLoopUnroll.mlir。
四、 make_ttgir
这个阶段Pass比较多,但是对我们源码产生变化的也能接受。
@staticmethod
def make_ttgir(mod, metadata, opt, capability):
# Set maxnreg on all kernels, if it was provided.
if opt.maxnreg is not None:
mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
cluster_info = nvidia.ClusterInfo()
if opt.cluster_dims is not None:
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
pm = ir.pass_manager(mod.context)
dump_enabled = pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
# optimize TTGIR
passes.ttgpuir.add_coalesce(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_f32_dot_tc(pm)
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
passes.ttir.add_loop_aware_cse(pm)
if capability // 10 in [8, 9]:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
elif capability // 10 >= 10:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.ttgpuir.add_optimize_accumulator_init(pm)
passes.ttgpuir.add_hoist_tmem_alloc(pm)
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
else:
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_coalesce_async_copy(pm)
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.common.add_symbol_dce(pm)
if capability // 10 >= 9:
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
passes.common.add_sccp(pm)
passes.common.add_canonicalizer(pm)
pm.run(mod)
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
tensordesc_meta = mod.get_tensordesc_metadata()
metadata["tensordesc_meta"] = tensordesc_meta
return mod
Pass执行后IR从14-ConvertTritonToTritonGPU.mlir到62-Canonicalizer.mlir。
1、ConvertTritonToTritonGPU
13-TritonLoopUnroll.mlir vs 14-ConvertTritonToTritonGPU.mlir,这里主要是加上了一些layout。
2、TritonGPUCoalesce
14-ConvertTritonToTritonGPU.mlir vs 15-TritonGPUCoalesce.mlir,访存合并。
3、TritonGPURemoveLayoutConversions
17-TritonGPUPlanCTAPass.mlir vs 18-TritonGPURemoveLayoutConversions.mlir,去除多余的convert_layout,cost-model选择了#blocked5。
// old
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
// new
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4、TritonGPUAccelerateMatmul
19-TritonGPUOptimizeThreadLocality.mlir vs 20-TritonGPUAccelerateMatmul.mlir,这里对tt.dot做了处理。
// old
%57 = tt.load %55, %56, %cst_0 : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc28)
...
%64 = tt.load %62, %63, %cst : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc32)
%65 = ttg.convert_layout %57 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> loc(#loc28)
%66 = ttg.convert_layout %64 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> loc(#loc32)
%67 = tt.dot %65, %66, %arg10, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x64xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf32, #blocked1> loc(#loc33)
scf.yield %67 : tensor<128x64xf32, #blocked1> loc(#loc34)
// new
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
...
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}>
#smem = #ttg.shared_memory
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
...
%57 = tt.load %55, %56, %cst_0 : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc28)
%58 = ttg.local_alloc %57 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc28)
...
%65 = tt.load %63, %64, %cst : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc32)
%66 = ttg.local_alloc %65 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc32)
%67 = ttg.convert_layout %arg10 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked2> loc(#loc33)
%result, %token = ttng.tmem_alloc %67 : (tensor<128x64xf32, #blocked2>) -> (!ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) loc(#loc33)
%68 = ttng.tc_gen5_mma %58, %66, %result[%token], %true, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc33)
%result_2, %token_3 = ttng.tmem_load %result[%68] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked2> loc(#loc33)
%69 = ttg.convert_layout %result_2 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked1> loc(#loc33)
scf.yield %69 : tensor<128x64xf32, #blocked1> loc(#loc34)
ttng即TritonNvidiaGPU,转换后包含分配 shared memory(alloc) 和 tensor memory(tmem),执行ttng.tc_gen5_mma操作,最后再从tensor memory取回来。
5、TritonGPURemoveLayoutConversions
20-TritonGPUAccelerateMatmul.mlir vs 21-TritonGPURemoveLayoutConversions.mlir,进一步合并layout。
// old
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// new
#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
6、TritonLoopAwareCSE
24-TritonNvidiaGPUOptimizeDescriptorEncodingPass.mlir vs 25-TritonLoopAwareCSE.mlir,把有些make_range之类的都合并了
// old
%7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc8)
...
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc13)
%16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> loc(#loc13)
// new
%7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc8)
...
%15 = tt.expand_dims %7 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> loc(#loc13)
7、TritonGPUOptimizeAccumulatorInit
28-TritonLoopInvariantCodeMotion.mlir vs 29-TritonGPUOptimizeAccumulatorInit.mlir,变换修改了ttng.tc_gen5_mma的operand。
// old
%31 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %cst_1) -> (tensor<128x64xf32, #blocked1>) : i32 {
...
%64 = ttng.tc_gen5_mma %55, %63, %result[%token], %true, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc33)
%result_2, %token_3 = ttng.tmem_load %result[%64] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> loc(#loc33)
scf.yield %result_2 : tensor<128x64xf32, #blocked1> loc(#loc34)
} loc(#loc23)
...
%46 = ttg.convert_layout %31 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked> loc(#loc41)
// new
%31:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %cst_1, %arg11 = %false) -> (tensor<128x64xf32, #blocked1>, i1) : i32 {
...
%64 = ttng.tc_gen5_mma %55, %63, %result[%token], %arg11, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc33)
%result_2, %token_3 = ttng.tmem_load %result[%64] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> loc(#loc33)
scf.yield %result_2, %true : tensor<128x64xf32, #blocked1>, i1 loc(#loc34)
} loc(#loc23)
%46 = ttg.convert_layout %31#0 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked> loc(#loc41)
8、TritonGPUHoistTMEMAlloc
29-TritonGPUOptimizeAccumulatorInit.mlir vs 30-TritonGPUHoistTMEMAlloc.mlir,将tensor memory的alloc挪到循环外面,避免不必要的开销。
// new
%result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) loc(#loc23)
%31 = ttng.tmem_store %cst_1, %result[%token], %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
%32:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %false, %arg11 = %31) -> (i1, !ttg.async.token) : i32 {
...
%65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
scf.yield %true, %65 : i1, !ttg.async.token loc(#loc34)
} loc(#loc24)
%result_2, %token_3 = ttng.tmem_load %result[%32#1] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1> loc(#loc23)
9、TritonGPUAssignLatencies
31-TritonNvidiaGPUPromoteLHSToTMemPass.mlir vs 32-TritonGPUAssignLatencies.mlir,为相应的Op打上latency标签。
// new
%55 = tt.load %53, %54, %cst_0 {tt.latency = 2 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc29)
...
%63 = tt.load %61, %62, %cst {tt.latency = 2 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc33)
%64 = ttg.local_alloc %63 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc33)
%65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
10、TritonGPUScheduleLoops
32-TritonGPUAssignLatencies.mlir vs 33-TritonGPUScheduleLoops.mlir,software pipeline loop scheduling,我最爱的软流水是这个Pass,涉及到GPU核心了,隐藏延迟。
// old
%32:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %false, %arg11 = %31) -> (i1, !ttg.async.token) : i32 {
%48 = arith.muli %arg9, %c64_i32 : i32 loc(#loc25)
%49 = arith.subi %arg4, %48 : i32 loc(#loc26)
%50 = tt.splat %49 : i32 -> tensor<1x64xi32, #blocked> loc(#loc27)
%51 = arith.cmpi slt, %15, %50 : tensor<1x64xi32, #blocked> loc(#loc27)
%52 = tt.splat %48 : i32 -> tensor<128x64xi32, #blocked> loc(#loc28)
%53 = tt.addptr %18, %52 : tensor<128x64x!tt.ptr<f32>, #blocked>, tensor<128x64xi32, #blocked> loc(#loc28)
%54 = tt.broadcast %51 : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> loc(#loc29)
%55 = tt.load %53, %54, %cst_0 {tt.latency = 2 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc29)
%56 = ttg.local_alloc %55 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc29)
%57 = tt.splat %49 : i32 -> tensor<64x1xi32, #blocked> loc(#loc30)
%58 = arith.cmpi slt, %20, %57 : tensor<64x1xi32, #blocked> loc(#loc30)
%59 = arith.muli %48, %arg7 : i32 loc(#loc31)
%60 = tt.splat %59 : i32 -> tensor<64x64xi32, #blocked> loc(#loc32)
%61 = tt.addptr %28, %60 : tensor<64x64x!tt.ptr<f32>, #blocked>, tensor<64x64xi32, #blocked> loc(#loc32)
%62 = tt.broadcast %58 : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked> loc(#loc33)
%63 = tt.load %61, %62, %cst {tt.latency = 2 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc33)
%64 = ttg.local_alloc %63 : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc33)
%65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
scf.yield %true, %65 : i1, !ttg.async.token loc(#loc34)
} loc(#loc24)
// new
%32:2 = scf.for %arg9 = %c0_i32 to %30 step %c1_i32 iter_args(%arg10 = %false, %arg11 = %31) -> (i1, !ttg.async.token) : i32 {
%48 = arith.muli %arg9, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 loc(#loc25)
%49 = arith.subi %arg4, %48 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 loc(#loc26)
%50 = tt.splat %49 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<1x64xi32, #blocked> loc(#loc27)
%51 = arith.cmpi slt, %15, %50 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x64xi32, #blocked> loc(#loc27)
%52 = tt.splat %48 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<128x64xi32, #blocked> loc(#loc28)
%53 = tt.addptr %18, %52 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked>, tensor<128x64xi32, #blocked> loc(#loc28)
%54 = tt.broadcast %51 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> loc(#loc29)
%55 = tt.load %53, %54, %cst_0 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x64x!tt.ptr<f32>, #blocked> loc(#loc29)
%56 = ttg.local_alloc %55 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #shared, #smem> loc(#loc29)
%57 = tt.splat %49 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<64x1xi32, #blocked> loc(#loc30)
%58 = arith.cmpi slt, %20, %57 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x1xi32, #blocked> loc(#loc30)
%59 = arith.muli %48, %arg7 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 loc(#loc31)
%60 = tt.splat %59 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 -> tensor<64x64xi32, #blocked> loc(#loc32)
%61 = tt.addptr %28, %60 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked>, tensor<64x64xi32, #blocked> loc(#loc32)
%62 = tt.broadcast %58 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked> loc(#loc33)
%63 = tt.load %61, %62, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<64x64x!tt.ptr<f32>, #blocked> loc(#loc33)
%64 = ttg.local_alloc %63 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<64x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared1, #smem> loc(#loc33)
%65 = ttng.tc_gen5_mma %56, %64, %result[%arg11], %arg10, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x64xf32, #shared1, #smem>, !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc23)
scf.yield %true, %65 : i1, !ttg.async.token loc(#loc34)
} {tt.scheduled_max_stage = 2 : i32} loc(#loc24)
11、SCCP
37-TritonGPURewritePartitionDependencies.mlir vs 38-SCCP.mlir,Sparse Conditional Constant Propagation,CFG中传播常量,依旧是调整顺序。
12、TritonGPUPipeline
42-TritonGPUScheduleLoops.mlir vs 43-TritonGPUPipeline.mlir,GPU流水,软件流水生效了,我们真正到了硬件层面的调度。以下是Pass的description
Applies software pipelining to loops in the module based on number of stages.
This may convert some load into asynchronous loads, and multi-buffer the data.
第一个shared改变了
//old
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
//new
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
还会增加local_dealloc,减少shared memory的使用。还有async_wait控制异步操作的同步点
// old
%result_2, %token_3 = ttng.tmem_load %result[%32#1] : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked> loc(#loc23)
// new
%111 = arith.cmpi sgt, %35, %c0_i32 : i32 loc(#loc24)
%112:16 = scf.if %111 -> (i1, !ttg.async.token, i32, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, i32, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64>) {
ttng.wait_barrier %110#12, %110#13 deps %110#14, %110#15 : !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64> loc(#loc23)
scf.yield %true, %110#1, %4, %4, %4, %4, %110#7, %c3_i32, %110#9, %3, %110#11, %3, %0, %110#2, %2, %1 : i1, !ttg.async.token, i32, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, i32, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64> loc(#loc24)
} else {
scf.yield %110#0, %110#1, %110#2, %110#3, %110#4, %110#5, %110#6, %110#7, %110#8, %110#9, %110#10, %110#11, %110#12, %110#13, %110#14, %110#15 : i1, !ttg.async.token, i32, i32, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.memdesc<1xi64, #shared, #smem, mutable, 2>, i32, !ttg.memdesc<128x64xf32, #shared2, #smem, mutable, 3x128x64>, !ttg.memdesc<64x64xf32, #shared1, #smem, mutable, 3x64x64> loc(#loc24)
} loc(#loc24)
%113 = ttg.async_wait {num = 0 : i32} loc(#loc24)
ttg.local_dealloc %41 : !ttg.memdesc<3x64x64xf32, #shared1, #smem, mutable> loc(#loc24)
ttg.local_dealloc %40 : !ttg.memdesc<3x128x64xf32, #shared2, #smem, mutable> loc(#loc24)
%114 = ttg.memdesc_subview %37[%c0_i32] : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
ttng.inval_barrier %114 : !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
%115 = ttg.memdesc_subview %37[%c1_i32] : !ttg.memdesc<2xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
ttng.inval_barrier %115 : !ttg.memdesc<1xi64, #shared, #smem, mutable, 2> loc(#loc24)
ttg.local_dealloc %37 : !ttg.memdesc<2xi64, #shared, #smem, mutable> loc(#loc24)
循环和循环前也发生了非常大的变化,感兴趣可以自己去看。
13、TritonNvidiaGPURemoveTMEMTokensPass
44-TritonGPUCombineTensorSelectAndIf.mlir vs 45-TritonNvidiaGPURemoveTMEMTokensPass.mlir,会多了个ub.poison,下一个CanonicalizerPass就把这个和之前的5个ub.poison全都干掉了。
14、TritonLoopAwareCSE
46-Canonicalizer.mlir vs 47-TritonLoopAwareCSE.mlir,去掉了ttg.memdesc_subview.
15、TritonNvidiaGPUInterleaveTMemPass
53-TritonGPURemoveLayoutConversions.mlir vs 54-TritonNvidiaGPUInterleaveTMemPass.mlir,调整了下ttng.tmem_load的位置。
16、SCCP
60-TritonGPUFenceInsertion.mlir vs 61-SCCP.mlir,又调整了次顺序。
17、最终产物
最终产物为matrix_multiplication_kernel.ttgir,也是62-Canonicalizer.mlir。
五、make_llir
要从ttgir到llvm ir。
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
if knobs.compilation.enable_experimental_consan:
# Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
passes.ttgpuir.add_concurrency_sanitizer(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if not knobs.compilation.disable_line_info:
passes.llvmir.add_di_scope(pm)
pm.run(mod)
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
if knobs.compilation.enable_asan:
raise RuntimeError(
"Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
llvm_mod = llvm.to_module(mod, context)
proc = sm_arch_from_capability(capability)
features = get_features(options, self.target.arch)
triple = 'nvptx64-nvidia-cuda'
nvidia.set_short_ptr()
llvm.attach_datalayout(llvm_mod, triple, proc, features)
nvidia.set_nvvm_reflect_ftz(llvm_mod)
if options.extern_libs:
paths = [path for (name, path) in options.extern_libs]
llvm.link_extern_libs(llvm_mod, paths)
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
# Get some metadata
# warp-specialization mutates num_warps
total_num_warps = src.get_int_attr("ttg.total-num-warps")
if total_num_warps is not None:
metadata["num_warps"] = total_num_warps
metadata["shared"] = src.get_int_attr("ttg.shared")
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
ret = str(llvm_mod)
del llvm_mod
del context
return ret
Pass执行后的IR从63-TritonNvidiaGPUMMALoweringPass.mlir到80-LLVMDIScope.mlir。
1、TritonGPUAllocateWarpGroups
64-TritonGPUCombineTensorSelectAndIf.mlir vs 65-TritonGPUAllocateWarpGroups.mlir,module上添加了ttg.total-num-warps。
2、SCFToControlFlowPass
65-TritonGPUAllocateWarpGroups.mlir vs 66-SCFToControlFlowPass.mlir,将scf.forlower到cf.br。
3、AllocateSharedMemory
66-SCFToControlFlowPass.mlir vs 67-AllocateSharedMemory.mlir ,module上也添加了ttg.shared大小描述, ttg.shared = 147472 : i32。
4、TritonTensorMemoryAllocationPass
67-AllocateSharedMemory.mlir vs 68-TritonTensorMemoryAllocationPass.mlir,在module上添加了ttg.tensor_memory_size = 64 : i32。
5、TritonGPUGlobalScratchAllocationPass
68-TritonTensorMemoryAllocationPass.mlir vs 69-TritonGPUGlobalScratchAllocationPass.mlir,加上了global_scratch的相关信息。
6、ConvertTritonGPUToLLVM
70-TritonGPUProxyFenceInsertion.mlir vs 71-ConvertTritonGPUToLLVM.mlir,这个Pass非常复杂,需要看每个Op的RewritePattern,这一步转换到线程级别了。
7、Canonicalizer
71-ConvertTritonGPUToLLVM.mlir vs 72-Canonicalizer.mlir,规范化Pass,优化生猛,IR从24154行降到了8648行。
8、CSE
72-Canonicalizer.mlir vs 73-CSE.mlir,公共子表达式消除,IR可以从8648行降到2956行。
9、ConvertNVGPUToLLVM
73-CSE.mlir vs 74-ConvertNVGPUToLLVM.mlir,代码中多了有nvvm.read.ptx.sreg.tid.x这些代码。
10、Canonicalizer
76-ReconcileUnrealizedCastsPass.mlir vs 77-Canonicalizer.mlir,调整了下顺序,减少了几行。
11、CSE
77-Canonicalizer.mlir vs 78-CSE.mlir,公共子表达式消除。
12、LLVMDIScope
79-SymbolDCE vs 80-LLVMDIScope,在 LLVM IR 中附加调试信息作用域(Debug Info Scope) 的 pass,生成调试信息(如 DWARF)以支持源级调试。
13、最终产物
最终产物为 matrix_multiplication_kernel.llir
六、make_ptx
这里实际上调用的是LLVM。
def make_ptx(self, src, metadata, opt, capability):
ptx_version = get_ptx_version_from_options(opt, self.target.arch)
triple = 'nvptx64-nvidia-cuda'
proc = sm_arch_from_capability(capability)
features = get_features(opt, self.target.arch)
ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
# Find kernel names (there should only be one)
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
assert len(names) == 1
metadata["name"] = names[0]
# post-process
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
# Remove the debug flag that prevents ptxas from optimizing the code
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
if knobs.nvidia.dump_nvptx:
print("// -----// NVPTX Dump //----- //")
print(ret)
return ret
产物为matrix_multiplication_kernel.ptx。输入文件是1683行,输出文件是2166行,基本一一对应。
1、dot对照
// old
%494 = icmp eq i32 %178, 0, !dbg !28
%495 = and i1 %186, %494, !dbg !28
br i1 %495, label %496, label %563, !dbg !28
496: ; preds = %10
%497 = tail call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1), !dbg !28
%498 = extractvalue { i32, i1 } %497, 1, !dbg !28
%499 = lshr exact i32 ptrtoint (ptr addrspace(3) @global_smem to i32), 4, !dbg !28
%500 = and i32 %499, 16383, !dbg !28
%501 = zext nneg i32 %500 to i64, !dbg !28
%502 = or disjoint i64 %501, 4611686293372403712, !dbg !28
%503 = lshr exact i32 ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), 4, !dbg !28
%504 = and i32 %503, 16383, !dbg !28
%505 = zext nneg i32 %504 to i64, !dbg !28
%506 = or disjoint i64 %505, 4611686293338849280, !dbg !28
tail call void asm sideeffect "@$5 tcgen05.mma.cta_group::1.kind::tf32 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b"(i32 %13, i64 %502, i64 %506, i32 135268624, i1 false, i1 %498) #4, !dbg !28
...
!28 = !DILocation(line: 38, column: 33, scope: !5)
// new
.loc 1 38 33 // matmul.py:38:33
setp.ne.s32 %p27, %r32, 0;
or.pred %p28, %p6, %p27;
@%p28 bra $L__BB0_2;
// %bb.1:
elect.sync %r530|%p30, -1;
bfe.u32 %r532, %r103, 4, 14;
cvt.u64.u32 %rd197, %r532;
or.b64 %rd180, %rd197, 4611686293372403712;
bfe.u32 %r534, %r464, 4, 14;
cvt.u64.u32 %rd198, %r534;
or.b64 %rd181, %rd198, 4611686293338849280;
mov.b32 %r515, 135268624;
mov.pred %p29, 0;
// begin inline asm
@%p30 tcgen05.mma.cta_group::1.kind::tf32 [ %r1077 + 0 ], %rd180, %rd181, %r515, %p29;
// end inline asm
2、store对照
// old
...
%1222 = getelementptr inbounds nuw i8, ptr addrspace(3) %1192, i32 1920, !dbg !44
%1223 = load <4 x i32>, ptr addrspace(3) %1222, align 16, !dbg !44
%.extract = extractelement <4 x i32> %1193, i64 0, !dbg !44
%.extract64 = extractelement <4 x i32> %1193, i64 1, !dbg !44
%.extract65 = extractelement <4 x i32> %1193, i64 2, !dbg !44
%.extract66 = extractelement <4 x i32> %1193, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract, i32 %.extract64, i32 %.extract65, i32 %.extract66, ptr addrspace(1) %981, i1 %1014) #4, !dbg !44
%.extract67 = extractelement <4 x i32> %1195, i64 0, !dbg !44
%.extract68 = extractelement <4 x i32> %1195, i64 1, !dbg !44
%.extract69 = extractelement <4 x i32> %1195, i64 2, !dbg !44
%.extract70 = extractelement <4 x i32> %1195, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract67, i32 %.extract68, i32 %.extract69, i32 %.extract70, ptr addrspace(1) %982, i1 %1015) #4, !dbg !44
%.extract71 = extractelement <4 x i32> %1197, i64 0, !dbg !44
%.extract72 = extractelement <4 x i32> %1197, i64 1, !dbg !44
%.extract73 = extractelement <4 x i32> %1197, i64 2, !dbg !44
%.extract74 = extractelement <4 x i32> %1197, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract71, i32 %.extract72, i32 %.extract73, i32 %.extract74, ptr addrspace(1) %983, i1 %1016) #4, !dbg !44
%.extract75 = extractelement <4 x i32> %1199, i64 0, !dbg !44
%.extract76 = extractelement <4 x i32> %1199, i64 1, !dbg !44
%.extract77 = extractelement <4 x i32> %1199, i64 2, !dbg !44
%.extract78 = extractelement <4 x i32> %1199, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract75, i32 %.extract76, i32 %.extract77, i32 %.extract78, ptr addrspace(1) %984, i1 %1017) #4, !dbg !44
%.extract79 = extractelement <4 x i32> %1201, i64 0, !dbg !44
%.extract80 = extractelement <4 x i32> %1201, i64 1, !dbg !44
%.extract81 = extractelement <4 x i32> %1201, i64 2, !dbg !44
%.extract82 = extractelement <4 x i32> %1201, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract79, i32 %.extract80, i32 %.extract81, i32 %.extract82, ptr addrspace(1) %985, i1 %1018) #4, !dbg !44
%.extract83 = extractelement <4 x i32> %1203, i64 0, !dbg !44
%.extract84 = extractelement <4 x i32> %1203, i64 1, !dbg !44
%.extract85 = extractelement <4 x i32> %1203, i64 2, !dbg !44
%.extract86 = extractelement <4 x i32> %1203, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract83, i32 %.extract84, i32 %.extract85, i32 %.extract86, ptr addrspace(1) %986, i1 %1019) #4, !dbg !44
%.extract87 = extractelement <4 x i32> %1205, i64 0, !dbg !44
%.extract88 = extractelement <4 x i32> %1205, i64 1, !dbg !44
%.extract89 = extractelement <4 x i32> %1205, i64 2, !dbg !44
%.extract90 = extractelement <4 x i32> %1205, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract87, i32 %.extract88, i32 %.extract89, i32 %.extract90, ptr addrspace(1) %987, i1 %1020) #4, !dbg !44
%.extract91 = extractelement <4 x i32> %1207, i64 0, !dbg !44
%.extract92 = extractelement <4 x i32> %1207, i64 1, !dbg !44
%.extract93 = extractelement <4 x i32> %1207, i64 2, !dbg !44
%.extract94 = extractelement <4 x i32> %1207, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract91, i32 %.extract92, i32 %.extract93, i32 %.extract94, ptr addrspace(1) %988, i1 %1021) #4, !dbg !44
%.extract95 = extractelement <4 x i32> %1209, i64 0, !dbg !44
%.extract96 = extractelement <4 x i32> %1209, i64 1, !dbg !44
%.extract97 = extractelement <4 x i32> %1209, i64 2, !dbg !44
%.extract98 = extractelement <4 x i32> %1209, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract95, i32 %.extract96, i32 %.extract97, i32 %.extract98, ptr addrspace(1) %989, i1 %1022) #4, !dbg !44
%.extract99 = extractelement <4 x i32> %1211, i64 0, !dbg !44
%.extract100 = extractelement <4 x i32> %1211, i64 1, !dbg !44
%.extract101 = extractelement <4 x i32> %1211, i64 2, !dbg !44
%.extract102 = extractelement <4 x i32> %1211, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract99, i32 %.extract100, i32 %.extract101, i32 %.extract102, ptr addrspace(1) %990, i1 %1023) #4, !dbg !44
%.extract103 = extractelement <4 x i32> %1213, i64 0, !dbg !44
%.extract104 = extractelement <4 x i32> %1213, i64 1, !dbg !44
%.extract105 = extractelement <4 x i32> %1213, i64 2, !dbg !44
%.extract106 = extractelement <4 x i32> %1213, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract103, i32 %.extract104, i32 %.extract105, i32 %.extract106, ptr addrspace(1) %991, i1 %1024) #4, !dbg !44
%.extract107 = extractelement <4 x i32> %1215, i64 0, !dbg !44
%.extract108 = extractelement <4 x i32> %1215, i64 1, !dbg !44
%.extract109 = extractelement <4 x i32> %1215, i64 2, !dbg !44
%.extract110 = extractelement <4 x i32> %1215, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract107, i32 %.extract108, i32 %.extract109, i32 %.extract110, ptr addrspace(1) %992, i1 %1025) #4, !dbg !44
%.extract111 = extractelement <4 x i32> %1217, i64 0, !dbg !44
%.extract112 = extractelement <4 x i32> %1217, i64 1, !dbg !44
%.extract113 = extractelement <4 x i32> %1217, i64 2, !dbg !44
%.extract114 = extractelement <4 x i32> %1217, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract111, i32 %.extract112, i32 %.extract113, i32 %.extract114, ptr addrspace(1) %993, i1 %1026) #4, !dbg !44
%.extract115 = extractelement <4 x i32> %1219, i64 0, !dbg !44
%.extract116 = extractelement <4 x i32> %1219, i64 1, !dbg !44
%.extract117 = extractelement <4 x i32> %1219, i64 2, !dbg !44
%.extract118 = extractelement <4 x i32> %1219, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract115, i32 %.extract116, i32 %.extract117, i32 %.extract118, ptr addrspace(1) %994, i1 %1027) #4, !dbg !44
%.extract119 = extractelement <4 x i32> %1221, i64 0, !dbg !44
%.extract120 = extractelement <4 x i32> %1221, i64 1, !dbg !44
%.extract121 = extractelement <4 x i32> %1221, i64 2, !dbg !44
%.extract122 = extractelement <4 x i32> %1221, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract119, i32 %.extract120, i32 %.extract121, i32 %.extract122, ptr addrspace(1) %995, i1 %1028) #4, !dbg !44
%.extract123 = extractelement <4 x i32> %1223, i64 0, !dbg !44
%.extract124 = extractelement <4 x i32> %1223, i64 1, !dbg !44
%.extract125 = extractelement <4 x i32> %1223, i64 2, !dbg !44
%.extract126 = extractelement <4 x i32> %1223, i64 3, !dbg !44
tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %.extract123, i32 %.extract124, i32 %.extract125, i32 %.extract126, ptr addrspace(1) %996, i1 %1029) #4, !dbg !44
...
!44 = !DILocation(line: 45, column: 21, scope: !5)
// new
.loc 1 45 21 // matmul.py:45:21
...
ld.shared.v4.b32 {%r1073, %r1074, %r1075, %r1076}, [%r1123+1920];
// begin inline asm
@%p92 st.global.v4.b32 [ %rd346 + 0 ], { %r1013, %r1014, %r1015, %r1016 };
// end inline asm
// begin inline asm
@%p93 st.global.v4.b32 [ %rd347 + 0 ], { %r1017, %r1018, %r1019, %r1020 };
// end inline asm
// begin inline asm
@%p94 st.global.v4.b32 [ %rd348 + 0 ], { %r1021, %r1022, %r1023, %r1024 };
// end inline asm
// begin inline asm
@%p95 st.global.v4.b32 [ %rd349 + 0 ], { %r1025, %r1026, %r1027, %r1028 };
// end inline asm
// begin inline asm
@%p96 st.global.v4.b32 [ %rd350 + 0 ], { %r1029, %r1030, %r1031, %r1032 };
// end inline asm
// begin inline asm
@%p97 st.global.v4.b32 [ %rd351 + 0 ], { %r1033, %r1034, %r1035, %r1036 };
// end inline asm
// begin inline asm
@%p98 st.global.v4.b32 [ %rd352 + 0 ], { %r1037, %r1038, %r1039, %r1040 };
// end inline asm
// begin inline asm
@%p99 st.global.v4.b32 [ %rd353 + 0 ], { %r1041, %r1042, %r1043, %r1044 };
// end inline asm
// begin inline asm
@%p100 st.global.v4.b32 [ %rd354 + 0 ], { %r1045, %r1046, %r1047, %r1048 };
// end inline asm
// begin inline asm
@%p101 st.global.v4.b32 [ %rd355 + 0 ], { %r1049, %r1050, %r1051, %r1052 };
// end inline asm
// begin inline asm
@%p102 st.global.v4.b32 [ %rd356 + 0 ], { %r1053, %r1054, %r1055, %r1056 };
// end inline asm
// begin inline asm
@%p103 st.global.v4.b32 [ %rd357 + 0 ], { %r1057, %r1058, %r1059, %r1060 };
// end inline asm
// begin inline asm
@%p104 st.global.v4.b32 [ %rd358 + 0 ], { %r1061, %r1062, %r1063, %r1064 };
// end inline asm
// begin inline asm
@%p105 st.global.v4.b32 [ %rd359 + 0 ], { %r1065, %r1066, %r1067, %r1068 };
// end inline asm
// begin inline asm
@%p106 st.global.v4.b32 [ %rd360 + 0 ], { %r1069, %r1070, %r1071, %r1072 };
// end inline asm
// begin inline asm
@%p107 st.global.v4.b32 [ %rd361 + 0 ], { %r1073, %r1074, %r1075, %r1076 };
// end inline asm
3、load对照
// old
%189 = and i32 %11, 1, !dbg !31
%190 = icmp eq i32 %189, 0, !dbg !31
%191 = and i32 %44, 28, !dbg !31
%192 = shl nuw nsw i32 %11, 9, !dbg !31
%193 = and i32 %192, 4096, !dbg !31
%194 = or disjoint i32 %191, %193, !dbg !31
%195 = and i32 %11, 16, !dbg !31
%.not = icmp eq i32 %195, 0, !dbg !31
%196 = select i1 %.not, i32 0, i32 36, !dbg !31
%197 = and i32 %11, 32, !dbg !31
%198 = icmp eq i32 %197, 0, !dbg !31
%199 = select i1 %198, i32 0, i32 72, !dbg !31
%200 = and i32 %11, 64, !dbg !31
%201 = icmp eq i32 %200, 0, !dbg !31
%202 = select i1 %201, i32 0, i32 144, !dbg !31
%203 = or disjoint i32 %199, %196, !dbg !31
%204 = xor i32 %203, %194, !dbg !31
%205 = xor i32 %204, %202, !dbg !31
%206 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %205, !dbg !31
%207 = or disjoint i32 %194, 256, !dbg !31
%208 = or disjoint i32 %203, %202, !dbg !31
%209 = xor i32 %208, %207, !dbg !31
%210 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %209, !dbg !31
%211 = or disjoint i32 %194, 512, !dbg !31
%212 = xor i32 %208, %211, !dbg !31
%213 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %212, !dbg !31
%214 = or disjoint i32 %194, 768, !dbg !31
%215 = xor i32 %208, %214, !dbg !31
%216 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %215, !dbg !31
%217 = or disjoint i32 %194, 1024, !dbg !31
%218 = xor i32 %208, %217, !dbg !31
%219 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %218, !dbg !31
%220 = or disjoint i32 %194, 1280, !dbg !31
%221 = xor i32 %208, %220, !dbg !31
%222 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %221, !dbg !31
%223 = or disjoint i32 %194, 1536, !dbg !31
%224 = xor i32 %208, %223, !dbg !31
%225 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %224, !dbg !31
%226 = or disjoint i32 %194, 1792, !dbg !31
%227 = xor i32 %208, %226, !dbg !31
%228 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %227, !dbg !31
%229 = or disjoint i32 %194, 2048, !dbg !31
%230 = xor i32 %208, %229, !dbg !31
%231 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %230, !dbg !31
%232 = or disjoint i32 %194, 2304, !dbg !31
%233 = xor i32 %208, %232, !dbg !31
%234 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %233, !dbg !31
%235 = or disjoint i32 %194, 2560, !dbg !31
%236 = xor i32 %208, %235, !dbg !31
%237 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %236, !dbg !31
%238 = or disjoint i32 %194, 2816, !dbg !31
%239 = xor i32 %208, %238, !dbg !31
%240 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %239, !dbg !31
%241 = or disjoint i32 %194, 3072, !dbg !31
%242 = xor i32 %208, %241, !dbg !31
%243 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %242, !dbg !31
%244 = or disjoint i32 %194, 3328, !dbg !31
%245 = xor i32 %208, %244, !dbg !31
%246 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %245, !dbg !31
%247 = or disjoint i32 %194, 3584, !dbg !31
%248 = xor i32 %208, %247, !dbg !31
%249 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %248, !dbg !31
%250 = or disjoint i32 %194, 3840, !dbg !31
%251 = xor i32 %208, %250, !dbg !31
%252 = getelementptr inbounds nuw float, ptr addrspace(3) @global_smem, i32 %251, !dbg !31
%253 = select i1 %188, i32 16, i32 0, !dbg !31
...
%838 = select i1 %821, i32 16, i32 0, !dbg !31
tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %822, ptr addrspace(1) %804, i32 %838) #4, !dbg !31
tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %823, ptr addrspace(1) %805, i32 %838) #4, !dbg !31
tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %824, ptr addrspace(1) %806, i32 %838) #4, !dbg !31
tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %825, ptr addrspace(1) %807, i32 %838) #4, !dbg !31
...
!31 = !DILocation(line: 35, column: 20, scope: !5)
// new
.loc 1 35 20 // matmul.py:35:20
and.b32 %r406, %r1, 1;
neg.s32 %r407, %r406;
shl.b32 %r408, %r1, 9;
or.b32 %r409, %r368, %r408;
and.b32 %r410, %r409, 4124;
bfe.s32 %r411, %r1, 4, 1;
and.b32 %r412, %r411, 36;
and.b32 %r413, %r1, 32;
bfe.s32 %r414, %r1, 5, 1;
and.b32 %r415, %r414, 72;
and.b32 %r416, %r1, 64;
bfe.s32 %r417, %r1, 6, 1;
and.b32 %r418, %r417, 144;
or.b32 %r419, %r415, %r412;
xor.b32 %r420, %r419, %r410;
xor.b32 %r34, %r420, %r418;
shl.b32 %r421, %r34, 2;
add.s32 %r171, %r103, %r421;
or.b32 %r422, %r410, 256;
or.b32 %r423, %r419, %r418;
xor.b32 %r35, %r423, %r422;
shl.b32 %r424, %r35, 2;
add.s32 %r173, %r103, %r424;
or.b32 %r425, %r410, 512;
xor.b32 %r36, %r423, %r425;
shl.b32 %r426, %r36, 2;
add.s32 %r175, %r103, %r426;
or.b32 %r427, %r410, 768;
xor.b32 %r37, %r423, %r427;
shl.b32 %r428, %r37, 2;
add.s32 %r177, %r103, %r428;
or.b32 %r429, %r410, 1024;
xor.b32 %r38, %r423, %r429;
shl.b32 %r430, %r38, 2;
add.s32 %r179, %r103, %r430;
or.b32 %r431, %r410, 1280;
xor.b32 %r39, %r423, %r431;
shl.b32 %r432, %r39, 2;
add.s32 %r181, %r103, %r432;
or.b32 %r433, %r410, 1536;
xor.b32 %r40, %r423, %r433;
shl.b32 %r434, %r40, 2;
add.s32 %r183, %r103, %r434;
or.b32 %r435, %r410, 1792;
xor.b32 %r41, %r423, %r435;
shl.b32 %r436, %r41, 2;
add.s32 %r185, %r103, %r436;
or.b32 %r437, %r410, 2048;
xor.b32 %r42, %r423, %r437;
shl.b32 %r438, %r42, 2;
add.s32 %r187, %r103, %r438;
or.b32 %r439, %r410, 2304;
xor.b32 %r43, %r423, %r439;
shl.b32 %r440, %r43, 2;
add.s32 %r189, %r103, %r440;
or.b32 %r441, %r410, 2560;
xor.b32 %r44, %r423, %r441;
shl.b32 %r442, %r44, 2;
add.s32 %r191, %r103, %r442;
or.b32 %r443, %r410, 2816;
xor.b32 %r45, %r423, %r443;
shl.b32 %r444, %r45, 2;
add.s32 %r193, %r103, %r444;
or.b32 %r445, %r410, 3072;
xor.b32 %r46, %r423, %r445;
shl.b32 %r446, %r46, 2;
add.s32 %r195, %r103, %r446;
or.b32 %r447, %r410, 3328;
xor.b32 %r47, %r423, %r447;
shl.b32 %r448, %r47, 2;
add.s32 %r197, %r103, %r448;
or.b32 %r449, %r410, 3584;
xor.b32 %r48, %r423, %r449;
shl.b32 %r450, %r48, 2;
add.s32 %r199, %r103, %r450;
or.b32 %r451, %r410, 3840;
xor.b32 %r49, %r423, %r451;
shl.b32 %r452, %r49, 2;
add.s32 %r201, %r103, %r452;
selp.b32 %r453, 16, 0, %p7;
selp.b32 %r174, %r453, 0, %p8;
...
selp.b32 %r453, 16, 0, %p7;
selp.b32 %r174, %r453, 0, %p8;
// begin inline asm
cp.async.cg.shared.global [ %r171 + 0 ], [ %rd31 + 0 ], 0x10, %r174;
// end inline asm
// begin inline asm
cp.async.cg.shared.global [ %r173 + 0 ], [ %rd32 + 0 ], 0x10, %r174;
// end inline asm
// begin inline asm
cp.async.cg.shared.global [ %r175 + 0 ], [ %rd33 + 0 ], 0x10, %r174;
// end inline asm
// begin inline asm
cp.async.cg.shared.global [ %r177 + 0 ], [ %rd34 + 0 ], 0x10, %r174;
// end inline asm
...
七、系列文章
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18956171

浙公网安备 33010602011771号