Commit f52255b9 by eqy Committed by Tianqi Chen

DLPack Conversion API (#1573)

parent b0368e33
......@@ -446,6 +446,32 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMStreamHandle stream);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out);
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*!
* \brief Create a new runtime stream.
*
* \param device_type The device type of context
......
......@@ -155,7 +155,7 @@ class NDArray {
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
*
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
TVM_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
......
......@@ -5,7 +5,7 @@ from __future__ import absolute_import
import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
......@@ -28,6 +28,17 @@ except IMPORT_EXCEPT:
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
# used for PyCapsule manipulation
if hasattr(ctypes, 'pythonapi'):
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.
......@@ -62,6 +73,7 @@ def context(dev_type, dev_id=0):
dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id)
def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array.
"""
......@@ -112,6 +124,42 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
ctypes.byref(handle)))
return _make_array(handle, False)
def from_dlpack(dltensor):
"""Produce an array from a DLPack tensor without memory copy.
Retreives the underlying DLPack tensor's pointer to create an array from the
data. Removes the original DLPack tensor's destructor as now the array is
responsible for destruction.
Parameters
----------
dltensor : DLPack tensor
Returns
-------
arr: tvm.nd.NDArray
The array view of the tensor data.
"""
dltensor = ctypes.py_object(dltensor)
name = ctypes.pythonapi.PyCapsule_GetName(dltensor)
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, name)
handle = TVMArrayHandle()
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, None)
return _make_array(handle, False)
def _dlpack_deleter(pycapsule):
pycapsule = ctypes.py_object(pycapsule)
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
_LIB.TVMDLManagedTensorCallDeleter(ptr)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
......@@ -260,6 +308,18 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Returns
-------
dlpack : DLPack tensor view of the array data
"""
handle = ctypes.c_void_p()
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
def free_extension_handle(handle, type_code):
"""Free c++ extension type handle
......
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from .. import ndarray
def convert_func(tvm_func, tensor_type, to_dlpack_func):
"""Convert a tvm function into one that accepts a tensor from another
framework, provided the other framework supports DLPACK
Parameters
----------
tvm_func: Function
Built tvm function operating on arrays
tensor_type: Type
Type of the tensors of the target framework
to_dlpack_func: Function
Function to convert the source tensors to DLPACK
"""
assert callable(tvm_func)
def _wrapper(*args):
args = tuple(ndarray.from_dlpack(to_dlpack_func(arg))\
if isinstance(arg, tensor_type) else arg for arg in args)
return tvm_func(*args)
return _wrapper
def to_pytorch_func(tvm_func):
"""Convert a tvm function into one that accepts PyTorch tensors
Parameters
----------
tvm_func: Function
Built tvm function operating on arrays
Returns
-------
wrapped_func: Function
Wrapped tvm function that operates on PyTorch tensors
"""
import torch
import torch.utils.dlpack
return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)
......@@ -8,7 +8,7 @@ from __future__ import absolute_import as _abs
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty
from ._ffi.ndarray import context, empty, from_dlpack
from ._ffi.ndarray import _set_class_ndarray
from ._ffi.ndarray import register_extension, free_extension_handle
......
......@@ -93,6 +93,16 @@ struct NDArray::Internal {
arr.data_ = nullptr;
return tensor;
}
// Container to DLManagedTensor
static DLManagedTensor* ToDLPack(NDArray::Container* from) {
CHECK(from != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = from->dl_tensor;
ret->manager_ctx = from;
from->IncRef();
ret->deleter = NDArrayDLPackDeleter;
return ret;
}
};
NDArray NDArray::CreateView(std::vector<int64_t> shape,
......@@ -115,13 +125,7 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape,
}
DLManagedTensor* NDArray::ToDLPack() const {
CHECK(data_ != nullptr);
DLManagedTensor* ret = new DLManagedTensor();
ret->dl_tensor = data_->dl_tensor;
ret->manager_ctx = const_cast<NDArray*>(this);
data_->IncRef();
ret->deleter = NDArrayDLPackDeleter;
return ret;
return Internal::ToDLPack(data_);
}
NDArray NDArray::Empty(std::vector<int64_t> shape,
......@@ -213,6 +217,24 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
API_END();
}
int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out) {
API_BEGIN();
*out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from));
API_END();
}
int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out) {
API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from));
API_END();
}
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) {
(*(dltensor->deleter))(dltensor);
}
int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes) {
......
import tvm
import numpy as np
from tvm.contrib.dlpack import to_pytorch_func
def test():
a = np.random.randn(1337)
tvm_a = tvm.nd.array(a)
np.testing.assert_equal(tvm.nd.from_dlpack(tvm_a.to_dlpack()).asnumpy(), a)
try:
import torch
import torch.utils.dlpack
x = torch.rand(56, 56)
tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x))
np.testing.assert_equal(x.numpy(), tvm_x.asnumpy())
y = tvm.nd.from_dlpack(tvm_x.to_dlpack())
np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy())
np.testing.assert_equal(torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy())
n = tvm.convert(137)
xx = torch.rand(137,137)
yy = torch.rand(137,137)
zz2 = torch.empty(137,137)
zz = xx.mm(yy)
XX = tvm.placeholder((n,n), name='X')
YY = tvm.placeholder((n,n), name='Y')
k = tvm.reduce_axis((0, n), name='k')
ZZ = tvm.compute((n,n), lambda i,j : tvm.sum(XX[i,k]*YY[k,j], axis=k))
s = tvm.create_schedule(ZZ.op)
f = tvm.build(s, [XX, YY, ZZ], target_host='llvm', name='f')
f_pytorch = to_pytorch_func(f)
zz2 = torch.empty(137,137)
f_pytorch(xx, yy, zz2)
np.testing.assert_allclose(zz.numpy(), zz2.numpy(), rtol=1e-6)
except ImportError:
pass
if __name__ == '__main__':
test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment