test_lang_basic.py 3.54 KB
Newer Older
tqchen committed
1
import tvm
tqchen committed
2

3 4
def test_const():
    x = tvm.const(1)
5 6
    print(x.dtype)
    assert x.dtype == tvm.int32
7
    assert isinstance(x, tvm.expr.IntImm)
8

tqchen committed
9 10 11 12 13 14 15 16 17 18 19 20
def test_make():
    x = tvm.const(1)
    y = tvm.make.IntImm('int32', 1)
    z = x + y

def test_ir():
    x = tvm.const(1)
    y = tvm.make.IntImm('int32', 1)
    z = x + y
    stmt = tvm.make.Evaluate(z)
    assert isinstance(stmt, tvm.stmt.Evaluate)

21 22 23 24 25 26 27
def test_ir2():
    x = tvm.var("n")
    a = tvm.var("array", tvm.handle)
    st = tvm.make.Store(a, x + 1, 1)
    assert isinstance(st, tvm.stmt.Store)
    assert(st.buffer_var == a)

28
def test_let():
29 30
    x = tvm.var('x')
    y = tvm.var('y')
31
    stmt = tvm.make.LetStmt(
tqchen committed
32
        x, 10, tvm.make.Evaluate(x + 1));
33

34 35 36 37 38 39 40 41 42
def test_cast():
    x = tvm.var('x', dtype="float32")
    y = x.astype("int32")
    z = x.astype("float32x4")
    assert isinstance(y, tvm.expr.Cast)
    assert isinstance(z, tvm.expr.Broadcast)
    assert z.lanes == 4


tqchen committed
43
def test_attr():
44 45
    x = tvm.var('x')
    y = tvm.var('y')
tqchen committed
46 47 48
    stmt = tvm.make.AttrStmt(
        y, "stride", 10, tvm.make.Evaluate(x + 1));
    assert stmt.node == y
49 50 51 52 53 54 55 56 57

    a = tvm.convert(1)
    assert a.value == 1
    try:
        a.no_field
        assert False
    except AttributeError:
        pass

58

tqchen committed
59
def test_basic():
60 61
    a = tvm.var('a')
    b = tvm.var('b')
tqchen committed
62
    c =  a + b
63
    assert str(c) == '(%s + %s)' % (a.name, b.name)
tqchen committed
64 65 66


def test_stmt():
67
    x = tvm.make.Evaluate(0)
68
    tvm.make.For(tvm.var('i'), 0, 1,
tqchen committed
69
                 tvm.stmt.For.Serial, 0,
70
                 x)
tqchen committed
71

72 73 74 75 76 77 78 79 80
def test_dir():
    x = tvm.var('x')
    dir(x)

def test_dtype():
    x = tvm.var('x')
    assert x.dtype == 'int32'
    y = tvm.var('y')
    assert (x > y).dtype == 'uint1'
tqchen committed
81

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

def test_any():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x or x
        assert False
    except ValueError:
        pass
    try:
        tvm.any()
        assert False
    except ValueError:
        pass
    assert str(tvm.any(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.any(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)


def test_all():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x and x
        assert False
    except ValueError:
        pass
    try:
        tvm.all()
        assert False
    except ValueError:
        pass
    assert str(tvm.all(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.all(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)

126 127 128 129 130 131 132 133 134 135
def test_bitwise():
    x = tvm.var('x')
    y = tvm.var('y')
    assert str(x << y) == 'shift_left(x, y)'
    assert str(x >> y) == 'shift_right(x, y)'
    assert str(x & y) == 'bitwise_and(x, y)'
    assert str(x | y) == 'bitwise_or(x, y)'
    assert str(x ^ y) == 'bitwise_xor(x, y)'
    assert str(~x) == 'bitwise_not(x)'

136

137 138 139 140 141 142 143 144
def test_equality():
    a = tvm.var('a')
    b = tvm.var('b')
    c = (a == b)
    assert not c
    d = (c != c)
    assert not d

145
if __name__ == "__main__":
146
    test_cast()
tqchen committed
147
    test_attr()
148
    test_const()
tqchen committed
149 150
    test_make()
    test_ir()
tqchen committed
151 152
    test_basic()
    test_stmt()
153
    test_let()
154 155
    test_dir()
    test_dtype()
156 157
    test_any()
    test_all()
158
    test_bitwise()
159
    test_equality()