# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unified type system in the project."""
from enum import IntEnum
import tvm._ffi

from .base import Node
from . import _ffi_api


class Type(Node):
    """The base class of all types."""
    def __eq__(self, other):
        """Compare two types for structural equivalence."""
        return bool(_ffi_api.type_alpha_equal(self, other))

    def __ne__(self, other):
        return not self.__eq__(other)

    def same_as(self, other):
        """Compares two Relay types by referential equality."""
        return super().__eq__(other)


class TypeKind(IntEnum):
    """Possible kinds of TypeVars."""
    Type = 0
    ShapeVar = 1
    BaseType = 2
    Constraint = 4
    AdtHandle = 5
    TypeData = 6


@tvm._ffi.register_object("relay.TypeVar")
class TypeVar(Type):
    """Type parameter in functions.

    A type variable represents a type placeholder which will
    be filled in later on. This allows the user to write
    functions which are generic over types.

    Parameters
    ----------
    name_hint: str
        The name of the type variable. This name only acts as a hint, and
        is not used for equality.

    kind : Optional[TypeKind]
        The kind of the type parameter.
    """
    def __init__(self, name_hint, kind=TypeKind.Type):
        self.__init_handle_by_constructor__(
            _ffi_api.TypeVar, name_hint, kind)

    def __call__(self, *args):
        """Create a type call from this type.

        Parameters
        ----------
        args: List[Type]
            The arguments to the type call.

        Returns
        -------
        call: Type
            The result type call.
        """
        # pylint: disable=import-outside-toplevel
        from .type_relation import TypeCall
        return TypeCall(self, args)


@tvm._ffi.register_object("relay.GlobalTypeVar")
class GlobalTypeVar(Type):
    """A global type variable that is used for defining new types or type aliases.

    Parameters
    ----------
    name_hint: str
        The name of the type variable. This name only acts as a hint, and
        is not used for equality.

    kind : Optional[TypeKind]
        The kind of the type parameter.
    """
    def __init__(self, name_hint, kind=TypeKind.AdtHandle):
        self.__init_handle_by_constructor__(
            _ffi_api.GlobalTypeVar, name_hint, kind)

    def __call__(self, *args):
        """Create a type call from this type.

        Parameters
        ----------
        args: List[Type]
            The arguments to the type call.

        Returns
        -------
        call: Type
            The result type call.
        """
        # pylint: disable=import-outside-toplevel
        from .type_relation import TypeCall
        return TypeCall(self, args)


@tvm._ffi.register_object("relay.TupleType")
class TupleType(Type):
    """The type of tuple values.

    Parameters
    ----------
    fields : List[Type]
        The fields in the tuple
    """

    def __init__(self, fields):
        self.__init_handle_by_constructor__(
            _ffi_api.TupleType, fields)


@tvm._ffi.register_object("relay.TypeConstraint")
class TypeConstraint(Type):
    """Abstract class representing a type constraint."""


@tvm._ffi.register_object("relay.FuncType")
class FuncType(Type):
    """Function type.

    A function type consists of a list of type parameters to enable
    the definition of generic functions,
    a set of type constraints which we omit for the time being,
    a sequence of argument types, and a return type.

    We can informally write them as:
    `forall (type_params), (arg_types) -> ret_type where type_constraints`

    Parameters
    ----------
    arg_types : List[tvm.relay.Type]
        The argument types

    ret_type : tvm.relay.Type
        The return type.

    type_params : Optional[List[tvm.relay.TypeVar]]
        The type parameters

    type_constraints : Optional[List[tvm.relay.TypeConstraint]]
        The type constraints.
    """
    def __init__(self,
                 arg_types,
                 ret_type,
                 type_params=None,
                 type_constraints=None):
        if type_params is None:
            type_params = []
        if type_constraints is None:
            type_constraints = []
        self.__init_handle_by_constructor__(
            _ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)


@tvm._ffi.register_object("relay.IncompleteType")
class IncompleteType(Type):
    """Incomplete type during type inference.

    kind : Optional[TypeKind]
        The kind of the incomplete type.
    """
    def __init__(self, kind=TypeKind.Type):
        self.__init_handle_by_constructor__(
            _ffi_api.IncompleteType, kind)


@tvm._ffi.register_object("relay.RefType")
class RelayRefType(Type):
    """Reference Type in relay.

    Parameters
    ----------
    value: Type
        The value type.
    """
    def __init__(self, value):
        self.__init_handle_by_constructor__(_ffi_api.RelayRefType, value)