test_arith_detect_linear_equation.py 1.49 KB
Newer Older
1 2 3
import tvm

def test_basic():
4 5
    a = tvm.var("a")
    b = tvm.var("b")
6 7 8
    m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a])
    assert m[0].value == 4
    assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0
9

10
    m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a])
11 12
    assert len(m) == 0

13 14 15
    m = tvm.arith.DetectLinearEquation(a * 4  + (a+1) + b * 6 + 7, [a])
    assert m[0].value == 5
    assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0
16

17 18
    m = tvm.arith.DetectLinearEquation(a * b + 7, [a])
    assert m[0] == b
19

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    m = tvm.arith.DetectLinearEquation(b * 7, [a])
    assert m[0].value == 0

def test_multivariate():
    v = [tvm.var("v%d" % i) for i in range(4)]
    b = tvm.var("b")
    m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
    assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5))
    assert(m[1].value == 8)

    m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
    assert(len(m) == 0)

    m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v)
    assert(len(m) == 0)

    m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v)
    assert(m[1].value == 16)
    assert(m[2].value == 2)
    assert(m[len(m)-1].value == 2)
40

41 42 43 44
    m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]])
    assert(m[0].value == 0)
    assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)

45 46
if __name__ == "__main__":
    test_basic()
47
    test_multivariate()