Commit bc5367a0 by Tianqi Chen Committed by ziheng

[PYTHON][FFI] Cythonize NDArray.copyto (#4549)

* [PYTHON][FFI] Cythonize NDArray.copyto

* Cythonize the shape property
parent ce0b6d5a
......@@ -85,6 +85,16 @@ class NDArrayBase(object):
def _tvm_handle(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value
def _copyto(self, target_nd):
"""Internal function that implements copy to target ndarray."""
check_call(_LIB.TVMArrayCopyFromTo(self.handle, target_nd.handle, None))
return target_nd
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
......
......@@ -68,6 +68,11 @@ cdef class NDArrayBase:
def __set__(self, value):
self._set_handle(value)
@property
def shape(self):
"""Shape of this array"""
return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim))
def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view
......@@ -76,6 +81,11 @@ cdef class NDArrayBase:
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))
def _copyto(self, target_nd):
"""Internal function that implements copy to target ndarray."""
CALL(TVMArrayCopyFromTo(self.chandle, (<NDArrayBase>target_nd).chandle, NULL))
return target_nd
def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
......
......@@ -157,10 +157,6 @@ def from_dlpack(dltensor):
class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
@property
def dtype(self):
......@@ -240,6 +236,7 @@ class NDArrayBase(_NDArrayBase):
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
......@@ -294,14 +291,12 @@ class NDArrayBase(_NDArrayBase):
target : NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, TVMContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.TVMArrayCopyFromTo(
self.handle, target.handle, None))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
return self._copyto(target)
elif isinstance(target, TVMContext):
res = empty(self.shape, self.dtype, target)
return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target)))
def free_extension_handle(handle, type_code):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment