test_arith_intset.py 4.65 KB
Newer Older
1 2 3 4 5 6 7
import tvm

def test_basic():
    s = tvm.arith.intset_interval(2, 3)
    assert s.min().value == 2
    assert s.max().value == 3

8 9 10 11 12 13 14 15
def test_vector():
    base = 10
    stride = 3
    lanes = 2
    s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
    assert s.min().value == base
    assert s.max().value == base + stride * lanes - 1

16
def test_deduce():
17 18 19 20
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm.var('c')
    d = tvm.var('d')
21 22 23 24 25 26

    b_s = tvm.arith.intset_interval(2, 3)
    c_s = tvm.arith.intset_interval(10, 15)
    d_s = tvm.arith.intset_interval(-3, -1)

    e0 = (-b)*a+c-d
27
    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
28 29 30 31 32 33
    ans0 = ((d - c) /(b*-1))
    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

    e0 = d*a+c-d
    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
    ans0 = ((0-c)/d + 1)
34 35 36
    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

    e1 = (a*4+b < c)
37
    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
38
    ans1 = (((c - b) + -1)/4) 
39 40
    assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)

41
    e2 = (tvm.max(5, a * 4) < 0)
42
    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
43 44 45
    assert str(res2.max()) == "neg_inf"
    assert str(res2.min()) == "pos_inf"

46 47 48 49 50
    e3 = (-b)+a*c-d
    res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
    ans3 = 2/c+1
    assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)

51
def test_check():
52 53 54 55
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm.var('c')
    d = tvm.var('d')
56 57 58 59 60 61

    b_s = tvm.arith.intset_interval(2, 3)
    c_s = tvm.arith.intset_interval(5, 7)
    d_s = tvm.arith.intset_interval(-3, -1)

    # no compare operator
62
    res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
63 64 65
    assert res1.is_nothing()

    # multiple compare operators
66
    res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {})
67
    assert res2.is_nothing()
68 69

    # multiple target variable
70
    res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    assert res2.is_nothing()

def test_deduce_basic():
    def test_basic(a1, a2, coff):
        a = tvm.var('a')
        b = tvm.var('b')
        b_s = tvm.arith.intset_interval(a1, a2)
        e0 = b + a*coff + 3

        res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
        [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1

        res1 = tvm.arith.DeduceBound(a, e0>17, {b: b_s}, {b: b_s})
        [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1

        res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s})
        [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
      
        res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
        [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
       
    test_basic(0, 4, 4)
    test_basic(1, 5, 4)
    test_basic(2, 6, 4)
    test_basic(0, 4, -4)
    test_basic(1, 5, -4)
    test_basic(2, 6, -4)

def test_deduce_complex():
    def test_complex(a1, a2, coff):
        a = tvm.var('a')
        b = tvm.var('b')
        b_s = tvm.arith.intset_interval(a1, a2)
        e0 = (b*3 + a* coff) * 4

        res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
        [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1

        res1 = tvm.arith.DeduceBound(a, e0<=63, {b: b_s}, {b: b_s})
        [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1

        res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
        [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1

        res1 = tvm.arith.DeduceBound(a, e0>=63, {b: b_s}, {b: b_s})
        [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
        assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1

    test_complex(0, 4, 4)
    test_complex(0, 4, -4)
    test_complex(2, 6, 4)
    test_complex(0, 4, -4)
    test_complex(1, 5, -4)
    test_complex(2, 6, -4)
132 133 134

if __name__ == "__main__":
    test_basic()
135
    test_vector()
136 137
    test_deduce()
    test_check()
138 139 140
    test_deduce_basic()
    test_deduce_complex()