ndarray.py 8.94 KB
Newer Older
1
# pylint: disable=invalid-name, unused-import
2
"""Runtime NDArray api"""
3
from __future__ import absolute_import
4 5

import sys
6 7
import ctypes
import numpy as np
8
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE
9 10
from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle
from .runtime_ctypes import TypeCode, tvm_shape_index_t
11

12

13
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
14

15 16 17 18 19
try:
    # pylint: disable=wrong-import-position
    if _FFI_MODE == "ctypes":
        raise ImportError()
    if sys.version_info >= (3, 0):
20
        from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array
21
        from ._cy3.core import NDArrayBase as _NDArrayBase
22
    else:
23
        from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array
24
        from ._cy2.core import NDArrayBase as _NDArrayBase
25 26
except IMPORT_EXCEPT:
    # pylint: disable=wrong-import-position
27
    from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array
28
    from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
29

30

31 32
def context(dev_type, dev_id=0):
    """Construct a TVM context with given device type and id.
33 34 35

    Parameters
    ----------
36 37
    dev_type: int or str
        The device type mask or name of the device.
38 39 40 41

    dev_id : int, optional
        The integer device id

42 43 44 45
    Returns
    -------
    ctx: TVMContext
        The corresponding context.
46

47 48 49 50
    Examples
    --------
    Context can be used to create reflection of context by
    string representation of the device type.
51

52
    .. code-block:: python
53

54 55 56
      assert tvm.context("cpu", 1) == tvm.cpu(1)
      assert tvm.context("gpu", 0) == tvm.gpu(0)
      assert tvm.context("cuda", 0) == tvm.gpu(0)
57
    """
58
    if isinstance(dev_type, string_types):
59
        dev_type = dev_type.split()[0]
60
        if dev_type not in TVMContext.STR2MASK:
61 62 63
            raise ValueError("Unknown device type %s" % dev_type)
        dev_type = TVMContext.STR2MASK[dev_type]
    return TVMContext(dev_type, dev_id)
64

65 66 67 68 69 70
def numpyasarray(np_data):
    """Return a TVMArray representation of a numpy array.
    """
    data = np_data
    assert data.flags['C_CONTIGUOUS']
    arr = TVMArray()
71
    shape = c_array(tvm_shape_index_t, data.shape)
72 73 74
    arr.data = data.ctypes.data_as(ctypes.c_void_p)
    arr.shape = shape
    arr.strides = None
75
    arr.dtype = TVMType(np.dtype(data.dtype).name)
76 77
    arr.ndim = data.ndim
    # CPU device
78
    arr.ctx = context(1, 0)
79 80
    return arr, shape

81 82

def empty(shape, dtype="float32", ctx=context(1, 0)):
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    """Create an empty array given shape and device

    Parameters
    ----------
    shape : tuple of int
        The shape of the array

    dtype : type or str
        The data type of the array.

    ctx : TVMContext
        The context of the array

    Returns
    -------
    arr : tvm.nd.NDArray
        The array tvm supported.
    """
101 102
    shape = c_array(tvm_shape_index_t, shape)
    ndim = ctypes.c_int(len(shape))
103
    handle = TVMArrayHandle()
104
    dtype = TVMType(dtype)
105
    check_call(_LIB.TVMArrayAlloc(
106 107 108 109 110 111 112
        shape, ndim,
        ctypes.c_int(dtype.type_code),
        ctypes.c_int(dtype.bits),
        ctypes.c_int(dtype.lanes),
        ctx.device_type,
        ctx.device_id,
        ctypes.byref(handle)))
113
    return _make_array(handle, False)
114

115
class NDArrayBase(_NDArrayBase):
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    """A simple Device/CPU Array object in runtime."""
    @property
    def shape(self):
        """Shape of this array"""
        return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))

    @property
    def dtype(self):
        """Type of this array"""
        return str(self.handle.contents.dtype)

    @property
    def ctx(self):
        """context of this array"""
        return self.handle.contents.ctx

    @property
    def context(self):
        """context of this array"""
        return self.ctx

    def __setitem__(self, in_slice, value):
        """Set ndarray value"""
        if (not isinstance(in_slice, slice) or
                in_slice.start is not None
                or in_slice.stop is not None):
            raise ValueError('Array only support set from numpy array')
        if isinstance(value, NDArrayBase):
            if value.handle is not self.handle:
                value.copyto(self)
        elif isinstance(value, (np.ndarray, np.generic)):
147
            self.copyfrom(value)
148 149 150
        else:
            raise TypeError('type %s not supported' % str(type(value)))

151
    def copyfrom(self, source_array):
152 153 154 155 156 157
        """Peform an synchronize copy from the array.

        Parameters
        ----------
        source_array : array_like
            The data source we should like to copy from.
158 159 160 161 162

        Returns
        -------
        arr : NDArray
            Reference to self.
163
        """
164 165 166 167
        if isinstance(source_array, NDArrayBase):
            source_array.copyto(self)
            return self

168 169 170 171 172 173
        if not isinstance(source_array, np.ndarray):
            try:
                source_array = np.array(source_array, dtype=self.dtype)
            except:
                raise TypeError('array must be an array_like data,' +
                                'type %s is not supported' % str(type(source_array)))
174 175 176 177 178 179 180 181 182 183
        t = TVMType(self.dtype)
        shape, dtype = self.shape, self.dtype
        if t.lanes > 1:
            shape = shape + (t.lanes,)
            t.lanes = 1
            dtype = str(t)
        source_array = np.ascontiguousarray(source_array, dtype=dtype)
        if source_array.shape != shape:
            raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
                source_array.shape, shape))
184 185 186 187
        assert source_array.flags['C_CONTIGUOUS']
        data = source_array.ctypes.data_as(ctypes.c_void_p)
        nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize)
        check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
188
        return self
189

190 191 192 193 194 195 196 197
    def __repr__(self):
        res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
        res += self.asnumpy().__repr__()
        return res

    def __str__(self):
        return str(self.asnumpy())

198 199 200 201 202 203 204 205
    def asnumpy(self):
        """Convert this array to numpy array

        Returns
        -------
        np_arr : numpy.ndarray
            The corresponding numpy array.
        """
206 207 208 209 210 211 212
        t = TVMType(self.dtype)
        shape, dtype = self.shape, self.dtype
        if t.lanes > 1:
            shape = shape + (t.lanes,)
            t.lanes = 1
            dtype = str(t)
        np_arr = np.empty(shape, dtype=dtype)
213 214 215 216
        assert np_arr.flags['C_CONTIGUOUS']
        data = np_arr.ctypes.data_as(ctypes.c_void_p)
        nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize)
        check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
217 218 219 220 221 222 223
        return np_arr

    def copyto(self, target):
        """Copy array to target

        Parameters
        ----------
224
        target : NDArray
225 226 227 228 229 230 231 232 233 234
            The target array to be copied, must have same shape as this array.
        """
        if isinstance(target, TVMContext):
            target = empty(self.shape, self.dtype, target)
        if isinstance(target, NDArrayBase):
            check_call(_LIB.TVMArrayCopyFromTo(
                self.handle, target.handle, None))
        else:
            raise ValueError("Unsupported target type %s" % str(type(target)))
        return target
235

236 237
def free_extension_handle(handle, type_code):
    """Free c++ extension type handle
238

239 240 241 242 243 244 245 246 247 248 249 250
    Parameters
    ----------
    handle : ctypes.c_void_p
        The handle to the extension type.

    type_code : int
         The tyoe code
    """
    check_call(_LIB.TVMExtTypeFree(handle, ctypes.c_int(type_code)))

def register_extension(cls, fcreate=None):
    """Register a extension class to TVM.
251 252 253 254 255 256 257

    After the class is registered, the class will be able
    to directly pass as Function argument generated by TVM.

    Parameters
    ----------
    cls : class
258
        The class object to be registered as extension.
259 260 261

    Note
    ----
262
    The registered class is requires one property: _tvm_handle and a class attribute _tvm_tcode.
263 264

    - ```_tvm_handle``` returns integer represents the address of the handle.
265
    - ```_tvm_tcode``` gives integer represents type code of the class.
266 267 268 269 270 271

    Returns
    -------
    cls : class
        The class being registered.

272 273 274
    fcreate : function, optional
        The creation function to create a class object given handle value.

275 276 277 278 279 280 281
    Example
    -------
    The following code registers user defined class
    MyTensor to be DLTensor compatible.

    .. code-block:: python

282
       @tvm.register_extension
283
       class MyTensor(object):
284 285
           _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE

286 287 288 289
           def __init__(self):
               self.handle = _LIB.NewDLTensor()

           @property
290
           def _tvm_handle(self):
291 292
               return self.handle.value
    """
293 294 295
    if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN:
        raise ValueError("Cannot register create when extension tcode is same as buildin")
    _reg_extension(cls, fcreate)
296
    return cls