expr.h 15.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2018 by Contributors
 * \file tvm/relay/expr.h
 * \brief Relay expression language.
 */
#ifndef TVM_RELAY_EXPR_H_
#define TVM_RELAY_EXPR_H_

#include <tvm/attrs.h>
#include <string>
11
#include <functional>
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#include "./base.h"
#include "./type.h"

namespace tvm {
namespace relay {

/*!
 * \brief A Relay expression.
 */
class Expr;
/*!
 * \brief Base type of the Relay expression hiearchy.
 */
class ExprNode : public RelayNode {
 public:
  /*!
   * \brief Stores the result of type inference(type checking).
   *
   * \note This can be undefined before type inference.
   *       This value is discarded during serialization.
   */
  mutable Type checked_type_ = Type(nullptr);
  /*!
   * \return The checked_type
   */
37
  const Type& checked_type() const;
38 39 40 41 42 43 44 45 46 47 48 49
  /*!
   * \brief Check if the inferred(checked) type of the Expr
   *  is backed by a TTypeNode and return it.
   *
   * \note This function will thrown an error if the node type
   *       of this Expr is not TTypeNode.
   *
   * \return The corresponding TTypeNode pointer.
   * \tparam The specific TypeNode we look for.
   */
  template<typename TTypeNode>
  inline const TTypeNode* type_as() const;
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

  static constexpr const char* _type_key = "relay.Expr";
  TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
};

RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef);

/*!
 * \brief Constant tensor, backed by an NDArray on the cpu(0) device.
 *
 * \note Scalar constants are represented by rank-0 const tensor.
 *  Constant folding are handled uniformly via Tensor types.
 */
class Constant;
/*!
 * \brief Constant tensor type.
 */
class ConstantNode : public ExprNode {
 public:
  /*! \brief The data of the tensor */
  runtime::NDArray data;

  /*! \return The corresponding tensor type of the data */
  TensorType tensor_type() const;

  /*! \return Whether it is scalar(rank-0 tensor) */
76 77 78
  bool is_scalar() const {
    return data->ndim == 0;
  }
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static Constant make(runtime::NDArray data);

  static constexpr const char* _type_key = "relay.Constant";
  TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr);

/*! \brief Tuple of multiple Exprs */
class Tuple;
/*! \brief Tuple container */
class TupleNode : public ExprNode {
 public:
  /*! \brief the fields of the tuple */
  tvm::Array<relay::Expr> fields;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("fields", &fields);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);

  static constexpr const char* _type_key = "relay.Tuple";
  TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);

/*!
 * \brief Local variables used in the let expression.
 *
 * Its semantics are similar to tvm.Var node used in TVM's low level
 * tensor expression language.
 *
122
 * \note Each Var is bind only once and is immutable.
123 124 125 126 127
 */
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
 public:
128
  /*!
129 130 131 132 133 134 135
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
136
   */
137
  Id vid;
138 139 140 141 142 143
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;
144

145 146 147 148 149
  /*! \return The name hint of the variable */
  const std::string& name_hint() const {
    return vid->name_hint;
  }

150
  void VisitAttrs(tvm::AttrVisitor* v) final {
151
    v->Visit("vid", &vid);
152
    v->Visit("type_annotation", &type_annotation);
雾雨魔理沙 committed
153
    v->Visit("span", &span);
154 155 156
    v->Visit("_checked_type_", &checked_type_);
  }

157 158
  TVM_DLL static Var make(std::string name_hint,
                          Type type_annotation);
159

160 161 162
  TVM_DLL static Var make(Id vid,
                          Type type_annotation);

163 164 165 166 167 168 169
  static constexpr const char* _type_key = "relay.Var";
  TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);

/*!
170
 * \brief Global variable that leaves in the top-level module.
171 172 173 174 175 176 177 178 179 180 181 182 183
 * This is used to enable recursive calls between function.
 *
 * \note A GlobalVar may only point to functions.
 */
class GlobalVar;
/*! \brief A GlobalId from the node's current type to target type. */
class GlobalVarNode : public ExprNode {
 public:
  /*! \brief The name of the variable, this only acts as a hint. */
  std::string name_hint;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("name_hint", &name_hint);
雾雨魔理沙 committed
184
    v->Visit("span", &span);
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static GlobalVar make(std::string name_hint);

  static constexpr const char* _type_key = "relay.GlobalVar";
  TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);

/*!
 * \brief Function (subgraph in computational graph)
 */
class Function;
/*! \brief Function container */
class FunctionNode : public ExprNode {
 public:
  /*! \brief Function parameters */
204
  tvm::Array<Var> params;
205 206 207 208 209 210 211
  /*!
   * \brief
   * The expression which represents the computation of the function,
   * the expression may reference the parameters, and the type of it
   * or sub-expressions may reference the type variables.
   */
  Expr body;
212 213
  /*! \brief User annotated return type of the function. */
  Type ret_type;
214 215 216 217 218 219 220
  /*!
   * \brief Type parameters of the function.
   *  Enables the function to vary its type based on these.
   *  This corresponds to template paramaters in c++'s terminology.
   *
   * \note This can be usually empty for non-polymorphic functions.
   */
221
  tvm::Array<TypeVar> type_params;
222

223 224 225 226 227
  /*!
   * \brief The attributes which store metadata about functions.
   */
  tvm::Attrs attrs;

228 229 230
  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("params", &params);
    v->Visit("body", &body);
231
    v->Visit("ret_type", &ret_type);
232
    v->Visit("type_params", &type_params);
233
    v->Visit("attrs", &attrs);
234
    v->Visit("span", &span);
235 236 237
    v->Visit("_checked_type_", &checked_type_);
  }

238 239 240 241 242 243 244
  /*!
   * \brief Return the derived function annotation of this expression.
   *
   * \return The function type annotation.
   * \note The function type annotation can contain IncompleteType.
   */
  TVM_DLL FuncType func_type_annotation() const;
245

246 247 248 249 250 251 252
  /*!
   * \brief Check whether the function is a primitive function.
   *
   * \return Whether the function is primitive or not.
   */
  bool IsPrimitive() const;

253 254
  TVM_DLL static Function make(tvm::Array<Var> params,
                               Expr body,
255
                               Type ret_type,
256 257
                               tvm::Array<TypeVar> ty_params,
                               tvm::Attrs attrs = Attrs());
258 259 260 261 262 263 264

  static constexpr const char* _type_key = "relay.Function";
  TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);

265 266 267 268 269

TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);


270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
/*!
 * \brief Call corresponds to operator invocation.
 *  Corresponds to the operator in computational graph terminology.
 */
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
 public:
  /*!
   * \brief The operator(function) being invoked
   *
   *  - It can be relay::Op which corresponds to the primitive operators.
   *  - It can also be user defined functions (Function, GlobalVar, Var).
   */
  Expr op;

  /*! \brief The arguments(inputs) of the call */
  tvm::Array<relay::Expr> args;

  /*! \brief The additional attributes */
  Attrs attrs;

  /*!
   * \brief The type arguments passed to polymorphic(template) function.
   *
   * This is the advance feature that is only used when the function is
   * polymorphic. It is safe to be ignored in most cases. For example, in the
   * following code, the type_args of addone call is [int].
   *
   * \code
   *
   * template<typename T>
   * T addone(T a) { return a + 1; }
   *
   * void main() {
   *   int x = addone<int>(10);
   * }
   *
   * \endcode
   */
  tvm::Array<Type> type_args;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("op", &op);
    v->Visit("args", &args);
    v->Visit("attrs", &attrs);
    v->Visit("type_args", &type_args);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

321 322 323
  TVM_DLL static Call make(Expr op,
                           Array<Expr> args,
                           Attrs attrs = Attrs(),
324
                           Array<Type> type_args = Array<Type>());
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361

  static constexpr const char* _type_key = "relay.Call";
  TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Call, CallNode, Expr);

/*!
 * \brief Let binding that binds a local var and optionally a type annotation.
 *
 * \note Let is useful to transform the program to be A-normal form.
 *  where each of the expression corresponds to a let binding.
 *
 *  For developers who are familar with the computational graph.
 *  Each of the let can be viewed as a operator node in the computational graph.
 *  Traversing the list of let bindings is similar to running
 * PostDFS-order(topo-order) traversal on the computational graph.
 */
class Let;
/*! \brief A binding of a sub-network. */
class LetNode : public ExprNode {
 public:
  /*! \brief The variable we bind to */
  Var var;
  /*! \brief The value we bind var to */
  Expr value;
  /*! \brief The body of the let binding */
  Expr body;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("var", &var);
    v->Visit("value", &value);
    v->Visit("body", &body);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

362
  TVM_DLL static Let make(Var var, Expr value, Expr body);
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377

  static constexpr const char* _type_key = "relay.Let";
  TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);

/*!
 * \brief Condition expression
 *
 * Unlike traditional statement `if`s, the if evalutes
 * to the result of the branch taken.
 *
 * let x = if (true) { 1 } else { 0 }; // x is 1
 * let y = if (false) { 1 } else { 0 }; // y is 0
378
 *
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
 * \note This is similar to C's ternary operator.
 */
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
 public:
  /*! \brief The condition */
  Expr cond;
  /*! \brief The expression evaluated when condition is true. */
  Expr true_branch;
  /*! \brief The expression evaluated when condition is false */
  Expr false_branch;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("cond", &cond);
    v->Visit("true_branch", &true_branch);
    v->Visit("false_branch", &false_branch);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch);

  static constexpr const char* _type_key = "relay.If";
  TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(If, IfNode, Expr);

408
/*! \brief Get index-th field out of a tuple. */
409 410 411
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
 public:
412
  /*! \brief The tuple Expression */
413 414 415 416 417
  Expr tuple;
  /*! \brief which value to get */
  int index;

  void VisitAttrs(tvm::AttrVisitor* v) final {
418
    v->Visit("tuple_value", &tuple);
419
    v->Visit("index", &index);
雾雨魔理沙 committed
420
    v->Visit("span", &span);
421
    v->Visit("_checked_type_", &checked_type_);
422 423 424 425
  }

  TVM_DLL static TupleGetItem make(Expr tuple, int index);

426
  static constexpr const char* _type_key = "relay.TupleGetItem";
427 428 429 430 431
  TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
 public:
  /*! \brief The initial value of the Reference. */
  Expr value;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("value", &value);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefCreate make(Expr value);

  static constexpr const char* _type_key = "relay.RefCreate";
  TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);

/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
 public:
  /*! \brief The Reference Expression. */
  Expr ref;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("ref", &ref);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefRead make(Expr ref);

  static constexpr const char* _type_key = "relay.RefRead";
  TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);

/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
 public:
  /*! \brief The Reference Expression. */
  Expr ref;
  /*! \brief The value to write into. */
  Expr value;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("ref", &ref);
    v->Visit("value", &value);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefWrite make(Expr ref, Expr value);

  static constexpr const char* _type_key = "relay.RefWrite";
  TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);

498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
/*!
 * \brief Base class of the temporary expression.
 *
 * TempExprs are pass specific expression that can be
 * useful to define intermediate result in the
 * rewriting pass such as layout or type transformation.
 *
 * Subclass TempExprNode allows us to pattern match on
 * specific kind TempExpr and use them for expression rewriting.
 *
 * TempExpr should only be used within a pass,
 */
class TempExprNode : public ExprNode {
 public:
  /*!
   * \brief Convert the expression to a normal(non-temp) Expr.
   * \return The corresponding normal(non-temp) expression.
   */
  virtual Expr Realize() const = 0;

  static constexpr const char* _type_key = "relay.TempExpr";
  TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);

524
// implementataions
525 526 527 528 529 530 531 532
inline const Type& ExprNode::checked_type() const {
  CHECK(checked_type_.defined()) << "internal error: the type checker has "
    "not populated the checked_type "
    "field for "
                                 << GetRef<Expr>(this);
  return this->checked_type_;
}

533 534 535 536 537
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
  static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
                "TType must be a special case of type");
  CHECK(checked_type_.defined())
538
      << "Type inference for this Expr has not completed. Try to call infer_type pass.";
539 540 541 542 543 544 545
  const TTypeNode* node = checked_type_.as<TTypeNode>();
  CHECK(node != nullptr)
      << "Expected type to be " << TTypeNode::_type_key
      << ", but get " << checked_type_->type_key();
  return node;
}

546 547 548
/*!
 * \brief Print node as text format.
 * \param node The node to be printed.
549
 * \param show_meta_data Whether to print meta data section.
550 551 552 553 554 555
 * \param annotate An optional callback function for attaching
 *        additional comment block to an expr.
 * \return The text representation.
 */
std::string RelayPrint(
    const NodeRef& node,
556
    bool show_meta_data = true,
557
    runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
558 559 560
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_EXPR_H_