pretty_printer.cc 23.3 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2019 by Contributors
 * \file pretty_printer.cc
 * \brief Pretty printer for Relay programs
 * Supports ANF, GNF, and metadata.
25 26 27 28 29 30 31 32
 *
 * Inlining heuristics:
 *  - Always inline:
 *    - GlobalVar
 *    - Constant
 *    - Op
 *    - Var
 *  - Otherwise, inline if the node is at the end of a scope and is used at most once.
33
 */
34

35 36 37 38 39
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
40
#include "../pass/dependency_graph.h"
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
#include "../../lang/attr_functor.h"

namespace tvm {
namespace relay {

/*!
 * \brief Meta data context for PrettyPrinter.
 *
 * This is an important part to enable bi-directional serializability.
 * We use tvm's Node system to build the current IR.
 * It can be hard to design a text format for all the possible nodes
 * as the set of nodes can grow when we do more extensions.
 *
 * Instead of trying to design readable text format for every node,
 * we support a meta data section in the text format.
 * We allow the text format to refer to a node in the meta data section.
 *
 * The meta data section is a json serialized string of an Map<string, Array<NodeRef>>.
 * Each element in the meta data section can be referenced by the text format.
 * Each meta data node is printed in the following format.
 *
 * meta[type-key-of-node>][<index-in-meta-section>]
 *
 * Specifically, consider the following IR(constructed by python).
 *
 * \code
 *
 * n = tvm.var("n")
 * x = tvm.relay.var("x", shape=(n, 1))
 * f = tvm.relay.Function([x], x)
 * print(f.astext())
 *
 * \endcode
 *
 * The corresponding text format is shown in the following code block.
 *
 * \code
 *
 * fn (%x: Tensor[(meta[Variable][0],), float32]) {
 *   %x
 * }
 * # Meta data section is a json-serialized string
 * # of the following array.
 * # [tvm.var("n")]
 *
 * \endcode
 *
 * Note that we store tvm.var("n") in the meta data section.
 * Since it is stored in the index-0 in the meta data section,
 * we print it as meta[Variable][0].
 *
 * The text parser can recover this object by loading from the corresponding
 * location in the meta data section.
 *
 * This is is a design trade-off.
 * It allows us to embedded any meta data in the text format,
 * while still being able to tweak the text part of the printed IR easily.
 */
class TextMetaDataContext {
 public:
  /*!
   * \brief Get text representation of meta node.
   * \param node The node to be converted to meta node.
   * \return A string representation of the meta node.
   */
  Doc GetMetaNode(const NodeRef& node) {
    auto it = meta_repr_.find(node);
    if (it != meta_repr_.end()) {
      return it->second;
    }
    Array<NodeRef>& mvector =
        meta_data_[node->type_key()];
    int64_t index = static_cast<int64_t>(mvector.size());
    mvector.push_back(node);
    Doc doc;
    doc << "meta[" << node->type_key() << "][" << index << "]";
    meta_repr_[node] = doc;
    return meta_repr_[node];
  }
  /*!
   * \brief Get the metadata section in json format.
   * \return the meta data string.
   */
  std::string GetMetaSection() const {
    if (meta_data_.size() == 0) return std::string();
    return SaveJSON(Map<std::string, NodeRef>(
        meta_data_.begin(), meta_data_.end()));
  }

  /*! \return whether the meta data context is empty. */
  bool empty() const {
    return meta_data_.empty();
  }

 private:
  /*! \brief additional metadata stored in TVM json format */
  std::unordered_map<std::string, Array<NodeRef> > meta_data_;
  /*! \brief map from meta data into its string representation */
  std::unordered_map<NodeRef, Doc, NodeHash, NodeEqual> meta_repr_;
};

class PrettyPrinter :
    public ExprFunctor<Doc(const Expr&)>,
    public PatternFunctor<Doc(const Pattern&)>,
    public TypeFunctor<Doc(const Type&)>,
    public AttrFunctor<Doc(const NodeRef&)> {
 public:
148
  explicit PrettyPrinter(bool show_meta_data,
149 150 151 152 153 154 155 156 157 158 159 160
                         runtime::TypedPackedFunc<std::string(Expr)> annotate) :
                         show_meta_data_(show_meta_data),
                         annotate_(annotate) {}

  /*!
    * \brief Print additional info about expr in comment.
    * \param expr The expression.
    */
  Doc PrintOptionalInfo(const Expr& expr) {
    Doc doc;
    // additional information in comment.
    if (annotate_ != nullptr) {
161
      return doc << " /* " << annotate_(expr) << " */";
162
    } else if (expr->checked_type_.defined()) {
163
      return doc << " /* ty=" << Print(expr->checked_type()) << " */";
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    } else {
      return doc;
    }
  }

  // indent a new body
  // TODO(jmp): indent should be an instance variable of the printer
  Doc PrintBody(const NodeRef& node, int indent = 2) {
    Doc doc;
    Doc body;
    doc << "{";
    doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n";
    doc << "}";
    return doc;
  }

  // create a new scope by creating a new printer object. This allows temp var
  // numbers to be reused and prevents hoisted vars from escaping too far
  Doc PrintScope(const NodeRef& node) {
    // print in a new scope
    doc_stack_.push_back(Doc());
    // must print first so doc_stack_.back() reference doesn't become stale
186
    Doc doc = Print(node, false, true);
187 188 189 190 191 192
    doc = doc_stack_.back() << doc;
    doc_stack_.pop_back();
    return doc;
  }

  Doc PrintFinal(const NodeRef& node) {
193 194 195 196 197
    if (node.as_derived<ExprNode>()) {
      Expr expr = Downcast<Expr>(node);
      dg_ = DependencyGraph::Create(&arena_, expr);
    }

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    Doc doc;
    doc << PrintScope(node);
    if (!meta_.empty()) {
      if (show_meta_data_) {
        std::string meta_json = meta_.GetMetaSection();
        // append meta data in the end.
        doc << "\n" << "/* meta data */" << "\n" << meta_json;
      } else {
        doc << "\n"
            << "// meta data omitted. you can use show_meta_data=True to include meta data";
      }
    }
    return doc;
  }

  Doc PrintAttrs(const Attrs& attrs, const Expr& op);

215
  Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
216
    if (node.as_derived<ExprNode>()) {
217
      return PrintExpr(Downcast<Expr>(node), meta, try_inline);
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    } else if (node.as_derived<TypeNode>()) {
      return PrintType(Downcast<Type>(node), meta);
    } else if (node.as_derived<ModuleNode>()) {
      return PrintMod(Downcast<Module>(node));
    } else {
      Doc doc;
      return doc << node;
    }
  }

  Doc TempVar(int n) {
    Doc doc;
    return doc << "%" << n;
  }

  Doc AllocTemp() {
    return TempVar(temp_var_counter_++);
  }

  /*!
    * \brief get a unique name with the corresponding prefix
    * \param prefix The prefix of the name
    * \return The returned name.
    */
  Doc GetUniqueName(const std::string& prefix) {
    std::string unique_prefix = prefix;
    auto it = name_alloc_map_.find(prefix);
    if (it != name_alloc_map_.end()) {
      while (true) {
        std::ostringstream os;
        os << prefix << (++it->second);
        std::string name = os.str();
        if (name_alloc_map_.count(name) == 0) {
          unique_prefix = name;
          break;
        }
      }
    }
    name_alloc_map_[unique_prefix] = 0;
    return Doc(unique_prefix);
  }

雾雨魔理沙 committed
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
  Doc Print(Kind k) {
    switch (k) {
    case kType:
      return Doc("Type");
    case kShapeVar:
      return Doc("Shape");
    case kBaseType:
      return Doc("BaseType");
    case kConstraint:
      return Doc("Constraint");
    case kAdtHandle:
      return Doc("AdtHandle");
    case kTypeData:
      return Doc("TypeData");
    default:
      LOG(ERROR) << "Unknown Kind";
      throw;
    }
  }
279
  /*!
雾雨魔理沙 committed
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
   * \brief Allocate name to a type variable.
   * \param var The input type variable.
   * \return The corresponding name.
   */
  Doc AllocTypeVar(const TypeVar& var) {
    std::string name = var->var->name_hint;
    if (name.length() == 0 || !std::isalpha(name[0])) {
      name = "t" + name;
    }
    Doc val = GetUniqueName("%" + name);
    if (memo_type_.count(var)) {
      val << "-malformed-ir";
    }
    memo_type_[var] = val;
    if (var->kind != kType) {
      val << ": " << Print(var->kind);
    }
    return val;
  }

  /*!
   * \brief Allocate name to a variable.
   * \param var The input variable.
   * \return The corresponding name.
   */
305 306 307
  Doc AllocVar(const Var& var) {
    std::string name = var->name_hint();
    // always make sure first name is alpha
雾雨魔理沙 committed
308
    if (name.length() == 0 || !std::isalpha(name[0])) {
309 310 311 312 313 314 315 316 317 318 319 320 321 322
      name = "v" + name;
    }
    Doc val = GetUniqueName("%" + name);
    // still print if ir is malformed, but show the error.
    if (memo_.count(var)) {
      val << "-malformed-ir";
    }
    memo_[var] = val;
    if (var->type_annotation.defined()) {
      val << ": " << Print(var->type_annotation);
    }
    return val;
  }

323 324 325 326 327 328
  bool IsUnique(const Expr& expr) {
    return !(dg_.expr_node.at(expr)->parents.head &&
             dg_.expr_node.at(expr)->parents.head->next);
  }

  bool AlwaysInline(const Expr& expr) {
329 330 331 332 333 334 335
    return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
           expr.as<OpNode>() || expr.as<VarNode>();
  }

  //------------------------------------
  // Overload of Expr printing functions
  //------------------------------------
336
  Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
337 338 339 340
    // Exploit memoization to print GNF.
    // The first time we visit an expression, we need to allocate a temp var
    // for it. Every subsequent time we can just use its assigned variable.
    // This works since hashing uses pointer equality.
341 342 343 344 345 346 347

    // determine whether to inline
    bool inline_expr = AlwaysInline(expr);
    if (try_inline) {
      inline_expr |= IsUnique(expr);
    }

348 349
    auto it = memo_.find(expr);
    if (it != memo_.end()) return it->second;
350

351 352 353
    Doc printed_expr;
    if (meta) {
      printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
354
    } else if (!inline_expr && expr.as<LetNode>()) {
355 356 357 358 359 360 361 362
      // wrap GNFed let in brackets
      Doc body;
      printed_expr << "{";
      printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n";
      printed_expr << "}";
    } else {
      printed_expr = VisitExpr(expr);
    }
363 364 365 366 367 368 369

    if (expr.as<CallNode>()) {
      printed_expr << PrintOptionalInfo(expr);
    }

    // add expr to doc
    if (expr.as<VarNode>()) {
370 371 372 373 374
      // This is our first time visiting the var and we hit the VarNode case
      // in the visitor. Thus the variable is free.
      doc_stack_.back() << "free_var " << printed_expr << "\n";
      // Memoization is done in AllocVar.
      return memo_[expr];
375
    } else if (inline_expr) {
376 377
      memo_[expr] = printed_expr;
      return printed_expr;
378 379 380 381 382
    } else {
      Doc temp_var = AllocTemp();
      memo_[expr] = temp_var;
      doc_stack_.back() << temp_var << " = " << printed_expr << "\n";
      return temp_var;
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
    }
  }

  // Should only be triggered when op is a free variable being visited for the
  // first time.
  Doc VisitExpr_(const VarNode* op) final {
    return AllocVar(GetRef<Var>(op));
  }

  Doc VisitExpr_(const ConstantNode* op) final {
    // Print out simple scalars directly.
    if (op->is_scalar()) {
      std::ostringstream os;
      DataType dtype = TVMType2Type(op->data->dtype);
      CHECK_EQ(op->data->ctx.device_type, kDLCPU);
      if (dtype == Int(32)) {
        return PrintConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
      } else if (dtype == Int(64)) {
        return PrintConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
      } else if (dtype == Float(32)) {
        return PrintConstScalar(dtype, static_cast<const float*>(op->data->data));
      } else if (dtype == Float(64)) {
        return PrintConstScalar(dtype, static_cast<const double*>(op->data->data));
      } else if (dtype == Bool()) {
        return PrintConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
      }
    }
    // default fall-back, record it as meta node.
    Doc doc;
    return doc << Print(GetRef<NodeRef>(op), true)
               << PrintOptionalInfo(GetRef<Expr>(op));
  }

  Doc VisitExpr_(const TupleNode* op) final {
    std::vector<Doc> fields;
    for (Expr field : op->fields) {
      fields.push_back(Print(field));
    }
    Doc doc;
    doc << "(" << PrintVec(fields);
    // conform to python tuple format (1,)
    if (op->fields.size() == 1) {
      doc << ",";
    }
    return doc << ")";
  }

  Doc VisitExpr_(const TupleGetItemNode* op) final {
    Doc doc;
    return doc << Print(op->tuple) << "." << op->index;
  }

  Doc VisitExpr_(const IfNode* op) final {
    Doc doc;
    doc << "if (" << Print(op->cond) << ") ";
    doc << PrintBody(op->true_branch);
    doc << " else ";
    doc << PrintBody(op->false_branch);
    return doc;
  }

  Doc VisitExpr_(const LetNode* op) final {
    Doc doc;
446
    doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n";
447
    // we use a scope here so GNF hoisting doesn't escape too far
448
    // and nested, unique lets are not hoisted
449 450 451 452 453 454
    doc << PrintScope(op->body);
    return doc;
  }

  Doc PrintFunc(const Doc& prefix, const Function& fn) {
      Doc doc;
雾雨魔理沙 committed
455 456 457 458 459 460 461 462 463 464 465
      doc << prefix;
      if (fn->type_params.size() > 0) {
        doc << "<";
        std::vector<Doc> type_params;
        for (const TypeVar& tv : fn->type_params) {
          type_params.push_back(AllocTypeVar(tv));
        }
        doc << PrintVec(type_params);
        doc << ">";
      }
      doc << "(";
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
      std::vector<Doc> params;
      for (Var param : fn->params) {
        params.push_back(AllocVar(param));
      }
      doc << PrintVec(params) << PrintAttrs(fn->attrs, fn);
      doc << ") ";
      if (fn->ret_type.defined()) {
        doc << "-> " << Print(fn->ret_type) << " ";
      }
      doc << PrintBody(fn->body);
      return doc;
  }

  Doc PrintMod(const Module& mod) {
    Doc doc;
    int counter = 0;
    for (const auto& kv : mod->functions) {
483 484
      dg_ = DependencyGraph::Create(&arena_, kv.second);

485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
      std::ostringstream os;
      if (counter++ != 0) {
        doc << "\n";
      }
      os << "def @" << kv.first->name_hint;
      doc << PrintFunc(Doc(os.str()), kv.second);
      doc << "\n";
    }
    return doc;
  }

  Doc VisitExpr_(const FunctionNode* op) final {
    return PrintFunc(Doc("fn "), GetRef<Function>(op));
  }

  Doc VisitExpr_(const GlobalVarNode* op) final {
    return Doc('@' + op->name_hint);
  }

  Doc VisitExpr_(const OpNode* op) final {
    return Doc(op->name);
  }

  Doc VisitExpr_(const CallNode* op) final {
    Doc doc;
510 511
    // visit args first so they are lifted before the op
    // this places op closer to its call site
512 513 514 515
    std::vector<Doc> args;
    for (Expr arg : op->args) {
      args.push_back(Print(arg));
    }
516
    doc << Print(op->op);
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
    return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, op->op) << ")";
  }

  Doc VisitExpr_(const RefCreateNode* op) final {
    Doc doc;
    return doc << "ref(" << Print(op->value) << ")";
  }

  Doc VisitExpr_(const RefReadNode* op) final {
    Doc doc;
    return doc << Print(op->ref) << "^";
  }

  Doc VisitExpr_(const RefWriteNode* op) final {
    Doc doc;
    return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
  }

  Doc VisitExpr_(const MatchNode* op) final {
    // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
    Doc doc;
    Doc body;
    doc << "match " << Print(op->data) << " ";
    doc << "{";
    std::vector<Doc> clauses;
    for (const auto& clause : op->clauses) {
      Doc clause_doc;
      clauses.push_back(clause_doc << Print(clause->lhs) << " -> "
                                   << Print(clause->rhs));
    }
    doc << Indent(2, body << "\n" << PrintVec(clauses, Doc("\n"))) << "\n";
    doc << "}";
    return doc;
  }

  Doc VisitPattern_(const PatternConstructorNode* p) final {
    Doc doc;
    doc << p->constructor->name_hint << "(";
    std::vector<Doc> pats;
    for (const auto& pat : p->patterns) {
      pats.push_back(Print(pat));
    }
    return doc << PrintVec(pats) << ")";
  }

  Doc VisitPattern_(const PatternVarNode* pv) final {
    return AllocVar(pv->var);
  }

  Doc VisitExpr_(const ConstructorNode* n) final {
    return Doc(n->name_hint);
  }

  //------------------------------------
  // Overload of Type printing functions
  //------------------------------------
  Doc PrintType(const Type& type, bool meta) {
    auto it = memo_type_.find(type);
    if (it != memo_type_.end()) return it->second;
    Doc printed_type;
    if (meta) {
      printed_type = meta_.GetMetaNode(GetRef<NodeRef>(type.get()));
    } else {
      printed_type = VisitType(type);
    }
    memo_type_[type] = printed_type;
    return printed_type;
  }

  Doc VisitTypeDefault_(const Node* node) final {
    // by default always print as meta data
    return Print(GetRef<NodeRef>(node), true);
  }

雾雨魔理沙 committed
591 592 593 594
  Doc VisitType_(const TypeVarNode* node) final {
    return AllocTypeVar(GetRef<TypeVar>(node));
  }

595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
  Doc VisitType_(const TensorTypeNode* node) final {
    // scalar type
    if (node->shape.size() == 0) {
      return PrintDType(node->dtype);
    }
    Doc doc;
    doc << "Tensor[(";
    std::vector<Doc> shapes;
    for (NodeRef shape : node->shape) {
      shapes.push_back(PrintAttr(shape));
    }
    doc << PrintVec(shapes);
    // conform to python tuple format (1,)
    if (node->shape.size() == 1) {
      doc << ",";
    }
    return doc << "), " << PrintDType(node->dtype) << "]";
  }

  Doc VisitType_(const TupleTypeNode* node) final {
    std::vector<Doc> fields;
    for (Type field : node->fields) {
      fields.push_back(Print(field));
    }
    Doc doc;
    doc << "(" << PrintVec(fields);
    // conform to python tuple format (1,)
    if (node->fields.size() == 1) {
      doc << ",";
    }
    return doc << ")";
  }

  Doc VisitType_(const FuncTypeNode* node) final {
    Doc doc;
    std::vector<Doc> arg_types;
    for (Type arg_type : node->arg_types) {
      arg_types.push_back(Print(arg_type));
    }
    return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type);
  }

  Doc VisitType_(const RefTypeNode* node) final {
    Doc doc;
    return doc << "ref(" << Print(node->value) << ")";
  }

  //------------------------------------
  // Overload of Attr printing functions
  //------------------------------------

  Doc PrintAttr(const NodeRef& value, bool meta = false) {
    if (value.defined()) {
      Doc printed_attr;
      if (meta) {
        printed_attr = meta_.GetMetaNode(value);
      } else {
        printed_attr = VisitAttr(value);
      }
      return printed_attr;
    } else {
      return Doc("None");
    }
  }

  Doc VisitAttrDefault_(const Node* op) final {
    return PrintAttr(GetRef<NodeRef>(op), true);
  }

  Doc VisitAttr_(const ArrayNode* op) final {
    Doc doc;
    doc << "[";
    std::vector<Doc> arr_vals;
    for (NodePtr<Node> val : op->data) {
      arr_vals.push_back(PrintAttr(NodeRef(val)));
    }
    doc << PrintVec(arr_vals);
    doc << "]";
    return doc;
  }

  Doc VisitAttr_(const ir::IntImm* op) final {
    return PrintConstScalar(op->type, &(op->value));
  }

  Doc VisitAttr_(const ir::UIntImm* op) final {
    return PrintConstScalar(op->type, &(op->value));
  }

  Doc VisitAttr_(const ir::FloatImm* op) final {
    return PrintConstScalar(op->type, &(op->value));
  }

  Doc VisitAttr_(const ir::StringImm* op) final {
    return PrintString(op->value);
  }

 private:
  /*! \brief Whether to print meta data. */
  bool show_meta_data_;
  /*! \brief additional comment function */
  runtime::TypedPackedFunc<std::string(Expr)> annotate_;
  /*! \brief Stack of docs to implement scoped GNFing. */
  std::vector<Doc> doc_stack_{};
  /*! \brief Map from Expr to Doc */
  std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
  /*! \brief Map from Type to Doc */
  std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
  /*! \brief name allocation map */
  std::unordered_map<std::string, int> name_alloc_map_;
  /*! \brief meta data context */
  TextMetaDataContext meta_;
  /*! \brief counter of temporary variable */
  size_t temp_var_counter_{0};
709 710 711 712
  /*! \brief arena for dependency graph */
  common::Arena arena_;
  /*! \brief dependency graph of the expr */
  DependencyGraph dg_;
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
  class AttrPrinter;
  friend class AttrPrinter;
};

/*!
 * \brief Attribute printer which prints the attributes in the call.
 */
class PrettyPrinter::AttrPrinter : public AttrVisitor {
 public:
  AttrPrinter(Doc& doc, PrettyPrinter* parent) : doc_(doc), parent_(parent) {}

  template<typename T>
  Doc PrintKV(const char* key, const T& value) {
    Doc doc;
    return doc << ", " << key << "=" << value;
  }

  void Visit(const char* key, double* value) final {
    doc_ << PrintKV(key, value[0]);
  }
  void Visit(const char* key, int64_t* value) final {
    doc_ << PrintKV(key, value[0]);
  }
  void Visit(const char* key, uint64_t* value) final {
    doc_ << PrintKV(key, value[0]);
  }
  void Visit(const char* key, int* value) final {
    doc_ << PrintKV(key, value[0]);
  }
  void Visit(const char* key, bool* value) final {
    doc_ << PrintKV(key, PrintBool(value[0]));
  }
  void Visit(const char* key, std::string* value) final {
    doc_ << PrintKV(key, PrintString(value[0]));
  }
  void Visit(const char* key, void** value) final {
    LOG(FATAL) << "do not allow void as argument";
  }
  void Visit(const char* key, DataType* value) final {
    doc_ << PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(value[0]))));
  }
  void Visit(const char* key, NodeRef* value) final {
    doc_ << PrintKV(key, parent_->PrintAttr(value[0]));
  }
  void Visit(const char* key, runtime::NDArray* value) final {
    LOG(FATAL) << "do not allow NDarray as argument";
  }

 private:
  Doc& doc_;
  PrettyPrinter* parent_;
};

Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {
  Doc doc;
  if (!attrs.defined()) return doc;
  const auto* op_node = op.as<OpNode>();
  if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
    // fallback
    return doc << ", " << meta_.GetMetaNode(attrs);
  } else {
    AttrPrinter printer(doc, this);
    const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
    return doc;
  }
}

std::string PrettyPrint_(const NodeRef& node,
                         bool show_meta_data,
782
                         runtime::TypedPackedFunc<std::string(Expr)> annotate) {
783 784
  Doc doc;
  doc << "v0.0.1" << "\n"
785
      << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
786 787 788
  return doc.str();
}

789
std::string AsText(const NodeRef& node,
790 791 792
                       bool show_meta_data,
                       runtime::TypedPackedFunc<std::string(Expr)> annotate) {
  return PrettyPrint_(node, show_meta_data, annotate);
793 794
}

795
TVM_REGISTER_API("relay._expr.AsText")
796 797
.set_body_typed<std::string(const NodeRef&,
                            bool,
798
                            runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
799 800 801

}  // namespace relay
}  // namespace tvm