import tvm import tvm.relay.testing import numpy as np from tvm import relay do_print = [False] def show(text): if do_print[0]: print("---------------------------") print(text) def test_func(): x = relay.var("x", shape=(3, 2)) y = relay.var("y") one = relay.const(10e10, dtype="float32") z = relay.add(x, one) z = relay.add(z, z) f = relay.Function([x, y], z) show(z.astext()) show(f.astext()) def test_env(): x = relay.var("x", "float32") y = relay.var("y", "float32") z = relay.add(x, y) z = relay.add(z, z) f = relay.Function([x, y], z) env = relay.Environment() env["myf"] = f text = env.astext() assert "def @myf" in text assert "%1 = add(%0, %0) # ty=float32" in text show(text) def test_meta_data(): n, c, h, w = tvm.var("n"), 10, 224, 224 x = relay.var("x", shape=(n, c, h, w)) w = relay.var("w") z = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=2) f = relay.Function([x, w], z) text = f.astext() assert "channels=2" in text assert "meta.Variable(id=0)" in text show(text) text = relay.const([1,2,3]).astext() assert "meta.relay.Constant(id=0)" in text show(text) def test_call_attrs(): x = relay.var("x") # non default args z = relay.nn.softmax(x, axis=2) assert "axis=2" in z.astext() # default args z = relay.nn.softmax(x) assert "softmax(%x)" in z.astext() # non default args z = relay.expand_dims(x, axis=2, num_newaxis=2) assert "num_newaxis=2" in z.astext() def test_let_if_scope(): x = relay.var("x", "float32") y = relay.var("y", "float32") cond = relay.var("cond", "bool") sb = relay.ScopeBuilder() with sb.if_scope(cond): v1 = sb.let("v", relay.const(1, "float32")) v2 = sb.let("v", x) sb.ret(relay.subtract(v1, v2)) with sb.else_scope(): v3 = relay.var("v") let2 = relay.Let(v3, y, v3) sb.ret(relay.add(let2, let2)) result = sb.get() f = relay.Function([x, y, cond], result) text = f.astext() assert text.count("{") == 4 assert "%cond: bool" in text show(f.astext()) def test_variable_name(): # avoid pure number even if the namehint is pure number v1 = relay.var("1") assert "%v1" in v1.astext() def test_mlp(): net, params = tvm.relay.testing.mlp.get_workload(batch_size=1) net.astext() def test_resnet(): net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) net.astext() if __name__ == "__main__": do_print[0] = True test_resnet() test_mlp() test_func() test_env() test_meta_data() test_call_attrs() test_let_if_scope() test_variable_name()