expr.h 16.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 * \file tvm/relay/expr.h
 * \brief Relay expression language.
 */
#ifndef TVM_RELAY_EXPR_H_
#define TVM_RELAY_EXPR_H_

#include <tvm/attrs.h>
#include <string>
29
#include <functional>
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
#include "./base.h"
#include "./type.h"

namespace tvm {
namespace relay {

/*!
 * \brief A Relay expression.
 */
class Expr;
/*!
 * \brief Base type of the Relay expression hiearchy.
 */
class ExprNode : public RelayNode {
 public:
  /*!
   * \brief Stores the result of type inference(type checking).
   *
   * \note This can be undefined before type inference.
   *       This value is discarded during serialization.
   */
  mutable Type checked_type_ = Type(nullptr);
  /*!
   * \return The checked_type
   */
55
  const Type& checked_type() const;
56 57 58 59 60 61 62 63 64 65 66 67
  /*!
   * \brief Check if the inferred(checked) type of the Expr
   *  is backed by a TTypeNode and return it.
   *
   * \note This function will thrown an error if the node type
   *       of this Expr is not TTypeNode.
   *
   * \return The corresponding TTypeNode pointer.
   * \tparam The specific TypeNode we look for.
   */
  template<typename TTypeNode>
  inline const TTypeNode* type_as() const;
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

  static constexpr const char* _type_key = "relay.Expr";
  TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
};

RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef);

/*!
 * \brief Constant tensor, backed by an NDArray on the cpu(0) device.
 *
 * \note Scalar constants are represented by rank-0 const tensor.
 *  Constant folding are handled uniformly via Tensor types.
 */
class Constant;
/*!
 * \brief Constant tensor type.
 */
class ConstantNode : public ExprNode {
 public:
  /*! \brief The data of the tensor */
  runtime::NDArray data;

  /*! \return The corresponding tensor type of the data */
  TensorType tensor_type() const;

  /*! \return Whether it is scalar(rank-0 tensor) */
94 95 96
  bool is_scalar() const {
    return data->ndim == 0;
  }
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static Constant make(runtime::NDArray data);

  static constexpr const char* _type_key = "relay.Constant";
  TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr);

/*! \brief Tuple of multiple Exprs */
class Tuple;
/*! \brief Tuple container */
class TupleNode : public ExprNode {
 public:
  /*! \brief the fields of the tuple */
  tvm::Array<relay::Expr> fields;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("fields", &fields);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);

  static constexpr const char* _type_key = "relay.Tuple";
  TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);

/*!
 * \brief Local variables used in the let expression.
 *
 * Its semantics are similar to tvm.Var node used in TVM's low level
 * tensor expression language.
 *
140
 * \note Each Var is bind only once and is immutable.
141 142 143 144 145
 */
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
 public:
146
  /*!
147 148 149 150 151 152 153
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
154
   */
155
  Id vid;
156 157 158 159 160 161
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;
162

163 164 165 166 167
  /*! \return The name hint of the variable */
  const std::string& name_hint() const {
    return vid->name_hint;
  }

168
  void VisitAttrs(tvm::AttrVisitor* v) final {
169
    v->Visit("vid", &vid);
170
    v->Visit("type_annotation", &type_annotation);
雾雨魔理沙 committed
171
    v->Visit("span", &span);
172 173 174
    v->Visit("_checked_type_", &checked_type_);
  }

175 176
  TVM_DLL static Var make(std::string name_hint,
                          Type type_annotation);
177

178 179 180
  TVM_DLL static Var make(Id vid,
                          Type type_annotation);

181 182 183 184 185 186 187
  static constexpr const char* _type_key = "relay.Var";
  TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);

/*!
188
 * \brief Global variable that leaves in the top-level module.
189 190 191 192 193 194 195 196 197 198 199 200 201
 * This is used to enable recursive calls between function.
 *
 * \note A GlobalVar may only point to functions.
 */
class GlobalVar;
/*! \brief A GlobalId from the node's current type to target type. */
class GlobalVarNode : public ExprNode {
 public:
  /*! \brief The name of the variable, this only acts as a hint. */
  std::string name_hint;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("name_hint", &name_hint);
雾雨魔理沙 committed
202
    v->Visit("span", &span);
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static GlobalVar make(std::string name_hint);

  static constexpr const char* _type_key = "relay.GlobalVar";
  TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);

/*!
 * \brief Function (subgraph in computational graph)
 */
class Function;
/*! \brief Function container */
class FunctionNode : public ExprNode {
 public:
  /*! \brief Function parameters */
222
  tvm::Array<Var> params;
223 224 225 226 227 228 229
  /*!
   * \brief
   * The expression which represents the computation of the function,
   * the expression may reference the parameters, and the type of it
   * or sub-expressions may reference the type variables.
   */
  Expr body;
230 231
  /*! \brief User annotated return type of the function. */
  Type ret_type;
232 233 234 235 236 237 238
  /*!
   * \brief Type parameters of the function.
   *  Enables the function to vary its type based on these.
   *  This corresponds to template paramaters in c++'s terminology.
   *
   * \note This can be usually empty for non-polymorphic functions.
   */
239
  tvm::Array<TypeVar> type_params;
240

241 242 243 244 245
  /*!
   * \brief The attributes which store metadata about functions.
   */
  tvm::Attrs attrs;

246 247 248
  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("params", &params);
    v->Visit("body", &body);
249
    v->Visit("ret_type", &ret_type);
250
    v->Visit("type_params", &type_params);
251
    v->Visit("attrs", &attrs);
252
    v->Visit("span", &span);
253 254 255
    v->Visit("_checked_type_", &checked_type_);
  }

256 257 258 259 260 261 262
  /*!
   * \brief Return the derived function annotation of this expression.
   *
   * \return The function type annotation.
   * \note The function type annotation can contain IncompleteType.
   */
  TVM_DLL FuncType func_type_annotation() const;
263

264 265 266 267 268 269 270
  /*!
   * \brief Check whether the function is a primitive function.
   *
   * \return Whether the function is primitive or not.
   */
  bool IsPrimitive() const;

271 272
  TVM_DLL static Function make(tvm::Array<Var> params,
                               Expr body,
273
                               Type ret_type,
274 275
                               tvm::Array<TypeVar> ty_params,
                               tvm::Attrs attrs = Attrs());
276 277 278 279 280 281 282

  static constexpr const char* _type_key = "relay.Function";
  TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);

283 284 285 286 287

TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);


288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
/*!
 * \brief Call corresponds to operator invocation.
 *  Corresponds to the operator in computational graph terminology.
 */
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
 public:
  /*!
   * \brief The operator(function) being invoked
   *
   *  - It can be relay::Op which corresponds to the primitive operators.
   *  - It can also be user defined functions (Function, GlobalVar, Var).
   */
  Expr op;

  /*! \brief The arguments(inputs) of the call */
  tvm::Array<relay::Expr> args;

  /*! \brief The additional attributes */
  Attrs attrs;

  /*!
   * \brief The type arguments passed to polymorphic(template) function.
   *
   * This is the advance feature that is only used when the function is
   * polymorphic. It is safe to be ignored in most cases. For example, in the
   * following code, the type_args of addone call is [int].
   *
   * \code
   *
   * template<typename T>
   * T addone(T a) { return a + 1; }
   *
   * void main() {
   *   int x = addone<int>(10);
   * }
   *
   * \endcode
   */
  tvm::Array<Type> type_args;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("op", &op);
    v->Visit("args", &args);
    v->Visit("attrs", &attrs);
    v->Visit("type_args", &type_args);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

339 340 341
  TVM_DLL static Call make(Expr op,
                           Array<Expr> args,
                           Attrs attrs = Attrs(),
342
                           Array<Type> type_args = Array<Type>());
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379

  static constexpr const char* _type_key = "relay.Call";
  TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Call, CallNode, Expr);

/*!
 * \brief Let binding that binds a local var and optionally a type annotation.
 *
 * \note Let is useful to transform the program to be A-normal form.
 *  where each of the expression corresponds to a let binding.
 *
 *  For developers who are familar with the computational graph.
 *  Each of the let can be viewed as a operator node in the computational graph.
 *  Traversing the list of let bindings is similar to running
 * PostDFS-order(topo-order) traversal on the computational graph.
 */
class Let;
/*! \brief A binding of a sub-network. */
class LetNode : public ExprNode {
 public:
  /*! \brief The variable we bind to */
  Var var;
  /*! \brief The value we bind var to */
  Expr value;
  /*! \brief The body of the let binding */
  Expr body;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("var", &var);
    v->Visit("value", &value);
    v->Visit("body", &body);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

380
  TVM_DLL static Let make(Var var, Expr value, Expr body);
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395

  static constexpr const char* _type_key = "relay.Let";
  TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);

/*!
 * \brief Condition expression
 *
 * Unlike traditional statement `if`s, the if evalutes
 * to the result of the branch taken.
 *
 * let x = if (true) { 1 } else { 0 }; // x is 1
 * let y = if (false) { 1 } else { 0 }; // y is 0
396
 *
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
 * \note This is similar to C's ternary operator.
 */
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
 public:
  /*! \brief The condition */
  Expr cond;
  /*! \brief The expression evaluated when condition is true. */
  Expr true_branch;
  /*! \brief The expression evaluated when condition is false */
  Expr false_branch;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("cond", &cond);
    v->Visit("true_branch", &true_branch);
    v->Visit("false_branch", &false_branch);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch);

  static constexpr const char* _type_key = "relay.If";
  TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(If, IfNode, Expr);

426
/*! \brief Get index-th field out of a tuple. */
427 428 429
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
 public:
430
  /*! \brief The tuple Expression */
431 432 433 434 435
  Expr tuple;
  /*! \brief which value to get */
  int index;

  void VisitAttrs(tvm::AttrVisitor* v) final {
436
    v->Visit("tuple_value", &tuple);
437
    v->Visit("index", &index);
雾雨魔理沙 committed
438
    v->Visit("span", &span);
439
    v->Visit("_checked_type_", &checked_type_);
440 441 442 443
  }

  TVM_DLL static TupleGetItem make(Expr tuple, int index);

444
  static constexpr const char* _type_key = "relay.TupleGetItem";
445 446 447 448 449
  TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
 public:
  /*! \brief The initial value of the Reference. */
  Expr value;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("value", &value);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefCreate make(Expr value);

  static constexpr const char* _type_key = "relay.RefCreate";
  TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);

/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
 public:
  /*! \brief The Reference Expression. */
  Expr ref;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("ref", &ref);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefRead make(Expr ref);

  static constexpr const char* _type_key = "relay.RefRead";
  TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);

/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
 public:
  /*! \brief The Reference Expression. */
  Expr ref;
  /*! \brief The value to write into. */
  Expr value;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("ref", &ref);
    v->Visit("value", &value);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }

  TVM_DLL static RefWrite make(Expr ref, Expr value);

  static constexpr const char* _type_key = "relay.RefWrite";
  TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);

516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
/*!
 * \brief Base class of the temporary expression.
 *
 * TempExprs are pass specific expression that can be
 * useful to define intermediate result in the
 * rewriting pass such as layout or type transformation.
 *
 * Subclass TempExprNode allows us to pattern match on
 * specific kind TempExpr and use them for expression rewriting.
 *
 * TempExpr should only be used within a pass,
 */
class TempExprNode : public ExprNode {
 public:
  /*!
   * \brief Convert the expression to a normal(non-temp) Expr.
   * \return The corresponding normal(non-temp) expression.
   */
  virtual Expr Realize() const = 0;

  static constexpr const char* _type_key = "relay.TempExpr";
  TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr);

542
// implementataions
543 544 545 546 547 548 549 550
inline const Type& ExprNode::checked_type() const {
  CHECK(checked_type_.defined()) << "internal error: the type checker has "
    "not populated the checked_type "
    "field for "
                                 << GetRef<Expr>(this);
  return this->checked_type_;
}

551 552 553 554 555
template<typename TTypeNode>
inline const TTypeNode* ExprNode::type_as() const {
  static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
                "TType must be a special case of type");
  CHECK(checked_type_.defined())
556
      << "Type inference for this Expr has not completed. Try to call infer_type pass.";
557 558 559 560 561 562 563
  const TTypeNode* node = checked_type_.as<TTypeNode>();
  CHECK(node != nullptr)
      << "Expected type to be " << TTypeNode::_type_key
      << ", but get " << checked_type_->type_key();
  return node;
}

564
/*!
565 566
 * \brief Render the node as a string in the Relay text format.
 * \param node The node to be rendered.
567
 * \param show_meta_data Whether to print meta data section.
568 569 570 571
 * \param annotate An optional callback function for attaching
 *        additional comment block to an expr.
 * \return The text representation.
 */
572
std::string AsText(const NodeRef& node,
雾雨魔理沙 committed
573 574
                   bool show_meta_data = true,
                   runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
575 576 577
}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_EXPR_H_