Unverified Commit 00a6474a by Tianqi Chen Committed by GitHub

[REFACTOR] IRPrinter->NodePrinter, move to node/printer.h (#4622)

Rationale: printer is a common infra that is shared across all nodes.
parent 81523604
......@@ -470,37 +470,6 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
}
return ret;
}
// Printer infra.
/*! \brief A Pretty printer class to print the IR. */
class IRPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit IRPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}
/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>;
TVM_DLL static FType& vtable();
};
} // namespace tvm
namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
namespace std {
......
......@@ -24,14 +24,16 @@
#define TVM_NODE_FUNCTOR_H_
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <tvm/runtime/object.h>
#include <vector>
#include <type_traits>
#include <utility>
namespace tvm {
using runtime::ObjectRef;
/*!
* \brief A dynamically dispatched functor on the type of the first argument.
*
......@@ -137,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement IRPrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of IRPrinter.
* // interface of NodePrinter.
*
* class IRPrinter {
* class NodePrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
......@@ -150,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, IRPrinter *)>;
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, IRPrinter* p) {
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
......
......@@ -38,6 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>
#include <string>
#include <vector>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_
#include <tvm/node/functor.h>
#include <iostream>
namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit NodePrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}
/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
TVM_DLL static FType& vtable();
};
} // namespace tvm
namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
......@@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
......
......@@ -811,8 +811,8 @@ IntSet EvalSet(Range r,
TVM_REGISTER_NODE_TYPE(IntervalSetNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
......
......@@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
......
......@@ -41,8 +41,8 @@ using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_REGISTER_NODE_TYPE(GenericFuncNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str();
});
......@@ -665,8 +665,8 @@ tvm::BuildConfig BuildConfig::Current() {
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
......
......@@ -22,6 +22,9 @@
* \brief ARM specific code generator
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/registry.h>
#include "codegen_cpu.h"
namespace tvm {
......
......@@ -22,6 +22,8 @@
* \brief X86-64 specific code generator
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/registry.h>
#include "codegen_cpu.h"
#include "llvm/MC/MCSubtargetInfo.h"
......
......@@ -22,7 +22,9 @@
* \brief LLVM runtime module for TVM
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/codegen.h>
#include <mutex>
#include "llvm_common.h"
......
......@@ -22,6 +22,7 @@
* \brief Source code module, only for viewing
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include "codegen_source_base.h"
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
......
......@@ -20,6 +20,7 @@
/*!
* \file codegen_hybrid.cc
*/
#include <tvm/runtime/registry.h>
#include <iomanip>
#include <cctype>
#include "codegen_hybrid.h"
......
......@@ -21,6 +21,7 @@
* \brief The span data structure.
*/
#include <tvm/ir/span.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
......@@ -48,8 +49,8 @@ SourceName SourceName::Get(const std::string& name) {
TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
......@@ -73,8 +74,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
......
......@@ -22,6 +22,7 @@
* \brief Common type system AST nodes throughout the IR.
*/
#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
......@@ -40,8 +41,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeVar")
return TypeVarNode::make(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVar(" << node->name_hint << ", "
<< node->kind << ")";
......@@ -61,8 +62,8 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVar(" << node->name_hint << ", "
<< node->kind << ")";
......@@ -85,8 +86,8 @@ TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncType(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
......
......@@ -61,8 +61,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
......
......@@ -450,8 +450,8 @@ Buffer BufferNode::make(Var data,
return Buffer(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BufferNode*>(node.get());
p->stream << "buffer(" << op->name << ", " << op << ")";
});
......
......@@ -194,8 +194,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
return -1;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, NodePrinter* p) {
auto* l = static_cast<const LayoutNode*>(node.get());
p->stream << "Layout(" << l->name << ")";
});
......@@ -361,8 +361,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
return BijectiveLayout(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, NodePrinter* p) {
auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")";
......
......@@ -97,33 +97,8 @@ Var var(std::string name_hint, DataType t) {
return Var(name_hint, t);
}
void IRPrinter::Print(const ObjectRef& ir) {
static const FType& f = vtable();
if (!ir.defined()) {
stream << "(nullptr)";
} else {
if (f.can_dispatch(ir)) {
f(ir, this);
} else {
// default value, output type key and addr.
stream << ir->GetTypeKey() << "(" << ir.get() << ")";
}
}
}
void IRPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
}
}
IRPrinter::FType& IRPrinter::vtable() {
static FType inst;
return inst;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntImm>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImm>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntImm*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
......@@ -132,8 +107,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) {
......@@ -148,8 +123,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
......
......@@ -24,8 +24,8 @@
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const LoweredFuncNode*>(node.get());
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});
......
......@@ -26,8 +26,8 @@
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const MemoryInfoNode*>(node.get());
p->stream << "mem-info("
<< "unit_bits=" << op->unit_bits << ", "
......
......@@ -67,8 +67,8 @@ Tensor TensorNode::make(Array<Expr> shape,
return Tensor(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorNode>([](const ObjectRef& node, NodePrinter* p) {
auto* t = static_cast<const TensorNode*>(node.get());
p->stream << "Tensor(shape=" << t->shape
<< ", op.name=" << t->op->name << ')';
......@@ -99,8 +99,8 @@ TensorIntrin TensorIntrinNode::make(std::string name,
return TensorIntrin(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const TensorIntrinNode*>(node.get());
p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
});
......@@ -124,8 +124,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
return TensorIntrinCall(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, NodePrinter* p) {
auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
});
......
......@@ -26,13 +26,13 @@
namespace tvm {
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")";
});
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Printer utilities
* \file node/printer.cc
*/
#include <tvm/node/printer.h>
namespace tvm {
void NodePrinter::Print(const ObjectRef& node) {
static const FType& f = vtable();
if (!node.defined()) {
stream << "(nullptr)";
} else {
if (f.can_dispatch(node)) {
f(node, this);
} else {
// default value, output type key and addr.
stream << node->GetTypeKey() << "(" << node.get() << ")";
}
}
}
void NodePrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
}
}
NodePrinter::FType& NodePrinter::vtable() {
static FType inst;
return inst;
}
} // namespace tvm
......@@ -39,8 +39,8 @@ namespace tvm {
using namespace ir;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ComputeOpNode*>(node.get());
p->stream << "compute(" << op->name << ", " << op << ")";
});
......
......@@ -30,8 +30,8 @@
namespace tvm {
using namespace ir;
// ExternOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ExternOpNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ExternOpNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ExternOpNode*>(node.get());
p->stream << "extern(" << op->name << ", " << op << ")";
});
......
......@@ -36,8 +36,8 @@
namespace tvm {
using namespace ir;
// HybridOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<HybridOpNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<HybridOpNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const HybridOpNode*>(node.get());
p->stream << "hybrid(" << op->name << ", " << op << ")";
});
......
......@@ -26,8 +26,8 @@
namespace tvm {
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const PlaceholderOpNode*>(node.get());
p->stream << "placeholder(" << op->name << ", " << op << ")";
});
......
......@@ -31,8 +31,8 @@ namespace tvm {
using namespace ir;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ScanOpNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ScanOpNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ScanOpNode*>(node.get());
p->stream << "scan(" << op->name << ", " << op << ")";
});
......
......@@ -33,8 +33,8 @@
namespace tvm {
using namespace ir;
// TensorComputeOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const TensorComputeOpNode*>(node.get());
p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
});
......
......@@ -54,8 +54,8 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
TVM_REGISTER_GLOBAL("relay._make.Closure")
.set_body_typed(ClosureNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ClosureNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ClosureNode*>(ref.get());
p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
});
......@@ -73,8 +73,8 @@ RecClosure RecClosureNode::make(Closure clos, Var bind) {
TVM_REGISTER_GLOBAL("relay._make.RecClosure")
.set_body_typed(RecClosureNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RecClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RecClosureNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RecClosureNode*>(ref.get());
p->stream << "RecClosureNode(" << node->clos << ")";
});
......@@ -88,8 +88,8 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) {
TVM_REGISTER_GLOBAL("relay._make.TupleValue")
.set_body_typed(TupleValueNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleValueNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleValueNode*>(ref.get());
p->stream << "TupleValueNode(" << node->fields << ")";
});
......@@ -100,8 +100,8 @@ TensorValue TensorValueNode::make(runtime::NDArray data) {
return TensorValue(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TensorValueNode*>(ref.get());
auto to_str = GetPackedFunc("relay._tensor_value_repr");
std::string data_str = to_str(GetRef<TensorValue>(node));
......@@ -122,8 +122,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefValue")
TVM_REGISTER_NODE_TYPE(RefValueNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefValueNode*>(ref.get());
p->stream << "RefValueNode(" << node->value << ")";
});
......@@ -143,8 +143,8 @@ TVM_REGISTER_GLOBAL("relay._make.ConstructorValue")
TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorValueNode*>(ref.get());
p->stream << "ConstructorValueNode(" << node->tag << ","
<< node->fields << ")";
......
......@@ -37,8 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_GLOBAL("relay._make.PatternWildcard")
.set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, NodePrinter* p) {
p->stream << "PatternWildcardNode()";
});
......@@ -53,8 +53,8 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_GLOBAL("relay._make.PatternVar")
.set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const PatternVarNode*>(ref.get());
p->stream << "PatternVarNode(" << node->var << ")";
});
......@@ -72,8 +72,8 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.PatternConstructor")
.set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const PatternConstructorNode*>(ref.get());
p->stream << "PatternConstructorNode(" << node->constructor
<< ", " << node->patterns << ")";
......@@ -90,8 +90,8 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_GLOBAL("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const PatternTupleNode*>(ref.get());
p->stream << "PatternTupleNode(" << node->patterns << ")";
});
......@@ -111,8 +111,8 @@ TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.Constructor")
.set_body_typed(ConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
......@@ -133,8 +133,8 @@ TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_GLOBAL("relay._make.TypeData")
.set_body_typed(TypeDataNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
......@@ -152,8 +152,8 @@ TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_GLOBAL("relay._make.Clause")
.set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ClauseNode*>(ref.get());
p->stream << "ClauseNode(" << node->lhs << ", "
<< node->rhs << ")";
......@@ -172,8 +172,8 @@ TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_GLOBAL("relay._make.Match")
.set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const MatchNode*>(ref.get());
p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ", " << node->complete << ")";
......
......@@ -26,7 +26,7 @@
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using tvm::NodePrinter;
using namespace tvm::runtime;
Constant ConstantNode::make(runtime::NDArray data) {
......@@ -40,8 +40,8 @@ TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_GLOBAL("relay._make.Constant")
.set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ConstantNode*>(ref.get());
const PackedFunc* fprint = Registry::Get("relay._constant_repr");
CHECK(fprint) << "unable to find printing function for constants";
......@@ -73,8 +73,8 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay._make.Tuple")
.set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleNode*>(ref.get());
p->stream << "Tuple(" << node->fields << ")";
});
......@@ -98,8 +98,8 @@ TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_GLOBAL("relay._make.Var")
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const VarNode*>(ref.get());
p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) {
......@@ -120,8 +120,8 @@ TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
.set_body_typed(GlobalVarNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});
......@@ -226,8 +226,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
.set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", "
......@@ -249,8 +249,8 @@ TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_GLOBAL("relay._make.Call")
.set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const CallNode*>(ref.get());
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
......@@ -269,8 +269,8 @@ TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_GLOBAL("relay._make.Let")
.set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const LetNode*>(ref.get());
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ")";
......@@ -289,8 +289,8 @@ TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay._make.If")
.set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const IfNode*>(ref.get());
p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< ", " << node->false_branch << ")";
......@@ -308,8 +308,8 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_GLOBAL("relay._make.TupleGetItem")
.set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleGetItemNode*>(ref.get());
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
......@@ -325,8 +325,8 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_GLOBAL("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefCreateNode*>(ref.get());
p->stream << "RefCreateNode(" << node->value << ")";
});
......@@ -342,8 +342,8 @@ TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_GLOBAL("relay._make.RefRead")
.set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefReadNode*>(ref.get());
p->stream << "RefReadNode(" << node->ref << ")";
});
......@@ -360,8 +360,8 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_GLOBAL("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefWriteNode*>(ref.get());
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});
......
......@@ -31,7 +31,7 @@
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using tvm::NodePrinter;
using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
......@@ -414,8 +414,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
mod->ImportFromStd(path);
});;
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModuleNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ModuleNode*>(ref.get());
p->stream << "ModuleNode( " << node->functions << ")";
});
......
......@@ -224,8 +224,8 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const OpNode*>(ref.get());
p->stream << "Op(" << node->name << ")";
});
......
......@@ -26,7 +26,7 @@
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using tvm::NodePrinter;
using namespace tvm::runtime;
TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
......@@ -57,8 +57,8 @@ TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TensorType")
.set_body_typed(TensorTypeNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TensorTypeNode*>(ref.get());
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
......@@ -75,8 +75,8 @@ TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_GLOBAL("relay._make.TypeCall")
.set_body_typed(TypeCallNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeCallNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeCallNode*>(ref.get());
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
......@@ -95,8 +95,8 @@ TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
return IncompleteTypeNode::make(static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
......@@ -118,8 +118,8 @@ TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
.set_body_typed(TypeRelationNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeRelationNode*>(ref.get());
p->stream << "TypeRelationNode("
<< node->func->name
......@@ -137,8 +137,8 @@ TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TupleType")
.set_body_typed(TupleTypeNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TupleTypeNode*>(ref.get());
p->stream << "TupleTypeNode(" << node->fields << ")";
});
......@@ -154,8 +154,8 @@ TVM_REGISTER_GLOBAL("relay._make.RefType")
TVM_REGISTER_NODE_TYPE(RefTypeNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RefTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const RefTypeNode*>(ref.get());
p->stream << "RefTypeNode(" << node->value << ")";
});
......
......@@ -34,7 +34,7 @@ namespace tvm {
namespace relay {
namespace transform {
using tvm::IRPrinter;
using tvm::NodePrinter;
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
......@@ -453,8 +453,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Info")
*ret = pass->Info();
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::NodePrinter* p) {
auto* node = static_cast<const PassInfoNode*>(ref.get());
p->stream << "The meta data of the pass: ";
p->stream << "pass name: " << node->name;
......@@ -479,8 +479,8 @@ TVM_REGISTER_GLOBAL("relay._transform.RunPass")
*ret = pass(mod);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const ModulePassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Module pass: " << info->name
......@@ -492,8 +492,8 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const FunctionPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Function pass: " << info->name
......@@ -512,8 +512,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential")
*ret = Sequential(passes, pass_info);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const SequentialNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Sequential pass: " << info->name
......@@ -542,8 +542,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
*ret = pctx;
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const PassContextNode*>(ref.get());
p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
......
......@@ -116,8 +116,8 @@ QConfig& QConfig::Current() {
TVM_REGISTER_NODE_TYPE(QConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<QConfigNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<QConfigNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* op = static_cast<const QConfigNode*>(ref.get());
p->stream << "qconfig(";
p->stream << "nbit_input=" << op->nbit_input << ", ";
......
......@@ -798,8 +798,8 @@ TVM_REGISTER_NODE_TYPE(SingletonNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
// Printer
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<StageNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StageNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const StageNode*>(node.get());
if (op->op.defined()) {
p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
......@@ -807,11 +807,11 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "group-stage(" << op << ")";
}
})
.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, IRPrinter* p) {
.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IterVarAttrNode*>(node.get());
p->stream << IterVarType2String(op->iter_type);
})
.set_dispatch<SplitNode>([](const ObjectRef& node, IRPrinter* p) {
.set_dispatch<SplitNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const SplitNode*>(node.get());
p->stream << "split(parent=";
p->Print(op->parent);
......@@ -821,7 +821,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->Print(op->inner);
p->stream << ')';
})
.set_dispatch<FuseNode>([](const ObjectRef& node, IRPrinter* p) {
.set_dispatch<FuseNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const FuseNode*>(node.get());
p->stream << "split(";
p->stream << "outer=";
......@@ -832,7 +832,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->Print(op->fused);
p->stream << ')';
})
.set_dispatch<RebaseNode>([](const ObjectRef& node, IRPrinter* p) {
.set_dispatch<RebaseNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const RebaseNode*>(node.get());
p->stream << "rebase(";
p->stream << "parent=";
......@@ -841,13 +841,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->Print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const ObjectRef& node, IRPrinter* p) {
.set_dispatch<SingletonNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const SingletonNode*>(node.get());
p->stream << "singleton(";
p->Print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ObjectRef& node, IRPrinter* p) {
.set_dispatch<ScheduleNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ScheduleNode*>(node.get());
p->stream << "schedule(" << op << ")";
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment