Commit 3a5b9fcb by Tianqi Chen

[DOCS][FRONTEND] Modify from_mxnet to also return params, update docs (#36)

parent 34d74282
nnvm.compiler
-------------
.. automodule:: nnvm.compiler
.. autofunction:: nnvm.compiler.build
.. autofunction:: nnvm.compiler.build_config
.. autofunction:: nnvm.compiler.optimize
.. automodule:: nnvm.compiler.graph_util
:members:
.. automodule:: nnvm.compiler.graph_attr
:members:
.. automodule:: nnvm.compiler.compile_engine
:members:
nnvm.frontend
-------------
.. automodule:: nnvm.frontend
.. autofunction:: nnvm.frontend.from_mxnet
nnvm.graph
----------
.. automodule:: nnvm.graph
.. autofunction:: nnvm.graph.create
.. autoclass:: nnvm.graph.Graph
:members:
Python API
==========
This document contains the python API to NNVM compiler toolchain.
For user
.. toctree::
:maxdepth: 2
compiler
frontend
runtime
symbol
graph
top
nnvm.runtime
------------
.. automodule:: nnvm.runtime
.. autofunction:: nnvm.runtime.create
.. autoclass:: nnvm.runtime.Module
:members:
nnvm.symbol
-----------
.. automodule:: nnvm.symbol
.. autoclass:: nnvm.symbol.Symbol
.. autofunction:: nnvm.symbol.Group
nnvm.top
--------
.. automodule:: nnvm.top
.. autofunction:: register_compute
.. autofunction:: register_schedule
.. autofunction:: register_pattern
.. autoclass:: nnvm.top.AttrDict
:members:
...@@ -11,4 +11,5 @@ Contents ...@@ -11,4 +11,5 @@ Contents
self self
top top
tutorials/index tutorials/index
api/python/index
dev/index dev/index
...@@ -2,4 +2,4 @@ NNVM Examples ...@@ -2,4 +2,4 @@ NNVM Examples
============= =============
This folder contains example snippets of running NNVM Compilation. This folder contains example snippets of running NNVM Compilation.
- See also [Tutorials](tutorials) for tutorials with detailed explainations. - See also [Tutorials](../tutorials) for tutorials with detailed explainations.
"""Namespace for NNVM-TVM compiler toolchain""" """NNVM compiler toolchain.
User only need to use :any:`build` and :any:`build_config` to do the compilation.
The other APIs are for more advanced interaction with the compiler toolchain.
"""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
...@@ -10,9 +14,6 @@ from . compile_engine import engine, graph_key ...@@ -10,9 +14,6 @@ from . compile_engine import engine, graph_key
from .. import symbol as _symbol from .. import symbol as _symbol
from .. import graph as _graph from .. import graph as _graph
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
from .. import top as _top from .. import top as _top
......
...@@ -105,6 +105,9 @@ def _update_shape_dtype(shape, dtype, params): ...@@ -105,6 +105,9 @@ def _update_shape_dtype(shape, dtype, params):
def optimize(graph, shape, dtype="float32"): def optimize(graph, shape, dtype="float32"):
"""Perform target and parameter invariant graph optimization. """Perform target and parameter invariant graph optimization.
This is an advanced function that usually do not need to be called.
Call build instead.
Parameters Parameters
---------- ----------
graph : Graph graph : Graph
...@@ -126,7 +129,11 @@ def optimize(graph, shape, dtype="float32"): ...@@ -126,7 +129,11 @@ def optimize(graph, shape, dtype="float32"):
def build(graph, target, shape, dtype="float32", params=None): def build(graph, target, shape, dtype="float32", params=None):
"""Build graph into runtime library. """Build graph into runtime library.
This is the final step of graph compilation. The build function will optimize the graph and do the compilation.
When params is provided, the compiler might split the graph to
pre-compute certain values, so the final execution graph can
be different from the original one.
Parameters Parameters
---------- ----------
...@@ -255,8 +262,10 @@ def precompute_prune(graph, params): ...@@ -255,8 +262,10 @@ def precompute_prune(graph, params):
graph._set_json_attr("param_name_list", list(params.keys()), "list_str") graph._set_json_attr("param_name_list", list(params.keys()), "list_str")
graph = graph.apply("PrecomputePrune") graph = graph.apply("PrecomputePrune")
pre_graph = graph_attr._move_out_graph(graph, "precompute_graph") pre_graph = graph_attr._move_out_graph(graph, "precompute_graph")
if not pre_graph.symbol.list_output_names(): if pre_graph is None:
return graph, params return graph, params
out_names = pre_graph.json_attr("output_names") out_names = pre_graph.json_attr("output_names")
if not pre_graph.symbol.list_output_names():
return graph, params
out_arrs = _run_graph(pre_graph, params) out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs)) return graph, dict(zip(out_names, out_arrs))
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Compiler engine interface to internal engine""" """Compiler engine interface to internal engine
You can get the engine singleton at ``nnvm.compiler.engine``
"""
import tvm import tvm
_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems") _list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
...@@ -30,7 +33,10 @@ class GraphFunc(tvm.node.NodeBase): ...@@ -30,7 +33,10 @@ class GraphFunc(tvm.node.NodeBase):
class Engine(object): class Engine(object):
"""Global singleton compilation engine.""" """Global singleton compilation engine.
You can get the singleton at ``nnvm.compiler.engine``
"""
def items(self): def items(self):
"""List the available cache key value pairs. """List the available cache key value pairs.
......
...@@ -13,6 +13,9 @@ def infer_shape(graph, **shape): ...@@ -13,6 +13,9 @@ def infer_shape(graph, **shape):
graph : Graph graph : Graph
The graph to perform shape inference from The graph to perform shape inference from
shape : dict of str to tuple
The specific input shape.
Returns Returns
------- -------
in_shape : list of tuple in_shape : list of tuple
...@@ -38,6 +41,9 @@ def infer_dtype(graph, **dtype): ...@@ -38,6 +41,9 @@ def infer_dtype(graph, **dtype):
graph : Graph graph : Graph
The graph to perform type inference from The graph to perform type inference from
dtype : dict of str to dtype
The specific input data type.
Returns Returns
------- -------
in_dtype : list of tuple in_dtype : list of tuple
......
"""Frontend package.""" """NNVM frontends."""
from __future__ import absolute_import from __future__ import absolute_import
from .mxnet import from_mxnet from .mxnet import from_mxnet
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""MXNet symbol frontend.""" """MXNet symbol frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import json import json
import tvm
from .. import symbol as _sym from .. import symbol as _sym
__all__ = ['from_mxnet'] __all__ = ['from_mxnet']
...@@ -288,17 +289,34 @@ def _from_mxnet_impl(symbol, graph): ...@@ -288,17 +289,34 @@ def _from_mxnet_impl(symbol, graph):
return node return node
def from_mxnet(symbol): def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from mxnet.Symbol to compatible nnvm.Symbol """Convert from MXNet's model into compatible NNVM format.
Parameters Parameters
---------- ----------
symbol : mxnet.Symbol symbol : mxnet.Symbol
MXNet symbol MXNet symbol
arg_params : dict of str to mx.NDArray
The argument parameters in mxnet
aux_params : dict of str to mx.NDArray
The auxiliary parameters in mxnet
Returns Returns
------- -------
nnvm.Symbol net: nnvm.Symbol
Compatible nnvm symbol Compatible nnvm symbol
params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm
""" """
return _from_mxnet_impl(symbol, {}) sym = _from_mxnet_impl(symbol, {})
params = {}
arg_params = arg_params if arg_params else {}
aux_params = aux_params if aux_params else {}
for k, v in arg_params.items():
params[k] = tvm.nd.array(v.asnumpy())
for k, v in aux_params.items():
params[k] = tvm.nd.array(v.asnumpy())
return sym, params
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines # pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
"""Symbolic configuration API.""" """NNVM Graph IR API.
This is a developer API that is used to manipulate and transform graphs.
"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
......
# pylint: disable=invalid-name, unused-import # pylint: disable=invalid-name, unused-import
"""Symbolic configuration API.""" """Symbolic graph construction API.
This namespace contains most of the registered operators.
For detailed list of operators, checkout ``Core Tensor Operators``
"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import sys as _sys import sys as _sys
import os as _os import os as _os
......
"""Declaration about Tensor operators""" """Tensor operator property registry
Provide information to lower and schedule tensor operators.
"""
from .attr_dict import AttrDict from .attr_dict import AttrDict
from . import tensor from . import tensor
from . import nn from . import nn
from . import transform from . import transform
from . import reduction from . import reduction
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
...@@ -10,6 +10,7 @@ class AttrDict(object): ...@@ -10,6 +10,7 @@ class AttrDict(object):
"""Attribute dictionary in nnvm. """Attribute dictionary in nnvm.
Used by python registration of compute and schedule function. Used by python registration of compute and schedule function.
AttrDict is passed as the first argument to schedule and compute function.
""" """
_tvm_tcode = 18 _tvm_tcode = 18
......
...@@ -6,8 +6,8 @@ import tvm ...@@ -6,8 +6,8 @@ import tvm
import topi import topi
from topi.util import get_const_int from topi.util import get_const_int
from .tensor import _fschedule_broadcast from .tensor import _fschedule_broadcast
from ..compiler import registry as reg from . import registry as reg
from ..compiler import OpPattern from .registry import OpPattern
# relu # relu
@reg.register_compute("relu") @reg.register_compute("relu")
...@@ -55,9 +55,26 @@ def schedule_softmax(_, outs, target): ...@@ -55,9 +55,26 @@ def schedule_softmax(_, outs, target):
# naive schedule # naive schedule
return tvm.create_schedule([x.op for x in outs]) return tvm.create_schedule([x.op for x in outs])
# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.OPAQUE) reg.register_pattern("softmax", OpPattern.OPAQUE)
# log softmax
@reg.register_compute("log_softmax")
def compute_log_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.log_softmax(inputs[0])
@reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("log_softmax", OpPattern.OPAQUE)
# dense # dense
@reg.register_compute("dense") @reg.register_compute("dense")
......
...@@ -5,8 +5,8 @@ from __future__ import absolute_import ...@@ -5,8 +5,8 @@ from __future__ import absolute_import
import tvm import tvm
import topi import topi
import topi.cuda import topi.cuda
from ..compiler import registry as reg from . import registry as reg
from ..compiler import OpPattern from .registry import OpPattern
def _schedule_reduce(_, outs, target): def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce""" """Generic schedule for reduce"""
......
...@@ -5,8 +5,8 @@ from __future__ import absolute_import ...@@ -5,8 +5,8 @@ from __future__ import absolute_import
import tvm import tvm
import topi import topi
import topi.cuda import topi.cuda
from ..compiler import registry as reg from . import registry as reg
from ..compiler import OpPattern from .registry import OpPattern
def _schedule_injective(_, outs, target): def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast""" """Generic schedule for binary bcast"""
......
...@@ -5,8 +5,8 @@ from __future__ import absolute_import ...@@ -5,8 +5,8 @@ from __future__ import absolute_import
import tvm import tvm
import topi import topi
from .tensor import _fschedule_broadcast, _fschedule_injective from .tensor import _fschedule_broadcast, _fschedule_injective
from ..compiler import registry as reg from . import registry as reg
from ..compiler import OpPattern from .registry import OpPattern
# Need add reshape # Need add reshape
@reg.register_compute("expand_dims") @reg.register_compute("expand_dims")
......
...@@ -110,8 +110,13 @@ TVM_REGISTER_GLOBAL("nnvm.graph._move_module") ...@@ -110,8 +110,13 @@ TVM_REGISTER_GLOBAL("nnvm.graph._move_module")
TVM_REGISTER_GLOBAL("nnvm.graph._move_graph") TVM_REGISTER_GLOBAL("nnvm.graph._move_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>(); const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)-> std::string key = args[1];
MoveCopyAttr<nnvm::Graph>(args[1]); if (g.attrs.count(key)) {
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<nnvm::Graph>(key);
} else {
*rv = nullptr;
}
}); });
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -24,6 +24,8 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { ...@@ -24,6 +24,8 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
std::unordered_set<nnvm::Node*> pruned; std::unordered_set<nnvm::Node*> pruned;
nnvm::NodeEntryMap<nnvm::NodePtr> entry_var; nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
std::unordered_set<std::string> unique_name; std::unordered_set<std::string> unique_name;
// number of edges that are not variable
int non_var_edge = 0;
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
bool can_be_pruned = true; bool can_be_pruned = true;
...@@ -46,6 +48,9 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { ...@@ -46,6 +48,9 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
for (auto& e : n->inputs) { for (auto& e : n->inputs) {
if (pruned.count(e.node.get())) { if (pruned.count(e.node.get())) {
if (!entry_var.count(e)) { if (!entry_var.count(e)) {
if (!e.node->is_variable()) {
++non_var_edge;
}
nnvm::NodePtr var = nnvm::Node::Create(); nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name; var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) { if (e.node->num_outputs() != 1) {
...@@ -61,6 +66,11 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { ...@@ -61,6 +66,11 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
} }
}); });
// nothing being pruned.
if (non_var_edge == 0) {
return src;
}
nnvm::Graph pre_graph; nnvm::Graph pre_graph;
pre_graph.outputs.reserve(entry_var.size()); pre_graph.outputs.reserve(entry_var.size());
std::vector<std::string> output_names; std::vector<std::string> output_names;
......
...@@ -107,6 +107,23 @@ def test_softmax(): ...@@ -107,6 +107,23 @@ def test_softmax():
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_log_softmax():
x = sym.Variable("x")
y = sym.log_softmax(x)
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = topi.testing.log_softmax_python(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_dense(): def test_dense():
x = sym.Variable("x") x = sym.Variable("x")
y = sym.dense(x, units=3, name="dense") y = sym.dense(x, units=3, name="dense")
...@@ -161,6 +178,7 @@ def test_batchnorm(): ...@@ -161,6 +178,7 @@ def test_batchnorm():
if __name__ == "__main__": if __name__ == "__main__":
test_log_softmax()
test_batchnorm() test_batchnorm()
test_dense() test_dense()
test_relu() test_relu()
......
...@@ -5,23 +5,11 @@ import tvm ...@@ -5,23 +5,11 @@ import tvm
import nnvm.symbol as sym import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
import nnvm.runtime import nnvm.runtime
from nnvm.testing.config import ctx_list
from nnvm import frontend from nnvm import frontend
import mxnet as mx import mxnet as mx
import model_zoo import model_zoo
USE_GPU=True
def default_target():
if USE_GPU:
return 'cuda'
else:
return 'llvm'
def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)
def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)): def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)):
def get_mxnet_output(symbol, x, dtype='float32'): def get_mxnet_output(symbol, x, dtype='float32'):
...@@ -35,37 +23,28 @@ def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=( ...@@ -35,37 +23,28 @@ def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(
args, auxs = mod.get_params() args, auxs = mod.get_params()
return out, args, auxs return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, dtype='float32'): def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
new_sym, params = frontend.from_mxnet(symbol, args, auxs)
dshape = x.shape dshape = x.shape
shape_dict = {'data': dshape} shape_dict = {'data': dshape}
for k, v in args.items(): graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
shape_dict[k] = v.shape m = nnvm.runtime.create(graph, lib, ctx)
for k, v in auxs.items():
shape_dict[k] = v.shape
graph, lib, _ = nnvm.compiler.build(symbol, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m['set_input'], m['run'], m['get_output']
# set inputs # set inputs
set_input('data', tvm.nd.array(x.astype(dtype))) m.set_input("data", tvm.nd.array(x.astype(dtype)))
for k, v in args.items(): m.set_input(**params)
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype))) m.run()
for k, v in auxs.items():
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype)))
# execute
run()
# get outputs # get outputs
out = tvm.nd.empty(out_shape, dtype) out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
get_output(0, out)
return out.asnumpy() return out.asnumpy()
# random input # random input
dtype = 'float32' dtype = 'float32'
x = np.random.uniform(size=data_shape) x = np.random.uniform(size=data_shape)
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
new_sym = frontend.from_mxnet(mx_symbol) assert "data" not in args
tvm_out = get_tvm_output(new_sym, x, args, auxs, dtype) for target, ctx in ctx_list():
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5) tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5)
def test_forward_mlp(): def test_forward_mlp():
mlp = model_zoo.mx_mlp mlp = model_zoo.mx_mlp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment