# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm def test_bound1(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule([A2.op]) xo, xi = s[A2].split(s[A2].op.axis[0], 8) s[A1].compute_at(s[A2], xo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert(bounds[A1.op.axis[0]].extent.value == 8) def test_bound2(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8) # test normalize not affecting schedule _ = s.normalize() s[A1].compute_at(s[A2], yo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert(bounds[A1.op.axis[0]].extent.value == 8) assert(bounds[A1.op.axis[1]].extent.value == 8) def test_bound3(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) s[A1].set_scope("shared") xo, xi = s[A2].split(A2.op.axis[0], 32) xi0, xi1 = s[A2].split(xi, nparts=16) s[A2].bind(xi0, tvm.thread_axis("threadIdx.x")) yo, yi = s[A2].split(A2.op.axis[1], 16) # test normalize not affecting schedule _ = s.normalize() s[A2].reorder(xo, xi0, yo, xi1, yi) s[A1].compute_at(s[A2], yo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) def test_bound_fusesplit1(): m = tvm.var('m') l = tvm.var('l') split1 = tvm.var('s') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1]) xo, xi = s[A2].split(fused_axes, split1) s[A1].compute_at(s[A2], xo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split1) / l ).value == 0) expected_extent = (((xo + 1) * split1 - 1) / l - (xo * split1) / l + 1) for i in range(1, 6): for j in range(1, 6): for k in range(1, 6): vars = tvm.convert({split1: tvm.const(i, "int32"), l: tvm.const(j, "int32"), xo.var: tvm.const(k, "int32")}) comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value assert(comp_ext == exp_ext) assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) def test_bound_fusesplit2(): m = tvm.var("m") l = tvm.convert(6) split = tvm.convert(3) A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1]) xo, xi = s[A2].split(fused_axes, split) s[A1].compute_at(s[A2], xo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) vars = tvm.convert({xo.var: tvm.const(5, "int32")}) assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2) assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3) assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1) assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3) def test_bound_warp(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) s[A1].set_scope("warp") xo, xi = s[A2].split(A2.op.axis[0], 32) xi0, xi1 = s[A2].split(xi, factor=16) tx = tvm.thread_axis("threadIdx.x") s[A2].bind(xi1, tx) s[A2].bind(xi0, tvm.thread_axis("threadIdx.y")) y = s[A2].op.axis[1] s[A1].compute_at(s[A2], y) xo, xi = s[A1].split(s[A1].op.axis[0], factor=16) s[A1].bind(xi, tx) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert(bounds[A1.op.axis[0]].extent.value==16) def test_bound_scan(): m = tvm.var("m") n = tvm.var("n") X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_scan = tvm.scan(s_init, s_update, s_state) assert tuple(s_scan.shape) == (m, n) s = tvm.create_schedule(s_scan.op) XX = s.cache_read(X, "local", s_update) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) s[XX].compute_at(s[s_update], xo) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) assert bounds[XX.op.axis[1]].extent.value == 4 def test_bound_conv1d(): n = tvm.var('n') A = tvm.compute((n+2), lambda i: 1, name='A') def computeB(ii): i = ii + 1 return A[i-1] + A[i] + A[i+1] B = tvm.compute(n, computeB, name='B') s = tvm.create_schedule(B.op) s[A].compute_at(s[B], B.op.axis[0]) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A.op.axis[0]].extent.value == 3) def test_bound_blur(): n = tvm.convert(12) A = tvm.compute((n, n), lambda i, j: 1, name='A') def computeB(ii, jj): # set the correct center i = ii + 1 j = jj + 1 return A[i][j] + A[i-1][j] + A[i+1][j] + A[i][j+1] + A[i][j-1] B = tvm.compute((n-2, n-2), computeB, name='B') s = tvm.create_schedule(B.op) s[A].compute_at(s[B], B.op.axis[1]) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A.op.axis[0]].extent.value == 3) assert(bounds[A.op.axis[1]].extent.value == 3) def test_bound_rfactor(): n = tvm.var('n') A = tvm.placeholder((n,), name='A') k = tvm.reduce_axis((0, n)) B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B') # schedule s = tvm.create_schedule(B.op) kf, ki = s[B].split(k, nparts=4) BF = s.rfactor(B, kf) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[BF.op.axis[0]].extent.value == 4) assert(bounds[BF.op.axis[1]].extent.value == 1) def test_bound_group_schedule(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1") x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") s = tvm.create_schedule(x2.op) g = s.create_group(outputs=x1, inputs=x, include_inputs=True) g.compute_at(s[x2], x2.op.axis[0]) assert s[x1].group == g assert s[x].group == g s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent == n def test_bound_nest_group(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") x1 = tvm.compute(x.shape, lambda *i: x(*i) + 1, name="x1") x2 = tvm.compute(x.shape, lambda *i: x1(*i) + 2, name="x2") s = tvm.create_schedule(x2.op) g1 = s.create_group(outputs=x, inputs=x, include_inputs=True) g2 = s.create_group(outputs=x1, inputs=x, include_inputs=True) assert s[x].group == g1 assert s[x1].group == g2 g2.compute_at(s[x2], x2.op.axis[0]) g1.compute_at(s[x1], s[x1].op.axis[1]) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent.value == 1 assert bounds[x1.op.axis[0]].extent.value == 1 assert bounds[x1.op.axis[1]].extent == n def test_bound_nest_thread(): m = tvm.var('m') A = tvm.placeholder((m), name='A') A1 = tvm.compute((m,), lambda i: A[i], name='A1') A2 = tvm.compute((m,), lambda i: A1[i] + 2, name='A2') A3 = tvm.compute((m,), lambda i: A2[i] + 3, name='A3') s = tvm.create_schedule(A3.op) s[A2].set_scope("shared") s[A1].set_scope("local") block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis("threadIdx.x") bx, tx = s[A3].split(A3.op.axis[0], factor=32) s[A3].bind(bx, block_x) s[A3].bind(tx, thread_x) s[A2].compute_at(s[A3], tx) _, xi = s[A2].split(A2.op.axis[0], nparts=1) s[A2].bind(xi, thread_x) s[A1].compute_at(s[A3], tx) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A1.op.axis[0]].extent.value==1) assert(bounds[A2.op.axis[0]].extent.value==32) assert(bounds[A3.op.axis[0]].extent == m) def test_gemm_bound(): nn = 1024 n = tvm.convert(nn) A = tvm.placeholder((n, n), name='A') B = tvm.placeholder((n, n), name='B') k = tvm.reduce_axis((0, n), name='k') C = tvm.compute( (n, n), lambda ii, jj: tvm.sum(A[ii, k] * B[jj, k], axis=k), name='CC') # schedule s = tvm.create_schedule(C.op) xtile, ytile = 32, 32 scale = 8 num_thread = 8 block_factor = scale * num_thread block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis("threadIdx.x") block_y = tvm.thread_axis("blockIdx.y") thread_y = tvm.thread_axis("threadIdx.y") CC = s.cache_write(C, "local") AA = s.cache_read(A, "shared", [CC]) BB = s.cache_read(B, "shared", [CC]) by, yi = s[C].split(C.op.axis[0], factor=block_factor) bx, xi = s[C].split(C.op.axis[1], factor=block_factor) s[C].reorder(by, bx, yi, xi) s[C].bind(by, block_y) s[C].bind(bx, block_x) ty, yi = s[C].split(yi, nparts=num_thread) tx, xi = s[C].split(xi, nparts=num_thread) s[C].reorder(ty, tx, yi, xi) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) yo, xo = CC.op.axis s[CC].reorder(k, yo, xo) s[CC].compute_at(s[C], tx) s[AA].compute_at(s[CC], k) s[BB].compute_at(s[CC], k) ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) tx, xi = s[AA].split(xi, nparts=num_thread) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) ty, xi = s[BB].split(s[BB].op.axis[0], nparts=num_thread) tx, xi = s[BB].split(xi, nparts=num_thread) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[BB.op.axis[0]].extent.value==64) assert(bounds[AA.op.axis[0]].extent.value==64) assert(bounds[CC.op.axis[0]].extent.value == 8) assert(bounds[CC.op.axis[1]].extent.value == 8) def test_bound_tensor_compute_op(): def intrin_test(): m1 = tvm.var("m1") n1 = tvm.var("n1") a = tvm.placeholder((m1, n1), name='a') c = tvm.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c') Ab = tvm.decl_buffer(a.shape, name="Abuf", offset_factor=1) Cb = tvm.decl_buffer(c.shape, name="Cbuf", offset_factor=1) def intrin_func(ins, outs): aa = ins[0] cc = outs[0] def _body(): ib = tvm.ir_builder.create() ib.emit(tvm.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))) return ib.get() return _body() with tvm.build_config(offset_factor=1): return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a : Ab, c : Cb}) test_func = intrin_test() A = tvm.placeholder((20,20), name='A') B = tvm.compute(A.shape, lambda i,j : A[i,j], name='B') C = tvm.compute((10, 20), lambda i : test_func(B[i:10, 0:20]), name='C') s = tvm.create_schedule(C.op) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert(bounds[B.op.axis[0]].extent.value == 10) def test_bound_simplification_failure(): # Check that the bounds are not expanded A = tvm.compute((2,), lambda j: j, "A") def _check(B, A=A): s = tvm.create_schedule(B.op) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.lower(s, [B, A], simple_mode=True) if not bounds[A.op.axis[0]].extent.value <= 2: print(stmt) assert bounds[A.op.axis[0]].extent.value <= 2 # These are hard to simplify, moreover we don't simplify them _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)])) _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)])) _check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)])) _check(tvm.compute((10,), lambda i: A[i + (0 - i)])) # This would cause out of bounds, but we nevertheless include it _check(tvm.compute((10,), lambda i: A[i])) if __name__ == "__main__": test_bound_nest_thread() test_bound1() test_bound_nest_group() test_bound_group_schedule() test_bound_scan() test_bound3() test_bound_rfactor() test_bound_blur() test_bound_conv1d() test_bound2() test_gemm_bound() test_bound_warp() test_bound_tensor_compute_op() test_bound_simplification_failure() test_bound_fusesplit1() test_bound_fusesplit2()