# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The type nodes of the Relay language.""" from enum import IntEnum from .base import RelayNode, register_relay_node from . import _make class Type(RelayNode): """The base type for all Relay types.""" def __eq__(self, other): """Compare two Relay types for structural equivalence using alpha equivalence. """ return bool(_make._type_alpha_equal(self, other)) def __ne__(self, other): return not self.__eq__(other) def same_as(self, other): """Compares two Relay types by referential equality.""" return super().__eq__(other) @register_relay_node class TensorType(Type): """A concrete TensorType in Relay. This is the type assigned to tensor's with a known dype and shape. For example a tensor of `float32` and `(5, 5)`. Parameters ---------- shape : List[tvm.Expr] The shape of the Tensor dtype : Optional[str] The content data type. Default to "float32". Returns ------- tensor_type : tvm.relay.TensorType The tensor type. """ def __init__(self, shape, dtype="float32"): self.__init_handle_by_constructor__( _make.TensorType, shape, dtype) @property def concrete_shape(self): """Get shape of the type as concrete tuple of int. Returns ------- shape : List[int] The concrete shape of the Type. Raises ------ TypeError : If the shape is symbolic """ return tuple(int(x) for x in self.shape) class Kind(IntEnum): """The kind of a type parameter, represents a variable shape, base type, type, or dimension. This controls what a type parameter is allowed to be instantiated with. For example one's of kind BaseType can only be `float32`, `int32`, and so on. """ Type = 0 ShapeVar = 1 BaseType = 2 Shape = 3 @register_relay_node class TypeVar(Type): """A type variable used for generic types in Relay, see tvm/relay/type.h for more details. A type variable represents a type placeholder which will be filled in later on. This allows the user to write functions which are generic over types. """ def __init__(self, var, kind=Kind.Type): """Construct a TypeVar. Parameters ---------- var : tvm.expr.Var The tvm.Var which backs the type parameter. kind : Optional[Kind] The kind of the type parameter. Default to Kind.Type. Returns ------- type_var : tvm.relay.TypeVar The type variable. """ self.__init_handle_by_constructor__(_make.TypeVar, var, kind) @register_relay_node class TypeConstraint(Type): """Abstract class representing a type constraint.""" pass @register_relay_node class TupleType(Type): """A tuple type in Relay, see tvm/relay/type.h for more details. Lists the type of each field in the tuple. """ def __init__(self, fields): """Constructs a tuple type Parameters ---------- fields : List[tvm.relay.Type] The fields in the tuple Returns ------- tuple_type : tvm.relay.TupleType the tuple type """ self.__init_handle_by_constructor__(_make.TupleType, fields) @register_relay_node class FuncType(Type): """A function type in Relay, see tvm/relay/type.h for more details. This is the type assigned to functions in Relay. They consist of a list of type parameters which enable the definition of generic fucntions, a set of type constraints which we omit for the time being, a sequence of argument types, and a return type. We informally write them as: `forall (type_params), (arg_types) -> ret_type where type_constraints` Parameters ---------- arg_types : List[tvm.relay.Type] The argument types ret_type : tvm.relay.Type The return type. type_params : Optional[List[tvm.relay.TypeVar]] The type parameters type_constraints : Optional[List[tvm.relay.TypeConstraint]] The type constraints. """ def __init__(self, arg_types, ret_type, type_params=None, type_constraints=None): if type_params is None: type_params = [] if type_constraints is None: type_constraints = [] self.__init_handle_by_constructor__( _make.FuncType, arg_types, ret_type, type_params, type_constraints) @register_relay_node class IncompleteType(Type): """An incomplete type.""" def __init__(self, kind=Kind.Type): self.__init_handle_by_constructor__(_make.IncompleteType, kind) @register_relay_node class TypeRelation(TypeConstraint): """Type relation in relay. Parameters ---------- func : EnvFunc User defined relation function. args : [tvm.relay.Type] List of types to the func. num_inputs : int Number of input arguments in args, this act as a hint for type inference. attrs : Attrs The attribute attached to the relation information Returns ------- type_relation : tvm.relay.TypeRelation The type relation. """ def __init__(self, func, args, num_inputs, attrs): self.__init_handle_by_constructor__(_make.TypeRelation, func, args, num_inputs, attrs) def scalar_type(dtype): """Creates a scalar type. This function returns TensorType((), dtype) Parameters ---------- dtype : str The content data type. Returns ------- s_type : tvm.relay.TensorType The result type. """ return TensorType((), dtype)