Commit f35f2276 by Zhi Committed by Jared Roesch

[relay][frontend] Return Module from get_workload (#3483)

* [relay][frontend] Return Module from get_workload

* pass entry_func to autotvm

* disable tune

* add property to module

* mod.entry_func to main

* .main -> mod["main"]

* fix
parent 42eee923
...@@ -34,20 +34,6 @@ namespace tvm { ...@@ -34,20 +34,6 @@ namespace tvm {
namespace relay { namespace relay {
/*! /*!
* \brief Infer the type of a function as if it is mapped to var in the mod.
*
* \param f the function.
* \param mod The module used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates mod and is not thread-safe.
*/
TVM_DLL Function InferType(const Function& f,
const Module& mod,
const GlobalVar& var);
/*!
* \brief Check that types are well kinded by applying "kinding rules". * \brief Check that types are well kinded by applying "kinding rules".
* *
* This pass ensures we do not do things that violate the design of the * This pass ensures we do not do things that violate the design of the
......
...@@ -65,16 +65,12 @@ class ModuleNode : public RelayNode { ...@@ -65,16 +65,12 @@ class ModuleNode : public RelayNode {
/*! \brief A map from global type vars to ADT type data. */ /*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions; tvm::Map<GlobalTypeVar, TypeData> type_definitions;
/*! \brief The entry function (i.e. "main"). */
GlobalVar entry_func;
ModuleNode() {} ModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions); v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions); v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_); v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func);
v->Visit("global_type_var_map_", &global_type_var_map_); v->Visit("global_type_var_map_", &global_type_var_map_);
} }
...@@ -120,6 +116,13 @@ class ModuleNode : public RelayNode { ...@@ -120,6 +116,13 @@ class ModuleNode : public RelayNode {
TVM_DLL void Remove(const GlobalVar& var); TVM_DLL void Remove(const GlobalVar& var);
/*! /*!
* \brief Check if the global_var_map_ contains a global variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalVar(const std::string& name) const;
/*!
* \brief Lookup a global function by its variable. * \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable. * \param str The unique string specifying the global variable.
* \returns The global variable. * \returns The global variable.
...@@ -180,10 +183,10 @@ class ModuleNode : public RelayNode { ...@@ -180,10 +183,10 @@ class ModuleNode : public RelayNode {
* Allows one to optionally pass a global function map as * Allows one to optionally pass a global function map as
* well. * well.
* *
* \param expr The expression to set as the entry point to the module. * \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map. * \param global_funcs The global function map.
* *
* \returns A module with expr set as the entry point. * \returns A module with expr set as the main function.
*/ */
TVM_DLL static Module FromExpr( TVM_DLL static Module FromExpr(
const Expr& expr, const Expr& expr,
......
...@@ -142,7 +142,7 @@ class BaseGraphTuner(object): ...@@ -142,7 +142,7 @@ class BaseGraphTuner(object):
# Generate workload and schedule dictionaries. # Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module): if isinstance(graph, relay.Module):
graph = graph[graph.entry_func] graph = graph["main"]
if isinstance(graph, relay.expr.Function): if isinstance(graph, relay.expr.Function):
node_dict = {} node_dict = {}
......
...@@ -85,7 +85,7 @@ def _infer_type(node): ...@@ -85,7 +85,7 @@ def _infer_type(node):
"""A method to infer the type of a relay expression.""" """A method to infer the type of a relay expression."""
mod = relay.Module.from_expr(node) mod = relay.Module.from_expr(node)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(node, relay.Function) else entry.body return entry if isinstance(node, relay.Function) else entry.body
......
...@@ -110,5 +110,5 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): ...@@ -110,5 +110,5 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
mod = relay.Module.from_expr(updated_expr) mod = relay.Module.from_expr(updated_expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(updated_expr, relay.Function) else entry.body return entry if isinstance(updated_expr, relay.Function) else entry.body
...@@ -289,7 +289,7 @@ class Interpreter(Executor): ...@@ -289,7 +289,7 @@ class Interpreter(Executor):
assert self.mod is not None assert self.mod is not None
def _interp_wrapper(*args, **kwargs): def _interp_wrapper(*args, **kwargs):
if expr is None: if expr is None:
args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs) args = self._convert_args(self.mod["main"], args, kwargs)
else: else:
args = self._convert_args(expr, args, kwargs) args = self._convert_args(expr, args, kwargs)
...@@ -301,17 +301,17 @@ class Interpreter(Executor): ...@@ -301,17 +301,17 @@ class Interpreter(Executor):
if expr is None: if expr is None:
pass pass
elif isinstance(expr, GlobalVar): elif isinstance(expr, GlobalVar):
self.mod[self.mod.entry_func] = self.mod[expr] self.mod["main"] = self.mod[expr]
else: else:
assert isinstance(expr, Function) assert isinstance(expr, Function)
func = Function([], Call(expr, relay_args)) func = Function([], Call(expr, relay_args))
relay_args = [] relay_args = []
if self.mod: if self.mod:
self.mod[self.mod.entry_func] = func self.mod["main"] = func
else: else:
self.mod = module.Module.from_expr(func) self.mod = module.Module.from_expr(func)
mod = self.optimize() mod = self.optimize()
opt_expr = Call(mod[self.mod.entry_func.name_hint], relay_args) opt_expr = Call(mod["main"], relay_args)
return self._intrp(opt_expr) return self._intrp(opt_expr)
return _interp_wrapper return _interp_wrapper
...@@ -45,7 +45,7 @@ def optimize(mod): ...@@ -45,7 +45,7 @@ def optimize(mod):
ret : tvm.relay.Module ret : tvm.relay.Module
The optimized module. The optimized module.
""" """
main_func = mod[mod.entry_func] main_func = mod["main"]
opt_passes = [] opt_passes = []
if not main_func.params and isinstance(main_func.body, GlobalVar): if not main_func.params and isinstance(main_func.body, GlobalVar):
...@@ -134,8 +134,8 @@ class VMExecutor(Executor): ...@@ -134,8 +134,8 @@ class VMExecutor(Executor):
expr = expr if expr else self.mod expr = expr if expr else self.mod
assert expr, "either expr or self.mod should be not null." assert expr, "either expr or self.mod should be not null."
if isinstance(expr, Expr): if isinstance(expr, Expr):
self.mod[self.mod.entry_func] = expr self.mod["main"] = expr
main = self.mod[self.mod.entry_func] main = self.mod["main"]
def _vm_wrapper(*args, **kwargs): def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs) args = self._convert_args(main, args, kwargs)
......
...@@ -177,7 +177,7 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -177,7 +177,7 @@ def build(mod, target=None, target_host=None, params=None):
The parameters of the final graph. The parameters of the final graph.
""" """
if isinstance(mod, _Module): if isinstance(mod, _Module):
func = mod[mod.entry_func] func = mod["main"]
elif isinstance(mod, _expr.Function): elif isinstance(mod, _expr.Function):
func = mod func = mod
warnings.warn( warnings.warn(
...@@ -233,8 +233,8 @@ class GraphExecutor(_interpreter.Executor): ...@@ -233,8 +233,8 @@ class GraphExecutor(_interpreter.Executor):
def _make_executor(self, expr=None): def _make_executor(self, expr=None):
if expr: if expr:
self.mod[self.mod.entry_func] = expr self.mod["main"] = expr
ret_type = self.mod[self.mod.entry_func].checked_type.ret_type ret_type = self.mod["main"].checked_type.ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(self.mod, target=self.target) graph_json, mod, params = build(self.mod, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx) gmodule = _graph_rt.create(graph_json, mod, self.ctx)
...@@ -242,7 +242,7 @@ class GraphExecutor(_interpreter.Executor): ...@@ -242,7 +242,7 @@ class GraphExecutor(_interpreter.Executor):
gmodule.set_input(**params) gmodule.set_input(**params)
def _graph_wrapper(*args, **kwargs): def _graph_wrapper(*args, **kwargs):
args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs) args = self._convert_args(self.mod["main"], args, kwargs)
# Create map of inputs. # Create map of inputs.
for i, arg in enumerate(args): for i, arg in enumerate(args):
gmodule.set_input(i, arg) gmodule.set_input(i, arg)
......
...@@ -451,7 +451,7 @@ class Caffe2NetDef(object): ...@@ -451,7 +451,7 @@ class Caffe2NetDef(object):
outputs = out[0] outputs = out[0]
func = _expr.Function(analysis.free_vars(outputs), outputs) func = _expr.Function(analysis.free_vars(outputs), outputs)
self._mod[self._mod.entry_func] = func self._mod["main"] = func
return self._mod, self._params return self._mod, self._params
......
...@@ -412,7 +412,7 @@ def infer_type(node): ...@@ -412,7 +412,7 @@ def infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node) mod = _module.Module.from_expr(node)
mod = _transform.InferType()(mod) mod = _transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body return entry if isinstance(node, _expr.Function) else entry.body
def infer_shape(inputs): def infer_shape(inputs):
......
...@@ -45,7 +45,7 @@ def _infer_type(node): ...@@ -45,7 +45,7 @@ def _infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node) mod = _module.Module.from_expr(node)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body return entry if isinstance(node, _expr.Function) else entry.body
def _mx_fully_connected(inputs, attrs): def _mx_fully_connected(inputs, attrs):
...@@ -1200,5 +1200,5 @@ def from_mxnet(symbol, ...@@ -1200,5 +1200,5 @@ def from_mxnet(symbol,
else: else:
msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol)) msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol))
raise ValueError(msg) raise ValueError(msg)
mod[mod.entry_func] = func mod["main"] = func
return mod, params return mod, params
...@@ -240,7 +240,7 @@ def _infer_type(node): ...@@ -240,7 +240,7 @@ def _infer_type(node):
"""A method to infer the type of an intermediate node in the relay graph.""" """A method to infer the type of an intermediate node in the relay graph."""
mod = _module.Module.from_expr(node) mod = _module.Module.from_expr(node)
mod = _transform.InferType()(mod) mod = _transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body return entry if isinstance(node, _expr.Function) else entry.body
def _infer_shape(node, params=None): def _infer_shape(node, params=None):
...@@ -2122,7 +2122,7 @@ class GraphProto(object): ...@@ -2122,7 +2122,7 @@ class GraphProto(object):
out = out[0] if len(out) == 1 else _expr.Tuple(out) out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(analysis.free_vars(out), out) func = _expr.Function(analysis.free_vars(out), out)
self._mod[self._mod.entry_func] = func self._mod["main"] = func
return self._mod, self._params return self._mod, self._params
def _parse_import_prerequisites(self, graph): def _parse_import_prerequisites(self, graph):
......
...@@ -78,8 +78,11 @@ class Module(RelayNode): ...@@ -78,8 +78,11 @@ class Module(RelayNode):
def _add(self, var, val, update=False): def _add(self, var, val, update=False):
if isinstance(val, _expr.Expr): if isinstance(val, _expr.Expr):
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
if _module.Module_ContainGlobalVar(self, var):
var = _module.Module_GetGlobalVar(self, var)
else:
var = _expr.GlobalVar(var) var = _expr.GlobalVar(var)
_make.Module_Add(self, var, val, update) _module.Module_Add(self, var, val, update)
else: else:
assert isinstance(val, _ty.Type) assert isinstance(val, _ty.Type)
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
......
...@@ -365,4 +365,4 @@ def quantize(graph, params=None, dataset=None): ...@@ -365,4 +365,4 @@ def quantize(graph, params=None, dataset=None):
mod = optimize(mod) mod = optimize(mod)
mod = quantize_seq(mod) mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint] return mod["main"]
...@@ -41,7 +41,7 @@ def run_opt_pass(expr, opt_pass): ...@@ -41,7 +41,7 @@ def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -103,8 +103,8 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype= ...@@ -103,8 +103,8 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype=
Returns Returns
------- -------
net : nnvm.symbol mod : tvm.relay.Module
The computational graph The relay module that contains a DCGAN network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
......
...@@ -105,8 +105,8 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4, ...@@ -105,8 +105,8 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4,
Returns Returns
------- -------
net: relay.Function mod: tvm.relay.Module
The computation graph representing densenet. The relay module that contains a DenseNet network.
params : dict of str to NDArray params : dict of str to NDArray
The benchmark paraeters. The benchmark paraeters.
......
...@@ -72,8 +72,8 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo ...@@ -72,8 +72,8 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo
The data type The data type
Returns Returns
------- -------
net : nnvm.symbol mod : tvm.relay.Module
The computational graph The relay module that contains a DQN network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
......
...@@ -289,8 +289,8 @@ def get_workload(batch_size=1, num_classes=1000, ...@@ -289,8 +289,8 @@ def get_workload(batch_size=1, num_classes=1000,
Returns Returns
------- -------
net : nnvm.Symbol mod : tvm.relay.Module
The computational graph The relay module that contains an Inception V3 network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -144,17 +144,16 @@ def create_workload(net, initializer=None, seed=0): ...@@ -144,17 +144,16 @@ def create_workload(net, initializer=None, seed=0):
Returns Returns
------- -------
net : tvm.relay.Function mod : tvm.relay.Module
The updated dataflow The created relay module.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
mod = relay.Module.from_expr(net) mod = relay.Module.from_expr(net)
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
net = mod[mod.entry_func]
shape_dict = { shape_dict = {
v.name_hint : v.checked_type for v in net.params} v.name_hint : v.checked_type for v in mod["main"].params}
np.random.seed(seed) np.random.seed(seed)
initializer = initializer if initializer else Xavier() initializer = initializer if initializer else Xavier()
params = {} params = {}
...@@ -164,4 +163,4 @@ def create_workload(net, initializer=None, seed=0): ...@@ -164,4 +163,4 @@ def create_workload(net, initializer=None, seed=0):
init_value = np.zeros(v.concrete_shape).astype(v.dtype) init_value = np.zeros(v.concrete_shape).astype(v.dtype)
initializer(k, init_value) initializer(k, init_value)
params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0)) params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0))
return net, params return mod, params
...@@ -173,8 +173,8 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"): ...@@ -173,8 +173,8 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"):
The data type The data type
Returns Returns
------- -------
net : nnvm.symbol mod : tvm.relay.Module
The computational graph The relay module that contains a LSTM network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
......
...@@ -84,8 +84,8 @@ def get_workload(batch_size, ...@@ -84,8 +84,8 @@ def get_workload(batch_size,
Returns Returns
------- -------
net : relay.Function mod : tvm.relay.Module
The dataflow. The relay module that contains a mlp network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -130,8 +130,8 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtyp ...@@ -130,8 +130,8 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtyp
Returns Returns
------- -------
net : relay.Function mod : tvm.relay.Module
The computational graph The relay module that contains a MobileNet network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -261,8 +261,8 @@ def get_workload(batch_size=1, ...@@ -261,8 +261,8 @@ def get_workload(batch_size=1,
Returns Returns
------- -------
net : relay.Function mod : tvm.relay.Module
The computational graph The relay module that contains a ResNet network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -149,8 +149,8 @@ def get_workload(batch_size=1, ...@@ -149,8 +149,8 @@ def get_workload(batch_size=1,
Returns Returns
------- -------
net : nnvm.Symbol mod : tvm.relay.Module
The computational graph The relay module that contains a SqueezeNet network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -124,8 +124,8 @@ def get_workload(batch_size, ...@@ -124,8 +124,8 @@ def get_workload(batch_size,
Returns Returns
------- -------
net : nnvm.Symbol mod : tvm.relay.Module
The computational graph The relay module that contains a VGG network.
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
......
...@@ -434,7 +434,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -434,7 +434,7 @@ class RelayBuildModule : public runtime::ModuleNode {
relay_module = Optimize(relay_module, targets_, params); relay_module = Optimize(relay_module, targets_, params);
CHECK(relay_module.defined()); CHECK(relay_module.defined());
// Get the updated function. // Get the updated function.
func = relay_module->Lookup(relay_module->entry_func->name_hint); func = relay_module->Lookup("main");
// Generate code for the updated function. // Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen()); graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
......
...@@ -52,10 +52,10 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs, ...@@ -52,10 +52,10 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
// TODO(zhiics): This measurement is for temporary usage. Remove it later. We // TODO(zhiics): This measurement is for temporary usage. Remove it later. We
// need to introduce a better profiling method. // need to introduce a better profiling method.
#if ENABLE_PROFILING #if ENABLE_PROFILING
DLOG(INFO) << "Entry function is " << module->entry_func << std::endl; DLOG(INFO) << "Entry function is main." << std::endl;
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
#endif // ENABLE_PROFILING #endif // ENABLE_PROFILING
Object res = vm.Invoke(module->entry_func->name_hint, vm_args); Object res = vm.Invoke("main", vm_args);
#if ENABLE_PROFILING #if ENABLE_PROFILING
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count(); auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
......
...@@ -46,8 +46,6 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, ...@@ -46,8 +46,6 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
n->global_var_map_.Set(kv.first->name_hint, kv.first); n->global_var_map_.Set(kv.first->name_hint, kv.first);
} }
n->entry_func = GlobalVarNode::make("main");
for (const auto& kv : n->type_definitions) { for (const auto& kv : n->type_definitions) {
// set global typevar map // set global typevar map
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
...@@ -59,6 +57,10 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, ...@@ -59,6 +57,10 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
return Module(n); return Module(n);
} }
bool ModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name); auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end()) CHECK(it != global_var_map_.end())
...@@ -194,7 +196,8 @@ Module ModuleNode::FromExpr( ...@@ -194,7 +196,8 @@ Module ModuleNode::FromExpr(
} else { } else {
func = FunctionNode::make({}, expr, Type(), {}, {}); func = FunctionNode::make({}, expr, Type(), {}, {});
} }
mod->Add(mod->entry_func, func); auto main_gv = GlobalVarNode::make("main");
mod->Add(main_gv, func);
return mod; return mod;
} }
...@@ -203,7 +206,7 @@ TVM_REGISTER_NODE_TYPE(ModuleNode); ...@@ -203,7 +206,7 @@ TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module") TVM_REGISTER_API("relay._make.Module")
.set_body_typed(ModuleNode::make); .set_body_typed(ModuleNode::make);
TVM_REGISTER_API("relay._make.Module_Add") TVM_REGISTER_API("relay._module.Module_Add")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Module mod = args[0]; Module mod = args[0];
GlobalVar var = args[1]; GlobalVar var = args[1];
...@@ -231,6 +234,9 @@ TVM_REGISTER_API("relay._module.Module_AddDef") ...@@ -231,6 +234,9 @@ TVM_REGISTER_API("relay._module.Module_AddDef")
TVM_REGISTER_API("relay._module.Module_GetGlobalVar") TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body_method<Module>(&ModuleNode::GetGlobalVar); .set_body_method<Module>(&ModuleNode::GetGlobalVar);
TVM_REGISTER_API("relay._module.Module_ContainGlobalVar")
.set_body_method<Module>(&ModuleNode::ContainGlobalVar);
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar); .set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
......
...@@ -161,7 +161,7 @@ class ConstantFolder : public ExprMutator { ...@@ -161,7 +161,7 @@ class ConstantFolder : public ExprMutator {
auto mod = ModuleNode::FromExpr(expr); auto mod = ModuleNode::FromExpr(expr);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
auto entry_func = mod->Lookup(mod->entry_func); auto entry_func = mod->Lookup("main");
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ValueToExpr(executor_(expr)); return ValueToExpr(executor_(expr));
} }
......
...@@ -751,7 +751,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -751,7 +751,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
auto mod = ModuleNode::FromExpr(expr); auto mod = ModuleNode::FromExpr(expr);
auto seq = transform::Sequential(passes); auto seq = transform::Sequential(passes);
mod = seq(mod); mod = seq(mod);
auto entry_func = mod->Lookup(mod->entry_func); auto entry_func = mod->Lookup("main");
auto fused_infered = auto fused_infered =
expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return Reify(executor_(fused_infered), ll); return Reify(executor_(fused_infered), ll);
...@@ -1018,7 +1018,6 @@ Expr PostProcess(const Expr& e) { ...@@ -1018,7 +1018,6 @@ Expr PostProcess(const Expr& e) {
} // namespace partial_eval } // namespace partial_eval
Module PartialEval(const Module& m) { Module PartialEval(const Module& m) {
CHECK(m->entry_func.defined());
relay::partial_eval::PartialEvaluator pe(m); relay::partial_eval::PartialEvaluator pe(m);
std::vector<GlobalVar> gvs; std::vector<GlobalVar> gvs;
for (const auto& p : m->functions) { for (const auto& p : m->functions) {
......
...@@ -263,7 +263,7 @@ Expr QuantizeRealize(const Call& ref_call, ...@@ -263,7 +263,7 @@ Expr QuantizeRealize(const Call& ref_call,
Expr FoldConstantOpt(const Expr& expr) { Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr); auto mod = ModuleNode::FromExpr(expr);
mod = transform::FoldConstant()(mod); mod = transform::FoldConstant()(mod);
auto entry_func = mod->Lookup(mod->entry_func); auto entry_func = mod->Lookup("main");
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func; return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
} }
......
...@@ -774,7 +774,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { ...@@ -774,7 +774,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) {
// type check it anyway; afterwards we can just recover type // type check it anyway; afterwards we can just recover type
// from the type-checked function to avoid doing unnecessary work. // from the type-checked function to avoid doing unnecessary work.
Function func = mod->Lookup(mod->entry_func); Function func = mod->Lookup("main");
// FromExpr wraps a naked expression as a function, we will unbox // FromExpr wraps a naked expression as a function, we will unbox
// it here. // it here.
...@@ -784,7 +784,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { ...@@ -784,7 +784,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) {
return func->body; return func->body;
} }
} else { } else {
auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); auto e = TypeInferencer(mod_ref, mod_ref->GetGlobalVar("main")).Infer(expr);
CHECK(WellFormed(e)); CHECK(WellFormed(e));
auto free_tvars = FreeTypeVars(e, mod_ref); auto free_tvars = FreeTypeVars(e, mod_ref);
CHECK(free_tvars.size() == 0) CHECK(free_tvars.size() == 0)
......
...@@ -35,7 +35,7 @@ TEST(Relay, SelfReference) { ...@@ -35,7 +35,7 @@ TEST(Relay, SelfReference) {
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {}); auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto mod = relay::ModuleNode::FromExpr(fx); auto mod = relay::ModuleNode::FromExpr(fx);
mod = relay::transform::InferType()(mod); mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup(mod->entry_func); auto type_fx = mod->Lookup("main");
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {}); auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(AlphaEqual(type_fx->checked_type(), expected)); CHECK(AlphaEqual(type_fx->checked_type(), expected));
......
...@@ -84,9 +84,9 @@ TEST(Relay, Sequential) { ...@@ -84,9 +84,9 @@ TEST(Relay, Sequential) {
} }
CHECK(mod.defined()); CHECK(mod.defined());
auto entry_func = mod->entry_func; auto entry_func = mod->GetGlobalVar("main");
CHECK(entry_func.defined()); CHECK(entry_func.defined());
relay::Function f = mod->Lookup(entry_func->name_hint); relay::Function f = mod->Lookup("main");
CHECK(f.defined()); CHECK(f.defined());
// Expected function // Expected function
...@@ -102,7 +102,7 @@ TEST(Relay, Sequential) { ...@@ -102,7 +102,7 @@ TEST(Relay, Sequential) {
// Infer type for the expected function. // Infer type for the expected function.
auto mod1 = relay::ModuleNode::FromExpr(expected_func); auto mod1 = relay::ModuleNode::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1); mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup(mod1->entry_func); auto expected = mod1->Lookup("main");
CHECK(relay::AlphaEqual(f, expected)); CHECK(relay::AlphaEqual(f, expected));
} }
......
...@@ -20,11 +20,10 @@ from tvm.relay import transform ...@@ -20,11 +20,10 @@ from tvm.relay import transform
from model_zoo import c2_squeezenet, relay_squeezenet from model_zoo import c2_squeezenet, relay_squeezenet
def compare_graph(lhs_mod, func): def compare_graph(lhs_mod, rhs_mod):
rhs_mod = relay.Module.from_expr(func) lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod) rhs_mod = transform.InferType()(rhs_mod)
assert relay.analysis.alpha_equal(lhs_mod[lhs_mod.entry_func], assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
rhs_mod[rhs_mod.entry_func])
def test_squeeze_net(): def test_squeeze_net():
...@@ -32,8 +31,8 @@ def test_squeeze_net(): ...@@ -32,8 +31,8 @@ def test_squeeze_net():
dtype_dict = {'data': 'float32'} dtype_dict = {'data': 'float32'}
mod, _, = relay.frontend.from_caffe2( mod, _, = relay.frontend.from_caffe2(
c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict) c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict)
relay_func, _ = relay_squeezenet() relay_mod, _ = relay_squeezenet()
compare_graph(mod, relay_func) compare_graph(mod, relay_mod)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -48,7 +48,7 @@ def run_model_checkonly(model_file, model_name='', input_name='image'): ...@@ -48,7 +48,7 @@ def run_model_checkonly(model_file, model_name='', input_name='image'):
shape_dict = {input_name : x.shape} shape_dict = {input_name : x.shape}
mod, params = relay.frontend.from_coreml(model, shape_dict) mod, params = relay.frontend.from_coreml(model, shape_dict)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_output = get_tvm_output(mod[mod.entry_func], x, params, target, ctx) tvm_output = get_tvm_output(mod["main"], x, params, target, ctx)
print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat)) print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat))
def test_mobilenet_checkonly(): def test_mobilenet_checkonly():
......
...@@ -19,15 +19,17 @@ from tvm import relay ...@@ -19,15 +19,17 @@ from tvm import relay
from tvm.relay import transform from tvm.relay import transform
import model_zoo import model_zoo
def compare_graph(f1, f2): def compare_graph(lhs_mod, rhs_mod):
assert relay.analysis.alpha_equal(f1, f2) lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod)
assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
def test_mlp(): def test_mlp():
shape = {"data": (1, 1, 28, 28)} shape = {"data": (1, 1, 28, 28)}
mx_fun = model_zoo.mx_mlp() mx_fun = model_zoo.mx_mlp()
mod, _ = relay.frontend.from_mxnet(mx_fun, shape=shape) mod, _ = relay.frontend.from_mxnet(mx_fun, shape=shape)
relay_fun = model_zoo.relay_mlp() relay_fun = model_zoo.relay_mlp()
compare_graph(mod[mod.entry_func], relay_fun) compare_graph(mod, relay_fun)
def test_vgg(): def test_vgg():
...@@ -35,8 +37,8 @@ def test_vgg(): ...@@ -35,8 +37,8 @@ def test_vgg():
for n in [11, 13, 16, 19]: for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg(n) mx_sym = model_zoo.mx_vgg(n)
mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
relay_sym = model_zoo.relay_vgg(n) relay_mod = model_zoo.relay_vgg(n)
compare_graph(mod[mod.entry_func], relay_sym) compare_graph(mod, relay_mod)
def test_resnet(): def test_resnet():
...@@ -44,8 +46,8 @@ def test_resnet(): ...@@ -44,8 +46,8 @@ def test_resnet():
for n in [18, 34, 50, 101]: for n in [18, 34, 50, 101]:
mx_sym = model_zoo.mx_resnet(n) mx_sym = model_zoo.mx_resnet(n)
mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape)
relay_sym = model_zoo.relay_resnet(n) relay_mod = model_zoo.relay_resnet(n)
compare_graph(mod[mod.entry_func], relay_sym) compare_graph(mod, relay_mod)
def test_squeezenet(): def test_squeezenet():
...@@ -53,32 +55,32 @@ def test_squeezenet(): ...@@ -53,32 +55,32 @@ def test_squeezenet():
for version in ['1.0', '1.1']: for version in ['1.0', '1.1']:
mx_sym = model_zoo.mx_squeezenet(version) mx_sym = model_zoo.mx_squeezenet(version)
mod, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_squeezenet(version) relay_mod = model_zoo.relay_squeezenet(version)
compare_graph(mod[mod.entry_func], relay_sym) compare_graph(mod, relay_mod)
def test_inception_v3(): def test_inception_v3():
shape = {"data": (1, 3, 299, 299)} shape = {"data": (1, 3, 299, 299)}
mx_sym = model_zoo.mx_inception_v3() mx_sym = model_zoo.mx_inception_v3()
mod, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_inception_v3() relay_mod = model_zoo.relay_inception_v3()
compare_graph(mod[mod.entry_func], relay_sym) compare_graph(mod, relay_mod)
def test_dqn(): def test_dqn():
shape = {"data": (1, 4, 84, 84)} shape = {"data": (1, 4, 84, 84)}
mx_sym = model_zoo.mx_dqn() mx_sym = model_zoo.mx_dqn()
mod, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_dqn() relay_mod = model_zoo.relay_dqn()
compare_graph(mod[mod.entry_func], relay_sym) compare_graph(mod, relay_mod)
def test_dcgan(): def test_dcgan():
shape = {"data": (2, 100)} shape = {"data": (2, 100)}
mx_sym = model_zoo.mx_dcgan() mx_sym = model_zoo.mx_dcgan()
mod, _ = relay.frontend.from_mxnet(mx_sym, shape) mod, _ = relay.frontend.from_mxnet(mx_sym, shape)
relay_sym = model_zoo.relay_dcgan(batch_size=2) relay_mod = model_zoo.relay_dcgan(batch_size=2)
compare_graph(mod[mod.entry_func], relay_sym) compare_graph(mod, relay_mod)
def test_multi_outputs(): def test_multi_outputs():
...@@ -97,15 +99,13 @@ def test_multi_outputs(): ...@@ -97,15 +99,13 @@ def test_multi_outputs():
z = F.split(x, **kwargs) z = F.split(x, **kwargs)
z = F.subtract(F.add(z[0], z[2]), y) z = F.subtract(F.add(z[0], z[2]), y)
func = relay.Function(relay.analysis.free_vars(z), z) func = relay.Function(relay.analysis.free_vars(z), z)
mod = relay.Module.from_expr(func) return relay.Module.from_expr(func)
mod = transform.InferType()(mod)
return mod[mod.entry_func]
mx_sym = mx_compose(mx, num_outputs=3, axis=1) mx_sym = mx_compose(mx, num_outputs=3, axis=1)
mod, _ = relay.frontend.from_mxnet( mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"x":xshape, "y":yshape}) mx_sym, shape={"x":xshape, "y":yshape})
relay_sym = relay_compose(relay, indices_or_sections=3, axis=1) relay_mod = relay_compose(relay, indices_or_sections=3, axis=1)
compare_graph(mod[mod.entry_func], relay_sym) compare_graph(mod, relay_mod)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -77,7 +77,7 @@ def test_alter_layout_conv2d(): ...@@ -77,7 +77,7 @@ def test_alter_layout_conv2d():
with autotvm.tophub.context(target): with autotvm.tophub.context(target):
mod = relay.Module.from_expr(N) mod = relay.Module.from_expr(N)
mod = transform.AlterOpLayout()(mod) mod = transform.AlterOpLayout()(mod)
O = mod[mod.entry_func] O = mod["main"]
# graph should differ # graph should differ
assert not relay.analysis.alpha_equal(N, O) assert not relay.analysis.alpha_equal(N, O)
......
...@@ -23,15 +23,15 @@ from tvm import relay ...@@ -23,15 +23,15 @@ from tvm import relay
from tvm.relay import testing from tvm.relay import testing
def benchmark_execution(net, def benchmark_execution(mod,
params, params,
measure=False, measure=False,
data_shape=(1, 3, 224, 224), data_shape=(1, 3, 224, 224),
out_shape=(1, 1000), out_shape=(1, 1000),
dtype='float32'): dtype='float32'):
def get_tvm_output(net, data, params, target, ctx, dtype='float32'): def get_tvm_output(mod, data, params, target, ctx, dtype='float32'):
with relay.build_config(opt_level=1): with relay.build_config(opt_level=1):
graph, lib, params = relay.build(net, target, params=params) graph, lib, params = relay.build(mod, target, params=params)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
...@@ -50,9 +50,9 @@ def benchmark_execution(net, ...@@ -50,9 +50,9 @@ def benchmark_execution(net,
return out.asnumpy() return out.asnumpy()
def get_tvm_vm_output(net, data, params, target, ctx, dtype='float32'): def get_tvm_vm_output(mod, data, params, target, ctx, dtype='float32'):
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx) ex = relay.create_executor('vm', mod=mod, ctx=ctx)
result = ex.evaluate(net)(data, **params) result = ex.evaluate()(data, **params)
return result.asnumpy().astype(dtype) return result.asnumpy().astype(dtype)
# random input # random input
...@@ -60,64 +60,64 @@ def benchmark_execution(net, ...@@ -60,64 +60,64 @@ def benchmark_execution(net,
target = "llvm" target = "llvm"
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
tvm_out = get_tvm_output(net, tvm.nd.array(data.astype(dtype)), params, tvm_out = get_tvm_output(mod, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype) target, ctx, dtype)
vm_out = get_tvm_vm_output(net, tvm.nd.array(data.astype(dtype)), params, vm_out = get_tvm_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
target, ctx, dtype) target, ctx, dtype)
tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_mlp(): def test_mlp():
image_shape = (1, 28, 28) image_shape = (1, 1, 28, 28)
net, params = testing.mlp.get_workload(1) mod, params = testing.mlp.get_workload(1)
benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 10)) benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 10))
def test_vgg(): def test_vgg():
for n in [11, 16]: for n in [11, 16]:
net, params = testing.vgg.get_workload(1, num_layers=n) mod, params = testing.vgg.get_workload(1, num_layers=n)
benchmark_execution(net, params) benchmark_execution(mod, params)
def test_resnet(): def test_resnet():
for n in [18, 50]: for n in [18, 50]:
net, params = testing.resnet.get_workload(batch_size=1, num_layers=n) mod, params = testing.resnet.get_workload(batch_size=1, num_layers=n)
benchmark_execution(net, params, True) benchmark_execution(mod, params, True)
def test_squeezenet(): def test_squeezenet():
for version in ['1.0', '1.1']: for version in ['1.0', '1.1']:
net, params = testing.squeezenet.get_workload(version=version) mod, params = testing.squeezenet.get_workload(version=version)
benchmark_execution(net, params) benchmark_execution(mod, params)
def test_inception_v3(): def test_inception_v3():
image_shape = (3, 299, 299) image_shape = (1, 3, 299, 299)
net, params = testing.inception_v3.get_workload(image_shape=image_shape) mod, params = testing.inception_v3.get_workload(image_shape=image_shape)
benchmark_execution(net, params, data_shape=image_shape) benchmark_execution(mod, params, data_shape=image_shape)
def test_dqn(): def test_dqn():
image_shape = (4, 84, 84) image_shape = (1, 4, 84, 84)
net, params = testing.dqn.get_workload( mod, params = testing.dqn.get_workload(
batch_size=1, image_shape=image_shape) batch_size=1, image_shape=image_shape)
benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 18)) benchmark_execution(mod, params, data_shape=image_shape, out_shape=(1, 18))
def test_dcgan(): def test_dcgan():
image_shape = (1, 100) image_shape = (1, 100)
net, params = testing.dcgan.get_workload(batch_size=1) mod, params = testing.dcgan.get_workload(batch_size=1)
benchmark_execution(net, params, data_shape=image_shape) benchmark_execution(mod, params, data_shape=image_shape)
def test_mobilenet(): def test_mobilenet():
net, params = testing.mobilenet.get_workload(batch_size=1) mod, params = testing.mobilenet.get_workload(batch_size=1)
benchmark_execution(net, params) benchmark_execution(mod, params)
def test_densenet(): def test_densenet():
net, params = testing.densenet.get_workload(batch_size=1) mod, params = testing.densenet.get_workload(batch_size=1)
benchmark_execution(net, params) benchmark_execution(mod, params)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -24,46 +24,46 @@ def get_network(name, batch_size): ...@@ -24,46 +24,46 @@ def get_network(name, batch_size):
input_shape = (batch_size, 3, 224, 224) input_shape = (batch_size, 3, 224, 224)
if name == 'resnet-18': if name == 'resnet-18':
net, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size) mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
elif name == 'mobilenet': elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size) mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'dcgan': elif name == 'dcgan':
net, params = relay.testing.dcgan.get_workload(batch_size=batch_size) mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size)
input_shape = (batch_size, 100) input_shape = (batch_size, 100)
else: else:
raise ValueError("Unsupported network: " + name) raise ValueError("Unsupported network: " + name)
return net, params, input_shape return mod, params, input_shape
def test_task_extraction(): def test_task_extraction():
target = 'llvm' target = 'llvm'
net, params, input_shape = get_network('resnet-18', batch_size=1) mod, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d,)) ops=(relay.op.nn.conv2d,))
assert len(tasks) == 12 assert len(tasks) == 12
net, params, input_shape = get_network('resnet-18', batch_size=1) mod, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.dense,)) ops=(relay.op.nn.dense,))
assert len(tasks) == 1 assert len(tasks) == 1
net, params, input_shape = get_network('resnet-18', batch_size=1) mod, params, input_shape = get_network('resnet-18', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense)) ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 13 assert len(tasks) == 13
net, params, input_shape = get_network('mobilenet', batch_size=1) mod, params, input_shape = get_network('mobilenet', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d, relay.op.nn.dense)) ops=(relay.op.nn.conv2d, relay.op.nn.dense))
assert len(tasks) == 20 assert len(tasks) == 20
net, params, input_shape = get_network('dcgan', batch_size=1) mod, params, input_shape = get_network('dcgan', batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d_transpose,)) ops=(relay.op.nn.conv2d_transpose,))
assert len(tasks) == 4 assert len(tasks) == 4
......
...@@ -29,7 +29,7 @@ def test_compile_engine(): ...@@ -29,7 +29,7 @@ def test_compile_engine():
f = relay.Function([x], z) f = relay.Function([x], z)
mod = relay.Module.from_expr(f) mod = relay.Module.from_expr(f)
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
return mod[mod.entry_func] return mod["main"]
z1 = engine.lower(get_func((10,)), "llvm") z1 = engine.lower(get_func((10,)), "llvm")
z2 = engine.lower(get_func((10,)), "llvm") z2 = engine.lower(get_func((10,)), "llvm")
z3 = engine.lower(get_func(()), "llvm") z3 = engine.lower(get_func(()), "llvm")
......
...@@ -125,7 +125,7 @@ def test_plan_memory(): ...@@ -125,7 +125,7 @@ def test_plan_memory():
func = relay.Function([x, y], z) func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func) mod = relay.Module.from_expr(func)
mod = relay.transform.FuseOps(0)(mod) mod = relay.transform.FuseOps(0)(mod)
func = mod[mod.entry_func] func = mod["main"]
smap = relay.backend._backend.GraphPlanMemory(func) smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = set() storage_ids = set()
device_types = set() device_types = set()
......
...@@ -224,9 +224,8 @@ def test_tuple_passing(): ...@@ -224,9 +224,8 @@ def test_tuple_passing():
fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
mod = relay.Module({}) mod = relay.Module({})
gv = relay.GlobalVar('fn') gv = relay.GlobalVar('main')
mod[gv] = fn mod[gv] = fn
mod.entry_func = gv
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
ctx = tvm.cpu() ctx = tvm.cpu()
......
...@@ -21,7 +21,7 @@ def check_type_err(expr, msg): ...@@ -21,7 +21,7 @@ def check_type_err(expr, msg):
try: try:
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
expr = entry if isinstance(expr, relay.Function) else entry.body expr = entry if isinstance(expr, relay.Function) else entry.body
assert False assert False
except tvm.TVMError as err: except tvm.TVMError as err:
......
...@@ -49,7 +49,7 @@ def test_ad(): ...@@ -49,7 +49,7 @@ def test_ad():
func = relay.Function([x], x + x) func = relay.Function([x], x + x)
mod = relay.Module.from_expr(gradient(func)) mod = relay.Module.from_expr(gradient(func))
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
back_func = mod[mod.entry_func] back_func = mod["main"]
feats = detect_feature(back_func) feats = detect_feature(back_func)
assert feats == set([ assert feats == set([
Feature.fVar, Feature.fVar,
......
...@@ -24,7 +24,7 @@ from tvm.relay.testing import ctx_list ...@@ -24,7 +24,7 @@ from tvm.relay.testing import ctx_list
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = relay.transform.InferType()(mod) mod = relay.transform.InferType()(mod)
return mod[mod.entry_func] return mod["main"]
def sigmoid(x): def sigmoid(x):
......
...@@ -24,7 +24,7 @@ import topi.testing ...@@ -24,7 +24,7 @@ import topi.testing
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def sigmoid(x): def sigmoid(x):
......
...@@ -28,7 +28,7 @@ import topi.testing ...@@ -28,7 +28,7 @@ import topi.testing
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_collapse_sum_like(): def test_collapse_sum_like():
......
...@@ -26,7 +26,7 @@ import topi.testing ...@@ -26,7 +26,7 @@ import topi.testing
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_conv2d_infer_type(): def test_conv2d_infer_type():
......
...@@ -26,7 +26,7 @@ from tvm.relay.testing import ctx_list ...@@ -26,7 +26,7 @@ from tvm.relay.testing import ctx_list
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_zeros_ones(): def test_zeros_ones():
......
...@@ -24,7 +24,7 @@ import topi.testing ...@@ -24,7 +24,7 @@ import topi.testing
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_binary_op(): def test_binary_op():
......
...@@ -27,7 +27,7 @@ import topi.testing ...@@ -27,7 +27,7 @@ import topi.testing
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def test_resize_infer_type(): def test_resize_infer_type():
......
...@@ -28,7 +28,7 @@ def run_opt_pass(expr, passes): ...@@ -28,7 +28,7 @@ def run_opt_pass(expr, passes):
seq = transform.Sequential(passes) seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3): with transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -31,7 +31,7 @@ def run_opt_pass(expr, passes): ...@@ -31,7 +31,7 @@ def run_opt_pass(expr, passes):
seq = transform.Sequential(passes) seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3): with transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
return mod[mod.entry_func] return mod["main"]
def test_redundant_annotation(): def test_redundant_annotation():
......
...@@ -58,7 +58,7 @@ def test_canonicalize_cast(): ...@@ -58,7 +58,7 @@ def test_canonicalize_cast():
_transform.InferType()]) _transform.InferType()])
with _transform.PassContext(opt_level=3): with _transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
y = mod[mod.entry_func.name_hint] y = mod["main"]
y_expected = expected(data, conv_weight, bias1, bias2) y_expected = expected(data, conv_weight, bias1, bias2)
gv = relay.GlobalVar("expected") gv = relay.GlobalVar("expected")
mod[gv] = y_expected mod[gv] = y_expected
......
...@@ -21,13 +21,13 @@ from tvm.relay import transform ...@@ -21,13 +21,13 @@ from tvm.relay import transform
def run_combine_parallel(expr, min_num_branches=3): def run_combine_parallel(expr, min_num_branches=3):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.CombineParallelConv2D(min_num_branches)(mod) mod = transform.CombineParallelConv2D(min_num_branches)(mod)
return mod[mod.entry_func] return mod["main"]
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
return mod[mod.entry_func] return mod["main"]
def test_combine_parallel_conv2d(): def test_combine_parallel_conv2d():
......
...@@ -49,7 +49,7 @@ def run_opt_pass(expr, opt_pass): ...@@ -49,7 +49,7 @@ def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -24,7 +24,7 @@ def run_opt_pass(expr, opt_pass): ...@@ -24,7 +24,7 @@ def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -26,7 +26,7 @@ def test_eta_expand_basic(): ...@@ -26,7 +26,7 @@ def test_eta_expand_basic():
with _transform.PassContext(opt_level=3): with _transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
got = mod[mod.entry_func.name_hint] got = mod["main"]
y = relay.var('y', 'int32') y = relay.var('y', 'int32')
expected = relay.Function([y], orig(y)) expected = relay.Function([y], orig(y))
......
...@@ -25,7 +25,7 @@ def run_opt_pass(expr, opt_pass): ...@@ -25,7 +25,7 @@ def run_opt_pass(expr, opt_pass):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -27,7 +27,7 @@ def run_opt_pass(expr, opt_pass): ...@@ -27,7 +27,7 @@ def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -357,7 +357,7 @@ def test_tuple_intermediate(): ...@@ -357,7 +357,7 @@ def test_tuple_intermediate():
m = fuse2(relay.Module.from_expr(orig)) m = fuse2(relay.Module.from_expr(orig))
relay.build(m, 'llvm') relay.build(m, 'llvm')
after = run_opt_pass(expected(x), transform.InferType()) after = run_opt_pass(expected(x), transform.InferType())
assert relay.analysis.alpha_equal(m[m.entry_func], after) assert relay.analysis.alpha_equal(m["main"], after)
def test_tuple_consecutive(): def test_tuple_consecutive():
...@@ -412,7 +412,7 @@ def test_tuple_consecutive(): ...@@ -412,7 +412,7 @@ def test_tuple_consecutive():
m = fuse2(relay.Module.from_expr(orig)) m = fuse2(relay.Module.from_expr(orig))
relay.build(m, 'llvm') relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(m[m.entry_func], after) assert relay.analysis.alpha_equal(m["main"], after)
def test_inception_like(): def test_inception_like():
...@@ -479,7 +479,7 @@ def test_inception_like(): ...@@ -479,7 +479,7 @@ def test_inception_like():
m = fuse2(relay.Module.from_expr(orig)) m = fuse2(relay.Module.from_expr(orig))
relay.build(m, 'llvm') relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType()) after = run_opt_pass(expected(dshape), transform.InferType())
assert relay.analysis.alpha_equal(m[m.entry_func], after) assert relay.analysis.alpha_equal(m["main"], after)
def test_fuse_parallel_injective(): def test_fuse_parallel_injective():
......
...@@ -185,9 +185,9 @@ def test_pow(): ...@@ -185,9 +185,9 @@ def test_pow():
i = relay.var("i", t) i = relay.var("i", t)
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
func = gradient(func, mod=mod) func = gradient(func, mod=mod)
mod[mod.entry_func] = func mod["main"] = func
m = transform.InferType()(mod) m = transform.InferType()(mod)
back_func = m[m.entry_func] back_func = m["main"]
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
i_nd = rand(dtype, *shape) i_nd = rand(dtype, *shape)
ex = create_executor(mod=mod) ex = create_executor(mod=mod)
......
...@@ -25,7 +25,7 @@ def run_opt_pass(expr, opt_pass): ...@@ -25,7 +25,7 @@ def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -29,7 +29,7 @@ from tvm.relay.testing import ctx_list ...@@ -29,7 +29,7 @@ from tvm.relay.testing import ctx_list
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = _transform.InferType()(mod) mod = _transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -41,7 +41,7 @@ def run_opt_pass(expr, passes): ...@@ -41,7 +41,7 @@ def run_opt_pass(expr, passes):
seq = transform.Sequential(passes) seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3): with transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
...@@ -57,10 +57,10 @@ def dcpe(expr, mod=None, grad=False): ...@@ -57,10 +57,10 @@ def dcpe(expr, mod=None, grad=False):
expr = gradient(expr) expr = gradient(expr)
if mod: if mod:
assert isinstance(expr, Function) assert isinstance(expr, Function)
mod[mod.entry_func] = expr mod["main"] = expr
seq = transform.Sequential(passes) seq = transform.Sequential(passes)
mod = seq(mod) mod = seq(mod)
return mod[mod.entry_func] return mod["main"]
return run_opt_pass(expr, passes) return run_opt_pass(expr, passes)
...@@ -192,8 +192,8 @@ def test_map(): ...@@ -192,8 +192,8 @@ def test_map():
orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil()))) expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil())))
expected = Function([], expected) expected = Function([], expected)
mod[mod.entry_func] = expected mod["main"] = expected
expected = mod[mod.entry_func] expected = mod["main"]
orig = Function([], orig) orig = Function([], orig)
res = dcpe(orig, mod=mod) res = dcpe(orig, mod=mod)
assert alpha_equal(res.body, expected.body) assert alpha_equal(res.body, expected.body)
...@@ -206,8 +206,8 @@ def test_loop(): ...@@ -206,8 +206,8 @@ def test_loop():
loop = GlobalVar("loop") loop = GlobalVar("loop")
mod[loop] = Function([x], loop(x), t, [t]) mod[loop] = Function([x], loop(x), t, [t])
expected = Call(loop, [const(1)]) expected = Call(loop, [const(1)])
mod[mod.entry_func] = Function([], expected) mod["main"] = Function([], expected)
expected = mod[mod.entry_func].body expected = mod["main"].body
call = Function([], loop(const(1))) call = Function([], loop(const(1)))
res = dcpe(call, mod=mod) res = dcpe(call, mod=mod)
assert alpha_equal(res.body, expected) assert alpha_equal(res.body, expected)
......
...@@ -25,7 +25,7 @@ from tvm.relay import transform ...@@ -25,7 +25,7 @@ from tvm.relay import transform
def run_infer_type(expr): def run_infer_type(expr):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -30,7 +30,7 @@ def run_opt_pass(expr, passes): ...@@ -30,7 +30,7 @@ def run_opt_pass(expr, passes):
seq = transform.Sequential(passes) seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3): with transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
...@@ -195,7 +195,7 @@ def test_gradient_if(): ...@@ -195,7 +195,7 @@ def test_gradient_if():
net = relay.Function([cond,x,y], net) net = relay.Function([cond,x,y], net)
mod = relay.Module.from_expr(net) mod = relay.Module.from_expr(net)
mod = relay.transform.ToANormalForm()(mod) mod = relay.transform.ToANormalForm()(mod)
mod[mod.entry_func] = relay.transform.gradient(mod[mod.entry_func], mode='higher_order') mod["main"] = relay.transform.gradient(mod["main"], mode='higher_order')
mod = relay.transform.ToANormalForm()(mod) mod = relay.transform.ToANormalForm()(mod)
......
...@@ -42,12 +42,12 @@ def test_recursion(): ...@@ -42,12 +42,12 @@ def test_recursion():
double = relay.Function([x], x + x) double = relay.Function([x], x + x)
i = relay.var("i", t) i = relay.var("i", t)
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
mod[mod.entry_func] = func mod["main"] = func
mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod) mod["main"] = to_cps(mod["main"], mod=mod)
mod[mod.entry_func] = un_cps(mod[mod.entry_func]) mod["main"] = un_cps(mod["main"])
ex = create_executor(mod=mod) ex = create_executor(mod=mod)
i_nd = rand(dtype, *shape) i_nd = rand(dtype, *shape)
forward = ex.evaluate(mod.entry_func)(i_nd) forward = ex.evaluate()(i_nd)
tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy()) tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
......
...@@ -24,7 +24,7 @@ from tvm.relay.analysis import detect_feature ...@@ -24,7 +24,7 @@ from tvm.relay.analysis import detect_feature
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -25,7 +25,7 @@ def run_infer_type(expr, mod=None): ...@@ -25,7 +25,7 @@ def run_infer_type(expr, mod=None):
if not mod: if not mod:
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
else: else:
if isinstance(expr, relay.GlobalVar): if isinstance(expr, relay.GlobalVar):
...@@ -34,7 +34,7 @@ def run_infer_type(expr, mod=None): ...@@ -34,7 +34,7 @@ def run_infer_type(expr, mod=None):
func = expr func = expr
if not isinstance(expr, relay.Function): if not isinstance(expr, relay.Function):
func = relay.Function(analysis.free_vars(expr), expr) func = relay.Function(analysis.free_vars(expr), expr)
mod[mod.entry_func] = func mod["main"] = func
gv = "main" gv = "main"
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
...@@ -266,7 +266,7 @@ def test_type_args(): ...@@ -266,7 +266,7 @@ def test_type_args():
def test_global_var_recursion(): def test_global_var_recursion():
mod = relay.Module({}) mod = relay.Module({})
gv = relay.GlobalVar("foo") gv = relay.GlobalVar("main")
x = relay.var('x', shape=[]) x = relay.var('x', shape=[])
tt = relay.scalar_type('float32') tt = relay.scalar_type('float32')
......
...@@ -25,7 +25,7 @@ def test_dup_type(): ...@@ -25,7 +25,7 @@ def test_dup_type():
b = relay.Var("b", t) b = relay.Var("b", t)
mod = relay.Module.from_expr(make_id(b)) mod = relay.Module.from_expr(make_id(b))
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
inferred = mod[mod.entry_func].body inferred = mod["main"].body
assert inferred.checked_type == relay.TupleType([t, t]) assert inferred.checked_type == relay.TupleType([t, t])
...@@ -39,9 +39,9 @@ def test_id_type(): ...@@ -39,9 +39,9 @@ def test_id_type():
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b])) make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
t = relay.scalar_type("float32") t = relay.scalar_type("float32")
b = relay.Var("b", t) b = relay.Var("b", t)
mod[mod.entry_func] = relay.Function([], make_id(b)) mod["main"] = relay.Function([], make_id(b))
mod = transform.InferType()(mod) mod = transform.InferType()(mod)
assert mod[mod.entry_func].body.checked_type == id_type(t) assert mod["main"].body.checked_type == id_type(t)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -121,7 +121,7 @@ def test_simple_call(): ...@@ -121,7 +121,7 @@ def test_simple_call():
mod[sum_up] = func mod[sum_up] = func
i_data = np.array(0, dtype='int32') i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32') iarg = relay.var('i', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg)) mod["main"] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data) result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data) tvm.testing.assert_allclose(result.asnumpy(), i_data)
...@@ -140,7 +140,7 @@ def test_count_loop(): ...@@ -140,7 +140,7 @@ def test_count_loop():
mod[sum_up] = func mod[sum_up] = func
i_data = np.array(0, dtype='int32') i_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32') iarg = relay.var('i', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg)) mod["main"] = relay.Function([iarg], sum_up(iarg))
result = veval(mod, i_data) result = veval(mod, i_data)
tvm.testing.assert_allclose(result.asnumpy(), i_data) tvm.testing.assert_allclose(result.asnumpy(), i_data)
...@@ -163,7 +163,7 @@ def test_sum_loop(): ...@@ -163,7 +163,7 @@ def test_sum_loop():
accum_data = np.array(0, dtype='int32') accum_data = np.array(0, dtype='int32')
iarg = relay.var('i', shape=[], dtype='int32') iarg = relay.var('i', shape=[], dtype='int32')
aarg = relay.var('accum', shape=[], dtype='int32') aarg = relay.var('accum', shape=[], dtype='int32')
mod[mod.entry_func] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
result = veval(mod, i_data, accum_data) result = veval(mod, i_data, accum_data)
tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
...@@ -212,7 +212,7 @@ def test_list_constructor(): ...@@ -212,7 +212,7 @@ def test_list_constructor():
one4 = cons(relay.const(3), one3) one4 = cons(relay.const(3), one3)
f = relay.Function([], one4) f = relay.Function([], one4)
mod[mod.entry_func] = f mod["main"] = f
result = veval(mod)() result = veval(mod)()
obj = to_list(result) obj = to_list(result)
...@@ -284,7 +284,7 @@ def test_compose(): ...@@ -284,7 +284,7 @@ def test_compose():
mod[add_one] = add_one_func mod[add_one] = add_one_func
f = relay.Function([y], add_two_body) f = relay.Function([y], add_two_body)
mod[mod.entry_func] = f mod["main"] = f
x_data = np.array(np.random.rand()).astype('float32') x_data = np.array(np.random.rand()).astype('float32')
result = veval(mod)(x_data) result = veval(mod)(x_data)
......
...@@ -44,8 +44,8 @@ def _create_data(target, dshape, dtype, layout): ...@@ -44,8 +44,8 @@ def _create_data(target, dshape, dtype, layout):
conv2 = relay.nn.conv2d(conv1, w2, channels=32, kernel_size=(3, 3), padding=(1, 1)) conv2 = relay.nn.conv2d(conv1, w2, channels=32, kernel_size=(3, 3), padding=(1, 1))
out = relay.add(conv1, conv2) out = relay.add(conv1, conv2)
net = relay.Function(relay.analysis.free_vars(out), out) net = relay.Function(relay.analysis.free_vars(out), out)
net, params = relay.testing.create_workload(net) mod, params = relay.testing.create_workload(net)
tasks = autotvm.task.extract_from_program(net, tasks = autotvm.task.extract_from_program(mod["main"],
target=target, target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d,)) ops=(relay.op.nn.conv2d,))
...@@ -160,7 +160,7 @@ def test_DPTuner_run(): ...@@ -160,7 +160,7 @@ def test_DPTuner_run():
g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout) g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
mod = relay.module.Module() mod = relay.module.Module()
mod[mod.entry_func] = g mod["main"] = g
costs = [0.02, 0.02, 0.045] costs = [0.02, 0.02, 0.045]
config_list = [] config_list = []
cfg_dict = {"i": -1, cfg_dict = {"i": -1,
......
...@@ -64,7 +64,7 @@ def test_has_multiple_inputs(): ...@@ -64,7 +64,7 @@ def test_has_multiple_inputs():
def test_expr2graph(): def test_expr2graph():
net, _ = resnet.get_workload(num_layers=50, batch_size=1) mod, _ = resnet.get_workload(num_layers=50, batch_size=1)
node_dict = {} node_dict = {}
node_list = [] node_list = []
target_ops = ["conv2d"] target_ops = ["conv2d"]
...@@ -80,9 +80,9 @@ def test_expr2graph(): ...@@ -80,9 +80,9 @@ def test_expr2graph():
op_name_list.append("Tuple") op_name_list.append("Tuple")
else: else:
op_name_list.append("null") op_name_list.append("null")
relay.analysis.post_order_visit(net, _count_node) relay.analysis.post_order_visit(mod["main"], _count_node)
expr2graph(net, target_ops, node_dict, node_list) expr2graph(mod["main"], target_ops, node_dict, node_list)
for i, item in enumerate(zip(op_name_list, node_list)): for i, item in enumerate(zip(op_name_list, node_list)):
op_name, node = item op_name, node = item
assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \ assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
......
...@@ -81,28 +81,29 @@ def get_network(name, batch_size): ...@@ -81,28 +81,29 @@ def get_network(name, batch_size):
if "resnet" in name: if "resnet" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name: elif "vgg" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet': elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size) mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'squeezenet_v1.1': elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3': elif name == 'inception_v3':
input_shape = (1, 3, 299, 299) input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet': elif name == 'mxnet':
# an example for mxnet model # an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True) block = get_model('resnet18_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func] net = mod["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
mod = relay.Module.from_expr(net)
else: else:
raise ValueError("Unsupported network: " + name) raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape return mod, params, input_shape, output_shape
################################################################# #################################################################
...@@ -316,8 +317,8 @@ def tune_tasks(tasks, ...@@ -316,8 +317,8 @@ def tune_tasks(tasks,
def tune_and_evaluate(tuning_opt): def tune_and_evaluate(tuning_opt):
# extract workloads from relay program # extract workloads from relay program
print("Extract tasks...") print("Extract tasks...")
net, params, input_shape, _ = get_network(network, batch_size=1) mod, params, input_shape, _ = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, params=params,
ops=(relay.op.nn.conv2d,)) ops=(relay.op.nn.conv2d,))
...@@ -330,7 +331,7 @@ def tune_and_evaluate(tuning_opt): ...@@ -330,7 +331,7 @@ def tune_and_evaluate(tuning_opt):
print("Compile...") print("Compile...")
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build( graph, lib, params = relay.build_module.build(
net, target=target, params=params) mod, target=target, params=params)
# export library # export library
tmp = tempdir() tmp = tempdir()
......
...@@ -81,28 +81,29 @@ def get_network(name, batch_size): ...@@ -81,28 +81,29 @@ def get_network(name, batch_size):
if "resnet" in name: if "resnet" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name: elif "vgg" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet': elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype) mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'squeezenet_v1.1': elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3': elif name == 'inception_v3':
input_shape = (1, 3, 299, 299) input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet': elif name == 'mxnet':
# an example for mxnet model # an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True) block = get_model('resnet18_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func] net = mod["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
mod = relay.Module.from_expr(net)
else: else:
raise ValueError("Unsupported network: " + name) raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape return mod, params, input_shape, output_shape
########################################### ###########################################
# Set Tuning Options # Set Tuning Options
...@@ -218,8 +219,8 @@ def tune_tasks(tasks, ...@@ -218,8 +219,8 @@ def tune_tasks(tasks,
def tune_and_evaluate(tuning_opt): def tune_and_evaluate(tuning_opt):
# extract workloads from relay program # extract workloads from relay program
print("Extract tasks...") print("Extract tasks...")
net, params, input_shape, out_shape = get_network(network, batch_size=1) mod, params, input_shape, out_shape = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, ops=(relay.op.nn.conv2d,)) params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks # run tuning tasks
...@@ -231,7 +232,7 @@ def tune_and_evaluate(tuning_opt): ...@@ -231,7 +232,7 @@ def tune_and_evaluate(tuning_opt):
print("Compile...") print("Compile...")
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build( graph, lib, params = relay.build_module.build(
net, target=target, params=params) mod, target=target, params=params)
# export library # export library
tmp = tempdir() tmp = tempdir()
......
...@@ -82,28 +82,29 @@ def get_network(name, batch_size): ...@@ -82,28 +82,29 @@ def get_network(name, batch_size):
if "resnet" in name: if "resnet" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name: elif "vgg" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet': elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype) mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'squeezenet_v1.1': elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3': elif name == 'inception_v3':
input_shape = (1, 3, 299, 299) input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet': elif name == 'mxnet':
# an example for mxnet model # an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True) block = get_model('resnet18_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func] net = mod["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
mod = relay.Module.from_expr(net)
else: else:
raise ValueError("Unsupported network: " + name) raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape return mod, params, input_shape, output_shape
################################################################# #################################################################
...@@ -300,8 +301,10 @@ def tune_tasks(tasks, ...@@ -300,8 +301,10 @@ def tune_tasks(tasks,
def tune_and_evaluate(tuning_opt): def tune_and_evaluate(tuning_opt):
# extract workloads from relay program # extract workloads from relay program
print("Extract tasks...") print("Extract tasks...")
net, params, input_shape, _ = get_network(network, batch_size=1) mod, params, input_shape, _ = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, target_host=target_host, tasks = autotvm.task.extract_from_program(mod["main"],
target=target,
target_host=target_host,
params=params, ops=(relay.op.nn.conv2d,)) params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks # run tuning tasks
...@@ -313,7 +316,7 @@ def tune_and_evaluate(tuning_opt): ...@@ -313,7 +316,7 @@ def tune_and_evaluate(tuning_opt):
print("Compile...") print("Compile...")
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build( graph, lib, params = relay.build_module.build(
net, target=target, params=params, target_host=target_host) mod, target=target, params=params, target_host=target_host)
# export library # export library
tmp = tempdir() tmp = tempdir()
if use_android: if use_android:
......
...@@ -49,28 +49,29 @@ def get_network(name, batch_size): ...@@ -49,28 +49,29 @@ def get_network(name, batch_size):
if "resnet" in name: if "resnet" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name: elif "vgg" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet': elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype) mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'squeezenet_v1.1': elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3': elif name == 'inception_v3':
input_shape = (1, 3, 299, 299) input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet': elif name == 'mxnet':
# an example for mxnet model # an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True) block = get_model('resnet18_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func] net = mod["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
mod = relay.Module.from_expr(net)
else: else:
raise ValueError("Unsupported network: " + name) raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape return mod, params, input_shape, output_shape
# Replace "llvm" with the correct target of your CPU. # Replace "llvm" with the correct target of your CPU.
# For example, for AWS EC2 c5 instance with Intel Xeon # For example, for AWS EC2 c5 instance with Intel Xeon
...@@ -177,21 +178,21 @@ def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True): ...@@ -177,21 +178,21 @@ def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
def tune_and_evaluate(tuning_opt): def tune_and_evaluate(tuning_opt):
# extract workloads from relay program # extract workloads from relay program
print("Extract tasks...") print("Extract tasks...")
net, params, data_shape, out_shape = get_network(model_name, batch_size) mod, params, data_shape, out_shape = get_network(model_name, batch_size)
tasks = autotvm.task.extract_from_program(net, target=target, tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, ops=(relay.op.nn.conv2d,)) params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks # run tuning tasks
print("Tuning...") print("Tuning...")
tune_kernels(tasks, **tuning_opt) tune_kernels(tasks, **tuning_opt)
tune_graph(net, data_shape, log_file, graph_opt_sch_file) tune_graph(mod["main"], data_shape, log_file, graph_opt_sch_file)
# compile kernels with graph-level best records # compile kernels with graph-level best records
with autotvm.apply_graph_best(graph_opt_sch_file): with autotvm.apply_graph_best(graph_opt_sch_file):
print("Compile...") print("Compile...")
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build( graph, lib, params = relay.build_module.build(
net, target=target, params=params) mod, target=target, params=params)
# upload parameters to device # upload parameters to device
ctx = tvm.cpu() ctx = tvm.cpu()
......
...@@ -142,7 +142,7 @@ with open(synset_path) as f: ...@@ -142,7 +142,7 @@ with open(synset_path) as f:
shape_dict = {'data': x.shape} shape_dict = {'data': x.shape}
mod, params = relay.frontend.from_mxnet(block, shape_dict) mod, params = relay.frontend.from_mxnet(block, shape_dict)
# we want a probability so add a softmax operator # we want a probability so add a softmax operator
func = mod[mod.entry_func] func = mod["main"]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
###################################################################### ######################################################################
......
...@@ -84,7 +84,7 @@ print('x', x.shape) ...@@ -84,7 +84,7 @@ print('x', x.shape)
shape_dict = {'data': x.shape} shape_dict = {'data': x.shape}
mod, params = relay.frontend.from_mxnet(block, shape_dict) mod, params = relay.frontend.from_mxnet(block, shape_dict)
## we want a probability so add a softmax operator ## we want a probability so add a softmax operator
func = mod[mod.entry_func] func = mod["main"]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)
###################################################################### ######################################################################
......
...@@ -65,11 +65,11 @@ image_shape = (3, 224, 224) ...@@ -65,11 +65,11 @@ image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_class) out_shape = (batch_size, num_class)
net, params = relay.testing.resnet.get_workload( mod, params = relay.testing.resnet.get_workload(
num_layers=18, batch_size=batch_size, image_shape=image_shape) num_layers=18, batch_size=batch_size, image_shape=image_shape)
# set show_meta_data=True if you want to show meta data # set show_meta_data=True if you want to show meta data
print(net.astext(show_meta_data=False)) print(mod.astext(show_meta_data=False))
###################################################################### ######################################################################
# Compilation # Compilation
...@@ -98,7 +98,7 @@ opt_level = 3 ...@@ -98,7 +98,7 @@ opt_level = 3
target = tvm.target.cuda() target = tvm.target.cuda()
with relay.build_config(opt_level=opt_level): with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build_module.build( graph, lib, params = relay.build_module.build(
net, target, params=params) mod, target, params=params)
##################################################################### #####################################################################
# Run the generate library # Run the generate library
......
...@@ -26,7 +26,7 @@ def run_opt_pass(expr, opt_pass): ...@@ -26,7 +26,7 @@ def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, transform.Pass)
mod = relay.Module.from_expr(expr) mod = relay.Module.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod[mod.entry_func] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
def _to_shape(shape): def _to_shape(shape):
......
...@@ -127,7 +127,7 @@ def compile_network(opt, env, target): ...@@ -127,7 +127,7 @@ def compile_network(opt, env, target):
# Perform quantization in Relay # Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0, with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]): skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params) relay_prog = relay.quantize.quantize(mod["main"], params=params)
# Perform graph packing and constant folding for VTA target # Perform graph packing and constant folding for VTA target
if target.device_name == "vta": if target.device_name == "vta":
......
...@@ -91,7 +91,7 @@ def compile_network(env, target, model, start_pack, stop_pack): ...@@ -91,7 +91,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
# Perform quantization in Relay # Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0, with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]): skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params) relay_prog = relay.quantize.quantize(mod["main"], params=params)
# Perform graph packing and constant folding for VTA target # Perform graph packing and constant folding for VTA target
if target.device_name == "vta": if target.device_name == "vta":
......
...@@ -160,7 +160,7 @@ with autotvm.tophub.context(target): ...@@ -160,7 +160,7 @@ with autotvm.tophub.context(target):
# Perform quantization in Relay # Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0, with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]): skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params) relay_prog = relay.quantize.quantize(mod["main"], params=params)
# Perform graph packing and constant folding for VTA target # Perform graph packing and constant folding for VTA target
if target.device_name == "vta": if target.device_name == "vta":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment