test_ir_text_printer.py 5.33 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
import tvm.relay.testing
19 20 21 22 23
import numpy as np
from tvm import relay

do_print = [False]

24 25
SEMVER = "v0.0.1\n"

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
def show(text):
    if do_print[0]:
        print("---------------------------")
        print(text)

def test_func():
    x = relay.var("x", shape=(3, 2))
    y = relay.var("y")
    one = relay.const(10e10, dtype="float32")
    z = relay.add(x, one)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
    show(z.astext())
    show(f.astext())


def test_env():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    z = relay.add(x, y)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
48
    env = relay.Module()
49
    env["myf"] = f
50 51
    text = env.astext()
    assert "def @myf" in text
Zhi committed
52
    assert "def @myf" in str(env)
53 54
    assert "%1 = add(%0, %0) // ty=float32" in text
    assert "%1 = add(%0, %0) // ty=float32" in str(env)
55
    show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
56 57 58 59 60 61 62 63 64 65 66 67 68
    show(text)


def test_meta_data():
    n, c, h, w = tvm.var("n"), 10, 224, 224
    x = relay.var("x", shape=(n, c, h, w))
    w = relay.var("w")
    z = relay.nn.conv2d(x, w,
                        kernel_size=(3, 3),
                        padding=(1, 1),
                        channels=2)
    f = relay.Function([x, w], z)
    text = f.astext()
Zhi committed
69
    text_no_meta = str(f)
70
    assert "channels=2" in text
Zhi committed
71
    assert "channels=2" in text_no_meta
72
    assert "meta[Variable][0]" in text
Zhi committed
73 74 75
    assert "meta[Variable][0]" in text_no_meta
    assert "type_key" in text
    assert "type_key" not in text_no_meta
76
    show(text)
Zhi committed
77
    show(f)
78 79

    text = relay.const([1,2,3]).astext()
80
    assert "meta[relay.Constant][0]" in text
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    show(text)


def test_call_attrs():
    x = relay.var("x")
    # non default args
    z = relay.nn.softmax(x, axis=2)
    assert "axis=2" in z.astext()
    # default args
    z = relay.nn.softmax(x)
    assert "softmax(%x)" in z.astext()
    # non default args
    z = relay.expand_dims(x, axis=2, num_newaxis=2)
    assert "num_newaxis=2" in z.astext()


def test_let_if_scope():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    cond = relay.var("cond", "bool")
101 102 103 104 105 106 107 108 109 110 111 112

    sb = relay.ScopeBuilder()
    with sb.if_scope(cond):
        v1 = sb.let("v", relay.const(1, "float32"))
        v2 = sb.let("v", x)
        sb.ret(relay.subtract(v1, v2))
    with sb.else_scope():
        v3 = relay.var("v")
        let2 = relay.Let(v3, y, v3)
        sb.ret(relay.add(let2, let2))
    result = sb.get()

113 114
    f = relay.Function([x, y, cond], result)
    text = f.astext()
115
    assert text.count("{") == 6
116 117 118 119
    assert "%cond: bool" in text
    show(f.astext())


120 121 122 123 124
def test_variable_name():
    # avoid pure number even if the namehint is pure number
    v1 = relay.var("1")
    assert "%v1" in v1.astext()

125

126 127 128 129
def test_mlp():
    net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
    net.astext()

130

131 132 133
def test_resnet():
    net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
    net.astext()
134

135

eqy committed
136 137 138 139
def test_mobilenet():
    net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
    net.astext()

140

141 142
def test_dqn():
    net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
143 144
    net.astext()

145

146 147 148
def test_dcgan():
    net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
    net.astext()
149

150 151 152 153 154

def test_lstm():
    net, params = tvm.relay.testing.lstm.get_workload(4, 4)
    net.astext()

155 156 157 158 159 160 161 162 163 164 165 166
def test_inception_v3():
    net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
    net.astext()

def test_squeezenet():
    for version in ['1.0', '1.1']:
        net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
        net.astext()

def test_vgg():
    net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
    net.astext()
167

168 169 170 171
def test_densenet():
    net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
    net.astext()

172 173 174 175 176 177 178 179 180 181 182 183 184
def test_call_node_order():
    x = relay.var("x")
    y = relay.var("y")
    assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \
        ("%0 = fn (%y) {\n"
         "  %y\n"
         "}\n"
         "%1 = %0(1)\n"
         "%2 = fn (%x) {\n"
         "  %x\n"
         "}\n"
         "%3 = %2(%1)\n"
         "%3")
185

186 187
if __name__ == "__main__":
    do_print[0] = True
188
    test_resnet()
eqy committed
189
    test_mobilenet()
190
    test_mlp()
191
    test_dqn()
192
    test_dcgan()
193 194 195
    test_squeezenet()
    test_inception_v3()
    test_vgg()
196
    test_densenet()
197 198 199 200
    test_func()
    test_env()
    test_meta_data()
    test_call_attrs()
201 202
    test_let_if_scope()
    test_variable_name()
203
    test_call_node_order()