pretty_printer.cc 28 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2019 by Contributors
 * \file pretty_printer.cc
 * \brief Pretty printer for Relay programs
 * Supports ANF, GNF, and metadata.
25 26 27 28 29 30 31 32
 *
 * Inlining heuristics:
 *  - Always inline:
 *    - GlobalVar
 *    - Constant
 *    - Op
 *    - Var
 *  - Otherwise, inline if the node is at the end of a scope and is used at most once.
33
 */
34

35
#include <dmlc/json.h>
36 37 38 39 40
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
41
#include "../pass/dependency_graph.h"
42 43 44 45 46
#include "../../lang/attr_functor.h"

namespace tvm {
namespace relay {

47 48
static const char* kSemVer = "v0.0.4";

49 50 51 52 53 54 55 56 57 58 59
Doc Brace(const Doc& d,
          const std::string& open = "{",
          const std::string& close = "}",
          int indent = 2) {
  Doc doc;
  doc << open;
  doc << Indent(indent, PrintNewLine() << d) << PrintNewLine();
  doc << close;
  return doc;
}

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
/*!
 * \brief Meta data context for PrettyPrinter.
 *
 * This is an important part to enable bi-directional serializability.
 * We use tvm's Node system to build the current IR.
 * It can be hard to design a text format for all the possible nodes
 * as the set of nodes can grow when we do more extensions.
 *
 * Instead of trying to design readable text format for every node,
 * we support a meta data section in the text format.
 * We allow the text format to refer to a node in the meta data section.
 *
 * The meta data section is a json serialized string of an Map<string, Array<NodeRef>>.
 * Each element in the meta data section can be referenced by the text format.
 * Each meta data node is printed in the following format.
 *
 * meta[type-key-of-node>][<index-in-meta-section>]
 *
 * Specifically, consider the following IR(constructed by python).
 *
 * \code
 *
 * n = tvm.var("n")
 * x = tvm.relay.var("x", shape=(n, 1))
 * f = tvm.relay.Function([x], x)
 * print(f.astext())
 *
 * \endcode
 *
 * The corresponding text format is shown in the following code block.
 *
 * \code
 *
 * fn (%x: Tensor[(meta[Variable][0],), float32]) {
 *   %x
 * }
 * # Meta data section is a json-serialized string
 * # of the following array.
 * # [tvm.var("n")]
 *
 * \endcode
 *
 * Note that we store tvm.var("n") in the meta data section.
 * Since it is stored in the index-0 in the meta data section,
 * we print it as meta[Variable][0].
 *
 * The text parser can recover this object by loading from the corresponding
 * location in the meta data section.
 *
 * This is is a design trade-off.
 * It allows us to embedded any meta data in the text format,
 * while still being able to tweak the text part of the printed IR easily.
 */
class TextMetaDataContext {
 public:
  /*!
   * \brief Get text representation of meta node.
   * \param node The node to be converted to meta node.
   * \return A string representation of the meta node.
   */
  Doc GetMetaNode(const NodeRef& node) {
    auto it = meta_repr_.find(node);
    if (it != meta_repr_.end()) {
      return it->second;
    }
125 126
    std::string type_key = node->type_key();
    CHECK(!type_key.empty());
127
    Array<NodeRef>& mvector =
128
        meta_data_[type_key];
129 130 131 132 133 134 135
    int64_t index = static_cast<int64_t>(mvector.size());
    mvector.push_back(node);
    Doc doc;
    doc << "meta[" << node->type_key() << "][" << index << "]";
    meta_repr_[node] = doc;
    return meta_repr_[node];
  }
136 137 138 139 140

  Doc PrintKeyValue(const std::string& str, const Doc& v) const {
    return Doc("\"") << str << "\": " << v;
  }

141 142 143 144
  /*!
   * \brief Get the metadata section in json format.
   * \return the meta data string.
   */
145 146 147
  Doc GetMetaSection() const {
    if (meta_data_.size() == 0) return Doc();
    return Doc(SaveJSON(Map<std::string, NodeRef>(meta_data_.begin(), meta_data_.end())));
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
  }

  /*! \return whether the meta data context is empty. */
  bool empty() const {
    return meta_data_.empty();
  }

 private:
  /*! \brief additional metadata stored in TVM json format */
  std::unordered_map<std::string, Array<NodeRef> > meta_data_;
  /*! \brief map from meta data into its string representation */
  std::unordered_map<NodeRef, Doc, NodeHash, NodeEqual> meta_repr_;
};

class PrettyPrinter :
    public ExprFunctor<Doc(const Expr&)>,
    public PatternFunctor<Doc(const Pattern&)>,
    public TypeFunctor<Doc(const Type&)>,
    public AttrFunctor<Doc(const NodeRef&)> {
 public:
168
  explicit PrettyPrinter(bool show_meta_data,
169 170 171 172 173 174 175 176 177 178
                         runtime::TypedPackedFunc<std::string(Expr)> annotate) :
                         show_meta_data_(show_meta_data),
                         annotate_(annotate) {}

  /*!
    * \brief Print additional info about expr in comment.
    * \param expr The expression.
    */
  Doc PrintOptionalInfo(const Expr& expr) {
    Doc doc;
179 180 181 182 183
    // default annotations
    if (annotate_ == nullptr) {
      if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
        doc << " /* ty=" << Print(expr->checked_type()) << " */";
      }
184
    } else {
185 186 187 188
      std::string annotated_expr = annotate_(expr);
      if (annotated_expr != "") {
        doc << annotated_expr;
      }
189
    }
190 191

    return doc;
192 193 194 195 196 197 198
  }

  // indent a new body
  Doc PrintBody(const NodeRef& node, int indent = 2) {
    Doc doc;
    Doc body;
    doc << "{";
199
    doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine();
200 201 202 203 204 205 206 207 208 209
    doc << "}";
    return doc;
  }

  // create a new scope by creating a new printer object. This allows temp var
  // numbers to be reused and prevents hoisted vars from escaping too far
  Doc PrintScope(const NodeRef& node) {
    // print in a new scope
    doc_stack_.push_back(Doc());
    // must print first so doc_stack_.back() reference doesn't become stale
210
    Doc doc = Print(node, false, true);
211 212 213 214 215 216
    doc = doc_stack_.back() << doc;
    doc_stack_.pop_back();
    return doc;
  }

  Doc PrintFinal(const NodeRef& node) {
217 218 219 220 221
    if (node.as_derived<ExprNode>()) {
      Expr expr = Downcast<Expr>(node);
      dg_ = DependencyGraph::Create(&arena_, expr);
    }

222 223 224
    Doc doc;
    doc << PrintScope(node);
    if (!meta_.empty()) {
225
      doc << PrintNewLine();
226 227
      if (show_meta_data_) {
        // append meta data in the end.
228
        doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection();
229
      } else {
230
        doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
231 232 233 234 235
      }
    }
    return doc;
  }

236 237
  std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
  std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
238

239
  Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
240
    if (node.as_derived<ExprNode>()) {
241
      return PrintExpr(Downcast<Expr>(node), meta, try_inline);
242 243
    } else if (node.as_derived<TypeNode>()) {
      return PrintType(Downcast<Type>(node), meta);
244 245
    } else if (node.as_derived<PatternNode>()) {
      return PrintPattern(Downcast<Pattern>(node), meta);
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    } else if (node.as_derived<ModuleNode>()) {
      return PrintMod(Downcast<Module>(node));
    } else {
      Doc doc;
      return doc << node;
    }
  }

  Doc TempVar(int n) {
    Doc doc;
    return doc << "%" << n;
  }

  Doc AllocTemp() {
    return TempVar(temp_var_counter_++);
  }

  /*!
    * \brief get a unique name with the corresponding prefix
    * \param prefix The prefix of the name
    * \return The returned name.
    */
  Doc GetUniqueName(const std::string& prefix) {
    std::string unique_prefix = prefix;
    auto it = name_alloc_map_.find(prefix);
    if (it != name_alloc_map_.end()) {
      while (true) {
        std::ostringstream os;
        os << prefix << (++it->second);
        std::string name = os.str();
        if (name_alloc_map_.count(name) == 0) {
          unique_prefix = name;
          break;
        }
      }
    }
    name_alloc_map_[unique_prefix] = 0;
    return Doc(unique_prefix);
  }

雾雨魔理沙 committed
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
  Doc Print(Kind k) {
    switch (k) {
    case kType:
      return Doc("Type");
    case kShapeVar:
      return Doc("Shape");
    case kBaseType:
      return Doc("BaseType");
    case kConstraint:
      return Doc("Constraint");
    case kAdtHandle:
      return Doc("AdtHandle");
    case kTypeData:
      return Doc("TypeData");
    default:
      LOG(ERROR) << "Unknown Kind";
      throw;
    }
  }
305
  /*!
雾雨魔理沙 committed
306 307 308 309 310
   * \brief Allocate name to a type variable.
   * \param var The input type variable.
   * \return The corresponding name.
   */
  Doc AllocTypeVar(const TypeVar& var) {
311 312 313 314 315
    if (memo_type_.count(var)) {
      Doc val = memo_type_[var];
      val << "-malformed-ir";
      return val;
    }
雾雨魔理沙 committed
316 317 318 319
    std::string name = var->var->name_hint;
    if (name.length() == 0 || !std::isalpha(name[0])) {
      name = "t" + name;
    }
320
    Doc val = GetUniqueName(name);
雾雨魔理沙 committed
321 322 323 324 325 326 327 328 329 330 331 332
    memo_type_[var] = val;
    if (var->kind != kType) {
      val << ": " << Print(var->kind);
    }
    return val;
  }

  /*!
   * \brief Allocate name to a variable.
   * \param var The input variable.
   * \return The corresponding name.
   */
333
  Doc AllocVar(const Var& var) {
334 335 336 337 338 339
    // still print if ir is malformed, but show the error.
    if (memo_.count(var)) {
      Doc val = memo_[var];
      val << "-malformed-ir";
      return val;
    }
340 341
    std::string name = var->name_hint();
    // always make sure first name is alpha
雾雨魔理沙 committed
342
    if (name.length() == 0 || !std::isalpha(name[0])) {
343 344 345 346 347 348 349 350 351 352
      name = "v" + name;
    }
    Doc val = GetUniqueName("%" + name);
    memo_[var] = val;
    if (var->type_annotation.defined()) {
      val << ": " << Print(var->type_annotation);
    }
    return val;
  }

353
  bool IsUnique(const Expr& expr) {
354 355 356 357 358 359
    auto it = dg_.expr_node.find(expr);
    if (it == dg_.expr_node.end()) {
      return true;
    } else {
      return !(it->second->parents.head && it->second->parents.head->next);
    }
360 361 362
  }

  bool AlwaysInline(const Expr& expr) {
363 364
    return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
           expr.as<VarNode>() || expr.as<ConstructorNode>();
365 366 367 368 369
  }

  //------------------------------------
  // Overload of Expr printing functions
  //------------------------------------
370
  Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
371 372 373 374
    // Exploit memoization to print GNF.
    // The first time we visit an expression, we need to allocate a temp var
    // for it. Every subsequent time we can just use its assigned variable.
    // This works since hashing uses pointer equality.
375 376 377 378 379 380 381

    // determine whether to inline
    bool inline_expr = AlwaysInline(expr);
    if (try_inline) {
      inline_expr |= IsUnique(expr);
    }

382 383
    auto it = memo_.find(expr);
    if (it != memo_.end()) return it->second;
384

385 386 387
    Doc printed_expr;
    if (meta) {
      printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
388
    } else if (!inline_expr && expr.as<LetNode>()) {
389 390
      // wrap GNFed let in brackets
      Doc body;
391
      printed_expr << "(";
392
      printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
393
      printed_expr << ")";
394 395 396
    } else {
      printed_expr = VisitExpr(expr);
    }
397

398
    printed_expr << PrintOptionalInfo(expr);
399 400 401

    // add expr to doc
    if (expr.as<VarNode>()) {
402 403
      // This is our first time visiting the var and we hit the VarNode case
      // in the visitor. Thus the variable is free.
404
      doc_stack_.back() << "free_var " << printed_expr << PrintNewLine();
405 406
      // Memoization is done in AllocVar.
      return memo_[expr];
407
    } else if (inline_expr) {
408 409
      memo_[expr] = printed_expr;
      return printed_expr;
410 411 412
    } else {
      Doc temp_var = AllocTemp();
      memo_[expr] = temp_var;
413
      doc_stack_.back() << temp_var << " = " << printed_expr << ";" << PrintNewLine();
414
      return temp_var;
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
    }
  }

  // Should only be triggered when op is a free variable being visited for the
  // first time.
  Doc VisitExpr_(const VarNode* op) final {
    return AllocVar(GetRef<Var>(op));
  }

  Doc VisitExpr_(const ConstantNode* op) final {
    // Print out simple scalars directly.
    if (op->is_scalar()) {
      std::ostringstream os;
      DataType dtype = TVMType2Type(op->data->dtype);
      CHECK_EQ(op->data->ctx.device_type, kDLCPU);
      if (dtype == Int(32)) {
        return PrintConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
      } else if (dtype == Int(64)) {
        return PrintConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
      } else if (dtype == Float(32)) {
        return PrintConstScalar(dtype, static_cast<const float*>(op->data->data));
      } else if (dtype == Float(64)) {
        return PrintConstScalar(dtype, static_cast<const double*>(op->data->data));
      } else if (dtype == Bool()) {
        return PrintConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
      }
    }
    // default fall-back, record it as meta node.
    Doc doc;
444
    return doc << Print(GetRef<NodeRef>(op), true);
445 446 447 448 449 450 451 452
  }

  Doc VisitExpr_(const TupleNode* op) final {
    std::vector<Doc> fields;
    for (Expr field : op->fields) {
      fields.push_back(Print(field));
    }
    Doc doc;
453
    doc << "(" << PrintSep(fields);
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
    // conform to python tuple format (1,)
    if (op->fields.size() == 1) {
      doc << ",";
    }
    return doc << ")";
  }

  Doc VisitExpr_(const TupleGetItemNode* op) final {
    Doc doc;
    return doc << Print(op->tuple) << "." << op->index;
  }

  Doc VisitExpr_(const IfNode* op) final {
    Doc doc;
    doc << "if (" << Print(op->cond) << ") ";
    doc << PrintBody(op->true_branch);
    doc << " else ";
    doc << PrintBody(op->false_branch);
    return doc;
  }

  Doc VisitExpr_(const LetNode* op) final {
    Doc doc;
477 478 479 480 481 482 483
    doc
      << "let "
      << AllocVar(op->var)
      << " = "
      << Print(op->value, false, true)
      << ";"
      << PrintNewLine();
484
    // we use a scope here so GNF hoisting doesn't escape too far
485
    // and nested, unique lets are not hoisted
486 487 488 489 490
    doc << PrintScope(op->body);
    return doc;
  }

  Doc PrintFunc(const Doc& prefix, const Function& fn) {
491 492 493
    Doc doc;
    doc << prefix;
    if (fn->type_params.size() > 0) {
494
      doc << "[";
495 496
      std::vector<Doc> type_params;
      for (const TypeVar& tv : fn->type_params) {
497
        type_params.push_back(Doc(tv->var->name_hint));
498
      }
499
      doc << PrintSep(type_params);
500
      doc << "]";
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
    }
    doc << "(";
    std::vector<Doc> params;
    for (Var param : fn->params) {
      params.push_back(AllocVar(param));
    }
    for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
      params.push_back(d);
    }
    doc << PrintSep(params) << ") ";
    if (fn->ret_type.defined()) {
      doc << "-> " << Print(fn->ret_type) << " ";
    }
    doc << PrintBody(fn->body);
    return doc;
516 517 518 519 520
  }

  Doc PrintMod(const Module& mod) {
    Doc doc;
    int counter = 0;
521 522 523 524 525 526 527 528 529
    // type definitions
    for (const auto& kv : mod->type_definitions) {
      if (counter++ != 0) {
        doc << PrintNewLine();
      }
      doc << Print(kv.second);
      doc << PrintNewLine();
    }
    // functions
530
    for (const auto& kv : mod->functions) {
531 532
      dg_ = DependencyGraph::Create(&arena_, kv.second);

533
      if (counter++ != 0) {
534
        doc << PrintNewLine();
535
      }
536
      std::ostringstream os;
537 538
      os << "def @" << kv.first->name_hint;
      doc << PrintFunc(Doc(os.str()), kv.second);
539
      doc << PrintNewLine();
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
    }
    return doc;
  }

  Doc VisitExpr_(const FunctionNode* op) final {
    return PrintFunc(Doc("fn "), GetRef<Function>(op));
  }

  Doc VisitExpr_(const GlobalVarNode* op) final {
    return Doc('@' + op->name_hint);
  }

  Doc VisitExpr_(const OpNode* op) final {
    return Doc(op->name);
  }

  Doc VisitExpr_(const CallNode* op) final {
    Doc doc;
558 559
    // visit args first so they are lifted before the op
    // this places op closer to its call site
560
    std::vector<Doc> args;
561
    for (const Expr& arg : op->args) {
562 563
      args.push_back(Print(arg));
    }
564 565 566
    for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
      args.push_back(d);
    }
567 568 569 570 571 572
    const auto* cons_node = op->op.as<ConstructorNode>();
    if (cons_node) {
      doc << cons_node->name_hint;
    } else {
      doc << Print(op->op);
    }
573
    return doc << "(" << PrintSep(args) << ")";
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
  }

  Doc VisitExpr_(const RefCreateNode* op) final {
    Doc doc;
    return doc << "ref(" << Print(op->value) << ")";
  }

  Doc VisitExpr_(const RefReadNode* op) final {
    Doc doc;
    return doc << Print(op->ref) << "^";
  }

  Doc VisitExpr_(const RefWriteNode* op) final {
    Doc doc;
    return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
  }

  Doc VisitExpr_(const MatchNode* op) final {
    // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
    Doc doc;
    Doc body;
595 596 597 598 599 600
    doc << "match";
    if (!op->complete) {
      doc << "?";
    }
    doc << " (" << Print(op->data) << ") {";
    std::vector<Doc> clause_docs;
601 602
    for (const auto& clause : op->clauses) {
      Doc clause_doc;
603 604 605 606 607 608 609 610
      clause_doc << PrintPattern(clause->lhs, false) << " => ";
      Doc rhs_doc = PrintScope(clause->rhs);
      if (clause->rhs.as<LetNode>()) {
        // only add braces if there are multiple lines on the rhs
        rhs_doc = Brace(rhs_doc);
      }
      clause_doc << rhs_doc << ",";
      clause_docs.push_back(clause_doc);
611
    }
612 613
    doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
        << PrintNewLine() << "}";
614 615 616
    return doc;
  }

617 618 619 620 621 622 623 624 625 626 627 628 629
  Doc PrintPattern(const Pattern& pattern, bool meta) {
    auto it = memo_pattern_.find(pattern);
    if (it != memo_pattern_.end()) return it->second;
    Doc printed_pattern;
    if (meta) {
      printed_pattern = meta_.GetMetaNode(GetRef<NodeRef>(pattern.get()));
    } else {
      printed_pattern = VisitPattern(pattern);
    }
    memo_pattern_[pattern] = printed_pattern;
    return printed_pattern;
  }

630 631
  Doc VisitPattern_(const PatternConstructorNode* p) final {
    Doc doc;
632 633 634 635 636 637 638 639
    doc << p->constructor->name_hint;
    if (!p->patterns.empty()) {
      doc << "(";
      std::vector<Doc> pats;
      for (const auto& pat : p->patterns) {
        pats.push_back(Print(pat));
      }
      doc << PrintSep(pats) << ")";
640
    }
641 642 643 644 645
    return doc;
  }

  Doc VisitPattern_(const PatternWildcardNode* pw) final {
    return Doc("_");
646 647 648 649 650 651 652
  }

  Doc VisitPattern_(const PatternVarNode* pv) final {
    return AllocVar(pv->var);
  }

  Doc VisitExpr_(const ConstructorNode* n) final {
653 654 655 656 657 658 659 660 661 662 663
    Doc doc;
    doc << n->name_hint;
    if (n->inputs.size() != 0) {
      doc << "(";
      std::vector<Doc> inputs;
      for (Type input : n->inputs) {
        inputs.push_back(Print(input));
      }
      doc << PrintSep(inputs) << ")";
    }
    return doc;
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
  }

  //------------------------------------
  // Overload of Type printing functions
  //------------------------------------
  Doc PrintType(const Type& type, bool meta) {
    auto it = memo_type_.find(type);
    if (it != memo_type_.end()) return it->second;
    Doc printed_type;
    if (meta) {
      printed_type = meta_.GetMetaNode(GetRef<NodeRef>(type.get()));
    } else {
      printed_type = VisitType(type);
    }
    memo_type_[type] = printed_type;
    return printed_type;
  }

  Doc VisitTypeDefault_(const Node* node) final {
    // by default always print as meta data
    return Print(GetRef<NodeRef>(node), true);
  }

雾雨魔理沙 committed
687
  Doc VisitType_(const TypeVarNode* node) final {
688
    return Doc(node->var->name_hint);
雾雨魔理沙 committed
689 690
  }

691 692 693 694 695 696 697 698 699 700 701
  Doc VisitType_(const GlobalTypeVarNode* node) final {
    return Doc(node->var->name_hint);
  }

  Doc VisitType_(const TypeCallNode* node) final {
    Doc doc = PrintType(node->func, false);
    std::vector<Doc> args;
    for (const Type& t : node->args) {
      args.push_back(PrintType(t, false));
    }
    doc << "[";
702
    doc << PrintSep(args);
703 704 705 706
    doc << "]";
    return doc;
  }

707 708 709 710 711 712 713 714 715 716 717
  Doc VisitType_(const TensorTypeNode* node) final {
    // scalar type
    if (node->shape.size() == 0) {
      return PrintDType(node->dtype);
    }
    Doc doc;
    doc << "Tensor[(";
    std::vector<Doc> shapes;
    for (NodeRef shape : node->shape) {
      shapes.push_back(PrintAttr(shape));
    }
718
    doc << PrintSep(shapes);
719 720 721 722 723 724 725 726 727
    return doc << "), " << PrintDType(node->dtype) << "]";
  }

  Doc VisitType_(const TupleTypeNode* node) final {
    std::vector<Doc> fields;
    for (Type field : node->fields) {
      fields.push_back(Print(field));
    }
    Doc doc;
728
    doc << "(" << PrintSep(fields);
729 730 731 732 733 734 735 736 737
    // conform to python tuple format (1,)
    if (node->fields.size() == 1) {
      doc << ",";
    }
    return doc << ")";
  }

  Doc VisitType_(const FuncTypeNode* node) final {
    Doc doc;
738 739
    doc << "fn ";
    if (node->type_params.size() != 0) {
740
      doc << "[";
741 742 743 744
      std::vector<Doc> type_params;
      for (Type type_param : node->type_params) {
        type_params.push_back(Print(type_param));
      }
745
      doc << PrintSep(type_params);
746
      doc << "]";
747
    }
748 749 750 751
    std::vector<Doc> arg_types;
    for (Type arg_type : node->arg_types) {
      arg_types.push_back(Print(arg_type));
    }
752
    return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
753 754 755 756 757 758 759
  }

  Doc VisitType_(const RefTypeNode* node) final {
    Doc doc;
    return doc << "ref(" << Print(node->value) << ")";
  }

760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
  Doc VisitType_(const TypeDataNode* node) final {
    Doc doc;
    doc << "type " << Print(node->header);

    // type vars
    if (node->type_vars.size() != 0) {
      doc << "[";
      std::vector<Doc> type_vars;
      for (Type type_var : node->type_vars) {
        type_vars.push_back(Print(type_var));
      }
      doc << PrintSep(type_vars) << "]";
    }
    doc << " ";

    std::vector<Doc> constructor_docs;
    for (Constructor constructor : node->constructors) {
      constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
    }
    Doc separator;
    separator << "," << PrintNewLine();
    Doc adt_body;
    adt_body << PrintSep(constructor_docs, separator);
    // add trailing comma if there are any constructors
    if (!constructor_docs.empty()) {
      adt_body << ",";
    }
    doc << Brace(adt_body);
    return doc;
  }

791 792 793 794 795 796 797
  //------------------------------------
  // Overload of Attr printing functions
  //------------------------------------

  Doc PrintAttr(const NodeRef& value, bool meta = false) {
    if (value.defined()) {
      Doc printed_attr;
798 799 800
      if (value.as<tvm::ir::Any>()) {
        printed_attr << "?";
      } else if (meta) {
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
        printed_attr = meta_.GetMetaNode(value);
      } else {
        printed_attr = VisitAttr(value);
      }
      return printed_attr;
    } else {
      return Doc("None");
    }
  }

  Doc VisitAttrDefault_(const Node* op) final {
    return PrintAttr(GetRef<NodeRef>(op), true);
  }

  Doc VisitAttr_(const ArrayNode* op) final {
    Doc doc;
    doc << "[";
    std::vector<Doc> arr_vals;
    for (NodePtr<Node> val : op->data) {
      arr_vals.push_back(PrintAttr(NodeRef(val)));
    }
822
    doc << PrintSep(arr_vals);
823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
    doc << "]";
    return doc;
  }

  Doc VisitAttr_(const ir::IntImm* op) final {
    return PrintConstScalar(op->type, &(op->value));
  }

  Doc VisitAttr_(const ir::UIntImm* op) final {
    return PrintConstScalar(op->type, &(op->value));
  }

  Doc VisitAttr_(const ir::FloatImm* op) final {
    return PrintConstScalar(op->type, &(op->value));
  }

  Doc VisitAttr_(const ir::StringImm* op) final {
    return PrintString(op->value);
  }

 private:
  /*! \brief Whether to print meta data. */
  bool show_meta_data_;
  /*! \brief additional comment function */
  runtime::TypedPackedFunc<std::string(Expr)> annotate_;
  /*! \brief Stack of docs to implement scoped GNFing. */
  std::vector<Doc> doc_stack_{};
  /*! \brief Map from Expr to Doc */
  std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
  /*! \brief Map from Type to Doc */
  std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
854 855
  /*! \brief Map from Type to Doc */
  std::unordered_map<Pattern, Doc, NodeHash, NodeEqual> memo_pattern_;
856 857 858 859 860 861
  /*! \brief name allocation map */
  std::unordered_map<std::string, int> name_alloc_map_;
  /*! \brief meta data context */
  TextMetaDataContext meta_;
  /*! \brief counter of temporary variable */
  size_t temp_var_counter_{0};
862 863 864 865
  /*! \brief arena for dependency graph */
  common::Arena arena_;
  /*! \brief dependency graph of the expr */
  DependencyGraph dg_;
866 867 868 869 870 871 872 873 874
  class AttrPrinter;
  friend class AttrPrinter;
};

/*!
 * \brief Attribute printer which prints the attributes in the call.
 */
class PrettyPrinter::AttrPrinter : public AttrVisitor {
 public:
875
  AttrPrinter(std::vector<Doc>* doc, PrettyPrinter* parent) : docs(doc), parent_(parent) {}
876 877

  template<typename T>
878
  void PrintKV(const char* key, const T& value) {
879
    Doc doc;
880 881
    doc << key << "=" << value;
    docs->push_back(doc);
882 883 884
  }

  void Visit(const char* key, double* value) final {
885 886 887
    Doc doc;
    doc << key << "=" << *value << "f";
    docs->push_back(doc);
888 889
  }
  void Visit(const char* key, int64_t* value) final {
890
    PrintKV(key, *value);
891 892
  }
  void Visit(const char* key, uint64_t* value) final {
893
    PrintKV(key, *value);
894 895
  }
  void Visit(const char* key, int* value) final {
896
    PrintKV(key, *value);
897 898
  }
  void Visit(const char* key, bool* value) final {
899
    PrintKV(key, PrintBool(*value));
900 901
  }
  void Visit(const char* key, std::string* value) final {
902
    PrintKV(key, PrintString(*value));
903 904 905 906 907
  }
  void Visit(const char* key, void** value) final {
    LOG(FATAL) << "do not allow void as argument";
  }
  void Visit(const char* key, DataType* value) final {
908
    PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value))));
909 910
  }
  void Visit(const char* key, NodeRef* value) final {
911
    PrintKV(key, parent_->PrintAttr(*value));
912 913 914 915
  }
  void Visit(const char* key, runtime::NDArray* value) final {
    LOG(FATAL) << "do not allow NDarray as argument";
  }
916 917 918
  void Visit(const char* key, runtime::Object* obj) final {
    LOG(FATAL) << "do not allow Object as argument";
  }
919 920

 private:
921
  std::vector<Doc>* docs;
922 923 924
  PrettyPrinter* parent_;
};

925 926 927
std::vector<Doc> PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
  std::vector<Doc> docs;
  if (!attrs.defined()) return docs;
928 929 930
  const auto* op_node = op.as<OpNode>();
  if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
    // fallback
931 932 933 934
    Doc doc;
    doc << meta_.GetMetaNode(attrs);
    docs.push_back(doc);
    return docs;
935
  } else {
936
    AttrPrinter printer(&docs, this);
937
    const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
938 939 940 941 942 943 944 945 946 947 948 949 950
    return docs;
  }
}

std::vector<Doc> PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) {
  std::vector<Doc> docs;
  if (!attrs.defined()) return docs;
  const auto* dict_attrs = attrs.as<DictAttrsNode>();
  CHECK(dict_attrs);
  for (const auto& k : dict_attrs->dict) {
    Doc doc;
    doc << k.first << "=" << Print(k.second);
    docs.push_back(doc);
951
  }
952
  return docs;
953 954 955 956
}

std::string PrettyPrint_(const NodeRef& node,
                         bool show_meta_data,
957
                         runtime::TypedPackedFunc<std::string(Expr)> annotate) {
958
  Doc doc;
959
  doc << kSemVer << PrintNewLine()
960
      << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
961 962 963
  return doc.str();
}

964 965 966 967 968 969
std::string PrettyPrint(const NodeRef& node) {
  Doc doc;
  doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
  return doc.str();
}

970
std::string AsText(const NodeRef& node,
971 972 973
                       bool show_meta_data,
                       runtime::TypedPackedFunc<std::string(Expr)> annotate) {
  return PrettyPrint_(node, show_meta_data, annotate);
974 975
}

976
TVM_REGISTER_API("relay._expr.AsText")
977 978
.set_body_typed<std::string(const NodeRef&,
                            bool,
979
                            runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
980 981 982

}  // namespace relay
}  // namespace tvm