test_lang_basic.py 3.82 KB
Newer Older
tqchen committed
1
import tvm
tqchen committed
2

3 4
def test_const():
    x = tvm.const(1)
5 6
    print(x.dtype)
    assert x.dtype == tvm.int32
7
    assert isinstance(x, tvm.expr.IntImm)
8

tqchen committed
9 10
def test_make():
    x = tvm.const(1)
11
    y = tvm.var("x")
tqchen committed
12
    z = x + y
13 14
    assert isinstance(tvm.max(x, y), tvm.expr.Max)
    assert isinstance(tvm.min(x, y), tvm.expr.Min)
tqchen committed
15 16 17 18 19 20 21 22

def test_ir():
    x = tvm.const(1)
    y = tvm.make.IntImm('int32', 1)
    z = x + y
    stmt = tvm.make.Evaluate(z)
    assert isinstance(stmt, tvm.stmt.Evaluate)

23 24 25 26 27 28 29
def test_ir2():
    x = tvm.var("n")
    a = tvm.var("array", tvm.handle)
    st = tvm.make.Store(a, x + 1, 1)
    assert isinstance(st, tvm.stmt.Store)
    assert(st.buffer_var == a)

30
def test_let():
31 32
    x = tvm.var('x')
    y = tvm.var('y')
33
    stmt = tvm.make.LetStmt(
tqchen committed
34
        x, 10, tvm.make.Evaluate(x + 1));
35

36 37 38 39 40 41 42 43 44
def test_cast():
    x = tvm.var('x', dtype="float32")
    y = x.astype("int32")
    z = x.astype("float32x4")
    assert isinstance(y, tvm.expr.Cast)
    assert isinstance(z, tvm.expr.Broadcast)
    assert z.lanes == 4


tqchen committed
45
def test_attr():
46 47
    x = tvm.var('x')
    y = tvm.var('y')
tqchen committed
48 49 50
    stmt = tvm.make.AttrStmt(
        y, "stride", 10, tvm.make.Evaluate(x + 1));
    assert stmt.node == y
51 52 53 54 55 56 57 58 59

    a = tvm.convert(1)
    assert a.value == 1
    try:
        a.no_field
        assert False
    except AttributeError:
        pass

60

tqchen committed
61
def test_basic():
62 63
    a = tvm.var('a')
    b = tvm.var('b')
tqchen committed
64
    c =  a + b
65
    assert str(c) == '(%s + %s)' % (a.name, b.name)
tqchen committed
66 67 68


def test_stmt():
69
    x = tvm.make.Evaluate(0)
70
    tvm.make.For(tvm.var('i'), 0, 1,
tqchen committed
71
                 tvm.stmt.For.Serial, 0,
72
                 x)
tqchen committed
73

74 75 76 77 78 79 80 81
def test_dir():
    x = tvm.var('x')
    dir(x)

def test_dtype():
    x = tvm.var('x')
    assert x.dtype == 'int32'
    y = tvm.var('y')
82
    assert (x > y).dtype == 'bool'
tqchen committed
83

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

def test_any():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x or x
        assert False
    except ValueError:
        pass
    try:
        tvm.any()
        assert False
    except ValueError:
        pass
    assert str(tvm.any(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.any(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)


def test_all():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x and x
        assert False
    except ValueError:
        pass
    try:
        tvm.all()
        assert False
    except ValueError:
        pass
    assert str(tvm.all(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.all(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)

128 129 130 131 132 133 134 135 136
def test_bitwise():
    x = tvm.var('x')
    y = tvm.var('y')
    assert str(x << y) == 'shift_left(x, y)'
    assert str(x >> y) == 'shift_right(x, y)'
    assert str(x & y) == 'bitwise_and(x, y)'
    assert str(x | y) == 'bitwise_or(x, y)'
    assert str(x ^ y) == 'bitwise_xor(x, y)'
    assert str(~x) == 'bitwise_not(x)'
137 138 139
    assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2"
    assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"
    assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"
140

141

142 143 144 145 146 147 148 149
def test_equality():
    a = tvm.var('a')
    b = tvm.var('b')
    c = (a == b)
    assert not c
    d = (c != c)
    assert not d

150
if __name__ == "__main__":
151
    test_cast()
tqchen committed
152
    test_attr()
153
    test_const()
tqchen committed
154 155
    test_make()
    test_ir()
tqchen committed
156 157
    test_basic()
    test_stmt()
158
    test_let()
159 160
    test_dir()
    test_dtype()
161 162
    test_any()
    test_all()
163
    test_bitwise()
164
    test_equality()