Unverified Commit 19cf5c66 by Tianqi Chen Committed by GitHub

[DLPACK] Enable cython support (#1589)

parent ec3f09b3
Subproject commit a5a80bdc8232c9dbfe508bb5c46e8f58cdf7ec20
Subproject commit a0b9563f45719553adf4d39fe3c14db1af0e1f40
# pylint: disable=invalid-name
"""Runtime NDArray api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
def _from_dlpack(dltensor):
dltensor = ctypes.py_object(dltensor)
if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
handle = TVMArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
return _make_array(handle, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
def _dlpack_deleter(pycapsule):
pycapsule = ctypes.cast(pycapsule, ctypes.py_object)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
_LIB.TVMDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
......@@ -29,6 +65,17 @@ class NDArrayBase(object):
def _tvm_handle(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
handle = ctypes.c_void_p()
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
......
from ..base import TVMError
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
import ctypes
......@@ -40,6 +41,11 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int64_t* strides
uint64_t byte_offset
ctypedef struct DLManagedTensor:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
ctypedef struct TVMValue:
int64_t v_int64
double v_float64
......@@ -49,7 +55,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
DLContext v_ctx
ctypedef int64_t tvm_index_t
ctypedef void* DLTensorHandle
ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
......@@ -92,6 +98,11 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to,
TVMStreamHandle stream)
int TVMArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out)
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
......
from ..runtime_ctypes import TVMArrayHandle
cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor"
cdef void _c_dlpack_deleter(object pycaps):
cdef DLManagedTensor* dltensor
if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor):
dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor)
TVMDLManagedTensorCallDeleter(dltensor)
def _from_dlpack(object dltensor):
cdef DLManagedTensor* ptr
cdef DLTensorHandle chandle
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
CALL(TVMArrayFromDLPack(ptr, &chandle))
# set name and destructor to be empty
pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
return c_make_array(chandle, 0)
raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")
cdef class NDArrayBase:
cdef DLTensor* chandle
cdef int c_is_view
......@@ -35,12 +59,26 @@ cdef class NDArrayBase:
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
cdef DLManagedTensor* dltensor
if self.c_is_view != 0:
raise ValueError("to_dlpack do not work with memory views")
CALL(TVMArrayToDLPack(self.chandle, &dltensor))
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
cdef _TVM_COMPATS = ()
cdef _TVM_EXT_RET = {}
......
......@@ -17,28 +17,17 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy3.core import NDArrayBase as _NDArrayBase
else:
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._cy2.core import NDArrayBase as _NDArrayBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.
......@@ -134,30 +123,14 @@ def from_dlpack(dltensor):
Parameters
----------
dltensor : DLPack tensor
Input DLManagedTensor, can only be consumed once.
Returns
-------
arr: tvm.nd.NDArray
The array view of the tensor data.
"""
dltensor = ctypes.py_object(dltensor)
name = ctypes.pythonapi.PyCapsule_GetName(dltensor)
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, name)
handle = TVMArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, None)
return _make_array(handle, False)
def _dlpack_deleter(pycapsule):
pycapsule = ctypes.py_object(pycapsule)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
_LIB.TVMDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
return _from_dlpack(dltensor)
class NDArrayBase(_NDArrayBase):
......@@ -308,17 +281,6 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
handle = ctypes.c_void_p()
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
......
......@@ -4,6 +4,10 @@ export PYTHONPATH=nnvm/python:python:topi/python
# to avoid openblas threading error
export OMP_NUM_THREADS=1
# Rebuild cython
make cython || exit -1
make cython3 || exit -1
echo "Running unittest..."
python -m nose -v nnvm/tests/python/unittest || exit -1
python3 -m nose -v nnvm/tests/python/unittest || exit -1
......
export PYTHONPATH=python:topi/python
# Rebuild cython
make cython || exit -1
make cython3 || exit -1
python -m nose -v topi/tests/python || exit -1
python3 -m nose -v topi/tests/python || exit -1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment