expr.h 18 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 * \file tvm/relay/expr.h
 * \brief Relay expression language.
 */
#ifndef TVM_RELAY_EXPR_H_
#define TVM_RELAY_EXPR_H_

#include <tvm/attrs.h>
#include <string>
29
#include <functional>
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
#include "./base.h"
#include "./type.h"

namespace tvm {
namespace relay {

/*!
 * \brief A Relay expression.
 */
class Expr;
/*!
 * \brief Base type of the Relay expression hiearchy.
 */
class ExprNode : public RelayNode {
 public:
  /*!
   * \brief Stores the result of type inference(type checking).
   *
   * \note This can be undefined before type inference.
   *       This value is discarded during serialization.
   */
  mutable Type checked_type_ = Type(nullptr);
  /*!
   * \return The checked_type
   */
55
  const Type& checked_type() const;
56 57 58 59 60 61 62 63 64 65 66 67
  /*!
   * \brief Check if the inferred(checked) type of the Expr
   *  is backed by a TTypeNode and return it.
   *
   * \note This function will thrown an error if the node type
   *       of this Expr is not TTypeNode.
   *
   * \return The corresponding TTypeNode pointer.
   * \tparam The specific TypeNode we look for.
   */
  template<typename TTypeNode>
  inline const TTypeNode* type_as() const;
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

  static constexpr const char* _type_key = "relay.Expr";
  TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
};

RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef);

/*!
 * \brief Constant tensor, backed by an NDArray on the cpu(0) device.
 *
 * \note Scalar constants are represented by rank-0 const tensor.
 *  Constant folding are handled uniformly via Tensor types.
 */
class Constant;
/*!
 * \brief Constant tensor type.
 */
class ConstantNode : public ExprNode {
 public:
  /*! \brief The data of the tensor */
  runtime::NDArray data;

  /*! \return The corresponding tensor type of the data */
  TensorType tensor_type() const;

  /*! \return Whether it is scalar(rank-0 tensor) */
94 95 96
  bool is_scalar() const {
    return data->ndim == 0;
  }
97

98
  void VisitAttrs(tvm::AttrVisitor* v) {
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static Constant make(runtime::NDArray data);

  static constexpr const char* _type_key = "relay.Constant";
  TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr);

/*! \brief Tuple of multiple Exprs */
class Tuple;
/*! \brief Tuple container */
class TupleNode : public ExprNode {
 public:
  /*! \brief the fields of the tuple */
  tvm::Array<relay::Expr> fields;

120
  void VisitAttrs(tvm::AttrVisitor* v) {
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    v->Visit("fields", &fields);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);

  static constexpr const char* _type_key = "relay.Tuple";
  TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);

/*!
 * \brief Local variables used in the let expression.
 *
 * Its semantics are similar to tvm.Var node used in TVM's low level
 * tensor expression language.
 *
140
 * \note Each Var is bind only once and is immutable.
141 142 143 144 145
 */
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
 public:
146
  /*!
147 148 149 150 151 152 153
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
154
   */
155
  Id vid;
156 157 158 159 160 161
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;
162

163 164 165 166 167
  /*! \return The name hint of the variable */
  const std::string& name_hint() const {
    return vid->name_hint;
  }

168
  void VisitAttrs(tvm::AttrVisitor* v) {
169
    v->Visit("vid", &vid);
170
    v->Visit("type_annotation", &type_annotation);
雾雨魔理沙 committed
171
    v->Visit("span", &span);
172 173 174
    v->Visit("_checked_type_", &checked_type_);
  }

175 176
  TVM_DLL static Var make(std::string name_hint,
                          Type type_annotation);
177

178 179 180
  TVM_DLL static Var make(Id vid,
                          Type type_annotation);

181 182 183 184 185 186 187
  static constexpr const char* _type_key = "relay.Var";
  TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);

/*!
188
 * \brief Global variable that leaves in the top-level module.
189 190 191 192 193 194 195 196 197 198 199
 * This is used to enable recursive calls between function.
 *
 * \note A GlobalVar may only point to functions.
 */
class GlobalVar;
/*! \brief A GlobalId from the node's current type to target type. */
class GlobalVarNode : public ExprNode {
 public:
  /*! \brief The name of the variable, this only acts as a hint. */
  std::string name_hint;

200
  void VisitAttrs(tvm::AttrVisitor* v) {
201
    v->Visit("name_hint", &name_hint);
雾雨魔理沙 committed
202
    v->Visit("span", &span);
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static GlobalVar make(std::string name_hint);

  static constexpr const char* _type_key = "relay.GlobalVar";
  TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);

/*!
 * \brief Function (subgraph in computational graph)
 */
class Function;
/*! \brief Function container */
class FunctionNode : public ExprNode {
 public:
  /*! \brief Function parameters */
222
  tvm::Array<Var> params;
223 224 225 226 227 228 229
  /*!
   * \brief
   * The expression which represents the computation of the function,
   * the expression may reference the parameters, and the type of it
   * or sub-expressions may reference the type variables.
   */
  Expr body;
230 231
  /*! \brief User annotated return type of the function. */
  Type ret_type;
232 233 234 235 236 237 238
  /*!
   * \brief Type parameters of the function.
   *  Enables the function to vary its type based on these.
   *  This corresponds to template paramaters in c++'s terminology.
   *
   * \note This can be usually empty for non-polymorphic functions.
   */
239
  tvm::Array<TypeVar> type_params;
240

241 242 243 244 245
  /*!
   * \brief The attributes which store metadata about functions.
   */
  tvm::Attrs attrs;

246
  void VisitAttrs(tvm::AttrVisitor* v) {
247 248
    v->Visit("params", &params);
    v->Visit("body", &body);
249
    v->Visit("ret_type", &ret_type);
250
    v->Visit("type_params", &type_params);
251
    v->Visit("attrs", &attrs);
252
    v->Visit("span", &span);
253 254 255
    v->Visit("_checked_type_", &checked_type_);
  }

256 257 258 259 260 261 262
  /*!
   * \brief Return the derived function annotation of this expression.
   *
   * \return The function type annotation.
   * \note The function type annotation can contain IncompleteType.
   */
  TVM_DLL FuncType func_type_annotation() const;
263

264 265 266 267 268 269 270
  /*!
   * \brief Check whether the function is a primitive function.
   *
   * \return Whether the function is primitive or not.
   */
  bool IsPrimitive() const;

Zhi committed
271 272 273 274 275 276 277 278 279
  /*!
   * \brief Check whether the function should use the TVM default compiler to build, or
   * use other compilers.
   *
   * \return Whether the function will be compiled using the default compiler
   * (e.g. those are used in the TVM stack).
   */
  bool UseDefaultCompiler() const;

280 281
  TVM_DLL static Function make(tvm::Array<Var> params,
                               Expr body,
282
                               Type ret_type,
283 284
                               tvm::Array<TypeVar> ty_params,
                               tvm::Attrs attrs = Attrs());
285

286 287 288 289 290 291 292 293 294 295 296 297 298
  /*!
   * \brief Attach the function's parameters to its attributes for use in analysis.
   * \return The function with its parameters attached.
   */
  Function SetParams(const tvm::Map<Var, Constant>& parameters) const;

  /*!
   * \brief Retrieve the function's parameters.
   *
   * \return The function's parameter.
   */
  tvm::Map<Var, Constant> GetParams() const;

299 300 301 302 303 304
  static constexpr const char* _type_key = "relay.Function";
  TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);

305 306 307 308

TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
/*!
 * \brief Call corresponds to operator invocation.
 *  Corresponds to the operator in computational graph terminology.
 */
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
 public:
  /*!
   * \brief The operator(function) being invoked
   *
   *  - It can be relay::Op which corresponds to the primitive operators.
   *  - It can also be user defined functions (Function, GlobalVar, Var).
   */
  Expr op;

  /*! \brief The arguments(inputs) of the call */
  tvm::Array<relay::Expr> args;

  /*! \brief The additional attributes */
  Attrs attrs;

  /*!
   * \brief The type arguments passed to polymorphic(template) function.
   *
   * This is the advance feature that is only used when the function is
   * polymorphic. It is safe to be ignored in most cases. For example, in the
   * following code, the type_args of addone call is [int].
   *
   * \code
   *
   * template<typename T>
   * T addone(T a) { return a + 1; }
   *
   * void main() {
   *   int x = addone<int>(10);
   * }
   *
   * \endcode
   */
  tvm::Array<Type> type_args;

351
  void VisitAttrs(tvm::AttrVisitor* v) {
352 353 354 355 356 357 358 359
    v->Visit("op", &op);
    v->Visit("args", &args);
    v->Visit("attrs", &attrs);
    v->Visit("type_args", &type_args);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

360 361 362
  TVM_DLL static Call make(Expr op,
                           Array<Expr> args,
                           Attrs attrs = Attrs(),
363
                           Array<Type> type_args = Array<Type>());
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392

  static constexpr const char* _type_key = "relay.Call";
  TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Call, CallNode, Expr);

/*!
 * \brief Let binding that binds a local var and optionally a type annotation.
 *
 * \note Let is useful to transform the program to be A-normal form.
 *  where each of the expression corresponds to a let binding.
 *
 *  For developers who are familar with the computational graph.
 *  Each of the let can be viewed as a operator node in the computational graph.
 *  Traversing the list of let bindings is similar to running
 * PostDFS-order(topo-order) traversal on the computational graph.
 */
class Let;
/*! \brief A binding of a sub-network. */
class LetNode : public ExprNode {
 public:
  /*! \brief The variable we bind to */
  Var var;
  /*! \brief The value we bind var to */
  Expr value;
  /*! \brief The body of the let binding */
  Expr body;

393
  void VisitAttrs(tvm::AttrVisitor* v) {
394 395 396 397 398 399 400
    v->Visit("var", &var);
    v->Visit("value", &value);
    v->Visit("body", &body);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

401
  TVM_DLL static Let make(Var var, Expr value, Expr body);
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416

  static constexpr const char* _type_key = "relay.Let";
  TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);

/*!
 * \brief Condition expression
 *
 * Unlike traditional statement `if`s, the if evalutes
 * to the result of the branch taken.
 *
 * let x = if (true) { 1 } else { 0 }; // x is 1
 * let y = if (false) { 1 } else { 0 }; // y is 0
417
 *
418 419 420 421 422 423 424 425 426 427 428 429 430
 * \note This is similar to C's ternary operator.
 */
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
 public:
  /*! \brief The condition */
  Expr cond;
  /*! \brief The expression evaluated when condition is true. */
  Expr true_branch;
  /*! \brief The expression evaluated when condition is false */
  Expr false_branch;

431
  void VisitAttrs(tvm::AttrVisitor* v) {
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
    v->Visit("cond", &cond);
    v->Visit("true_branch", &true_branch);
    v->Visit("false_branch", &false_branch);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch);

  static constexpr const char* _type_key = "relay.If";
  TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(If, IfNode, Expr);

447
/*! \brief Get index-th field out of a tuple. */
448 449 450
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
 public:
451
  /*! \brief The tuple Expression */
452 453 454 455
  Expr tuple;
  /*! \brief which value to get */
  int index;

456
  void VisitAttrs(tvm::AttrVisitor* v) {
457
    v->Visit("tuple_value", &tuple);
458
    v->Visit("index", &index);
雾雨魔理沙 committed
459
    v->Visit("span", &span);
460
    v->Visit("_checked_type_", &checked_type_);
461 462 463 464
  }

  TVM_DLL static TupleGetItem make(Expr tuple, int index);

465
  static constexpr const char* _type_key = "relay.TupleGetItem";
466 467 468 469 470
  TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

471 472 473 474 475 476 477
/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
 public:
  /*! \brief The initial value of the Reference. */
  Expr value;

478
  void VisitAttrs(tvm::AttrVisitor* v) {
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
    v->Visit("value", &value);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefCreate make(Expr value);

  static constexpr const char* _type_key = "relay.RefCreate";
  TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);

/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
 public:
  /*! \brief The Reference Expression. */
  Expr ref;

499
  void VisitAttrs(tvm::AttrVisitor* v) {
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
    v->Visit("ref", &ref);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefRead make(Expr ref);

  static constexpr const char* _type_key = "relay.RefRead";
  TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
 public:
  /*! \brief The Reference Expression. */
  Expr ref;
  /*! \brief The value to write into. */
  Expr value;

521
  void VisitAttrs(tvm::AttrVisitor* v) {
522 523 524 525 526 527 528 529 530 531 532 533 534 535
    v->Visit("ref", &ref);
    v->Visit("value", &value);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefWrite make(Expr ref, Expr value);

  static constexpr const char* _type_key = "relay.RefWrite";
  TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);

536 537 538 539 540 541 542 543
/*!
 * \brief Base class of the temporary expression.
 *
 * TempExprs are pass specific expression that can be
 * useful to define intermediate result in the
 * rewriting pass such as layout or type transformation.
 *
 * Subclass TempExprNode allows us to pattern match on
544
 * specific kind of TempExpr and use them for expression rewriting.
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
 *
 * TempExpr should only be used within a pass,
 */
class TempExprNode : public ExprNode {
 public:
  /*!
   * \brief Convert the expression to a normal(non-temp) Expr.
   * \return The corresponding normal(non-temp) expression.
   */
  virtual Expr Realize() const = 0;

  static constexpr const char* _type_key = "relay.TempExpr";
  TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);

562
// implementataions
563
inline const Type& ExprNode::checked_type() const {
564 565 566 567 568
  CHECK(checked_type_.defined())
      << "internal error: the type checker has "
      << "not populated the checked_type "
      << "field for "
      << GetRef<Expr>(this);
569 570 571
  return this->checked_type_;
}

572 573 574 575 576
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
  static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
                "TType must be a special case of type");
  CHECK(checked_type_.defined())
577
      << "Type inference for this Expr has not completed. Try to call infer_type pass.";
578 579 580
  const TTypeNode* node = checked_type_.as<TTypeNode>();
  CHECK(node != nullptr)
      << "Expected type to be " << TTypeNode::_type_key
581
      << ", but get " << checked_type_->GetTypeKey();
582 583 584
  return node;
}

585 586 587
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const NodeRef& node);

588
/*!
589 590
 * \brief Render the node as a string in the Relay text format.
 * \param node The node to be rendered.
591
 * \param show_meta_data Whether to print meta data section.
592 593 594 595
 * \param annotate An optional callback function for attaching
 *        additional comment block to an expr.
 * \return The text representation.
 */
596
std::string AsText(const NodeRef& node,
雾雨魔理沙 committed
597 598
                   bool show_meta_data = true,
                   runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
599

Zhi committed
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
 * \brief Indicate the compiler that should be used for builing this function.
 * When this is unset or set to "default", the default compilation pipeline will be used.
 */
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
}  // namespace attr

619 620 621
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_EXPR_H_