Unverified Commit e4d817d4 by Tianqi Chen Committed by GitHub

[REFACTOR] Establish printer in the source folder (#4752)

* [REFACTOR] Establish printer in the source folder.

As we move towards the unified IR, we will eventually want to build a unified
printers for both relay and TIR.

This PR isolate the printer component into a separate folder in src as a first step.

- Refactored the Doc DSL using Object, clean up APIs.
- Isolate out the meta data into a header.
- move printer into relay_text_printer, add comments about further TODos.

* Rename NodePrinter -> ReprPrinter to distinguish it from other printers
parent f8f75ca2
......@@ -132,6 +132,7 @@ file(GLOB_RECURSE COMPILER_SRCS
src/autotvm/*.cc
src/tir/*.cc
src/driver/*.cc
src/printer/*.cc
src/api/*.cc
)
......
......@@ -144,7 +144,7 @@ def _GetContext(debugger):
def PrettyPrint(debugger, command, result, internal_dict):
ctx = _GetContext(debugger)
rc = ctx.EvaluateExpression(
"tvm::relay::PrettyPrint({command})".format(command=command)
"tvm::PrettyPrint({command})".format(command=command)
)
result.AppendMessage(str(rc))
......@@ -175,7 +175,7 @@ def _EvalExpressionAsString(logger, ctx, expr):
def _EvalAsNodeRef(logger, ctx, value):
return _EvalExpressionAsString(
logger, ctx, "tvm::relay::PrettyPrint({name})".format(name=value.name)
logger, ctx, "tvm::PrettyPrint({name})".format(name=value.name)
)
......
......@@ -308,5 +308,33 @@ class IRModule : public ObjectRef {
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
};
/*!
* \brief Pretty print a node for debug purposes.
*
* \param node The node to be printed.
* \return The text reperesentation.
* \note This function does not show version or meta-data.
* Use AsText if you want to store the text.
* \sa AsText.
*/
TVM_DLL std::string PrettyPrint(const ObjectRef& node);
/*!
* \brief Render the node as a string in the text format.
*
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
*
* \note We support a limited set of IR nodes that are part of
* relay IR and
*
* \sa PrettyPrint.
* \return The text representation.
*/
TVM_DLL std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
......@@ -139,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of NodePrinter.
* // interface of ReprPrinter.
*
* class NodePrinter {
* class ReprPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
......@@ -152,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
......
......@@ -38,7 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <string>
#include <vector>
......
......@@ -17,25 +17,25 @@
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \file tvm/node/repr_printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_
#ifndef TVM_NODE_REPR_PRINTER_H_
#define TVM_NODE_REPR_PRINTER_H_
#include <tvm/node/functor.h>
#include <iostream>
namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
class ReprPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit NodePrinter(std::ostream& stream) // NOLINT(*)
explicit ReprPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}
/*! \brief The node to be printed. */
......@@ -43,7 +43,7 @@ class NodePrinter {
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
using FType = NodeFunctor<void(const ObjectRef&, ReprPrinter*)>;
TVM_DLL static FType& vtable();
};
......@@ -60,9 +60,9 @@ namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
ReprPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
#endif // TVM_NODE_REPR_PRINTER_H_
......@@ -26,6 +26,7 @@
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <string>
#include <functional>
#include "./base.h"
......@@ -40,6 +41,7 @@ using BaseFunc = tvm::BaseFunc;
using BaseFuncNode = tvm::BaseFuncNode;
using GlobalVar = tvm::GlobalVar;
using GlobalVarNode = tvm::GlobalVarNode;
using tvm::PrettyPrint;
/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
......@@ -539,20 +541,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const ObjectRef& node);
/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
......
......@@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
......
......@@ -813,8 +813,8 @@ IntSet EvalSet(Range r,
TVM_REGISTER_NODE_TYPE(IntervalSetNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
......
......@@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
......
......@@ -45,8 +45,8 @@ TVM_REGISTER_GLOBAL("relay._make.Constructor")
return Constructor(name_hint, inputs, belong_to);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
......@@ -71,8 +71,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeData")
return TypeData(header, type_vars, constructors);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
......
......@@ -59,8 +59,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
......
......@@ -31,8 +31,8 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")";
});
......
......@@ -111,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) {
annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
CHECK_NE(it->second.size(), 0);
......
......@@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("make.IntImm")
TVM_REGISTER_NODE_TYPE(IntImmNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
......@@ -104,8 +104,8 @@ TVM_REGISTER_GLOBAL("make.FloatImm")
TVM_REGISTER_NODE_TYPE(FloatImmNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
......@@ -134,8 +134,8 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
......@@ -159,15 +159,15 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
return GlobalVar(name);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});
// Container printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ArrayNode*>(node.get());
p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) {
......@@ -179,8 +179,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ']';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
......@@ -194,8 +194,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '}';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StrMapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
......
......@@ -434,8 +434,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
mod->ImportFromStd(path);
});;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IRModuleNode*>(ref.get());
p->stream << "IRModuleNode( " << node->functions << ")";
});
......
......@@ -227,8 +227,8 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const OpNode*>(ref.get());
p->stream << "Op(" << node->name << ")";
});
......
......@@ -48,8 +48,8 @@ SourceName SourceName::Get(const std::string& name) {
TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
......@@ -73,8 +73,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
......
......@@ -27,7 +27,7 @@
namespace tvm {
using tvm::NodePrinter;
using tvm::ReprPrinter;
using namespace tvm::runtime;
TensorType::TensorType(Array<PrimExpr> shape, DataType dtype) {
......@@ -60,8 +60,8 @@ TVM_REGISTER_GLOBAL("relay._make.TensorType")
return TensorType(shape, dtype);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TensorTypeNode*>(ref.get());
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
......
......@@ -24,7 +24,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/ir/transform.h>
// TODO(tqchen): Update to use String container after it is merged.
......@@ -38,7 +38,7 @@ namespace transform {
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::NodePrinter;
using tvm::ReprPrinter;
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
......@@ -341,8 +341,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Info")
*ret = pass->Info();
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
auto* node = static_cast<const PassInfoNode*>(ref.get());
p->stream << "The meta data of the pass: ";
p->stream << "pass name: " << node->name;
......@@ -371,8 +371,8 @@ TVM_REGISTER_GLOBAL("relay._transform.RunPass")
*ret = pass(mod);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ModulePassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Module pass: " << info->name
......@@ -391,8 +391,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential")
*ret = Sequential(passes, pass_info);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SequentialNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Sequential pass: " << info->name
......@@ -421,8 +421,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
*ret = pctx;
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PassContextNode*>(ref.get());
p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
......
......@@ -38,8 +38,8 @@ TVM_REGISTER_GLOBAL("relay._make.PrimType")
return PrimType(dtype);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PrimTypeNode*>(ref.get());
p->stream << node->dtype;
});
......@@ -59,8 +59,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeVar")
return TypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVar(" << node->name_hint << ", "
<< node->kind << ")";
......@@ -81,8 +81,8 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVar(" << node->name_hint << ", "
<< node->kind << ")";
......@@ -110,8 +110,8 @@ TVM_REGISTER_GLOBAL("relay._make.FuncType")
return FuncType(arg_types, ret_type, type_params, type_constraints);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncType(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
......@@ -136,8 +136,8 @@ TVM_REGISTER_GLOBAL("relay._make.TupleType")
return TupleType(fields);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});
......@@ -156,8 +156,8 @@ TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
return IncompleteType(static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
......@@ -176,8 +176,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefType")
TVM_REGISTER_NODE_TYPE(RelayRefTypeNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
p->stream << "RelayRefTypeNode(" << node->value << ")";
});
......
......@@ -40,8 +40,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeCall")
return TypeCall(func, type);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
......@@ -69,8 +69,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
return TypeRelation(func, args, num_inputs, attrs);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeRelationNode*>(ref.get());
p->stream << "TypeRelationNode("
<< node->func->name
......
......@@ -19,13 +19,13 @@
/*!
* Printer utilities
* \file node/printer.cc
* \file node/repr_printer.cc
*/
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
namespace tvm {
void NodePrinter::Print(const ObjectRef& node) {
void ReprPrinter::Print(const ObjectRef& node) {
static const FType& f = vtable();
if (!node.defined()) {
stream << "(nullptr)";
......@@ -39,13 +39,13 @@ void NodePrinter::Print(const ObjectRef& node) {
}
}
void NodePrinter::PrintIndent() {
void ReprPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
}
}
NodePrinter::FType& NodePrinter::vtable() {
ReprPrinter::FType& ReprPrinter::vtable() {
static FType inst;
return inst;
}
......
......@@ -20,43 +20,82 @@
/*!
* \file src/tvm/relay/doc.cc
* \brief Doc ADT used for pretty printing.
* Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf.
*
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
*/
#include <memory>
#include <tvm/runtime/packed_func.h>
#include <vector>
#include <sstream>
#include "doc.h"
namespace tvm {
namespace relay {
// Text constructor
DocAtom Text(const std::string& str) {
return std::make_shared<TextNode>(str);
}
/*!
* \brief Represent a piece of text in the doc.
*/
class DocTextNode : public DocAtomNode {
public:
/*! \brief The str content in the text. */
std::string str;
explicit DocTextNode(std::string str_val)
: str(str_val) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}
// Line constructor
DocAtom Line(int indent = 0) {
return std::make_shared<LineNode>(indent);
}
static constexpr const char* _type_key = "printer.DocText";
TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode);
};
Doc::Doc(const std::string& str) {
if (str == "\n") {
this->stream_ = {Line()};
} else {
this->stream_ = {Text(str)};
TVM_REGISTER_OBJECT_TYPE(DocTextNode);
class DocText : public DocAtom {
public:
explicit DocText(std::string str) {
data_ = runtime::make_object<DocTextNode>(str);
}
}
// DSL function implementations
TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode);
};
/*!
* \brief Represent a line breaker in the doc.
*/
class DocLineNode : public DocAtomNode {
public:
/*! \brief The amount of indent in newline. */
int indent;
explicit DocLineNode(int indent)
: indent(indent) {}
static constexpr const char* _type_key = "printer.DocLine";
TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode);
};
TVM_REGISTER_OBJECT_TYPE(DocLineNode);
class DocLine : public DocAtom {
public:
explicit DocLine(int indent) {
data_ = runtime::make_object<DocLineNode>(indent);
}
TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode);
};
// DSL function implementations
Doc& Doc::operator<<(const Doc& right) {
CHECK(this != &right);
this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end());
this->stream_.insert(
this->stream_.end(), right.stream_.begin(), right.stream_.end());
return *this;
}
Doc& Doc::operator<<(const std::string& right) {
return *this << Doc(right);
Doc& Doc::operator<<(std::string right) {
return *this << DocText(right);
}
Doc& Doc::operator<<(const DocAtom& right) {
......@@ -64,63 +103,71 @@ Doc& Doc::operator<<(const DocAtom& right) {
return *this;
}
Doc Indent(int indent, const Doc& doc) {
Doc ret;
for (auto atom : doc.stream_) {
if (auto text = std::dynamic_pointer_cast<TextNode>(atom)) {
ret.stream_.push_back(text);
} else if (auto line = std::dynamic_pointer_cast<LineNode>(atom)) {
ret.stream_.push_back(Line(indent + line->indent));
} else {CHECK(false);}
}
return ret;
}
std::string Doc::str() {
std::ostringstream os;
for (auto atom : this->stream_) {
if (auto text = std::dynamic_pointer_cast<TextNode>(atom)) {
if (auto* text = atom.as<DocTextNode>()) {
os << text->str;
} else if (auto line = std::dynamic_pointer_cast<LineNode>(atom)) {
} else if (auto* line = atom.as<DocLineNode>()) {
os << "\n" << std::string(line->indent, ' ');
} else {CHECK(false);}
} else {
LOG(FATAL) << "do not expect type " << atom->GetTypeKey();
}
}
return os.str();
}
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
seq = vec[0];
for (size_t i = 1; i < vec.size(); i++) {
seq << sep << vec[i];
}
}
return seq;
Doc Doc::NewLine(int indent) {
return Doc() << DocLine(indent);
}
Doc PrintBool(bool value) {
if (value) {
return Doc("True");
} else {
return Doc("False");
}
Doc Doc::Text(std::string text) {
return Doc() << DocText(text);
}
Doc PrintDType(DataType dtype) {
return Doc(runtime::DLDataType2String(dtype));
Doc Doc::Indent(int indent, Doc doc) {
for (size_t i = 0; i < doc.stream_.size(); ++i) {
if (auto* line = doc.stream_[i].as<DocLineNode>()) {
doc.stream_[i] = DocLine(indent + line->indent);
}
}
return doc;
}
Doc PrintString(const std::string& value) {
Doc Doc::StrLiteral(const std::string& value, std::string quote) {
// TODO(M.K.): add escape.
Doc doc;
return doc << "\"" << value << "\"";
return doc << quote << value << quote;
}
Doc PrintNewLine(int ident) {
Doc Doc::PyBoolLiteral(bool value) {
if (value) {
return Doc::Text("True");
} else {
return Doc::Text("False");
}
}
Doc Doc::Brace(std::string open,
const Doc& body,
std::string close,
int indent) {
Doc doc;
return doc << Line(ident);
doc << open;
doc << Indent(indent, NewLine() << body) << NewLine();
doc << close;
return doc;
}
} // namespace relay
Doc Doc::Concat(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
if (vec.size() == 1) return vec[0];
seq << vec[0];
for (size_t i = 1; i < vec.size(); ++i) {
seq << sep << vec[i];
}
}
return seq;
}
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/printer/doc.h
* \brief Doc ADT used for pretty printing.
*
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
*/
#ifndef TVM_PRINTER_DOC_H_
#define TVM_PRINTER_DOC_H_
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <string>
#include <vector>
#include <type_traits>
namespace tvm {
/*!
* \brief Doc atom node for the ADT.
* \sa DocAtom
*/
class DocAtomNode : public Object {
public:
static constexpr const char* _type_key = "printer.DocAtom";
TVM_DECLARE_BASE_OBJECT_INFO(DocAtomNode, Object);
};
/*!
* \brief Managed reference to DocAtomNode.
* \sa DocAtomNode.
*/
class DocAtom : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode);
};
/*!
* \brief Stream-like interface for Doc DSL.
*
* The Doc DSL de-couples the layout decision from the printing decision.
*
* The layout(code formating) decisions include:
* - Change indentation.
* - Break single line into multiple ones(subjected to future improvements).
*/
class Doc {
public:
/*! \brief default constructor */
Doc() {}
/*!
* \brief Append right to the end of the current doc stream.
* \param right The doc to be appended.
* \return reference to self.
*/
Doc& operator<<(const Doc& right);
/*!
* \brief Append right to the end of the current doc stream.
* \param right The doc to be appended.
* \return reference to self.
* \note pass by value to allow copy elison optimization.
*/
Doc& operator<<(std::string right);
/*!
* \brief Append right to the end of the current doc stream.
* \param right The doc to be appended.
* \return reference to self.
*/
Doc& operator<<(const DocAtom& right);
/*!
* \brief Convert value to string via std::ostreamstream
* the append to the current doc stream.
* \param right The doc to be appended.
* \tparam T the type of the value.
* \return reference to self.
*/
template<typename T,
typename = typename std::enable_if<!std::is_class<T>::value>::type>
Doc& operator<<(const T& value) {
std::ostringstream os;
os << value;
return *this << os.str();
}
/*!
* \brief Convert the doc stream into string.
* \return The string representation.
*/
std::string str();
/*!
* \brief Create a doc that represents text content.
* \return The created doc.
*/
static Doc Text(std::string value);
/*!
* \brief Create a doc that represents a new line.
* \return The created doc.
*/
static Doc NewLine(int indent = 0);
/*!
* \brief Create a new doc that adds indentation to everyline of the doc.
* \param indent The indent to be added.
* \param doc The doc to be indented.
* \return The created doc.
* \note pass by value to allow copy elison optimization.
*/
static Doc Indent(int indent, Doc doc);
/*!
* \brief Create a Doc that represents a string literal.
* \param value The content of the string literal.
* \param quote The quote in the literal.
* \return The created doc.
*/
static Doc StrLiteral(const std::string& value, std::string quote = "\"");
/*!
* \brief Create a Doc that represents a boolean literal in python syntax.
* \param value The bool value.
* \return The created doc.
*/
static Doc PyBoolLiteral(bool value);
/*!
* \brief Enclose body by brace and add indent.
* \param body The body
* \param open The open brace.
* \param close The close brace.
* \param indent amount of indentation.
* \return The created doc.
*/
static Doc Brace(std::string open,
const Doc& body,
std::string close,
int indent = 2);
/*!
* \brief Create a doc by concatenating together with separator.
* \param vec The docs to be concatenated.
* \param sep The seperator.
* \return The created doc.
*/
static Doc Concat(const std::vector<Doc>& vec, const Doc& sep = Text(", "));
private:
/*! \brief Internal doc stream. */
std::vector<DocAtom> stream_;
};
} // namespace tvm
#endif // TVM_PRINTER_DOC_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/printer/meta_data.h
* \brief Meta data context for printers.
*/
#ifndef TVM_PRINTER_META_DATA_H_
#define TVM_PRINTER_META_DATA_H_
#include <tvm/node/serialization.h>
#include <tvm/node/container.h>
#include <string>
#include <unordered_map>
#include "doc.h"
namespace tvm {
/*!
* \brief Meta data context for Printers
*
* This is an important part to enable bi-directional serializability.
* We use tvm's Node system to build the current IR.
* It can be hard to design a text format for all the possible nodes
* as the set of nodes can grow when we do more extensions.
*
* Instead of trying to design readable text format for every node,
* we support a meta data section in the text format.
* We allow the text format to refer to a node in the meta data section.
*
* The meta data section is a json serialized string of an Map<string, Array<NodeRef>>.
* Each element in the meta data section can be referenced by the text format.
* Each meta data node is printed in the following format.
*
* meta[type-key-of-node>][<index-in-meta-section>]
*
* Specifically, consider the following IR(constructed by python).
*
* \code
*
* n = tvm.var("n")
* x = tvm.relay.var("x", shape=(n, 1))
* f = tvm.relay.Function([x], x)
* print(f.astext())
*
* \endcode
*
* The corresponding text format is shown in the following code block.
*
* \code
*
* fn (%x: Tensor[(meta[Variable][0],), float32]) {
* %x
* }
* # Meta data section is a json-serialized string
* # of the following array.
* # [tvm.var("n")]
*
* \endcode
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta data section,
* we print it as meta[Variable][0].
*
* The text parser can recover this object by loading from the corresponding
* location in the meta data section.
*
* This is is a design trade-off.
* It allows us to embedded any meta data in the text format,
* while still being able to tweak the text part of the printed IR easily.
*/
class TextMetaDataContext {
public:
/*!
* \brief Get text representation of meta node.
* \param node The node to be converted to meta node.
* \return A string representation of the meta node.
*/
Doc GetMetaNode(const ObjectRef& node) {
auto it = meta_repr_.find(node);
if (it != meta_repr_.end()) {
return it->second;
}
std::string type_key = node->GetTypeKey();
CHECK(!type_key.empty());
Array<ObjectRef>& mvector =
meta_data_[type_key];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
Doc doc;
doc << "meta[" << type_key << "][" << index << "]";
meta_repr_[node] = doc;
return meta_repr_[node];
}
/*!
* \brief Print a key value pair
*/
Doc PrintKeyValue(const std::string& str, const Doc& v) const {
return Doc() << "\"" << str << "\": " << v;
}
/*!
* \brief Get the metadata section in json format.
* \return the meta data string.
*/
Doc GetMetaSection() const {
if (meta_data_.size() == 0) return Doc();
return Doc::Text(
SaveJSON(Map<std::string, ObjectRef>(meta_data_.begin(), meta_data_.end())));
}
/*! \return whether the meta data context is empty. */
bool empty() const {
return meta_data_.empty();
}
private:
/*! \brief additional metadata stored in TVM json format */
std::unordered_map<std::string, Array<ObjectRef> > meta_data_;
/*! \brief map from meta data into its string representation */
std::unordered_map<ObjectRef, Doc, ObjectHash, ObjectEqual> meta_repr_;
};
} // namespace tvm
#endif // TVM_PRINTER_META_DATA_H_
......@@ -18,9 +18,11 @@
*/
/*!
* \file pretty_printer.cc
* \brief Pretty printer for Relay programs
* Supports ANF, GNF, and metadata.
* \file text_format_printer.cc
* \brief Printer to print out the IR text format
* that can be parsed by a parser.
*
* Supports ANF, GNF in relay and metadata.
*
* Inlining heuristics:
* - Always inline:
......@@ -31,141 +33,26 @@
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/node/serialization.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/ir/module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "../pass/dependency_graph.h"
#include "../../ir/attr_functor.h"
#include "meta_data.h"
#include "../relay/pass/dependency_graph.h"
#include "../ir/attr_functor.h"
namespace tvm {
namespace relay {
static const char* kSemVer = "v0.0.4";
Doc Brace(const Doc& d,
const std::string& open = "{",
const std::string& close = "}",
int indent = 2) {
Doc doc;
doc << open;
doc << Indent(indent, PrintNewLine() << d) << PrintNewLine();
doc << close;
return doc;
}
/*!
* \brief Meta data context for PrettyPrinter.
*
* This is an important part to enable bi-directional serializability.
* We use tvm's Node system to build the current IR.
* It can be hard to design a text format for all the possible nodes
* as the set of nodes can grow when we do more extensions.
*
* Instead of trying to design readable text format for every node,
* we support a meta data section in the text format.
* We allow the text format to refer to a node in the meta data section.
*
* The meta data section is a json serialized string of an Map<string, Array<NodeRef>>.
* Each element in the meta data section can be referenced by the text format.
* Each meta data node is printed in the following format.
*
* meta[type-key-of-node>][<index-in-meta-section>]
*
* Specifically, consider the following IR(constructed by python).
*
* \code
*
* n = tvm.var("n")
* x = tvm.relay.var("x", shape=(n, 1))
* f = tvm.relay.Function([x], x)
* print(f.astext())
*
* \endcode
*
* The corresponding text format is shown in the following code block.
*
* \code
*
* fn (%x: Tensor[(meta[Variable][0],), float32]) {
* %x
* }
* # Meta data section is a json-serialized string
* # of the following array.
* # [tvm.var("n")]
*
* \endcode
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta data section,
* we print it as meta[Variable][0].
*
* The text parser can recover this object by loading from the corresponding
* location in the meta data section.
*
* This is is a design trade-off.
* It allows us to embedded any meta data in the text format,
* while still being able to tweak the text part of the printed IR easily.
*/
class TextMetaDataContext {
public:
/*!
* \brief Get text representation of meta node.
* \param node The node to be converted to meta node.
* \return A string representation of the meta node.
*/
Doc GetMetaNode(const ObjectRef& node) {
auto it = meta_repr_.find(node);
if (it != meta_repr_.end()) {
return it->second;
}
std::string type_key = node->GetTypeKey();
CHECK(!type_key.empty());
Array<ObjectRef>& mvector =
meta_data_[type_key];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
Doc doc;
doc << "meta[" << type_key << "][" << index << "]";
meta_repr_[node] = doc;
return meta_repr_[node];
}
Doc PrintKeyValue(const std::string& str, const Doc& v) const {
return Doc("\"") << str << "\": " << v;
}
/*!
* \brief Get the metadata section in json format.
* \return the meta data string.
*/
Doc GetMetaSection() const {
if (meta_data_.size() == 0) return Doc();
return Doc(SaveJSON(Map<std::string, ObjectRef>(meta_data_.begin(), meta_data_.end())));
}
/*! \return whether the meta data context is empty. */
bool empty() const {
return meta_data_.empty();
}
private:
/*! \brief additional metadata stored in TVM json format */
std::unordered_map<std::string, Array<ObjectRef> > meta_data_;
/*! \brief map from meta data into its string representation */
std::unordered_map<ObjectRef, Doc, ObjectHash, ObjectEqual> meta_repr_;
};
class PrettyPrinter :
class RelayTextPrinter :
public ExprFunctor<Doc(const Expr&)>,
public PatternFunctor<Doc(const Pattern&)>,
public TypeFunctor<Doc(const Type&)>,
public AttrFunctor<Doc(const ObjectRef&)> {
public:
explicit PrettyPrinter(bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) :
show_meta_data_(show_meta_data),
explicit RelayTextPrinter(bool show_meta_data,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
: show_meta_data_(show_meta_data),
annotate_(annotate) {}
/*!
......@@ -194,7 +81,7 @@ class PrettyPrinter :
Doc doc;
Doc body;
doc << "{";
doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine();
doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine();
doc << "}";
return doc;
}
......@@ -220,10 +107,10 @@ class PrettyPrinter :
Doc doc;
doc << PrintScope(node);
if (!meta_.empty()) {
doc << PrintNewLine();
doc << Doc::NewLine();
if (show_meta_data_) {
// append meta data in the end.
doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection();
doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection();
} else {
doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
}
......@@ -244,8 +131,9 @@ class PrettyPrinter :
} else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<IRModule>(node));
} else {
Doc doc;
return doc << node;
std::ostringstream os;
os << node;
return Doc() << os.str();
}
}
......@@ -278,23 +166,23 @@ class PrettyPrinter :
}
}
name_alloc_map_[unique_prefix] = 0;
return Doc(unique_prefix);
return Doc::Text(unique_prefix);
}
Doc Print(Kind k) {
switch (k) {
case kType:
return Doc("Type");
return Doc::Text("Type");
case kShapeVar:
return Doc("Shape");
return Doc::Text("Shape");
case kBaseType:
return Doc("BaseType");
return Doc::Text("BaseType");
case kConstraint:
return Doc("Constraint");
return Doc::Text("Constraint");
case kAdtHandle:
return Doc("AdtHandle");
return Doc::Text("AdtHandle");
case kTypeData:
return Doc("TypeData");
return Doc::Text("TypeData");
default:
LOG(ERROR) << "Unknown Kind";
throw;
......@@ -387,7 +275,7 @@ class PrettyPrinter :
// wrap GNFed let in brackets
Doc body;
printed_expr << "(";
printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine();
printed_expr << ")";
} else {
printed_expr = VisitExpr(expr);
......@@ -399,7 +287,7 @@ class PrettyPrinter :
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << PrintNewLine();
doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine();
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
......@@ -408,7 +296,7 @@ class PrettyPrinter :
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << PrintNewLine();
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
return temp_var;
}
}
......@@ -419,6 +307,28 @@ class PrettyPrinter :
return AllocVar(GetRef<Var>(op));
}
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param value The value to be printed.
*/
template<typename T>
static Doc ScalarLiteral(DataType dtype, const T& value) {
std::ostringstream os;
if (dtype == DataType::Int(32)) {
os << value;
} else if (dtype == DataType::Float(32)) {
os << value << 'f';
} else if (dtype == DataType::Float(64)) {
os << value;
} else if (dtype == DataType::Bool()) {
return Doc::PyBoolLiteral(value != 0);
} else {
os << value;
}
return Doc::Text(os.str());
}
Doc VisitExpr_(const ConstantNode* op) final {
// Print out simple scalars directly.
if (op->is_scalar()) {
......@@ -426,15 +336,15 @@ class PrettyPrinter :
DataType dtype = DataType(op->data->dtype);
CHECK_EQ(op->data->ctx.device_type, kDLCPU);
if (dtype == DataType::Int(32)) {
return PrintConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
} else if (dtype == DataType::Int(64)) {
return PrintConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
} else if (dtype == DataType::Float(32)) {
return PrintConstScalar(dtype, static_cast<const float*>(op->data->data));
return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
} else if (dtype == DataType::Float(64)) {
return PrintConstScalar(dtype, static_cast<const double*>(op->data->data));
return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
} else if (dtype == DataType::Bool()) {
return PrintConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
}
}
// default fall-back, record it as meta node.
......@@ -448,7 +358,7 @@ class PrettyPrinter :
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << PrintSep(fields);
doc << "(" << Doc::Concat(fields);
// conform to python tuple format (1,)
if (op->fields.size() == 1) {
doc << ",";
......@@ -478,7 +388,7 @@ class PrettyPrinter :
<< " = "
<< Print(op->value, false, true)
<< ";"
<< PrintNewLine();
<< Doc::NewLine();
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body);
......@@ -492,9 +402,9 @@ class PrettyPrinter :
doc << "[";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(Doc(tv->name_hint));
type_params.push_back(Doc::Text(tv->name_hint));
}
doc << PrintSep(type_params);
doc << Doc::Concat(type_params);
doc << "]";
}
doc << "(";
......@@ -505,7 +415,7 @@ class PrettyPrinter :
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << PrintSep(params) << ") ";
doc << Doc::Concat(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
......@@ -530,36 +440,36 @@ class PrettyPrinter :
// type definitions
for (const auto& kv : mod->type_definitions) {
if (counter++ != 0) {
doc << PrintNewLine();
doc << Doc::NewLine();
}
doc << Print(kv.second);
doc << PrintNewLine();
doc << Doc::NewLine();
}
// functions
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
if (counter++ != 0) {
doc << PrintNewLine();
doc << Doc::NewLine();
}
std::ostringstream os;
os << "def @" << kv.first->name_hint;
doc << PrintFunc(Doc(os.str()), kv.second);
doc << PrintNewLine();
doc << PrintFunc(Doc::Text(os.str()), kv.second);
doc << Doc::NewLine();
}
return doc;
}
Doc VisitExpr_(const FunctionNode* op) final {
return PrintFunc(Doc("fn "), GetRef<Function>(op));
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}
Doc VisitExpr_(const GlobalVarNode* op) final {
return Doc('@' + op->name_hint);
return Doc::Text('@' + op->name_hint);
}
Doc VisitExpr_(const OpNode* op) final {
return Doc(op->name);
return Doc::Text(op->name);
}
Doc VisitExpr_(const CallNode* op) final {
......@@ -584,7 +494,7 @@ class PrettyPrinter :
// don't print as a call if it's a 0-arity cons
return doc;
} else {
return doc << "(" << PrintSep(args) << ")";
return doc << "(" << Doc::Concat(args) << ")";
}
}
......@@ -619,13 +529,13 @@ class PrettyPrinter :
Doc rhs_doc = PrintScope(clause->rhs);
if (clause->rhs.as<LetNode>()) {
// only add braces if there are multiple lines on the rhs
rhs_doc = Brace(rhs_doc);
rhs_doc = Doc::Brace("{", rhs_doc, "}");
}
clause_doc << rhs_doc << ",";
clause_docs.push_back(clause_doc);
}
doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
<< PrintNewLine() << "}";
doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine()))
<< Doc::NewLine() << "}";
return doc;
}
......@@ -651,7 +561,7 @@ class PrettyPrinter :
for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
}
doc << PrintSep(pats) << ")";
doc << Doc::Concat(pats) << ")";
}
return doc;
}
......@@ -663,12 +573,12 @@ class PrettyPrinter :
for (const auto& pat : pt->patterns) {
pats.push_back(Print(pat));
}
doc << PrintSep(pats) << ")";
doc << Doc::Concat(pats) << ")";
return doc;
}
Doc VisitPattern_(const PatternWildcardNode* pw) final {
return Doc("_");
return Doc::Text("_");
}
Doc VisitPattern_(const PatternVarNode* pv) final {
......@@ -684,7 +594,7 @@ class PrettyPrinter :
for (Type input : n->inputs) {
inputs.push_back(Print(input));
}
doc << PrintSep(inputs) << ")";
doc << Doc::Concat(inputs) << ")";
}
return doc;
}
......@@ -711,11 +621,11 @@ class PrettyPrinter :
}
Doc VisitType_(const TypeVarNode* node) final {
return Doc(node->name_hint);
return Doc::Text(node->name_hint);
}
Doc VisitType_(const GlobalTypeVarNode* node) final {
return Doc(node->name_hint);
return Doc::Text(node->name_hint);
}
Doc VisitType_(const TypeCallNode* node) final {
......@@ -725,11 +635,15 @@ class PrettyPrinter :
args.push_back(PrintType(t, false));
}
doc << "[";
doc << PrintSep(args);
doc << Doc::Concat(args);
doc << "]";
return doc;
}
Doc PrintDType(DataType dtype) {
return Doc::Text(runtime::DLDataType2String(dtype));
}
Doc VisitType_(const TensorTypeNode* node) final {
// scalar type
if (node->shape.size() == 0) {
......@@ -741,7 +655,7 @@ class PrettyPrinter :
for (ObjectRef shape : node->shape) {
shapes.push_back(PrintAttr(shape));
}
doc << PrintSep(shapes);
doc << Doc::Concat(shapes);
return doc << "), " << PrintDType(node->dtype) << "]";
}
......@@ -751,7 +665,7 @@ class PrettyPrinter :
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << PrintSep(fields);
doc << "(" << Doc::Concat(fields);
// conform to python tuple format (1,)
if (node->fields.size() == 1) {
doc << ",";
......@@ -768,14 +682,14 @@ class PrettyPrinter :
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << PrintSep(type_params);
doc << Doc::Concat(type_params);
doc << "]";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
arg_types.push_back(Print(arg_type));
}
return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type);
}
Doc VisitType_(const RelayRefTypeNode* node) final {
......@@ -795,7 +709,7 @@ class PrettyPrinter :
for (Type type_var : node->type_vars) {
type_vars.push_back(Print(type_var));
}
doc << PrintSep(type_vars) << "]";
doc << Doc::Concat(type_vars) << "]";
}
doc << " ";
......@@ -804,14 +718,14 @@ class PrettyPrinter :
constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
}
Doc separator;
separator << "," << PrintNewLine();
separator << "," << Doc::NewLine();
Doc adt_body;
adt_body << PrintSep(constructor_docs, separator);
adt_body << Doc::Concat(constructor_docs, separator);
// add trailing comma if there are any constructors
if (!constructor_docs.empty()) {
adt_body << ",";
}
doc << Brace(adt_body);
doc << Doc::Brace("{", adt_body, "}");
in_adt_def_ = false;
return doc;
}
......@@ -832,7 +746,7 @@ class PrettyPrinter :
}
return printed_attr;
} else {
return Doc("None");
return Doc::Text("None");
}
}
......@@ -847,28 +761,28 @@ class PrettyPrinter :
for (auto val : op->data) {
arr_vals.push_back(PrintAttr(val));
}
doc << PrintSep(arr_vals);
doc << Doc::Concat(arr_vals);
doc << "]";
return doc;
}
Doc VisitAttr_(const tir::IntImmNode* op) final {
return PrintConstScalar(op->dtype, &(op->value));
return ScalarLiteral(op->dtype, op->value);
}
Doc VisitAttr_(const tir::FloatImmNode* op) final {
return PrintConstScalar(op->dtype, &(op->value));
return ScalarLiteral(op->dtype, op->value);
}
Doc VisitAttr_(const tir::StringImmNode* op) final {
return PrintString(op->value);
return Doc::StrLiteral(op->value);
}
private:
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_;
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
/*! \brief Stack of docs to implement scoped GNFing. */
std::vector<Doc> doc_stack_{};
/*! \brief Map from Expr to Doc */
......@@ -896,9 +810,11 @@ class PrettyPrinter :
/*!
* \brief Attribute printer which prints the attributes in the call.
*/
class PrettyPrinter::AttrPrinter : public AttrVisitor {
class RelayTextPrinter::AttrPrinter :
public AttrVisitor {
public:
AttrPrinter(std::vector<Doc>* doc, PrettyPrinter* parent) : docs(doc), parent_(parent) {}
AttrPrinter(std::vector<Doc>* doc, RelayTextPrinter* parent)
: docs(doc), parent_(parent) {}
template<typename T>
void PrintKV(const char* key, const T& value) {
......@@ -922,16 +838,16 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
PrintKV(key, *value);
}
void Visit(const char* key, bool* value) final {
PrintKV(key, PrintBool(*value));
PrintKV(key, Doc::PyBoolLiteral(*value));
}
void Visit(const char* key, std::string* value) final {
PrintKV(key, PrintString(*value));
PrintKV(key, Doc::StrLiteral(*value));
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "do not allow void as argument";
}
void Visit(const char* key, DataType* value) final {
PrintKV(key, PrintString(runtime::DLDataType2String(*value)));
PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value)));
}
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
......@@ -942,10 +858,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
private:
std::vector<Doc>* docs;
PrettyPrinter* parent_;
RelayTextPrinter* parent_;
};
std::vector<Doc> PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
std::vector<Doc> RelayTextPrinter::PrintCallAttrs(
const Attrs& attrs, const Expr& op) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* op_node = op.as<OpNode>();
......@@ -962,7 +879,7 @@ std::vector<Doc> PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& o
}
}
std::vector<Doc> PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) {
std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* dict_attrs = attrs.as<DictAttrsNode>();
......@@ -974,30 +891,34 @@ std::vector<Doc> PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) {
}
return docs;
}
} // namespace relay
std::string PrettyPrint_(const ObjectRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << kSemVer << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
static const char* kSemVer = "v0.0.4";
// TODO(tvm-team): split into files, related: arith/analyzer.h
//
// - text_printer.h (common header)
// - text_printer.cc (prints modules dispatch into relay and tir files)
// - type_text_printer.cc(specific printing logics for types,
// can also consider put under type_text_printer)
// - Implements AsText
// - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR)
std::string PrettyPrint(const ObjectRef& node) {
Doc doc;
doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
}
std::string AsText(const ObjectRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return PrettyPrint_(node, show_meta_data, annotate);
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) {
Doc doc;
doc << kSemVer << Doc::NewLine()
<< relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
TVM_REGISTER_GLOBAL("relay._expr.AsText")
.set_body_typed(AsText);
} // namespace relay
} // namespace tvm
......@@ -47,8 +47,8 @@ InterpreterClosure::InterpreterClosure(tvm::Map<Var, ObjectRef> env,
data_ = std::move(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<InterpreterClosureObj >([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<InterpreterClosureObj >([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const InterpreterClosureObj*>(ref.get());
p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")";
});
......@@ -68,8 +68,8 @@ RecClosure::RecClosure(InterpreterClosure clos, Var bind) {
data_ = std::move(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RecClosureObj>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RecClosureObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RecClosureObj*>(ref.get());
p->stream << "RecClosureObj(" << node->clos << ")";
});
......@@ -87,8 +87,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefValue")
TVM_REGISTER_NODE_TYPE(RefValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefValueObj>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefValueObj*>(ref.get());
p->stream << "RefValueObj(" << node->value << ")";
});
......@@ -111,8 +111,8 @@ TVM_REGISTER_GLOBAL("relay._make.ConstructorValue")
TVM_REGISTER_NODE_TYPE(ConstructorValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstructorValueObj*>(ref.get());
p->stream << "ConstructorValueObj(" << node->tag << ","
<< node->fields << ")";
......
......@@ -37,8 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_GLOBAL("relay._make.PatternWildcard")
.set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) {
p->stream << "PatternWildcardNode()";
});
......@@ -53,8 +53,8 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_GLOBAL("relay._make.PatternVar")
.set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PatternVarNode*>(ref.get());
p->stream << "PatternVarNode(" << node->var << ")";
});
......@@ -72,8 +72,8 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.PatternConstructor")
.set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PatternConstructorNode*>(ref.get());
p->stream << "PatternConstructorNode(" << node->constructor
<< ", " << node->patterns << ")";
......@@ -90,8 +90,8 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_GLOBAL("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PatternTupleNode*>(ref.get());
p->stream << "PatternTupleNode(" << node->patterns << ")";
});
......@@ -108,8 +108,8 @@ TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_GLOBAL("relay._make.Clause")
.set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ClauseNode*>(ref.get());
p->stream << "ClauseNode(" << node->lhs << ", "
<< node->rhs << ")";
......@@ -128,8 +128,8 @@ TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_GLOBAL("relay._make.Match")
.set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const MatchNode*>(ref.get());
p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ", " << node->complete << ")";
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/doc.h
* \brief Doc ADT used for pretty printing.
* Based on Section 1 of
* https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf, but with
* a vector instead of an implicitly linked list.
*/
#ifndef TVM_RELAY_IR_DOC_H_
#define TVM_RELAY_IR_DOC_H_
#include <tvm/relay/expr.h>
#include <memory>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
// Doc Atom ADT
struct DocAtomNode {
virtual ~DocAtomNode() = default;
};
using DocAtom = std::shared_ptr<DocAtomNode>;
struct TextNode : DocAtomNode {
std::string str;
explicit TextNode(const std::string& str) : str(str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}
};
struct LineNode : DocAtomNode {
int indent;
explicit LineNode(int indent) : indent(indent) {}
};
// Doc is a stream-like interface
class Doc {
public:
Doc() {}
explicit Doc(const std::string& str);
template<typename T>
explicit Doc(const T& str) {
(*this) << str;
}
// Append right to this.
Doc& operator<<(const Doc& right);
// Like above.
Doc& operator<<(const std::string& right);
// Like above.
Doc& operator<<(const DocAtom& right);
// Like above, but converts right to a string first.
template<typename T>
Doc& operator<<(const T& right) {
std::ostringstream os;
os << right;
return *this << os.str();
}
// Indent a doc stream.
friend Doc Indent(int indent, const Doc& doc);
// Wadler's `layout`
std::string str();
private:
std::vector<DocAtom> stream_;
};
// DSL functions
// Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep = Doc(", "));
// Print a constant bool value.
Doc PrintBool(bool value);
// Print a data type.
Doc PrintDType(DataType dtype);
// Print a string.
Doc PrintString(const std::string& value);
// Print a newline.
Doc PrintNewLine(int indent = 0);
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param data The pointer to hold the data.
*/
template<typename T>
Doc PrintConstScalar(DataType dtype, const T* data) {
std::ostringstream os;
if (dtype == DataType::Int(32)) {
os << data[0];
} else if (dtype == DataType::Float(32)) {
os << data[0] << 'f';
} else if (dtype == DataType::Bool()) {
return PrintBool(data[0] != 0);
} else {
// todo(@M.K.) this is unsafe. fix.
os << data[0];
}
return Doc(os.str());
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_DOC_H_
......@@ -21,12 +21,13 @@
* \file src/tvm/relay/ir/expr.cc
* \brief The expression AST nodes of Relay.
*/
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
using tvm::NodePrinter;
using tvm::ReprPrinter;
using namespace tvm::runtime;
Constant ConstantNode::make(runtime::NDArray data) {
......@@ -40,8 +41,8 @@ TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_GLOBAL("relay._make.Constant")
.set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstantNode*>(ref.get());
const PackedFunc* fprint = Registry::Get("relay._constant_repr");
CHECK(fprint) << "unable to find printing function for constants";
......@@ -73,8 +74,8 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay._make.Tuple")
.set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleNode*>(ref.get());
p->stream << "Tuple(" << node->fields << ")";
});
......@@ -98,8 +99,8 @@ TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_GLOBAL("relay._make.Var")
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const VarNode*>(ref.get());
p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) {
......@@ -208,8 +209,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
.set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", "
......@@ -231,8 +232,8 @@ TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_GLOBAL("relay._make.Call")
.set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const CallNode*>(ref.get());
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
......@@ -251,8 +252,8 @@ TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_GLOBAL("relay._make.Let")
.set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const LetNode*>(ref.get());
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")";
......@@ -271,8 +272,8 @@ TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay._make.If")
.set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IfNode*>(ref.get());
p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")";
......@@ -290,8 +291,8 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_GLOBAL("relay._make.TupleGetItem")
.set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleGetItemNode*>(ref.get());
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
......@@ -307,8 +308,8 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_GLOBAL("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefCreateNode*>(ref.get());
p->stream << "RefCreateNode(" << node->value << ")";
});
......@@ -324,8 +325,8 @@ TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_GLOBAL("relay._make.RefRead")
.set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefReadNode*>(ref.get());
p->stream << "RefReadNode(" << node->ref << ")";
});
......@@ -342,8 +343,8 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_GLOBAL("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefWriteNode*>(ref.get());
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});
......
......@@ -23,7 +23,7 @@
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/relay/transform.h>
......@@ -157,8 +157,8 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Function pass: " << info->name
......
......@@ -957,7 +957,7 @@ class FuseMutator : private ExprMutator {
// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = AsText(body, false, [this](const Expr& expr) -> std::string {
std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string {
auto it = gmap_.find(expr.get());
if (it == gmap_.end()) return "";
std::ostringstream os;
......
......@@ -116,8 +116,8 @@ QConfig& QConfig::Current() {
TVM_REGISTER_NODE_TYPE(QConfigNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<QConfigNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<QConfigNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* op = static_cast<const QConfigNode*>(ref.get());
p->stream << "qconfig(";
p->stream << "nbit_input=" << op->nbit_input << ", ";
......
......@@ -23,7 +23,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target.h>
#include <tvm/target/generic_func.h>
#include <tvm/runtime/registry.h>
......
......@@ -23,7 +23,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
......@@ -39,8 +39,8 @@ using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str();
});
......@@ -381,8 +381,8 @@ tvm::BuildConfig BuildConfig::Current() {
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
......
......@@ -21,13 +21,13 @@
* \file target/target_info.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target_info.h>
namespace tvm {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MemoryInfoNode*>(node.get());
p->stream << "mem-info("
<< "unit_bits=" << op->unit_bits << ", "
......
......@@ -453,8 +453,8 @@ Buffer BufferNode::make(Var data,
return Buffer(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferNode*>(node.get());
p->stream << "buffer(" << op->name << ", " << op << ")";
});
......
......@@ -198,8 +198,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
return -1;
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* l = static_cast<const LayoutNode*>(node.get());
p->stream << "Layout(" << l->name << ")";
});
......@@ -365,8 +365,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
return BijectiveLayout(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")";
......
......@@ -57,8 +57,8 @@ IterVar IterVarNode::make(Range dom,
return IterVar(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) {
......@@ -339,8 +339,8 @@ PrimExpr AnyNode::make() {
return PrimExpr(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StringImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StringImmNode*>(node.get());
auto& stream = p->stream;
stream << '"';
......@@ -375,24 +375,24 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
stream << '"';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<CastNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const CastNode*>(node.get());
p->stream << op->dtype << '(';
p->Print(op->value);
p->stream << ')';
})
.set_dispatch<VarNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<VarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const VarNode*>(node.get());
// omit the type
// stream << op->name << "." << op->type;
p->stream << op->name_hint;
})
.set_dispatch<SizeVarNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<SizeVarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SizeVarNode*>(node.get());
p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
})
.set_dispatch<AddNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<AddNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AddNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -400,7 +400,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<SubNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<SubNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SubNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -408,7 +408,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<MulNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<MulNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MulNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -416,7 +416,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<DivNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<DivNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DivNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -424,7 +424,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<ModNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<ModNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -432,7 +432,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<MinNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<MinNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MinNode*>(node.get());
p->stream << "min(";
p->Print(op->a);
......@@ -440,7 +440,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ")";
})
.set_dispatch<MaxNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<MaxNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MaxNode*>(node.get());
p->stream << "max(";
p->Print(op->a);
......@@ -448,7 +448,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ")";
})
.set_dispatch<EQNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<EQNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EQNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -456,7 +456,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<NENode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<NENode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const NENode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -464,7 +464,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<LTNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<LTNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LTNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -472,7 +472,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<LENode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<LENode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LENode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -480,7 +480,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<GTNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<GTNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const GTNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -488,7 +488,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->b);
p->stream << ')';
})
.set_dispatch<GENode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<GENode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const GENode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -497,20 +497,20 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ')';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloorDivNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloorDivNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloorDivNode*>(node.get());
p->stream << "floordiv(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloorModNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloorModNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloorModNode*>(node.get());
p->stream << "floormod(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AndNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AndNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AndNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -519,8 +519,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ')';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OrNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OrNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const OrNode*>(node.get());
p->stream << '(';
p->Print(op->a);
......@@ -529,15 +529,15 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ')';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<NotNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<NotNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const NotNode*>(node.get());
p->stream << '!';
p->Print(op->a);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SelectNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SelectNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SelectNode*>(node.get());
p->stream << "select(";
p->Print(op->condition);
......@@ -548,8 +548,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LoadNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LoadNode*>(node.get());
p->stream << op->buffer_var << "[";
p->Print(op->index);
......@@ -560,8 +560,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
}
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RampNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RampNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RampNode*>(node.get());
p->stream << "ramp(";
p->Print(op->base);
......@@ -570,16 +570,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ", " << op->lanes << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BroadcastNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BroadcastNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BroadcastNode*>(node.get());
p->stream << "x" << op->lanes << "(";
p->Print(op->value);
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const CallNode*>(node.get());
p->stream << op->name << "(";
for (size_t i = 0; i < op->args.size(); ++i) {
......@@ -591,8 +591,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetNode*>(node.get());
p->stream << "(let " << op->var << " = ";
p->Print(op->value);
......@@ -601,13 +601,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AnyNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AnyNode>([](const ObjectRef& node, ReprPrinter* p) {
p->stream << "?";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ReduceNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ReduceNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ReduceNode*>(node.get());
p->stream << "reduce(combiner="
<< op->combiner;
......@@ -618,8 +618,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<CommReducerNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CommReducerNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const CommReducerNode*>(node.get());
p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs
......
......@@ -24,8 +24,8 @@
namespace tvm {
namespace tir {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LoweredFuncNode*>(node.get());
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});
......
......@@ -248,8 +248,8 @@ Stmt EvaluateNode::make(PrimExpr value) {
// Printers
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LetStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetStmtNode*>(node.get());
p->PrintIndent();
p->stream << "let " << op->var << " = ";
......@@ -258,8 +258,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AttrStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AttrStmtNode*>(node.get());
p->PrintIndent();
p->stream << "// attr [";
......@@ -271,8 +271,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AssertStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AssertStmtNode*>(node.get());
p->PrintIndent();
p->stream << "assert(";
......@@ -283,8 +283,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProducerConsumerNode*>(node.get());
if (op->is_producer) {
p->PrintIndent();
......@@ -317,8 +317,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
return out;
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ForNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ForNode*>(node.get());
p->PrintIndent();
p->stream << op->for_type << " (" << op->loop_var << ", ";
......@@ -335,8 +335,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StoreNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer_var << "[";
......@@ -350,8 +350,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ProvideNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProvideNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProvideNode*>(node.get());
p->PrintIndent();
p->stream << op->func->func_name() << "(";
......@@ -368,8 +368,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AllocateNode*>(node.get());
p->PrintIndent();
p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
......@@ -386,16 +386,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FreeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FreeNode*>(node.get());
p->PrintIndent();
p->stream << "free " << op->buffer_var;
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RealizeNode*>(node.get());
p->PrintIndent();
p->stream << "realize " << op->func->func_name() << "(";
......@@ -425,8 +425,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PrefetchNode*>(node.get());
p->PrintIndent();
p->stream << "prefetch " << op->func->func_name() << "(";
......@@ -444,16 +444,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
}
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SeqStmtNode*>(node.get());
for (Stmt stmt : op->seq) {
p->Print(stmt);
}
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IfThenElseNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IfThenElseNode*>(node.get());
p->PrintIndent();
while (true) {
......@@ -483,8 +483,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EvaluateNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EvaluateNode*>(node.get());
p->PrintIndent();
p->Print(op->value);
......@@ -492,7 +492,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
});
template<typename T>
void PrintList(const Array<T> &exprs, NodePrinter* p) {
void PrintList(const Array<T> &exprs, ReprPrinter* p) {
for (size_t i = 0; i < exprs.size(); ++i) {
p->Print(exprs[i]);
if (i < exprs.size() - 1) {
......@@ -501,8 +501,8 @@ void PrintList(const Array<T> &exprs, NodePrinter* p) {
}
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ShuffleNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ShuffleNode*>(node.get());
p->stream << "shuffle(";
PrintList(op->vectors, p);
......
......@@ -39,8 +39,8 @@ namespace tvm {
namespace top {
using namespace tir;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ComputeOpNode*>(node.get());
p->stream << "compute(" << op->name << ", " << op << ")";
});
......
......@@ -31,8 +31,8 @@ namespace tvm {
namespace top {
using namespace tir;
// ExternOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ExternOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ExternOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ExternOpNode*>(node.get());
p->stream << "extern(" << op->name << ", " << op << ")";
});
......
......@@ -37,8 +37,8 @@ namespace tvm {
namespace top {
using namespace tir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<HybridOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const HybridOpNode*>(node.get());
p->stream << "hybrid(" << op->name << ", " << op << ")";
});
......
......@@ -27,8 +27,8 @@ namespace tvm {
namespace top {
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PlaceholderOpNode*>(node.get());
p->stream << "placeholder(" << op->name << ", " << op << ")";
});
......
......@@ -31,8 +31,8 @@ namespace tvm {
namespace top {
using namespace tir;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ScanOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ScanOpNode*>(node.get());
p->stream << "scan(" << op->name << ", " << op << ")";
});
......
......@@ -34,8 +34,8 @@ namespace tvm {
namespace top {
using namespace tir;
// TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TensorComputeOpNode*>(node.get());
p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
});
......
......@@ -795,8 +795,8 @@ TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StageNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StageNode*>(node.get());
if (op->op.defined()) {
p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
......@@ -804,11 +804,11 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "group-stage(" << op << ")";
}
})
.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterVarAttrNode*>(node.get());
p->stream << IterVarType2String(op->iter_type);
})
.set_dispatch<SplitNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<SplitNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SplitNode*>(node.get());
p->stream << "split(parent=";
p->Print(op->parent);
......@@ -818,7 +818,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->inner);
p->stream << ')';
})
.set_dispatch<FuseNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FuseNode*>(node.get());
p->stream << "split(";
p->stream << "outer=";
......@@ -829,7 +829,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->fused);
p->stream << ')';
})
.set_dispatch<RebaseNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<RebaseNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RebaseNode*>(node.get());
p->stream << "rebase(";
p->stream << "parent=";
......@@ -838,13 +838,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<SingletonNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SingletonNode*>(node.get());
p->stream << "singleton(";
p->Print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")";
});
......
......@@ -82,8 +82,8 @@ Tensor TensorNode::make(Array<PrimExpr> shape,
return Tensor(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* t = static_cast<const TensorNode*>(node.get());
p->stream << "Tensor(shape=" << t->shape
<< ", op.name=" << t->op->name << ')';
......@@ -114,8 +114,8 @@ TensorIntrin TensorIntrinNode::make(std::string name,
return TensorIntrin(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TensorIntrinNode*>(node.get());
p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
});
......@@ -139,8 +139,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
return TensorIntrinCall(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment