Commit 77445311 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] fix 'please use input parameter mod warning' triggered in build_module (#3452)

parent dfc1fb25
...@@ -127,10 +127,10 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -127,10 +127,10 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
free_var = relay.Var("var_%d" % i, input_type) free_var = relay.Var("var_%d" % i, input_type)
params.append(free_var) params.append(free_var)
call = relay.Call(node.op, params, node.attrs) call = relay.Call(node.op, params, node.attrs)
func = relay.Function(params, call) mod = relay.Module.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear() relay.backend.compile_engine.get().clear()
build_thread = threading.Thread(target=relay.build, build_thread = threading.Thread(target=relay.build,
args=(func, args=(mod,
"llvm -device=tracing", "llvm -device=tracing",
None, None,
None)) None))
......
...@@ -105,8 +105,9 @@ def extract_from_program(func, params, ops, target, target_host=None): ...@@ -105,8 +105,9 @@ def extract_from_program(func, params, ops, target, target_host=None):
relay.backend.compile_engine.get().clear() relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems # wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=_build, build_thread = threading.Thread(target=_build,
args=(func, args=(mod,
target, target,
target_host, target_host,
params)) params))
...@@ -183,8 +184,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): ...@@ -183,8 +184,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
for func, param in zip(funcs, params): for func, param in zip(funcs, params):
relay.backend.compile_engine.get().clear() relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems # wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
build_thread = threading.Thread(target=my_build, build_thread = threading.Thread(target=my_build,
args=(func, args=(mod,
target, target,
target_host, target_host,
params)) params))
......
...@@ -163,6 +163,8 @@ class Executor(object): ...@@ -163,6 +163,8 @@ class Executor(object):
args: List[tvm.NDArray] args: List[tvm.NDArray]
The new arguments with all keyword arguments placed in the correct slot. The new arguments with all keyword arguments placed in the correct slot.
""" """
assert expr is not None
if not kwargs: if not kwargs:
return args return args
......
...@@ -25,7 +25,6 @@ from tvm import expr as tvm_expr ...@@ -25,7 +25,6 @@ from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt from ..contrib import graph_runtime as _graph_rt
from . import _build_module from . import _build_module
from . import ir_pass
from . import ty as _ty from . import ty as _ty
from . import expr as _expr from . import expr as _expr
from .module import Module as _Module from .module import Module as _Module
...@@ -227,23 +226,23 @@ class GraphExecutor(_interpreter.Executor): ...@@ -227,23 +226,23 @@ class GraphExecutor(_interpreter.Executor):
""" """
def __init__(self, mod, ctx, target): def __init__(self, mod, ctx, target):
assert mod is not None
self.mod = mod self.mod = mod
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
def _make_executor(self, expr=None): def _make_executor(self, expr=None):
if not expr: if expr:
assert self.mod, "either expr or self.mod should be not null." self.mod[self.mod.entry_func] = expr
expr = self.mod[self.mod.entry_func] ret_type = self.mod[self.mod.entry_func].checked_type.ret_type
ret_type = ir_pass.infer_type(expr).ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(expr, target=self.target) graph_json, mod, params = build(self.mod, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx) gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params: if params:
gmodule.set_input(**params) gmodule.set_input(**params)
def _graph_wrapper(*args, **kwargs): def _graph_wrapper(*args, **kwargs):
args = self._convert_args(expr, args, kwargs) args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs)
# Create map of inputs. # Create map of inputs.
for i, arg in enumerate(args): for i, arg in enumerate(args):
gmodule.set_input(i, arg) gmodule.set_input(i, arg)
...@@ -280,6 +279,8 @@ def create_executor(kind="debug", ...@@ -280,6 +279,8 @@ def create_executor(kind="debug",
target : :py:class:`tvm.Target` target : :py:class:`tvm.Target`
The corresponding context The corresponding context
""" """
if mod is None:
mod = _Module()
if ctx is not None: if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type assert ctx.device_type == _nd.context(str(target), 0).device_type
else: else:
......
...@@ -33,7 +33,7 @@ class Module(RelayNode): ...@@ -33,7 +33,7 @@ class Module(RelayNode):
Parameters Parameters
---------- ----------
functions : dict, optional. functions: Optional[dict].
Map of global var to Function Map of global var to Function
""" """
def __init__(self, functions=None, type_definitions=None): def __init__(self, functions=None, type_definitions=None):
...@@ -100,7 +100,7 @@ class Module(RelayNode): ...@@ -100,7 +100,7 @@ class Module(RelayNode):
Parameters Parameters
---------- ----------
var: str or GlobalVar var: Union[String, GlobalVar, GlobalTypeVar]
The name or global variable. The name or global variable.
Returns Returns
......
...@@ -40,8 +40,8 @@ def test_task_extraction(): ...@@ -40,8 +40,8 @@ def test_task_extraction():
net, params, input_shape = get_network('resnet-18', batch_size=1) net, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(net, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d,)) ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12 assert len(tasks) == 12
net, params, input_shape = get_network('resnet-18', batch_size=1) net, params, input_shape = get_network('resnet-18', batch_size=1)
......
...@@ -57,7 +57,7 @@ def test_compile_placeholder_bypass(): ...@@ -57,7 +57,7 @@ def test_compile_placeholder_bypass():
result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)]) result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)])
func = relay.Function(relay.ir_pass.free_vars(result), result) func = relay.Function(relay.ir_pass.free_vars(result), result)
with relay.build_config(opt_level=0): with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, 'llvm') graph, lib, params = relay.build(relay.Module.from_expr(func), 'llvm')
def test_compile_injective_with_tuple(): def test_compile_injective_with_tuple():
...@@ -66,7 +66,7 @@ def test_compile_injective_with_tuple(): ...@@ -66,7 +66,7 @@ def test_compile_injective_with_tuple():
x_transpose = relay.transpose(x) x_transpose = relay.transpose(x)
output = relay.Tuple([x_transpose, y]) output = relay.Tuple([x_transpose, y])
func = relay.Function([x, y], output) func = relay.Function([x, y], output)
relay.build(func, 'llvm') relay.build(relay.Module.from_expr(func), 'llvm')
def test_compile_tuple_dup(): def test_compile_tuple_dup():
...@@ -74,7 +74,7 @@ def test_compile_tuple_dup(): ...@@ -74,7 +74,7 @@ def test_compile_tuple_dup():
log = relay.log(x) log = relay.log(x)
output = relay.Tuple([log, log]) output = relay.Tuple([log, log])
f = relay.Function([x], output) f = relay.Function([x], output)
relay.build(f, 'llvm') relay.build(relay.Module.from_expr(f), 'llvm')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -101,7 +101,7 @@ def test_with_params(): ...@@ -101,7 +101,7 @@ def test_with_params():
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data} params = {"y": y_data}
graph, lib, params = relay.build(func, "llvm", params=params) graph, lib, params = relay.build(relay.Module.from_expr(func), "llvm", params=params)
mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod.set_input(**params) mod.set_input(**params)
mod.set_input(x=x_data) mod.set_input(x=x_data)
...@@ -170,7 +170,7 @@ def test_gru_like(): ...@@ -170,7 +170,7 @@ def test_gru_like():
for target, ctx in ctx_list(): for target, ctx in ctx_list():
with relay.build_config(opt_level=2): with relay.build_config(opt_level=2):
graph, lib, params = relay.build(z, target) graph, lib, params = relay.build(relay.Module.from_expr(z), target)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
m.set_input("X", tvm.nd.array(x.astype(dtype))) m.set_input("X", tvm.nd.array(x.astype(dtype)))
m.set_input("y", tvm.nd.array(y.astype(dtype))) m.set_input("y", tvm.nd.array(y.astype(dtype)))
......
...@@ -43,7 +43,7 @@ def test_basic_build(): ...@@ -43,7 +43,7 @@ def test_basic_build():
targets = { targets = {
tvm.expr.IntImm("int32", ctx.device_type): tgt tvm.expr.IntImm("int32", ctx.device_type): tgt
} }
g_json, mmod, params = relay.build(func, targets, "llvm", params=params) g_json, mmod, params = relay.build(relay.Module.from_expr(func), targets, "llvm", params=params)
# test # test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
...@@ -115,7 +115,7 @@ def test_fp16_conversion(): ...@@ -115,7 +115,7 @@ def test_fp16_conversion():
# build # build
with relay.build_config(opt_level=1): with relay.build_config(opt_level=1):
g_json, mmod, params = relay.build(func, tgt) g_json, mmod, params = relay.build(relay.Module.from_expr(func), tgt)
# test # test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
......
...@@ -342,6 +342,9 @@ def test_tuple_get_root(): ...@@ -342,6 +342,9 @@ def test_tuple_get_root():
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(zz, after)
fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
fuse2 = relay.transform.FuseOps(fuse_opt_level=2)
def test_tuple_intermediate(): def test_tuple_intermediate():
def before(x): def before(x):
inj = relay.squeeze(x) inj = relay.squeeze(x)
...@@ -363,16 +366,12 @@ def test_tuple_intermediate(): ...@@ -363,16 +366,12 @@ def test_tuple_intermediate():
dshape = (1, 16, 64, 64) dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
z = before(x) orig = before(x)
z = relay.ir_pass.infer_type(z) fuse0(relay.Module.from_expr(orig))
zz = relay.ir_pass.fuse_ops(z, opt_level=0) m = fuse2(relay.Module.from_expr(orig))
assert not relay.ir_pass.free_vars(zz) relay.build(m, 'llvm')
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(x)) after = relay.ir_pass.infer_type(expected(x))
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(m[m.entry_func], after)
def test_tuple_consecutive(): def test_tuple_consecutive():
...@@ -422,16 +421,12 @@ def test_tuple_consecutive(): ...@@ -422,16 +421,12 @@ def test_tuple_consecutive():
dshape = (1, 16, 64, 64) dshape = (1, 16, 64, 64)
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
z = before(x) orig = before(x)
z = relay.ir_pass.infer_type(z) fuse0(relay.Module.from_expr(orig))
zz = relay.ir_pass.fuse_ops(z, opt_level=0) m = fuse2(relay.Module.from_expr(orig))
assert not relay.ir_pass.free_vars(zz) relay.build(m, 'llvm')
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape)) after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(m[m.entry_func], after)
def test_inception_like(): def test_inception_like():
...@@ -493,16 +488,12 @@ def test_inception_like(): ...@@ -493,16 +488,12 @@ def test_inception_like():
return relay.Function(relay.ir_pass.free_vars(out), out) return relay.Function(relay.ir_pass.free_vars(out), out)
dshape = (1, 16, 64, 64) dshape = (1, 16, 64, 64)
z = before(dshape) orig = before(dshape)
z = relay.ir_pass.infer_type(z) fuse0(relay.Module.from_expr(orig))
zz = relay.ir_pass.fuse_ops(z, opt_level=0) m = fuse2(relay.Module.from_expr(orig))
assert not relay.ir_pass.free_vars(zz) relay.build(m, 'llvm')
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
relay.build(zz, 'llvm')
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected(dshape)) after = relay.ir_pass.infer_type(expected(dshape))
assert relay.ir_pass.alpha_equal(zz, after) assert relay.ir_pass.alpha_equal(m[m.entry_func], after)
def test_fuse_parallel_injective(): def test_fuse_parallel_injective():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment