test_ir_text_printer.py 4.56 KB
Newer Older
1
import tvm
2
import tvm.relay.testing
3 4 5 6 7
import numpy as np
from tvm import relay

do_print = [False]

8 9
SEMVER = "v0.0.1\n"

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
def show(text):
    if do_print[0]:
        print("---------------------------")
        print(text)

def test_func():
    x = relay.var("x", shape=(3, 2))
    y = relay.var("y")
    one = relay.const(10e10, dtype="float32")
    z = relay.add(x, one)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
    show(z.astext())
    show(f.astext())


def test_env():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    z = relay.add(x, y)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
32
    env = relay.Module()
33
    env["myf"] = f
34 35
    text = env.astext()
    assert "def @myf" in text
Zhi committed
36
    assert "def @myf" in str(env)
37 38
    assert "%1 = add(%0, %0) // ty=float32" in text
    assert "%1 = add(%0, %0) // ty=float32" in str(env)
39
    show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
40 41 42 43 44 45 46 47 48 49 50 51 52
    show(text)


def test_meta_data():
    n, c, h, w = tvm.var("n"), 10, 224, 224
    x = relay.var("x", shape=(n, c, h, w))
    w = relay.var("w")
    z = relay.nn.conv2d(x, w,
                        kernel_size=(3, 3),
                        padding=(1, 1),
                        channels=2)
    f = relay.Function([x, w], z)
    text = f.astext()
Zhi committed
53
    text_no_meta = str(f)
54
    assert "channels=2" in text
Zhi committed
55
    assert "channels=2" in text_no_meta
56
    assert "meta[Variable][0]" in text
Zhi committed
57 58 59
    assert "meta[Variable][0]" in text_no_meta
    assert "type_key" in text
    assert "type_key" not in text_no_meta
60
    show(text)
Zhi committed
61
    show(f)
62 63

    text = relay.const([1,2,3]).astext()
64
    assert "meta[relay.Constant][0]" in text
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    show(text)


def test_call_attrs():
    x = relay.var("x")
    # non default args
    z = relay.nn.softmax(x, axis=2)
    assert "axis=2" in z.astext()
    # default args
    z = relay.nn.softmax(x)
    assert "softmax(%x)" in z.astext()
    # non default args
    z = relay.expand_dims(x, axis=2, num_newaxis=2)
    assert "num_newaxis=2" in z.astext()


def test_let_if_scope():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    cond = relay.var("cond", "bool")
85 86 87 88 89 90 91 92 93 94 95 96

    sb = relay.ScopeBuilder()
    with sb.if_scope(cond):
        v1 = sb.let("v", relay.const(1, "float32"))
        v2 = sb.let("v", x)
        sb.ret(relay.subtract(v1, v2))
    with sb.else_scope():
        v3 = relay.var("v")
        let2 = relay.Let(v3, y, v3)
        sb.ret(relay.add(let2, let2))
    result = sb.get()

97 98
    f = relay.Function([x, y, cond], result)
    text = f.astext()
99
    assert text.count("{") == 6
100 101 102 103
    assert "%cond: bool" in text
    show(f.astext())


104 105 106 107 108
def test_variable_name():
    # avoid pure number even if the namehint is pure number
    v1 = relay.var("1")
    assert "%v1" in v1.astext()

109

110 111 112 113
def test_mlp():
    net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
    net.astext()

114

115 116 117
def test_resnet():
    net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
    net.astext()
118

119

eqy committed
120 121 122 123
def test_mobilenet():
    net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
    net.astext()

124

125 126
def test_dqn():
    net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
127 128
    net.astext()

129

130 131 132
def test_dcgan():
    net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
    net.astext()
133

134 135 136 137 138

def test_lstm():
    net, params = tvm.relay.testing.lstm.get_workload(4, 4)
    net.astext()

139 140 141 142 143 144 145 146 147 148 149 150
def test_inception_v3():
    net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
    net.astext()

def test_squeezenet():
    for version in ['1.0', '1.1']:
        net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
        net.astext()

def test_vgg():
    net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
    net.astext()
151

152 153 154 155
def test_densenet():
    net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
    net.astext()

156 157 158 159 160 161 162 163 164 165 166 167 168
def test_call_node_order():
    x = relay.var("x")
    y = relay.var("y")
    assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \
        ("%0 = fn (%y) {\n"
         "  %y\n"
         "}\n"
         "%1 = %0(1)\n"
         "%2 = fn (%x) {\n"
         "  %x\n"
         "}\n"
         "%3 = %2(%1)\n"
         "%3")
169

170 171
if __name__ == "__main__":
    do_print[0] = True
172
    test_resnet()
eqy committed
173
    test_mobilenet()
174
    test_mlp()
175
    test_dqn()
176
    test_dcgan()
177 178 179
    test_squeezenet()
    test_inception_v3()
    test_vgg()
180
    test_densenet()
181 182 183 184
    test_func()
    test_env()
    test_meta_data()
    test_call_attrs()
185 186
    test_let_if_scope()
    test_variable_name()
187
    test_call_node_order()