Commit 2607a836 by Tianqi Chen Committed by GitHub

[RUNTIME][PYTHON] More compatibility in ndarray (#463)

parent 0220abba
......@@ -165,6 +165,10 @@ class NDArrayBase(_NDArrayBase):
arr : NDArray
Reference to self.
"""
if isinstance(source_array, NDArrayBase):
source_array.copyto(self)
return self
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
......@@ -187,6 +191,14 @@ class NDArrayBase(_NDArrayBase):
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
def __str__(self):
return str(self.asnumpy())
def asnumpy(self):
"""Convert this array to numpy array
......
......@@ -140,7 +140,7 @@ def array(arr, ctx=cpu(0)):
ret : NDArray
The created array
"""
if not isinstance(arr, _np.ndarray):
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment