test_ir_well_formed.py 2.21 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
import tvm
from tvm import relay
Zhi committed
19
from tvm.relay.analysis import well_formed
20
from tvm.relay.prelude import Prelude
21

22 23
def test_let():
    x = relay.Var("x")
24 25 26
    assert well_formed(x)
    v = relay.Constant(tvm.nd.array(10))
    ty = None
27
    let = relay.Let(x, v, x)
28
    assert well_formed(let)
29
    assert not well_formed(relay.Let(x, v, let))
30
    f = relay.Function([x], x, ty)
31
    assert well_formed(f)
32 33 34
    assert well_formed(
        relay.Let(relay.Var("y"), f,
                  relay.Let(relay.Var("z"), f, v)))
35 36 37


def test_tuple():
38
    x = relay.Var("x")
39 40
    assert well_formed(x)
    v = relay.Constant(tvm.nd.array(10))
41
    let = relay.Let(x, v, x)
42 43
    assert well_formed(let)
    assert well_formed(relay.Tuple([v, v]))
44
    assert not well_formed(relay.Tuple([let, relay.Let(x, v, x)]))
45 46 47


def test_tuple_get_item():
48
    t = relay.Var("t")
49
    assert well_formed(relay.TupleGetItem(t, 2))
50 51 52 53 54 55


def test_adt():
    mod = relay.Module()
    p = Prelude(mod)
    x = relay.Var("x")
56 57 58
    some_case = relay.Clause(relay.PatternConstructor(p.some,
                                                      [relay.PatternVar(x)]),
                             x)
59
    default_case = relay.Clause(relay.PatternVar(x), x)
60 61
    m0 = relay.Match(p.none(), [default_case])
    m1 = relay.Match(p.none(), [some_case, default_case])
62 63 64 65 66 67 68 69
    assert well_formed(m0)
    assert not well_formed(m1)

if __name__ == "__main__":
    test_let()
    test_tuple()
    test_tuple_get_item()
    test_adt()