Unverified Commit e4d817d4 by Tianqi Chen Committed by GitHub

[REFACTOR] Establish printer in the source folder (#4752)

* [REFACTOR] Establish printer in the source folder.

As we move towards the unified IR, we will eventually want to build a unified
printers for both relay and TIR.

This PR isolate the printer component into a separate folder in src as a first step.

- Refactored the Doc DSL using Object, clean up APIs.
- Isolate out the meta data into a header.
- move printer into relay_text_printer, add comments about further TODos.

* Rename NodePrinter -> ReprPrinter to distinguish it from other printers
parent f8f75ca2
......@@ -132,6 +132,7 @@ file(GLOB_RECURSE COMPILER_SRCS
src/autotvm/*.cc
src/tir/*.cc
src/driver/*.cc
src/printer/*.cc
src/api/*.cc
)
......
......@@ -144,7 +144,7 @@ def _GetContext(debugger):
def PrettyPrint(debugger, command, result, internal_dict):
ctx = _GetContext(debugger)
rc = ctx.EvaluateExpression(
"tvm::relay::PrettyPrint({command})".format(command=command)
"tvm::PrettyPrint({command})".format(command=command)
)
result.AppendMessage(str(rc))
......@@ -175,7 +175,7 @@ def _EvalExpressionAsString(logger, ctx, expr):
def _EvalAsNodeRef(logger, ctx, value):
return _EvalExpressionAsString(
logger, ctx, "tvm::relay::PrettyPrint({name})".format(name=value.name)
logger, ctx, "tvm::PrettyPrint({name})".format(name=value.name)
)
......
......@@ -308,5 +308,33 @@ class IRModule : public ObjectRef {
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
};
/*!
* \brief Pretty print a node for debug purposes.
*
* \param node The node to be printed.
* \return The text reperesentation.
* \note This function does not show version or meta-data.
* Use AsText if you want to store the text.
* \sa AsText.
*/
TVM_DLL std::string PrettyPrint(const ObjectRef& node);
/*!
* \brief Render the node as a string in the text format.
*
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
*
* \note We support a limited set of IR nodes that are part of
* relay IR and
*
* \sa PrettyPrint.
* \return The text representation.
*/
TVM_DLL std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
......@@ -139,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of NodePrinter.
* // interface of ReprPrinter.
*
* class NodePrinter {
* class ReprPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
......@@ -152,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
......
......@@ -38,7 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <string>
#include <vector>
......
......@@ -17,25 +17,25 @@
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \file tvm/node/repr_printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_
#ifndef TVM_NODE_REPR_PRINTER_H_
#define TVM_NODE_REPR_PRINTER_H_
#include <tvm/node/functor.h>
#include <iostream>
namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
class ReprPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit NodePrinter(std::ostream& stream) // NOLINT(*)
explicit ReprPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}
/*! \brief The node to be printed. */
......@@ -43,7 +43,7 @@ class NodePrinter {
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
using FType = NodeFunctor<void(const ObjectRef&, ReprPrinter*)>;
TVM_DLL static FType& vtable();
};
......@@ -60,9 +60,9 @@ namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
ReprPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
#endif // TVM_NODE_REPR_PRINTER_H_
......@@ -26,6 +26,7 @@
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <string>
#include <functional>
#include "./base.h"
......@@ -40,6 +41,7 @@ using BaseFunc = tvm::BaseFunc;
using BaseFuncNode = tvm::BaseFuncNode;
using GlobalVar = tvm::GlobalVar;
using GlobalVarNode = tvm::GlobalVarNode;
using tvm::PrettyPrint;
/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
......@@ -539,20 +541,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};
/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const ObjectRef& node);
/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
......
......@@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
......
......@@ -813,8 +813,8 @@ IntSet EvalSet(Range r,
TVM_REGISTER_NODE_TYPE(IntervalSetNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
......
......@@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
......
......@@ -45,8 +45,8 @@ TVM_REGISTER_GLOBAL("relay._make.Constructor")
return Constructor(name_hint, inputs, belong_to);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
......@@ -71,8 +71,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeData")
return TypeData(header, type_vars, constructors);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
......
......@@ -59,8 +59,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
......
......@@ -31,8 +31,8 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")";
});
......
......@@ -111,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) {
annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
CHECK_NE(it->second.size(), 0);
......
......@@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("make.IntImm")
TVM_REGISTER_NODE_TYPE(IntImmNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
......@@ -104,8 +104,8 @@ TVM_REGISTER_GLOBAL("make.FloatImm")
TVM_REGISTER_NODE_TYPE(FloatImmNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
......@@ -134,8 +134,8 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
......@@ -159,15 +159,15 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
return GlobalVar(name);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});
// Container printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ArrayNode*>(node.get());
p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) {
......@@ -179,8 +179,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ']';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
......@@ -194,8 +194,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '}';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StrMapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
......
......@@ -434,8 +434,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
mod->ImportFromStd(path);
});;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IRModuleNode*>(ref.get());
p->stream << "IRModuleNode( " << node->functions << ")";
});
......
......@@ -227,8 +227,8 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const OpNode*>(ref.get());
p->stream << "Op(" << node->name << ")";
});
......
......@@ -48,8 +48,8 @@ SourceName SourceName::Get(const std::string& name) {
TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
......@@ -73,8 +73,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
......
......@@ -27,7 +27,7 @@
namespace tvm {
using tvm::NodePrinter;
using tvm::ReprPrinter;
using namespace tvm::runtime;
TensorType::TensorType(Array<PrimExpr> shape, DataType dtype) {
......@@ -60,8 +60,8 @@ TVM_REGISTER_GLOBAL("relay._make.TensorType")
return TensorType(shape, dtype);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TensorTypeNode*>(ref.get());
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
......
......@@ -24,7 +24,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/ir/transform.h>
// TODO(tqchen): Update to use String container after it is merged.
......@@ -38,7 +38,7 @@ namespace transform {
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::NodePrinter;
using tvm::ReprPrinter;
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
......@@ -341,8 +341,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Info")
*ret = pass->Info();
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
auto* node = static_cast<const PassInfoNode*>(ref.get());
p->stream << "The meta data of the pass: ";
p->stream << "pass name: " << node->name;
......@@ -371,8 +371,8 @@ TVM_REGISTER_GLOBAL("relay._transform.RunPass")
*ret = pass(mod);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ModulePassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Module pass: " << info->name
......@@ -391,8 +391,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential")
*ret = Sequential(passes, pass_info);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SequentialNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Sequential pass: " << info->name
......@@ -421,8 +421,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
*ret = pctx;
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PassContextNode*>(ref.get());
p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
......
......@@ -38,8 +38,8 @@ TVM_REGISTER_GLOBAL("relay._make.PrimType")
return PrimType(dtype);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PrimTypeNode*>(ref.get());
p->stream << node->dtype;
});
......@@ -59,8 +59,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeVar")
return TypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVar(" << node->name_hint << ", "
<< node->kind << ")";
......@@ -81,8 +81,8 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVar(" << node->name_hint << ", "
<< node->kind << ")";
......@@ -110,8 +110,8 @@ TVM_REGISTER_GLOBAL("relay._make.FuncType")
return FuncType(arg_types, ret_type, type_params, type_constraints);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncType(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
......@@ -136,8 +136,8 @@ TVM_REGISTER_GLOBAL("relay._make.TupleType")
return TupleType(fields);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});
......@@ -156,8 +156,8 @@ TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
return IncompleteType(static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
......@@ -176,8 +176,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefType")
TVM_REGISTER_NODE_TYPE(RelayRefTypeNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
p->stream << "RelayRefTypeNode(" << node->value << ")";
});
......
......@@ -40,8 +40,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeCall")
return TypeCall(func, type);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
......@@ -69,8 +69,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
return TypeRelation(func, args, num_inputs, attrs);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeRelationNode*>(ref.get());
p->stream << "TypeRelationNode("
<< node->func->name
......
......@@ -19,13 +19,13 @@
/*!
* Printer utilities
* \file node/printer.cc
* \file node/repr_printer.cc
*/
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
namespace tvm {
void NodePrinter::Print(const ObjectRef& node) {
void ReprPrinter::Print(const ObjectRef& node) {
static const FType& f = vtable();
if (!node.defined()) {
stream << "(nullptr)";
......@@ -39,13 +39,13 @@ void NodePrinter::Print(const ObjectRef& node) {
}
}
void NodePrinter::PrintIndent() {
void ReprPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
}
}
NodePrinter::FType& NodePrinter::vtable() {
ReprPrinter::FType& ReprPrinter::vtable() {
static FType inst;
return inst;
}
......
......@@ -20,43 +20,82 @@
/*!
* \file src/tvm/relay/doc.cc
* \brief Doc ADT used for pretty printing.
* Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf.
*
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
*/
#include <memory>
#include <tvm/runtime/packed_func.h>
#include <vector>
#include <sstream>
#include "doc.h"
namespace tvm {
namespace relay {
// Text constructor
DocAtom Text(const std::string& str) {
return std::make_shared<TextNode>(str);
}
/*!
* \brief Represent a piece of text in the doc.
*/
class DocTextNode : public DocAtomNode {
public:
/*! \brief The str content in the text. */
std::string str;
explicit DocTextNode(std::string str_val)
: str(str_val) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}
// Line constructor
DocAtom Line(int indent = 0) {
return std::make_shared<LineNode>(indent);
}
static constexpr const char* _type_key = "printer.DocText";
TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode);
};
Doc::Doc(const std::string& str) {
if (str == "\n") {
this->stream_ = {Line()};
} else {
this->stream_ = {Text(str)};
TVM_REGISTER_OBJECT_TYPE(DocTextNode);
class DocText : public DocAtom {
public:
explicit DocText(std::string str) {
data_ = runtime::make_object<DocTextNode>(str);
}
}
// DSL function implementations
TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode);
};
/*!
* \brief Represent a line breaker in the doc.
*/
class DocLineNode : public DocAtomNode {
public:
/*! \brief The amount of indent in newline. */
int indent;
explicit DocLineNode(int indent)
: indent(indent) {}
static constexpr const char* _type_key = "printer.DocLine";
TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode);
};
TVM_REGISTER_OBJECT_TYPE(DocLineNode);
class DocLine : public DocAtom {
public:
explicit DocLine(int indent) {
data_ = runtime::make_object<DocLineNode>(indent);
}
TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode);
};
// DSL function implementations
Doc& Doc::operator<<(const Doc& right) {
CHECK(this != &right);
this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end());
this->stream_.insert(
this->stream_.end(), right.stream_.begin(), right.stream_.end());
return *this;
}
Doc& Doc::operator<<(const std::string& right) {
return *this << Doc(right);
Doc& Doc::operator<<(std::string right) {
return *this << DocText(right);
}
Doc& Doc::operator<<(const DocAtom& right) {
......@@ -64,63 +103,71 @@ Doc& Doc::operator<<(const DocAtom& right) {
return *this;
}
Doc Indent(int indent, const Doc& doc) {
Doc ret;
for (auto atom : doc.stream_) {
if (auto text = std::dynamic_pointer_cast<TextNode>(atom)) {
ret.stream_.push_back(text);
} else if (auto line = std::dynamic_pointer_cast<LineNode>(atom)) {
ret.stream_.push_back(Line(indent + line->indent));
} else {CHECK(false);}
}
return ret;
}
std::string Doc::str() {
std::ostringstream os;
for (auto atom : this->stream_) {
if (auto text = std::dynamic_pointer_cast<TextNode>(atom)) {
if (auto* text = atom.as<DocTextNode>()) {
os << text->str;
} else if (auto line = std::dynamic_pointer_cast<LineNode>(atom)) {
} else if (auto* line = atom.as<DocLineNode>()) {
os << "\n" << std::string(line->indent, ' ');
} else {CHECK(false);}
} else {
LOG(FATAL) << "do not expect type " << atom->GetTypeKey();
}
}
return os.str();
}
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
seq = vec[0];
for (size_t i = 1; i < vec.size(); i++) {
seq << sep << vec[i];
}
}
return seq;
Doc Doc::NewLine(int indent) {
return Doc() << DocLine(indent);
}
Doc PrintBool(bool value) {
if (value) {
return Doc("True");
} else {
return Doc("False");
}
Doc Doc::Text(std::string text) {
return Doc() << DocText(text);
}
Doc PrintDType(DataType dtype) {
return Doc(runtime::DLDataType2String(dtype));
Doc Doc::Indent(int indent, Doc doc) {
for (size_t i = 0; i < doc.stream_.size(); ++i) {
if (auto* line = doc.stream_[i].as<DocLineNode>()) {
doc.stream_[i] = DocLine(indent + line->indent);
}
}
return doc;
}
Doc PrintString(const std::string& value) {
Doc Doc::StrLiteral(const std::string& value, std::string quote) {
// TODO(M.K.): add escape.
Doc doc;
return doc << "\"" << value << "\"";
return doc << quote << value << quote;
}
Doc PrintNewLine(int ident) {
Doc Doc::PyBoolLiteral(bool value) {
if (value) {
return Doc::Text("True");
} else {
return Doc::Text("False");
}
}
Doc Doc::Brace(std::string open,
const Doc& body,
std::string close,
int indent) {
Doc doc;
return doc << Line(ident);
doc << open;
doc << Indent(indent, NewLine() << body) << NewLine();
doc << close;
return doc;
}
} // namespace relay
Doc Doc::Concat(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
if (vec.size() == 1) return vec[0];
seq << vec[0];
for (size_t i = 1; i < vec.size(); ++i) {
seq << sep << vec[i];
}
}
return seq;
}
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/printer/doc.h
* \brief Doc ADT used for pretty printing.
*
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
*/
#ifndef TVM_PRINTER_DOC_H_
#define TVM_PRINTER_DOC_H_
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <string>
#include <vector>
#include <type_traits>
namespace tvm {
/*!
* \brief Doc atom node for the ADT.
* \sa DocAtom
*/
class DocAtomNode : public Object {
public:
static constexpr const char* _type_key = "printer.DocAtom";
TVM_DECLARE_BASE_OBJECT_INFO(DocAtomNode, Object);
};
/*!
* \brief Managed reference to DocAtomNode.
* \sa DocAtomNode.
*/
class DocAtom : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode);
};
/*!
* \brief Stream-like interface for Doc DSL.
*
* The Doc DSL de-couples the layout decision from the printing decision.
*
* The layout(code formating) decisions include:
* - Change indentation.
* - Break single line into multiple ones(subjected to future improvements).
*/
class Doc {
public:
/*! \brief default constructor */
Doc() {}
/*!
* \brief Append right to the end of the current doc stream.
* \param right The doc to be appended.
* \return reference to self.
*/
Doc& operator<<(const Doc& right);
/*!
* \brief Append right to the end of the current doc stream.
* \param right The doc to be appended.
* \return reference to self.
* \note pass by value to allow copy elison optimization.
*/
Doc& operator<<(std::string right);
/*!
* \brief Append right to the end of the current doc stream.
* \param right The doc to be appended.
* \return reference to self.
*/
Doc& operator<<(const DocAtom& right);
/*!
* \brief Convert value to string via std::ostreamstream
* the append to the current doc stream.
* \param right The doc to be appended.
* \tparam T the type of the value.
* \return reference to self.
*/
template<typename T,
typename = typename std::enable_if<!std::is_class<T>::value>::type>
Doc& operator<<(const T& value) {
std::ostringstream os;
os << value;
return *this << os.str();
}
/*!
* \brief Convert the doc stream into string.
* \return The string representation.
*/
std::string str();
/*!
* \brief Create a doc that represents text content.
* \return The created doc.
*/
static Doc Text(std::string value);
/*!
* \brief Create a doc that represents a new line.
* \return The created doc.
*/
static Doc NewLine(int indent = 0);
/*!
* \brief Create a new doc that adds indentation to everyline of the doc.
* \param indent The indent to be added.
* \param doc The doc to be indented.
* \return The created doc.
* \note pass by value to allow copy elison optimization.
*/
static Doc Indent(int indent, Doc doc);
/*!
* \brief Create a Doc that represents a string literal.
* \param value The content of the string literal.
* \param quote The quote in the literal.
* \return The created doc.
*/
static Doc StrLiteral(const std::string& value, std::string quote = "\"");
/*!
* \brief Create a Doc that represents a boolean literal in python syntax.
* \param value The bool value.
* \return The created doc.
*/
static Doc PyBoolLiteral(bool value);
/*!
* \brief Enclose body by brace and add indent.
* \param body The body
* \param open The open brace.
* \param close The close brace.
* \param indent amount of indentation.
* \return The created doc.
*/
static Doc Brace(std::string open,
const Doc& body,
std::string close,
int indent = 2);
/*!
* \brief Create a doc by concatenating together with separator.
* \param vec The docs to be concatenated.
* \param sep The seperator.
* \return The created doc.
*/
static Doc Concat(const std::vector<Doc>& vec, const Doc& sep = Text(", "));
private:
/*! \brief Internal doc stream. */
std::vector<DocAtom> stream_;
};
} // namespace tvm
#endif // TVM_PRINTER_DOC_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/printer/meta_data.h
* \brief Meta data context for printers.
*/
#ifndef TVM_PRINTER_META_DATA_H_
#define TVM_PRINTER_META_DATA_H_
#include <tvm/node/serialization.h>
#include <tvm/node/container.h>
#include <string>
#include <unordered_map>
#include "doc.h"
namespace tvm {
/*!
* \brief Meta data context for Printers
*
* This is an important part to enable bi-directional serializability.
* We use tvm's Node system to build the current IR.
* It can be hard to design a text format for all the possible nodes
* as the set of nodes can grow when we do more extensions.
*
* Instead of trying to design readable text format for every node,
* we support a meta data section in the text format.
* We allow the text format to refer to a node in the meta data section.
*
* The meta data section is a json serialized string of an Map<string, Array<NodeRef>>.
* Each element in the meta data section can be referenced by the text format.
* Each meta data node is printed in the following format.
*
* meta[type-key-of-node>][<index-in-meta-section>]
*
* Specifically, consider the following IR(constructed by python).
*
* \code
*
* n = tvm.var("n")
* x = tvm.relay.var("x", shape=(n, 1))
* f = tvm.relay.Function([x], x)
* print(f.astext())
*
* \endcode
*
* The corresponding text format is shown in the following code block.
*
* \code
*
* fn (%x: Tensor[(meta[Variable][0],), float32]) {
* %x
* }
* # Meta data section is a json-serialized string
* # of the following array.
* # [tvm.var("n")]
*
* \endcode
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta data section,
* we print it as meta[Variable][0].
*
* The text parser can recover this object by loading from the corresponding
* location in the meta data section.
*
* This is is a design trade-off.
* It allows us to embedded any meta data in the text format,
* while still being able to tweak the text part of the printed IR easily.
*/
class TextMetaDataContext {
public:
/*!
* \brief Get text representation of meta node.
* \param node The node to be converted to meta node.
* \return A string representation of the meta node.
*/
Doc GetMetaNode(const ObjectRef& node) {
auto it = meta_repr_.find(node);
if (it != meta_repr_.end()) {
return it->second;
}
std::string type_key = node->GetTypeKey();
CHECK(!type_key.empty());
Array<ObjectRef>& mvector =
meta_data_[type_key];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
Doc doc;
doc << "meta[" << type_key << "][" << index << "]";
meta_repr_[node] = doc;
return meta_repr_[node];
}
/*!
* \brief Print a key value pair
*/
Doc PrintKeyValue(const std::string& str, const Doc& v) const {
return Doc() << "\"" << str << "\": " << v;
}
/*!
* \brief Get the metadata section in json format.
* \return the meta data string.
*/
Doc GetMetaSection() const {
if (meta_data_.size() == 0) return Doc();
return Doc::Text(
SaveJSON(Map<std::string, ObjectRef>(meta_data_.begin(), meta_data_.end())));
}
/*! \return whether the meta data context is empty. */
bool empty() const {
return meta_data_.empty();
}
private:
/*! \brief additional metadata stored in TVM json format */
std::unordered_map<std::string, Array<ObjectRef> > meta_data_;
/*! \brief map from meta data into its string representation */
std::unordered_map<ObjectRef, Doc, ObjectHash, ObjectEqual> meta_repr_;
};
} // namespace tvm
#endif // TVM_PRINTER_META_DATA_H_
......@@ -47,8 +47,8 @@ InterpreterClosure::InterpreterClosure(tvm::Map<Var, ObjectRef> env,
data_ = std::move(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<InterpreterClosureObj >([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<InterpreterClosureObj >([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const InterpreterClosureObj*>(ref.get());
p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")";
});
......@@ -68,8 +68,8 @@ RecClosure::RecClosure(InterpreterClosure clos, Var bind) {
data_ = std::move(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RecClosureObj>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RecClosureObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RecClosureObj*>(ref.get());
p->stream << "RecClosureObj(" << node->clos << ")";
});
......@@ -87,8 +87,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefValue")
TVM_REGISTER_NODE_TYPE(RefValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefValueObj>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefValueObj*>(ref.get());
p->stream << "RefValueObj(" << node->value << ")";
});
......@@ -111,8 +111,8 @@ TVM_REGISTER_GLOBAL("relay._make.ConstructorValue")
TVM_REGISTER_NODE_TYPE(ConstructorValueObj);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstructorValueObj*>(ref.get());
p->stream << "ConstructorValueObj(" << node->tag << ","
<< node->fields << ")";
......
......@@ -37,8 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_GLOBAL("relay._make.PatternWildcard")
.set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) {
p->stream << "PatternWildcardNode()";
});
......@@ -53,8 +53,8 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_GLOBAL("relay._make.PatternVar")
.set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PatternVarNode*>(ref.get());
p->stream << "PatternVarNode(" << node->var << ")";
});
......@@ -72,8 +72,8 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.PatternConstructor")
.set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PatternConstructorNode*>(ref.get());
p->stream << "PatternConstructorNode(" << node->constructor
<< ", " << node->patterns << ")";
......@@ -90,8 +90,8 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_GLOBAL("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PatternTupleNode*>(ref.get());
p->stream << "PatternTupleNode(" << node->patterns << ")";
});
......@@ -108,8 +108,8 @@ TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_GLOBAL("relay._make.Clause")
.set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ClauseNode*>(ref.get());
p->stream << "ClauseNode(" << node->lhs << ", "
<< node->rhs << ")";
......@@ -128,8 +128,8 @@ TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_GLOBAL("relay._make.Match")
.set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const MatchNode*>(ref.get());
p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ", " << node->complete << ")";
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/doc.h
* \brief Doc ADT used for pretty printing.
* Based on Section 1 of
* https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf, but with
* a vector instead of an implicitly linked list.
*/
#ifndef TVM_RELAY_IR_DOC_H_
#define TVM_RELAY_IR_DOC_H_
#include <tvm/relay/expr.h>
#include <memory>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
// Doc Atom ADT
struct DocAtomNode {
virtual ~DocAtomNode() = default;
};
using DocAtom = std::shared_ptr<DocAtomNode>;
struct TextNode : DocAtomNode {
std::string str;
explicit TextNode(const std::string& str) : str(str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}
};
struct LineNode : DocAtomNode {
int indent;
explicit LineNode(int indent) : indent(indent) {}
};
// Doc is a stream-like interface
class Doc {
public:
Doc() {}
explicit Doc(const std::string& str);
template<typename T>
explicit Doc(const T& str) {
(*this) << str;
}
// Append right to this.
Doc& operator<<(const Doc& right);
// Like above.
Doc& operator<<(const std::string& right);
// Like above.
Doc& operator<<(const DocAtom& right);
// Like above, but converts right to a string first.
template<typename T>
Doc& operator<<(const T& right) {
std::ostringstream os;
os << right;
return *this << os.str();
}
// Indent a doc stream.
friend Doc Indent(int indent, const Doc& doc);
// Wadler's `layout`
std::string str();
private:
std::vector<DocAtom> stream_;
};
// DSL functions
// Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep = Doc(", "));
// Print a constant bool value.
Doc PrintBool(bool value);
// Print a data type.
Doc PrintDType(DataType dtype);
// Print a string.
Doc PrintString(const std::string& value);
// Print a newline.
Doc PrintNewLine(int indent = 0);
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param data The pointer to hold the data.
*/
template<typename T>
Doc PrintConstScalar(DataType dtype, const T* data) {
std::ostringstream os;
if (dtype == DataType::Int(32)) {
os << data[0];
} else if (dtype == DataType::Float(32)) {
os << data[0] << 'f';
} else if (dtype == DataType::Bool()) {
return PrintBool(data[0] != 0);
} else {
// todo(@M.K.) this is unsafe. fix.
os << data[0];
}
return Doc(os.str());
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_DOC_H_
......@@ -21,12 +21,13 @@
* \file src/tvm/relay/ir/expr.cc
* \brief The expression AST nodes of Relay.
*/
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
using tvm::NodePrinter;
using tvm::ReprPrinter;
using namespace tvm::runtime;
Constant ConstantNode::make(runtime::NDArray data) {
......@@ -40,8 +41,8 @@ TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_GLOBAL("relay._make.Constant")
.set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstantNode*>(ref.get());
const PackedFunc* fprint = Registry::Get("relay._constant_repr");
CHECK(fprint) << "unable to find printing function for constants";
......@@ -73,8 +74,8 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay._make.Tuple")
.set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleNode*>(ref.get());
p->stream << "Tuple(" << node->fields << ")";
});
......@@ -98,8 +99,8 @@ TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_GLOBAL("relay._make.Var")
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const VarNode*>(ref.get());
p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) {
......@@ -208,8 +209,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
.set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", "
......@@ -231,8 +232,8 @@ TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_GLOBAL("relay._make.Call")
.set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const CallNode*>(ref.get());
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
......@@ -251,8 +252,8 @@ TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_GLOBAL("relay._make.Let")
.set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const LetNode*>(ref.get());
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")";
......@@ -271,8 +272,8 @@ TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay._make.If")
.set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IfNode*>(ref.get());
p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")";
......@@ -290,8 +291,8 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_GLOBAL("relay._make.TupleGetItem")
.set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleGetItemNode*>(ref.get());
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
......@@ -307,8 +308,8 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_GLOBAL("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefCreateNode*>(ref.get());
p->stream << "RefCreateNode(" << node->value << ")";
});
......@@ -324,8 +325,8 @@ TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_GLOBAL("relay._make.RefRead")
.set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefReadNode*>(ref.get());
p->stream << "RefReadNode(" << node->ref << ")";
});
......@@ -342,8 +343,8 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_GLOBAL("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const RefWriteNode*>(ref.get());
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});
......
......@@ -23,7 +23,7 @@
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/relay/transform.h>
......@@ -157,8 +157,8 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Function pass: " << info->name
......
......@@ -957,7 +957,7 @@ class FuseMutator : private ExprMutator {
// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = AsText(body, false, [this](const Expr& expr) -> std::string {
std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string {
auto it = gmap_.find(expr.get());
if (it == gmap_.end()) return "";
std::ostringstream os;
......
......@@ -116,8 +116,8 @@ QConfig& QConfig::Current() {
TVM_REGISTER_NODE_TYPE(QConfigNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<QConfigNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<QConfigNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* op = static_cast<const QConfigNode*>(ref.get());
p->stream << "qconfig(";
p->stream << "nbit_input=" << op->nbit_input << ", ";
......
......@@ -23,7 +23,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target.h>
#include <tvm/target/generic_func.h>
#include <tvm/runtime/registry.h>
......
......@@ -23,7 +23,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
......@@ -39,8 +39,8 @@ using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str();
});
......@@ -381,8 +381,8 @@ tvm::BuildConfig BuildConfig::Current() {
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
......
......@@ -21,13 +21,13 @@
* \file target/target_info.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target_info.h>
namespace tvm {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MemoryInfoNode*>(node.get());
p->stream << "mem-info("
<< "unit_bits=" << op->unit_bits << ", "
......
......@@ -453,8 +453,8 @@ Buffer BufferNode::make(Var data,
return Buffer(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferNode*>(node.get());
p->stream << "buffer(" << op->name << ", " << op << ")";
});
......
......@@ -198,8 +198,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
return -1;
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* l = static_cast<const LayoutNode*>(node.get());
p->stream << "Layout(" << l->name << ")";
});
......@@ -365,8 +365,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
return BijectiveLayout(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")";
......
......@@ -24,8 +24,8 @@
namespace tvm {
namespace tir {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LoweredFuncNode*>(node.get());
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});
......
......@@ -248,8 +248,8 @@ Stmt EvaluateNode::make(PrimExpr value) {
// Printers
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LetStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetStmtNode*>(node.get());
p->PrintIndent();
p->stream << "let " << op->var << " = ";
......@@ -258,8 +258,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AttrStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AttrStmtNode*>(node.get());
p->PrintIndent();
p->stream << "// attr [";
......@@ -271,8 +271,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AssertStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AssertStmtNode*>(node.get());
p->PrintIndent();
p->stream << "assert(";
......@@ -283,8 +283,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProducerConsumerNode*>(node.get());
if (op->is_producer) {
p->PrintIndent();
......@@ -317,8 +317,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
return out;
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ForNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ForNode*>(node.get());
p->PrintIndent();
p->stream << op->for_type << " (" << op->loop_var << ", ";
......@@ -335,8 +335,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StoreNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer_var << "[";
......@@ -350,8 +350,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ProvideNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProvideNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProvideNode*>(node.get());
p->PrintIndent();
p->stream << op->func->func_name() << "(";
......@@ -368,8 +368,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AllocateNode*>(node.get());
p->PrintIndent();
p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
......@@ -386,16 +386,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->body);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FreeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FreeNode*>(node.get());
p->PrintIndent();
p->stream << "free " << op->buffer_var;
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RealizeNode*>(node.get());
p->PrintIndent();
p->stream << "realize " << op->func->func_name() << "(";
......@@ -425,8 +425,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PrefetchNode*>(node.get());
p->PrintIndent();
p->stream << "prefetch " << op->func->func_name() << "(";
......@@ -444,16 +444,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
}
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SeqStmtNode*>(node.get());
for (Stmt stmt : op->seq) {
p->Print(stmt);
}
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IfThenElseNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IfThenElseNode*>(node.get());
p->PrintIndent();
while (true) {
......@@ -483,8 +483,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "}\n";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EvaluateNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EvaluateNode*>(node.get());
p->PrintIndent();
p->Print(op->value);
......@@ -492,7 +492,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
});
template<typename T>
void PrintList(const Array<T> &exprs, NodePrinter* p) {
void PrintList(const Array<T> &exprs, ReprPrinter* p) {
for (size_t i = 0; i < exprs.size(); ++i) {
p->Print(exprs[i]);
if (i < exprs.size() - 1) {
......@@ -501,8 +501,8 @@ void PrintList(const Array<T> &exprs, NodePrinter* p) {
}
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ShuffleNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ShuffleNode*>(node.get());
p->stream << "shuffle(";
PrintList(op->vectors, p);
......
......@@ -39,8 +39,8 @@ namespace tvm {
namespace top {
using namespace tir;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ComputeOpNode*>(node.get());
p->stream << "compute(" << op->name << ", " << op << ")";
});
......
......@@ -31,8 +31,8 @@ namespace tvm {
namespace top {
using namespace tir;
// ExternOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ExternOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ExternOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ExternOpNode*>(node.get());
p->stream << "extern(" << op->name << ", " << op << ")";
});
......
......@@ -37,8 +37,8 @@ namespace tvm {
namespace top {
using namespace tir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<HybridOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const HybridOpNode*>(node.get());
p->stream << "hybrid(" << op->name << ", " << op << ")";
});
......
......@@ -27,8 +27,8 @@ namespace tvm {
namespace top {
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PlaceholderOpNode*>(node.get());
p->stream << "placeholder(" << op->name << ", " << op << ")";
});
......
......@@ -31,8 +31,8 @@ namespace tvm {
namespace top {
using namespace tir;
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ScanOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ScanOpNode*>(node.get());
p->stream << "scan(" << op->name << ", " << op << ")";
});
......
......@@ -34,8 +34,8 @@ namespace tvm {
namespace top {
using namespace tir;
// TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TensorComputeOpNode*>(node.get());
p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
});
......
......@@ -795,8 +795,8 @@ TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StageNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StageNode*>(node.get());
if (op->op.defined()) {
p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
......@@ -804,11 +804,11 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "group-stage(" << op << ")";
}
})
.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterVarAttrNode*>(node.get());
p->stream << IterVarType2String(op->iter_type);
})
.set_dispatch<SplitNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<SplitNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SplitNode*>(node.get());
p->stream << "split(parent=";
p->Print(op->parent);
......@@ -818,7 +818,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->inner);
p->stream << ')';
})
.set_dispatch<FuseNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FuseNode*>(node.get());
p->stream << "split(";
p->stream << "outer=";
......@@ -829,7 +829,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->fused);
p->stream << ')';
})
.set_dispatch<RebaseNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<RebaseNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RebaseNode*>(node.get());
p->stream << "rebase(";
p->stream << "parent=";
......@@ -838,13 +838,13 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->Print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<SingletonNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SingletonNode*>(node.get());
p->stream << "singleton(";
p->Print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ObjectRef& node, NodePrinter* p) {
.set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")";
});
......
......@@ -82,8 +82,8 @@ Tensor TensorNode::make(Array<PrimExpr> shape,
return Tensor(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* t = static_cast<const TensorNode*>(node.get());
p->stream << "Tensor(shape=" << t->shape
<< ", op.name=" << t->op->name << ')';
......@@ -114,8 +114,8 @@ TensorIntrin TensorIntrinNode::make(std::string name,
return TensorIntrin(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TensorIntrinNode*>(node.get());
p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
});
......@@ -139,8 +139,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
return TensorIntrinCall(n);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment