# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import NodeBase, register_relay_node
from . import _make


class Type(NodeBase):
    """The base type for all Relay types."""

    def __eq__(self, other):
        """Compare two Relay types for structural equivalence using
           alpha equivalence.
        """
        return bool(_make._type_alpha_eq(self, other))

    def __ne__(self, other):
        return not self.__eq__(other)

    def same_as(self, other):
        """Compares two Relay types by referential equality."""
        return super().__eq__(other)


@register_relay_node
class TensorType(Type):
    """A concrete TensorType in Relay, see tvm/relay/type.h for more details.

    This is the type assigned to tensor's with a known dype and shape. For
    example a tensor of `float32` and `(5, 5)`.
    """

    def __init__(self, shape, dtype):
        """Construct a tensor type.

        Parameters
        ----------
        shape: list of tvm.Expr
        dtype: str

        Returns
        -------
        tensor_type: The TensorType
        """
        self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)


class Kind(IntEnum):
    """The kind of a type parameter, represents a variable shape,
       base type, type, or dimension.

       This controls what a type parameter is allowed to be instantiated
       with. For example one's of kind BaseType can only be `float32`, `int32`,
       and so on.
    """
    ShapeVar = 0
    Shape = 1
    BaseType = 2
    Type = 3


@register_relay_node
class TypeParam(Type):
    """A type parameter used for generic types in Relay,
    see tvm/relay/type.h for more details.

    A type parameter represents a type placeholder which will
    be filled in later on. This allows the user to write
    functions which are generic over types.
    """

    def __init__(self, var, kind):
        """Construct a TypeParam.

        Parameters
        ----------
        var: tvm.expr.Var
            The tvm.Var which backs the type parameter.

        kind: Kind
            The kind of the type parameter.

        Returns
        -------
        type_param: TypeParam
            The type parameter.
        """
        self.__init_handle_by_constructor__(_make.TypeParam, var, kind)


@register_relay_node
class TypeConstraint(Type):
    """Abstract class representing a type constraint."""
    pass


@register_relay_node
class TupleType(Type):
    """A tuple type in Relay, see tvm/relay/type.h for more details.

    Lists the type of each field in the tuple.
    """

    def __init__(self, fields):
        """Constructs a tuple type

        Parameters
        ----------
        fields: list of tvm.Type

        Returns
        -------
        tuple_type: the tuple type
        """
        self.__init_handle_by_constructor__(_make.TupleType, fields)


@register_relay_node
class FuncType(Type):
    """A function type in Relay, see tvm/relay/type.h for more details.

    This is the type assigned to functions in Relay. They consist of
    a list of type parameters which enable the definition of generic
    functions, a set of type constraints which we omit for the time
    being, a sequence of argument types, and a return type.

    We informally write them as:
    `forall (type_params), (arg_types) -> ret_type where type_constraints`
    """

    def __init__(self,
                 arg_types,
                 ret_type,
                 type_params,
                 type_constraints,
                 ):
        """Construct a function type.

        Parameters
        ----------
        arg_types:  list of Type
        ret_type: Type
        type_params: list of TypeParam
        type_constraints: list of TypeConstraint

        Returns
        -------
        func_type: FuncType
            The function type.
        """
        self.__init_handle_by_constructor__(
            _make.FuncType, arg_types, ret_type, type_params, type_constraints)


@register_relay_node
class IncompleteType(Type):
    """An incomplete type."""

    def __init__(self, kind=Kind.Type):
        self.__init_handle_by_constructor__(_make.IncompleteType, kind)

@register_relay_node
class TypeRelation(TypeConstraint):
    """Type relation in relay.

    Parameters
    ----------
    func : EnvFunc
        User defined relation function.

    args : list of types
        List of types to the func.

    num_inputs: int
        Number of input arguments in args,
        this act as a hint for type inference.

    attrs : Attrs
        The attribute attached to the relation information
    """
    def __init__(self, func, args, num_inputs, attrs):
        self.__init_handle_by_constructor__(_make.TypeRelation,
                                            func, args, num_inputs, attrs)