import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray


cdef void tvm_callback_finalize(void* fhandle):
    local_pyfunc = <object>(fhandle)
    Py_DECREF(local_pyfunc)

cdef int tvm_callback(TVMValue* args,
                      int* type_codes,
                      int num_args,
                      TVMRetValueHandle ret,
                      void* fhandle) with gil:
    cdef list pyargs
    cdef TVMValue value
    cdef int tcode
    local_pyfunc = <object>(fhandle)
    pyargs = []
    for i in range(num_args):
        value = args[i]
        tcode = type_codes[i]
        if (tcode == kNodeHandle or
            tcode == kFuncHandle or
            tcode == kModuleHandle or
            tcode > kExtBegin):
            CALL(TVMCbArgToReturn(&value, tcode))

        if tcode != kArrayHandle:
            pyargs.append(make_ret(value, tcode))
        else:
            pyargs.append(c_make_array(value.v_handle, True))
    try:
        rv = local_pyfunc(*pyargs)
    except Exception:
        msg = traceback.format_exc()
        TVMAPISetLastError(c_str(msg))
        return -1
    if rv is not None:
        if isinstance(rv, tuple):
            raise ValueError("PackedFunction can only support one return value")
        temp_args = []
        make_arg(rv, &value, &tcode, temp_args)
        CALL(TVMCFuncSetReturn(ret, &value, &tcode, 1))
    return 0


def convert_to_tvm_func(object pyfunc):
    """Convert a python function to TVM function

    Parameters
    ----------
    pyfunc : python function
        The python function to be converted.

    Returns
    -------
    tvmfunc: tvm.Function
        The converted tvm function.
    """
    cdef TVMFunctionHandle chandle
    Py_INCREF(pyfunc)
    CALL(TVMFuncCreateFromCFunc(tvm_callback,
                                <void*>(pyfunc),
                                tvm_callback_finalize,
                                &chandle))
    ret = _CLASS_FUNCTION(None, False)
    (<FunctionBase>ret).chandle = chandle
    return ret


cdef inline int make_arg(object arg,
                         TVMValue* value,
                         int* tcode,
                         list temp_args) except -1:
    """Pack arguments into c args tvm call accept"""
    cdef unsigned long long ptr
    if isinstance(arg, NodeBase):
        value[0].v_handle = (<NodeBase>arg).chandle
        tcode[0] = kNodeHandle
    elif isinstance(arg, NDArrayBase):
        value[0].v_handle = (<NDArrayBase>arg).chandle
        tcode[0] = kArrayHandle
    elif isinstance(arg, _TVM_COMPATS):
        ptr = arg._tvm_handle
        value[0].v_handle = (<void*>ptr)
        tcode[0] = arg.__class__._tvm_tcode
    elif isinstance(arg, (int, long)):
        value[0].v_int64 = arg
        tcode[0] = kInt
    elif isinstance(arg, float):
        value[0].v_float64 = arg
        tcode[0] = kFloat
    elif isinstance(arg, str):
        tstr = c_str(arg)
        value[0].v_str = tstr
        tcode[0] = kStr
        temp_args.append(tstr)
    elif arg is None:
        value[0].v_handle = NULL
        tcode[0] = kNull
    elif isinstance(arg, Number):
        value[0].v_float64 = arg
        tcode[0] = kFloat
    elif isinstance(arg, TVMType):
        tstr = c_str(str(arg))
        value[0].v_str = tstr
        tcode[0] = kStr
        temp_args.append(tstr)
    elif isinstance(arg, TVMContext):
        value[0].v_ctx = (<DLContext*>(
            <unsigned long long>ctypes.addressof(arg)))[0]
        tcode[0] = kTVMContext
    elif isinstance(arg, bytearray):
        arr = TVMByteArray()
        arr.data = ctypes.cast(
            (ctypes.c_byte * len(arg)).from_buffer(arg),
            ctypes.POINTER(ctypes.c_byte))
        arr.size = len(arg)
        value[0].v_handle = <void*>(
            <unsigned long long>ctypes.addressof(arr))
        tcode[0] = kBytes
        temp_args.append(arr)
    elif isinstance(arg, string_types):
        tstr = c_str(arg)
        value[0].v_str = tstr
        tcode[0] = kStr
        temp_args.append(tstr)
    elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
        arg = convert_to_node(arg)
        value[0].v_handle = (<NodeBase>arg).chandle
        tcode[0] = kNodeHandle
        temp_args.append(arg)
    elif isinstance(arg, _CLASS_MODULE):
        value[0].v_handle = c_handle(arg.handle)
        tcode[0] = kModuleHandle
    elif isinstance(arg, FunctionBase):
        value[0].v_handle = (<FunctionBase>arg).chandle
        tcode[0] = kFuncHandle
    elif isinstance(arg, ctypes.c_void_p):
        value[0].v_handle = c_handle(arg)
        tcode[0] = kHandle
    elif callable(arg):
        arg = convert_to_tvm_func(arg)
        value[0].v_handle = (<FunctionBase>arg).chandle
        tcode[0] = kFuncHandle
        temp_args.append(arg)
    else:
        raise TypeError("Don't know how to handle type %s" % type(arg))
    return 0

cdef inline bytearray make_ret_bytes(void* chandle):
    handle = ctypes_handle(chandle)
    arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
    size = arr.size
    res = bytearray(size)
    rptr = (ctypes.c_byte * size).from_buffer(res)
    if not ctypes.memmove(rptr, arr.data, size):
        raise RuntimeError('memmove failed')
    return res

cdef inline object make_ret(TVMValue value, int tcode):
    """convert result to return value."""
    if tcode == kNodeHandle:
        return make_ret_node(value.v_handle)
    elif tcode == kNull:
        return None
    elif tcode == kInt:
        return value.v_int64
    elif tcode == kFloat:
        return value.v_float64
    elif tcode == kStr:
        return py_str(value.v_str)
    elif tcode == kBytes:
        return make_ret_bytes(value.v_handle)
    elif tcode == kHandle:
        return ctypes_handle(value.v_handle)
    elif tcode == kTVMContext:
        return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
    elif tcode == kModuleHandle:
        return _CLASS_MODULE(ctypes_handle(value.v_handle))
    elif tcode == kFuncHandle:
        fobj = _CLASS_FUNCTION(None, False)
        (<FunctionBase>fobj).chandle = value.v_handle
        return fobj
    elif tcode in _TVM_EXT_RET:
        return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))

    raise ValueError("Unhandled type code %d" % tcode)


cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
    cdef TVMValue[3] values
    cdef int[3] tcodes
    cdef TVMValue ret_val
    cdef int ret_code
    nargs = len(args)
    temp_args = []
    for i in range(nargs):
        make_arg(args[i], &values[i], &tcodes[i], temp_args)
    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
                     nargs, &ret_val, &ret_code))
    return make_ret(ret_val, ret_code)

cdef inline object FuncCall(void* chandle, tuple args):
    cdef int nargs
    nargs = len(args)
    if nargs <= 3:
        return FuncCall3(chandle, args, nargs)

    cdef vector[TVMValue] values
    cdef vector[int] tcodes
    cdef TVMValue ret_val
    cdef int ret_code
    values.resize(max(nargs, 1))
    tcodes.resize(max(nargs, 1))
    temp_args = []
    for i in range(nargs):
        make_arg(args[i], &values[i], &tcodes[i], temp_args)
    CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
                     nargs, &ret_val, &ret_code))
    return make_ret(ret_val, ret_code)


cdef class FunctionBase:
    cdef TVMFunctionHandle chandle
    cdef int is_global

    cdef inline _set_handle(self, handle):
        if handle is None:
            self.chandle = NULL
        else:
            self.chandle = c_handle(handle)

    property is_global:
        def __get__(self):
            return self.c_is_global != 0

        def __set__(self, value):
            self.c_is_global = value

    property handle:
        def __get__(self):
            if self.chandle == NULL:
                return None
            else:
                return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
        def __set__(self, value):
            self._set_handle(value)

    def __init__(self, handle, is_global):
        self._set_handle(handle)
        self.c_is_global = is_global

    def __dealloc__(self):
        if self.is_global == 0:
            CALL(TVMFuncFree(self.chandle))

    def __call__(self, *args):
        return FuncCall(self.chandle, args)

_CLASS_FUNCTION = None
_CLASS_MODULE = None

def _set_class_module(module_class):
    """Initialize the module."""
    global _CLASS_MODULE
    _CLASS_MODULE = module_class

def _set_class_function(func_class):
    global _CLASS_FUNCTION
    _CLASS_FUNCTION = func_class