test_arith_intset.py 6.69 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23
import tvm

def test_basic():
    s = tvm.arith.intset_interval(2, 3)
    assert s.min().value == 2
    assert s.max().value == 3

24 25 26 27 28 29 30 31
def test_vector():
    base = 10
    stride = 3
    lanes = 2
    s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
    assert s.min().value == base
    assert s.max().value == base + stride * lanes - 1

32
def test_deduce():
33 34 35 36
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm.var('c')
    d = tvm.var('d')
37 38 39 40

    b_s = tvm.arith.intset_interval(2, 3)
    c_s = tvm.arith.intset_interval(10, 15)
    d_s = tvm.arith.intset_interval(-3, -1)
41
    zero = tvm.const(0, "int32")
42 43

    e0 = (-b)*a+c-d
44
    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
45 46 47
    ans0 = ((d - c) /(b*-1))
    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

48 49 50 51
    # expression containing variable a is on rhs
    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

52 53 54
    e0 = d*a+c-d
    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
    ans0 = ((0-c)/d + 1)
55 56
    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

57 58 59 60
    # expression containing variable a is on rhs
    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

61
    e1 = (a*4+b < c)
62
    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
63
    ans1 = (((c - b) + -1)/4)
64 65
    assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)

66 67 68 69 70
    # expression containing variable a is on rhs
    e1 = (c > a*4+b)
    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
    assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)

71
    e2 = (tvm.max(5, a * 4) < 0)
72
    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
73 74 75
    assert str(res2.max()) == "neg_inf"
    assert str(res2.min()) == "pos_inf"

76 77 78 79 80 81 82
    # expression containing variable a is on rhs
    e2 = (zero < tvm.max(5, a * 4))
    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
    assert str(res2.max()) == "neg_inf"
    assert str(res2.min()) == "pos_inf"


83 84 85 86 87
    e3 = (-b)+a*c-d
    res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
    ans3 = 2/c+1
    assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)

88 89 90
    res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
    assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)

91
def test_check():
92 93 94 95
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm.var('c')
    d = tvm.var('d')
96 97 98 99 100 101

    b_s = tvm.arith.intset_interval(2, 3)
    c_s = tvm.arith.intset_interval(5, 7)
    d_s = tvm.arith.intset_interval(-3, -1)

    # no compare operator
102
    res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
103 104 105
    assert res1.is_nothing()

    # multiple compare operators
106
    res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
107
    assert res2.is_nothing()
108 109

    # multiple target variable
110
    res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
111 112 113 114 115 116 117 118 119 120 121 122 123
    assert res2.is_nothing()

def test_deduce_basic():
    def test_basic(a1, a2, coff):
        a = tvm.var('a')
        b = tvm.var('b')
        b_s = tvm.arith.intset_interval(a1, a2)
        e0 = b + a*coff + 3

        res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
        [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1

124 125
        # expression containing variable a is on rhs
        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
126 127 128
        [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1

129 130
        # expression containing variable a is on rhs
        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
131 132
        [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
133

134 135 136
        res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
        [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
137

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    test_basic(0, 4, 4)
    test_basic(1, 5, 4)
    test_basic(2, 6, 4)
    test_basic(0, 4, -4)
    test_basic(1, 5, -4)
    test_basic(2, 6, -4)

def test_deduce_complex():
    def test_complex(a1, a2, coff):
        a = tvm.var('a')
        b = tvm.var('b')
        b_s = tvm.arith.intset_interval(a1, a2)
        e0 = (b*3 + a* coff) * 4

        res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
        [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1

156 157
        # expression containing variable a is on rhs
        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
158 159 160 161 162 163 164
        [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1

        res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
        [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1

165 166
        # expression containing variable a is on rhs
        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
167 168 169 170 171 172 173 174 175
        [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1

    test_complex(0, 4, 4)
    test_complex(0, 4, -4)
    test_complex(2, 6, 4)
    test_complex(0, 4, -4)
    test_complex(1, 5, -4)
    test_complex(2, 6, -4)
176 177 178

if __name__ == "__main__":
    test_basic()
179
    test_vector()
180 181
    test_deduce()
    test_check()
182 183
    test_deduce_basic()
    test_deduce_complex()