Commit 3a5b9fcb by Tianqi Chen

[DOCS][FRONTEND] Modify from_mxnet to also return params, update docs (#36)

parent 34d74282
nnvm.compiler
-------------
.. automodule:: nnvm.compiler
.. autofunction:: nnvm.compiler.build
.. autofunction:: nnvm.compiler.build_config
.. autofunction:: nnvm.compiler.optimize
.. automodule:: nnvm.compiler.graph_util
:members:
.. automodule:: nnvm.compiler.graph_attr
:members:
.. automodule:: nnvm.compiler.compile_engine
:members:
nnvm.frontend
-------------
.. automodule:: nnvm.frontend
.. autofunction:: nnvm.frontend.from_mxnet
nnvm.graph
----------
.. automodule:: nnvm.graph
.. autofunction:: nnvm.graph.create
.. autoclass:: nnvm.graph.Graph
:members:
Python API
==========
This document contains the python API to NNVM compiler toolchain.
For user
.. toctree::
:maxdepth: 2
compiler
frontend
runtime
symbol
graph
top
nnvm.runtime
------------
.. automodule:: nnvm.runtime
.. autofunction:: nnvm.runtime.create
.. autoclass:: nnvm.runtime.Module
:members:
nnvm.symbol
-----------
.. automodule:: nnvm.symbol
.. autoclass:: nnvm.symbol.Symbol
.. autofunction:: nnvm.symbol.Group
nnvm.top
--------
.. automodule:: nnvm.top
.. autofunction:: register_compute
.. autofunction:: register_schedule
.. autofunction:: register_pattern
.. autoclass:: nnvm.top.AttrDict
:members:
......@@ -11,4 +11,5 @@ Contents
self
top
tutorials/index
api/python/index
dev/index
......@@ -2,4 +2,4 @@ NNVM Examples
=============
This folder contains example snippets of running NNVM Compilation.
- See also [Tutorials](tutorials) for tutorials with detailed explainations.
- See also [Tutorials](../tutorials) for tutorials with detailed explainations.
"""Namespace for NNVM-TVM compiler toolchain"""
"""NNVM compiler toolchain.
User only need to use :any:`build` and :any:`build_config` to do the compilation.
The other APIs are for more advanced interaction with the compiler toolchain.
"""
from __future__ import absolute_import
import tvm
......@@ -10,9 +14,6 @@ from . compile_engine import engine, graph_key
from .. import symbol as _symbol
from .. import graph as _graph
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
from .. import top as _top
......
......@@ -105,6 +105,9 @@ def _update_shape_dtype(shape, dtype, params):
def optimize(graph, shape, dtype="float32"):
"""Perform target and parameter invariant graph optimization.
This is an advanced function that usually do not need to be called.
Call build instead.
Parameters
----------
graph : Graph
......@@ -126,7 +129,11 @@ def optimize(graph, shape, dtype="float32"):
def build(graph, target, shape, dtype="float32", params=None):
"""Build graph into runtime library.
This is the final step of graph compilation.
The build function will optimize the graph and do the compilation.
When params is provided, the compiler might split the graph to
pre-compute certain values, so the final execution graph can
be different from the original one.
Parameters
----------
......@@ -255,8 +262,10 @@ def precompute_prune(graph, params):
graph._set_json_attr("param_name_list", list(params.keys()), "list_str")
graph = graph.apply("PrecomputePrune")
pre_graph = graph_attr._move_out_graph(graph, "precompute_graph")
if not pre_graph.symbol.list_output_names():
if pre_graph is None:
return graph, params
out_names = pre_graph.json_attr("output_names")
if not pre_graph.symbol.list_output_names():
return graph, params
out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs))
# pylint: disable=invalid-name
"""Compiler engine interface to internal engine"""
"""Compiler engine interface to internal engine
You can get the engine singleton at ``nnvm.compiler.engine``
"""
import tvm
_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
......@@ -30,7 +33,10 @@ class GraphFunc(tvm.node.NodeBase):
class Engine(object):
"""Global singleton compilation engine."""
"""Global singleton compilation engine.
You can get the singleton at ``nnvm.compiler.engine``
"""
def items(self):
"""List the available cache key value pairs.
......
......@@ -13,6 +13,9 @@ def infer_shape(graph, **shape):
graph : Graph
The graph to perform shape inference from
shape : dict of str to tuple
The specific input shape.
Returns
-------
in_shape : list of tuple
......@@ -38,6 +41,9 @@ def infer_dtype(graph, **dtype):
graph : Graph
The graph to perform type inference from
dtype : dict of str to dtype
The specific input data type.
Returns
-------
in_dtype : list of tuple
......
"""Frontend package."""
"""NNVM frontends."""
from __future__ import absolute_import
from .mxnet import from_mxnet
......@@ -2,6 +2,7 @@
"""MXNet symbol frontend."""
from __future__ import absolute_import as _abs
import json
import tvm
from .. import symbol as _sym
__all__ = ['from_mxnet']
......@@ -288,17 +289,34 @@ def _from_mxnet_impl(symbol, graph):
return node
def from_mxnet(symbol):
"""Convert from mxnet.Symbol to compatible nnvm.Symbol
def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
Parameters
----------
symbol : mxnet.Symbol
MXNet symbol
arg_params : dict of str to mx.NDArray
The argument parameters in mxnet
aux_params : dict of str to mx.NDArray
The auxiliary parameters in mxnet
Returns
-------
nnvm.Symbol
net: nnvm.Symbol
Compatible nnvm symbol
params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm
"""
return _from_mxnet_impl(symbol, {})
sym = _from_mxnet_impl(symbol, {})
params = {}
arg_params = arg_params if arg_params else {}
aux_params = aux_params if aux_params else {}
for k, v in arg_params.items():
params[k] = tvm.nd.array(v.asnumpy())
for k, v in aux_params.items():
params[k] = tvm.nd.array(v.asnumpy())
return sym, params
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
"""Symbolic configuration API."""
"""NNVM Graph IR API.
This is a developer API that is used to manipulate and transform graphs.
"""
from __future__ import absolute_import as _abs
import ctypes
......
# pylint: disable=invalid-name, unused-import
"""Symbolic configuration API."""
"""Symbolic graph construction API.
This namespace contains most of the registered operators.
For detailed list of operators, checkout ``Core Tensor Operators``
"""
from __future__ import absolute_import as _abs
import sys as _sys
import os as _os
......
"""Declaration about Tensor operators"""
"""Tensor operator property registry
Provide information to lower and schedule tensor operators.
"""
from .attr_dict import AttrDict
from . import tensor
from . import nn
from . import transform
from . import reduction
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
......@@ -10,6 +10,7 @@ class AttrDict(object):
"""Attribute dictionary in nnvm.
Used by python registration of compute and schedule function.
AttrDict is passed as the first argument to schedule and compute function.
"""
_tvm_tcode = 18
......
......@@ -6,8 +6,8 @@ import tvm
import topi
from topi.util import get_const_int
from .tensor import _fschedule_broadcast
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern
# relu
@reg.register_compute("relu")
......@@ -55,9 +55,26 @@ def schedule_softmax(_, outs, target):
# naive schedule
return tvm.create_schedule([x.op for x in outs])
# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.OPAQUE)
# log softmax
@reg.register_compute("log_softmax")
def compute_log_softmax(attrs, inputs, _):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.log_softmax(inputs[0])
@reg.register_schedule("log_softmax")
def schedule_log_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("log_softmax", OpPattern.OPAQUE)
# dense
@reg.register_compute("dense")
......
......@@ -5,8 +5,8 @@ from __future__ import absolute_import
import tvm
import topi
import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern
def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
......
......@@ -5,8 +5,8 @@ from __future__ import absolute_import
import tvm
import topi
import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern
def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast"""
......
......@@ -5,8 +5,8 @@ from __future__ import absolute_import
import tvm
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective
from ..compiler import registry as reg
from ..compiler import OpPattern
from . import registry as reg
from .registry import OpPattern
# Need add reshape
@reg.register_compute("expand_dims")
......
......@@ -110,8 +110,13 @@ TVM_REGISTER_GLOBAL("nnvm.graph._move_module")
TVM_REGISTER_GLOBAL("nnvm.graph._move_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<nnvm::Graph>(args[1]);
std::string key = args[1];
if (g.attrs.count(key)) {
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<nnvm::Graph>(key);
} else {
*rv = nullptr;
}
});
} // namespace compiler
} // namespace nnvm
......@@ -24,6 +24,8 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
std::unordered_set<nnvm::Node*> pruned;
nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
std::unordered_set<std::string> unique_name;
// number of edges that are not variable
int non_var_edge = 0;
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
bool can_be_pruned = true;
......@@ -46,6 +48,9 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
for (auto& e : n->inputs) {
if (pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
if (!e.node->is_variable()) {
++non_var_edge;
}
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) {
......@@ -61,6 +66,11 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
}
});
// nothing being pruned.
if (non_var_edge == 0) {
return src;
}
nnvm::Graph pre_graph;
pre_graph.outputs.reserve(entry_var.size());
std::vector<std::string> output_names;
......
......@@ -107,6 +107,23 @@ def test_softmax():
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_log_softmax():
x = sym.Variable("x")
y = sym.log_softmax(x)
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
y_np = topi.testing.log_softmax_python(data)
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_dense():
x = sym.Variable("x")
y = sym.dense(x, units=3, name="dense")
......@@ -161,6 +178,7 @@ def test_batchnorm():
if __name__ == "__main__":
test_log_softmax()
test_batchnorm()
test_dense()
test_relu()
......
......@@ -5,23 +5,11 @@ import tvm
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm.testing.config import ctx_list
from nnvm import frontend
import mxnet as mx
import model_zoo
USE_GPU=True
def default_target():
if USE_GPU:
return 'cuda'
else:
return 'llvm'
def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)
def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)):
def get_mxnet_output(symbol, x, dtype='float32'):
......@@ -35,37 +23,28 @@ def test_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(
args, auxs = mod.get_params()
return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, dtype='float32'):
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
new_sym, params = frontend.from_mxnet(symbol, args, auxs)
dshape = x.shape
shape_dict = {'data': dshape}
for k, v in args.items():
shape_dict[k] = v.shape
for k, v in auxs.items():
shape_dict[k] = v.shape
graph, lib, _ = nnvm.compiler.build(symbol, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m['set_input'], m['run'], m['get_output']
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = nnvm.runtime.create(graph, lib, ctx)
# set inputs
set_input('data', tvm.nd.array(x.astype(dtype)))
for k, v in args.items():
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype)))
for k, v in auxs.items():
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype)))
# execute
run()
m.set_input("data", tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out = tvm.nd.empty(out_shape, dtype)
get_output(0, out)
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
# random input
dtype = 'float32'
x = np.random.uniform(size=data_shape)
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
new_sym = frontend.from_mxnet(mx_symbol)
tvm_out = get_tvm_output(new_sym, x, args, auxs, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5)
assert "data" not in args
for target, ctx in ctx_list():
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5)
def test_forward_mlp():
mlp = model_zoo.mx_mlp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment