test_lang_operator.py 2.36 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
import tvm

def test_const_fold():
    def check(f, *args):
        x = f(*[tvm.const(x) for x in args])
        y = f(*args)
        if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
            raise ValueError("check error: %s vs %s " % (x, y))

    check(lambda x, y: x + y, 3, 4)
    check(lambda x, y: x * y, 3, 12)
    check(lambda x, y: x * y - 10, 3, 12)
    check(lambda x, y: x - y % 10, 3, 12)
    check(lambda x, y: x // y + 10, 100, 12)
    check(lambda x, y: x & y + 10, 112, 128)
    check(lambda x, y: x > y, 112, 128)
    check(lambda x, y: x < y, 112, 128)
    check(lambda x, y: x <= y, 112, 128)
    check(lambda x, y: x >= y, 112, 128)
    check(lambda x, y: (x | y) ^ 10, 112, 128)


def test_const_fold2():
    x = tvm.var("x")
    assert (x + 0).same_as(x)
    assert (0 + x).same_as(x)
    assert (x - 0).same_as(x)
    assert (x % 1).value == 0
    assert (x * 1).same_as(x)
    assert (1 * x).same_as(x)
    assert isinstance((1 / x), tvm.expr.Div)

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
def test_const_fold3():
    def check_throws(f):
        try:
            f()
        except tvm.TVMError:
            pass
        else:
            raise AssertionError("Should have raised an exception but didn't.")

    # Test that using ints with logic operations is forbidden
    x = tvm.var("x")
    for val in [0, 1]:
        for func in [tvm.all, tvm.any]:
            check_throws(lambda: func(tvm.const(val, 'uint1'), x))
            check_throws(lambda: func(x, tvm.const(val, 'uint1')))

    # Test const folding when both arguments are const
    for tvm_func, py_func in [(tvm.all, lambda a, b: a and b), (tvm.any, lambda a, b: a or b)]:
        for v1 in [0, 1]:
            for v2 in [0, 1]:
                assert tvm.ir_pass.Equal(tvm_func(tvm.const(v1, 'uint1'), tvm.const(v2, 'uint1')),
                                         tvm.const(py_func(v1, v2), 'uint1'))

    x = tvm.var("x", 'uint1')
    true = tvm.const(1, 'uint1')
    false = tvm.const(0, 'uint1')

    assert tvm.all(x, true).same_as(x)
    assert tvm.all(true, x).same_as(x)
    assert tvm.any(x, false).same_as(x)
    assert tvm.any(false, x).same_as(x)

    assert tvm.all(x, false).same_as(false)
    assert tvm.all(false, x).same_as(false)
    assert tvm.any(x, true).same_as(true)
    assert tvm.any(true, x).same_as(true)

70 71 72
if __name__ == "__main__":
    test_const_fold()
    test_const_fold2()
73
    test_const_fold3()