Commit 67179f78 by Tianqi Chen

Enable Alias, refactor C API to reflect Op Semantics (#41)

* Enable Alias, refactor C API to reflect Op Semantics

* add alias example
parent 410062df
...@@ -127,6 +127,7 @@ NNVM_REGISTER_OP(identity) ...@@ -127,6 +127,7 @@ NNVM_REGISTER_OP(identity)
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
.describe("add two data together") .describe("add two data together")
.set_num_inputs(2) .set_num_inputs(2)
.add_alias("__add_symbol__")
.attr<FInferShape>("FInferShape", SameShape) .attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0) .attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>( .attr<FGradient>(
......
...@@ -26,9 +26,19 @@ namespace dmlc { ...@@ -26,9 +26,19 @@ namespace dmlc {
template<typename EntryType> template<typename EntryType>
class Registry { class Registry {
public: public:
/*! \return list of functions in the registry */ /*! \return list of entries in the registry(excluding alias) */
inline static const std::vector<const EntryType*> &List() { inline static const std::vector<const EntryType*>& List() {
return Get()->entry_list_; return Get()->const_list_;
}
/*! \return list all names registered in the registry, including alias */
inline static std::vector<std::string> ListAllNames() {
const std::map<std::string, EntryType*> &fmap = Get()->fmap_;
typename std::map<std::string, EntryType*>::const_iterator p;
std::vector<std::string> names;
for (p = fmap.begin(); p !=fmap.end(); ++p) {
names.push_back(p->first);
}
return names;
} }
/*! /*!
* \brief Find the entry with corresponding name. * \brief Find the entry with corresponding name.
...@@ -45,6 +55,21 @@ class Registry { ...@@ -45,6 +55,21 @@ class Registry {
} }
} }
/*! /*!
* \brief Add alias to the key_name
* \param key_name The original entry key
* \param alias The alias key.
*/
inline void AddAlias(const std::string& key_name,
const std::string& alias) {
EntryType* e = fmap_.at(key_name);
if (fmap_.count(alias)) {
CHECK_EQ(e, fmap_.at(alias))
<< "Entry " << e->name << " already registered under different entry";
} else {
fmap_[alias] = e;
}
}
/*!
* \brief Internal function to register a name function under name. * \brief Internal function to register a name function under name.
* \param name name of the function * \param name name of the function
* \return ref to the registered entry, used to set properties * \return ref to the registered entry, used to set properties
...@@ -55,6 +80,7 @@ class Registry { ...@@ -55,6 +80,7 @@ class Registry {
EntryType *e = new EntryType(); EntryType *e = new EntryType();
e->name = name; e->name = name;
fmap_[name] = e; fmap_[name] = e;
const_list_.push_back(e);
entry_list_.push_back(e); entry_list_.push_back(e);
return *e; return *e;
} }
...@@ -79,16 +105,17 @@ class Registry { ...@@ -79,16 +105,17 @@ class Registry {
private: private:
/*! \brief list of entry types */ /*! \brief list of entry types */
std::vector<const EntryType*> entry_list_; std::vector<EntryType*> entry_list_;
/*! \brief list of entry types */
std::vector<const EntryType*> const_list_;
/*! \brief map of name->function */ /*! \brief map of name->function */
std::map<std::string, EntryType*> fmap_; std::map<std::string, EntryType*> fmap_;
/*! \brief constructor */ /*! \brief constructor */
Registry() {} Registry() {}
/*! \brief destructor */ /*! \brief destructor */
~Registry() { ~Registry() {
for (typename std::map<std::string, EntryType*>::iterator p = fmap_.begin(); for (size_t i = 0; i < entry_list_.size(); ++i) {
p != fmap_.end(); ++p) { delete entry_list_[i];
delete p->second;
} }
} }
}; };
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
typedef unsigned int nn_uint; typedef unsigned int nn_uint;
/*! \brief handle to a function that takes param and creates symbol */ /*! \brief handle to a function that takes param and creates symbol */
typedef void *AtomicSymbolCreator; typedef void *OpHandle;
/*! \brief handle to a symbol that can be bind as operator */ /*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle; typedef void *SymbolHandle;
/*! \brief handle to Graph */ /*! \brief handle to Graph */
...@@ -53,17 +53,39 @@ NNVM_DLL void NNAPISetLastError(const char* msg); ...@@ -53,17 +53,39 @@ NNVM_DLL void NNAPISetLastError(const char* msg);
NNVM_DLL const char *NNGetLastError(void); NNVM_DLL const char *NNGetLastError(void);
/*! /*!
* \brief list all the available AtomicSymbolEntry * \brief list all the available operator names, include entries.
* \param out_size the size of returned array
* \param out_array the output operator name array.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNListAllOpNames(nn_uint *out_size,
const char*** out_array);
/*!
* \brief Get operator handle given name.
* \param op_name The name of the operator.
* \param op_out The returnning op handle.
*/
NNVM_DLL int NNGetOpHandle(const char* op_name,
OpHandle* op_out);
/*!
* \brief list all the available operators.
* This won't include the alias, use ListAllNames
* instead to get all alias names.
*
* \param out_size the size of returned array * \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array * \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, NNVM_DLL int NNListUniqueOps(nn_uint *out_size,
AtomicSymbolCreator **out_array); OpHandle **out_array);
/*! /*!
* \brief Get the detailed information about atomic symbol. * \brief Get the detailed information about atomic symbol.
* \param creator the AtomicSymbolCreator. * \param op The operator handle.
* \param name The returned name of the creator. * \param real_name The returned name of the creator.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol. * \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents. * \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args * \param arg_names Name of the arguments of doc args
...@@ -72,24 +94,24 @@ NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, ...@@ -72,24 +94,24 @@ NNVM_DLL int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
* \param return_type Return type of the function, if any. * \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, NNVM_DLL int NNGetOpInfo(OpHandle op,
const char **name, const char **real_name,
const char **description, const char **description,
nn_uint *num_doc_args, nn_uint *num_doc_args,
const char ***arg_names, const char ***arg_names,
const char ***arg_type_infos, const char ***arg_type_infos,
const char ***arg_descriptions, const char ***arg_descriptions,
const char **return_type); const char **return_type);
/*! /*!
* \brief Create an AtomicSymbol functor. * \brief Create an AtomicSymbol functor.
* \param creator the AtomicSymbolCreator * \param op The operator handle
* \param num_param the number of parameters * \param num_param the number of parameters
* \param keys the keys to the params * \param keys the keys to the params
* \param vals the vals of the params * \param vals the vals of the params
* \param out pointer to the created symbol handle * \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param, nn_uint num_param,
const char **keys, const char **keys,
const char **vals, const char **vals,
......
...@@ -200,6 +200,13 @@ class Op { ...@@ -200,6 +200,13 @@ class Op {
*/ */
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*) inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*! /*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator.
* \return reference to self.
*/
Op& add_alias(const std::string& alias); // NOLINT(*)
/*!
* \brief Register additional attributes to operator. * \brief Register additional attributes to operator.
* \param attr_name The name of the attribute. * \param attr_name The name of the attribute.
* \param value The value to be set. * \param value The value to be set.
......
...@@ -45,7 +45,7 @@ _LIB = _load_lib() ...@@ -45,7 +45,7 @@ _LIB = _load_lib()
# type definitions # type definitions
nn_uint = ctypes.c_uint nn_uint = ctypes.c_uint
SymbolCreatorHandle = ctypes.c_void_p OpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p
......
...@@ -8,7 +8,7 @@ import ctypes ...@@ -8,7 +8,7 @@ import ctypes
import sys import sys
from .._base import _LIB from .._base import _LIB
from .._base import c_array, c_str, nn_uint, py_str, string_types from .._base import c_array, c_str, nn_uint, py_str, string_types
from .._base import SymbolHandle from .._base import SymbolHandle, OpHandle
from .._base import check_call, ctypes2docstring from .._base import check_call, ctypes2docstring
from ..name import NameManager from ..name import NameManager
from ..attribute import AttrScope from ..attribute import AttrScope
...@@ -114,9 +114,9 @@ def _set_symbol_class(cls): ...@@ -114,9 +114,9 @@ def _set_symbol_class(cls):
_symbol_cls = cls _symbol_cls = cls
def _make_atomic_symbol_function(handle): def _make_atomic_symbol_function(handle, name):
"""Create an atomic symbol function by handle and funciton name.""" """Create an atomic symbol function by handle and funciton name."""
name = ctypes.c_char_p() real_name = ctypes.c_char_p()
desc = ctypes.c_char_p() desc = ctypes.c_char_p()
num_args = nn_uint() num_args = nn_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)() arg_names = ctypes.POINTER(ctypes.c_char_p)()
...@@ -124,15 +124,15 @@ def _make_atomic_symbol_function(handle): ...@@ -124,15 +124,15 @@ def _make_atomic_symbol_function(handle):
arg_descs = ctypes.POINTER(ctypes.c_char_p)() arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p() ret_type = ctypes.c_char_p()
check_call(_LIB.NNSymbolGetAtomicSymbolInfo( check_call(_LIB.NNGetOpInfo(
handle, ctypes.byref(name), ctypes.byref(desc), handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args), ctypes.byref(num_args),
ctypes.byref(arg_names), ctypes.byref(arg_names),
ctypes.byref(arg_types), ctypes.byref(arg_types),
ctypes.byref(arg_descs), ctypes.byref(arg_descs),
ctypes.byref(ret_type))) ctypes.byref(ret_type)))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name.value) func_name = name
desc = py_str(desc.value) desc = py_str(desc.value)
doc_str = ('%s\n\n' + doc_str = ('%s\n\n' +
...@@ -199,22 +199,25 @@ def _make_atomic_symbol_function(handle): ...@@ -199,22 +199,25 @@ def _make_atomic_symbol_function(handle):
return creator return creator
def _init_symbol_module(): def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module.""" """List and add all the atomic symbol functions to current module."""
plist = ctypes.POINTER(ctypes.c_void_p)() _set_symbol_class(symbol_class)
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint() size = ctypes.c_uint()
check_call(_LIB.NNSymbolListAtomicSymbolCreators(ctypes.byref(size), check_call(_LIB.NNListAllOpNames(ctypes.byref(size),
ctypes.byref(plist))) ctypes.byref(plist)))
module_obj = sys.modules["nnvm.symbol"] op_names = []
module_internal = sys.modules["nnvm._symbol_internal"]
for i in range(size.value): for i in range(size.value):
hdl = SymbolHandle(plist[i]) op_names.append(py_str(plist[i]))
function = _make_atomic_symbol_function(hdl)
module_obj = sys.modules["%s.symbol" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_'): if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function) setattr(module_internal, function.__name__, function)
else: else:
setattr(module_obj, function.__name__, function) setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
ctypedef void* SymbolHandle ctypedef void* SymbolHandle
ctypedef void* AtomicSymbolCreator ctypedef void* OpHandle
ctypedef unsigned nn_uint ctypedef unsigned nn_uint
cdef py_str(const char* x): cdef py_str(const char* x):
......
ctypedef void* SymbolHandle
cdef class Symbol:
# handle for symbolic operator.
cdef SymbolHandle handle
...@@ -14,21 +14,25 @@ include "./base.pyi" ...@@ -14,21 +14,25 @@ include "./base.pyi"
cdef extern from "nnvm/c_api.h": cdef extern from "nnvm/c_api.h":
const char* NNGetLastError(); const char* NNGetLastError();
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, int NNListAllOpNames(nn_uint *out_size,
AtomicSymbolCreator **out_array); const char ***out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, int NNGetOpHandle(const char *op_name,
OpHandle *handle);
int NNGetOpInfo(OpHandle op,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNListOpNames(nn_uint *out_size,
const char ***out_array);
int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param, nn_uint num_param,
const char **keys, const char **keys,
const char **vals, const char **vals,
SymbolHandle *out); SymbolHandle *out);
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNSymbolFree(SymbolHandle symbol); int NNSymbolFree(SymbolHandle symbol);
int NNSymbolSetAttrs(SymbolHandle symbol, int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param, nn_uint num_param,
...@@ -88,7 +92,7 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): ...@@ -88,7 +92,7 @@ cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
_symbol_cls = SymbolBase _symbol_cls = SymbolBase
def _set_symbol_class(cls): cdef _set_symbol_class(cls):
global _symbol_cls global _symbol_cls
_symbol_cls = cls _symbol_cls = cls
...@@ -98,9 +102,9 @@ cdef NewSymbol(SymbolHandle handle): ...@@ -98,9 +102,9 @@ cdef NewSymbol(SymbolHandle handle):
(<SymbolBase>sym).handle = handle (<SymbolBase>sym).handle = handle
return sym return sym
cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): cdef _make_atomic_symbol_function(OpHandle handle, string name):
"""Create an atomic symbol function by handle and funciton name.""" """Create an atomic symbol function by handle and funciton name."""
cdef const char *name cdef const char *real_name
cdef const char *desc cdef const char *desc
cdef nn_uint num_args cdef nn_uint num_args
cdef const char** arg_names cdef const char** arg_names
...@@ -108,13 +112,14 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -108,13 +112,14 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
cdef const char** arg_descs cdef const char** arg_descs
cdef const char* return_type cdef const char* return_type
CALL(NNSymbolGetAtomicSymbolInfo( CALL(NNGetOpInfo(
handle, &name, &desc, handle, &real_name, &desc,
&num_args, &arg_names, &num_args, &arg_names,
&arg_types, &arg_descs, &arg_types, &arg_descs,
&return_type)) &return_type))
param_str = BuildDoc(num_args, arg_names, arg_types, arg_descs) param_str = BuildDoc(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name) func_name = py_str(name.c_str())
doc_str = ('%s\n\n' + doc_str = ('%s\n\n' +
'%s\n' + '%s\n' +
'name : string, optional.\n' + 'name : string, optional.\n' +
...@@ -190,20 +195,23 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -190,20 +195,23 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
return creator return creator
def _init_symbol_module(): def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module.""" """List and add all the atomic symbol functions to current module."""
cdef AtomicSymbolCreator* plist cdef const char** op_name_ptrs
cdef nn_uint size cdef nn_uint size
CALL(NNSymbolListAtomicSymbolCreators(&size, &plist)) cdef vector[string] op_names
module_obj = _sys.modules["nnvm.symbol"] cdef OpHandle handle
module_internal = _sys.modules["nnvm._symbol_internal"]
for i in range(size):
function = _make_atomic_symbol_function(plist[i])
_set_symbol_class(symbol_class)
CALL(NNListAllOpNames(&size, &op_name_ptrs))
for i in range(size):
op_names.push_back(string(op_name_ptrs[i]));
module_obj = _sys.modules["%s.symbol" % root_namespace]
module_internal = _sys.modules["%s._symbol_internal" % root_namespace]
for i in range(op_names.size()):
CALL(NNGetOpHandle(op_names[i].c_str(), &handle))
function = _make_atomic_symbol_function(handle, op_names[i])
if function.__name__.startswith('_'): if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function) setattr(module_internal, function.__name__, function)
else: else:
setattr(module_obj, function.__name__, function) setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
...@@ -12,15 +12,16 @@ from .attribute import AttrScope ...@@ -12,15 +12,16 @@ from .attribute import AttrScope
# Use different verison of SymbolBase # Use different verison of SymbolBase
# When possible, use cython to speedup part of computation. # When possible, use cython to speedup part of computation.
try: try:
if int(_os.environ.get("NNVM_ENABLE_CYTHON", True)) == 0: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .ctypes.symbol import SymbolBase, _set_symbol_class from ._ctypes.symbol import SymbolBase, _init_symbol_module
elif _sys.version_info >= (3, 0): elif _sys.version_info >= (3, 0):
from ._cy3.symbol import SymbolBase, _set_symbol_class from ._cy3.symbol import SymbolBase, _init_symbol_module
else: else:
from ._cy2.symbol import SymbolBase, _set_symbol_class from ._cy2.symbol import SymbolBase, _init_symbol_module
except: except ImportError:
from .ctypes.symbol import SymbolBase, _set_symbol_class from ._ctypes.symbol import SymbolBase, _init_symbol_module
class Symbol(SymbolBase): class Symbol(SymbolBase):
...@@ -286,4 +287,4 @@ def Group(symbols): ...@@ -286,4 +287,4 @@ def Group(symbols):
return Symbol(handle) return Symbol(handle)
# Set the real symbol class to Symbol # Set the real symbol class to Symbol
_set_symbol_class(Symbol) _init_symbol_module(Symbol, "nnvm")
...@@ -10,24 +10,45 @@ ...@@ -10,24 +10,45 @@
using namespace nnvm; using namespace nnvm;
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, int NNListAllOpNames(nn_uint *out_size,
AtomicSymbolCreator **out_array) { const char*** out_array) {
API_BEGIN();
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<Op>::ListAllNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<nn_uint>(ret->ret_vec_str.size());
API_END();
}
int NNGetOpHandle(const char* op_name,
OpHandle* op_out) {
API_BEGIN();
*op_out = (OpHandle)Op::Get(op_name); // NOLINT(*)
API_END();
}
int NNListUniqueOps(nn_uint *out_size,
OpHandle **out_array) {
API_BEGIN(); API_BEGIN();
auto &vec = dmlc::Registry<Op>::List(); auto &vec = dmlc::Registry<Op>::List();
*out_size = static_cast<nn_uint>(vec.size()); *out_size = static_cast<nn_uint>(vec.size());
*out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) *out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END(); API_END();
} }
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, int NNGetOpInfo(OpHandle handle,
const char **name, const char **name,
const char **description, const char **description,
nn_uint *num_doc_args, nn_uint *num_doc_args,
const char ***arg_names, const char ***arg_names,
const char ***arg_type_infos, const char ***arg_type_infos,
const char ***arg_descriptions, const char ***arg_descriptions,
const char **return_type) { const char **return_type) {
const Op *op = static_cast<const Op *>(creator); const Op *op = static_cast<const Op *>(handle);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
...@@ -51,7 +72,7 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, ...@@ -51,7 +72,7 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
API_END(); API_END();
} }
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, int NNSymbolCreateAtomicSymbol(OpHandle creator,
nn_uint num_param, nn_uint num_param,
const char **keys, const char **keys,
const char **vals, const char **vals,
......
...@@ -38,6 +38,11 @@ Op::Op() { ...@@ -38,6 +38,11 @@ Op::Op() {
index_ = mgr->op_counter++; index_ = mgr->op_counter++;
} }
Op& Op::add_alias(const std::string& alias) { // NOLINT(*)
dmlc::Registry<Op>::Get()->AddAlias(this->name, alias);
return *this;
}
// find operator by name // find operator by name
const Op* Op::Get(const std::string& name) { const Op* Op::Get(const std::string& name) {
const Op* op = dmlc::Registry<Op>::Find(name); const Op* op = dmlc::Registry<Op>::Find(name);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment