ty.py 8.69 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
20
from .base import RelayNode, register_relay_node
21 22
from . import _make

23
Any = _make.Any
24

25
class Type(RelayNode):
26 27 28 29 30 31
    """The base type for all Relay types."""

    def __eq__(self, other):
        """Compare two Relay types for structural equivalence using
           alpha equivalence.
        """
32
        return bool(_make._alpha_equal(self, other))
33 34 35 36 37 38 39 40

    def __ne__(self, other):
        return not self.__eq__(other)

    def same_as(self, other):
        """Compares two Relay types by referential equality."""
        return super().__eq__(other)

41 42 43 44 45 46 47 48 49 50 51 52 53
    def __call__(self, *args):
        """Create a type call from this type.

        Parameters
        ----------
        args: List[relay.Type]
            The arguments to the type call.

        Returns
        -------
        call: relay.TypeCall
        """
        return TypeCall(self, args)
54

55 56
@register_relay_node
class TensorType(Type):
57
    """A concrete TensorType in Relay.
58 59 60 61

    This is the type assigned to tensor's with a known dype and shape. For
    example a tensor of `float32` and `(5, 5)`.

62 63
    Parameters
    ----------
64
    shape : List[tvm.Expr]
65
        The shape of the Tensor
66

67
    dtype : Optional[str]
68
        The content data type.
69 70 71 72 73 74
        Default to "float32".

    Returns
    -------
    tensor_type : tvm.relay.TensorType
        The tensor type.
75 76 77 78
    """
    def __init__(self, shape, dtype="float32"):
        self.__init_handle_by_constructor__(
            _make.TensorType, shape, dtype)
79

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    @property
    def concrete_shape(self):
        """Get shape of the type as concrete tuple of int.

        Returns
        -------
        shape : List[int]
            The concrete shape of the Type.

        Raises
        ------
        TypeError : If the shape is symbolic
        """
        return tuple(int(x) for x in self.shape)

95 96 97 98 99 100 101 102 103

class Kind(IntEnum):
    """The kind of a type parameter, represents a variable shape,
       base type, type, or dimension.

       This controls what a type parameter is allowed to be instantiated
       with. For example one's of kind BaseType can only be `float32`, `int32`,
       and so on.
    """
104 105
    Type = 0
    ShapeVar = 1
106
    BaseType = 2
107
    Shape = 3
108 109 110
    Constraint = 4
    AdtHandle = 5
    TypeData = 6
111 112

@register_relay_node
113
class TypeVar(Type):
114
    """A type variable used for generic types in Relay,
115 116
    see tvm/relay/type.h for more details.

117
    A type variable represents a type placeholder which will
118 119 120 121
    be filled in later on. This allows the user to write
    functions which are generic over types.
    """

122
    def __init__(self, var, kind=Kind.Type):
123
        """Construct a TypeVar.
124 125 126

        Parameters
        ----------
127
        var : tvm.expr.Var
128 129
            The tvm.Var which backs the type parameter.

130
        kind : Optional[Kind]
131
            The kind of the type parameter.
132
            Default to Kind.Type.
133 134 135

        Returns
        -------
136 137
        type_var : tvm.relay.TypeVar
            The type variable.
138
        """
139
        self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
140

141 142 143 144 145 146 147 148 149 150 151 152 153
def ShapeVar(name):
    """A helper which constructs a type var of which the shape kind.

    Parameters
    ----------
    name : str

    Returns
    -------
    type_var : tvm.relay.TypeVar
        The shape variable.
    """
    return TypeVar(name, kind=Kind.ShapeVar)
154 155

@register_relay_node
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
class GlobalTypeVar(Type):
    """A global type variable in Relay.
    GlobalTypeVar is used to refer to the global type-level definitions
    stored in the environment.
    """

    def __init__(self, var, kind=Kind.AdtHandle):
        """Construct a GlobalTypeVar.

        Parameters
        ----------
        var: tvm.Var
            The tvm.Var which backs the type parameter.
        kind: Kind, optional
            The kind of the type parameter, Kind.AdtHandle by default.

        Returns
        -------
        type_var: GlobalTypeVar
            The global type variable.
        """
        self.__init_handle_by_constructor__(_make.GlobalTypeVar, var, kind)


@register_relay_node
class TypeCall(Type):
    """Type-level function application in Relay.
    A type call applies argument types to a constructor (type-level function).
    """

    def __init__(self, func, args):
        """Construct a TypeCall.
        Parameters
        ----------
        func: tvm.relay.Type
            The function.
        args: List[tvm.expr.Type]
            The arguments.
        Returns
        -------
        type_call: TypeCall
            The type function application.
        """
        self.__init_handle_by_constructor__(_make.TypeCall, func, args)


@register_relay_node
203 204 205 206 207
class TypeConstraint(Type):
    """Abstract class representing a type constraint."""


@register_relay_node
208 209 210 211 212 213 214 215 216 217 218
class TupleType(Type):
    """A tuple type in Relay, see tvm/relay/type.h for more details.

    Lists the type of each field in the tuple.
    """

    def __init__(self, fields):
        """Constructs a tuple type

        Parameters
        ----------
219 220
        fields : List[tvm.relay.Type]
            The fields in the tuple
221 222 223

        Returns
        -------
224 225
        tuple_type : tvm.relay.TupleType
            the tuple type
226 227 228 229 230
        """
        self.__init_handle_by_constructor__(_make.TupleType, fields)


@register_relay_node
231 232 233 234 235
class FuncType(Type):
    """A function type in Relay, see tvm/relay/type.h for more details.

    This is the type assigned to functions in Relay. They consist of
    a list of type parameters which enable the definition of generic
236
    functions, a set of type constraints which we omit for the time
237 238 239 240
    being, a sequence of argument types, and a return type.

    We informally write them as:
    `forall (type_params), (arg_types) -> ret_type where type_constraints`
241 242 243

    Parameters
    ----------
244
    arg_types : List[tvm.relay.Type]
245 246
        The argument types

247
    ret_type : tvm.relay.Type
248 249
        The return type.

250
    type_params : Optional[List[tvm.relay.TypeVar]]
251 252
        The type parameters

253
    type_constraints : Optional[List[tvm.relay.TypeConstraint]]
254
        The type constraints.
255 256 257 258
    """
    def __init__(self,
                 arg_types,
                 ret_type,
259 260 261 262 263 264
                 type_params=None,
                 type_constraints=None):
        if type_params is None:
            type_params = []
        if type_constraints is None:
            type_constraints = []
265 266 267 268 269 270 271
        self.__init_handle_by_constructor__(
            _make.FuncType, arg_types, ret_type, type_params, type_constraints)


@register_relay_node
class IncompleteType(Type):
    """An incomplete type."""
272
    def __init__(self, kind=Kind.Type):
273
        self.__init_handle_by_constructor__(_make.IncompleteType, kind)
274 275 276 277 278 279 280 281


@register_relay_node
class TypeRelation(TypeConstraint):
    """Type relation in relay.

    Parameters
    ----------
282
    func : EnvFunc
283 284
        User defined relation function.

285
    args : [tvm.relay.Type]
286 287
        List of types to the func.

288
    num_inputs : int
289 290 291
        Number of input arguments in args,
        this act as a hint for type inference.

292
    attrs : Attrs
293
        The attribute attached to the relation information
294 295 296 297 298

    Returns
    -------
    type_relation : tvm.relay.TypeRelation
        The type relation.
299 300 301 302
    """
    def __init__(self, func, args, num_inputs, attrs):
        self.__init_handle_by_constructor__(_make.TypeRelation,
                                            func, args, num_inputs, attrs)
303 304


305 306 307 308 309 310 311 312 313 314 315 316 317
@register_relay_node
class RefType(Type):
    """Reference Type in relay.

    Parameters
    ----------
    value: Type
        The value type.
    """
    def __init__(self, value):
        self.__init_handle_by_constructor__(_make.RefType, value)


318 319 320 321 322 323 324
def scalar_type(dtype):
    """Creates a scalar type.

    This function returns TensorType((), dtype)

    Parameters
    ----------
325
    dtype : str
326 327 328 329
        The content data type.

    Returns
    -------
330
    s_type : tvm.relay.TensorType
331 332 333
        The result type.
    """
    return TensorType((), dtype)