Commit 2607a836 by Tianqi Chen Committed by GitHub

[RUNTIME][PYTHON] More compatibility in ndarray (#463)

parent 0220abba
...@@ -165,6 +165,10 @@ class NDArrayBase(_NDArrayBase): ...@@ -165,6 +165,10 @@ class NDArrayBase(_NDArrayBase):
arr : NDArray arr : NDArray
Reference to self. Reference to self.
""" """
if isinstance(source_array, NDArrayBase):
source_array.copyto(self)
return self
if not isinstance(source_array, np.ndarray): if not isinstance(source_array, np.ndarray):
try: try:
source_array = np.array(source_array, dtype=self.dtype) source_array = np.array(source_array, dtype=self.dtype)
...@@ -187,6 +191,14 @@ class NDArrayBase(_NDArrayBase): ...@@ -187,6 +191,14 @@ class NDArrayBase(_NDArrayBase):
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes)) check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self return self
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
def __str__(self):
return str(self.asnumpy())
def asnumpy(self): def asnumpy(self):
"""Convert this array to numpy array """Convert this array to numpy array
......
...@@ -140,7 +140,7 @@ def array(arr, ctx=cpu(0)): ...@@ -140,7 +140,7 @@ def array(arr, ctx=cpu(0)):
ret : NDArray ret : NDArray
The created array The created array
""" """
if not isinstance(arr, _np.ndarray): if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr) arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr) return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment