TOPI介绍
Introduction to TOPI
本文介绍TVM算子库(TOPI),TOPI提供numpy-style的通用计算和比TVM更加高度抽象的schedules,本文中将展示如何使用TOPI帮助我们写TVM的样板文件
import tvm
import tvm.testing
from tvm import te
from tvm import topi
import numpy as np
基础例子
首先回顾行求和算子(相当于B = numpy.sum(A, axis=1))为了计算二维TVM张量A的行求和,首先需要指定符号算子和schedule如下
n = te.var("n")
m = te.var("m")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), "k")
B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B")
s = te.create_schedule(B.op)
然后检查IR code
print(tvm.lower(s, [A], simple_mode=True))
输出
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
n, m = T.int32(), T.int32()
A_1 = T.match_buffer(A, (n, m), strides=("stride", "stride"), buffer_type="auto")
B = T.allocate([n], "float32", "global")
for i in range(n):
B_1 = T.Buffer((n,), data=B)
B_1[i] = T.float32(0)
for k in range(m):
A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type="auto")
B_1[i] = B_1[i] + A_2[i * A_1.strides[0] + k * A_1.strides[1]]
但是,对于这种常见的算子,我们必须定义规约轴和精确的计算,使用te.compute,可以想象,对于更多的算子,有大量的细节需要提供。所以,可以使用topi.sum来代替,与numpy.sum更加相似。
C = topi.sum(A, axis=1)
ts = te.create_schedule(C.op)
print(tvm.lower(ts, [A], simple_mode=True))
输出
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
n, m = T.int32(), T.int32()
A_1 = T.match_buffer(A, (n, m), strides=("stride", "stride"), buffer_type="auto")
A_red = T.allocate([n], "float32", "global")
for ax0 in range(n):
A_red_1 = T.Buffer((n,), data=A_red)
A_red_1[ax0] = T.float32(0)
for k1 in range(m):
A_2 = T.Buffer((A_1.strides[0] * n,), data=A_1.data, buffer_type="auto")
A_red_1[ax0] = A_red_1[ax0] + A_2[ax0 * A_1.strides[0] + k1 * A_1.strides[1]]
重载numpy-style算子
我们可以使用topi.broadcast_add将两个张量相加,更加简洁,TOPI提供了算子的重载,常见的算子例如
x, y = 100, 10
a = te.placeholder((x, y, y), name="a")
b = te.placeholder((y, y), name="b")
c = a + b # same as topi.broadcast_add
d = a * b # same as topi.broadcast_mul
重载为相同的格式,TOPI将通过primitive将(int, float)广播为张量d - 3.14
通用schedules和算子融合
截止到现在,我们已经看了如何使用TOPI定义计算,而不需要使用低级API书写精确的计算。但是还没有结束,还需要和之前一样惊醒scheduling,TOPI也提供了高级的scheduling,取决于给定的上下文,例如,对于CUDA,可以调度以topi.sum结尾的一系列操作,只需要topi.generic.schedule_reduce
e = topi.elemwise_sum([c, d])
f = e / 2.0
g = topi.sum(f)
with tvm.target.cuda():
sg = topi.cuda.schedule_reduce(g)
print(tvm.lower(sg, [a, b], simple_mode=True))
输出
/workspace/python/tvm/target/target.py:422: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.
warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.")
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(a: T.Buffer((100, 10, 10), "float32"), b: T.Buffer((10, 10), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
T_divide_red = T.allocate([1], "float32", "global")
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
T_divide_red_rf = T.allocate([1], "float32", "local")
reduce_temp0 = T.allocate([1], "float32", "local")
T_divide_red_rf_1 = T.Buffer((1,), data=T_divide_red_rf, scope="local", align=4)
T_divide_red_rf_1[0] = T.float32(0)
for k0_k1_fused_k2_fused_outer in range(10):
if T.likely(k0_k1_fused_k2_fused_outer * 64 + threadIdx_x // 16 < 625 and k0_k1_fused_k2_fused_outer * 64 + threadIdx_x // 16 < 625 and k0_k1_fused_k2_fused_outer * 64 + threadIdx_x // 16 < 625):
a_1 = T.Buffer((10000,), data=a.data)
b_1 = T.Buffer((100,), data=b.data)
T_divide_red_rf_1[0] = T_divide_red_rf_1[0] + (a_1[k0_k1_fused_k2_fused_outer * 1024 + threadIdx_x] + b_1[(k0_k1_fused_k2_fused_outer * 24 + threadIdx_x) % 100] + a_1[k0_k1_fused_k2_fused_outer * 1024 + threadIdx_x] * b_1[(k0_k1_fused_k2_fused_outer * 24 + threadIdx_x) % 100]) * T.float32(0.5)
reduce_temp0_1 = T.Buffer((1,), data=reduce_temp0, scope="local")
with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
T.tvm_thread_allreduce(T.uint32(1), T_divide_red_rf_1[0], T.bool(True), reduce_temp0_1[0], threadIdx_x)
if threadIdx_x == 0:
T_divide_red_1 = T.Buffer((1,), data=T_divide_red, align=4)
T_divide_red_1[0] = reduce_temp0_1[0]
如你所见,调度的步骤可以检查
print(sg.stages)
输出
[stage(a, placeholder(a, 0x1cc25820)), stage(b, placeholder(b, 0x148a2750)), stage(T_add, compute(T_add, body=[a[ax0, ax1, ax2] + b[ax1, ax2]], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=broadcast, attrs={})), stage(T_multiply, compute(T_multiply, body=[a[ax0, ax1, ax2] * b[ax1, ax2]], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=broadcast, attrs={})), stage(T_elemwise_sum, compute(T_elemwise_sum, body=[T_add[ax0, ax1, ax2] + T_multiply[ax0, ax1, ax2]], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide, compute(T_divide, body=[T_elemwise_sum[ax0, ax1, ax2] / T.float32(2)], axis=[T.iter_var(ax0, T.Range(0, 100), "DataPar", ""), T.iter_var(ax1, T.Range(0, 10), "DataPar", ""), T.iter_var(ax2, T.Range(0, 10), "DataPar", "")], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide_red.rf, compute(T_divide_red.rf, body=[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), source=[T_divide[(k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 // 10, (k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 % 10, (k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) % 10]], init=[], axis=[T.iter_var(k0_k1_fused_k2_fused_outer, T.Range(0, 10), "CommReduce", "")], condition=T.likely((k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 // 10 < 100 and (k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024) // 10 < 1000 and k0_k1_fused_k2_fused_inner + k0_k1_fused_k2_fused_outer * 1024 < 10000), value_index=0)], axis=[T.iter_var(k0_k1_fused_k2_fused_inner, T.Range(0, 1024), "DataPar", "")], reduce_axis=[T.iter_var(k0_k1_fused_k2_fused_outer, T.Range(0, 10), "CommReduce", "")], tag=, attrs={})), stage(T_divide_red, compute(T_divide_red.repl, body=[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), source=[T_divide_red.rf[k0_k1_fused_k2_fused_inner_v]], init=[], axis=[T.iter_var(k0_k1_fused_k2_fused_inner_v, T.Range(0, 1024), "CommReduce", "")], condition=T.bool(True), value_index=0)], axis=[], reduce_axis=[T.iter_var(k0_k1_fused_k2_fused_inner_v, T.Range(0, 1024), "CommReduce", "")], tag=, attrs={}))]
与numpy的计算结果比较,测试正确性
func = tvm.build(sg, [a, b, g], "cuda")
dev = tvm.cuda(0)
a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype)
b_np = np.random.uniform(size=(y, y)).astype(b.dtype)
g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0)
a_nd = tvm.nd.array(a_np, dev)
b_nd = tvm.nd.array(b_np, dev)
g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev)
func(a_nd, b_nd, g_nd)
tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-5)
TOPI还提供了常见的神经网络算子,例如经过优化schedule的_softmax_
tarray = te.placeholder((512, 512), name="tarray")
softmax_topi = topi.nn.softmax(tarray)
with tvm.target.Target("cuda"):
sst = topi.cuda.schedule_softmax(softmax_topi)
print(tvm.lower(sst, [tarray], simple_mode=True))
输出
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(tarray: T.Buffer((512, 512), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
T_softmax_norm = T.allocate([65536], "float32x4", "global")
blockIdx_x = T.launch_thread("blockIdx.x", 512)
normal_reduce_temp0 = T.allocate([1], "float32", "local")
reduce_temp0 = T.allocate([1], "float32", "local")
T_softmax_exp = T.allocate([512], "float32", "warp")
normal_reduce_temp0_1 = T.allocate([1], "float32", "local")
reduce_temp0_1 = T.allocate([1], "float32", "local")
threadIdx_x = T.env_thread("threadIdx.x")
T_softmax_exp_1 = T.Buffer((512,), data=T_softmax_exp, scope="warp")
with T.launch_thread(threadIdx_x, 32):
normal_reduce_temp0_2 = T.Buffer((1,), data=normal_reduce_temp0, scope="local")
normal_reduce_temp0_2[0] = T.float32(-3.4028234663852886e+38)
tarray_1 = T.Buffer((262144,), data=tarray.data)
for k_inner in range(16):
normal_reduce_temp0_2[0] = T.max(normal_reduce_temp0_2[0], tarray_1[blockIdx_x * 512 + threadIdx_x * 16 + k_inner])
with T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e+38)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0, scope="local")
T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0_2[0], T.bool(True), reduce_temp0_2[0], threadIdx_x)
for i1_inner_outer in range(4):
cse_var_1: T.int32 = i1_inner_outer * 4
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0, scope="local", align=4)
T_softmax_exp_1[threadIdx_x * 16 + cse_var_1:threadIdx_x * 16 + cse_var_1 + 4] = T.exp(tarray_1[blockIdx_x * 512 + threadIdx_x * 16 + cse_var_1:blockIdx_x * 512 + threadIdx_x * 16 + cse_var_1 + 4] - T.Broadcast(reduce_temp0_2[0], 4))
T.launch_thread(threadIdx_x, 32)
normal_reduce_temp0_2 = T.Buffer((1,), data=normal_reduce_temp0_1, scope="local")
normal_reduce_temp0_2[0] = T.float32(0)
for k_inner in range(16):
normal_reduce_temp0_2[0] = normal_reduce_temp0_2[0] + T_softmax_exp_1[threadIdx_x * 16 + k_inner]
with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))):
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0_1, scope="local")
T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0_2[0], T.bool(True), reduce_temp0_2[0], threadIdx_x)
for i1_inner_outer in range(4):
T_softmax_norm_1 = T.Buffer((65536,), "float32x4", data=T_softmax_norm)
reduce_temp0_2 = T.Buffer((1,), data=reduce_temp0_1, scope="local", align=4)
T_softmax_norm_1[blockIdx_x * 128 + threadIdx_x * 4 + i1_inner_outer] = T_softmax_exp_1[threadIdx_x * 16 + i1_inner_outer * 4:threadIdx_x * 16 + i1_inner_outer * 4 + 4] / T.Broadcast(reduce_temp0_2[0], 4)
融合卷积
可以将topi.nn.conv2d和topi.nn.relu融合
data = te.placeholder((1, 3, 224, 224))
kernel = te.placeholder((10, 3, 5, 5))
with tvm.target.Target("cuda"):
conv = topi.cuda.conv2d_nchw(data, kernel, 1, 2, 1)
out = topi.nn.relu(conv)
sconv = topi.cuda.schedule_conv2d_nchw([out])
print(tvm.lower(sconv, [data, kernel], simple_mode=True))
输出
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(placeholder: T.Buffer((1, 3, 224, 224), "float32"), placeholder_1: T.Buffer((10, 3, 5, 5), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
compute = T.allocate([501760], "float32", "global")
blockIdx_z = T.launch_thread("blockIdx.z", 5)
conv2d_nchw = T.allocate([14], "float32", "local")
pad_temp_shared = T.allocate([112], "float32", "shared")
placeholder_shared = T.allocate([2], "float32", "shared")
blockIdx_y = T.launch_thread("blockIdx.y", 224)
blockIdx_x = T.launch_thread("blockIdx.x", 2)
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 1)
threadIdx_x = T.launch_thread("threadIdx.x", 16)
conv2d_nchw_1 = T.Buffer((4,), data=conv2d_nchw, scope="local", align=8)
conv2d_nchw_1[0] = T.float32(0)
conv2d_nchw_1[2] = T.float32(0)
conv2d_nchw_1[4] = T.float32(0)
conv2d_nchw_1[6] = T.float32(0)
conv2d_nchw_1[8] = T.float32(0)
conv2d_nchw_1[10] = T.float32(0)
conv2d_nchw_1[12] = T.float32(0)
conv2d_nchw_1[1] = T.float32(0)
conv2d_nchw_1[3] = T.float32(0)
conv2d_nchw_1[5] = T.float32(0)
conv2d_nchw_1[7] = T.float32(0)
conv2d_nchw_1[9] = T.float32(0)
conv2d_nchw_1[11] = T.float32(0)
conv2d_nchw_1[13] = T.float32(0)
for rc_outer, ry_outer in T.grid(3, 5):
threadIdx_x_1 = T.env_thread("threadIdx.x")
pad_temp_shared_1 = T.Buffer((112,), data=pad_temp_shared, scope="shared")
placeholder_2 = T.Buffer((150528,), data=placeholder.data)
with T.launch_thread("threadIdx.z", 1) as threadIdx_z_1:
threadIdx_y_1 = T.launch_thread("threadIdx.y", 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and 1 <= blockIdx_x * 56 + threadIdx_x_1 * 7 // 2, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 450], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and 1 <= blockIdx_x * 56 + (threadIdx_x_1 * 7 + 1) // 2, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 449], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 448], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0))
threadIdx_x_2 = T.env_thread("threadIdx.x")
placeholder_shared_1 = T.Buffer((2,), data=placeholder_shared, scope="shared", align=8)
placeholder_3 = T.Buffer((750,), data=placeholder_1.data)
with T.launch_thread("threadIdx.z", 1) as threadIdx_z_1:
threadIdx_y_1 = T.launch_thread("threadIdx.y", 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
threadIdx_z_1 = T.env_thread("threadIdx.z")
threadIdx_y_1 = T.env_thread("threadIdx.y")
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and 1 <= blockIdx_x * 56 + (threadIdx_x_1 * 7 + 1) // 2, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 449], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 448], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0))
threadIdx_z_2 = T.env_thread("threadIdx.z")
threadIdx_y_2 = T.env_thread("threadIdx.y")
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 1]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 448], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 442], T.float32(0))
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 2]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 447], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 442], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and blockIdx_x * 112 + threadIdx_x_1 * 7 < 217, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 441], T.float32(0))
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 3]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
with T.launch_thread(threadIdx_z_1, 1):
T.launch_thread(threadIdx_y_1, 1)
T.launch_thread(threadIdx_x_1, 16)
pad_temp_shared_1[threadIdx_x_1 * 7] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 446], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 1] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 445], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 2] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 444], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 3] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 443], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 4] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 442], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 5] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and blockIdx_x * 112 + threadIdx_x_1 * 7 < 217, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 441], T.float32(0))
pad_temp_shared_1[threadIdx_x_1 * 7 + 6] = T.if_then_else(2 <= blockIdx_y + ry_outer and blockIdx_y + ry_outer < 226 and blockIdx_x * 112 + threadIdx_x_1 * 7 < 216, placeholder_2[rc_outer * 50176 + blockIdx_y * 224 + ry_outer * 224 + blockIdx_x * 112 + threadIdx_x_1 * 7 - 440], T.float32(0))
with T.launch_thread(threadIdx_z_2, 1):
T.launch_thread(threadIdx_y_2, 1)
T.launch_thread(threadIdx_x_2, 16)
if T.likely(threadIdx_x_2 < 2):
placeholder_shared_1[threadIdx_x_2] = placeholder_3[blockIdx_z * 150 + threadIdx_x_2 * 75 + rc_outer * 25 + ry_outer * 5 + 4]
conv2d_nchw_1[0] = conv2d_nchw_1[0] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[0]
conv2d_nchw_1[2] = conv2d_nchw_1[2] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[0]
conv2d_nchw_1[4] = conv2d_nchw_1[4] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[0]
conv2d_nchw_1[6] = conv2d_nchw_1[6] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[0]
conv2d_nchw_1[8] = conv2d_nchw_1[8] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[0]
conv2d_nchw_1[10] = conv2d_nchw_1[10] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[0]
conv2d_nchw_1[12] = conv2d_nchw_1[12] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[0]
conv2d_nchw_1[1] = conv2d_nchw_1[1] + pad_temp_shared_1[threadIdx_x] * placeholder_shared_1[1]
conv2d_nchw_1[3] = conv2d_nchw_1[3] + pad_temp_shared_1[threadIdx_x + 16] * placeholder_shared_1[1]
conv2d_nchw_1[5] = conv2d_nchw_1[5] + pad_temp_shared_1[threadIdx_x + 32] * placeholder_shared_1[1]
conv2d_nchw_1[7] = conv2d_nchw_1[7] + pad_temp_shared_1[threadIdx_x + 48] * placeholder_shared_1[1]
conv2d_nchw_1[9] = conv2d_nchw_1[9] + pad_temp_shared_1[threadIdx_x + 64] * placeholder_shared_1[1]
conv2d_nchw_1[11] = conv2d_nchw_1[11] + pad_temp_shared_1[threadIdx_x + 80] * placeholder_shared_1[1]
conv2d_nchw_1[13] = conv2d_nchw_1[13] + pad_temp_shared_1[threadIdx_x + 96] * placeholder_shared_1[1]
compute_1 = T.Buffer((501760,), data=compute)
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x] = T.max(conv2d_nchw_1[0], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 16] = T.max(conv2d_nchw_1[2], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 32] = T.max(conv2d_nchw_1[4], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 48] = T.max(conv2d_nchw_1[6], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 64] = T.max(conv2d_nchw_1[8], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 80] = T.max(conv2d_nchw_1[10], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 96] = T.max(conv2d_nchw_1[12], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50176] = T.max(conv2d_nchw_1[1], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50192] = T.max(conv2d_nchw_1[3], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50208] = T.max(conv2d_nchw_1[5], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50224] = T.max(conv2d_nchw_1[7], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50240] = T.max(conv2d_nchw_1[9], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50256] = T.max(conv2d_nchw_1[11], T.float32(0))
compute_1[blockIdx_z * 100352 + blockIdx_y * 224 + blockIdx_x * 112 + threadIdx_x + 50272] = T.max(conv2d_nchw_1[13], T.float32(0))
小结
本文介绍了
- 如何使用TOPI API,有numpy风格的常见算子
- TOPI的功能,通用的调度和带上下文的算子融合,创建优化的kernel代码

浙公网安备 33010602011771号