test_ir_text_printer.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import tvm
18
from tvm import relay
19
import tvm.relay.testing
20
import numpy as np
21 22
from tvm.relay import Expr
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars
23 24 25

do_print = [False]

26
SEMVER = "v0.0.3\n"
27

28 29 30 31 32 33 34 35 36 37 38
def astext(p, graph_equal=False):
    txt = p.astext()
    if isinstance(p, Expr) and free_vars(p):
        return txt
    x = relay.fromtext(txt)
    if graph_equal:
        assert_graph_equal(x, p)
    else:
        assert_alpha_equal(x, p)
    return txt

39 40 41 42 43 44 45 46 47 48 49 50
def show(text):
    if do_print[0]:
        print("---------------------------")
        print(text)

def test_func():
    x = relay.var("x", shape=(3, 2))
    y = relay.var("y")
    one = relay.const(10e10, dtype="float32")
    z = relay.add(x, one)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
51 52
    show(astext(z))
    show(astext(f))
53 54 55 56 57 58 59 60


def test_env():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    z = relay.add(x, y)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
61
    env = relay.Module()
62
    env["myf"] = f
63
    text = astext(env)
64
    assert "def @myf" in text
Zhi committed
65
    assert "def @myf" in str(env)
66 67
    assert "add(%0, %0) /* ty=float32 */" in text
    assert "add(%0, %0) /* ty=float32 */" in str(env)
68
    show(env.astext(annotate=lambda x: str(x.checked_type.dtype) if type(x) == relay.Call else ""))
69 70 71 72 73 74 75 76 77 78 79 80
    show(text)


def test_meta_data():
    n, c, h, w = tvm.var("n"), 10, 224, 224
    x = relay.var("x", shape=(n, c, h, w))
    w = relay.var("w")
    z = relay.nn.conv2d(x, w,
                        kernel_size=(3, 3),
                        padding=(1, 1),
                        channels=2)
    f = relay.Function([x, w], z)
81
    text = astext(f, graph_equal=True)
Zhi committed
82
    text_no_meta = str(f)
83
    assert "channels=2" in text
Zhi committed
84
    assert "channels=2" in text_no_meta
85
    assert "meta[Variable][0]" in text
Zhi committed
86 87 88
    assert "meta[Variable][0]" in text_no_meta
    assert "type_key" in text
    assert "type_key" not in text_no_meta
89

90
    text = astext(relay.const([1,2,3]))
91
    assert "meta[relay.Constant][0]" in text
92 93 94 95 96 97


def test_call_attrs():
    x = relay.var("x")
    # non default args
    z = relay.nn.softmax(x, axis=2)
98
    assert "axis=2" in astext(z)
99 100
    # default args
    z = relay.nn.softmax(x)
101
    assert "softmax(%x)" in astext(z)
102 103
    # non default args
    z = relay.expand_dims(x, axis=2, num_newaxis=2)
104
    assert "num_newaxis=2" in astext(z)
105 106 107 108 109 110


def test_let_if_scope():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    cond = relay.var("cond", "bool")
111 112 113 114 115 116 117 118 119 120 121 122

    sb = relay.ScopeBuilder()
    with sb.if_scope(cond):
        v1 = sb.let("v", relay.const(1, "float32"))
        v2 = sb.let("v", x)
        sb.ret(relay.subtract(v1, v2))
    with sb.else_scope():
        v3 = relay.var("v")
        let2 = relay.Let(v3, y, v3)
        sb.ret(relay.add(let2, let2))
    result = sb.get()

123
    f = relay.Function([x, y, cond], result)
124
    text = astext(f)
125
    assert text.count("{") == 4
126
    assert "%cond: bool" in text
127
    show(astext(f))
128 129


130 131 132
def test_variable_name():
    # avoid pure number even if the namehint is pure number
    v1 = relay.var("1")
133
    assert "%v1" in astext(v1)
134

135

136 137
def test_mlp():
    net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
138
    astext(net)
139

140

141 142
def test_resnet():
    net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
143
    astext(net)
144

145

eqy committed
146 147
def test_mobilenet():
    net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
148
    astext(net)
eqy committed
149

150

151 152
def test_dqn():
    net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
153
    astext(net)
154

155

156 157
def test_dcgan():
    net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
158
    astext(net)
159

160 161

def test_lstm():
162 163 164
    net, params = tvm.relay.testing.lstm.get_workload(1, 1)
    astext(net)

165
    net, params = tvm.relay.testing.lstm.get_workload(4, 4)
166
    astext(net)
167

168 169
def test_inception_v3():
    net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
170
    astext(net)
171 172 173 174

def test_squeezenet():
    for version in ['1.0', '1.1']:
        net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
175
        astext(net)
176 177 178

def test_vgg():
    net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
179
    astext(net)
180

181 182
def test_densenet():
    net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
183
    astext(net)
184

185 186 187
def test_call_node_order():
    x = relay.var("x")
    y = relay.var("y")
188 189
    prog = relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])])
    assert astext(prog) == SEMVER + \
190 191
        ("%0 = fn (%y) {\n"
         "  %y\n"
192 193
         "};\n"
         "%1 = %0(1);\n"
194 195
         "%2 = fn (%x) {\n"
         "  %x\n"
196
         "};\n"
197 198 199 200 201
         "%2(%1)")

def test_let_inlining():
    tup = relay.Tuple([relay.const(0), relay.const(0)])
    x = relay.var("x")
202
    assert astext(relay.Let(x, tup, tup)) == SEMVER + \
203 204
        ("%0 = (0, 0);\n"
         "let %x = %0;\n"
205 206
         "%0")

207
    assert astext(relay.Let(x, tup, x)) == SEMVER + \
208
        ("let %x = (0, 0);\n"
209
         "%x")
210

211 212 213 214
def test_zeros():
    x = relay.op.zeros([], "float32")
    astext(x)

215 216
if __name__ == "__main__":
    do_print[0] = True
217 218 219 220
    test_lstm()
    test_zeros()
    test_meta_data()
    test_let_inlining()
221
    test_resnet()
eqy committed
222
    test_mobilenet()
223
    test_mlp()
224
    test_dqn()
225
    test_dcgan()
226 227 228
    test_squeezenet()
    test_inception_v3()
    test_vgg()
229
    test_densenet()
230 231 232
    test_func()
    test_env()
    test_call_attrs()
233 234
    test_let_if_scope()
    test_variable_name()
235
    test_call_node_order()