test_expr_functor.py 3.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import te
19
from tvm import relay
20
from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor
21 22 23

def check_visit(expr):
    try:
24
        ef = ExprFunctor()
25 26 27 28 29
        ef.visit(expr)
        assert False
    except NotImplementedError:
        pass

30 31 32
    ev = ExprVisitor()
    ev.visit(expr)

33 34 35
    em = ExprMutator()
    assert em.visit(expr)

36

37 38 39
def test_constant():
    check_visit(relay.const(1.0))

40

41 42 43 44
def test_tuple():
    t = relay.Tuple([relay.var('x', shape=())])
    check_visit(t)

45

46 47 48 49
def test_var():
    v = relay.var('x', shape=())
    check_visit(v)

50

51 52 53 54
def test_global():
    v = relay.GlobalVar('f')
    check_visit(v)

55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
def test_function():
    x = relay.var('x', shape=())
    y = relay.var('y', shape=())
    params = [x, y]
    body = x + y
    ret_type = relay.TensorType(())
    type_params = []
    attrs = None # How to build?
    f = relay.Function(
        params,
        body,
        ret_type,
        type_params,
        attrs
    )
    check_visit(f)

73

74 75 76 77 78 79
def test_call():
    x = relay.var('x', shape=())
    y = relay.var('y', shape=())
    call = relay.op.add(x, y)
    check_visit(call)

80

81 82 83 84 85 86 87
def test_let():
    x = relay.var('x', shape=())
    value = relay.const(2.0)
    body = x + x
    l = relay.Let(x, value, body)
    check_visit(l)

88

89 90 91 92 93
def test_ite():
    cond = relay.var('x', shape=(), dtype='bool')
    ite = relay.If(cond, cond, cond)
    check_visit(ite)

94

95 96 97 98 99
def test_get_item():
    t = relay.Tuple([relay.var('x', shape=())])
    t = relay.TupleGetItem(t, 0)
    check_visit(t)

100

101 102 103 104
def test_ref_create():
    r = relay.expr.RefCreate(relay.const(1.0))
    check_visit(r)

105

106 107 108 109 110
def test_ref_read():
    ref = relay.expr.RefCreate(relay.const(1.0))
    r = relay.expr.RefRead(ref)
    check_visit(r)

111

112 113 114 115 116
def test_ref_write():
    ref = relay.expr.RefCreate(relay.const(1.0))
    r = relay.expr.RefWrite(ref, relay.const(2.0))
    check_visit(r)

117 118 119 120 121 122 123 124

def test_memo():
    expr = relay.const(1)
    for _ in range(100):
        expr = expr + expr
    check_visit(expr)


雾雨魔理沙 committed
125 126 127 128
def test_match():
    p = relay.prelude.Prelude()
    check_visit(p.mod[p.map])

129 130 131 132 133 134 135 136 137 138

def test_match_completeness():
    p = relay.prelude.Prelude()
    for completeness in [True, False]:
        match_expr = relay.adt.Match(p.nil, [], complete=completeness)
        result_expr = ExprMutator().visit(match_expr)
        # ensure the mutator doesn't mangle the completeness flag
        assert result_expr.complete == completeness


139 140 141 142 143 144 145 146 147 148 149 150
if __name__ == "__main__":
    test_constant()
    test_tuple()
    test_var()
    test_global()
    test_function()
    test_call()
    test_let()
    test_ite()
    test_ref_create()
    test_ref_read()
    test_ref_write()
151
    test_memo()
雾雨魔理沙 committed
152
    test_match()
153
    test_match_completeness()