# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..runtime_ctypes import DataType, TVMContext, TVMByteArray


cdef void tvm_callback_finalize(void* fhandle):
    local_pyfunc = <object>(fhandle)
    Py_DECREF(local_pyfunc)

cdef int tvm_callback(TVMValue* args,
                      int* type_codes,
                      int num_args,
                      TVMRetValueHandle ret,
                      void* fhandle) with gil:
    cdef list pyargs
    cdef TVMValue value
    cdef int tcode
    local_pyfunc = <object>(fhandle)
    pyargs = []
    for i in range(num_args):
        value = args[i]
        tcode = type_codes[i]
        if (tcode == kTVMObjectHandle or
            tcode == kTVMPackedFuncHandle or
            tcode == kTVMModuleHandle or
            tcode > kTVMExtBegin):
            CALL(TVMCbArgToReturn(&value, tcode))

        if tcode != kTVMDLTensorHandle:
            pyargs.append(make_ret(value, tcode))
        else:
            pyargs.append(c_make_array(value.v_handle, True, False))
    try:
        rv = local_pyfunc(*pyargs)
    except Exception:
        msg = traceback.format_exc()
        msg = py2cerror(msg)
        TVMAPISetLastError(c_str(msg))
        return -1
    if rv is not None:
        if isinstance(rv, tuple):
            raise ValueError("PackedFunction can only support one return value")
        temp_args = []
        make_arg(rv, &value, &tcode, temp_args)
        CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1))
    return 0


cdef object make_packed_func(TVMPackedFuncHandle chandle, int is_global):
    obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
    (<PackedFuncBase>obj).chandle = chandle
    (<PackedFuncBase>obj).is_global = is_global
    return obj


def convert_to_tvm_func(object pyfunc):
    """Convert a python function to TVM function

    Parameters
    ----------
    pyfunc : python function
        The python function to be converted.

    Returns
    -------
    tvmfunc: tvm.Function
        The converted tvm function.
    """
    cdef TVMPackedFuncHandle chandle
    Py_INCREF(pyfunc)
    CALL(TVMFuncCreateFromCFunc(tvm_callback,
                                <void*>(pyfunc),
                                tvm_callback_finalize,
                                &chandle))
    return make_packed_func(chandle, False)


cdef inline int make_arg(object arg,
                         TVMValue* value,
                         int* tcode,
                         list temp_args) except -1:
    """Pack arguments into c args tvm call accept"""
    cdef unsigned long long ptr
    if isinstance(arg, ObjectBase):
        value[0].v_handle = (<ObjectBase>arg).chandle
        tcode[0] = kTVMObjectHandle
    elif isinstance(arg, NDArrayBase):
        value[0].v_handle = (<NDArrayBase>arg).chandle
        tcode[0] = (kTVMNDArrayHandle if
                    not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
    elif isinstance(arg, _TVM_COMPATS):
        ptr = arg._tvm_handle
        value[0].v_handle = (<void*>ptr)
        tcode[0] = arg.__class__._tvm_tcode
    elif isinstance(arg, (int, long)):
        value[0].v_int64 = arg
        tcode[0] = kInt
    elif isinstance(arg, float):
        value[0].v_float64 = arg
        tcode[0] = kFloat
    elif isinstance(arg, str):
        tstr = c_str(arg)
        value[0].v_str = tstr
        tcode[0] = kTVMStr
        temp_args.append(tstr)
    elif arg is None:
        value[0].v_handle = NULL
        tcode[0] = kTVMNullptr
    elif isinstance(arg, Number):
        value[0].v_float64 = arg
        tcode[0] = kFloat
    elif isinstance(arg, DataType):
        tstr = c_str(str(arg))
        value[0].v_str = tstr
        tcode[0] = kTVMStr
        temp_args.append(tstr)
    elif isinstance(arg, TVMContext):
        value[0].v_ctx = (<DLContext*>(
            <unsigned long long>ctypes.addressof(arg)))[0]
        tcode[0] = kTVMContext
    elif isinstance(arg, bytearray):
        arr = TVMByteArray()
        arr.data = ctypes.cast(
            (ctypes.c_byte * len(arg)).from_buffer(arg),
            ctypes.POINTER(ctypes.c_byte))
        arr.size = len(arg)
        value[0].v_handle = <void*>(
            <unsigned long long>ctypes.addressof(arr))
        tcode[0] = kTVMBytes
        temp_args.append(arr)
    elif isinstance(arg, string_types):
        tstr = c_str(arg)
        value[0].v_str = tstr
        tcode[0] = kTVMStr
        temp_args.append(tstr)
    elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
        arg = _FUNC_CONVERT_TO_OBJECT(arg)
        value[0].v_handle = (<ObjectBase>arg).chandle
        tcode[0] = kTVMObjectHandle
        temp_args.append(arg)
    elif isinstance(arg, _CLASS_MODULE):
        value[0].v_handle = c_handle(arg.handle)
        tcode[0] = kTVMModuleHandle
    elif isinstance(arg, PackedFuncBase):
        value[0].v_handle = (<PackedFuncBase>arg).chandle
        tcode[0] = kTVMPackedFuncHandle
    elif isinstance(arg, ctypes.c_void_p):
        value[0].v_handle = c_handle(arg)
        tcode[0] = kTVMOpaqueHandle
    elif callable(arg):
        arg = convert_to_tvm_func(arg)
        value[0].v_handle = (<PackedFuncBase>arg).chandle
        tcode[0] = kTVMPackedFuncHandle
        temp_args.append(arg)
    else:
        raise TypeError("Don't know how to handle type %s" % type(arg))
    return 0


cdef inline bytearray make_ret_bytes(void* chandle):
    handle = ctypes_handle(chandle)
    arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
    size = arr.size
    res = bytearray(size)
    rptr = (ctypes.c_byte * size).from_buffer(res)
    if not ctypes.memmove(rptr, arr.data, size):
        raise RuntimeError('memmove failed')
    return res


cdef inline object make_ret(TVMValue value, int tcode):
    """convert result to return value."""
    if tcode == kTVMObjectHandle:
        return make_ret_object(value.v_handle)
    elif tcode == kTVMNullptr:
        return None
    elif tcode == kInt:
        return value.v_int64
    elif tcode == kFloat:
        return value.v_float64
    elif tcode == kTVMNDArrayHandle:
        return c_make_array(value.v_handle, False, True)
    elif tcode == kTVMStr:
        return py_str(value.v_str)
    elif tcode == kTVMBytes:
        return make_ret_bytes(value.v_handle)
    elif tcode == kTVMOpaqueHandle:
        return ctypes_handle(value.v_handle)
    elif tcode == kTVMContext:
        return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
    elif tcode == kTVMModuleHandle:
        return _CLASS_MODULE(ctypes_handle(value.v_handle))
    elif tcode == kTVMPackedFuncHandle:
        return make_packed_func(value.v_handle, False)
    elif tcode in _TVM_EXT_RET:
        return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))

    raise ValueError("Unhandled type code %d" % tcode)


cdef inline int FuncCall3(void* chandle,
                          tuple args,
                          int nargs,
                          TVMValue* ret_val,
                          int* ret_tcode) except -1:
    cdef TVMValue[3] values
    cdef int[3] tcodes
    nargs = len(args)
    temp_args = []
    for i in range(nargs):
        make_arg(args[i], &values[i], &tcodes[i], temp_args)
    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
                     nargs, ret_val, ret_tcode))
    return 0

cdef inline int FuncCall(void* chandle,
                         tuple args,
                         TVMValue* ret_val,
                         int* ret_tcode) except -1:
    cdef int nargs
    nargs = len(args)
    if nargs <= 3:
        FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
        return 0

    cdef vector[TVMValue] values
    cdef vector[int] tcodes
    values.resize(max(nargs, 1))
    tcodes.resize(max(nargs, 1))
    temp_args = []
    for i in range(nargs):
        make_arg(args[i], &values[i], &tcodes[i], temp_args)
    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
                     nargs, ret_val, ret_tcode))
    return 0


cdef inline int ConstructorCall(void* constructor_handle,
                                int type_code,
                                tuple args,
                                void** handle) except -1:
    """Call contructor of a handle function"""
    cdef TVMValue ret_val
    cdef int ret_tcode
    FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
    assert ret_tcode == type_code
    handle[0] = ret_val.v_handle
    return 0


cdef class PackedFuncBase:
    cdef TVMPackedFuncHandle chandle
    cdef int is_global

    cdef inline _set_handle(self, handle):
        if handle is None:
            self.chandle = NULL
        else:
            self.chandle = c_handle(handle)

    property is_global:
        def __get__(self):
            return self.c_is_global != 0

        def __set__(self, value):
            self.c_is_global = value

    property handle:
        def __get__(self):
            if self.chandle == NULL:
                return None
            else:
                return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
        def __set__(self, value):
            self._set_handle(value)

    def __init__(self, handle, is_global):
        self._set_handle(handle)
        self.c_is_global = is_global

    def __dealloc__(self):
        if self.is_global == 0:
            CALL(TVMFuncFree(self.chandle))

    def __call__(self, *args):
        cdef TVMValue ret_val
        cdef int ret_tcode
        FuncCall(self.chandle, args, &ret_val, &ret_tcode)
        return make_ret(ret_val, ret_tcode)


def _get_global_func(name, allow_missing):
    cdef TVMPackedFuncHandle chandle
    CALL(TVMFuncGetGlobal(c_str(name), &chandle))
    if chandle != NULL:
        return make_packed_func(chandle, True)

    if allow_missing:
       return None

    raise ValueError("Cannot find global function %s" % name)


_CLASS_PACKED_FUNC = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
_CLASS_OBJECT_GENERIC = None
_FUNC_CONVERT_TO_OBJECT = None

def _set_class_module(module_class):
    """Initialize the module."""
    global _CLASS_MODULE
    _CLASS_MODULE = module_class

def _set_class_packed_func(func_class):
    global _CLASS_PACKED_FUNC
    _CLASS_PACKED_FUNC = func_class

def _set_class_object(obj_class):
    global _CLASS_OBJECT
    _CLASS_OBJECT = obj_class

def _set_class_object_generic(object_generic_class, func_convert_to_object):
    global _CLASS_OBJECT_GENERIC
    global _FUNC_CONVERT_TO_OBJECT
    _CLASS_OBJECT_GENERIC = object_generic_class
    _FUNC_CONVERT_TO_OBJECT = func_convert_to_object