test_ir_text_printer.py 3 KB
Newer Older
1
import tvm
2
import tvm.relay.testing
3 4 5
import numpy as np
from tvm import relay

6

7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
do_print = [False]

def show(text):
    if do_print[0]:
        print("---------------------------")
        print(text)

def test_func():
    x = relay.var("x", shape=(3, 2))
    y = relay.var("y")
    one = relay.const(10e10, dtype="float32")
    z = relay.add(x, one)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
    show(z.astext())
    show(f.astext())


def test_env():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    z = relay.add(x, y)
    z = relay.add(z, z)
    f = relay.Function([x, y], z)
    env = relay.Environment()
32
    env["myf"] = f
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    text = env.astext()
    assert "def @myf" in text
    assert "%1 = add(%0, %0) # ty=float32" in text
    show(text)


def test_meta_data():
    n, c, h, w = tvm.var("n"), 10, 224, 224
    x = relay.var("x", shape=(n, c, h, w))
    w = relay.var("w")
    z = relay.nn.conv2d(x, w,
                        kernel_size=(3, 3),
                        padding=(1, 1),
                        channels=2)
    f = relay.Function([x, w], z)
    text = f.astext()
    assert "channels=2" in text
    assert "meta.Variable(id=0)" in text
    show(text)

    text = relay.const([1,2,3]).astext()
    assert "meta.relay.Constant(id=0)" in text
    show(text)


def test_call_attrs():
    x = relay.var("x")
    # non default args
    z = relay.nn.softmax(x, axis=2)
    assert "axis=2" in z.astext()
    # default args
    z = relay.nn.softmax(x)
    assert "softmax(%x)" in z.astext()
    # non default args
    z = relay.expand_dims(x, axis=2, num_newaxis=2)
    assert "num_newaxis=2" in z.astext()


def test_let_if_scope():
    x = relay.var("x", "float32")
    y = relay.var("y", "float32")
    cond = relay.var("cond", "bool")
75 76 77 78 79 80 81 82 83 84 85 86

    sb = relay.ScopeBuilder()
    with sb.if_scope(cond):
        v1 = sb.let("v", relay.const(1, "float32"))
        v2 = sb.let("v", x)
        sb.ret(relay.subtract(v1, v2))
    with sb.else_scope():
        v3 = relay.var("v")
        let2 = relay.Let(v3, y, v3)
        sb.ret(relay.add(let2, let2))
    result = sb.get()

87 88 89 90 91 92 93
    f = relay.Function([x, y, cond], result)
    text = f.astext()
    assert text.count("{") == 4
    assert "%cond: bool" in text
    show(f.astext())


94 95 96 97 98
def test_variable_name():
    # avoid pure number even if the namehint is pure number
    v1 = relay.var("1")
    assert "%v1" in v1.astext()

99 100 101 102 103 104 105
def test_mlp():
    net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
    net.astext()

def test_resnet():
    net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
    net.astext()
106

107 108
def test_dqn():
    net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
109 110 111 112 113
    net.astext()

def test_dcgan():
    net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
    net.astext()
114

115 116
if __name__ == "__main__":
    do_print[0] = True
117 118
    test_resnet()
    test_mlp()
119
    test_dqn()
120
    test_dcgan()
121 122 123 124
    test_func()
    test_env()
    test_meta_data()
    test_call_attrs()
125 126
    test_let_if_scope()
    test_variable_name()