Commit 305614a9 by Tianqi Chen Committed by GitHub

[PYTHON] Enable cython ndarray API (#113)

parent 706f9b6f
......@@ -10,7 +10,8 @@ from numbers import Number, Integral
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import TVMType, TVMByteArray, NDArrayBase, _make_array
from ..runtime_ctypes import TVMType, TVMByteArray
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
......
"""Runtime NDArray api"""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call
from ..runtime_ctypes import TVMArrayHandle
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view
def __del__(self):
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle))
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
......@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import ctypes
from ..base import py_str, check_call, _LIB
from ..ndarray import TVMByteArray
from ..runtime_ctypes import TVMByteArray
class TypeCode(object):
"""Type code used in API calls"""
......
......@@ -19,18 +19,34 @@ cdef enum TVMTypeCode:
kBytes = 11
cdef extern from "tvm/runtime/c_runtime_api.h":
struct DLType:
ctypedef struct DLDataType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct DLContext:
int device_id
int device_type
ctypedef struct DLTensor:
void* data
DLContext ctx
int ndim
DLDataType dtype
int64_t* shape
int64_t* strides
size_t byte_offset;
ctypedef struct TVMValue:
int64_t v_int64
double v_float64
void* v_handle
const char* v_str
DLType v_type
DLDataType v_type
ctypedef int64_t tvm_index_t
ctypedef void* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
......@@ -61,6 +77,15 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
DLDataType dtype,
DLContext ctx,
DLTensorHandle* out)
int TVMArrayFree(DLTensorHandle handle)
int TVMArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to,
TVMStreamHandle stream)
cdef extern from "tvm/c_api.h":
int TVMCbArgToReturn(TVMValue* value, int code)
......@@ -106,6 +131,7 @@ cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)
cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
......
include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
......@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import NDArrayBase, TVMType, TVMByteArray, _make_array
from ..runtime_ctypes import TVMType, TVMByteArray
print("TVM: Initializing cython mode...")
......@@ -32,7 +32,7 @@ cdef int tvm_callback(TVMValue* args,
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
pyargs.append(_make_array(ctypes_handle(value.v_handle), True))
pyargs.append(c_make_array(value.v_handle, True))
try:
rv = local_pyfunc(*pyargs)
except Exception:
......@@ -81,8 +81,7 @@ cdef inline void make_arg(object arg,
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = c_handle(
ctypes.cast(arg.handle, ctypes.c_void_p))
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = kArrayHandle
elif isinstance(arg, Integral):
value[0].v_int64 = arg
......@@ -205,7 +204,7 @@ cdef class FunctionBase:
cdef TVMFunctionHandle chandle
cdef int is_global
cdef _set_handle(self, handle):
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
......
from ..runtime_ctypes import TVMArrayHandle
cdef class NDArrayBase:
cdef DLTensor* chandle
cdef int c_is_view
cdef inline _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = ctypes.addressof(handle.contents)
self.chandle = <DLTensor*>(ptr)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(
<unsigned long long>self.chandle, TVMArrayHandle)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view
def __dealloc__(self):
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
# pylint: disable=invalid-name, unused-import
"""Runtime NDArray api"""
from __future__ import absolute_import
import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types
from .. import _api_internal
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle, tvm_shape_index_t
tvm_shape_index_t = ctypes.c_int64
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("device_id", ctypes.c_int),
("device_type", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
8 : 'metal',
9 : 'vpi'
}
STR2MASK = {
'cpu': 1,
'gpu': 2,
'cuda': 2,
'cl': 4,
'opencl': 4,
'metal': 8,
'vpi': 9
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_id = device_id
self.device_type = device_type
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self, None))
def __eq__(self, other):
return (isinstance(other, TVMContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
from ._cy2.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_size_t)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.
......@@ -214,28 +100,10 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle)))
return _CLASS_NDARRAY(handle)
return _make_array(handle, False)
class NDArrayBase(object):
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view
def __del__(self):
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle))
@property
def shape(self):
"""Shape of this array"""
......@@ -324,13 +192,3 @@ class NDArrayBase(object):
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
"""Common runtime ctypes."""
# pylint: disable=invalid-name
from __future__ import absolute_import
import ctypes
import numpy as np
from .base import _LIB, check_call
from .. import _api_internal
tvm_shape_index_t = ctypes.c_int64
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class TVMType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
self.type_code = 4
bits = 64
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
def __eq__(self, other):
return (self.bits == other.bits and
self.type_code == other.type_code and
self.lanes == other.lanes)
def __ne__(self, other):
return not self.__eq__(other)
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("device_id", ctypes.c_int),
("device_type", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
8 : 'metal',
9 : 'vpi'
}
STR2MASK = {
'cpu': 1,
'gpu': 2,
'cuda': 2,
'cl': 4,
'opencl': 4,
'metal': 8,
'vpi': 9
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_id = device_id
self.device_type = device_type
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self, None))
def __eq__(self, other):
return (isinstance(other, TVMContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_size_t)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
......@@ -8,7 +8,6 @@
#include <tvm/api_registry.h>
namespace tvm {
TVM_REGISTER_API("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kNodeHandle);
......@@ -34,4 +33,7 @@ TVM_REGISTER_API("_load_json")
*ret = NodeRef(LoadJSON_(args[0]));
});
TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment