Commit 553657eb by Tianqi Chen Committed by GitHub

[PYTHON] Support DLTensor compatible API (#136)

* [PYTHON] Support DLTensor compatible API

* optimize for common path
parent e3695cad
......@@ -14,3 +14,5 @@ tvm.ndarray
.. autofunction:: tvm.opencl
.. autofunction:: tvm.metal
.. autofunction:: tvm.ndarray.array
.. autofunction:: tvm.register_dltensor
......@@ -23,5 +23,6 @@ from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .node import register_node
from .ndarray import register_dltensor
from .schedule import create_schedule
from .build import build, lower
......@@ -11,6 +11,7 @@ from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMByteArray
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
......@@ -94,6 +95,9 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, _nd._DLTENSOR_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._dltensor_addr)
type_codes[i] = TypeCode.ARRAY_HANDLE
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = TypeCode.INT
......
......@@ -24,10 +24,20 @@ class NDArrayBase(object):
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle))
@property
def _dltensor_addr(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_DLTENSOR_COMPATS = ()
def _reg_dltensor(cls):
global _DLTENSOR_COMPATS
_DLTENSOR_COMPATS += (cls,)
_CLASS_NDARRAY = None
def _set_class_ndarray(cls):
......
......@@ -77,21 +77,34 @@ cdef inline void make_arg(object arg,
int* tcode,
list temp_args):
"""Pack arguments into c args tvm call accept"""
cdef unsigned long long ptr
if isinstance(arg, NodeBase):
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = kArrayHandle
elif isinstance(arg, Integral):
elif isinstance(arg, _DLTENSOR_COMPATS):
ptr = arg._dltensor_addr
value[0].v_handle = (<void*>ptr)
tcode[0] = kArrayHandle
elif isinstance(arg, (int, long)):
value[0].v_int64 = arg
tcode[0] = kInt
elif isinstance(arg, Number):
elif isinstance(arg, float):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, str):
tstr = c_str(arg)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif arg is None:
value[0].v_handle = NULL
tcode[0] = kNull
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType):
tstr = c_str(str(arg))
value[0].v_str = tstr
......@@ -167,9 +180,9 @@ cdef inline object make_ret(TVMValue value, int tcode):
raise ValueError("Unhandled type code %d" % tcode)
cdef inline object FuncCall2(void* chandle, tuple args, int nargs):
cdef TVMValue[2] values
cdef int[2] tcodes
cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
cdef TVMValue[3] values
cdef int[3] tcodes
cdef TVMValue ret_val
cdef int ret_code
nargs = len(args)
......@@ -183,8 +196,8 @@ cdef inline object FuncCall2(void* chandle, tuple args, int nargs):
cdef inline object FuncCall(void* chandle, tuple args):
cdef int nargs
nargs = len(args)
if nargs <= 2:
return FuncCall2(chandle, args, nargs)
if nargs <= 3:
return FuncCall3(chandle, args, nargs)
cdef vector[TVMValue] values
cdef vector[int] tcodes
......
......@@ -12,6 +12,10 @@ cdef class NDArrayBase:
ptr = ctypes.addressof(handle.contents)
self.chandle = <DLTensor*>(ptr)
property _dltensor_addr:
def __get__(self):
return <unsigned long long>self.chandle
property handle:
def __get__(self):
if self.chandle == NULL:
......@@ -23,12 +27,10 @@ cdef class NDArrayBase:
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view
def __dealloc__(self):
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))
......@@ -39,12 +41,17 @@ cdef c_make_array(void* chandle, is_view):
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
cdef _DLTENSOR_COMPATS = ()
def _reg_dltensor(cls):
global _DLTENSOR_COMPATS
_DLTENSOR_COMPATS += (cls,)
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None
cdef object _CLASS_NDARRAY = None
def _set_class_ndarray(cls):
global _CLASS_NDARRAY
......
......@@ -16,12 +16,15 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
from ._cy3.core import _set_class_ndarray, _reg_dltensor, _make_array
from ._cy3.core import NDArrayBase as _NDArrayBase
else:
from ._cy2.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
from ._cy2.core import _set_class_ndarray, _reg_dltensor, _make_array
from ._cy2.core import NDArrayBase as _NDArrayBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import _set_class_ndarray, _reg_dltensor, _make_array
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
def context(dev_type, dev_id=0):
......@@ -192,3 +195,44 @@ class NDArrayBase(_NDArrayBase):
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def register_dltensor(cls):
"""Register a DLTensor compatible class to TVM.
After the class is registered, the class will be able
to directly pass as Function argument generated by TVM.
Parameters
----------
cls : class
The class object to be registered as DLTensor compatible.
Note
----
The registered class is requires a property _dltensor_addr,
which returns an integer that represents the address of DLTensor.
Returns
-------
cls : class
The class being registered.
Example
-------
The following code registers user defined class
MyTensor to be DLTensor compatible.
.. code-block:: python
@tvm.register_dltensor
class MyTensor(object):
def __init__(self):
self.handle = _LIB.NewDLTensor()
@property
def _dltensor_addr(self):
return self.handle.value
"""
_reg_dltensor(cls)
return cls
......@@ -9,7 +9,7 @@ import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import _set_class_ndarray, register_dltensor
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
......
import tvm
import numpy as np
@tvm.register_dltensor
class MyTensorView(object):
def __init__(self, arr):
self.arr = arr
@property
def _dltensor_addr(self):
return self.arr._dltensor_addr
def test_dltensor_compatible():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n,), dtype)
i = tvm.var('i')
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
f = tvm.codegen.build_module(fapi, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
f(aview)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
if __name__ == "__main__":
test_dltensor_compatible()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment