test_lang_basic.py 1.48 KB
Newer Older
tqchen committed
1
import tvm
tqchen committed
2

3 4
def test_const():
    x = tvm.const(1)
tqchen committed
5
    assert x.dtype == 'int32'
6
    assert isinstance(x, tvm.expr.IntImm)
7

8 9 10 11 12 13 14 15 16 17
def test_const_saveload_json():
    # save load json
    x = tvm.const(1)
    y = tvm.const(10)
    z = x + y
    z = z + z
    json_str = tvm.save_json(z)
    zz = tvm.load_json(json_str)
    assert tvm.save_json(zz) == tvm.save_json(z)

tqchen committed
18 19 20 21
def test_make():
    x = tvm.const(1)
    y = tvm.make.IntImm('int32', 1)
    z = x + y
22
    print(z)
tqchen committed
23 24 25 26 27 28 29 30

def test_ir():
    x = tvm.const(1)
    y = tvm.make.IntImm('int32', 1)
    z = x + y
    stmt = tvm.make.Evaluate(z)
    assert isinstance(stmt, tvm.stmt.Evaluate)

31 32 33 34
def test_let():
    x = tvm.Var('x')
    y = tvm.Var('y')
    stmt = tvm.make.LetStmt(
tqchen committed
35
        x, 10, tvm.make.Evaluate(x + 1));
36

tqchen committed
37 38 39 40 41 42
def test_attr():
    x = tvm.Var('x')
    y = tvm.Var('y')
    stmt = tvm.make.AttrStmt(
        y, "stride", 10, tvm.make.Evaluate(x + 1));
    assert stmt.node == y
43 44 45 46 47 48 49 50 51

    a = tvm.convert(1)
    assert a.value == 1
    try:
        a.no_field
        assert False
    except AttributeError:
        pass

52

tqchen committed
53 54 55 56
def test_basic():
    a = tvm.Var('a')
    b = tvm.Var('b')
    c =  a + b
57
    assert str(c) == '(%s + %s)' % (a.name, b.name)
tqchen committed
58 59 60


def test_stmt():
61
    x = tvm.make.Evaluate(0)
tqchen committed
62 63
    tvm.make.For(tvm.Var('i'), 0, 1,
                 tvm.stmt.For.Serial, 0,
64
                 x)
tqchen committed
65 66


67
if __name__ == "__main__":
tqchen committed
68
    test_attr()
69
    test_const()
70
    test_const_saveload_json()
tqchen committed
71 72
    test_make()
    test_ir()
tqchen committed
73 74
    test_basic()
    test_stmt()
75
    test_let()