# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from ... import _api_internal
from ..base import string_types
from ..node_generic import _set_class_node_base

"""Maps node type to its constructor"""
NODE_TYPE = []

def _register_node(int index, object cls):
    """register node class"""
    while len(NODE_TYPE) <= index:
        NODE_TYPE.append(None)
    NODE_TYPE[index] = cls


cdef inline object make_ret_node(void* chandle):
    global NODE_TYPE
    cdef int tindex
    cdef list node_type
    cdef object cls
    node_type = NODE_TYPE
    CALL(TVMNodeGetTypeIndex(chandle, &tindex))
    if tindex < len(node_type):
        cls = node_type[tindex]
        if cls is not None:
            obj = cls.__new__(cls)
        else:
            obj = NodeBase.__new__(NodeBase)
    else:
        obj = NodeBase.__new__(NodeBase)
    (<NodeBase>obj).chandle = chandle
    return obj


cdef class NodeBase:
    cdef void* chandle

    cdef _set_handle(self, handle):
        cdef unsigned long long ptr
        if handle is None:
            self.chandle = NULL
        else:
            ptr = handle.value
            self.chandle = <void*>(ptr)

    property handle:
        def __get__(self):
            if self.chandle == NULL:
                return None
            else:
                return ctypes_handle(self.chandle)

        def __set__(self, value):
            self._set_handle(value)

    def __dealloc__(self):
        CALL(TVMNodeFree(self.chandle))

    def __getattr__(self, name):
        cdef TVMValue ret_val
        cdef int ret_type_code, ret_succ
        CALL(TVMNodeGetAttr(self.chandle, c_str(name),
                            &ret_val, &ret_type_code, &ret_succ))
        if ret_succ == 0:
            raise AttributeError(
                "'%s' object has no attribute '%s'" % (type(self), name))
        return make_ret(ret_val, ret_type_code)

    def __init_handle_by_constructor__(self, fconstructor, *args):
        """Initialize the handle by calling constructor function.

        Parameters
        ----------
        fconstructor : Function
            Constructor function.

        args: list of objects
            The arguments to the constructor

        Note
        ----
        We have a special calling convention to call constructor functions.
        So the return handle is directly set into the Node object
        instead of creating a new Node.
        """
        # avoid error raised during construction.
        self.chandle = NULL
        cdef void* chandle
        ConstructorCall(
            (<FunctionBase>fconstructor).chandle,
            kNodeHandle, args, &chandle)
        self.chandle = chandle

_set_class_node_base(NodeBase)