Commit c6c42af0 by Tianqi Chen

[COMPILER] Initial compiler infra (#12)

parent f6f448e1
...@@ -11,7 +11,7 @@ include $(config) ...@@ -11,7 +11,7 @@ include $(config)
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
CFLAGS += -Itvm/include -Itvm/dlpack/include CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src
ifdef DMLC_CORE_PATH ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include CFLAGS += -I$(DMLC_CORE_PATH)/include
...@@ -38,7 +38,7 @@ PLUGIN_OBJ = ...@@ -38,7 +38,7 @@ PLUGIN_OBJ =
include $(NNVM_PLUGINS) include $(NNVM_PLUGINS)
# specify tensor path # specify tensor path
.PHONY: clean all test lint doc cython cython3 cyclean .PHONY: clean all test lint pylint doc cython cython3 cyclean
UNAME_S := $(shell uname -s) UNAME_S := $(shell uname -s)
...@@ -55,7 +55,7 @@ endif ...@@ -55,7 +55,7 @@ endif
all: lib/libnnvm.a lib/libnnvm_top.$(SHARED_LIBRARY_SUFFIX) lib/libnnvm_top_runtime.$(SHARED_LIBRARY_SUFFIX) all: lib/libnnvm.a lib/libnnvm_top.$(SHARED_LIBRARY_SUFFIX) lib/libnnvm_top_runtime.$(SHARED_LIBRARY_SUFFIX)
SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc) SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc)
SRC_TOP = $(wildcard src/top/*.cc, src/top/*/*.cc src/runtime/*.cc) SRC_TOP = $(wildcard src/top/*/*.cc src/runtime/*.cc src/compiler/*.cc src/compiler/*/*.cc)
ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC)) ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC))
TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_TOP)) TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_TOP))
ALL_DEP = $(ALL_OBJ) ALL_DEP = $(ALL_OBJ)
...@@ -90,9 +90,12 @@ cython3: ...@@ -90,9 +90,12 @@ cython3:
cyclean: cyclean:
rm -rf python/nnvm/*/*.so python/nnvm/*/*.dylib python/nnvm/*/*.cpp rm -rf python/nnvm/*/*.so python/nnvm/*/*.dylib python/nnvm/*/*.cpp
lint: lint: pylint
python dmlc-core/scripts/lint.py nnvm cpp include src python dmlc-core/scripts/lint.py nnvm cpp include src
pylint:
pylint python/nnvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
doc: doc:
doxygen docs/Doxyfile doxygen docs/Doxyfile
......
/*!
* Copyright (c) 2017 by Contributors
* \file contrib_op_param.h
* \brief Additional parameters for compiler optimized operators.
*/
#ifndef NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#define NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#include <dmlc/parameter.h>
#include <string>
namespace nnvm {
namespace compiler {
/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout);
DMLC_DECLARE_FIELD(dst_layout);
}
};
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_CONTRIB_OP_PARAM_H_
/*!
* Copyright (c) 2017 by Contributors
* \file op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef NNVM_COMPILER_OP_ATTR_TYPES_H_
#define NNVM_COMPILER_OP_ATTR_TYPES_H_
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>
namespace nnvm {
namespace compiler {
using ::tvm::Array;
using ::tvm::Tensor;
using ::tvm::Schedule;
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind : int {
// Elementwise operation
kElemWise = 0,
// Broadcast operation
kBroadcast = 1,
// Complex operation, can fuse bcast in input/outputs
// but cannot chain another complex op
kComplex = 2,
// Extern operation, cannot fuse anything.
kExtern = 3
};
/*! \brief the operator pattern */
using TOpPattern = int;
/*!
* \brief Computation description interface
* \param attrs The attribute of the node.
* \param inputs The input tensors(placeholders)
* \return The output description of the tensor.
*/
using FTVMCompute = std::function<
Array<Tensor>
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
/*!
* \brief Build the computation schedule for
* op whose root is at current op.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = std::function<
Schedule(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target)>;
/*! \brief Layout Information about an entry */
using TLayoutInfo = std::string;
/*!
* \brief The producer consumer function of node layout
* \param attrs The attribute of the node.
* \param ilayouts The input layouts that the node request.
* \param olayouts The output layouts that the node produce.
* \return bool The success flag.
*/
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
std::vector<TLayoutInfo> *ilayouts,
std::vector<TLayoutInfo> *olayouts)>;
/*!
* \brief Transform from normal operator to vectorized operator
* \param node The source node.
* \return Transformed vectorized op.
*/
using FTVMVectorizedOp = std::function<nnvm::NodePtr (const nnvm::Node* node)>;
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_OP_ATTR_TYPES_H_
/*!
* Copyright (c) 2017 by Contributors
* \file packed_func_ext.h
* \brief Extension to enable packed functionn for nnvm types
*/
#ifndef NNVM_COMPILER_PACKED_FUNC_EXT_H_
#define NNVM_COMPILER_PACKED_FUNC_EXT_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <string>
#include <unordered_map>
namespace nnvm {
namespace compiler {
using tvm::runtime::PackedFunc;
using AttrDict = std::unordered_map<std::string, std::string>;
/*!
* \brief Get PackedFunction from global registry and
* report error if it does not exist
* \param name The name of the function.
* \return The created PackedFunc.
*/
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
} // namespace compiler
} // namespace nnvm
// Enable the graph and symbol object exchange.
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<nnvm::Symbol> {
static const int code = 16;
};
template<>
struct extension_class_info<nnvm::Graph> {
static const int code = 17;
};
template<>
struct extension_class_info<nnvm::compiler::AttrDict> {
static const int code = 18;
};
} // namespace runtime
} // namespace tvm
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_
...@@ -72,6 +72,18 @@ template<typename AttrType> ...@@ -72,6 +72,18 @@ template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs, using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
std::vector<AttrType> *in_attrs, std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs)>; std::vector<AttrType> *out_attrs)>;
/*!
* \brief Get attribute dictionary from node.
*
* \param attrs The attributes of the node.
* \return The attribute dict.
* \note Register under "FUpdateAttrDict"
*/
using FGetAttrDict = std::function<
std::unordered_map<std::string, std::string>
(const NodeAttrs& attrs)>;
/*! /*!
* \brief Shape inference function. * \brief Shape inference function.
* Update the shapes given the input shape information. * Update the shapes given the input shape information.
......
NNVM Core Operator Specs NNVM Core Operator and Compiler
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name # pylint: disable=invalid-name, unused-import
""" ctypes library of nnvm and helper functions """ """ ctypes library of nnvm and helper functions """
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
import os
import ctypes import ctypes
import numpy as np import numpy as np
from . import libinfo from . import libinfo
__all__ = ['NNNetError'] try:
import tvm
except ImportError:
pass
#---------------------------- #----------------------------
# library loading # library loading
#---------------------------- #----------------------------
...@@ -181,7 +184,7 @@ def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True) ...@@ -181,7 +184,7 @@ def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True)
param_keys.add(key) param_keys.add(key)
type_info = py_str(arg_types[i]) type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info) ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0: if arg_descs[i]:
ret += '\n ' + py_str(arg_descs[i]) ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret) param_str.append(ret)
doc_str = ('Parameters\n' + doc_str = ('Parameters\n' +
......
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines # pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines,
# pylint: disable=len-as-condition, consider-iterating-dictionary
"""Symbolic configuration API.""" """Symbolic configuration API."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
...@@ -7,7 +8,7 @@ import copy ...@@ -7,7 +8,7 @@ import copy
import ctypes import ctypes
import sys import sys
from .._base import _LIB from .._base import _LIB
from .._base import c_array, c_str, nn_uint, py_str, string_types from .._base import c_array, c_str, nn_uint, py_str
from .._base import SymbolHandle, OpHandle from .._base import SymbolHandle, OpHandle
from .._base import check_call, ctypes2docstring from .._base import check_call, ctypes2docstring
from ..name import NameManager from ..name import NameManager
......
"""Namespace for NNVM-TVM compiler toolchain"""
from __future__ import absolute_import
import tvm
from . import build_module
from . build_module import build
from .. import symbol as _symbol
from .. import graph as _graph
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
from .. import top as _top
tvm.register_extension(_symbol.Symbol, _symbol.Symbol)
tvm.register_extension(_graph.Graph, _graph.Graph)
# pylint: disable=invalid-name
"""Namespace for building operators."""
from __future__ import absolute_import as _abs
import tvm
from . import graph_attr
from .. import graph as _graph
@tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
@tvm.register_func("nnvm.compiler.build_target")
def _build(funcs, target):
return tvm.build(funcs, target=target)
_move_module = tvm.get_global_func("nnvm.compiler._move_module")
def optimize(graph):
"""Perform graph optimization
Parameters
----------
graph : Graph
The graph to be used in lowering.
Returns
-------
graph : Graph
The optimized execution graph.
"""
return graph
def build(graph, target, shape, dtype="float32"):
"""Build graph into runtime library.
This is the final step of graph compilation.
Parameters
----------
graph : Graph
The graph to be used in lowering
target : str
The build target
shape : dict of str to tuple
The input shape to the graph
dtype : str or dict of str to str
The input types to the graph
Returns
-------
graph : Graph
The final execution graph.
libmod : tvm.Module
The modue that comes with the execution graph
"""
if not isinstance(target, str):
raise TypeError("require target to be str")
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
graph = graph_attr.set_shape(graph, shape)
graph = graph_attr.set_dtype(graph, dtype)
graph._set_json_attr("target", target, "str")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
libmod = _move_module(graph)
return graph, libmod
"""Utilities to access graph attributes"""
from __future__ import absolute_import as _abs
def set_shape(g, shape):
"""Set the shape of graph nodes in the graph attribute.
Parameters
----------
g : Graph
The input graph
shape : dict of str to tuple
The input shape
Returns
-------
g : Graph
The updated graph with updated shape.
"""
index = g.index
list_shape = [[]] * index.num_node_entries
for k, v in shape.items():
list_shape[index.entry_id(k)] = v
g._set_json_attr("shape", list_shape, 'list_shape')
return g
DTYPE_DICT = {
"float32": 0
}
def set_dtype(g, dtype):
"""Set the dtype of graph nodes
Parameters
----------
g : Graph
The input graph
dtype : dict of str to str or str
The input dtype
Returns
-------
g : Graph
The updated graph with updated dtype.
"""
index = g.index
if isinstance(dtype, dict):
list_dtype = [-1] * index.num_node_entries
for k, v in dtype.items():
list_dtype[index.entry_id(k)] = DTYPE_DICT[v]
else:
list_dtype = [DTYPE_DICT[dtype]] * index.num_node_entries
g._set_json_attr("dtype", list_dtype, "list_int")
return g
"""Namespace of graph pass.
Principle:
- Graph in, graph out: always takes in graph as first argument and returns a graph
- Composable API: break graph transformation pass as segments of small transformations.
"""
from __future__ import absolute_import as _abs
# pylint: disable=invalid-name
"""Information registry to register operator information for compiler"""
import tvm
class OpPattern(object):
ELEM_WISE = 0
BROADCAST = 1
COMPLEX = 2
EXTERN = 2
_register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
_register_pattern = tvm.get_global_func("nnvm._register_pattern")
def register_compute(op_name, f=None, level=10):
"""Register compute function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_compute(op_name, myf, level)
return myf
return register(f) if f else register
def register_schedule(op_name, f=None, level=10):
"""Register schedule function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_schedule(op_name, myf, level)
return myf
return register(f) if f else register
def register_pattern(op_name, pattern, level=10):
"""Register pattern code for operator
Parameters
----------
op_name : str
The name of operator
pattern : int
The pattern code.
level : int
The priority level
"""
_register_pattern(op_name, pattern, level)
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import ctypes import ctypes
import sys
import json import json
from ._base import _LIB from ._base import _LIB
from ._base import c_array, c_str, nn_uint, py_str, string_types from ._base import c_array, c_str, nn_uint, py_str, string_types
...@@ -12,12 +11,73 @@ from ._base import GraphHandle, SymbolHandle ...@@ -12,12 +11,73 @@ from ._base import GraphHandle, SymbolHandle
from ._base import check_call from ._base import check_call
from .symbol import Symbol, Group as _Group from .symbol import Symbol, Group as _Group
class GraphIndex(object):
"""Index for quickly accessing graph attributes.
Parameters
----------
graph : Graph
The graph to create index.
"""
def __init__(self, graph):
jgraph = json.loads(create(graph).apply("SaveJSON").json_attr("json"))
self.nodes = jgraph["nodes"]
self.entry_ptr = jgraph["node_row_ptr"]
self._name2nodeid = {n["name"]: i for i, n in enumerate(self.nodes)}
@property
def num_nodes(self):
"""Number of nodes in graph."""
return len(self.entry_ptr) - 1
@property
def num_node_entries(self):
"""Number of nodes in graph."""
return self.entry_ptr[-1]
def node_id(self, key):
"""Get the node index for a given key.
Parameters
----------
key : str or int
The node key or index
Returns
-------
index : int
The entry index
"""
return self._name2nodeid[key]
def entry_id(self, key, value_index=0):
"""Get the entry id of a node entry.
Parameters
----------
key : str or int
The node key or index
value_index : int
The value index of output
Returns
-------
index : int
The entry index
"""
idx = self.node_id(key) if isinstance(key, str) else key
assert value_index < self.entry_ptr[idx + 1]
return self.entry_ptr[idx] + value_index
class Graph(object): class Graph(object):
"""Graph is the graph object that can be used to apply optimization pass. """Graph is the graph object that can be used to apply optimization pass.
It contains additional graphwise attribute besides the internal symbol.
It contains additional graphwise attribute besides the internal symbol.
""" """
_tvm_tcode = 17
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle): def __init__(self, handle):
...@@ -29,6 +89,7 @@ class Graph(object): ...@@ -29,6 +89,7 @@ class Graph(object):
the handle to the underlying C++ Graph the handle to the underlying C++ Graph
""" """
self.handle = handle self.handle = handle
self._index = None
def __del__(self): def __del__(self):
check_call(_LIB.NNGraphFree(self.handle)) check_call(_LIB.NNGraphFree(self.handle))
...@@ -53,8 +114,7 @@ class Graph(object): ...@@ -53,8 +114,7 @@ class Graph(object):
if success.value != 0: if success.value != 0:
json_str = py_str(ret.value) json_str = py_str(ret.value)
return json.loads(json_str)[1] return json.loads(json_str)[1]
else: return None
return None
def _set_symbol_list_attr(self, key, value): def _set_symbol_list_attr(self, key, value):
"""Set the attribute of the graph. """Set the attribute of the graph.
...@@ -96,17 +156,33 @@ class Graph(object): ...@@ -96,17 +156,33 @@ class Graph(object):
self.handle, c_str(key), c_str(json_value))) self.handle, c_str(key), c_str(json_value)))
@property @property
def _tvm_handle(self):
return self.handle.value
@property
def symbol(self): def symbol(self):
shandle = SymbolHandle() shandle = SymbolHandle()
check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle))) check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle)))
return Symbol(shandle) return Symbol(shandle)
@property
def index(self):
if not self._index:
self._index = GraphIndex(self)
return self._index
def apply(self, passes): def apply(self, passes):
"""Apply passes to the graph """Apply passes to the graph
Parameters Parameters
---------- ----------
passes : str or list of str
The passes to be applied
Returns
-------
g : Graph
The transformed graph.
""" """
if isinstance(passes, string_types): if isinstance(passes, string_types):
passes = [passes] passes = [passes]
......
...@@ -52,7 +52,7 @@ def find_lib_path(): ...@@ -52,7 +52,7 @@ def find_lib_path():
dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path] dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0: if not lib_path:
raise RuntimeError('Cannot find the files.\n' + raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path))) 'List of candidates:\n' + str('\n'.join(dll_path)))
return lib_path return lib_path
......
"""Runtime environment for nnvm relies on TVM."""
import tvm
from tvm.contrib import rpc
def create(graph, libmod, ctx):
"""Create a runtime executor module given the graph and module.
Parameters
----------
graph : The graph to be deployed
The graph to be loaded.
libmod : tvm.Module
The module of the corresponding function
ctx : TVMContext
The context to deploy the module, can be local or remote.
Returns
-------
graph_module : tvm.Module
Runtime graph module to execute the graph.
"""
json_str = graph if isinstance(graph, str) else graph.apply("SaveJSON").json_attr("json")
device_type = ctx.device_type
device_id = ctx.device_id
if device_type >= rpc.RPC_SESS_MASK:
assert libmod.type_key == "rpc"
assert rpc._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index
hmod = rpc._ModuleHandle(libmod)
fcreate = ctx._rpc_sess.get_function("nnvm.runtime.remote_create")
device_type = device_type % rpc.RPC_SESS_MASK
return fcreate(json_str, hmod, device_type, device_id)
fcreate = tvm.get_global_func("nnvm.runtime.create")
return fcreate(json_str, libmod, device_type, device_id)
# pylint: disable=invalid-name, unused-import
"""Symbolic configuration API.""" """Symbolic configuration API."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import sys as _sys import sys as _sys
...@@ -7,21 +8,21 @@ import ctypes as _ctypes ...@@ -7,21 +8,21 @@ import ctypes as _ctypes
from numbers import Number as _Number from numbers import Number as _Number
from . import _base from . import _base
from ._base import _LIB, check_call as _check_call from ._base import _LIB, check_call as _check_call
from . import _symbol_internal as _internal
from .attribute import AttrScope from .attribute import AttrScope
from . import _symbol_internal as _internal
# Use different verison of SymbolBase # Use different verison of SymbolBase
# When possible, use cython to speedup part of computation. # When possible, use cython to speedup part of computation.
try: try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from ._ctypes.symbol import SymbolBase, _init_symbol_module from ._ctypes.symbol import SymbolBase, _init_symbol_module
elif _sys.version_info >= (3, 0): elif _sys.version_info >= (3, 0):
from ._cy3.symbol import SymbolBase, _init_symbol_module from ._cy3.symbol import SymbolBase, _init_symbol_module
else: else:
from ._cy2.symbol import SymbolBase, _init_symbol_module from ._cy2.symbol import SymbolBase, _init_symbol_module
except ImportError: except ImportError:
from ._ctypes.symbol import SymbolBase, _init_symbol_module from ._ctypes.symbol import SymbolBase, _init_symbol_module
class Symbol(SymbolBase): class Symbol(SymbolBase):
...@@ -29,6 +30,12 @@ class Symbol(SymbolBase): ...@@ -29,6 +30,12 @@ class Symbol(SymbolBase):
# disable dictionary storage, also do not have parent type. # disable dictionary storage, also do not have parent type.
__slots__ = [] __slots__ = []
_tvm_tcode = 16
@property
def _tvm_handle(self):
return self.handle.value
def __add__(self, other): def __add__(self, other):
if isinstance(other, Symbol): if isinstance(other, Symbol):
return __add_symbol__(self, other) return __add_symbol__(self, other)
...@@ -148,8 +155,7 @@ class Symbol(SymbolBase): ...@@ -148,8 +155,7 @@ class Symbol(SymbolBase):
self.handle, _base.c_str(key), _ctypes.byref(ret), _ctypes.byref(success))) self.handle, _base.c_str(key), _ctypes.byref(ret), _ctypes.byref(success)))
if success.value != 0: if success.value != 0:
return _base.py_str(ret.value) return _base.py_str(ret.value)
else: return None
return None
def list_attr(self, recursive=False): def list_attr(self, recursive=False):
"""Get all attributes from the symbol. """Get all attributes from the symbol.
......
"""Declaration about Tensor operators"""
from .attr_dict import AttrDict
from . import tensor
from . import nn
# pylint: disable=invalid-name
"""Attr dictionary object used by schedule functions"""
import json
import tvm
_dict_get = tvm.get_global_func("nnvm.compiler._dict_get")
_dict_size = tvm.get_global_func("nnvm.compiler._dict_size")
_dict_keys = tvm.get_global_func("nnvm.compiler._dict_keys")
class AttrDict(object):
"""Attribute dictionary in nnvm.
Used by python registration of compute and schedule function.
"""
_tvm_tcode = 18
def __init__(self, handle):
self.handle = handle
def __del__(self):
tvm.nd.free_extension_handle(self.handle, 18)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, key):
return _dict_get(self, key)
def keys(self):
"""Get list of keys in the dict.
Returns
-------
keys : list of str
List of keys
"""
return [x.value for x in _dict_keys(self)]
def get_int_tuple(self, key):
"""Get tuple of integer from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
tuple : tuple of int
The result tuple
"""
return tuple(json.loads(self[key]))
def get_int(self, key):
"""Get integer from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : int
The result value
"""
return int(self[key])
def get_bool(self, key):
"""Get bool from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : bool
The result value
"""
return self[key] != "False"
def __repr__(self):
return str({k : self[k] for k in self.keys()})
tvm.register_extension(AttrDict, AttrDict)
"""Definition of nn ops"""
from __future__ import absolute_import
import tvm
import topi
from ..compiler import registry as reg
from ..compiler import OpPattern
# conv
@reg.register_compute("conv2d")
def compute_conv2d(attrs, inputs):
"""Compute definition of conv2d"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
layout = attrs["layout"]
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
out = topi.nn.conv2d_nchw(inputs[0], inputs[1], strides, padding)
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.broadcast_to(bias, (1, bias.shape[0], 1, 1))
out = topi.broadcast_add(out, bias)
return out
@reg.register_schedule("conv2d")
def schedule_conv2d(_, outs, target):
"""Schedule definition of conv2d"""
if target == "cuda":
return topi.cuda.schedule_conv2d_nchw(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.COMPLEX)
# pylint: disable=invalid-name
"""Tensor ops"""
from __future__ import absolute_import
import tvm
import topi
import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern
def _schedule_broadcast(_, outs, target):
"""Generic schedule for binary bcast"""
if target == "cuda":
return topi.cuda.schedule_elemwise(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
return s
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
# exp
reg.register_compute("exp",
lambda _, x: topi.exp(x[0]))
reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_schedule("exp", _fschedule_broadcast)
# broadcast_add
reg.register_compute("broadcast_add",
lambda _, x: topi.broadcast_add(x[0], x[1]))
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
reg.register_schedule("broadcast_add", _fschedule_broadcast)
# broadcast_sub
reg.register_compute("broadcast_sub",
lambda _, x: topi.broadcast_sub(x[0], x[1]))
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
reg.register_schedule("broadcast_sub", _fschedule_broadcast)
# broadcast_mul
reg.register_compute("broadcast_mul",
lambda _, x: topi.broadcast_mul(x[0], x[1]))
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
reg.register_schedule("broadcast_mul", _fschedule_broadcast)
# broadcast_div
reg.register_compute("broadcast_div",
lambda _, x: topi.broadcast_div(x[0], x[1]))
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
reg.register_schedule("broadcast_div", _fschedule_broadcast)
/*!
* Copyright (c) 2017 by Contributors
* \file packed_func_ext.cc
* \brief Registeration of extension type.
*/
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h>
namespace tvm {
namespace runtime {
TVM_REGISTER_EXT_TYPE(nnvm::Graph);
TVM_REGISTER_EXT_TYPE(nnvm::Symbol);
TVM_REGISTER_EXT_TYPE(nnvm::compiler::AttrDict);
} // namespace runtime
} // namespace tvm
namespace nnvm {
namespace compiler {
using tvm::Tensor;
using tvm::Array;
using tvm::Node;
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
TVM_REGISTER_GLOBAL("nnvm.compiler._dict_get")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const AttrDict& dict = args[0].AsExtension<AttrDict>();
std::string key = args[1];
auto it = dict.find(key);
if (it != dict.end()) {
*rv = it->second;
} else {
*rv = nullptr;
}
});
TVM_REGISTER_GLOBAL("nnvm.compiler._dict_size")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const AttrDict& dict = args[0].AsExtension<AttrDict>();
*rv = static_cast<int64_t>(dict.size());
});
TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const AttrDict& dict = args[0].AsExtension<AttrDict>();
tvm::Array<tvm::Expr> keys;
for (const auto& kv : dict) {
keys.push_back(kv.first);
}
*rv = keys;
});
// custom version of TVM compute
inline std::unordered_map<std::string, std::string>
GetAttrDict(const NodeAttrs& attrs) {
static auto& fgetdict = nnvm::Op::GetAttr<FGetAttrDict>("FGetAttrDict");
if (fgetdict.count(attrs.op)) {
return fgetdict[attrs.op](attrs);
} else {
return attrs.dict;
}
}
TVM_REGISTER_GLOBAL("nnvm._register_compute")
.set_body([](TVMArgs args, TVMRetValue *rv) {
PackedFunc f = args[1];
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fcompute = [f](const NodeAttrs& attrs, const Array<Tensor>& inputs)
-> Array<Tensor> {
TVMRetValue ret = f(GetAttrDict(attrs), inputs);
if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
return {ret.operator Tensor()};
} else {
return ret;
}
};
op.set_attr<FTVMCompute>("FTVMCompute", fcompute, args[2]);
});
TVM_REGISTER_GLOBAL("nnvm._register_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
PackedFunc f = args[1];
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fschedule = [f](const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target) {
return f(GetAttrDict(attrs), outs, target).operator Schedule();
};
op.set_attr<FTVMSchedule>("FTVMSchedule", fschedule, args[2]);
});
TVM_REGISTER_GLOBAL("nnvm._register_pattern")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
op.set_attr<TOpPattern>("TOpPattern", args[1].operator int(), args[2]);
});
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file layout_transform.cc
* \brief Transforms layout.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/contrib_op_param.h>
namespace nnvm {
namespace compiler {
const TLayoutInfo& GetDefaultLayout() {
static TLayoutInfo default_layout = "default";
return default_layout;
}
nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
const std::string& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src + "_to_" + dst + std::to_string(count++);
n->attrs.dict["src_layout"] = src;
n->attrs.dict["dst_layout"] = dst;
n->op()->attr_parser(&(n->attrs));
return n;
}
/*!
* \brief A simple layout transform pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& op_layout_request =
nnvm::Op::GetAttr<FTVMLayoutRequest>("FTVMLayoutRequest");
static auto& op_vecop =
nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const std::vector<TLayoutInfo>& input_layouts =
src.GetAttr<std::vector<TLayoutInfo> >("layout");
const IndexedGraph& idx = src.indexed_graph();
std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
is_map = false;
break;
}
}
}
return is_map;
};
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id];
mirror_vec[nid] = new_node;
continue;
}
if (op_vecop.count(inode.source->op())) {
new_node = op_vecop[inode.source->op()](inode.source);
new_node->inputs.resize(new_node->num_inputs());
}
// set up output and input layouts
std::vector<TLayoutInfo> request_ilayouts(new_node->num_inputs(), GetDefaultLayout());
if (op_layout_request.count(new_node->op())) {
std::vector<TLayoutInfo> produce_olayouts(new_node->num_outputs(), GetDefaultLayout());
CHECK(op_layout_request[new_node->op()](
new_node->attrs, &request_ilayouts, &produce_olayouts))
<< "Layout request fail";
CHECK_EQ(request_ilayouts.size(), new_node->num_inputs());
CHECK_EQ(produce_olayouts.size(), new_node->num_outputs());
for (size_t i = 0; i < new_node->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i];
}
}
bool map_layout = is_map_op(nid);
if (map_layout) {
const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])];
for (const auto& e : inode.inputs) {
if (produce_vec[idx.entry_id(e)] != layout) {
map_layout = false;
break;
}
}
if (map_layout) {
for (size_t i = 0; i < inode.source->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = layout;
}
}
}
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] =
nnvm::NodeEntry{in, e.index, e.version};
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
TLayoutInfo request = request_ilayouts[i];
if (!map_layout && (produce != request)) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_" + request;
tnode->inputs.emplace_back(new_node->inputs[i]);
new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0};
}
}
mirror_vec[nid] = new_node;
}
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : idx.outputs()) {
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
if (produce != GetDefaultLayout()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout());
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_default";
tnode->inputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
} else {
outputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
}
nnvm::Graph ret;
ret.outputs = std::move(outputs);
return ret;
}
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file prune_graph.cc
* \brief Prune the graph to do constant folding.
*
* This pass breaks the graph into pre-compute graph
* and the execution graph.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <unordered_set>
namespace nnvm {
namespace compiler {
nnvm::Graph PruneGraph(nnvm::Graph src) {
const auto& params = src.GetAttr<std::unordered_set<std::string> >("params");
std::unordered_set<nnvm::Node*> pruned;
nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
bool can_be_pruned = true;
if (n->is_variable()) {
if (params.count(n->attrs.name)) {
pruned.emplace(n.get());
}
can_be_pruned = false;
}
for (const auto& e : n->inputs) {
if (!pruned.count(e.node.get())) {
can_be_pruned = false;
}
}
if (can_be_pruned) {
pruned.emplace(n.get());
} else {
// scan again to find edge nodes, skip variables
for (auto& e : n->inputs) {
if (!e.node->is_variable() && pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index);
entry_var.emplace(e, var);
}
e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
}
}
}
});
nnvm::Graph pre_graph;
pre_graph.outputs.reserve(entry_var.size());
std::vector<std::string> output_names;
output_names.reserve(entry_var.size());
for (auto kv : entry_var) {
if (kv.first.node->is_variable()) continue;
pre_graph.outputs.emplace_back(kv.first);
output_names.emplace_back(kv.second->attrs.name);
}
pre_graph.attrs["pruned_params"] =
std::make_shared<dmlc::any>(std::move(output_names));
src.attrs["pre_graph"] =
std::make_shared<dmlc::any>(std::move(pre_graph));
return src;
}
NNVM_REGISTER_PASS(PruneGraph)
.set_body(PruneGraph);
} // namespace compiler
} // namespace nnvm
...@@ -312,23 +312,34 @@ NNVM_REGISTER_OP(tvm_op) ...@@ -312,23 +312,34 @@ NNVM_REGISTER_OP(tvm_op)
return param.num_outputs; return param.num_outputs;
}); });
TVM_REGISTER_GLOBAL("nnvm.tvm.create_executor") tvm::runtime::Module RuntimeCreate(std::string sym_json,
tvm::runtime::Module m,
int device_type,
int device_id) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
// load graph from json string
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(sym_json);
g = nnvm::ApplyPass(std::move(g), "LoadJSON");
std::shared_ptr<GraphExecutor> exec = std::make_shared<GraphExecutor>();
exec->Init(g, m, ctx);
return tvm::runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("nnvm.runtime.create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = RuntimeCreate(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("nnvm.runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
std::string sym_json = args[0]; void* mhandle = args[1];
std::string param_blob = args[1]; *rv = RuntimeCreate(args[0],
tvm::runtime::Module m = args[2]; *static_cast<tvm::runtime::Module*>(mhandle),
TVMContext ctx; args[2], args[3]);
ctx.device_type = static_cast<DLDeviceType>(args[3].operator int());
ctx.device_id = args[4];
// load graph from json string
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(sym_json);
g = nnvm::ApplyPass(std::move(g), "LoadJSON");
std::shared_ptr<GraphExecutor> exec = std::make_shared<GraphExecutor>();
exec->Init(g, m, ctx);
// load params form stream of string
exec->LoadParams(std::move(param_blob));
*rv = tvm::runtime::Module(exec);
}); });
} // namespace runtime } // namespace runtime
} // namespace nnvm } // namespace nnvm
...@@ -114,11 +114,12 @@ a bias vector is created and added to the outputs. ...@@ -114,11 +114,12 @@ a bias vector is created and added to the outputs.
.add_argument("bias", "1D Tensor", "Bias parameter.") .add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(Conv2DParam::__FIELDS__()) .add_arguments(Conv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<Conv2DParam>) .set_attr_parser(ParamParser<Conv2DParam>)
.set_num_outputs(1) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape) .set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2); .set_support_level(2);
...@@ -203,11 +204,12 @@ said convolution. ...@@ -203,11 +204,12 @@ said convolution.
.add_argument("bias", "1D Tensor", "Bias parameter.") .add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(Conv2DTransposeParam::__FIELDS__()) .add_arguments(Conv2DTransposeParam::__FIELDS__())
.set_attr_parser(ParamParser<Conv2DTransposeParam>) .set_attr_parser(ParamParser<Conv2DTransposeParam>)
.set_num_outputs(1) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DTransposeParam>)
.set_num_inputs(UseBiasNumInputs<Conv2DTransposeParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DTransposeParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DTransposeParam>)
.set_attr<FInferShape>("FInferShape", Conv2DTransposeInferShape) .set_attr<FInferShape>("FInferShape", Conv2DTransposeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DTransposeParam>)
.set_support_level(2); .set_support_level(2);
} // namespace top } // namespace top
......
...@@ -66,6 +66,7 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored. ...@@ -66,6 +66,7 @@ If ``use_bias`` is set to be false, then the ``bias`` term is ignored.
.add_argument("bias", "1D Tensor", "Bias parameter.") .add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(DenseParam::__FIELDS__()) .add_arguments(DenseParam::__FIELDS__())
.set_attr_parser(ParamParser<DenseParam>) .set_attr_parser(ParamParser<DenseParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<DenseParam>)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<DenseParam>) .set_num_inputs(UseBiasNumInputs<DenseParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<DenseParam>)
...@@ -95,10 +96,11 @@ NNVM_REGISTER_OP(dropout) ...@@ -95,10 +96,11 @@ NNVM_REGISTER_OP(dropout)
)" NNVM_ADD_FILELINE) )" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input to which dropout will be applied") .add_argument("data", "Tensor", "Input to which dropout will be applied")
.add_arguments(DropoutParam::__FIELDS__())
.set_attr_parser(ParamParser<DropoutParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<DropoutParam>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(2) .set_num_outputs(2)
.set_attr_parser(ParamParser<DropoutParam>)
.add_arguments(DropoutParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 2>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 2>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 2>)
.set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) { .set_attr<FNumVisibleOutputs>("FNumVisibleOutputs", [](const NodeAttrs& attrs) {
...@@ -172,10 +174,11 @@ axis to be the last item in the input shape. ...@@ -172,10 +174,11 @@ axis to be the last item in the input shape.
.add_argument("beta", "Tensor", "The beta offset factor") .add_argument("beta", "Tensor", "The beta offset factor")
.add_argument("moving_mean", "Tensor", "running mean of input") .add_argument("moving_mean", "Tensor", "running mean of input")
.add_argument("moving_var", "Tensor", "running variance of input") .add_argument("moving_var", "Tensor", "running variance of input")
.add_arguments(BatchNormParam::__FIELDS__())
.set_attr_parser(ParamParser<BatchNormParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BatchNormParam>)
.set_num_inputs(5) .set_num_inputs(5)
.set_num_outputs(3) .set_num_outputs(3)
.set_attr_parser(ParamParser<BatchNormParam>)
.add_arguments(BatchNormParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", BatchNormInferShape) .set_attr<FInferShape>("FInferShape", BatchNormInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<5, 3>) .set_attr<FInferType>("FInferType", ElemwiseType<5, 3>)
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
...@@ -203,10 +206,12 @@ NNVM_REGISTER_OP(softmax) ...@@ -203,10 +206,12 @@ NNVM_REGISTER_OP(softmax)
.. note:: .. note::
This operator can be optimized away for inference. This operator can be optimized away for inference.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.")
.add_arguments(SoftmaxParam::__FIELDS__())
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SoftmaxParam>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr_parser(ParamParser<SoftmaxParam>)
.add_arguments(SoftmaxParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1); .set_support_level(1);
...@@ -220,10 +225,12 @@ NNVM_REGISTER_OP(log_softmax) ...@@ -220,10 +225,12 @@ NNVM_REGISTER_OP(log_softmax)
.. note:: .. note::
This operator can be optimized away for inference. This operator can be optimized away for inference.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.")
.add_arguments(SoftmaxParam::__FIELDS__())
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SoftmaxParam>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr_parser(ParamParser<SoftmaxParam>)
.add_arguments(SoftmaxParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1); .set_support_level(1);
...@@ -237,10 +244,12 @@ NNVM_REGISTER_OP(leaky_relu) ...@@ -237,10 +244,12 @@ NNVM_REGISTER_OP(leaky_relu)
`y = x > 0 ? x : alpha * x` `y = x > 0 ? x : alpha * x`
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.")
.add_arguments(LeakyReLUParam::__FIELDS__())
.set_attr_parser(ParamParser<LeakyReLUParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<LeakyReLUParam>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr_parser(ParamParser<LeakyReLUParam>)
.add_arguments(LeakyReLUParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_support_level(1); .set_support_level(1);
......
...@@ -72,6 +72,7 @@ NNVM_REGISTER_OP(max_pool2d) ...@@ -72,6 +72,7 @@ NNVM_REGISTER_OP(max_pool2d)
.add_argument("data", "4D Tensor", "Input data.") .add_argument("data", "4D Tensor", "Input data.")
.add_arguments(Pool2DParam::__FIELDS__()) .add_arguments(Pool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<Pool2DParam>) .set_attr_parser(ParamParser<Pool2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
...@@ -98,10 +99,11 @@ NNVM_REGISTER_OP(avg_pool2d) ...@@ -98,10 +99,11 @@ NNVM_REGISTER_OP(avg_pool2d)
.add_argument("data", "4D Tensor", "Input data.") .add_argument("data", "4D Tensor", "Input data.")
.add_arguments(Pool2DParam::__FIELDS__()) .add_arguments(Pool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<Pool2DParam>) .set_attr_parser(ParamParser<Pool2DParam>)
.set_num_outputs(1) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Pool2DParam>)
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape) .set_attr<FInferShape>("FInferShape", Pool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
...@@ -135,10 +137,11 @@ NNVM_REGISTER_OP(global_max_pool2d) ...@@ -135,10 +137,11 @@ NNVM_REGISTER_OP(global_max_pool2d)
.add_argument("data", "4D Tensor", "Input data.") .add_argument("data", "4D Tensor", "Input data.")
.add_arguments(GlobalPool2DParam::__FIELDS__()) .add_arguments(GlobalPool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<GlobalPool2DParam>) .set_attr_parser(ParamParser<GlobalPool2DParam>)
.set_num_outputs(1) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
...@@ -154,10 +157,11 @@ NNVM_REGISTER_OP(global_avg_pool2d) ...@@ -154,10 +157,11 @@ NNVM_REGISTER_OP(global_avg_pool2d)
.add_argument("data", "4D Tensor", "Input data.") .add_argument("data", "4D Tensor", "Input data.")
.add_arguments(GlobalPool2DParam::__FIELDS__()) .add_arguments(GlobalPool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<GlobalPool2DParam>) .set_attr_parser(ParamParser<GlobalPool2DParam>)
.set_num_outputs(1) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape) .set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
} // namespace top } // namespace top
......
...@@ -37,6 +37,19 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { ...@@ -37,6 +37,19 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param); attrs->parsed = std::move(param);
} }
/*!
* \brief Parse keyword arguments as PType arguments and save to parsed
* \tparam PType the arameter type.
* \param attrs The attributes.
*/
template<typename PType>
inline std::unordered_map<std::string, std::string>
ParamGetAttrDict(const nnvm::NodeAttrs& attrs) {
std::unordered_map<std::string, std::string> dict = attrs.dict;
nnvm::get<PType>(attrs.parsed).UpdateDict(&dict);
return dict;
}
/*! \brief check if shape is empty or contains unkown (0) dim. */ /*! \brief check if shape is empty or contains unkown (0) dim. */
inline bool shape_is_none(const TShape& x) { inline bool shape_is_none(const TShape& x) {
return x.ndim() == 0 || x.Size() == 0; return x.ndim() == 0 || x.Size() == 0;
......
...@@ -61,13 +61,14 @@ The dimension which you do not want to change can also be kept as `0` which mean ...@@ -61,13 +61,14 @@ The dimension which you do not want to change can also be kept as `0` which mean
So with `shape=(2,0)`, we will obtain the same result as in the above example. So with `shape=(2,0)`, we will obtain the same result as in the above example.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<BroadcastToParam>) .add_argument("data", "Tensor", "Input data.")
.add_arguments(BroadcastToParam::__FIELDS__()) .add_arguments(BroadcastToParam::__FIELDS__())
.set_num_inputs(1) .set_attr_parser(ParamParser<BroadcastToParam>)
.set_num_outputs(1) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape) .set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Input data.") .set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(4); .set_support_level(4);
// binary broadcast op // binary broadcast op
......
...@@ -95,70 +95,60 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) ...@@ -95,70 +95,60 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
// unary scalar op // unary scalar op
DMLC_REGISTER_PARAMETER(ScalarParam); DMLC_REGISTER_PARAMETER(ScalarParam);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__add_scalar__) #define NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(op) \
NNVM_REGISTER_ELEMWISE_UNARY_OP(op) \
.add_arguments(ScalarParam::__FIELDS__()) \
.set_attr_parser(ParamParser<ScalarParam>) \
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ScalarParam>)
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__)
.describe(R"code(Tensor add scalar .describe(R"code(Tensor add scalar
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__sub_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__)
.describe(R"code(Tensor substract scalar .describe(R"code(Tensor substract scalar
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rsub_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__)
.describe(R"code(scalar substract Tensor .describe(R"code(scalar substract Tensor
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__mul_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__)
.describe(R"code(Tensor multiplies scalar .describe(R"code(Tensor multiplies scalar
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__div_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__div_scalar__)
.describe(R"code(Tensor divides scalar .describe(R"code(Tensor divides scalar
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rdiv_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rdiv_scalar__)
.describe(R"code(scalar divides Tensor .describe(R"code(scalar divides Tensor
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__pow_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__pow_scalar__)
.describe(R"code(Tensor power scalar .describe(R"code(Tensor power scalar
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
NNVM_REGISTER_ELEMWISE_UNARY_OP(__rpow_scalar__) NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rpow_scalar__)
.describe(R"code(scalar power Tensor .describe(R"code(scalar power Tensor
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_attr_parser(ParamParser<ScalarParam>)
.add_arguments(ScalarParam::__FIELDS__())
.set_support_level(3); .set_support_level(3);
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -90,15 +90,17 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) { ...@@ -90,15 +90,17 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param); attrs->parsed = std::move(param);
} }
#define NNVM_REGISTER_REDUCE_OP(op) \ #define NNVM_REGISTER_REDUCE_OP(op) \
NNVM_REGISTER_OP(op) \ NNVM_REGISTER_OP(op) \
.set_num_inputs(1) \ .add_argument("data", "Tensor", "The input") \
.set_num_outputs(1) \ .add_arguments(ReduceParam::__FIELDS__()) \
.set_attr_parser(AxesParamParser<ReduceParam>) \ .set_attr_parser(AxesParamParser<ReduceParam>) \
.set_attr<FInferShape>("FInferShape", ReduceShape) \ .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferShape>("FInferShape", ReduceShape) \
.add_argument("data", "Tensor", "The input") \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.add_arguments(ReduceParam::__FIELDS__()) .set_num_inputs(1) \
.set_num_outputs(1) \
NNVM_REGISTER_REDUCE_OP(sum) NNVM_REGISTER_REDUCE_OP(sum)
......
...@@ -132,13 +132,14 @@ Example:: ...@@ -132,13 +132,14 @@ Example::
[ 5., 5., 8., 8.]] [ 5., 5., 8., 8.]]
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_num_outputs(1)
.set_num_inputs(kVarg)
.set_attr_parser(ParamParser<ConcatenateParam>)
.add_arguments(ConcatenateParam::__FIELDS__())
.add_argument("data", "Tensor-or-Tensor[]", "List of arrays to concatenate") .add_argument("data", "Tensor-or-Tensor[]", "List of arrays to concatenate")
.add_arguments(ConcatenateParam::__FIELDS__())
.set_attr_parser(ParamParser<ConcatenateParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape) .set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_num_outputs(1)
.set_num_inputs(kVarg)
.set_support_level(1); .set_support_level(1);
...@@ -204,6 +205,7 @@ inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) { ...@@ -204,6 +205,7 @@ inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) {
} }
} }
// Intentionally not add ParamGetAttrDict for indices_or_sections.
NNVM_REGISTER_OP(split) NNVM_REGISTER_OP(split)
.describe(R"code(Splits an array along a particular axis into multiple sub-arrays. .describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
...@@ -211,13 +213,13 @@ NNVM_REGISTER_OP(split) ...@@ -211,13 +213,13 @@ NNVM_REGISTER_OP(split)
along which to split the array. along which to split the array.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attr_parser(SplitParamParser)
.set_num_outputs(SplitNumOutputs)
.add_arguments(SplitParam::__FIELDS__())
.add_argument("data", "Tensor", "List of arrays to concatenate") .add_argument("data", "Tensor", "List of arrays to concatenate")
.add_arguments(SplitParam::__FIELDS__())
.set_attr_parser(SplitParamParser)
.set_attr<FInferShape>("FInferShape", SplitInferShape) .set_attr<FInferShape>("FInferShape", SplitInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_num_inputs(1)
.set_num_outputs(SplitNumOutputs)
.set_support_level(1); .set_support_level(1);
// cast // cast
...@@ -237,8 +239,9 @@ NNVM_REGISTER_OP(cast) ...@@ -237,8 +239,9 @@ NNVM_REGISTER_OP(cast)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data array") .add_argument("data", "Tensor", "Input data array")
.set_attr_parser(ParamParser<CastParam>)
.add_arguments(CastParam::__FIELDS__()) .add_arguments(CastParam::__FIELDS__())
.set_attr_parser(ParamParser<CastParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", CastInferType) .set_attr<FInferType>("FInferType", CastInferType)
.set_num_inputs(1) .set_num_inputs(1)
...@@ -387,13 +390,14 @@ The significance of each is explained below: ...@@ -387,13 +390,14 @@ The significance of each is explained below:
- input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_num_inputs(1) .add_argument("data", "Tensor", "Input data.")
.set_num_outputs(1)
.set_attr_parser(ParamParser<ReshapeParam>)
.add_arguments(ReshapeParam::__FIELDS__()) .add_arguments(ReshapeParam::__FIELDS__())
.set_attr_parser(ParamParser<ReshapeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>)
.set_attr<FInferShape>("FInferShape", ReshapeInferShape) .set_attr<FInferShape>("FInferShape", ReshapeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Input data.") .set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(3); .set_support_level(3);
// tranpose // tranpose
...@@ -453,13 +457,14 @@ Examples:: ...@@ -453,13 +457,14 @@ Examples::
[[ 3., 4.], [[ 3., 4.],
[ 7., 8.]]] [ 7., 8.]]]
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.set_num_inputs(1) .add_argument("data", "Tensor", "Source input")
.set_num_outputs(1)
.set_attr_parser(ParamParser<TransposeParam>)
.add_arguments(TransposeParam::__FIELDS__()) .add_arguments(TransposeParam::__FIELDS__())
.set_attr_parser(ParamParser<TransposeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", TransposeShape) .set_attr<nnvm::FInferShape>("FInferShape", TransposeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.add_argument("data", "Tensor", "Source input") .set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(4); .set_support_level(4);
} // namespace top } // namespace top
......
import numpy as np
import tvm
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
def test_compile():
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.exp(y + x)
shape = (10, 128)
dtype = tvm.float32
shape_dict = {"x": shape, "y": shape}
graph, lib = nnvm.compiler.build(z, "llvm", shape_dict)
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
# set inputs
set_input("x", na)
set_input("y", nb)
# execute
run()
# get outputs
out = tvm.nd.empty(shape, dtype)
get_output(0, out)
np.testing.assert_allclose(
out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
if __name__ == "__main__":
test_compile()
from tvm.contrib import util, rpc
import tvm
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
import numpy as np
def test_rpc_executor():
host = "localhost"
port = 9091
server = rpc.Server(host, port)
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.exp(y + x)
shape = (10, 128)
dtype = tvm.float32
shape_dict = {"x": shape, "y": shape}
tmp = util.tempdir()
lib_name = tmp.relpath("net.o")
graph, lib = nnvm.compiler.build(z, "llvm", shape_dict)
# save module
lib.save(lib_name)
remote = rpc.connect(host, port)
remote.upload(lib_name)
ctx = remote.cpu(0)
# load remote
rlib = remote.load_module("net.o")
# Create remotemodule
m = nnvm.runtime.create(graph, rlib, remote.cpu(0))
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
na = tvm.nd.array(np.ones(shape).astype(dtype), ctx)
nb = tvm.nd.array(np.ones(shape).astype(dtype), ctx)
# set inputs
set_input("x", na)
set_input("y", nb)
# execute
run()
# get outputs
out = tvm.nd.empty(shape, dtype, ctx)
get_output(0, out)
np.testing.assert_allclose(
out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
server.terminate()
if __name__ == "__main__":
test_rpc_executor()
import numpy as np
import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
def test_conv2d():
x = sym.Variable("x")
y = sym.conv2d(x, channels=10, kernel_size=(3, 3),
name="y", use_bias=False, padding=(1,1))
dtype = "float32"
dshape = (1, 3, 18, 18)
kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape}
graph, lib = nnvm.compiler.build(y, "llvm", shape_dict)
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# execute
run()
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
set_input("x", data)
set_input("y_weight", kernel)
# execute
run()
# get outputs
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
c_np = topi.testing.conv2d_nchw_python(
data.asnumpy(), kernel.asnumpy(), 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
if __name__ == "__main__":
test_conv2d()
...@@ -6,13 +6,11 @@ def infer_shape(sym): ...@@ -6,13 +6,11 @@ def infer_shape(sym):
g = graph.create(sym) g = graph.create(sym)
g._set_json_attr("shape_attr_key", "shape") g._set_json_attr("shape_attr_key", "shape")
g = g.apply("InferShape") g = g.apply("InferShape")
jgraph = json.loads(g.apply("SaveJSON").json_attr("json"))
jnodes = jgraph["nodes"]
jnode_row_ptr = jgraph["node_row_ptr"]
sdict = {} sdict = {}
vshape = g.json_attr("shape") vshape = g.json_attr("shape")
for i, n in enumerate(jnodes): entry_ptr = g.index.entry_ptr
begin, end = jnode_row_ptr[i], jnode_row_ptr[i + 1] for i, n in enumerate(g.index.nodes):
begin, end = entry_ptr[i], entry_ptr[i + 1]
sdict[n["name"]] = vshape[begin:end] sdict[n["name"]] = vshape[begin:end]
return sdict return sdict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment