Commit 305614a9 by Tianqi Chen Committed by GitHub

[PYTHON] Enable cython ndarray API (#113)

parent 706f9b6f
...@@ -10,7 +10,8 @@ from numbers import Number, Integral ...@@ -10,7 +10,8 @@ from numbers import Number, Integral
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import TVMType, TVMByteArray, NDArrayBase, _make_array from ..runtime_ctypes import TVMType, TVMByteArray
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
......
"""Runtime NDArray api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from ..runtime_ctypes import TVMArrayHandle
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view
def __del__(self):
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle))
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
from ..base import py_str, check_call, _LIB from ..base import py_str, check_call, _LIB
from ..ndarray import TVMByteArray from ..runtime_ctypes import TVMByteArray
class TypeCode(object): class TypeCode(object):
"""Type code used in API calls""" """Type code used in API calls"""
......
...@@ -19,18 +19,34 @@ cdef enum TVMTypeCode: ...@@ -19,18 +19,34 @@ cdef enum TVMTypeCode:
kBytes = 11 kBytes = 11
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
struct DLType: ctypedef struct DLDataType:
uint8_t code uint8_t code
uint8_t bits uint8_t bits
uint16_t lanes uint16_t lanes
ctypedef struct DLContext:
int device_id
int device_type
ctypedef struct DLTensor:
void* data
DLContext ctx
int ndim
DLDataType dtype
int64_t* shape
int64_t* strides
size_t byte_offset;
ctypedef struct TVMValue: ctypedef struct TVMValue:
int64_t v_int64 int64_t v_int64
double v_float64 double v_float64
void* v_handle void* v_handle
const char* v_str const char* v_str
DLType v_type DLDataType v_type
ctypedef int64_t tvm_index_t
ctypedef void* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle ctypedef void* NodeHandle
...@@ -61,6 +77,15 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -61,6 +77,15 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* resource_handle, void* resource_handle,
TVMPackedCFuncFinalizer fin, TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) TVMFunctionHandle *out)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
DLDataType dtype,
DLContext ctx,
DLTensorHandle* out)
int TVMArrayFree(DLTensorHandle handle)
int TVMArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to,
TVMStreamHandle stream)
cdef extern from "tvm/c_api.h": cdef extern from "tvm/c_api.h":
int TVMCbArgToReturn(TVMValue* value, int code) int TVMCbArgToReturn(TVMValue* value, int code)
...@@ -106,6 +131,7 @@ cdef inline object ctypes_handle(void* chandle): ...@@ -106,6 +131,7 @@ cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle.""" """Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p) return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
cdef inline void* c_handle(object handle): cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle.""" """Cast C types handle to c handle."""
cdef unsigned long long v_ptr cdef unsigned long long v_ptr
......
include "./base.pxi" include "./base.pxi"
include "./node.pxi" include "./node.pxi"
include "./function.pxi" include "./function.pxi"
include "./ndarray.pxi"
...@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF ...@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import NDArrayBase, TVMType, TVMByteArray, _make_array from ..runtime_ctypes import TVMType, TVMByteArray
print("TVM: Initializing cython mode...") print("TVM: Initializing cython mode...")
...@@ -32,7 +32,7 @@ cdef int tvm_callback(TVMValue* args, ...@@ -32,7 +32,7 @@ cdef int tvm_callback(TVMValue* args,
if tcode != kArrayHandle: if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
else: else:
pyargs.append(_make_array(ctypes_handle(value.v_handle), True)) pyargs.append(c_make_array(value.v_handle, True))
try: try:
rv = local_pyfunc(*pyargs) rv = local_pyfunc(*pyargs)
except Exception: except Exception:
...@@ -81,8 +81,7 @@ cdef inline void make_arg(object arg, ...@@ -81,8 +81,7 @@ cdef inline void make_arg(object arg,
value[0].v_handle = (<NodeBase>arg).chandle value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
value[0].v_handle = c_handle( value[0].v_handle = (<NDArrayBase>arg).chandle
ctypes.cast(arg.handle, ctypes.c_void_p))
tcode[0] = kArrayHandle tcode[0] = kArrayHandle
elif isinstance(arg, Integral): elif isinstance(arg, Integral):
value[0].v_int64 = arg value[0].v_int64 = arg
...@@ -205,7 +204,7 @@ cdef class FunctionBase: ...@@ -205,7 +204,7 @@ cdef class FunctionBase:
cdef TVMFunctionHandle chandle cdef TVMFunctionHandle chandle
cdef int is_global cdef int is_global
cdef _set_handle(self, handle): cdef inline _set_handle(self, handle):
if handle is None: if handle is None:
self.chandle = NULL self.chandle = NULL
else: else:
......
from ..runtime_ctypes import TVMArrayHandle
cdef class NDArrayBase:
cdef DLTensor* chandle
cdef int c_is_view
cdef inline _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = ctypes.addressof(handle.contents)
self.chandle = <DLTensor*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(
<unsigned long long>self.chandle, TVMArrayHandle)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view
def __dealloc__(self):
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement # pylint: disable=invalid-name, unused-import
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Runtime NDArray api""" """Runtime NDArray api"""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import ctypes import ctypes
import numpy as np import numpy as np
from .base import _LIB, check_call, c_array, string_types from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
from .. import _api_internal from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle, tvm_shape_index_t
tvm_shape_index_t = ctypes.c_int64
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("device_id", ctypes.c_int),
("device_type", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
8 : 'metal',
9 : 'vpi'
}
STR2MASK = {
'cpu': 1,
'gpu': 2,
'cuda': 2,
'cl': 4,
'opencl': 4,
'metal': 8,
'vpi': 9
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_id = device_id
self.device_type = device_type
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
def sync(self): IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self, None))
def __eq__(self, other): try:
return (isinstance(other, TVMContext) and # pylint: disable=wrong-import-position
self.device_id == other.device_id and if _FFI_MODE == "ctypes":
self.device_type == other.device_type) raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
else:
from ._cy2.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_size_t)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id. """Construct a TVM context with given device type and id.
...@@ -214,28 +100,10 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): ...@@ -214,28 +100,10 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
dtype = TVMType(dtype) dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc( check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle))) shape, ndim, dtype, ctx, ctypes.byref(handle)))
return _CLASS_NDARRAY(handle) return _make_array(handle, False)
class NDArrayBase(object): class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view
def __del__(self):
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle))
@property @property
def shape(self): def shape(self):
"""Shape of this array""" """Shape of this array"""
...@@ -324,13 +192,3 @@ class NDArrayBase(object): ...@@ -324,13 +192,3 @@ class NDArrayBase(object):
else: else:
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
"""Common runtime ctypes."""
# pylint: disable=invalid-name
from __future__ import absolute_import
import ctypes
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
tvm_shape_index_t = ctypes.c_int64
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("device_id", ctypes.c_int),
("device_type", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
8 : 'metal',
9 : 'vpi'
}
STR2MASK = {
'cpu': 1,
'gpu': 2,
'cuda': 2,
'cl': 4,
'opencl': 4,
'metal': 8,
'vpi': 9
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_id = device_id
self.device_type = device_type
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self, None))
def __eq__(self, other):
return (isinstance(other, TVMContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_size_t)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
namespace tvm { namespace tvm {
TVM_REGISTER_API("_format_str") TVM_REGISTER_API("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle); CHECK(args[0].type_code() == kNodeHandle);
...@@ -34,4 +33,7 @@ TVM_REGISTER_API("_load_json") ...@@ -34,4 +33,7 @@ TVM_REGISTER_API("_load_json")
*ret = NodeRef(LoadJSON_(args[0])); *ret = NodeRef(LoadJSON_(args[0]));
}); });
TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment