# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm def csimplify(z): return tvm.ir_pass.CanonicalSimplify( tvm.make.Evaluate(z)).value def test_simplify(): x = tvm.var('n') z = x * 4 - x * 2 zz = csimplify(z) assert zz.b.value == 2 z = (x / 4) * 2 - (x / 4) zz = csimplify(z) assert zz.a == x and zz.b.value == 4 z = (x % 4) * 3 + (x % 4) zz = csimplify(z) assert zz.b.value == 4 zz = zz.a assert zz.a == x and zz.b.value == 4 n = tvm.var('n') assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32")) assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n) tvm.ir_pass.CanonicalSimplify(n / (-1)) # This is not true in the current implementation # assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)), # tvm.ir_pass.CanonicalSimplify(-n)) def test_simplify_mod(): ib = tvm.ir_builder.create() n = tvm.var('n') A = ib.pointer("float32", name="A") with ib.for_range(0, 10, name="j") as j: with ib.for_range(0, 16, name="i") as i: A[i] = A[(j * 32 + i+1) % 16] body = ib.get() stmt = tvm.ir_pass.CanonicalSimplify(body) diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) assert diff.value == 0 # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16) assert index != j index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)}) assert index == j # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify( (j + n * 32) % 16, {j: tvm.Range(0, 6)}) assert index != j index = tvm.ir_pass.CanonicalSimplify( (j + n * 32) % 16, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)}) assert index == j def test_simplify_minmax(): x = tvm.var('x') e1 = tvm.max(x, 1) - tvm.max(x, 1) e1s = tvm.ir_pass.CanonicalSimplify(e1) assert e1s.value == 0 e2 = tvm.min(x, 1) - tvm.min(x, 1) e2s = tvm.ir_pass.CanonicalSimplify(e2) assert e2s.value == 0 def test_mul(): x = tvm.var('x') e = x * x - x * x es = tvm.ir_pass.CanonicalSimplify(e) assert es.value == 0 def test_modular(): rx = tvm.var("rx") ry = tvm.var("ry") y = tvm.var("y") x = tvm.var("x") i32_const = lambda x: tvm.const(x, "int32") vmap = {rx: tvm.Range(i32_const(0), i32_const(3)), ry: tvm.Range(i32_const(0), i32_const(3)), y: tvm.Range(i32_const(0), i32_const(2)), x: tvm.Range(i32_const(0), i32_const(14))} idx = ry * 16 + rx + y * 16 + x z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap) z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap) assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 def test_const_propagation(): x1 = tvm.const(4, "int32") x2 = x1 + 5 assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 x3 = x2 / 3 assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 x4 = x3 + 0.5 assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5 x5 = tvm.ceil(x4) assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4 x6 = x5.astype('int') assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4 y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int') assert isinstance(y, tvm.expr.IntImm) and y.value == 6 if __name__ == "__main__": test_modular() test_simplify() test_mul() test_simplify_minmax() test_const_propagation() test_simplify_mod()