Commit 31f9fc0a by Tianqi Chen

Make cython compatible with python3 (#12)

parent bed950da
...@@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loop ...@@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loop
-Iinclude -Idmlc-core/include -fPIC -Iinclude -Idmlc-core/include -fPIC
# specify tensor path # specify tensor path
.PHONY: clean all test lint doc cython cython3 .PHONY: clean all test lint doc cython cython3 cyclean
all: lib/libnnvm.so lib/libnnvm.a cli_test all: lib/libnnvm.so lib/libnnvm.a cli_test
...@@ -37,6 +37,8 @@ cython: ...@@ -37,6 +37,8 @@ cython:
cython3: cython3:
cd python; python3 setup.py build_ext --inplace cd python; python3 setup.py build_ext --inplace
cyclean:
rm -rf python/nnvm/*/*.so python/nnvm/*/*.cpp
lint: lint:
python2 dmlc-core/scripts/lint.py nnvm cpp include src python2 dmlc-core/scripts/lint.py nnvm cpp include src
......
...@@ -9,6 +9,21 @@ cdef py_str(const char* x): ...@@ -9,6 +9,21 @@ cdef py_str(const char* x):
return x.decode("utf-8") return x.decode("utf-8")
cdef c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef CALL(int ret): cdef CALL(int ret):
if ret != 0: if ret != 0:
raise NNVMError(NNGetLastError()) raise NNVMError(NNGetLastError())
...@@ -20,6 +35,13 @@ cdef const char** CBeginPtr(vector[const char*]& vec): ...@@ -20,6 +35,13 @@ cdef const char** CBeginPtr(vector[const char*]& vec):
else: else:
return NULL return NULL
cdef vector[const char*] SVec2Ptr(vector[string]& vec):
cdef vector[const char*] svec
svec.resize(vec.size())
for i in range(vec.size()):
svec[i] = vec[i].c_str()
return svec
cdef BuildDoc(nn_uint num_args, cdef BuildDoc(nn_uint num_args,
const char** arg_names, const char** arg_names,
......
...@@ -6,6 +6,7 @@ from .._base import NNVMError ...@@ -6,6 +6,7 @@ from .._base import NNVMError
from ..name import NameManager from ..name import NameManager
from ..attribute import AttrScope from ..attribute import AttrScope
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libcpp.string cimport string
from cpython.version cimport PY_MAJOR_VERSION from cpython.version cimport PY_MAJOR_VERSION
include "./base.pyi" include "./base.pyi"
...@@ -110,7 +111,7 @@ cdef class Symbol: ...@@ -110,7 +111,7 @@ cdef class Symbol:
CALL(NNSymbolGetOutput(self.handle, c_index, &handle)) CALL(NNSymbolGetOutput(self.handle, c_index, &handle))
return NewSymbol(handle) return NewSymbol(handle)
def attr(self, const char* key): def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol. """Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters Parameters
...@@ -125,6 +126,8 @@ cdef class Symbol: ...@@ -125,6 +126,8 @@ cdef class Symbol:
""" """
cdef const char* ret cdef const char* ret
cdef int success cdef int success
key = c_str(key)
CALL(NNSymbolGetAttr( CALL(NNSymbolGetAttr(
self.handle, key, &ret, &success)) self.handle, key, &ret, &success))
if success != 0: if success != 0:
...@@ -203,16 +206,19 @@ cdef class Symbol: ...@@ -203,16 +206,19 @@ cdef class Symbol:
def debug_str(self): def debug_str(self):
cdef const char* out_str cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str)) CALL(NNSymbolPrint(self.handle, &out_str))
return str(out_str) return py_str(out_str)
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef vector[const char*] param_keys cdef vector[string] sparam_keys
cdef vector[const char*] param_vals cdef vector[string] sparam_vals
cdef nn_uint num_args cdef nn_uint num_args
for k, v in kwargs.items(): for k, v in kwargs.items():
param_keys.push_back(k) sparam_keys.push_back(c_str(k))
param_vals.push_back(str(v)) sparam_vals.push_back(c_str(str(v)))
# keep strings in vector
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
num_args = param_keys.size() num_args = param_keys.size()
CALL(NNSymbolSetAttrs( CALL(NNSymbolSetAttrs(
handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals))) handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
...@@ -225,7 +231,7 @@ cdef NewSymbol(SymbolHandle handle): ...@@ -225,7 +231,7 @@ cdef NewSymbol(SymbolHandle handle):
return sym return sym
def Variable(const char* name, **kwargs): def Variable(name, **kwargs):
"""Create a symbolic variable with specified name. """Create a symbolic variable with specified name.
Parameters Parameters
...@@ -241,6 +247,7 @@ def Variable(const char* name, **kwargs): ...@@ -241,6 +247,7 @@ def Variable(const char* name, **kwargs):
The created variable symbol. The created variable symbol.
""" """
cdef SymbolHandle handle cdef SymbolHandle handle
name = c_str(name)
CALL(NNSymbolCreateVariable(name, &handle)) CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle) return NewSymbol(handle)
...@@ -274,10 +281,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -274,10 +281,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
func_hint = func_name.lower() func_hint = func_name.lower()
def creator(*args, **kwargs): def creator(*args, **kwargs):
cdef vector[const char*] param_keys cdef vector[string] sparam_keys
cdef vector[const char*] param_vals cdef vector[string] sparam_vals
cdef vector[SymbolHandle] symbol_args cdef vector[SymbolHandle] symbol_args
cdef vector[const char*] symbol_keys cdef vector[string] ssymbol_keys
cdef SymbolHandle ret_handle cdef SymbolHandle ret_handle
name = kwargs.pop("name", None) name = kwargs.pop("name", None)
...@@ -286,11 +293,11 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -286,11 +293,11 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
if len(kwargs) != 0: if len(kwargs) != 0:
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, Symbol): if isinstance(v, Symbol):
symbol_keys.push_back(k) ssymbol_keys.push_back(c_str(k))
symbol_args.push_back((<Symbol>v).handle) symbol_args.push_back((<Symbol>v).handle)
else: else:
param_keys.push_back(k) sparam_keys.push_back(c_str(k))
param_vals.push_back(str(v)) sparam_vals.push_back(c_str(str(v)))
if len(args) != 0: if len(args) != 0:
if symbol_args.size() != 0: if symbol_args.size() != 0:
...@@ -301,6 +308,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -301,6 +308,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
raise TypeError('Compose expect `Symbol` as arguments') raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<Symbol>v).handle) symbol_args.push_back((<Symbol>v).handle)
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys)
CALL(NNSymbolCreateAtomicSymbol( CALL(NNSymbolCreateAtomicSymbol(
handle, handle,
<nn_uint>param_keys.size(), <nn_uint>param_keys.size(),
...@@ -315,7 +326,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): ...@@ -315,7 +326,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
name = NameManager.current.get(name, func_hint) name = NameManager.current.get(name, func_hint)
cdef const char* c_name = NULL cdef const char* c_name = NULL
if name: if name:
name = c_str(name)
c_name = name c_name = name
CALL(NNSymbolCompose( CALL(NNSymbolCompose(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment