Commit 553657eb by Tianqi Chen Committed by GitHub

[PYTHON] Support DLTensor compatible API (#136)

* [PYTHON] Support DLTensor compatible API

* optimize for common path
parent e3695cad
...@@ -14,3 +14,5 @@ tvm.ndarray ...@@ -14,3 +14,5 @@ tvm.ndarray
.. autofunction:: tvm.opencl .. autofunction:: tvm.opencl
.. autofunction:: tvm.metal .. autofunction:: tvm.metal
.. autofunction:: tvm.ndarray.array .. autofunction:: tvm.ndarray.array
.. autofunction:: tvm.register_dltensor
...@@ -23,5 +23,6 @@ from ._ffi.base import TVMError, __version__ ...@@ -23,5 +23,6 @@ from ._ffi.base import TVMError, __version__
from .api import * from .api import *
from .intrin import * from .intrin import *
from .node import register_node from .node import register_node
from .ndarray import register_dltensor
from .schedule import create_schedule from .schedule import create_schedule
from .build import build, lower from .build import build, lower
...@@ -11,6 +11,7 @@ from ..base import _LIB, check_call ...@@ -11,6 +11,7 @@ from ..base import _LIB, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMByteArray from ..runtime_ctypes import TVMType, TVMByteArray
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import TVMPackedCFunc, TVMCFuncFinalizer
...@@ -94,6 +95,9 @@ def _make_tvm_args(args, temp_args): ...@@ -94,6 +95,9 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = TypeCode.ARRAY_HANDLE type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, _nd._DLTENSOR_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._dltensor_addr)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, Integral): elif isinstance(arg, Integral):
values[i].v_int64 = arg values[i].v_int64 = arg
type_codes[i] = TypeCode.INT type_codes[i] = TypeCode.INT
......
...@@ -24,10 +24,20 @@ class NDArrayBase(object): ...@@ -24,10 +24,20 @@ class NDArrayBase(object):
if not self.is_view: if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle)) check_call(_LIB.TVMArrayFree(self.handle))
@property
def _dltensor_addr(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def _make_array(handle, is_view): def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle) handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view) return _CLASS_NDARRAY(handle, is_view)
_DLTENSOR_COMPATS = ()
def _reg_dltensor(cls):
global _DLTENSOR_COMPATS
_DLTENSOR_COMPATS += (cls,)
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
def _set_class_ndarray(cls): def _set_class_ndarray(cls):
......
...@@ -77,21 +77,34 @@ cdef inline void make_arg(object arg, ...@@ -77,21 +77,34 @@ cdef inline void make_arg(object arg,
int* tcode, int* tcode,
list temp_args): list temp_args):
"""Pack arguments into c args tvm call accept""" """Pack arguments into c args tvm call accept"""
cdef unsigned long long ptr
if isinstance(arg, NodeBase): if isinstance(arg, NodeBase):
value[0].v_handle = (<NodeBase>arg).chandle value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
value[0].v_handle = (<NDArrayBase>arg).chandle value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = kArrayHandle tcode[0] = kArrayHandle
elif isinstance(arg, Integral): elif isinstance(arg, _DLTENSOR_COMPATS):
ptr = arg._dltensor_addr
value[0].v_handle = (<void*>ptr)
tcode[0] = kArrayHandle
elif isinstance(arg, (int, long)):
value[0].v_int64 = arg value[0].v_int64 = arg
tcode[0] = kInt tcode[0] = kInt
elif isinstance(arg, Number): elif isinstance(arg, float):
value[0].v_float64 = arg value[0].v_float64 = arg
tcode[0] = kFloat tcode[0] = kFloat
elif isinstance(arg, str):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif arg is None: elif arg is None:
value[0].v_handle = NULL value[0].v_handle = NULL
tcode[0] = kNull tcode[0] = kNull
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType): elif isinstance(arg, TVMType):
tstr = c_str(str(arg)) tstr = c_str(str(arg))
value[0].v_str = tstr value[0].v_str = tstr
...@@ -167,9 +180,9 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -167,9 +180,9 @@ cdef inline object make_ret(TVMValue value, int tcode):
raise ValueError("Unhandled type code %d" % tcode) raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall2(void* chandle, tuple args, int nargs): cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
cdef TVMValue[2] values cdef TVMValue[3] values
cdef int[2] tcodes cdef int[3] tcodes
cdef TVMValue ret_val cdef TVMValue ret_val
cdef int ret_code cdef int ret_code
nargs = len(args) nargs = len(args)
...@@ -183,8 +196,8 @@ cdef inline object FuncCall2(void* chandle, tuple args, int nargs): ...@@ -183,8 +196,8 @@ cdef inline object FuncCall2(void* chandle, tuple args, int nargs):
cdef inline object FuncCall(void* chandle, tuple args): cdef inline object FuncCall(void* chandle, tuple args):
cdef int nargs cdef int nargs
nargs = len(args) nargs = len(args)
if nargs <= 2: if nargs <= 3:
return FuncCall2(chandle, args, nargs) return FuncCall3(chandle, args, nargs)
cdef vector[TVMValue] values cdef vector[TVMValue] values
cdef vector[int] tcodes cdef vector[int] tcodes
......
...@@ -12,6 +12,10 @@ cdef class NDArrayBase: ...@@ -12,6 +12,10 @@ cdef class NDArrayBase:
ptr = ctypes.addressof(handle.contents) ptr = ctypes.addressof(handle.contents)
self.chandle = <DLTensor*>(ptr) self.chandle = <DLTensor*>(ptr)
property _dltensor_addr:
def __get__(self):
return <unsigned long long>self.chandle
property handle: property handle:
def __get__(self): def __get__(self):
if self.chandle == NULL: if self.chandle == NULL:
...@@ -23,12 +27,10 @@ cdef class NDArrayBase: ...@@ -23,12 +27,10 @@ cdef class NDArrayBase:
def __set__(self, value): def __set__(self, value):
self._set_handle(value) self._set_handle(value)
def __init__(self, handle, is_view): def __init__(self, handle, is_view):
self._set_handle(handle) self._set_handle(handle)
self.c_is_view = is_view self.c_is_view = is_view
def __dealloc__(self): def __dealloc__(self):
if self.c_is_view == 0: if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle)) CALL(TVMArrayFree(self.chandle))
...@@ -39,12 +41,17 @@ cdef c_make_array(void* chandle, is_view): ...@@ -39,12 +41,17 @@ cdef c_make_array(void* chandle, is_view):
(<NDArrayBase>ret).chandle = <DLTensor*>chandle (<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret return ret
cdef _DLTENSOR_COMPATS = ()
def _reg_dltensor(cls):
global _DLTENSOR_COMPATS
_DLTENSOR_COMPATS += (cls,)
def _make_array(handle, is_view): def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle) handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view) return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None cdef object _CLASS_NDARRAY = None
def _set_class_ndarray(cls): def _set_class_ndarray(cls):
global _CLASS_NDARRAY global _CLASS_NDARRAY
......
...@@ -16,12 +16,15 @@ try: ...@@ -16,12 +16,15 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase from ._cy3.core import _set_class_ndarray, _reg_dltensor, _make_array
from ._cy3.core import NDArrayBase as _NDArrayBase
else: else:
from ._cy2.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase from ._cy2.core import _set_class_ndarray, _reg_dltensor, _make_array
from ._cy2.core import NDArrayBase as _NDArrayBase
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase from ._ctypes.ndarray import _set_class_ndarray, _reg_dltensor, _make_array
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
def context(dev_type, dev_id=0): def context(dev_type, dev_id=0):
...@@ -192,3 +195,44 @@ class NDArrayBase(_NDArrayBase): ...@@ -192,3 +195,44 @@ class NDArrayBase(_NDArrayBase):
else: else:
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def register_dltensor(cls):
"""Register a DLTensor compatible class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as DLTensor compatible.
Note
----
The registered class is requires a property _dltensor_addr,
which returns an integer that represents the address of DLTensor.
Returns
-------
cls : class
The class being registered.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_dltensor
class MyTensor(object):
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _dltensor_addr(self):
return self.handle.value
"""
_reg_dltensor(cls)
return cls
...@@ -9,7 +9,7 @@ import numpy as _np ...@@ -9,7 +9,7 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty from ._ffi.ndarray import context, empty
from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import _set_class_ndarray, register_dltensor
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime. """Lightweight NDArray class of TVM runtime.
......
import tvm
import numpy as np
@tvm.register_dltensor
class MyTensorView(object):
def __init__(self, arr):
self.arr = arr
@property
def _dltensor_addr(self):
return self.arr._dltensor_addr
def test_dltensor_compatible():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n,), dtype)
i = tvm.var('i')
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
f = tvm.codegen.build_module(fapi, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
f(aview)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
if __name__ == "__main__":
test_dltensor_compatible()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment