Commit bc5367a0 by Tianqi Chen Committed by ziheng

[PYTHON][FFI] Cythonize NDArray.copyto (#4549)

* [PYTHON][FFI] Cythonize NDArray.copyto

* Cythonize the shape property
parent ce0b6d5a
...@@ -85,6 +85,16 @@ class NDArrayBase(object): ...@@ -85,6 +85,16 @@ class NDArrayBase(object):
def _tvm_handle(self): def _tvm_handle(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value return ctypes.cast(self.handle, ctypes.c_void_p).value
def _copyto(self, target_nd):
"""Internal function that implements copy to target ndarray."""
check_call(_LIB.TVMArrayCopyFromTo(self.handle, target_nd.handle, None))
return target_nd
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
def to_dlpack(self): def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory """Produce an array from a DLPack Tensor without copying memory
......
...@@ -68,6 +68,11 @@ cdef class NDArrayBase: ...@@ -68,6 +68,11 @@ cdef class NDArrayBase:
def __set__(self, value): def __set__(self, value):
self._set_handle(value) self._set_handle(value)
@property
def shape(self):
"""Shape of this array"""
return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim))
def __init__(self, handle, is_view): def __init__(self, handle, is_view):
self._set_handle(handle) self._set_handle(handle)
self.c_is_view = is_view self.c_is_view = is_view
...@@ -76,6 +81,11 @@ cdef class NDArrayBase: ...@@ -76,6 +81,11 @@ cdef class NDArrayBase:
if self.c_is_view == 0: if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle)) CALL(TVMArrayFree(self.chandle))
def _copyto(self, target_nd):
"""Internal function that implements copy to target ndarray."""
CALL(TVMArrayCopyFromTo(self.chandle, (<NDArrayBase>target_nd).chandle, NULL))
return target_nd
def to_dlpack(self): def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory """Produce an array from a DLPack Tensor without copying memory
......
...@@ -157,10 +157,6 @@ def from_dlpack(dltensor): ...@@ -157,10 +157,6 @@ def from_dlpack(dltensor):
class NDArrayBase(_NDArrayBase): class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
@property @property
def dtype(self): def dtype(self):
...@@ -240,6 +236,7 @@ class NDArrayBase(_NDArrayBase): ...@@ -240,6 +236,7 @@ class NDArrayBase(_NDArrayBase):
except: except:
raise TypeError('array must be an array_like data,' + raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array))) 'type %s is not supported' % str(type(source_array)))
t = TVMType(self.dtype) t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype shape, dtype = self.shape, self.dtype
if t.lanes > 1: if t.lanes > 1:
...@@ -294,14 +291,12 @@ class NDArrayBase(_NDArrayBase): ...@@ -294,14 +291,12 @@ class NDArrayBase(_NDArrayBase):
target : NDArray target : NDArray
The target array to be copied, must have same shape as this array. The target array to be copied, must have same shape as this array.
""" """
if isinstance(target, TVMContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase): if isinstance(target, NDArrayBase):
check_call(_LIB.TVMArrayCopyFromTo( return self._copyto(target)
self.handle, target.handle, None)) elif isinstance(target, TVMContext):
else: res = empty(self.shape, self.dtype, target)
raise ValueError("Unsupported target type %s" % str(type(target))) return self._copyto(res)
return target raise ValueError("Unsupported target type %s" % str(type(target)))
def free_extension_handle(handle, type_code): def free_extension_handle(handle, type_code):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment