Commit bed950da by Tianqi Chen

Create a ctypes cython optional compatible package (#11)

parent 807400aa
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops\ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -I../include -fPIC -L../lib -Iinclude -Idmlc-core/include -fPIC
# specify tensor path # specify tensor path
.PHONY: clean all test lint doc python .PHONY: clean all test lint doc cython cython3
all: lib/libnnvm.so lib/libnnvm.a cli_test all: lib/libnnvm.so lib/libnnvm.a cli_test
...@@ -31,9 +31,13 @@ lib/libnnvm.a: $(ALL_DEP) ...@@ -31,9 +31,13 @@ lib/libnnvm.a: $(ALL_DEP)
cli_test: $(ALL_DEP) build/test_main.o cli_test: $(ALL_DEP) build/test_main.o
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
python: cython:
cd python; python setup.py build_ext --inplace cd python; python setup.py build_ext --inplace
cython3:
cd python; python3 setup.py build_ext --inplace
lint: lint:
python2 dmlc-core/scripts/lint.py nnvm cpp include src python2 dmlc-core/scripts/lint.py nnvm cpp include src
......
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
"""NNVM python API for ease of use and help new framework establish python API. """ """NNVM python API for ease of use and help new framework establish python API. """
from __future__ import absolute_import from __future__ import absolute_import as _abs
from . import base from . import _base
from . import symbol as sym from . import symbol as sym
from . import symbol from . import symbol
from ._base import NNVMError
__version__ = base.__version__ __version__ = _base.__version__
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
import os
import ctypes import ctypes
import numpy as np import numpy as np
from . import libinfo from . import libinfo
...@@ -31,7 +32,7 @@ class NNVMError(Exception): ...@@ -31,7 +32,7 @@ class NNVMError(Exception):
def _load_lib(): def _load_lib():
"""Load libary by searching possible path.""" """Load libary by searching possible path."""
lib_path = libinfo.find_lib_path() lib_path = libinfo.find_lib_path()
lib = ctypes.cdll.LoadLibrary(lib_path[0]) lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions # DMatrix functions
lib.NNGetLastError.restype = ctypes.c_char_p lib.NNGetLastError.restype = ctypes.c_char_p
return lib return lib
...@@ -41,13 +42,13 @@ __version__ = libinfo.__version__ ...@@ -41,13 +42,13 @@ __version__ = libinfo.__version__
# library instance of nnvm # library instance of nnvm
_LIB = _load_lib() _LIB = _load_lib()
# type definitions # type definitions
nn_uint = ctypes.c_uint nn_uint = ctypes.c_uint
SymbolCreatorHandle = ctypes.c_void_p SymbolCreatorHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p GraphHandle = ctypes.c_void_p
#---------------------------- #----------------------------
# helper function definition # helper function definition
#---------------------------- #----------------------------
......
This folder is by default empty and will hold DLLs generated by cython.
"""Namespace for cython generated modules for python2"""
This folder is by default empty and will hold DLLs generated by cython.
\ No newline at end of file
"""Cython generated modules"""
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Attribute scoping support for symbolic API.""" """Attribute scoping support for symbolic API."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import string_types from ._base import string_types
class AttrScope(object): class AttrScope(object):
"""Attribute manager for scoping. """Attribute manager for scoping.
...@@ -59,4 +59,3 @@ class AttrScope(object): ...@@ -59,4 +59,3 @@ class AttrScope(object):
AttrScope.current = self._old_scope AttrScope.current = self._old_scope
AttrScope.current = AttrScope() AttrScope.current = AttrScope()
Ctypes specific implementation of certain modules
\ No newline at end of file
""""ctypes implementation of the Symbol"""
Cython specific implementation of certain modules
\ No newline at end of file
ctypedef void* SymbolHandle
ctypedef void* AtomicSymbolCreator
ctypedef unsigned nn_uint
cdef py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
else:
return x.decode("utf-8")
cdef CALL(int ret):
if ret != 0:
raise NNVMError(NNGetLastError())
cdef const char** CBeginPtr(vector[const char*]& vec):
if (vec.size() != 0):
return &vec[0]
else:
return NULL
cdef BuildDoc(nn_uint num_args,
const char** arg_names,
const char** arg_types,
const char** arg_descs,
remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args):
key = arg_names[i]
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = arg_types[i]
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + arg_descs[i]
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
...@@ -6,10 +6,10 @@ from __future__ import absolute_import as _abs ...@@ -6,10 +6,10 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
import sys import sys
import json import json
from .base import _LIB from ._base import _LIB
from .base import c_array, c_str, nn_uint, py_str, string_types from ._base import c_array, c_str, nn_uint, py_str, string_types
from .base import GraphHandle, SymbolHandle from ._base import GraphHandle, SymbolHandle
from .base import check_call from ._base import check_call
from .symbol import Symbol from .symbol import Symbol
......
# coding: utf-8 # coding: utf-8
"""Automatic naming support for symbolic API.""" """Automatic naming support for symbolic API."""
from __future__ import absolute_import from __future__ import absolute_import as _abs
class NameManager(object): class NameManager(object):
"""NameManager to do automatic naming. """NameManager to do automatic naming.
......
import sys
from libcpp.vector cimport vector
ctypedef void* SymbolHandle
ctypedef void* AtomicSymbolCreator
ctypedef unsigned nn_uint
cdef extern from "nnvm/c_api.h":
int NNSymbolFree(SymbolHandle symbol)
int NNSymbolCreateVariable(const char *name, SymbolHandle *out)
const char* NNGetLastError()
int NNSymbolPrint(SymbolHandle symbol, const char **out_str)
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNSymbolCompose(SymbolHandle sym,
const char* name,
nn_uint num_args,
const char** keys,
SymbolHandle* args);
cdef CALL(int ret):
if ret != 0:
raise RuntimeError(NNGetLastError())
cdef const char** CBeginPtr(vector[const char*]& vec):
if (vec.size() != 0):
return &vec[0]
else:
return NULL
cdef ctypes2docstring(nn_uint num_args,
const char** arg_names,
const char** arg_types,
const char** arg_descs,
remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args):
key = arg_names[i]
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = arg_types[i]
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + arg_descs[i]
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
cdef class Symbol:
# handle for symbolic operator.
cdef SymbolHandle handle
def __dealloc__(self):
CALL(NNSymbolFree(self.handle))
def debug_str(self):
cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str))
return str(out_str)
cdef NewSymbol(SymbolHandle handle):
"""Create a new symbol given handle"""
sym = Symbol()
sym.handle = handle
return sym
def Variable(const char* name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
kwargs : dict of string -> string
Additional attributes to set on the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
cdef SymbolHandle handle
CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle)
cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
"""Create an atomic symbol function by handle and funciton name."""
cdef const char *name
cdef const char *desc
cdef nn_uint num_args
cdef const char** arg_names
cdef const char** arg_types
cdef const char** arg_descs
cdef const char* return_type
CALL(NNSymbolGetAtomicSymbolInfo(
handle, &name, &desc,
&num_args, &arg_names,
&arg_types, &arg_descs,
&return_type))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = name
doc_str = ('%s\n\n' +
'%s\n' +
'name : string, optional.\n' +
' Name of the resulting symbol.\n\n' +
'Returns\n' +
'-------\n' +
'symbol: Symbol\n' +
' The result symbol.')
doc_str = doc_str % (desc, param_str)
def creator(*args, **kwargs):
cdef vector[const char*] param_keys
cdef vector[const char*] param_vals
cdef vector[SymbolHandle] symbol_args
cdef vector[const char*] symbol_keys
cdef SymbolHandle ret_handle
cdef const char* c_name = NULL
name = kwargs.pop('name', None)
attr = kwargs.pop('attr', None)
if name:
c_name = name
if len(kwargs) != 0:
for k, v in kwargs.items():
if isinstance(v, Symbol):
symbol_keys.push_back(k)
symbol_args.push_back((<Symbol>v).handle)
else:
param_keys.push_back(k)
param_vals.push_back(str(v))
if len(args) != 0:
if symbol_args.size() != 0:
raise TypeError("compose only accept input Symbols\
either as positional or keyword arguments, not both")
for v in args:
if not isinstance(v, Symbol):
raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<Symbol>v).handle)
CALL(NNSymbolCreateAtomicSymbol(
handle,
<nn_uint>param_keys.size(),
CBeginPtr(param_keys),
CBeginPtr(param_vals),
&ret_handle))
num_args = <nn_uint>(symbol_args.size())
CALL(NNSymbolCompose(
ret_handle, c_name, num_args,
&symbol_keys[0] if symbol_keys.size() != 0 else NULL,
&symbol_args[0] if symbol_args.size() != 0 else NULL))
return NewSymbol(ret_handle)
creator.__name__ = func_name
creator.__doc__ = doc_str
return creator
def _init_symbol_module():
"""List and add all the atomic symbol functions to current module."""
cdef AtomicSymbolCreator* plist
cdef nn_uint size
CALL(NNSymbolListAtomicSymbolCreators(&size, &plist))
module_obj = sys.modules[__name__]
for i in range(size):
function = _make_atomic_symbol_function(plist[i])
if function.__name__.startswith('_'):
setattr(Symbol, function.__name__, staticmethod(function))
else:
setattr(module_obj, function.__name__, function)
# Initialize the atomic symbol in startups
_init_symbol_module()
import os
import sys
from distutils.core import setup from distutils.core import setup
from Cython.Build import cythonize from Cython.Build import cythonize
from distutils.extension import Extension from distutils.extension import Extension
def config():
if sys.version_info >= (3, 0):
subdir = "_cy3"
else:
subdir = "_cy2"
ret = []
path = "nnvm/cython"
for fn in os.listdir(path):
if not fn.endswith(".pyx"):
continue
ret.append(Extension(
"nnvm/%s/%s" % (subdir, fn[:-4]),
["nnvm/cython/%s" % fn],
include_dirs=["../include/"],
language="c++"))
return ret
setup( setup(
name='nnvm', name='nnvm',
ext_modules = cythonize([ ext_modules = cythonize(config())
Extension("nnvm/symbolx",
["nnvm/symbolx.pyx"],
libraries=["nnvm"],
language="c++")
])
) )
import nnvm.symbol as sym import nnvm.symbol as sym
from nnvm.base import NNVMError from nnvm import NNVMError
def test_compose(): def test_compose():
x = sym.Variable('x') x = sym.Variable('x')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment