# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import tvm def test_cast(): analyzer = tvm.arith.Analyzer() x = tvm.var("x", dtype="int8") m = analyzer.modular_set((x * 3).astype("uint32")) assert m.coeff == 3 assert m.base == 0 m = analyzer.modular_set( (x * 3 + 1).astype("float32").astype("int32")) assert m.coeff == 3 assert m.base == 1 def test_add_sub(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x", "int64"), tvm.var("y", "int64") m = analyzer.modular_set(x * 6 + y * 4) assert m.coeff == 2 assert m.base == 0 analyzer.bind(y, x * 4 + 1) m = analyzer.modular_set(1 - y) assert m.coeff == 4 assert m.base == 0 def test_mul(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1)) assert m.coeff == 4 assert m.base == 2 def test_div_shift(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") # not sure if x is non-negative tdiv = tvm.truncdiv m = analyzer.modular_set(tdiv(x * 4 + 2, 2)) assert m.coeff == 1 assert m.base == 0 # right shift always round down so it is fine m = analyzer.modular_set((x * 4 + 2) >> 1) assert m.coeff == 2 assert m.base == 1 fld = tvm.floordiv m = analyzer.modular_set(fld(x * 4 + 2, 2)) assert m.coeff == 2 assert m.base == 1 # x is non-negative analyzer.update(x, tvm.arith.ConstIntBound(0, 100)) m = analyzer.modular_set(tdiv(x * 4 + 2, 2)) assert m.coeff == 2 assert m.base == 1 def test_min_max_select(): analyzer = tvm.arith.Analyzer() x, y = tvm.var("x"), tvm.var("y") m = analyzer.modular_set(tvm.min(x * 3, y * 9)) assert m.coeff == 3 assert m.base == 0 m = analyzer.modular_set(tvm.max(x * 3 + 1, y * 9 + 4)) assert m.coeff == 3 assert m.base == 1 m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2)) assert m.coeff == 1 assert m.base == 0 def test_mix_index(): a = tvm.var("a") b = tvm.var("b") analyzer = tvm.arith.Analyzer() tdiv = tvm.truncdiv m = analyzer.modular_set(a * 4 + b * 6 + 7) assert m.coeff == 2 assert m.base == 1 m = analyzer.modular_set((a * 4 + 1) * (b * 8 + 3)) assert m.coeff == 4 assert m.base == 3 m = analyzer.modular_set(tdiv(a * 4 + 1, b * 8 + 3)) assert m.coeff == 1 assert m.base == 0 m = analyzer.modular_set((a * 4 + 1) * tdiv(b * 8, 4)) assert m.coeff == 2 assert m.base == 0 m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7 + 2)) assert m.coeff == 3 assert m.base == 2 m = analyzer.modular_set(a * 12 + tvm.min(b * 3 * 7, 2)) assert m.coeff == 1 assert m.base == 0 def test_constraint_scope(): a = tvm.var("a") b = tvm.var("b") analyzer = tvm.arith.Analyzer() tmod = tvm.truncmod with analyzer.constraint_scope(tmod(b, 4) == 2): m = analyzer.modular_set(b + 1) assert m.coeff == 4 assert m.base == 3 with analyzer.constraint_scope(tmod(a, 2) == 1): m = analyzer.modular_set(b + a * 2) assert m.coeff == 4 assert m.base == 0 m = analyzer.modular_set(b + a * 2) assert m.coeff == 2 assert m.base == 0 m = analyzer.modular_set(b + 1) assert m.coeff == 1 assert m.base == 0 def test_intersect(): a = tvm.var("a") analyzer = tvm.arith.Analyzer() tmod = tvm.truncmod with analyzer.constraint_scope(tmod(a, 4) == 1): with analyzer.constraint_scope(tmod(a, 3) == 1): m = analyzer.modular_set(a) assert m.coeff == 12 assert m.base == 1 with analyzer.constraint_scope(tmod(a, 3) == 2): with analyzer.constraint_scope(tmod(a, 5) == 3): with analyzer.constraint_scope(tmod(a, 7) == 2): m = analyzer.modular_set(a) assert m.coeff == 105 assert m.base == 23 if __name__ == "__main__": test_cast() test_add_sub() test_mul() test_div_shift() test_min_max_select() test_mix_index() test_constraint_scope() test_intersect()