Commit dc97e527 by Josh Pollock Committed by Tianqi Chen

[Relay][Text Format] Pretty Printer Smart Inlining (#2881)

parent fefbb006
...@@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None): ...@@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None):
""" """
return _ir_pass.eliminate_common_subexpr(expr, fskip) return _ir_pass.eliminate_common_subexpr(expr, fskip)
def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
"""
THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT!
USE `.astext()` INSTEAD!
A version of the pretty printer intended for debugging passes. Contains
advanced printing options.
Parameters
----------
ast : Union[relay.Expr, relay.Module, relay.Type]
The relay fragment to be turned into text.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
gnf : bool
Whether to print in GNF. If it is disabled, pointers are left implicit.
Returns
-------
text : str
A text representation of `ast`.
"""
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)
def partial_evaluate(expr): def partial_evaluate(expr):
""" """
Evaluate the static fragment of the code. Evaluate the static fragment of the code.
......
...@@ -22,12 +22,22 @@ ...@@ -22,12 +22,22 @@
* \file pretty_printer.cc * \file pretty_printer.cc
* \brief Pretty printer for Relay programs * \brief Pretty printer for Relay programs
* Supports ANF, GNF, and metadata. * Supports ANF, GNF, and metadata.
*
* Inlining heuristics:
* - Always inline:
* - GlobalVar
* - Constant
* - Op
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/ */
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include "doc.h" #include "doc.h"
#include "type_functor.h" #include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h" #include "../../lang/attr_functor.h"
namespace tvm { namespace tvm {
...@@ -135,10 +145,8 @@ class PrettyPrinter : ...@@ -135,10 +145,8 @@ class PrettyPrinter :
public TypeFunctor<Doc(const Type&)>, public TypeFunctor<Doc(const Type&)>,
public AttrFunctor<Doc(const NodeRef&)> { public AttrFunctor<Doc(const NodeRef&)> {
public: public:
explicit PrettyPrinter(bool GNF, explicit PrettyPrinter(bool show_meta_data,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) : runtime::TypedPackedFunc<std::string(Expr)> annotate) :
GNF_(GNF),
show_meta_data_(show_meta_data), show_meta_data_(show_meta_data),
annotate_(annotate) {} annotate_(annotate) {}
...@@ -150,10 +158,9 @@ class PrettyPrinter : ...@@ -150,10 +158,9 @@ class PrettyPrinter :
Doc doc; Doc doc;
// additional information in comment. // additional information in comment.
if (annotate_ != nullptr) { if (annotate_ != nullptr) {
return doc << " // " << annotate_(expr); return doc << " /* " << annotate_(expr) << " */";
} else if (expr->checked_type_.defined()) { } else if (expr->checked_type_.defined()) {
doc << " // ty="; return doc << " /* ty=" << Print(expr->checked_type()) << " */";
return doc << Print(expr->checked_type());
} else { } else {
return doc; return doc;
} }
...@@ -176,13 +183,18 @@ class PrettyPrinter : ...@@ -176,13 +183,18 @@ class PrettyPrinter :
// print in a new scope // print in a new scope
doc_stack_.push_back(Doc()); doc_stack_.push_back(Doc());
// must print first so doc_stack_.back() reference doesn't become stale // must print first so doc_stack_.back() reference doesn't become stale
Doc doc = Print(node); Doc doc = Print(node, false, true);
doc = doc_stack_.back() << doc; doc = doc_stack_.back() << doc;
doc_stack_.pop_back(); doc_stack_.pop_back();
return doc; return doc;
} }
Doc PrintFinal(const NodeRef& node) { Doc PrintFinal(const NodeRef& node) {
if (node.as_derived<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}
Doc doc; Doc doc;
doc << PrintScope(node); doc << PrintScope(node);
if (!meta_.empty()) { if (!meta_.empty()) {
...@@ -200,9 +212,9 @@ class PrettyPrinter : ...@@ -200,9 +212,9 @@ class PrettyPrinter :
Doc PrintAttrs(const Attrs& attrs, const Expr& op); Doc PrintAttrs(const Attrs& attrs, const Expr& op);
Doc Print(const NodeRef& node, bool meta = false) { Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
if (node.as_derived<ExprNode>()) { if (node.as_derived<ExprNode>()) {
return PrintExpr(Downcast<Expr>(node), meta); return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) { } else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta); return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<ModuleNode>()) { } else if (node.as_derived<ModuleNode>()) {
...@@ -308,7 +320,12 @@ class PrettyPrinter : ...@@ -308,7 +320,12 @@ class PrettyPrinter :
return val; return val;
} }
inline bool IsAtomicExpr(const Expr& expr) { bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head &&
dg_.expr_node.at(expr)->parents.head->next);
}
bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<OpNode>() || expr.as<VarNode>(); expr.as<OpNode>() || expr.as<VarNode>();
} }
...@@ -316,17 +333,25 @@ class PrettyPrinter : ...@@ -316,17 +333,25 @@ class PrettyPrinter :
//------------------------------------ //------------------------------------
// Overload of Expr printing functions // Overload of Expr printing functions
//------------------------------------ //------------------------------------
Doc PrintExpr(const Expr& expr, bool meta) { Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
// Exploit memoization to print GNF. // Exploit memoization to print GNF.
// The first time we visit an expression, we need to allocate a temp var // The first time we visit an expression, we need to allocate a temp var
// for it. Every subsequent time we can just use its assigned variable. // for it. Every subsequent time we can just use its assigned variable.
// This works since hashing uses pointer equality. // This works since hashing uses pointer equality.
// determine whether to inline
bool inline_expr = AlwaysInline(expr);
if (try_inline) {
inline_expr |= IsUnique(expr);
}
auto it = memo_.find(expr); auto it = memo_.find(expr);
if (it != memo_.end()) return it->second; if (it != memo_.end()) return it->second;
Doc printed_expr; Doc printed_expr;
if (meta) { if (meta) {
printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get())); printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
} else if (GNF_ && expr.as<LetNode>()) { } else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets // wrap GNFed let in brackets
Doc body; Doc body;
printed_expr << "{"; printed_expr << "{";
...@@ -335,28 +360,26 @@ class PrettyPrinter : ...@@ -335,28 +360,26 @@ class PrettyPrinter :
} else { } else {
printed_expr = VisitExpr(expr); printed_expr = VisitExpr(expr);
} }
// we choose to inline atomic exprs
if (GNF_ && !IsAtomicExpr(expr)) {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr;
if (expr.as<CallNode>()) { if (expr.as<CallNode>()) {
doc_stack_.back() << PrintOptionalInfo(expr); printed_expr << PrintOptionalInfo(expr);
} }
doc_stack_.back() << "\n";
return temp_var; // add expr to doc
} else if (expr.as<VarNode>()) { if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case // This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free. // in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n"; doc_stack_.back() << "free_var " << printed_expr << "\n";
// Memoization is done in AllocVar. // Memoization is done in AllocVar.
return memo_[expr]; return memo_[expr];
} else { } else if (inline_expr) {
memo_[expr] = printed_expr; memo_[expr] = printed_expr;
if (GNF_ && expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}
return printed_expr; return printed_expr;
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << "\n";
return temp_var;
} }
} }
...@@ -420,8 +443,9 @@ class PrettyPrinter : ...@@ -420,8 +443,9 @@ class PrettyPrinter :
Doc VisitExpr_(const LetNode* op) final { Doc VisitExpr_(const LetNode* op) final {
Doc doc; Doc doc;
doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n"; doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n";
// we use a scope here so GNF hoisting doesn't escape too far // we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body); doc << PrintScope(op->body);
return doc; return doc;
} }
...@@ -456,6 +480,8 @@ class PrettyPrinter : ...@@ -456,6 +480,8 @@ class PrettyPrinter :
Doc doc; Doc doc;
int counter = 0; int counter = 0;
for (const auto& kv : mod->functions) { for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
std::ostringstream os; std::ostringstream os;
if (counter++ != 0) { if (counter++ != 0) {
doc << "\n"; doc << "\n";
...@@ -664,8 +690,6 @@ class PrettyPrinter : ...@@ -664,8 +690,6 @@ class PrettyPrinter :
} }
private: private:
/*! \brief Whether to use GNF. */
bool GNF_;
/*! \brief Whether to print meta data. */ /*! \brief Whether to print meta data. */
bool show_meta_data_; bool show_meta_data_;
/*! \brief additional comment function */ /*! \brief additional comment function */
...@@ -682,6 +706,10 @@ class PrettyPrinter : ...@@ -682,6 +706,10 @@ class PrettyPrinter :
TextMetaDataContext meta_; TextMetaDataContext meta_;
/*! \brief counter of temporary variable */ /*! \brief counter of temporary variable */
size_t temp_var_counter_{0}; size_t temp_var_counter_{0};
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
DependencyGraph dg_;
class AttrPrinter; class AttrPrinter;
friend class AttrPrinter; friend class AttrPrinter;
}; };
...@@ -751,25 +779,17 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { ...@@ -751,25 +779,17 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {
std::string PrettyPrint_(const NodeRef& node, std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate, runtime::TypedPackedFunc<std::string(Expr)> annotate) {
bool gnf) {
Doc doc; Doc doc;
doc << "v0.0.1" << "\n" doc << "v0.0.1" << "\n"
<< PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node); << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str(); return doc.str();
} }
std::string AsText(const NodeRef& node, std::string AsText(const NodeRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return PrettyPrint_(node, show_meta_data, annotate, true); return PrettyPrint_(node, show_meta_data, annotate);
}
std::string PassDebugPrint(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate,
bool gnf) {
return PrettyPrint_(node, show_meta_data, annotate, gnf);
} }
TVM_REGISTER_API("relay._expr.AsText") TVM_REGISTER_API("relay._expr.AsText")
...@@ -777,11 +797,5 @@ TVM_REGISTER_API("relay._expr.AsText") ...@@ -777,11 +797,5 @@ TVM_REGISTER_API("relay._expr.AsText")
bool, bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(AsText); runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
TVM_REGISTER_API("relay._ir_pass.pass_debug_print")
.set_body_typed<std::string(const NodeRef&,
bool,
runtime::TypedPackedFunc<std::string(Expr)>,
bool)>(PassDebugPrint);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/pass/dependency_graph.cc
* \brief
*/
#include "dependency_graph.h"
#include <tvm/relay/expr_functor.h>
#include <unordered_set>
#include <utility>
namespace tvm {
namespace relay {
// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
public:
explicit Creator(common::Arena* arena)
: arena_(arena) {}
DependencyGraph Create(const Expr& body) {
this->VisitExpr(body);
return std::move(graph_);
}
private:
/*! \brief allocator of all the internal node object */
common::Arena* arena_;
// The output.
DependencyGraph graph_;
// Update the message stored at the node.
void Depend(DependencyGraph::Node* parent, const Expr& child) {
VisitExpr(child);
CHECK_NE(graph_.expr_node.count(child), 0);
Depend(parent, graph_.expr_node[child]);
}
void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) {
auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
parent_link->value = parent;
child->parents.Push(parent_link);
auto* child_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
child_link->value = child;
parent->children.Push(child_link);
}
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
DependencyGraph::Node* NewNode(bool new_scope) {
auto* ret = arena_->make<DependencyGraph::Node>();
ret->new_scope = new_scope;
return ret;
}
void VisitExpr(const Expr& e) final {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
void VisitExpr_(const CallNode* c) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(c)];
Depend(n, c->op);
for (const auto& a : c->args) {
Depend(n, a);
}
}
void VisitExpr_(const TupleNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
for (const auto& a : t->fields) {
Depend(n, a);
}
}
void VisitExpr_(const TupleGetItemNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
Depend(n, t->tuple);
}
void VisitExpr_(const RefCreateNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->value);
}
void VisitExpr_(const RefReadNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
}
void VisitExpr_(const RefWriteNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
Depend(n, r->value);
}
void VisitExpr_(const IfNode* i) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)];
DependencyGraph::Node* t = NewNode(true);
DependencyGraph::Node* f = NewNode(true);
Depend(n, i->cond);
Depend(n, t);
Depend(n, f);
Depend(t, i->true_branch);
Depend(f, i->false_branch);
graph_.post_dfs_order.push_back(f);
graph_.post_dfs_order.push_back(t);
}
void VisitExpr_(const FunctionNode* f) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const LetNode* l) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const MatchNode* m) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(m)];
Depend(n, m->data);
std::vector<DependencyGraph::Node*> v;
for (const Clause& c : m->clauses) {
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, c->rhs);
v.push_back(b);
}
for (auto it = v.rbegin(); it != v.rend(); ++it) {
graph_.post_dfs_order.push_back(*it);
}
}
void VisitExpr_(const VarNode* v) final { }
void VisitExpr_(const GlobalVarNode* v) final { }
void VisitExpr_(const ConstantNode* c) final { }
void VisitExpr_(const OpNode* o) final { }
void VisitExpr_(const ConstructorNode* c) final { }
};
DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) {
return Creator(arena).Create(body);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2019 by Contributors.
* \file tvm/relay/pass/dependency_graph.h
* \brief
*/
#ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#include <tvm/relay/expr.h>
#include <unordered_map>
#include <vector>
#include "let_list.h"
#include "../../common/arena.h"
namespace tvm {
namespace relay {
using common::LinkNode;
using common::LinkedList;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class DependencyGraph {
public:
/*! \brief A node in the graph. */
struct Node {
// Determine scope boundaries. Used for calculating scopes, not for
// constructing dependency graph.
bool new_scope = false;
// incoming edges
LinkedList<Node*> children;
// outgoing edges
LinkedList<Node*> parents;
};
/*! \brief Maps a Relay Expr to its node in the dependency graph. */
std::unordered_map<Expr, Node*, NodeHash, NodeEqual> expr_node;
/*! \brief The dependency graph in post DFS order. */
std::vector<Node*> post_dfs_order;
/*!
* \brief Create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static DependencyGraph Create(common::Arena* arena, const Expr& body);
private:
class Creator;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
...@@ -29,193 +29,11 @@ ...@@ -29,193 +29,11 @@
#include "let_list.h" #include "let_list.h"
#include "../../common/arena.h" #include "../../common/arena.h"
#include "pass_util.h" #include "pass_util.h"
#include "dependency_graph.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using common::LinkNode;
using common::LinkedList;
/* DependencyGraph track input and output of an Expr.
* Additionally, dummy scope is created to model scope.
* It allow us to traverse the graph in reverse order.
*/
class DependencyGraph {
public:
/*! \brief A node in the graph. */
struct Node {
bool new_scope = false;
LinkedList<Node*> input;
LinkedList<Node*> output;
};
/*! \brief The node map that maps node to graph */
std::unordered_map<Expr, Node*, NodeHash, NodeEqual> expr_node;
/*! \brief All the nodes in post DFS order */
std::vector<Node*> post_dfs_order;
/*!
* \brief create a dependency graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static DependencyGraph Create(common::Arena* arena, const Expr& body);
private:
class Creator;
};
// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
public:
explicit Creator(common::Arena* arena)
: arena_(arena) {}
DependencyGraph Create(const Expr& body) {
this->VisitExpr(body);
return std::move(graph_);
}
private:
/*! \brief allocator of all the internal node object */
common::Arena* arena_;
// The output.
DependencyGraph graph_;
// Update the message stored at the node.
void Depend(DependencyGraph::Node* parent, const Expr& child) {
VisitExpr(child);
CHECK_NE(graph_.expr_node.count(child), 0);
Depend(parent, graph_.expr_node[child]);
}
void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) {
auto* parent_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
parent_link->value = parent;
child->output.Push(parent_link);
auto* child_link = arena_->make<LinkNode<DependencyGraph::Node*> >();
child_link->value = child;
parent->input.Push(child_link);
}
std::unordered_set<Expr, NodeHash, NodeEqual> visited_;
DependencyGraph::Node* NewNode(bool new_scope) {
auto* ret = arena_->make<DependencyGraph::Node>();
ret->new_scope = new_scope;
return ret;
}
void VisitExpr(const Expr& e) final {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
void VisitExpr_(const CallNode* c) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(c)];
Depend(n, c->op);
for (const auto& a : c->args) {
Depend(n, a);
}
}
void VisitExpr_(const TupleNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
for (const auto& a : t->fields) {
Depend(n, a);
}
}
void VisitExpr_(const TupleGetItemNode* t) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(t)];
Depend(n, t->tuple);
}
void VisitExpr_(const RefCreateNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->value);
}
void VisitExpr_(const RefReadNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
}
void VisitExpr_(const RefWriteNode* r) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(r)];
Depend(n, r->ref);
Depend(n, r->value);
}
void VisitExpr_(const IfNode* i) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(i)];
DependencyGraph::Node* t = NewNode(true);
DependencyGraph::Node* f = NewNode(true);
Depend(n, i->cond);
Depend(n, t);
Depend(n, f);
Depend(t, i->true_branch);
Depend(f, i->false_branch);
graph_.post_dfs_order.push_back(f);
graph_.post_dfs_order.push_back(t);
}
void VisitExpr_(const FunctionNode* f) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const LetNode* l) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
}
void VisitExpr_(const MatchNode* m) final {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(m)];
Depend(n, m->data);
std::vector<DependencyGraph::Node*> v;
for (const Clause& c : m->clauses) {
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, c->rhs);
v.push_back(b);
}
for (auto it = v.rbegin(); it != v.rend(); ++it) {
graph_.post_dfs_order.push_back(*it);
}
}
void VisitExpr_(const VarNode* v) final { }
void VisitExpr_(const GlobalVarNode* v) final { }
void VisitExpr_(const ConstantNode* c) final { }
void VisitExpr_(const OpNode* o) final { }
void VisitExpr_(const ConstructorNode* c) final { }
};
DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) {
return Creator(arena).Create(body);
}
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv); Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
struct ScopeNode; struct ScopeNode;
...@@ -256,7 +74,7 @@ std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGrap ...@@ -256,7 +74,7 @@ std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGrap
Scope global_scope = std::make_shared<ScopeNode>(); Scope global_scope = std::make_shared<ScopeNode>();
for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it; DependencyGraph::Node* n = *it;
auto iit = n->output.head; auto iit = n->parents.head;
Scope s; Scope s;
if (iit == nullptr) { if (iit == nullptr) {
s = global_scope; s = global_scope;
...@@ -313,7 +131,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -313,7 +131,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Scope GetSubScope(const Expr& e, size_t i) { Scope GetSubScope(const Expr& e, size_t i) {
DependencyGraph::Node* n = dg_.expr_node.at(e); DependencyGraph::Node* n = dg_.expr_node.at(e);
auto h = n->input.head; auto h = n->children.head;
while (i != 0) { while (i != 0) {
CHECK(h); CHECK(h);
--i; --i;
......
...@@ -50,8 +50,8 @@ def test_env(): ...@@ -50,8 +50,8 @@ def test_env():
text = env.astext() text = env.astext()
assert "def @myf" in text assert "def @myf" in text
assert "def @myf" in str(env) assert "def @myf" in str(env)
assert "%1 = add(%0, %0) // ty=float32" in text assert "add(%0, %0) /* ty=float32 */" in text
assert "%1 = add(%0, %0) // ty=float32" in str(env) assert "add(%0, %0) /* ty=float32 */" in str(env)
show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
show(text) show(text)
...@@ -112,7 +112,7 @@ def test_let_if_scope(): ...@@ -112,7 +112,7 @@ def test_let_if_scope():
f = relay.Function([x, y, cond], result) f = relay.Function([x, y, cond], result)
text = f.astext() text = f.astext()
assert text.count("{") == 6 assert text.count("{") == 4
assert "%cond: bool" in text assert "%cond: bool" in text
show(f.astext()) show(f.astext())
...@@ -180,8 +180,19 @@ def test_call_node_order(): ...@@ -180,8 +180,19 @@ def test_call_node_order():
"%2 = fn (%x) {\n" "%2 = fn (%x) {\n"
" %x\n" " %x\n"
"}\n" "}\n"
"%3 = %2(%1)\n" "%2(%1)")
"%3")
def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x")
assert relay.Let(x, tup, tup).astext() == SEMVER + \
("%0 = (0, 0)\n"
"let %x = %0\n"
"%0")
assert relay.Let(x, tup, x).astext() == SEMVER + \
("let %x = (0, 0)\n"
"%x")
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
...@@ -201,3 +212,4 @@ if __name__ == "__main__": ...@@ -201,3 +212,4 @@ if __name__ == "__main__":
test_let_if_scope() test_let_if_scope()
test_variable_name() test_variable_name()
test_call_node_order() test_call_node_order()
test_let_inlining()
...@@ -38,7 +38,7 @@ def test_unary_op(): ...@@ -38,7 +38,7 @@ def test_unary_op():
x = relay.var("x", tp) x = relay.var("x", tp)
y = opfunc(x) y = opfunc(x)
# test printer # test printer
assert ("%0 = {}(%x)".format(y.op.name)) in y.astext() assert ("{}(%x)".format(y.op.name)) in y.astext()
# test type inference # test type inference
assert relay.ir_pass.infer_type(y).checked_type == tp assert relay.ir_pass.infer_type(y).checked_type == tp
...@@ -78,7 +78,7 @@ def test_binary_op(): ...@@ -78,7 +78,7 @@ def test_binary_op():
y = relay.var("y", t2) y = relay.var("y", t2)
z = opfunc(x, y) z = opfunc(x, y)
# test printer # test printer
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() assert ("{}(%x, %y)".format(z.op.name)) in z.astext()
assert relay.ir_pass.infer_type(z).checked_type == t1 assert relay.ir_pass.infer_type(z).checked_type == t1
if ref is not None: if ref is not None:
......
...@@ -29,7 +29,7 @@ def test_binary_op(): ...@@ -29,7 +29,7 @@ def test_binary_op():
y = relay.var("y", t2) y = relay.var("y", t2)
z = opfunc(x, y) z = opfunc(x, y)
# test printer # test printer
assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() assert ("{}(%x, %y)".format(z.op.name)) in z.astext()
assert relay.ir_pass.infer_type(z).checked_type == t1 assert relay.ir_pass.infer_type(z).checked_type == t1
if ref is not None: if ref is not None:
......
...@@ -44,7 +44,7 @@ def initialize_box_adt(mod): ...@@ -44,7 +44,7 @@ def initialize_box_adt(mod):
def test_monomorphic_let(): def test_monomorphic_let():
"Program: let x = 1; x" "Program: let %x = 1; %x"
sb = relay.ScopeBuilder() sb = relay.ScopeBuilder()
x = sb.let('x', relay.const(1.0, "float64")) x = sb.let('x', relay.const(1.0, "float64"))
sb.ret(x) sb.ret(x)
...@@ -53,7 +53,7 @@ def test_monomorphic_let(): ...@@ -53,7 +53,7 @@ def test_monomorphic_let():
def test_single_op(): def test_single_op():
"Program: fn (x : float32) { let t1 = f(x); t1 }" "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
x = relay.var('x', shape=[]) x = relay.var('x', shape=[])
func = relay.Function([x], op.log(x)) func = relay.Function([x], op.log(x))
ttype = relay.TensorType([], dtype='float32') ttype = relay.TensorType([], dtype='float32')
...@@ -63,8 +63,9 @@ def test_single_op(): ...@@ -63,8 +63,9 @@ def test_single_op():
def test_add_broadcast_op(): def test_add_broadcast_op():
""" """
Program: Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
x + y -> Tensor[(5, 10, 4), float32] {
%x + %y
} }
""" """
x = relay.var('x', shape=(10, 4)) x = relay.var('x', shape=(10, 4))
...@@ -80,10 +81,10 @@ def test_add_broadcast_op(): ...@@ -80,10 +81,10 @@ def test_add_broadcast_op():
def test_dual_op(): def test_dual_op():
"""Program: """Program:
fn (x : Tensor[f32, (10, 10)]) { fn (%x : Tensor[(10, 10), float32]) {
let t1 = log(x); let %t1 = log(x);
let t2 = add(t1, x); let %t2 = add(%t1, %x);
t1 %t1
} }
""" """
tp = relay.TensorType((10, 10), "float32") tp = relay.TensorType((10, 10), "float32")
...@@ -99,8 +100,8 @@ def test_dual_op(): ...@@ -99,8 +100,8 @@ def test_dual_op():
def test_decl(): def test_decl():
"""Program: """Program:
def f(x : Tensor[(10, 10), f32]) { def @f(%x : Tensor[(10, 10), float32]) {
log(x) log(%x)
} }
""" """
tp = relay.TensorType((10, 10)) tp = relay.TensorType((10, 10))
...@@ -113,11 +114,11 @@ def test_decl(): ...@@ -113,11 +114,11 @@ def test_decl():
def test_recursion(): def test_recursion():
""" """
Program: Program:
def f(n: i32, data: f32) -> f32 { def @f(%n: int32, %data: float32) -> float32 {
if (n == 0) { if (%n == 0) {
data %data
} else { } else {
f(n - 1, log(data)) @f(%n - 1, log(%data))
} }
} }
""" """
...@@ -134,7 +135,7 @@ def test_recursion(): ...@@ -134,7 +135,7 @@ def test_recursion():
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
mod = relay.Module() mod = relay.Module()
mod[f] = relay.Function([n, data], sb.get()) mod[f] = relay.Function([n, data], sb.get())
assert "%3 = @f(%1, %2)" in mod.astext() assert "@f(%1, %2) /* ty=float32 */" in mod.astext()
assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment