test_lang_reflection.py 3.75 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20
import tvm

def test_const_saveload_json():
    # save load json
21 22
    x = tvm.const(1, "int32")
    y = tvm.const(10, "int32")
23 24
    z = x + y
    z = z + z
25 26 27
    json_str = tvm.ir.save_json(z)
    zz = tvm.ir.load_json(json_str)
    assert tvm.ir.save_json(zz) == tvm.ir.save_json(z)
28 29


30 31
def test_make_smap():
    # save load json
32 33
    x = tvm.const(1, "int32")
    y = tvm.const(10, "int32")
34
    z = tvm.tir.Add(x, y)
35
    smap = tvm.convert({"z": z, "x": x})
36 37
    json_str = tvm.ir.save_json(tvm.convert([smap]))
    arr = tvm.ir.load_json(json_str)
38 39 40 41
    assert len(arr) == 1
    assert arr[0]["z"].a == arr[0]["x"]


42
def test_make_node():
43 44
    x = tvm.ir.make_node("IntImm", dtype="int32", value=10)
    assert isinstance(x, tvm.tir.IntImm)
45 46
    assert x.value == 10
    A = tvm.placeholder((10, ), name='A')
47
    AA = tvm.ir.make_node("Tensor",
48 49 50 51 52 53 54
                       shape=A.shape,
                       dtype=A.dtype,
                       op=A.op,
                       value_index=A.value_index)
    assert AA.op == A.op
    assert AA.value_index == A.value_index

55 56 57

def test_make_attrs():
    try:
58
        x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
59
        assert False
60
    except tvm.error.TVMError as e:
61 62 63
        assert str(e).find("unknown_key") != -1

    try:
64
        x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
65
        assert False
66
    except tvm.error.TVMError as e:
67 68
        assert str(e).find("upper bound") != -1

69
    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
70 71 72 73 74
    assert x.name == "xx"
    assert x.padding[0].value == 3
    assert x.padding[1].value == 4
    assert x.axis == 10

75

76
    dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
77
    assert dattr.x.value == 1
78
    datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
79
    assert dattr.name.value == "xyz"
80

81 82


83 84 85 86
def test_make_sum():
    A = tvm.placeholder((2, 10), name='A')
    k = tvm.reduce_axis((0,10), "k")
    B = tvm.compute((2,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
87 88
    json_str = tvm.ir.save_json(B)
    BB = tvm.ir.load_json(json_str)
89 90
    assert B.op.body[0].combiner is not None
    assert BB.op.body[0].combiner is not None
91

92 93 94 95 96 97 98

def test_env_func():
    @tvm.register_func("test.env_func")
    def test(x):
        return x + 1

    f = tvm.get_global_func("test.env_func")
99
    x = tvm.ir.EnvFunc.get("test.env_func")
100
    assert x.name == "test.env_func"
101 102
    json_str = tvm.ir.save_json([x])
    y = tvm.ir.load_json(json_str)[0]
103 104 105 106
    assert y.name == x.name
    assert y(1) == 2
    assert y.func(1) == 2

107
    x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
108 109 110 111
    assert x.name == "xx"
    assert x.padding[0].value == 3
    assert x.padding[1].value == 4
    assert x.axis == 10
112 113
    x = tvm.ir.load_json(tvm.ir.save_json(x))
    assert isinstance(x.func, tvm.ir.EnvFunc)
114 115 116
    assert x.func(10) == 11


117
if __name__ == "__main__":
118
    test_env_func()
119
    test_make_attrs()
120
    test_make_node()
121
    test_make_smap()
122
    test_const_saveload_json()
123
    test_make_sum()