ty.py 8.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
20
from .base import RelayNode, register_relay_node
21 22
from . import _make

23
Any = _make.Any
24

25
class Type(RelayNode):
26 27 28 29 30 31
    """The base type for all Relay types."""

    def __eq__(self, other):
        """Compare two Relay types for structural equivalence using
           alpha equivalence.
        """
32
        return bool(_make._alpha_equal(self, other))
33 34 35 36 37 38 39 40

    def __ne__(self, other):
        return not self.__eq__(other)

    def same_as(self, other):
        """Compares two Relay types by referential equality."""
        return super().__eq__(other)

41 42 43 44 45 46 47 48 49 50 51 52 53
    def __call__(self, *args):
        """Create a type call from this type.

        Parameters
        ----------
        args: List[relay.Type]
            The arguments to the type call.

        Returns
        -------
        call: relay.TypeCall
        """
        return TypeCall(self, args)
54

55 56 57
    def is_dynamic(self):
        return _make.IsDynamic(self)

58 59
@register_relay_node
class TensorType(Type):
60
    """A concrete TensorType in Relay.
61

62 63
    This is the type assigned to tensors with a known dtype and shape. For
    example, a tensor of `float32` and `(5, 5)`.
64

65 66
    Parameters
    ----------
67
    shape : List[tvm.Expr]
68
        The shape of the Tensor
69

70
    dtype : Optional[str]
71
        The content data type.
72 73 74 75 76 77
        Default to "float32".

    Returns
    -------
    tensor_type : tvm.relay.TensorType
        The tensor type.
78 79 80 81
    """
    def __init__(self, shape, dtype="float32"):
        self.__init_handle_by_constructor__(
            _make.TensorType, shape, dtype)
82

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    @property
    def concrete_shape(self):
        """Get shape of the type as concrete tuple of int.

        Returns
        -------
        shape : List[int]
            The concrete shape of the Type.

        Raises
        ------
        TypeError : If the shape is symbolic
        """
        return tuple(int(x) for x in self.shape)

98 99 100 101 102 103 104 105 106

class Kind(IntEnum):
    """The kind of a type parameter, represents a variable shape,
       base type, type, or dimension.

       This controls what a type parameter is allowed to be instantiated
       with. For example one's of kind BaseType can only be `float32`, `int32`,
       and so on.
    """
107 108
    Type = 0
    ShapeVar = 1
109
    BaseType = 2
110
    Shape = 3
111 112 113
    Constraint = 4
    AdtHandle = 5
    TypeData = 6
114 115

@register_relay_node
116
class TypeVar(Type):
117
    """A type variable used for generic types in Relay,
118 119
    see tvm/relay/type.h for more details.

120
    A type variable represents a type placeholder which will
121 122 123 124
    be filled in later on. This allows the user to write
    functions which are generic over types.
    """

125
    def __init__(self, name_hint, kind=Kind.Type):
126
        """Construct a TypeVar.
127 128 129

        Parameters
        ----------
130 131 132
        name_hint: str
            The name of the type variable. This name only acts as a hint, and
            is not used for equality.
133

134
        kind : Optional[Kind]
135
            The kind of the type parameter.
136
            Default to Kind.Type.
137 138 139

        Returns
        -------
140 141
        type_var : tvm.relay.TypeVar
            The type variable.
142
        """
143
        self.__init_handle_by_constructor__(_make.TypeVar, name_hint, kind)
144

145 146 147 148 149 150 151 152 153 154 155 156 157
def ShapeVar(name):
    """A helper which constructs a type var of which the shape kind.

    Parameters
    ----------
    name : str

    Returns
    -------
    type_var : tvm.relay.TypeVar
        The shape variable.
    """
    return TypeVar(name, kind=Kind.ShapeVar)
158 159

@register_relay_node
160 161 162 163 164 165
class GlobalTypeVar(Type):
    """A global type variable in Relay.
    GlobalTypeVar is used to refer to the global type-level definitions
    stored in the environment.
    """

166
    def __init__(self, name_hint, kind=Kind.AdtHandle):
167 168 169 170
        """Construct a GlobalTypeVar.

        Parameters
        ----------
171 172 173 174
        name_hint: str
            The name of the global type variable. This name only acts as a
            hint, and is not used for equality.

175 176 177 178 179 180 181 182
        kind: Kind, optional
            The kind of the type parameter, Kind.AdtHandle by default.

        Returns
        -------
        type_var: GlobalTypeVar
            The global type variable.
        """
183
        self.__init_handle_by_constructor__(_make.GlobalTypeVar, name_hint, kind)
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208


@register_relay_node
class TypeCall(Type):
    """Type-level function application in Relay.
    A type call applies argument types to a constructor (type-level function).
    """

    def __init__(self, func, args):
        """Construct a TypeCall.
        Parameters
        ----------
        func: tvm.relay.Type
            The function.
        args: List[tvm.expr.Type]
            The arguments.
        Returns
        -------
        type_call: TypeCall
            The type function application.
        """
        self.__init_handle_by_constructor__(_make.TypeCall, func, args)


@register_relay_node
209 210 211 212 213
class TypeConstraint(Type):
    """Abstract class representing a type constraint."""


@register_relay_node
214 215 216 217 218 219 220 221 222 223 224
class TupleType(Type):
    """A tuple type in Relay, see tvm/relay/type.h for more details.

    Lists the type of each field in the tuple.
    """

    def __init__(self, fields):
        """Constructs a tuple type

        Parameters
        ----------
225 226
        fields : List[tvm.relay.Type]
            The fields in the tuple
227 228 229

        Returns
        -------
230 231
        tuple_type : tvm.relay.TupleType
            the tuple type
232 233 234 235 236
        """
        self.__init_handle_by_constructor__(_make.TupleType, fields)


@register_relay_node
237 238 239 240 241
class FuncType(Type):
    """A function type in Relay, see tvm/relay/type.h for more details.

    This is the type assigned to functions in Relay. They consist of
    a list of type parameters which enable the definition of generic
242
    functions, a set of type constraints which we omit for the time
243 244 245 246
    being, a sequence of argument types, and a return type.

    We informally write them as:
    `forall (type_params), (arg_types) -> ret_type where type_constraints`
247 248 249

    Parameters
    ----------
250
    arg_types : List[tvm.relay.Type]
251 252
        The argument types

253
    ret_type : tvm.relay.Type
254 255
        The return type.

256
    type_params : Optional[List[tvm.relay.TypeVar]]
257 258
        The type parameters

259
    type_constraints : Optional[List[tvm.relay.TypeConstraint]]
260
        The type constraints.
261 262 263 264
    """
    def __init__(self,
                 arg_types,
                 ret_type,
265 266 267 268 269 270
                 type_params=None,
                 type_constraints=None):
        if type_params is None:
            type_params = []
        if type_constraints is None:
            type_constraints = []
271 272 273 274 275 276 277
        self.__init_handle_by_constructor__(
            _make.FuncType, arg_types, ret_type, type_params, type_constraints)


@register_relay_node
class IncompleteType(Type):
    """An incomplete type."""
278
    def __init__(self, kind=Kind.Type):
279
        self.__init_handle_by_constructor__(_make.IncompleteType, kind)
280 281 282 283 284 285 286 287


@register_relay_node
class TypeRelation(TypeConstraint):
    """Type relation in relay.

    Parameters
    ----------
288
    func : EnvFunc
289 290
        User defined relation function.

291
    args : [tvm.relay.Type]
292 293
        List of types to the func.

294
    num_inputs : int
295 296 297
        Number of input arguments in args,
        this act as a hint for type inference.

298
    attrs : Attrs
299
        The attribute attached to the relation information
300 301 302 303 304

    Returns
    -------
    type_relation : tvm.relay.TypeRelation
        The type relation.
305 306 307 308
    """
    def __init__(self, func, args, num_inputs, attrs):
        self.__init_handle_by_constructor__(_make.TypeRelation,
                                            func, args, num_inputs, attrs)
309 310


311 312 313 314 315 316 317 318 319 320 321 322
@register_relay_node
class RefType(Type):
    """Reference Type in relay.

    Parameters
    ----------
    value: Type
        The value type.
    """
    def __init__(self, value):
        self.__init_handle_by_constructor__(_make.RefType, value)

323 324 325 326 327 328 329
def scalar_type(dtype):
    """Creates a scalar type.

    This function returns TensorType((), dtype)

    Parameters
    ----------
330
    dtype : str
331 332 333 334
        The content data type.

    Returns
    -------
335
    s_type : tvm.relay.TensorType
336 337 338
        The result type.
    """
    return TensorType((), dtype)