ty.py 5.8 KB
Newer Older
1 2 3
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
4
from .base import RelayNode, register_relay_node
5 6 7
from . import _make


8
class Type(RelayNode):
9 10 11 12 13 14
    """The base type for all Relay types."""

    def __eq__(self, other):
        """Compare two Relay types for structural equivalence using
           alpha equivalence.
        """
15
        return bool(_make._type_alpha_equal(self, other))
16 17 18 19 20 21 22 23

    def __ne__(self, other):
        return not self.__eq__(other)

    def same_as(self, other):
        """Compares two Relay types by referential equality."""
        return super().__eq__(other)

24

25 26
@register_relay_node
class TensorType(Type):
27
    """A concrete TensorType in Relay.
28 29 30 31

    This is the type assigned to tensor's with a known dype and shape. For
    example a tensor of `float32` and `(5, 5)`.

32 33
    Parameters
    ----------
34
    shape : List[tvm.Expr]
35
        The shape of the Tensor
36

37
    dtype : Optional[str]
38
        The content data type.
39 40 41 42 43 44
        Default to "float32".

    Returns
    -------
    tensor_type : tvm.relay.TensorType
        The tensor type.
45 46 47 48
    """
    def __init__(self, shape, dtype="float32"):
        self.__init_handle_by_constructor__(
            _make.TensorType, shape, dtype)
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    @property
    def concrete_shape(self):
        """Get shape of the type as concrete tuple of int.

        Returns
        -------
        shape : List[int]
            The concrete shape of the Type.

        Raises
        ------
        TypeError : If the shape is symbolic
        """
        return tuple(int(x) for x in self.shape)

65 66 67 68 69 70 71 72 73

class Kind(IntEnum):
    """The kind of a type parameter, represents a variable shape,
       base type, type, or dimension.

       This controls what a type parameter is allowed to be instantiated
       with. For example one's of kind BaseType can only be `float32`, `int32`,
       and so on.
    """
74 75
    Type = 0
    ShapeVar = 1
76
    BaseType = 2
77
    Shape = 3
78 79

@register_relay_node
80
class TypeVar(Type):
81
    """A type variable used for generic types in Relay,
82 83
    see tvm/relay/type.h for more details.

84
    A type variable represents a type placeholder which will
85 86 87 88
    be filled in later on. This allows the user to write
    functions which are generic over types.
    """

89
    def __init__(self, var, kind=Kind.Type):
90
        """Construct a TypeVar.
91 92 93

        Parameters
        ----------
94
        var : tvm.expr.Var
95 96
            The tvm.Var which backs the type parameter.

97
        kind : Optional[Kind]
98
            The kind of the type parameter.
99
            Default to Kind.Type.
100 101 102

        Returns
        -------
103 104
        type_var : tvm.relay.TypeVar
            The type variable.
105
        """
106
        self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
107 108 109 110 111 112 113 114 115


@register_relay_node
class TypeConstraint(Type):
    """Abstract class representing a type constraint."""
    pass


@register_relay_node
116 117 118 119 120 121 122 123 124 125 126
class TupleType(Type):
    """A tuple type in Relay, see tvm/relay/type.h for more details.

    Lists the type of each field in the tuple.
    """

    def __init__(self, fields):
        """Constructs a tuple type

        Parameters
        ----------
127 128
        fields : List[tvm.relay.Type]
            The fields in the tuple
129 130 131

        Returns
        -------
132 133
        tuple_type : tvm.relay.TupleType
            the tuple type
134 135 136 137 138
        """
        self.__init_handle_by_constructor__(_make.TupleType, fields)


@register_relay_node
139 140 141 142 143 144 145 146 147 148
class FuncType(Type):
    """A function type in Relay, see tvm/relay/type.h for more details.

    This is the type assigned to functions in Relay. They consist of
    a list of type parameters which enable the definition of generic
    fucntions, a set of type constraints which we omit for the time
    being, a sequence of argument types, and a return type.

    We informally write them as:
    `forall (type_params), (arg_types) -> ret_type where type_constraints`
149 150 151

    Parameters
    ----------
152
    arg_types : List[tvm.relay.Type]
153 154
        The argument types

155
    ret_type : tvm.relay.Type
156 157
        The return type.

158
    type_params : Optional[List[tvm.relay.TypeVar]]
159 160
        The type parameters

161
    type_constraints : Optional[List[tvm.relay.TypeConstraint]]
162
        The type constraints.
163 164 165 166
    """
    def __init__(self,
                 arg_types,
                 ret_type,
167 168 169 170 171 172
                 type_params=None,
                 type_constraints=None):
        if type_params is None:
            type_params = []
        if type_constraints is None:
            type_constraints = []
173 174 175 176 177 178 179
        self.__init_handle_by_constructor__(
            _make.FuncType, arg_types, ret_type, type_params, type_constraints)


@register_relay_node
class IncompleteType(Type):
    """An incomplete type."""
180
    def __init__(self, kind=Kind.Type):
181
        self.__init_handle_by_constructor__(_make.IncompleteType, kind)
182 183 184 185 186 187 188 189


@register_relay_node
class TypeRelation(TypeConstraint):
    """Type relation in relay.

    Parameters
    ----------
190
    func : EnvFunc
191 192
        User defined relation function.

193
    args : [tvm.relay.Type]
194 195
        List of types to the func.

196
    num_inputs : int
197 198 199
        Number of input arguments in args,
        this act as a hint for type inference.

200
    attrs : Attrs
201
        The attribute attached to the relation information
202 203 204 205 206

    Returns
    -------
    type_relation : tvm.relay.TypeRelation
        The type relation.
207 208 209 210
    """
    def __init__(self, func, args, num_inputs, attrs):
        self.__init_handle_by_constructor__(_make.TypeRelation,
                                            func, args, num_inputs, attrs)
211 212 213 214 215 216 217 218 219


def scalar_type(dtype):
    """Creates a scalar type.

    This function returns TensorType((), dtype)

    Parameters
    ----------
220
    dtype : str
221 222 223 224
        The content data type.

    Returns
    -------
225
    s_type : tvm.relay.TensorType
226 227 228
        The result type.
    """
    return TensorType((), dtype)