test_arith_detect_clip_bound.py 727 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
import tvm

def test_basic():
    a = tvm.var("a")
    b = tvm.var("b")
    c = tvm.var("c")
    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
                                          a - 1 > 0), [a])
    assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
    assert m[0].value == 2
    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
                                          a - 1 > 0), [a, b])
    assert len(m) == 0
    m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
                                          b - 1 > 0), [a, b])
    assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
    assert tvm.ir_pass.Simplify(m[2] - 2).value == 0


if __name__ == "__main__":
    test_basic()