test_arith_intset.py 4.03 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19

20 21 22 23 24 25 26 27 28 29 30

class IntSetChecker:
    def __init__(self):
        self.analyzer = tvm.arith.Analyzer()

    def verify(self, data, dmap, expected):
        res = self.analyzer.int_set(data, dmap)
        def err_msg():
            return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
        def equal(x, y):
            res = self.analyzer.canonical_simplify(x - y)
31
            return tvm.tir.ir_pass.Equal(res, 0)
32 33 34
        assert equal(res.min_value, expected[0]), err_msg()
        assert equal(res.max_value, expected[1]), err_msg()

35
def test_basic():
36 37 38 39
    s = tvm.arith.IntervalSet(2, 3)
    assert s.min_value.value == 2
    assert s.max_value.value == 3

40 41 42 43
    s = tvm.arith.IntSet.single_point(2)
    assert s.min_value.value == 2
    assert s.max_value.value == 2

44

45 46 47 48
def test_vector():
    base = 10
    stride = 3
    lanes = 2
49
    s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes))
50 51 52 53 54 55
    assert s.min_value.value == base
    assert s.max_value.value == base + stride * lanes - 1


def test_add_sub():
    ck = IntSetChecker()
56
    x, y = te.var("x"), te.var("y")
57 58 59 60 61 62 63 64 65 66
    ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
    ck.verify(x + y,
              {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
              (1, 21))
    ck.verify(x - y,
              {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
              (-11, 9))

def test_mul_div():
    ck = IntSetChecker()
67
    x, y = te.var("x"), te.var("y")
68

69
    tdiv = tvm.tir.truncdiv
70 71 72 73
    ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
    ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
    ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
    ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
74

75 76
    ck.verify(tdiv(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
    ck.verify(tdiv(x, 2), {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
77

78
    fld = tvm.te.floordiv
79 80 81
    ck.verify(fld(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
    ck.verify(fld(x, 2), {x : tvm.arith.IntervalSet(-1, 10)}, (-1, 5))

82 83 84

def test_mod():
    ck = IntSetChecker()
85 86
    x, y = te.var("x"), te.var("y")
    tmod = tvm.tir.truncmod
87
    ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
88 89
    ck.verify(tmod(x, y), {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
    ck.verify(tmod(x, 10), {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
90

91
    flm = tvm.te.floormod
92 93 94
    ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))


95 96
def test_max_min():
    ck = IntSetChecker()
97 98 99 100 101
    x, y = te.var("x"), te.var("y")
    ck.verify(tvm.te.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11))
    ck.verify(tvm.te.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9))
    ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y)))
    ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y)))
102 103 104 105


def test_select():
    ck = IntSetChecker()
106
    x, y = te.var("x"), te.var("y")
107
    ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1),
108 109
              {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))

110 111 112

if __name__ == "__main__":
    test_basic()
113
    test_vector()
114 115 116 117 118
    test_add_sub()
    test_mul_div()
    test_max_min()
    test_select()
    test_mod()