Commit 67179f78 by Tianqi Chen

Enable Alias, refactor C API to reflect Op Semantics (#41)

* Enable Alias, refactor C API to reflect Op Semantics

* add alias example
parent 410062df
......@@ -127,6 +127,7 @@ NNVM_REGISTER_OP(identity)
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.add_alias("__add_symbol__")
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
......
......@@ -26,9 +26,19 @@ namespace dmlc {
template<typename EntryType>
class Registry {
public:
/*! \return list of functions in the registry */
inline static const std::vector<const EntryType*> &List() {
return Get()->entry_list_;
/*! \return list of entries in the registry(excluding alias) */
inline static const std::vector<const EntryType*>& List() {
return Get()->const_list_;
}
/*! \return list all names registered in the registry, including alias */
inline static std::vector<std::string> ListAllNames() {
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
typename std::map<std::string, EntryType*>::const_iterator p;
std::vector<std::string> names;
for (p = fmap.begin(); p !=fmap.end(); ++p) {
names.push_back(p->first);
}
return names;
}
/*!
* \brief Find the entry with corresponding name.
......@@ -45,6 +55,21 @@ class Registry {
}
}
/*!
* \brief Add alias to the key_name
* \param key_name The original entry key
* \param alias The alias key.
*/
inline void AddAlias(const std::string& key_name,
const std::string& alias) {
EntryType* e = fmap_.at(key_name);
if (fmap_.count(alias)) {
CHECK_EQ(e, fmap_.at(alias))
<< "Entry " << e->name << " already registered under different entry";
} else {
fmap_[alias] = e;
}
}
/*!
* \brief Internal function to register a name function under name.
* \param name name of the function
* \return ref to the registered entry, used to set properties
......@@ -55,6 +80,7 @@ class Registry {
EntryType *e = new EntryType();
e->name = name;
fmap_[name] = e;
const_list_.push_back(e);
entry_list_.push_back(e);
return *e;
}
......@@ -79,16 +105,17 @@ class Registry {
private:
/*! \brief list of entry types */
std::vector<const EntryType*> entry_list_;
std::vector<EntryType*> entry_list_;
/*! \brief list of entry types */
std::vector<const EntryType*> const_list_;
/*! \brief map of name->function */
std::map<std::string, EntryType*> fmap_;
/*! \brief constructor */
Registry() {}
/*! \brief destructor */
~Registry() {
for (typename std::map<std::string, EntryType*>::iterator p = fmap_.begin();
p != fmap_.end(); ++p) {
delete p->second;
for (size_t i = 0; i < entry_list_.size(); ++i) {
delete entry_list_[i];
}
}
};
......
......@@ -29,7 +29,7 @@
typedef unsigned int nn_uint;
/*! \brief handle to a function that takes param and creates symbol */
typedef void *AtomicSymbolCreator;
typedef void *OpHandle;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to Graph */
......@@ -53,17 +53,39 @@ NNVM_DLL void NNAPISetLastError(const char* msg);
NNVM_DLL const char *NNGetLastError(void);
/*!
* \brief list all the available AtomicSymbolEntry
* \brief list all the available operator names, include entries.
* \param out_size the size of returned array
* \param out_array the output operator name array.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNListAllOpNames(nn_uint *out_size,
const char*** out_array);
/*!
* \brief Get operator handle given name.
* \param op_name The name of the operator.
* \param op_out The returnning op handle.
*/
NNVM_DLL int NNGetOpHandle(const char* op_name,
OpHandle* op_out);
/*!
* \brief list all the available operators.
* This won't include the alias, use ListAllNames
* instead to get all alias names.
*
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
NNVM_DLL int NNListUniqueOps(nn_uint *out_size,
OpHandle **out_array);
/*!
* \brief Get the detailed information about atomic symbol.
* \param creator the AtomicSymbolCreator.
* \param name The returned name of the creator.
* \param op The operator handle.
* \param real_name The returned name of the creator.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args
......@@ -72,24 +94,24 @@ NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
NNVM_DLL int NNGetOpInfo(OpHandle op,
const char **real_name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
/*!
* \brief Create an AtomicSymbol functor.
* \param creator the AtomicSymbolCreator
* \param op The operator handle
* \param num_param the number of parameters
* \param keys the keys to the params
* \param vals the vals of the params
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param,
const char **keys,
const char **vals,
......
......@@ -200,6 +200,13 @@ class Op {
*/
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator.
* \return reference to self.
*/
Op& add_alias(const std::string& alias); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
......
......@@ -45,7 +45,7 @@ _LIB = _load_lib()
# type definitions
nn_uint = ctypes.c_uint
SymbolCreatorHandle = ctypes.c_void_p
OpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p
......
......@@ -8,7 +8,7 @@ import ctypes
import sys
from .._base import _LIB
from .._base import c_array, c_str, nn_uint, py_str, string_types
from .._base import SymbolHandle
from .._base import SymbolHandle, OpHandle
from .._base import check_call, ctypes2docstring
from ..name import NameManager
from ..attribute import AttrScope
......@@ -114,9 +114,9 @@ def _set_symbol_class(cls):
_symbol_cls = cls
def _make_atomic_symbol_function(handle):
def _make_atomic_symbol_function(handle, name):
"""Create an atomic symbol function by handle and funciton name."""
name = ctypes.c_char_p()
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = nn_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
......@@ -124,15 +124,15 @@ def _make_atomic_symbol_function(handle):
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p()
check_call(_LIB.NNSymbolGetAtomicSymbolInfo(
handle, ctypes.byref(name), ctypes.byref(desc),
check_call(_LIB.NNGetOpInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(ret_type)))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name.value)
func_name = name
desc = py_str(desc.value)
doc_str = ('%s\n\n' +
......@@ -199,22 +199,25 @@ def _make_atomic_symbol_function(handle):
return creator
def _init_symbol_module():
def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
plist = ctypes.POINTER(ctypes.c_void_p)()
_set_symbol_class(symbol_class)
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
module_obj = sys.modules["nnvm.symbol"]
module_internal = sys.modules["nnvm._symbol_internal"]
check_call(_LIB.NNListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
hdl = SymbolHandle(plist[i])
function = _make_atomic_symbol_function(hdl)
op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.symbol" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
ctypedef void* SymbolHandle
ctypedef void* AtomicSymbolCreator
ctypedef void* OpHandle
ctypedef unsigned nn_uint
cdef py_str(const char* x):
......
ctypedef void* SymbolHandle
cdef class Symbol:
# handle for symbolic operator.
cdef SymbolHandle handle
......@@ -14,21 +14,25 @@ include "./base.pyi"
cdef extern from "nnvm/c_api.h":
const char* NNGetLastError();
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
int NNListAllOpNames(nn_uint *out_size,
const char ***out_array);
int NNGetOpHandle(const char *op_name,
OpHandle *handle);
int NNGetOpInfo(OpHandle op,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNListOpNames(nn_uint *out_size,
const char ***out_array);
int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNSymbolFree(SymbolHandle symbol);
int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
......@@ -88,7 +92,7 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
_symbol_cls = SymbolBase
def _set_symbol_class(cls):
cdef _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
......@@ -98,9 +102,9 @@ cdef NewSymbol(SymbolHandle handle):
(<SymbolBase>sym).handle = handle
return sym
cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
cdef _make_atomic_symbol_function(OpHandle handle, string name):
"""Create an atomic symbol function by handle and funciton name."""
cdef const char *name
cdef const char *real_name
cdef const char *desc
cdef nn_uint num_args
cdef const char** arg_names
......@@ -108,13 +112,14 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
cdef const char** arg_descs
cdef const char* return_type
CALL(NNSymbolGetAtomicSymbolInfo(
handle, &name, &desc,
CALL(NNGetOpInfo(
handle, &real_name, &desc,
&num_args, &arg_names,
&arg_types, &arg_descs,
&return_type))
param_str = BuildDoc(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name)
func_name = py_str(name.c_str())
doc_str = ('%s\n\n' +
'%s\n' +
'name : string, optional.\n' +
......@@ -190,20 +195,23 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
return creator
def _init_symbol_module():
def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
cdef AtomicSymbolCreator* plist
cdef const char** op_name_ptrs
cdef nn_uint size
CALL(NNSymbolListAtomicSymbolCreators(&size, &plist))
module_obj = _sys.modules["nnvm.symbol"]
module_internal = _sys.modules["nnvm._symbol_internal"]
for i in range(size):
function = _make_atomic_symbol_function(plist[i])
cdef vector[string] op_names
cdef OpHandle handle
_set_symbol_class(symbol_class)
CALL(NNListAllOpNames(&size, &op_name_ptrs))
for i in range(size):
op_names.push_back(string(op_name_ptrs[i]));
module_obj = _sys.modules["%s.symbol" % root_namespace]
module_internal = _sys.modules["%s._symbol_internal" % root_namespace]
for i in range(op_names.size()):
CALL(NNGetOpHandle(op_names[i].c_str(), &handle))
function = _make_atomic_symbol_function(handle, op_names[i])
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
......@@ -12,15 +12,16 @@ from .attribute import AttrScope
# Use different verison of SymbolBase
# When possible, use cython to speedup part of computation.
try:
if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0:
from .ctypes.symbol import SymbolBase, _set_symbol_class
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from ._ctypes.symbol import SymbolBase, _init_symbol_module
elif _sys.version_info >= (3, 0):
from ._cy3.symbol import SymbolBase, _set_symbol_class
from ._cy3.symbol import SymbolBase, _init_symbol_module
else:
from ._cy2.symbol import SymbolBase, _set_symbol_class
except:
from .ctypes.symbol import SymbolBase, _set_symbol_class
from ._cy2.symbol import SymbolBase, _init_symbol_module
except ImportError:
from ._ctypes.symbol import SymbolBase, _init_symbol_module
class Symbol(SymbolBase):
......@@ -286,4 +287,4 @@ def Group(symbols):
return Symbol(handle)
# Set the real symbol class to Symbol
_set_symbol_class(Symbol)
_init_symbol_module(Symbol, "nnvm")
......@@ -10,24 +10,45 @@
using namespace nnvm;
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array) {
int NNListAllOpNames(nn_uint *out_size,
const char*** out_array) {
API_BEGIN();
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<nn_uint>(ret->ret_vec_str.size());
API_END();
}
int NNGetOpHandle(const char* op_name,
OpHandle* op_out) {
API_BEGIN();
*op_out = (OpHandle)Op::Get(op_name); // NOLINT(*)
API_END();
}
int NNListUniqueOps(nn_uint *out_size,
OpHandle **out_array) {
API_BEGIN();
auto &vec = dmlc::Registry<Op>::List();
*out_size = static_cast<nn_uint>(vec.size());
*out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*)
*out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
const Op *op = static_cast<const Op *>(creator);
int NNGetOpInfo(OpHandle handle,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
const Op *op = static_cast<const Op *>(handle);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
......@@ -51,7 +72,7 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
API_END();
}
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
int NNSymbolCreateAtomicSymbol(OpHandle creator,
nn_uint num_param,
const char **keys,
const char **vals,
......
......@@ -38,6 +38,11 @@ Op::Op() {
index_ = mgr->op_counter++;
}
Op& Op::add_alias(const std::string& alias) { // NOLINT(*)
dmlc::Registry<Op>::Get()->AddAlias(this->name, alias);
return *this;
}
// find operator by name
const Op* Op::Get(const std::string& name) {
const Op* op = dmlc::Registry<Op>::Find(name);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment