Commit fa351045 by Zhi Committed by Tianqi Chen

[relay][frontend] Return module from frontend parsers (#3353)

parent 07fbe5c8
...@@ -21,7 +21,8 @@ from __future__ import absolute_import ...@@ -21,7 +21,8 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from . import _backend from . import _backend
from .. import _make, ir_pass from .. import _make, ir_pass, transform
from .. import module
from ... import register_func, nd from ... import register_func, nd
from ..base import NodeBase, register_relay_node from ..base import NodeBase, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
...@@ -191,14 +192,14 @@ class Executor(object): ...@@ -191,14 +192,14 @@ class Executor(object):
return tuple(cargs) return tuple(cargs)
def _make_executor(self, _): def _make_executor(self, expr=None):
""" """
Construct a Python function that implements the evaluation Construct a Python function that implements the evaluation
of expression. of expression.
Parameters Parameters
---------- ----------
expr: relay.Expr expr: Optional[relay.Expr]
The Relay expression to execute. The Relay expression to execute.
Returns Returns
...@@ -208,16 +209,16 @@ class Executor(object): ...@@ -208,16 +209,16 @@ class Executor(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def evaluate(self, expr, binds=None): def evaluate(self, expr=None, binds=None):
""" """
Evaluate a Relay expression on the executor. Evaluate a Relay expression on the executor.
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr: Optional[tvm.relay.Expr]
The expression to evaluate. The expression to evaluate.
binds: Map[tvm.relay.Var, tvm.relay.Expr] binds: Optional[Map[tvm.relay.Var, tvm.relay.Expr]]
Additional binding of free variable. Additional binding of free variable.
Returns Returns
...@@ -232,6 +233,9 @@ class Executor(object): ...@@ -232,6 +233,9 @@ class Executor(object):
scope_builder.ret(expr) scope_builder.ret(expr)
expr = scope_builder.get() expr = scope_builder.get()
if not expr:
return self._make_executor()
if isinstance(expr, Function): if isinstance(expr, Function):
assert not ir_pass.free_vars(expr) assert not ir_pass.free_vars(expr)
...@@ -264,46 +268,47 @@ class Interpreter(Executor): ...@@ -264,46 +268,47 @@ class Interpreter(Executor):
self.target = target self.target = target
self._intrp = _backend.CreateInterpreter(mod, ctx, target) self._intrp = _backend.CreateInterpreter(mod, ctx, target)
def optimize(self, expr): def optimize(self):
"""Optimize an expr. """Optimize functions in a module.
Parameters
----------
expr : Expr
The expression to be optimized.
Returns Returns
------- -------
opt_expr : Expr opt_mod : tvm.relay.Module
The optimized expression. The optimized module.
""" """
# TODO: We need to move this optimization code into the optimizer/pass manager seq = transform.Sequential([transform.SimplifyInference(),
wrapped_expr = expr if isinstance(expr, Function) else Function([], expr) transform.FuseOps(0),
if self.mod: transform.InferType()])
self.mod[self.mod.entry_func] = wrapped_expr return seq(self.mod)
ck_expr = ir_pass.infer_type(wrapped_expr, mod=self.mod)
simp_expr = ir_pass.simplify_inference(ck_expr) def _make_executor(self, expr=None):
ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod) if expr is None or isinstance(expr, GlobalVar):
fused_expr = ir_pass.fuse_ops(ck_simp, 0, mod=self.mod) assert self.mod is not None
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused if isinstance(expr, Function) else Call(ck_fused, [])
def _make_executor(self, expr):
def _interp_wrapper(*args, **kwargs): def _interp_wrapper(*args, **kwargs):
args = self._convert_args(expr, args, kwargs) if expr is None:
args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs)
else:
args = self._convert_args(expr, args, kwargs)
relay_args = [] relay_args = []
for arg in args: for arg in args:
relay_args.append(_arg_to_ast(arg)) relay_args.append(_arg_to_ast(arg))
if isinstance(expr, GlobalVar): # Set the entry function for the module.
func = self.mod[expr] if expr is None:
func = self.optimize(func) pass
self.mod._add(expr, func, True) elif isinstance(expr, GlobalVar):
opt_expr = Call(expr, relay_args) self.mod[self.mod.entry_func] = self.mod[expr]
return self._intrp(opt_expr)
else: else:
call = Call(expr, relay_args) assert isinstance(expr, Function)
opt_expr = self.optimize(call) func = Function([], Call(expr, relay_args))
return self._intrp(opt_expr) relay_args = []
if self.mod:
self.mod[self.mod.entry_func] = func
else:
self.mod = module.Module.from_expr(func)
mod = self.optimize()
opt_expr = Call(mod[self.mod.entry_func.name_hint], relay_args)
return self._intrp(opt_expr)
return _interp_wrapper return _interp_wrapper
...@@ -130,9 +130,11 @@ class VMExecutor(Executor): ...@@ -130,9 +130,11 @@ class VMExecutor(Executor):
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
def _make_executor(self, expr): def _make_executor(self, expr=None):
assert isinstance(expr, Expr) expr = expr if expr else self.mod
self.mod[self.mod.entry_func] = expr assert expr, "either expr or self.mod should be not null."
if isinstance(expr, Expr):
self.mod[self.mod.entry_func] = expr
main = self.mod[self.mod.entry_func] main = self.mod[self.mod.entry_func]
def _vm_wrapper(*args, **kwargs): def _vm_wrapper(*args, **kwargs):
......
...@@ -219,16 +219,19 @@ class GraphExecutor(_interpreter.Executor): ...@@ -219,16 +219,19 @@ class GraphExecutor(_interpreter.Executor):
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
def _make_executor(self, func): def _make_executor(self, expr=None):
ret_type = ir_pass.infer_type(func).ret_type if not expr:
assert self.mod, "either expr or self.mod should be not null."
expr = self.mod[self.mod.entry_func]
ret_type = ir_pass.infer_type(expr).ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(func, target=self.target) graph_json, mod, params = build(expr, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx) gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params: if params:
gmodule.set_input(**params) gmodule.set_input(**params)
def _graph_wrapper(*args, **kwargs): def _graph_wrapper(*args, **kwargs):
args = self._convert_args(func, args, kwargs) args = self._convert_args(expr, args, kwargs)
# Create map of inputs. # Create map of inputs.
for i, arg in enumerate(args): for i, arg in enumerate(args):
gmodule.set_input(i, arg) gmodule.set_input(i, arg)
......
...@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs ...@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import AttrCvt, Renamer from .common import AttrCvt, Renamer
...@@ -382,6 +383,7 @@ class Caffe2NetDef(object): ...@@ -382,6 +383,7 @@ class Caffe2NetDef(object):
self._ops = {} self._ops = {}
self._shape = shape self._shape = shape
self._dtype = dtype self._dtype = dtype
self._mod = _module.Module({})
def from_caffe2(self, init_net, predict_net): def from_caffe2(self, init_net, predict_net):
"""Construct Relay expression from caffe2 graph. """Construct Relay expression from caffe2 graph.
...@@ -393,8 +395,9 @@ class Caffe2NetDef(object): ...@@ -393,8 +395,9 @@ class Caffe2NetDef(object):
Returns Returns
------- -------
func : tvm.relay.expr.Function mod : tvm.relay.Module
Compatible relay function The module that optimizations will be performed on.
params : dict params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights A dict of name: tvm.nd.array pairs, used as pretrained weights
""" """
...@@ -448,8 +451,9 @@ class Caffe2NetDef(object): ...@@ -448,8 +451,9 @@ class Caffe2NetDef(object):
outputs = out[0] outputs = out[0]
func = _expr.Function(ir_pass.free_vars(outputs), outputs) func = _expr.Function(ir_pass.free_vars(outputs), outputs)
self._mod[self._mod.entry_func] = func
return func, self._params return self._mod, self._params
def _get_node(self, blob): def _get_node(self, blob):
"""Get the Symbol of blob and detect cyclic dependency in the graph.""" """Get the Symbol of blob and detect cyclic dependency in the graph."""
...@@ -560,8 +564,8 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"): ...@@ -560,8 +564,8 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"):
Returns Returns
------- -------
sym : tvm.relay.expr.Function mod : tvm.relay.Module
Compatible relay function The module that optimizations will be performed on.
params : dict of str to tvm.ndarray params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format Dict of converted parameters stored in tvm.ndarray format
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import tvm import tvm
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from ..._ffi import base as _base from ..._ffi import base as _base
...@@ -416,8 +417,8 @@ def from_coreml(model, shape=None): ...@@ -416,8 +417,8 @@ def from_coreml(model, shape=None):
Returns Returns
------- -------
func : tvm.relay.Function mod : tvm.relay.Module
Compatible relay Function. The relay module for compilation.
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by Relay. The parameter dict to be used by Relay.
...@@ -463,4 +464,4 @@ def from_coreml(model, shape=None): ...@@ -463,4 +464,4 @@ def from_coreml(model, shape=None):
outexpr = outexpr[0] outexpr = outexpr[0]
func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) func = _expr.Function(ir_pass.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return func, params return _module.Module.from_expr(func), params
...@@ -25,6 +25,7 @@ import numpy as np ...@@ -25,6 +25,7 @@ import numpy as np
import tvm import tvm
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .common import get_relay_op, new_var from .common import get_relay_op, new_var
__all__ = ['from_darknet'] __all__ = ['from_darknet']
...@@ -820,7 +821,7 @@ class GraphProto(object): ...@@ -820,7 +821,7 @@ class GraphProto(object):
outputs = _as_list(sym) + self._outs outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
sym = _expr.Function(ir_pass.free_vars(outputs), outputs) sym = _expr.Function(ir_pass.free_vars(outputs), outputs)
return sym, self._tvmparams return _module.Module.from_expr(sym), self._tvmparams
def from_darknet(net, def from_darknet(net,
shape=None, shape=None,
...@@ -838,8 +839,9 @@ def from_darknet(net, ...@@ -838,8 +839,9 @@ def from_darknet(net,
Returns Returns
------- -------
sym : tvm.relay.Function mod : tvm.relay.Module
Compatible relay Function The relay module for compilation.
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by relay The parameter dict to be used by relay
""" """
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import tvm import tvm
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import ExprTable, new_var from .common import ExprTable, new_var
...@@ -679,8 +680,8 @@ def from_keras(model, shape=None): ...@@ -679,8 +680,8 @@ def from_keras(model, shape=None):
Returns Returns
------- -------
func : tvm.relay.Function mod : tvm.relay.Module
Compatible relay Function. The relay module for compilation.
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by Relay. The parameter dict to be used by Relay.
...@@ -744,4 +745,4 @@ def from_keras(model, shape=None): ...@@ -744,4 +745,4 @@ def from_keras(model, shape=None):
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) func = _expr.Function(ir_pass.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return func, params return _module.Module.from_expr(func), params
...@@ -23,6 +23,7 @@ import tvm ...@@ -23,6 +23,7 @@ import tvm
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .. import module as _module
from ... import nd as _nd from ... import nd as _nd
from .common import StrAttrsDict from .common import StrAttrsDict
...@@ -992,7 +993,8 @@ _convert_map = { ...@@ -992,7 +993,8 @@ _convert_map = {
_convert_map.update({k : _rename(k) for k in _identity_list}) _convert_map.update({k : _rename(k) for k in _identity_list})
def _from_mxnet_impl(symbol, shape_dict, dtype_info): def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
#pylint: disable=unused-argument
"""Convert mxnet symbol to compatible relay Function. """Convert mxnet symbol to compatible relay Function.
Reconstruct a relay Function by traversing the mxnet symbol. Reconstruct a relay Function by traversing the mxnet symbol.
...@@ -1009,6 +1011,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): ...@@ -1009,6 +1011,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
dtype_info : dict or str. dtype_info : dict or str.
Known parameter dtypes Known parameter dtypes
mod : tvm.relay.Module
The module that contains global information. It will be used for
converting ops that need global information, e.g. control-flow ops.
Returns: Returns:
------- -------
func : tvm.relay.Function func : tvm.relay.Function
...@@ -1097,8 +1103,8 @@ def from_mxnet(symbol, ...@@ -1097,8 +1103,8 @@ def from_mxnet(symbol,
Returns Returns
------- -------
sym : tvm.relay.Function mod : tvm.relay.Module
Compatible relay Function The relay module for compilation
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm The parameter dict to be used by nnvm
...@@ -1108,6 +1114,7 @@ def from_mxnet(symbol, ...@@ -1108,6 +1114,7 @@ def from_mxnet(symbol,
except ImportError as e: except ImportError as e:
raise ImportError("{}. MXNet is required to parse symbols.".format(e)) raise ImportError("{}. MXNet is required to parse symbols.".format(e))
mod = _module.Module()
if isinstance(symbol, mx.sym.Symbol): if isinstance(symbol, mx.sym.Symbol):
params = {} params = {}
arg_params = arg_params if arg_params else {} arg_params = arg_params if arg_params else {}
...@@ -1117,7 +1124,7 @@ def from_mxnet(symbol, ...@@ -1117,7 +1124,7 @@ def from_mxnet(symbol,
for k, v in aux_params.items(): for k, v in aux_params.items():
params[k] = _nd.array(v.asnumpy()) params[k] = _nd.array(v.asnumpy())
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(symbol, shape, dtype) func = _from_mxnet_impl(symbol, shape, dtype, mod)
elif isinstance(symbol, mx.gluon.HybridBlock): elif isinstance(symbol, mx.gluon.HybridBlock):
if arg_params is not None or aux_params is not None: if arg_params is not None or aux_params is not None:
raise ValueError("arg_params and aux_params ae not used when importing HybridBlock") raise ValueError("arg_params and aux_params ae not used when importing HybridBlock")
...@@ -1129,10 +1136,11 @@ def from_mxnet(symbol, ...@@ -1129,10 +1136,11 @@ def from_mxnet(symbol,
if isinstance(sym, (list, tuple)): if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym) sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(sym, shape, dtype) func = _from_mxnet_impl(sym, shape, dtype, mod)
elif isinstance(symbol, mx.gluon.Block): elif isinstance(symbol, mx.gluon.Block):
raise NotImplementedError("Only Hybrid Blocks are supported now.") raise NotImplementedError("Only Hybrid Blocks are supported now.")
else: else:
msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol)) msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol))
raise ValueError(msg) raise ValueError(msg)
return sym, params mod[mod.entry_func] = func
return mod, params
...@@ -24,6 +24,7 @@ import tvm ...@@ -24,6 +24,7 @@ import tvm
from ... import nd as _nd from ... import nd as _nd
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from .common import AttrCvt, Renamer from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name
...@@ -999,8 +1000,9 @@ class GraphProto(object): ...@@ -999,8 +1000,9 @@ class GraphProto(object):
Returns Returns
------- -------
sym : tvm.relay.expr.Function mod : tvm.relay.Module
The returned relay function The returned relay module
params : dict params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights A dict of name: tvm.nd.array pairs, used as pretrained weights
""" """
...@@ -1090,7 +1092,7 @@ class GraphProto(object): ...@@ -1090,7 +1092,7 @@ class GraphProto(object):
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(ir_pass.free_vars(outputs), outputs) func = _expr.Function(ir_pass.free_vars(outputs), outputs)
return func, self._params return _module.Module.from_expr(func), self._params
def _parse_value_proto(self, value_proto): def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str.""" """Parse ValueProto or raw str."""
...@@ -1219,8 +1221,8 @@ def from_onnx(model, ...@@ -1219,8 +1221,8 @@ def from_onnx(model,
Returns Returns
------- -------
sym : tvm.relay.expr.Function mod : tvm.relay.Module
Compatible relay function The relay module for compilation
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by relay The parameter dict to be used by relay
...@@ -1243,5 +1245,5 @@ def from_onnx(model, ...@@ -1243,5 +1245,5 @@ def from_onnx(model,
opset = model.opset_import[0].version if model.opset_import else 1 opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError: except AttributeError:
opset = 1 opset = 1
sym, params = g.from_onnx(graph, opset) mod, params = g.from_onnx(graph, opset)
return sym, params return mod, params
...@@ -31,6 +31,7 @@ from .. import ir_pass ...@@ -31,6 +31,7 @@ from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ..expr_functor import ExprMutator from ..expr_functor import ExprMutator
from .. import module as _module
__all__ = ['from_tensorflow'] __all__ = ['from_tensorflow']
...@@ -1823,6 +1824,7 @@ class GraphProto(object): ...@@ -1823,6 +1824,7 @@ class GraphProto(object):
self._input_shapes = {} self._input_shapes = {}
self._loops = {} self._loops = {}
self._branches = {} self._branches = {}
self._mod = _module.Module({})
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef. """Construct relay nodes from tensorflow graph definition - GraphDef.
...@@ -1856,8 +1858,9 @@ class GraphProto(object): ...@@ -1856,8 +1858,9 @@ class GraphProto(object):
Returns Returns
------- -------
sym : relay.op mod : tvm.relay.Module
The returned relay operator The module that optimizations will be performed on.
params : dict params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights A dict of name: tvm.nd.array pairs, used as pretrained weights
""" """
...@@ -2046,8 +2049,8 @@ class GraphProto(object): ...@@ -2046,8 +2049,8 @@ class GraphProto(object):
out = out[0] if len(out) == 1 else _expr.Tuple(out) out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(ir_pass.free_vars(out), out) func = _expr.Function(ir_pass.free_vars(out), out)
self._mod[self._mod.entry_func] = func
return func, self._params return self._mod, self._params
def _parse_import_prerequisites(self, graph): def _parse_import_prerequisites(self, graph):
""" Calculate the named preconditions from TensorFlow `graph`. """ Calculate the named preconditions from TensorFlow `graph`.
...@@ -2336,12 +2339,12 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): ...@@ -2336,12 +2339,12 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
Returns Returns
------- -------
sym : relay.op mod : tvm.relay.Module
Compatible relay operator The module that optimizations will be performed on.
params : dict of str to tvm.ndarray params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format Dict of converted parameters stored in tvm.ndarray format
""" """
g = GraphProto() g = GraphProto()
sym, params = g.from_tensorflow(graph, layout, shape, outputs) mod, params = g.from_tensorflow(graph, layout, shape, outputs)
return sym, params return mod, params
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import tvm import tvm
from .. import ir_pass from .. import ir_pass
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import ExprTable from .common import ExprTable
...@@ -749,8 +750,8 @@ def from_tflite(model, shape_dict, dtype_dict): ...@@ -749,8 +750,8 @@ def from_tflite(model, shape_dict, dtype_dict):
Returns Returns
------- -------
func : tvm.relay.Function mod : tvm.relay.Module
Compatible relay Function The relay module for compilation.
params : dict of str to tvm.NDArray params : dict of str to tvm.NDArray
The parameter dict to be used by relay The parameter dict to be used by relay
...@@ -788,4 +789,4 @@ def from_tflite(model, shape_dict, dtype_dict): ...@@ -788,4 +789,4 @@ def from_tflite(model, shape_dict, dtype_dict):
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(ir_pass.free_vars(outputs), outputs) func = _expr.Function(ir_pass.free_vars(outputs), outputs)
return func, params return _module.Module.from_expr(func), params
...@@ -40,9 +40,10 @@ def get_tvm_output(model, ...@@ -40,9 +40,10 @@ def get_tvm_output(model,
input_names = model.predict_net.op[0].input[0] input_names = model.predict_net.op[0].input[0]
shape_dict = {input_names: input_data.shape} shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype} dtype_dict = {input_names: input_data.dtype}
func, params = relay.frontend.from_caffe2(model.init_net, model.predict_net, shape_dict, dtype_dict) mod, params = relay.frontend.from_caffe2(
model.init_net, model.predict_net, shape_dict, dtype_dict)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func], target, params=params)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
......
...@@ -28,9 +28,10 @@ def compare_graph(f1, f2): ...@@ -28,9 +28,10 @@ def compare_graph(f1, f2):
def test_squeeze_net(): def test_squeeze_net():
shape_dict = {'data': (1, 3, 224, 224)} shape_dict = {'data': (1, 3, 224, 224)}
dtype_dict = {'data': 'float32'} dtype_dict = {'data': 'float32'}
from_c2_func, _ = relay.frontend.from_caffe2(c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict) mod, _, = relay.frontend.from_caffe2(
c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict)
relay_func, _ = relay_squeezenet() relay_func, _ = relay_squeezenet()
compare_graph(from_c2_func, relay_func) compare_graph(mod[mod.entry_func], relay_func)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -46,9 +46,9 @@ def run_model_checkonly(model_file, model_name='', input_name='image'): ...@@ -46,9 +46,9 @@ def run_model_checkonly(model_file, model_name='', input_name='image'):
model = cm.models.MLModel(model_file) model = cm.models.MLModel(model_file)
x = model_zoo.get_cat_image() x = model_zoo.get_cat_image()
shape_dict = {input_name : x.shape} shape_dict = {input_name : x.shape}
func, params = relay.frontend.from_coreml(model, shape_dict) mod, params = relay.frontend.from_coreml(model, shape_dict)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_output = get_tvm_output(func, x, params, target, ctx) tvm_output = get_tvm_output(mod[mod.entry_func], x, params, target, ctx)
print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat)) print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat))
def test_mobilenet_checkonly(): def test_mobilenet_checkonly():
...@@ -71,9 +71,9 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap ...@@ -71,9 +71,9 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap
shape_dict = {input_name: input_data.shape} shape_dict = {input_name: input_data.shape}
dtype_dict = {input_name: input_data.dtype} dtype_dict = {input_name: input_data.dtype}
func, params = relay.frontend.from_coreml(coreml_model, shape_dict) mod, params = relay.frontend.from_coreml(coreml_model, shape_dict)
with relay.transform.build_config(opt_level=3): with relay.transform.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func], target, params=params)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
......
...@@ -52,10 +52,12 @@ def _read_memory_buffer(shape, data, dtype='float32'): ...@@ -52,10 +52,12 @@ def _read_memory_buffer(shape, data, dtype='float32'):
def _get_tvm_output(net, data, build_dtype='float32', states=None): def _get_tvm_output(net, data, build_dtype='float32', states=None):
'''Compute TVM output''' '''Compute TVM output'''
dtype = 'float32' dtype = 'float32'
sym, params = relay.frontend.from_darknet(net, data.shape, dtype) mod, params = relay.frontend.from_darknet(net, data.shape, dtype)
target = 'llvm' target = 'llvm'
shape_dict = {'data': data.shape} shape_dict = {'data': data.shape}
graph, library, params = relay.build(sym, target, params=params) graph, library, params = relay.build(mod[mod.entry_func],
target,
params=params)
# Execute on TVM # Execute on TVM
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
......
...@@ -42,9 +42,11 @@ def verify_keras_frontend(keras_model, need_transpose=True): ...@@ -42,9 +42,11 @@ def verify_keras_frontend(keras_model, need_transpose=True):
def get_tvm_output(xs, target, ctx, dtype='float32'): def get_tvm_output(xs, target, ctx, dtype='float32'):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
func, params = relay.frontend.from_keras(keras_model, shape_dict) mod, params = relay.frontend.from_keras(keras_model, shape_dict)
with relay.transform.build_config(opt_level=2): with relay.transform.build_config(opt_level=2):
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func],
target,
params=params)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
for name, x in zip(keras_model.input_names, xs): for name, x in zip(keras_model.input_names, xs):
m.set_input(name, tvm.nd.array(x.astype(dtype))) m.set_input(name, tvm.nd.array(x.astype(dtype)))
......
...@@ -26,60 +26,60 @@ def compare_graph(f1, f2): ...@@ -26,60 +26,60 @@ def compare_graph(f1, f2):
def test_mlp(): def test_mlp():
shape = {"data": (1, 1, 28, 28)} shape = {"data": (1, 1, 28, 28)}
mx_fun = model_zoo.mx_mlp() mx_fun = model_zoo.mx_mlp()
from_mx_fun, _ = relay.frontend.from_mxnet(mx_fun, shape=shape) mod, _ = relay.frontend.from_mxnet(mx_fun, shape=shape)
relay_fun = model_zoo.relay_mlp() relay_fun = model_zoo.relay_mlp()
compare_graph(from_mx_fun, relay_fun) compare_graph(mod[mod.entry_func], relay_fun)
def test_vgg(): def test_vgg():
shape = {"data": (1, 3, 224, 224)} shape = {"data": (1, 3, 224, 224)}
for n in [11, 13, 16, 19]: for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg(n) mx_sym = model_zoo.mx_vgg(n)
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
relay_sym = model_zoo.relay_vgg(n) relay_sym = model_zoo.relay_vgg(n)
compare_graph(from_mx_sym, relay_sym) compare_graph(mod[mod.entry_func], relay_sym)
def test_resnet(): def test_resnet():
shape = {"data": (1, 3, 224, 224)} shape = {"data": (1, 3, 224, 224)}
for n in [18, 34, 50, 101]: for n in [18, 34, 50, 101]:
mx_sym = model_zoo.mx_resnet(n) mx_sym = model_zoo.mx_resnet(n)
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
relay_sym = model_zoo.relay_resnet(n) relay_sym = model_zoo.relay_resnet(n)
compare_graph(from_mx_sym, relay_sym) compare_graph(mod[mod.entry_func], relay_sym)
def test_squeezenet(): def test_squeezenet():
shape = {"data": (1, 3, 224, 224)} shape = {"data": (1, 3, 224, 224)}
for version in ['1.0', '1.1']: for version in ['1.0', '1.1']:
mx_sym = model_zoo.mx_squeezenet(version) mx_sym = model_zoo.mx_squeezenet(version)
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_squeezenet(version) relay_sym = model_zoo.relay_squeezenet(version)
compare_graph(from_mx_sym, relay_sym) compare_graph(mod[mod.entry_func], relay_sym)
def test_inception_v3(): def test_inception_v3():
shape = {"data": (1, 3, 299, 299)} shape = {"data": (1, 3, 299, 299)}
mx_sym = model_zoo.mx_inception_v3() mx_sym = model_zoo.mx_inception_v3()
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_inception_v3() relay_sym = model_zoo.relay_inception_v3()
compare_graph(from_mx_sym, relay_sym) compare_graph(mod[mod.entry_func], relay_sym)
def test_dqn(): def test_dqn():
shape = {"data": (1, 4, 84, 84)} shape = {"data": (1, 4, 84, 84)}
mx_sym = model_zoo.mx_dqn() mx_sym = model_zoo.mx_dqn()
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_dqn() relay_sym = model_zoo.relay_dqn()
compare_graph(from_mx_sym, relay_sym) compare_graph(mod[mod.entry_func], relay_sym)
def test_dcgan(): def test_dcgan():
shape = {"data": (2, 100)} shape = {"data": (2, 100)}
mx_sym = model_zoo.mx_dcgan() mx_sym = model_zoo.mx_dcgan()
from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_dcgan(batch_size=2) relay_sym = model_zoo.relay_dcgan(batch_size=2)
compare_graph(from_mx_sym, relay_sym) compare_graph(mod[mod.entry_func], relay_sym)
def test_multi_outputs(): def test_multi_outputs():
...@@ -100,10 +100,10 @@ def test_multi_outputs(): ...@@ -100,10 +100,10 @@ def test_multi_outputs():
return relay.Function(relay.ir_pass.free_vars(z), z) return relay.Function(relay.ir_pass.free_vars(z), z)
mx_sym = mx_compose(mx, num_outputs=3, axis=1) mx_sym = mx_compose(mx, num_outputs=3, axis=1)
from_mx_sym, _ = relay.frontend.from_mxnet( mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"x":xshape, "y":yshape}) mx_sym, shape={"x":xshape, "y":yshape})
relay_sym = relay_compose(relay, indices_or_sections=3, axis=1) relay_sym = relay_compose(relay, indices_or_sections=3, axis=1)
compare_graph(from_mx_sym, relay_sym) compare_graph(mod[mod.entry_func], relay_sym)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -42,9 +42,11 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output ...@@ -42,9 +42,11 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
shape_dict = {input_names: input_data.shape} shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype} dtype_dict = {input_names: input_data.dtype}
sym, params = relay.frontend.from_onnx(graph_def, shape_dict) mod, params = relay.frontend.from_onnx(graph_def, shape_dict)
with relay.build_config(opt_level=1): with relay.build_config(opt_level=1):
graph, lib, params = relay.build(sym, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func],
target,
params=params)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
......
...@@ -22,9 +22,9 @@ from tvm.relay.frontend.tensorflow import from_tensorflow ...@@ -22,9 +22,9 @@ from tvm.relay.frontend.tensorflow import from_tensorflow
def check_equal(graph, tf_out): def check_equal(graph, tf_out):
expr, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
ex = relay.create_executor('debug') ex = relay.create_executor('debug', mod=mod)
relay_out = ex.evaluate(expr)(**params) relay_out = ex.evaluate()(**params)
if isinstance(relay_out, relay.backend.interpreter.TensorValue): if isinstance(relay_out, relay.backend.interpreter.TensorValue):
np.testing.assert_allclose(tf_out, relay_out.asnumpy()) np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else: else:
......
...@@ -60,13 +60,12 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -60,13 +60,12 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
sym, params = relay.frontend.from_tensorflow(graph_def, mod, params = relay.frontend.from_tensorflow(graph_def,
layout=layout, layout=layout,
shape=shape_dict, shape=shape_dict,
outputs=out_names) outputs=out_names)
with relay.build_config(opt_level=opt_level): with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(sym, target, target_host, params) graph, lib, params = relay.build(mod[mod.entry_func], target, target_host, params)
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
...@@ -1442,14 +1441,16 @@ def test_forward_ptb(): ...@@ -1442,14 +1441,16 @@ def test_forward_ptb():
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden), 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)} 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)}
sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict) mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
dtype_dict = {'Model/Placeholder': 'int32', dtype_dict = {'Model/Placeholder': 'int32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32', 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'} 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'}
target = 'llvm' target = 'llvm'
with relay.build_config(opt_level=0): with relay.build_config(opt_level=0):
graph, lib, params = relay.build(sym, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func],
target,
params=params)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
return params, graph_runtime.create(graph, lib, ctx) return params, graph_runtime.create(graph, lib, ctx)
......
...@@ -63,11 +63,13 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target ...@@ -63,11 +63,13 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
shape_dict[e] = input_data[i].shape shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype.name dtype_dict[e] = input_data[i].dtype.name
func, params = relay.frontend.from_tflite(tflite_model, mod, params = relay.frontend.from_tflite(tflite_model,
shape_dict=shape_dict, shape_dict=shape_dict,
dtype_dict=dtype_dict) dtype_dict=dtype_dict)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func],
target,
params=params)
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
......
...@@ -35,9 +35,9 @@ def veval(f, *args, ctx=tvm.cpu()): ...@@ -35,9 +35,9 @@ def veval(f, *args, ctx=tvm.cpu()):
mod = f mod = f
ex = relay.create_executor('vm', mod=mod, ctx=ctx) ex = relay.create_executor('vm', mod=mod, ctx=ctx)
if len(args) == 0: if len(args) == 0:
return ex.evaluate(mod[mod.entry_func]) return ex.evaluate()
else: else:
return ex.evaluate(mod[mod.entry_func])(*args) return ex.evaluate()(*args)
def test_split(): def test_split():
x = relay.var('x', shape=(12,)) x = relay.var('x', shape=(12,))
......
...@@ -260,10 +260,10 @@ elif test_target == 'vulkan': ...@@ -260,10 +260,10 @@ elif test_target == 'vulkan':
input_name = 'input_1' input_name = 'input_1'
shape_dict = {input_name: x.shape} shape_dict = {input_name: x.shape}
func, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict) mod, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target=target, graph, lib, params = relay.build(mod[mod.entry_func], target=target,
target_host=target_host, params=params) target_host=target_host, params=params)
# After `relay.build`, you will get three return values: graph, # After `relay.build`, you will get three return values: graph,
......
...@@ -140,8 +140,9 @@ with open(synset_path) as f: ...@@ -140,8 +140,9 @@ with open(synset_path) as f:
# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon # We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
shape_dict = {'data': x.shape} shape_dict = {'data': x.shape}
func, params = relay.frontend.from_mxnet(block, shape_dict) mod, params = relay.frontend.from_mxnet(block, shape_dict)
# we want a probability so add a softmax operator # we want a probability so add a softmax operator
func = mod[mod.entry_func]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
###################################################################### ######################################################################
......
...@@ -76,9 +76,9 @@ x, img = data.transforms.presets.ssd.load_test(im_fname, short=512) ...@@ -76,9 +76,9 @@ x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
block = model_zoo.get_model(model_name, pretrained=True) block = model_zoo.get_model(model_name, pretrained=True)
def build(target): def build(target):
net, params = relay.frontend.from_mxnet(block, {"data": dshape}) mod, params = relay.frontend.from_mxnet(block, {"data": dshape})
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func], target, params=params)
return graph, lib, params return graph, lib, params
###################################################################### ######################################################################
......
...@@ -83,13 +83,13 @@ dtype_dict = {input_name: data.dtype} ...@@ -83,13 +83,13 @@ dtype_dict = {input_name: data.dtype}
# parse Caffe2 model and convert into Relay computation graph # parse Caffe2 model and convert into Relay computation graph
from tvm import relay from tvm import relay
func, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict) mod, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict)
# compile the model # compile the model
# target x86 CPU # target x86 CPU
target = 'llvm' target = 'llvm'
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func], target, params=params)
###################################################################### ######################################################################
# Execute on TVM # Execute on TVM
......
...@@ -68,10 +68,12 @@ target = 'cuda' ...@@ -68,10 +68,12 @@ target = 'cuda'
shape_dict = {'image': x.shape} shape_dict = {'image': x.shape}
# Parse CoreML model and convert into Relay computation graph # Parse CoreML model and convert into Relay computation graph
func, params = relay.frontend.from_coreml(mlmodel, shape_dict) mod, params = relay.frontend.from_coreml(mlmodel, shape_dict)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func],
target,
params=params)
###################################################################### ######################################################################
# Execute on TVM # Execute on TVM
......
...@@ -82,7 +82,7 @@ batch_size = 1 ...@@ -82,7 +82,7 @@ batch_size = 1
data = np.empty([batch_size, net.c, net.h, net.w], dtype) data = np.empty([batch_size, net.c, net.h, net.w], dtype)
shape_dict = {'data': data.shape} shape_dict = {'data': data.shape}
print("Converting darknet to relay functions...") print("Converting darknet to relay functions...")
sym, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape) mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape)
###################################################################### ######################################################################
# Import the graph to Relay # Import the graph to Relay
...@@ -95,7 +95,10 @@ data = np.empty([batch_size, net.c, net.h, net.w], dtype) ...@@ -95,7 +95,10 @@ data = np.empty([batch_size, net.c, net.h, net.w], dtype)
shape = {'data': data.shape} shape = {'data': data.shape}
print("Compiling the model...") print("Compiling the model...")
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params) graph, lib, params = relay.build(mod[mod.entry_func],
target=target,
target_host=target_host,
params=params)
[neth, netw] = shape['data'][2:] # Current image shape is 608x608 [neth, netw] = shape['data'][2:] # Current image shape is 608x608
###################################################################### ######################################################################
......
...@@ -74,18 +74,18 @@ print('input_1', data.shape) ...@@ -74,18 +74,18 @@ print('input_1', data.shape)
# ---------------------------- # ----------------------------
# convert the keras model(NHWC layout) to Relay format(NCHW layout). # convert the keras model(NHWC layout) to Relay format(NCHW layout).
shape_dict = {'input_1': data.shape} shape_dict = {'input_1': data.shape}
func, params = relay.frontend.from_keras(keras_resnet50, shape_dict) mod, params = relay.frontend.from_keras(keras_resnet50, shape_dict)
# compile the model # compile the model
target = 'cuda' target = 'cuda'
ctx = tvm.gpu(0) ctx = tvm.gpu(0)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
executor = relay.build_module.create_executor('graph', func, ctx, target) executor = relay.build_module.create_executor('graph', mod, ctx, target)
###################################################################### ######################################################################
# Execute on TVM # Execute on TVM
# --------------- # ---------------
dtype = 'float32' dtype = 'float32'
tvm_out = executor.evaluate(func)(tvm.nd.array(data.astype(dtype)), **params) tvm_out = executor.evaluate()(tvm.nd.array(data.astype(dtype)), **params)
top1_tvm = np.argmax(tvm_out.asnumpy()[0]) top1_tvm = np.argmax(tvm_out.asnumpy()[0])
##################################################################### #####################################################################
......
...@@ -82,8 +82,9 @@ print('x', x.shape) ...@@ -82,8 +82,9 @@ print('x', x.shape)
# It's as easy as several lines. # It's as easy as several lines.
# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon # We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
shape_dict = {'data': x.shape} shape_dict = {'data': x.shape}
func, params = relay.frontend.from_mxnet(block, shape_dict) mod, params = relay.frontend.from_mxnet(block, shape_dict)
## we want a probability so add a softmax operator ## we want a probability so add a softmax operator
func = mod[mod.entry_func]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
###################################################################### ######################################################################
...@@ -132,6 +133,6 @@ mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs) ...@@ -132,6 +133,6 @@ mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs)
# for a normal mxnet model, we start from here # for a normal mxnet model, we start from here
mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0) mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0)
# now we use the same API to get Relay computation graph # now we use the same API to get Relay computation graph
relay_func, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict, mod, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict,
arg_params=args, aux_params=auxs) arg_params=args, aux_params=auxs)
# repeat the same steps to run this model using TVM # repeat the same steps to run this model using TVM
...@@ -71,16 +71,16 @@ target = 'llvm' ...@@ -71,16 +71,16 @@ target = 'llvm'
input_name = '1' input_name = '1'
shape_dict = {input_name: x.shape} shape_dict = {input_name: x.shape}
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict) mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
with relay.build_config(opt_level=1): with relay.build_config(opt_level=1):
intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target) intrp = relay.build_module.create_executor('graph', mod, tvm.cpu(0), target)
###################################################################### ######################################################################
# Execute on TVM # Execute on TVM
# --------------------------------------------- # ---------------------------------------------
dtype = 'float32' dtype = 'float32'
tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy() tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
###################################################################### ######################################################################
# Display results # Display results
......
...@@ -124,7 +124,9 @@ x = np.array(image) ...@@ -124,7 +124,9 @@ x = np.array(image)
# params: params converted from tensorflow params (tensor protobuf). # params: params converted from tensorflow params (tensor protobuf).
shape_dict = {'DecodeJpeg/contents': x.shape} shape_dict = {'DecodeJpeg/contents': x.shape}
dtype_dict = {'DecodeJpeg/contents': 'uint8'} dtype_dict = {'DecodeJpeg/contents': 'uint8'}
sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict) mod, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict)
print("Tensorflow protobuf imported to relay frontend.") print("Tensorflow protobuf imported to relay frontend.")
###################################################################### ######################################################################
...@@ -138,7 +140,10 @@ print("Tensorflow protobuf imported to relay frontend.") ...@@ -138,7 +140,10 @@ print("Tensorflow protobuf imported to relay frontend.")
# lib: target library which can be deployed on target with TVM runtime. # lib: target library which can be deployed on target with TVM runtime.
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params) graph, lib, params = relay.build(mod[mod.entry_func],
target=target,
target_host=target_host,
params=params)
###################################################################### ######################################################################
# Execute the portable graph on TVM # Execute the portable graph on TVM
......
...@@ -138,14 +138,14 @@ input_dtype = "float32" ...@@ -138,14 +138,14 @@ input_dtype = "float32"
# parse TFLite model and convert into Relay computation graph # parse TFLite model and convert into Relay computation graph
from tvm import relay from tvm import relay
func, params = relay.frontend.from_tflite(tflite_model, mod, params = relay.frontend.from_tflite(tflite_model,
shape_dict={input_tensor: input_shape}, shape_dict={input_tensor: input_shape},
dtype_dict={input_tensor: input_dtype}) dtype_dict={input_tensor: input_dtype})
# target x86 CPU # target x86 CPU
target = "llvm" target = "llvm"
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target, params=params) graph, lib, params = relay.build(mod[mod.entry_func], target, params=params)
###################################################################### ######################################################################
# Execute on TVM # Execute on TVM
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment