Commit 31f9fc0a by Tianqi Chen

Make cython compatible with python3 (#12)

parent bed950da
......@@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loop
-Iinclude -Idmlc-core/include -fPIC
# specify tensor path
.PHONY: clean all test lint doc cython cython3
.PHONY: clean all test lint doc cython cython3 cyclean
all: lib/libnnvm.so lib/libnnvm.a cli_test
......@@ -37,6 +37,8 @@ cython:
cython3:
cd python; python3 setup.py build_ext --inplace
cyclean:
rm -rf python/nnvm/*/*.so python/nnvm/*/*.cpp
lint:
python2 dmlc-core/scripts/lint.py nnvm cpp include src
......
......@@ -9,6 +9,21 @@ cdef py_str(const char* x):
return x.decode("utf-8")
cdef c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef CALL(int ret):
if ret != 0:
raise NNVMError(NNGetLastError())
......@@ -20,6 +35,13 @@ cdef const char** CBeginPtr(vector[const char*]& vec):
else:
return NULL
cdef vector[const char*] SVec2Ptr(vector[string]& vec):
cdef vector[const char*] svec
svec.resize(vec.size())
for i in range(vec.size()):
svec[i] = vec[i].c_str()
return svec
cdef BuildDoc(nn_uint num_args,
const char** arg_names,
......
......@@ -6,6 +6,7 @@ from .._base import NNVMError
from ..name import NameManager
from ..attribute import AttrScope
from libcpp.vector cimport vector
from libcpp.string cimport string
from cpython.version cimport PY_MAJOR_VERSION
include "./base.pyi"
......@@ -110,7 +111,7 @@ cdef class Symbol:
CALL(NNSymbolGetOutput(self.handle, c_index, &handle))
return NewSymbol(handle)
def attr(self, const char* key):
def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
......@@ -125,6 +126,8 @@ cdef class Symbol:
"""
cdef const char* ret
cdef int success
key = c_str(key)
CALL(NNSymbolGetAttr(
self.handle, key, &ret, &success))
if success != 0:
......@@ -203,16 +206,19 @@ cdef class Symbol:
def debug_str(self):
cdef const char* out_str
CALL(NNSymbolPrint(self.handle, &out_str))
return str(out_str)
return py_str(out_str)
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef vector[const char*] param_keys
cdef vector[const char*] param_vals
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef nn_uint num_args
for k, v in kwargs.items():
param_keys.push_back(k)
param_vals.push_back(str(v))
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))
# keep strings in vector
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
num_args = param_keys.size()
CALL(NNSymbolSetAttrs(
handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
......@@ -225,7 +231,7 @@ cdef NewSymbol(SymbolHandle handle):
return sym
def Variable(const char* name, **kwargs):
def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
......@@ -241,6 +247,7 @@ def Variable(const char* name, **kwargs):
The created variable symbol.
"""
cdef SymbolHandle handle
name = c_str(name)
CALL(NNSymbolCreateVariable(name, &handle))
return NewSymbol(handle)
......@@ -274,10 +281,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
func_hint = func_name.lower()
def creator(*args, **kwargs):
cdef vector[const char*] param_keys
cdef vector[const char*] param_vals
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef vector[SymbolHandle] symbol_args
cdef vector[const char*] symbol_keys
cdef vector[string] ssymbol_keys
cdef SymbolHandle ret_handle
name = kwargs.pop("name", None)
......@@ -286,11 +293,11 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
if len(kwargs) != 0:
for k, v in kwargs.items():
if isinstance(v, Symbol):
symbol_keys.push_back(k)
ssymbol_keys.push_back(c_str(k))
symbol_args.push_back((<Symbol>v).handle)
else:
param_keys.push_back(k)
param_vals.push_back(str(v))
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))
if len(args) != 0:
if symbol_args.size() != 0:
......@@ -301,6 +308,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<Symbol>v).handle)
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys)
CALL(NNSymbolCreateAtomicSymbol(
handle,
<nn_uint>param_keys.size(),
......@@ -315,7 +326,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
name = NameManager.current.get(name, func_hint)
cdef const char* c_name = NULL
if name:
name = c_str(name)
c_name = name
CALL(NNSymbolCompose(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment