Commit 807400aa by Tianqi Chen

Check in a experimental cython API (#10)

parent b63cb4d1
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops\ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC -Iinclude -Idmlc-core/include -I../include -fPIC -L../lib
# specify tensor path # specify tensor path
.PHONY: clean all test lint doc .PHONY: clean all test lint doc python
all: lib/libnnvm.so lib/libnnvm.a cli_test all: lib/libnnvm.so lib/libnnvm.a cli_test
...@@ -31,6 +31,9 @@ lib/libnnvm.a: $(ALL_DEP) ...@@ -31,6 +31,9 @@ lib/libnnvm.a: $(ALL_DEP)
cli_test: $(ALL_DEP) build/test_main.o cli_test: $(ALL_DEP) build/test_main.o
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
python:
cd python; python setup.py build_ext --inplace
lint: lint:
python2 dmlc-core/scripts/lint.py nnvm cpp include src python2 dmlc-core/scripts/lint.py nnvm cpp include src
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#ifdef __cplusplus #ifdef __cplusplus
#define NNVM_EXTERN_C extern "C" #define NNVM_EXTERN_C extern "C"
#else
#define NNVM_EXTERN_C
#endif #endif
/*! \brief NNVM_DLL prefix for windows */ /*! \brief NNVM_DLL prefix for windows */
...@@ -42,7 +44,7 @@ typedef void *GraphHandle; ...@@ -42,7 +44,7 @@ typedef void *GraphHandle;
* this function is threadsafe and can be called by different thread * this function is threadsafe and can be called by different thread
* \return error info * \return error info
*/ */
NNVM_DLL const char *NNGetLastError(); NNVM_DLL const char *NNGetLastError(void);
/*! /*!
* \brief list all the available AtomicSymbolEntry * \brief list all the available AtomicSymbolEntry
......
ctypedef void* SymbolHandle
cdef class Symbol:
# handle for symbolic operator.
cdef SymbolHandle handle
import sys
from libcpp.vector cimport vector
ctypedef void* SymbolHandle
ctypedef void* AtomicSymbolCreator
ctypedef unsigned nn_uint
cdef extern from "nnvm/c_api.h":
int NNSymbolFree(SymbolHandle symbol)
int NNSymbolCreateVariable(const char *name, SymbolHandle *out)
const char* NNGetLastError()
int NNSymbolPrint(SymbolHandle symbol, const char **out_str)
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNSymbolCompose(SymbolHandle sym,
const char* name,
nn_uint num_args,
const char** keys,
SymbolHandle* args);
cdef CALL(int ret):
if ret != 0:
raise RuntimeError(NNGetLastError())
cdef const char** CBeginPtr(vector[const char*]& vec):
if (vec.size() != 0):
return &vec[0]
else:
return NULL
cdef ctypes2docstring(nn_uint num_args,
const char** arg_names,
const char** arg_types,
const char** arg_descs,
remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args):
key = arg_names[i]
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = arg_types[i]
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + arg_descs[i]
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
cdef class Symbol:
# handle for symbolic operator.
cdef SymbolHandle handle
def __dealloc__(self):
CALL(NNSymbolFree(self.handle))
def debug_str(self):
cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str))
return str(out_str)
cdef NewSymbol(SymbolHandle handle):
"""Create a new symbol given handle"""
sym = Symbol()
sym.handle = handle
return sym
def Variable(const char* name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
cdef SymbolHandle handle
CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle)
cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
"""Create an atomic symbol function by handle and funciton name."""
cdef const char *name
cdef const char *desc
cdef nn_uint num_args
cdef const char** arg_names
cdef const char** arg_types
cdef const char** arg_descs
cdef const char* return_type
CALL(NNSymbolGetAtomicSymbolInfo(
handle, &name, &desc,
&num_args, &arg_names,
&arg_types, &arg_descs,
&return_type))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = name
doc_str = ('%s\n\n' +
'%s\n' +
'name : string, optional.\n' +
' Name of the resulting symbol.\n\n' +
'Returns\n' +
'-------\n' +
'symbol: Symbol\n' +
' The result symbol.')
doc_str = doc_str % (desc, param_str)
def creator(*args, **kwargs):
cdef vector[const char*] param_keys
cdef vector[const char*] param_vals
cdef vector[SymbolHandle] symbol_args
cdef vector[const char*] symbol_keys
cdef SymbolHandle ret_handle
cdef const char* c_name = NULL
name = kwargs.pop('name', None)
attr = kwargs.pop('attr', None)
if name:
c_name = name
if len(kwargs) != 0:
for k, v in kwargs.items():
if isinstance(v, Symbol):
symbol_keys.push_back(k)
symbol_args.push_back((<Symbol>v).handle)
else:
param_keys.push_back(k)
param_vals.push_back(str(v))
if len(args) != 0:
if symbol_args.size() != 0:
raise TypeError("compose only accept input Symbols\
either as positional or keyword arguments, not both")
for v in args:
if not isinstance(v, Symbol):
raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<Symbol>v).handle)
CALL(NNSymbolCreateAtomicSymbol(
handle,
<nn_uint>param_keys.size(),
CBeginPtr(param_keys),
CBeginPtr(param_vals),
&ret_handle))
num_args = <nn_uint>(symbol_args.size())
CALL(NNSymbolCompose(
ret_handle, c_name, num_args,
&symbol_keys[0] if symbol_keys.size() != 0 else NULL,
&symbol_args[0] if symbol_args.size() != 0 else NULL))
return NewSymbol(ret_handle)
creator.__name__ = func_name
creator.__doc__ = doc_str
return creator
def _init_symbol_module():
"""List and add all the atomic symbol functions to current module."""
cdef AtomicSymbolCreator* plist
cdef nn_uint size
CALL(NNSymbolListAtomicSymbolCreators(&size, &plist))
module_obj = sys.modules[__name__]
for i in range(size):
function = _make_atomic_symbol_function(plist[i])
if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function))
else:
setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
from distutils.core import setup
from Cython.Build import cythonize
from distutils.extension import Extension
setup(
name='nnvm',
ext_modules = cythonize([
Extension("nnvm/symbolx",
["nnvm/symbolx.pyx"],
libraries=["nnvm"],
language="c++")
])
)
...@@ -123,7 +123,11 @@ Symbol Symbol::Copy() const { ...@@ -123,7 +123,11 @@ Symbol Symbol::Copy() const {
void Symbol::Print(std::ostream &os) const { void Symbol::Print(std::ostream &os) const {
if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) { if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n'; if (outputs[0].node->is_variable()) {
os << "Variable:" << outputs[0].node->attrs.name << '\n';
} else {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n';
}
} else { } else {
// use DFSVisit to copy all the nodes // use DFSVisit to copy all the nodes
os << "Symbol Outputs:\n"; os << "Symbol Outputs:\n";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment