test_arith_intset.py 1.96 KB
Newer Older
1 2 3 4 5 6 7
import tvm

def test_basic():
    s = tvm.arith.intset_interval(2, 3)
    assert s.min().value == 2
    assert s.max().value == 3

8 9 10 11 12 13 14 15
def test_vector():
    base = 10
    stride = 3
    lanes = 2
    s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
    assert s.min().value == base
    assert s.max().value == base + stride * lanes - 1

16
def test_deduce():
17 18 19 20
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm.var('c')
    d = tvm.var('d')
21 22 23 24 25 26

    b_s = tvm.arith.intset_interval(2, 3)
    c_s = tvm.arith.intset_interval(10, 15)
    d_s = tvm.arith.intset_interval(-3, -1)

    e0 = (-b)*a+c-d
27
    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
28 29 30 31
    ans0 = (d-c)/(-b)+(-1)
    assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)

    e1 = (a*4+b < c)
32
    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
33 34 35
    ans1 = (c-b)/4+(-2)
    assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)

36
    e2 = (tvm.max(5, a * 4) < 0)
37
    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
38 39 40
    assert str(res2.max()) == "neg_inf"
    assert str(res2.min()) == "pos_inf"

41 42 43 44 45
    e3 = (-b)+a*c-d
    res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
    ans3 = 2/c+1
    assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)

46
def test_check():
47 48 49 50
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm.var('c')
    d = tvm.var('d')
51 52 53 54 55 56

    b_s = tvm.arith.intset_interval(2, 3)
    c_s = tvm.arith.intset_interval(5, 7)
    d_s = tvm.arith.intset_interval(-3, -1)

    # no compare operator
57
    res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
58 59 60
    assert res1.is_nothing()

    # multiple compare operators
61
    res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {})
62 63 64
    assert res1.is_nothing()

    # multiple target variable
65
    res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
66 67 68 69
    assert res1.is_nothing()

if __name__ == "__main__":
    test_basic()
70
    test_vector()
71 72
    test_deduce()
    test_check()