# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from ..runtime_ctypes import TVMArrayHandle

cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor"


cdef void _c_dlpack_deleter(object pycaps):
    cdef DLManagedTensor* dltensor
    if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor):
        dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor)
        TVMDLManagedTensorCallDeleter(dltensor)


def _from_dlpack(object dltensor):
    cdef DLManagedTensor* ptr
    cdef DLTensorHandle chandle
    if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
        ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
        CALL(TVMArrayFromDLPack(ptr, &chandle))
        # set name and destructor to be empty
        pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
        pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
        return c_make_array(chandle, False, False)
    raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")


cdef class NDArrayBase:
    cdef DLTensor* chandle
    cdef int c_is_view

    cdef inline _set_handle(self, handle):
        cdef unsigned long long ptr
        if handle is None:
            self.chandle = NULL
        else:
            ptr = ctypes.cast(handle, ctypes.c_void_p).value
            self.chandle = <DLTensor*>(ptr)

    property _tvm_handle:
        def __get__(self):
            return <unsigned long long>self.chandle

    property handle:
        def __get__(self):
            if self.chandle == NULL:
                return None
            else:
                return ctypes.cast(
                    <unsigned long long>self.chandle, TVMArrayHandle)

        def __set__(self, value):
            self._set_handle(value)

    @property
    def shape(self):
        """Shape of this array"""
        return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim))

    def __init__(self, handle, is_view):
        self._set_handle(handle)
        self.c_is_view = is_view

    def __dealloc__(self):
        if self.c_is_view == 0:
            CALL(TVMArrayFree(self.chandle))

    def _copyto(self, target_nd):
        """Internal function that implements copy to target ndarray."""
        CALL(TVMArrayCopyFromTo(self.chandle, (<NDArrayBase>target_nd).chandle, NULL))
        return target_nd

    def to_dlpack(self):
        """Produce an array from a DLPack Tensor without copying memory

        Returns
        -------
        dlpack : DLPack tensor view of the array data
        """
        cdef DLManagedTensor* dltensor
        if self.c_is_view != 0:
            raise ValueError("to_dlpack do not work with memory views")
        CALL(TVMArrayToDLPack(self.chandle, &dltensor))
        return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)


# Import limited object-related function from C++ side to improve the speed
# NOTE: can only use POD-C compatible object in FFI.
cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime":
    cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle)


cdef c_make_array(void* chandle, is_view, is_container):
    global _TVM_ND_CLS

    if is_container:
        tindex = (
            <TVMObject*>TVMArrayHandleToObjectHandle(<DLTensorHandle>chandle)).type_index_
        if tindex < len(_TVM_ND_CLS):
            cls = _TVM_ND_CLS[tindex]
            if cls is not None:
                ret = cls.__new__(cls)
            else:
                ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
        else:
            ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
        (<NDArrayBase>ret).chandle = <DLTensor*>chandle
        (<NDArrayBase>ret).c_is_view = <int>is_view
        return ret
    else:
        ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY)
        (<NDArrayBase>ret).chandle = <DLTensor*>chandle
        (<NDArrayBase>ret).c_is_view = <int>is_view
        return ret


cdef _TVM_COMPATS = ()

cdef _TVM_EXT_RET = {}

def _reg_extension(cls, fcreate):
    global _TVM_COMPATS
    _TVM_COMPATS += (cls,)
    if fcreate:
        _TVM_EXT_RET[cls._tvm_tcode] = fcreate

cdef list _TVM_ND_CLS = []

cdef _register_ndarray(int index, object cls):
    """register object class"""
    global _TVM_ND_CLS
    while len(_TVM_ND_CLS) <= index:
        _TVM_ND_CLS.append(None)

    _TVM_ND_CLS[index] = cls


def _make_array(handle, is_view, is_container):
    cdef unsigned long long ptr
    ptr = ctypes.cast(handle, ctypes.c_void_p).value
    return c_make_array(<void*>ptr, is_view, is_container)

cdef object _CLASS_NDARRAY = None

def _set_class_ndarray(cls):
    global _CLASS_NDARRAY
    _CLASS_NDARRAY = cls