type.h 14.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 * \file tvm/relay/type.h
 * \brief Relay typed AST nodes.
 */
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_

#include <tvm/api_registry.h>
#include <tvm/ir.h>
29
#include <tvm/node/node.h>
30 31
#include <string>

32 33
#include "base.h"
#include "../attrs.h"
34 35 36 37

namespace tvm {
namespace relay {

38 39
using Any = tvm::ir::Any;

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
 public:
  static constexpr const char* _type_key = "relay.Type";
  TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node);
};

/*!
 * \brief Type is the base type of relay type hiearchy.
 *
 * Relay's type system contains following two key concepts:
 *
 * - TensorType: type of certain Tensor values in the expression.
 * - FunctionType: the type of the function.
 *
 * There are also advanced types to support generic(polymorphic types),
 * which can be ignored when first reading the code base.
 */
class Type : public NodeRef {
 public:
  Type() {}
61
  explicit Type(NodePtr<tvm::Node> p) : NodeRef(p) {}
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

  using ContainerType = TypeNode;
};

/*!
 * \brief Base of all Tensor types
 *  This container can hold TensorType or GenericTensorType.
 */
class BaseTensorTypeNode : public TypeNode {
 public:
  static constexpr const char* _type_key = "relay.BaseTensorType";
  TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type);

/*!
 * \brief This is the most commonly used type in relay.
 *  TensorType have a fixed dimension, data type.
 *
 *  The elements of shape can be either IntImm(constant integer),
 *  or any symbolic integer expression.
 *  The symbolic integer allows generic shape inference in certain cases.
 * \sa TensorTypeNode The container class of TensorType.
 */
class TensorType;
/*! \brief TensorType container node */
class TensorTypeNode : public BaseTensorTypeNode {
 public:
  /*!
   * \brief The shape of the tensor,
93
   *  represented by IndexExpr(tvm::Expr).
94
   */
95
  Array<IndexExpr> shape;
96 97 98 99 100 101 102 103
  /*! \brief The content data type */
  DataType dtype;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("shape", &shape);
    v->Visit("dtype", &dtype);
    v->Visit("span", &span);
  }
Siju committed
104 105 106 107 108

  /*! \brief Return product of elements in the shape.
   *  \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
   */
  TVM_DLL IndexExpr Size() const;
109

110
  TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
111 112 113 114 115 116 117 118 119 120

  /*! \brief Construct an scalar containing elements of dtype.  */
  TVM_DLL static TensorType Scalar(DataType dtype);

  static constexpr const char* _type_key = "relay.TensorType";
  TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode);
};

RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);

121
/*! \brief Possible kinds of Type. */
122 123
enum Kind : int {
  kType = 0,
124
  /*! \brief Template variable in shape expression. */
125 126 127 128 129 130 131 132
  kShapeVar = 1,
  kBaseType = 2,
  kShape = 3,
  kConstraint = 4,
  kAdtHandle = 5,
  kTypeData = 6
};

133 134 135 136 137
/*!
 * \brief Type parameter in the function.
 *  This can be viewed as template parameter in c++ template function.
 *
 * For example, in the following pesudo code,
138
 * the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
139 140 141 142 143 144 145 146 147
 * This function can take in a Tensor with shape=(3, 3) and
 * returns a Tensor with shape=(9,)
 *
 * \code
 *
 *  template<i32 n>
 *  f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
 *
 * \endcode
148
 * \sa TypeVarNode The actual container class of TypeVar
149
 */
150 151 152
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
 public:
  /*!
   * \brief The variable itself is only meaningful when
   *  kind is ShapeVar, otherwise, we only use the name.
   */
  tvm::Var var;
  /*! \brief The kind of type parameter */
  Kind kind;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("var", &var);
    v->Visit("kind", &kind);
    v->Visit("span", &span);
  }

168
  TVM_DLL static TypeVar make(std::string name, Kind kind);
169

170 171
  static constexpr const char* _type_key = "relay.TypeVar";
  TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode);
172 173
};

174
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);
175 176

/*!
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
 * \brief A global type variable that is used for defining new types or type aliases.
 */
class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
 public:
  /*!
   * \brief The variable itself is only meaningful when
   *  kind is ShapeVar; otherwise, we only use the name.
   */
  tvm::Var var;
  /*! \brief The kind of type parameter */
  Kind kind;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("var", &var);
    v->Visit("kind", &kind);
    v->Visit("span", &span);
  }

  TVM_DLL static GlobalTypeVar make(std::string name, Kind kind);

  static constexpr const char* _type_key = "relay.GlobalTypeVar";
  TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type);

/*!
 * \brief Type application.
 */
class TypeCall;
/*! \brief TypeCall container node */
class TypeCallNode : public TypeNode {
 public:
  /*!
   * \brief The type-level function (ADT that takes type params).
   */
  Type func;
  /*! \brief The arguments. */
  tvm::Array<Type> args;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("func", &func);
    v->Visit("args", &args);
    v->Visit("span", &span);
  }

  TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args);

  static constexpr const char* _type_key = "relay.TypeCall";
  TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type);

/*!
234 235 236 237 238
 * \brief IncompleteType.
 * This is intermediate values that is used during type inference.
 *
 * If we view the type relations as "computational graph of types",
 * then IncompleteType represents intermediate values of the graph,
239
 * TypeVar represents the input to the graph.
240 241 242 243 244 245
 */
class IncompleteType;

/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
 public:
246
  Kind kind;
247 248 249

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("kind", &kind);
250
    v->Visit("span", &span);
251 252
  }

253
  TVM_DLL static IncompleteType make(Kind kind);
254 255 256 257 258 259 260 261

  static constexpr const char* _type_key = "relay.IncompleteType";
  TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);

/*!
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
 * \brief Potential Constraints in the type.
 * \note This is reserved for future use.
 */
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode {
 public:
  static constexpr const char* _type_key = "relay.TypeConstraint";
  TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type);

class FuncType;
/*!
 * \brief Function type in Relay.
 *
 * Relay support polymorphic function type.
 * This can be roughly viewed as template function in C++.
 *
282
 * \sa TypeVar, TypeConstraint
283 284 285 286 287 288 289 290 291 292
 */
class FuncTypeNode : public TypeNode {
 public:
  /*! \brief type type of arguments */
  tvm::Array<Type> arg_types;
  /*! \brief The type of return value. */
  Type ret_type;
  // The following fields are used in polymorphic(template) functions
  // For normal functions, the following two fields will be empty.
  /*! \brief The type parameters of the function */
293
  tvm::Array<TypeVar> type_params;
294 295 296 297 298 299 300 301 302 303 304 305 306 307
  /*!
   * \brief potential constraint the type need to obey
   * \note this field is reserved for futher purposes.
   */
  tvm::Array<TypeConstraint> type_constraints;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("arg_types", &arg_types);
    v->Visit("ret_type", &ret_type);
    v->Visit("type_params", &type_params);
    v->Visit("type_constraints", &type_constraints);
    v->Visit("span", &span);
  }

308 309
  TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
                               Type ret_type,
310
                               tvm::Array<TypeVar> type_params,
311 312 313 314 315 316 317 318
                               tvm::Array<TypeConstraint> type_constraints);

  static constexpr const char* _type_key = "relay.FuncType";
  TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type);

319 320 321 322 323 324 325 326 327 328 329 330 331 332
/*!
 * \brief The type of tuple values.
 */
class TupleType;
/*!
 * \brief TupleType container.
 */
class TupleTypeNode : public TypeNode {
 public:
  /*! \brief The type of each field in the tuple. */
  tvm::Array<Type> fields;

  TupleTypeNode() {}

333 334 335 336
  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("fields", &fields);
    v->Visit("span", &span);
  }
337 338 339

  TVM_DLL static TupleType make(tvm::Array<Type> fields);

340
  static constexpr const char* _type_key = "relay.TupleType";
341 342 343 344 345
  TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
/*!
 * \brief The type of reference values.
 */
class RefType;
/*!
 * \brief Reference Type in relay.
 */
class RefTypeNode : public TypeNode {
 public:
  /*! \brief The type of value in the Reference. */
  Type value;

  RefTypeNode() {}

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("value", &value);
    v->Visit("span", &span);
  }

  TVM_DLL static RefType make(Type value);

  static constexpr const char* _type_key = "relay.RefType";
  TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);

373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
class TypeReporter;

/*!
 * \brief reporter that reports back to the
 *  type resolution information.
 */
class TypeReporterNode : public Node {
 public:
  /*!
   * \brief Create a type equality constraint.
   *
   *  The "assign direction" acts as a hint to the solver
   *  showing that it is more likely to resolve dst by src.
   *  But it is possible for the solver to resolve src by dst as well.
   */
  TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
389

390
  /*!
391
   * \brief assert shape expression comparison.
Siju committed
392
   * \note Use assert only if any of the condition input is symbolic.
393 394 395 396 397 398
   * \param cond The condition of operation.
   * \return false if assertation can be proven to have failed
   *      true if solver can still proceed.
   */
  TVM_DLL virtual bool Assert(const IndexExpr& cond)= 0;
  /*!
399 400 401
   * \brief assert shape expression equals each other.
   * \param lhs The left operand.
   * \param rhs The right operand.
402 403
   * \return false if assertation can be proven to have failed
   *      true if solver can still proceed.
404
   */
405
  TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
406

407 408 409 410 411 412
  /*!
   * \brief Set the location at which to report unification errors.
   * \param ref The program node to report the error.
   */
  TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;

413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
  // solver is not serializable.
  void VisitAttrs(tvm::AttrVisitor* v) final {}

  static constexpr const char* _type_key = "relay.TypeReporter";
  TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
};

/*!
 * \brief Container class of TypeReporter.
 * \sa TypeReporterNode
 */
class TypeReporter : public NodeRef {
 public:
  TypeReporter() {}
  explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {
  }
  TypeReporterNode* operator->() const {
    return static_cast<TypeReporterNode*>(node_.get());
  }
  using ContainerType = TypeReporterNode;
};

/*!
 * \brief User defined type constraint function.
 *
 * If the input type information can be used to fully decide
 * the IncompleteTypes, then the function should call
 * reporter.Assign to report the new types, and return true.
 * Otherwise, the function should return false.
 *
 * \param args The arguments to the relation.
 *   The types are stored in the form of
 *   [input_type_0, input_type_1, ... input_type_n,
 *    output_type_0, output_type_1, ... output_type_m]
 *
 * \param num_inputs Number of input types in the args.
 * \param attrs The additional attributes of the operator.
 * \param reporter The reporter to report solution to.
 * \return false if This relation cannot be resolved.
 *   true if this relation has been resolved.
 */
454
using TypeRelationFn =
455 456 457 458
    TypedEnvFunc<bool(const Array<Type>& args,
                      int num_inputs,
                      const Attrs& attrs,
                      const TypeReporter& reporter)>;
459 460

/*!
461
 * \brief User defined type relation, is an input-output relation on types.
462 463 464 465 466
 */
class TypeRelation;
/*!
 * \brief TypeRelation container.
 * \note This node is not directly serializable.
467
 * The type function need to be lookedup in the module.
468 469 470 471 472 473
 */
class TypeRelationNode : public TypeConstraintNode {
 public:
  /*!
   * \brief The function on input and output variables which
   *  this is not directly serializable,
474
   *  need to be looked-up in the module.
475
   */
476
  TypeRelationFn func;
477 478
  /*! \brief The type arguments to the type function. */
  tvm::Array<Type> args;
479 480 481 482
  /*! \brief Number of inputs arguments */
  int num_inputs;
  /*! \brief Attributes to the relation function */
  Attrs attrs;
483 484

  void VisitAttrs(tvm::AttrVisitor* v) final {
485 486 487 488
    v->Visit("func", &func);
    v->Visit("args", &args);
    v->Visit("num_inputs", &num_inputs);
    v->Visit("attrs", &attrs);
489
    v->Visit("span", &span);
490 491
  }

492 493 494 495
  TVM_DLL static TypeRelation make(TypeRelationFn func,
                                   Array<Type> args,
                                   int num_args,
                                   Attrs attrs);
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513

  static constexpr const char* _type_key = "relay.TypeRelation";
  TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode);
};

RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);

// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;

}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_TYPE_H_