Commit 77445311 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] fix 'please use input parameter mod warning' triggered in build_module (#3452)

parent dfc1fb25
......@@ -127,10 +127,10 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
func = relay.Function(params, call)
mod = relay.Module.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build,
args=(func,
args=(mod,
"llvm -device=tracing",
None,
None))
......
......@@ -105,8 +105,9 @@ def extract_from_program(func, params, ops, target, target_host=None):
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_build,
args=(func,
args=(mod,
target,
target_host,
params))
......@@ -183,8 +184,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
for func, param in zip(funcs, params):
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=my_build,
args=(func,
args=(mod,
target,
target_host,
params))
......
......@@ -163,6 +163,8 @@ class Executor(object):
args: List[tvm.NDArray]
The new arguments with all keyword arguments placed in the correct slot.
"""
assert expr is not None
if not kwargs:
return args
......
......@@ -25,7 +25,6 @@ from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ir_pass
from . import ty as _ty
from . import expr as _expr
from .module import Module as _Module
......@@ -227,23 +226,23 @@ class GraphExecutor(_interpreter.Executor):
"""
def __init__(self, mod, ctx, target):
assert mod is not None
self.mod = mod
self.ctx = ctx
self.target = target
def _make_executor(self, expr=None):
if not expr:
assert self.mod, "either expr or self.mod should be not null."
expr = self.mod[self.mod.entry_func]
ret_type = ir_pass.infer_type(expr).ret_type
if expr:
self.mod[self.mod.entry_func] = expr
ret_type = self.mod[self.mod.entry_func].checked_type.ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(expr, target=self.target)
graph_json, mod, params = build(self.mod, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(**params)
def _graph_wrapper(*args, **kwargs):
args = self._convert_args(expr, args, kwargs)
args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
......@@ -280,6 +279,8 @@ def create_executor(kind="debug",
target : :py:class:`tvm.Target`
The corresponding context
"""
if mod is None:
mod = _Module()
if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type
else:
......
......@@ -33,7 +33,7 @@ class Module(RelayNode):
Parameters
----------
functions : dict, optional.
functions: Optional[dict].
Map of global var to Function
"""
def __init__(self, functions=None, type_definitions=None):
......@@ -100,7 +100,7 @@ class Module(RelayNode):
Parameters
----------
var: str or GlobalVar
var: Union[String, GlobalVar, GlobalTypeVar]
The name or global variable.
Returns
......
......@@ -40,8 +40,8 @@ def test_task_extraction():
net, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params,
ops=(relay.op.nn.conv2d,))
params=params,
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12
net, params, input_shape = get_network('resnet-18', batch_size=1)
......
......@@ -57,7 +57,7 @@ def test_compile_placeholder_bypass():
result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)])
func = relay.Function(relay.ir_pass.free_vars(result), result)
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, 'llvm')
graph, lib, params = relay.build(relay.Module.from_expr(func), 'llvm')
def test_compile_injective_with_tuple():
......@@ -66,7 +66,7 @@ def test_compile_injective_with_tuple():
x_transpose = relay.transpose(x)
output = relay.Tuple([x_transpose, y])
func = relay.Function([x, y], output)
relay.build(func, 'llvm')
relay.build(relay.Module.from_expr(func), 'llvm')
def test_compile_tuple_dup():
......@@ -74,7 +74,7 @@ def test_compile_tuple_dup():
log = relay.log(x)
output = relay.Tuple([log, log])
f = relay.Function([x], output)
relay.build(f, 'llvm')
relay.build(relay.Module.from_expr(f), 'llvm')
if __name__ == "__main__":
......
......@@ -101,7 +101,7 @@ def test_with_params():
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data}
graph, lib, params = relay.build(func, "llvm", params=params)
graph, lib, params = relay.build(relay.Module.from_expr(func), "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input(**params)
mod.set_input(x=x_data)
......@@ -170,7 +170,7 @@ def test_gru_like():
for target, ctx in ctx_list():
with relay.build_config(opt_level=2):
graph, lib, params = relay.build(z, target)
graph, lib, params = relay.build(relay.Module.from_expr(z), target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input("X", tvm.nd.array(x.astype(dtype)))
m.set_input("y", tvm.nd.array(y.astype(dtype)))
......
......@@ -43,7 +43,7 @@ def test_basic_build():
targets = {
tvm.expr.IntImm("int32", ctx.device_type): tgt
}
g_json, mmod, params = relay.build(func, targets, "llvm", params=params)
g_json, mmod, params = relay.build(relay.Module.from_expr(func), targets, "llvm", params=params)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
......@@ -115,7 +115,7 @@ def test_fp16_conversion():
# build
with relay.build_config(opt_level=1):
g_json, mmod, params = relay.build(func, tgt)
g_json, mmod, params = relay.build(relay.Module.from_expr(func), tgt)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
......
......@@ -342,6 +342,9 @@ def test_tuple_get_root():
assert relay.ir_pass.alpha_equal(zz, after)
fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
fuse2 = relay.transform.FuseOps(fuse_opt_level=2)
def test_tuple_intermediate():
def before(x):
inj = relay.squeeze(x)
......@@ -363,16 +366,12 @@ def test_tuple_intermediate():
dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
orig = before(x)
fuse0(relay.Module.from_expr(orig))
m = fuse2(relay.Module.from_expr(orig))
relay.build(m, 'llvm')
after = relay.ir_pass.infer_type(expected(x))
assert relay.ir_pass.alpha_equal(zz, after)
assert relay.ir_pass.alpha_equal(m[m.entry_func], after)
def test_tuple_consecutive():
......@@ -422,16 +421,12 @@ def test_tuple_consecutive():
dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape)
z = before(x)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
orig = before(x)
fuse0(relay.Module.from_expr(orig))
m = fuse2(relay.Module.from_expr(orig))
relay.build(m, 'llvm')
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
assert relay.ir_pass.alpha_equal(m[m.entry_func], after)
def test_inception_like():
......@@ -493,16 +488,12 @@ def test_inception_like():
return relay.Function(relay.ir_pass.free_vars(out), out)
dshape = (1, 16, 64, 64)
z = before(dshape)
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
orig = before(dshape)
fuse0(relay.Module.from_expr(orig))
m = fuse2(relay.Module.from_expr(orig))
relay.build(m, 'llvm')
after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after)
assert relay.ir_pass.alpha_equal(m[m.entry_func], after)
def test_fuse_parallel_injective():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment