test_lang_reflection.py 2.91 KB
Newer Older
1 2 3 4
import tvm

def test_const_saveload_json():
    # save load json
5 6
    x = tvm.const(1, "int32")
    y = tvm.const(10, "int32")
7 8 9 10 11 12 13
    z = x + y
    z = z + z
    json_str = tvm.save_json(z)
    zz = tvm.load_json(json_str)
    assert tvm.save_json(zz) == tvm.save_json(z)


14 15
def test_make_smap():
    # save load json
16 17
    x = tvm.const(1, "int32")
    y = tvm.const(10, "int32")
18
    z = tvm.expr.Add(x, y)
19 20 21 22 23 24 25
    smap = tvm.convert({"z": z, "x": x})
    json_str = tvm.save_json(tvm.convert([smap]))
    arr = tvm.load_json(json_str)
    assert len(arr) == 1
    assert arr[0]["z"].a == arr[0]["x"]


26 27 28 29 30 31 32 33 34 35 36 37 38
def test_make_node():
    x = tvm.make.node("IntImm", dtype="int32", value=10)
    assert isinstance(x, tvm.expr.IntImm)
    assert x.value == 10
    A = tvm.placeholder((10, ), name='A')
    AA = tvm.make.node("Tensor",
                       shape=A.shape,
                       dtype=A.dtype,
                       op=A.op,
                       value_index=A.value_index)
    assert AA.op == A.op
    assert AA.value_index == A.value_index

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

def test_make_attrs():
    try:
        x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx")
        assert False
    except tvm.TVMError as e:
        assert str(e).find("unknown_key") != -1

    try:
        x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx")
        assert False
    except tvm.TVMError as e:
        assert str(e).find("upper bound") != -1

    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4))
    assert x.name == "xx"
    assert x.padding[0].value == 3
    assert x.padding[1].value == 4
    assert x.axis == 10

59

60 61
    dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
    assert dattr.x.value == 1
62 63
    datrr = tvm.load_json(tvm.save_json(dattr))
    assert dattr.name.value == "xyz"
64

65 66


67 68 69 70 71 72
def test_make_sum():
    A = tvm.placeholder((2, 10), name='A')
    k = tvm.reduce_axis((0,10), "k")
    B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
    json_str = tvm.save_json(B)
    BB = tvm.load_json(json_str)
73 74
    assert B.op.body[0].combiner is not None
    assert BB.op.body[0].combiner is not None
75

76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

def test_env_func():
    @tvm.register_func("test.env_func")
    def test(x):
        return x + 1

    f = tvm.get_global_func("test.env_func")
    x = tvm.get_env_func("test.env_func")
    assert x.name == "test.env_func"
    json_str = tvm.save_json([x])
    y = tvm.load_json(json_str)[0]
    assert y.name == x.name
    assert y(1) == 2
    assert y.func(1) == 2

    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
    assert x.name == "xx"
    assert x.padding[0].value == 3
    assert x.padding[1].value == 4
    assert x.axis == 10
    x = tvm.load_json(tvm.save_json(x))
    assert isinstance(x.func, tvm.container.EnvFunc)
    assert x.func(10) == 11


101
if __name__ == "__main__":
102
    test_env_func()
103
    test_make_attrs()
104
    test_make_node()
105
    test_make_smap()
106
    test_const_saveload_json()
107
    test_make_sum()