text_printer.cc 26.3 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file text_printer.cc
 * \brief Text printer to print relay in text form.
 */
6
#include <tvm/relay/module.h>
7
#include <tvm/relay/expr_functor.h>
8
#include <tvm/relay/pattern_functor.h>
9
#include <sstream>
10
#include "type_functor.h"
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
#include "../../lang/attr_functor.h"

namespace tvm {
namespace relay {

/*!
 * \brief the text value used in text printer.
 * Defined as a struct for future compatibility reason
 */
struct TextValue {
  /*! \brief The str representation */
  std::string name;
  // constructor
  TextValue() {}
  // constructor
  explicit TextValue(std::string name) : name(name) {}
27 28 29 30 31 32
  TextValue operator+(const TextValue& rhs) const {
    return TextValue(name + rhs.name);
  }
  TextValue operator+(const std::string& str) const {
    return TextValue(name + str);
  }
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
};

// operator overloading
inline std::ostream& operator<<(std::ostream& os, const TextValue& val) {  // NOLINT(*)
  return os << val.name;
}

/*!
 * \brief Meta data context for TextPrinter.
 *
 * This is an important part to enable bi-directional serializability.
 * We use tvm's Node system to build the current IR.
 * It can be hard to design a text format for all the possible nodes
 * as the set of nodes can grow when we do more extensions.
 *
48
 * Instead of trying to design readable text format for every node,
49 50 51
 * we support a meta-data section in the text format.
 * We allow the text format to refer to a node in the meta-data section.
 *
52
 * The meta-data section is a json serialized string of an Map<string, Array<NodeRef>>.
53 54 55
 * Each element in the meta-data section can be referenced by the text format.
 * Each meta data node is printed in the following format.
 *
56
 * meta[type-key-of-node>][<index-in-meta-section>]
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 *
 * Specifically, consider the following IR(constructed by python).
 *
 * \code
 *
 * n = tvm.var("n")
 * x = tvm.relay.var("x", shape=(n, 1))
 * f = tvm.relay.Function([x], x)
 * print(f.astext())
 *
 * \endcode
 *
 * The corresponding text format is shown in the following code block.
 *
 * \code
 *
73
 * fn (%x: Tensor[(meta[Variable][0],), float32]) {
74 75 76 77 78 79 80 81 82
 *   %x
 * }
 * # Meta data section is a json-serialized string
 * # of the following array.
 * # [tvm.var("n")]
 *
 * \endcode
 *
 * Note that we store tvm.var("n") in the meta data section.
83
 * Since it is stored in the index-0 in the meta-data section,
84
 * we print it as meta[Variable][0].
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
 *
 * The text parser can recover this object by loading from the corresponding
 * location in the meta data section.
 *
 * This is is a design trade-off.
 * It allows us to embedded any meta-data in the text format,
 * while still being able to tweak the text part of the printed IR easily.
 */
class TextMetaDataContext {
 public:
  /*!
   * \brief Get text representation of meta node.
   * \param node The node to be converted to meta node.
   * \return A string representation of the meta node.
   */
  std::string GetMetaNode(const NodeRef& node) {
101 102 103
    auto it = meta_repr_.find(node);
    if (it != meta_repr_.end()) {
      return it->second;
104
    }
105 106 107 108 109 110 111 112
    Array<NodeRef>& mvector =
        meta_data_[node->type_key()];
    int64_t index = static_cast<int64_t>(mvector.size());
    mvector.push_back(node);
    std::ostringstream os;
    os << "meta[" << node->type_key() << "][" << index << "]";
    meta_repr_[node] = os.str();
    return meta_repr_[node];
113 114 115 116 117 118 119
  }
  /*!
   * \brief Get the metadata section in json format.
   * \return the meta datastring.
   */
  std::string GetMetaSection() const {
    if (meta_data_.size() == 0) return std::string();
120 121
    return SaveJSON(Map<std::string, NodeRef>(
        meta_data_.begin(), meta_data_.end()));
122 123
  }

124 125 126 127 128
  /*! \return whether the meta data context is empty. */
  bool empty() const {
    return meta_data_.empty();
  }

129 130
 private:
  /*! \brief additional metadata stored in TVM json format */
131 132 133
  std::unordered_map<std::string, Array<NodeRef> > meta_data_;
  /*! \brief map from meta data into its string representation */
  std::unordered_map<NodeRef, std::string, NodeHash, NodeEqual> meta_repr_;
134 135 136
};

class TextPrinter :
137
    public ExprFunctor<TextValue(const Expr&)>,
138
    public PatternFunctor<TextValue(const Pattern&)>,
139 140 141
    public TypeFunctor<void (const Type&, std::ostream& os)>,  // NOLINT(*)
    public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
 public:
142 143 144
  explicit TextPrinter(bool show_meta_data,
                       runtime::TypedPackedFunc<std::string(Expr)> annotate)
      : show_meta_data_(show_meta_data), annotate_(annotate) {}
145 146 147 148 149 150 151 152
  /*!
   * \brief Print a node to string.
   * \param node.
   * \return The string representation.
   */
  std::string Print(const NodeRef& node) {
    if (node.as<FunctionNode>()) {
      this->PrintFunc(Downcast<Function>(node));
153 154
    } else if (node.as<ModuleNode>()) {
      this->PrintEnv(Downcast<Module>(node));
155 156 157 158 159 160 161
    } else if (node.as_derived<TypeNode>()) {
      this->PrintType(Downcast<Type>(node), stream_);
    } else if (node.as_derived<ExprNode>()) {
      this->PrintExpr(Downcast<Expr>(node));
    } else {
      stream_ << node;
    }
162 163 164 165 166 167 168 169 170 171 172
    if (!meta_.empty()) {
      if (show_meta_data_) {
        std::string meta_json = meta_.GetMetaSection();
        // append meta data in the end.
        stream_ << "# meta data\n"
                << "r\"\"\"\n"
                << meta_json << "\n"
                << "\"\"\"";
      } else {
        stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n";
      }
173 174 175 176 177
    }
    return stream_.str();
  }

  void PrintFunc(const Function& func) {
178
    this->PrintFuncInternal("fn ", func);
179 180 181
    stream_ << "\n";
  }

182
  void PrintEnv(const Module& mod) {
183
    int counter = 0;
184
    for (const auto& kv : mod->functions) {
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
      std::ostringstream os;
      if (counter++ != 0) {
        stream_ << "\n";
      }
      os << "def @" << kv.first->name_hint;
      this->PrintFuncInternal(os.str(), kv.second);
      stream_ << "\n";
    }
  }

  void PrintExpr(const Expr& expr) {
    TextValue val = GetValue(expr);
    stream_ << val << "\n";
  }

  /*!
   * \brief Get text representation of expr.
   *
   * This function may generate additional instructions
   * in order to compute the final result id of expr.
   *
   * When trying to recursively print out an Expr.
   * The caller should always call GetValue of its children first.
   * Then the caller can print out to stream_ using the obtained value.
   *
   * This is to avoid the call of subsequent GetValue print out
   * additional instructions which get mixed with the partial instruction
   * printed by the caller.
   *
   * \param expr The input expression.
   * \return The text value of Expr.
   */
  TextValue GetValue(const Expr& expr) {
    auto it = memo_.find(expr);
    if (it != memo_.end()) return it->second;
    TextValue val = this->VisitExpr(expr);
    memo_[expr] = val;
    return val;
  }
224 225 226
  TextValue GetValue(const Pattern& p) {
    return this->VisitPattern(p);
  }
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
  //------------------------------------
  // Overload of Expr printing functions
  //------------------------------------
  TextValue VisitExpr_(const ConstantNode* op) final {
    // Print out simple scalar directly.
    if (op->is_scalar()) {
      std::ostringstream os;
      DataType dtype = TVMType2Type(op->data->dtype);
      CHECK_EQ(op->data->ctx.device_type, kDLCPU);
      if (dtype == Int(32)) {
        return ConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
      } else if (dtype == Int(64)) {
        return ConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
      } else if (dtype == Float(32)) {
        return ConstScalar(dtype, static_cast<const float*>(op->data->data));
      } else if (dtype == Float(64)) {
        return ConstScalar(dtype, static_cast<const double*>(op->data->data));
244 245
      } else if (dtype == Bool()) {
        return ConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
246 247 248 249 250 251
      }
    }
    // default fall-back, record it as meta node.
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
    stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op));
252 253 254
    this->PrintEndInst("");
    this->PrintOptionalInfo(GetRef<Expr>(op));
    stream_ << '\n';
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    return id;
  }

  TextValue VisitExpr_(const TupleNode* op) final {
    std::vector<TextValue> fields;
    for (Expr field : op->fields) {
      fields.push_back(GetValue(field));
    }
    // NOTE: always recursively visit to get ids,
    // before print out the current line
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
    stream_ << id << " = (";
    for (size_t i = 0; i < fields.size(); ++i) {
      stream_ << fields[i];
      if (i + 1 != fields.size()) {
        stream_ << ", ";
      }
    }
274 275 276
    if (fields.size() == 1) {
      stream_ << ',';
    }
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
    stream_ << ')';
    this->PrintEndInst("\n");
    return id;
  }

  TextValue VisitExpr_(const VarNode* op) final {
    Var var = GetRef<Var>(op);
    // This is an unbounded var.
    TextValue val = AllocVarName(var);
    this->PrintIndent();
    stream_ << "free_var ";
    this->PrintVarDecl(var, stream_);
    this->PrintEndInst("\n");
    return val;
  }

  TextValue VisitExpr_(const GlobalVarNode* op) final {
    return TextValue('@' + op->name_hint);
  }

  TextValue VisitExpr_(const FunctionNode* op) final {
    TextValue id = AllocTempVar();
    std::ostringstream os;
300
    os << id << " = fn";
301 302 303 304 305 306 307 308 309 310 311
    this->PrintFuncInternal(os.str(), GetRef<Function>(op));
    this->PrintEndInst("\n");
    return id;
  }

  TextValue VisitExpr_(const CallNode* op) final {
    // possibly through meta-data
    std::vector<TextValue> args;
    for (Expr arg : op->args) {
      args.emplace_back(GetValue(arg));
    }
312
    TextValue call_op = GetValue(op->op);
313 314
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331

    stream_ << id << " = " << call_op;

    auto type_args = op->type_args;

    if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) {
      stream_ << "<";
      for (size_t i = 0; i < op->type_args.size(); ++i) {
        this->PrintType(type_args[i], stream_);
        if (i + 1 != type_args.size()) {
          stream_ << ", ";
        }
      }
      stream_ << ">";
    }

    stream_ << "(";
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
    for (size_t i = 0; i < args.size(); ++i) {
      stream_ << args[i];
      if (i + 1 != args.size()) {
        stream_ << ", ";
      }
    }
    this->PrintCallAttrs(op->op, op->attrs, stream_);
    stream_ << ")";
    this->PrintEndInst("");
    this->PrintOptionalInfo(GetRef<Expr>(op));
    stream_ << '\n';
    return id;
  }

  TextValue VisitExpr_(const LetNode* op) final {
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
    stream_ << id << " = ";
    this->PrintScope(GetRef<Expr>(op));
    this->PrintEndInst("\n");
    return id;
  }

  TextValue VisitExpr_(const IfNode* op) final {
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
    stream_ << id << " = ";
    this->PrintScope(GetRef<Expr>(op));
    this->PrintEndInst("\n");
    return id;
  }

  TextValue VisitExpr_(const OpNode* op) final {
    return TextValue(op->name);
  }

  TextValue VisitExpr_(const TupleGetItemNode* op) final {
    TextValue tuple = GetValue(op->tuple);
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
372
    stream_ << id << " = " << tuple << "." << op->index << "";
373 374 375 376
    this->PrintEndInst("\n");
    return id;
  }

377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
  TextValue VisitExpr_(const RefCreateNode* op) final {
    TextValue value = GetValue(op->value);
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
    stream_ << id << " = " << "RefCreate(" << op->value << ")";
    this->PrintEndInst("\n");
    return id;
  }

  TextValue VisitExpr_(const RefReadNode* op) final {
    TextValue ref = GetValue(op->ref);
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
    stream_ << id << " = " << "RefRead(" << ref << ")";
    this->PrintEndInst("\n");
    return id;
  }

  TextValue VisitExpr_(const RefWriteNode* op) final {
    TextValue ref = GetValue(op->ref);
    TextValue value = GetValue(op->value);
    TextValue id = this->AllocTempVar();
    this->PrintIndent();
    stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")";
    this->PrintEndInst("\n");
    return id;
  }

405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
  TextValue VisitExpr_(const MatchNode* op) final {
    TextValue data = GetValue(op->data);
    this->PrintIndent();
    TextValue id = this->AllocTempVar();
    stream_ << id << " = " << "Match " << data << " with";
    this->PrintEndInst("\n");
    for (const auto& c : op->clauses) {
      this->PrintIndent();
      stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs);
      this->PrintEndInst("\n");
    }
    return id;
  }

  TextValue VisitPattern_(const PatternConstructorNode* p) final {
    TextValue ret(p->constructor->name_hint + "(");
    for (const Pattern& pat : p->patterns) {
      ret = ret + " " + GetValue(pat);
    }
    return ret + ")";
  }

  TextValue VisitPattern_(const PatternVarNode* pv) final {
    return GetValue(pv->var);
  }

  TextValue VisitExpr_(const ConstructorNode* n) final {
    return TextValue(n->name_hint);
  }

435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
  /*!
   * \brief Print the type to os
   * \param type The type to be printed.
   * \param os The output type.
   */
  void PrintType(const Type& type, std::ostream& os) {  // NOLINT(*)
    this->VisitType(type, os);
  }
  //------------------------------------
  // Overload of Expr printing functions
  //------------------------------------
  void VisitType_(const TensorTypeNode* node, std::ostream& os) final {  // NOLINT(*)
    // scalar type
    if (node->shape.size() == 0) {
      os << runtime::TVMType2String(Type2TVMType(node->dtype));
      return;
    }
    os << "Tensor[(";
    for (size_t i = 0; i < node->shape.size(); ++i) {
      this->PrintAttr(node->shape[i], os);
      if (i + 1 != node->shape.size()) {
        os << ", ";
      }
    }
    // conform to python tuple format (1,)
    if (node->shape.size() == 1) {
      os << ",";
    }
    os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]";
  }

466 467 468 469 470 471 472 473 474 475 476
  void VisitType_(const TupleTypeNode* node, std::ostream& os) final {  // NOLINT(*)
    os << "Tuple[";
    for (size_t i = 0; i < node->fields.size(); ++i) {
      this->PrintType(node->fields[i], os);
      if (i + 1 != node->fields.size()) {
        os << ", ";
      }
    }
    os << "]";
  }

477 478 479 480
  void VisitType_(const RefTypeNode* node, std::ostream& os) final {
    VisitTypeDefault_(node, os);
  }

481 482 483 484 485 486 487 488 489 490 491 492
  void VisitType_(const TypeCallNode* node, std::ostream& os) final {
    os << node->func << "(" << node->args << ")";
  }

  void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final {
    VisitTypeDefault_(node, os);
  }

  void VisitType_(const TypeDataNode* node, std::ostream& os) final {
    VisitTypeDefault_(node, os);
  }

493 494 495 496 497 498 499 500 501 502 503
  void VisitTypeDefault_(const Node* node, std::ostream& os) final {  // NOLINT(*)
    // by default always print as meta-data
    os << meta_.GetMetaNode(GetRef<NodeRef>(node));
  }

  /*!
   * \brief Print an attribute value to os.
   * \param value The value to be printed.
   * \param os The output type.
   */
  void PrintAttr(const NodeRef& value, std::ostream& os) {  // NOLINT(*)
504 505 506 507 508
    if (value.defined()) {
      this->VisitAttr(value, os);
    } else {
      os << "None";
    }
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622
  }
  //------------------------------------
  // Overload of Attr printing functions
  //------------------------------------
  void VisitAttr_(const ArrayNode* op, std::ostream& os) final {  // NOLINT(*)
    os << "[";
    for (size_t i = 0; i < op->data.size(); ++i) {
      this->PrintAttr(NodeRef(op->data[i]), os);
      if (i + 1 != op->data.size()) {
        os << ", ";
      }
    }
    os << "]";
  }
  void VisitAttrDefault_(const Node* op, std::ostream& os) final { // NOLINT(*)
    os << meta_.GetMetaNode(GetRef<NodeRef>(op));
  }

  void VisitAttr_(const ir::IntImm* op, std::ostream& os) final {  // NOLINT(*)
    this->PrintConstScalar(op->type, &(op->value), os);
  }

  void VisitAttr_(const ir::UIntImm* op, std::ostream& os) final {  // NOLINT(*)
    this->PrintConstScalar(op->type, &(op->value), os);
  }

  void VisitAttr_(const ir::FloatImm* op, std::ostream& os) final {  // NOLINT(*)
    this->PrintConstScalar(op->type, &(op->value), os);
  }

  void VisitAttr_(const ir::StringImm* op, std::ostream& os) final {  // NOLINT(*)
    this->PrintString(op->value, os);
  }

 protected:
  /*!
   * \brief Print attributes after call.
   * \param op The operator to be called.
   * \param attrs The attributes.
   * \param os The output stream.
   */
  void PrintCallAttrs(const Expr& op, const Attrs& attrs, std::ostream& os);  // NOLINT(*)

  /*!
   * \brief Print the a new scopr.
   * \param body The body.
   */
  void PrintScope(Expr body) {
    stream_ << "{\n";
    int sid = this->BeginScope();
    this->PrintScopeBody(body);
    this->EndScope(sid);
    this->PrintIndent();
    stream_ << "}";
  }
  /*!
   * \brief Print the body of a new scope without {}
   *
   * This function will keep printing continuous sequence
   * of let/if scope without introducing a new scope in the text.
   *
   * \param body The body.
   */
  void PrintScopeBody(Expr body) {
    if (const LetNode* let = body.as<LetNode>()) {
      TextValue value = GetValue(let->value);
      AllocVarName(let->var);
      // let var = value;
      this->PrintIndent();
      stream_ << "let ";
      this->PrintVarDecl(let->var, stream_);
      stream_ << " = " << value;
      this->PrintEndInst("\n");
      this->PrintScopeBody(let->body);
    } else if (const IfNode* ifnode = body.as<IfNode>()) {
      TextValue cond = GetValue(ifnode->cond);
      this->PrintIndent();
      stream_ << "if (" << cond << ") ";
      this->PrintScope(ifnode->true_branch);
      this->PrintIndent();
      stream_ << "else ";
      this->PrintScope(ifnode->false_branch);
      this->PrintEndInst("\n");
    } else {
      TextValue value = GetValue(body);
      this->PrintIndent();
      stream_ << value;
      this->PrintEndInst("\n");
    }
  }

  /*!
   * \brief Internal function to print a function argument list and its body.
   * \param prefix The prefix before argument list.
   * \param fn The function to be printed.
   */
  void PrintFuncInternal(std::string prefix, const Function& fn) {
    // TODO(tqchen, M.K.) support generic function
    // Possibly through meta-data
    CHECK_EQ(fn->type_params.size(), 0U)
        << "generic fn not yet supported";
    this->PrintIndent();
    stream_ << prefix << "(";
    size_t decl_indent = prefix.length() + 1;
    for (size_t i = 0; i < fn->params.size(); ++i) {
      if (i != 0) {
        this->PrintIndent(decl_indent);
      }
      AllocVarName(fn->params[i]);
      this->PrintVarDecl(fn->params[i], stream_);
      if (i + 1 != fn->params.size()) {
        stream_ << ",\n";
      }
    }
623
    stream_ << ')';
624
    if (fn->ret_type.defined()) {
625 626 627
      stream_ << '\n';
      this->PrintIndent(decl_indent);
      stream_ << "-> ";
628 629
      this->PrintType(fn->ret_type, stream_);
    }
630
    stream_ << ' ';
631 632 633 634 635 636 637 638
    this->PrintScope(fn->body);
  }
  /*!
   * \brief Print additional info about expr in comment.
   * \param expr The expression.
   */
  void PrintOptionalInfo(const Expr& expr) {
    // additional information in comment.
639 640 641
    if (annotate_ != nullptr) {
      stream_ << " # " << annotate_(expr);
    } else if (expr->checked_type_.defined()) {
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778
      stream_ << " # ty=";
      this->PrintType(expr->checked_type(), stream_);
    }
  }
  /*!
   * \brief print var_name[:type]
   * \param var The variable to be printed
   * \param os The output stream
   */
  void PrintVarDecl(const Var& var, std::ostream& os) {  // NOLINT(*)
    TextValue v = GetValue(var);
    os << v;
    if (var->type_annotation.defined()) {
      os << ": ";
      this->PrintType(var->type_annotation, os);
    }
  }
  /*!
   * \brief Get a constant scalar value.
   * \param dtype The data type.
   * \param data The pointer to the data.
   * \tparam T the content data type holding the data.
   */
  template<typename T>
  TextValue ConstScalar(DataType dtype, const T* data) {
    std::ostringstream os;
    PrintConstScalar(dtype, data, os);
    return TextValue(os.str());
  }
  /*!
   * \brief special method to print out const scalar
   * \param dtype The data type
   * \param data The pointer to hold the data.
   * \param os The output stream.
   */
  template<typename T>
  void PrintConstScalar(DataType dtype, const T* data, std::ostream& os) {  // NOLINT(*)
    if (dtype == Int(32)) {
      os << data[0];
    } else if (dtype == Float(32)) {
      os << data[0] << 'f';
    } else if (dtype == Bool()) {
      PrintBool(data[0] != 0, os);
    } else {
      os << dtype << "(" << data[0] << ")";
    }
  }
  /*!
   * \brief Print constant bool value.
   * \param value The value to be printed.
   * \param os The output stream
   */
  void PrintBool(bool value, std::ostream& os) { // NOLINT(*)
    if (value) {
      os << "True";
    } else {
      os << "False";
    }
  }
  /*!
   * \brief Print constant string.
   * \param value The value to be printed.
   * \param os The output stream
   */
  void PrintString(const std::string& value, std::ostream& os) { // NOLINT(*)
    // TODO(M.K.): add escape.
    os << "\"" << value << "\"";
  }
  /*!
   * \brief get a unique name with the corresponding prefix
   * \param prefix The prefix of the name
   * \return The returned name.
   */
  std::string GetUniqueName(std::string prefix) {
    auto it = name_alloc_map_.find(prefix);
    if (it != name_alloc_map_.end()) {
      while (true) {
        std::ostringstream os;
        os << prefix << (++it->second);
        std::string name = os.str();
        if (name_alloc_map_.count(name) == 0) {
          prefix = name;
          break;
        }
      }
    }
    name_alloc_map_[prefix] = 0;
    return prefix;
  }
  /*!
   * \brief mark the beginning of a new scope
   * \return The scope id.
   */
  int BeginScope() {
    int sid = static_cast<int>(scope_valid_.size());
    scope_valid_.push_back(true);
    indent_ += 2;
    return sid;
  }
  /*!
   * \brief mark the end of an old scope.
   * \param scope_id The scope id to be ended.
   */
  void EndScope(int scope_id) {
    scope_valid_[scope_id] = false;
    indent_ -= 2;
  }
  /*!
   * \brief Print the indent to the stream.
   * \param more_indent More indentation besides the current one.
   */
  void PrintIndent(int64_t more_indent = 0) {
    for (int i = 0; i < indent_ + more_indent; ++i) {
      stream_ << ' ';
    }
  }
  /*!
   * \brief print end of the line.
   */
  void PrintEndInst(const char* suffix) {
    stream_ << suffix;
  }
  /*!
   * \brief Allocate temporary value
   * \return A new text value.
   */
  TextValue AllocTempVar() {
    std::ostringstream os;
    os << '%' << temp_var_counter_++;
    return TextValue(os.str());
  }
  /*!
   * \brief Allocate name to a variable.
   * \param var The input variable.
   * \return The corresponding name.
   */
  TextValue AllocVarName(const Var& var) {
779
    std::string name = var->name_hint();
780 781 782 783 784 785 786
    // always make sure first name is alpha
    if (name.length() != 0 && !std::isalpha(name[0])) {
      name = "%v" + name;
    } else {
      name = "%" + name;
    }
    TextValue val(GetUniqueName(name));
787 788 789 790
    // still print if ir is malformed, but show the error.
    if (memo_.count(var)) {
      memo_[var] = TextValue(val.name + "-malformed-ir");
    }
791 792 793 794 795 796 797
    memo_[var] = val;
    return val;
  }

 private:
  class AttrPrinter;
  friend class AttrPrinter;
798 799
  /*! \brief Whether to print meta data. */
  bool show_meta_data_;
800 801
  /*! \brief additional comment function */
  runtime::TypedPackedFunc<std::string(Expr)> annotate_;
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
  /*! \brief meta data context */
  TextMetaDataContext meta_;
  /*! \brief Check whether scope is still valid */
  std::vector<bool> scope_valid_;
  /*! \brief The current indentation value */
  int indent_{0};
  /*! \brief name allocation map */
  std::unordered_map<std::string, int> name_alloc_map_;
  /*! \brief Map from expression to its text value */
  std::unordered_map<Expr, TextValue, NodeHash, NodeEqual> memo_;
  /*! \brief counter of temporary variable */
  int64_t temp_var_counter_{0};
  /*! \brief Output stream */
  std::ostringstream stream_;
};

/*!
 * \brief Attribute printer which prints the attributes in the call.
 */
class TextPrinter::AttrPrinter: public AttrVisitor {
 public:
  AttrPrinter(std::ostream& stream, TextPrinter* parent)  // NOLINT(*)
      : stream_(stream), parent_(parent) {}

  void Visit(const char* key, double* value) final {
    PrintSep();
    stream_ << key << "=" << value[0];
  }
  void Visit(const char* key, int64_t* value) final {
    PrintSep();
    stream_ << key << "=" << value[0];
  }
  void Visit(const char* key, uint64_t* value) final {
    PrintSep();
    stream_ << key << "=" << value[0];
  }
  void Visit(const char* key, int* value) final {
    PrintSep();
    stream_ << key << "=" << value[0];
  }
  void Visit(const char* key, bool* value) final {
    PrintSep();
    stream_ << key << "=";
    parent_->PrintBool(value[0], stream_);
  }
  void Visit(const char* key, std::string* value) final {
    PrintSep();
    stream_ << key << "=";
    parent_->PrintString(value[0], stream_);
  }
  void Visit(const char* key, void** value) final {
    LOG(FATAL) << "do not allow void as argument";
  }
  void Visit(const char* key, DataType* value) final {
    PrintSep();
    stream_ << key << "=";
    parent_->PrintString(runtime::TVMType2String(Type2TVMType(value[0])), stream_);
  }
  void Visit(const char* key, NodeRef* value) final {
    PrintSep();
    stream_ << key << "=";
    parent_->PrintAttr(value[0], stream_);
  }
  void Visit(const char* key, runtime::NDArray* value) final {
    LOG(FATAL) << "do not allow NDarray as argument";
  }

 private:
  void PrintSep() {
    stream_ << ", ";
  }
  std::ostream& stream_;  // NOLINT(*)
  TextPrinter* parent_;
};

void TextPrinter::PrintCallAttrs(const Expr& op,
                                 const Attrs& attrs,
                                 std::ostream& os) {  // NOLINT(*)
  if (!attrs.defined()) return;
  if (const auto* op_node = op.as<OpNode>()) {
    if (attrs->type_index() == op_node->attrs_type_index) {
      AttrPrinter printer(os, this);
      const_cast<BaseAttrsNode*>(attrs.operator->())
          ->VisitNonDefaultAttrs(&printer);
      return;
    }
  }
  os << ", " << meta_.GetMetaNode(attrs);
}

892
std::string RelayPrint(const NodeRef& node,
893
                       bool show_meta_data,
894
                       runtime::TypedPackedFunc<std::string(Expr)> annotate) {
895
  return TextPrinter(show_meta_data, annotate).Print(node);
896 897
}

898 899
TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string(
900
    const NodeRef&, bool,
901
    runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
902 903 904

}  // namespace relay
}  // namespace tvm