ndarray.pxi 1.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
from ..runtime_ctypes import TVMArrayHandle

cdef class NDArrayBase:
    cdef DLTensor* chandle
    cdef int c_is_view

    cdef inline _set_handle(self, handle):
        cdef unsigned long long ptr
        if handle is None:
            self.chandle = NULL
        else:
12
            ptr = ctypes.cast(handle, ctypes.c_void_p).value
13 14
            self.chandle = <DLTensor*>(ptr)

15
    property _tvm_handle:
16 17 18
        def __get__(self):
            return <unsigned long long>self.chandle

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    property handle:
        def __get__(self):
            if self.chandle == NULL:
                return None
            else:
                return ctypes.cast(
                    <unsigned long long>self.chandle, TVMArrayHandle)

        def __set__(self, value):
            self._set_handle(value)

    def __init__(self, handle, is_view):
        self._set_handle(handle)
        self.c_is_view = is_view

    def __dealloc__(self):
        if self.c_is_view == 0:
            CALL(TVMArrayFree(self.chandle))


cdef c_make_array(void* chandle, is_view):
    ret = _CLASS_NDARRAY(None, is_view)
    (<NDArrayBase>ret).chandle = <DLTensor*>chandle
    return ret

44
cdef _TVM_COMPATS = ()
45

46 47 48
cdef _TVM_EXT_RET = {}

def _reg_extension(cls, fcreate):
49 50
    global _TVM_COMPATS
    _TVM_COMPATS += (cls,)
51 52 53
    if fcreate:
        _TVM_EXT_RET[cls._tvm_tcode] = fcreate

54 55

def _make_array(handle, is_view):
56 57 58
    cdef unsigned long long ptr
    ptr = ctypes.cast(handle, ctypes.c_void_p).value
    return c_make_array(<void*>ptr, is_view)
59

60
cdef object _CLASS_NDARRAY = None
61 62 63 64

def _set_class_ndarray(cls):
    global _CLASS_NDARRAY
    _CLASS_NDARRAY = cls