Unverified Commit 7de8a539 by Matthew Brookhart Committed by GitHub

[RELAY] Non-recursive Graph Vistor and Rewriter (#4886)

* First pass a defining a non-recursive Graph Vistor and Rewriter

autoformat

remove a currently empty test until testing is solidfied

* Make CalcDep from Dead Code Elimination non-recursive

* Partially working, not passing all tests yet

passes tests when disabling GetExprRefCount, I think I have a bug in visit counting

fix GetExprRefCount

Fix a subtle bug with nested recursive/non-recursive scopes

* Refactor

* improve comments

* respond to review comments on comments

* Fix a problem with default recursion for dataflow nodes

mark DataflowVisitor methods as override

* implement ScopeMutator

* convert forward_rewrite to ScopeMutator, remove DataflowMutator

* rewrite ExprRewriter and convert fast_math to use it

* switch BiasAddSimplifier to ExprRewriter

fix a clang warning

fix cpp lint

fix doc param error

* respond to review comments

* fix a typo in the iterative looping

* add a regression test for GetExprRefCount issue

* Normalize naming

* fix lint

* First pass a defining a non-recursive Graph Vistor and Rewriter

autoformat

remove a currently empty test until testing is solidfied

* Make CalcDep from Dead Code Elimination non-recursive

* Partially working, not passing all tests yet

passes tests when disabling GetExprRefCount, I think I have a bug in visit counting

fix GetExprRefCount

Fix a subtle bug with nested recursive/non-recursive scopes

* Refactor

* improve comments

* respond to review comments on comments

* Fix a problem with default recursion for dataflow nodes

mark DataflowVisitor methods as override

* implement ScopeMutator

* convert forward_rewrite to ScopeMutator, remove DataflowMutator

* rewrite ExprRewriter and convert fast_math to use it

* switch BiasAddSimplifier to ExprRewriter

fix a clang warning

fix cpp lint

fix doc param error

* respond to review comments

* fix a typo in the iterative looping

* add a regression test for GetExprRefCount issue

* Normalize naming

* fix lint

* respond to review comments
parent 6b840fa9
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <string> #include <string>
#include <unordered_map>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -225,6 +226,16 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr); ...@@ -225,6 +226,16 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/ */
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod); TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
/*!
* \brief Get reference counter of each internal ExprNode in body.
*
* \param body The body expression.
*
* \return The reference count mapping.
*/
TVM_DLL std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -233,6 +233,189 @@ class ExprMutator ...@@ -233,6 +233,189 @@ class ExprMutator
}; };
/*! /*!
* \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
*
* MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
/*! \brief The constructor of MixedModeVisitor
* \param visit_limit The number of times to allow visitation to a node. Usually 1, ocassionally
* higher (i.e., 2 for dead code elimiation), limited to 10 as a sanity check.
*/
explicit MixedModeVisitor(int visit_limit = 1);
/*!
* \brief VisitExpr is finalized to preserve call expansion of dataflow regions
*/
void VisitExpr(const Expr& expr) final;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
protected:
/*!
* \brief A function to apply when reaching a leaf of the graph non-recursively
*/
virtual void VisitLeaf(const Expr& expr);
/*!
* \brief A function to determine if an expression has already been visited or needs to be
* re-visited
*/
virtual bool CheckVisited(const Expr& expr);
/*!
* \brief The max number of times to visit a node
*/
size_t visit_limit_;
};
/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* MixedModeMutator treats Expr as dataflow graph, and only Rewrites each Expr once.
* The mutated results are memoized in a map and reused so that
* local transformation on the dataflow preserves the graph structure.
*
* MixedModeMutator provides the same recursive API as ExprMutator, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
*
* Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior.
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
Expr VisitExpr(const Expr& expr) final;
virtual Expr DispatchVisitExpr(const Expr& expr);
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
/*!
* \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be
* able to rewrite the op only with data about the original node `pre` and the same node with
* modified inputs `post` and should not recurse.
*
* \param pre The expression node before rewriting.
* \param post The expression with rewritten inputs.
*/
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }
protected:
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
template <typename T>
Expr Rewrite(const T* op) {
Expr post = ExprMutator::VisitExpr_(op);
return Rewrite_(op, post);
}
virtual void VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);
};
#define RELAY_EXPR_REWRITER_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \
return self->Rewrite_(static_cast<const OP*>(n.get()), post); \
});
#define EXPR_REWRITER_REWRITE_DEFAULT \
{ return post; }
/*! \brief A non-iterating Expression Rewriter
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
*
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
* graph rewriting.
*/
class ExprRewriter {
private:
using TSelf = ExprRewriter;
using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const Expr& post)>;
public:
/*! \brief virtual destructor */
virtual ~ExprRewriter() {}
/*!
* \brief Same as call.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
Expr operator()(const Expr& pre, const Expr& post) {
return Rewrite(pre, post);
}
/*!
* \brief The functor call.
* \param pre The expression node before rewriting.
* \param post The expression node with rewritten inputs.
* \return The result of the call
*/
virtual Expr Rewrite(const Expr& pre, const Expr& post) {
CHECK(pre.defined());
static FType vtable = InitVTable();
return vtable(pre, this, post);
}
// Functions that can be overriden by subclass, should not recurse
virtual Expr Rewrite_(const VarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const GlobalVarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const FunctionNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const LetNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const IfNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const OpNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const TupleGetItemNode* pre,
const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefCreateNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefReadNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const RefWriteNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const ConstructorNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
virtual Expr Rewrite_(const MatchNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_REWRITER_DISPATCH(ConstantNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleNode);
RELAY_EXPR_REWRITER_DISPATCH(VarNode);
RELAY_EXPR_REWRITER_DISPATCH(GlobalVarNode);
RELAY_EXPR_REWRITER_DISPATCH(FunctionNode);
RELAY_EXPR_REWRITER_DISPATCH(CallNode);
RELAY_EXPR_REWRITER_DISPATCH(LetNode);
RELAY_EXPR_REWRITER_DISPATCH(IfNode);
RELAY_EXPR_REWRITER_DISPATCH(OpNode);
RELAY_EXPR_REWRITER_DISPATCH(TupleGetItemNode);
RELAY_EXPR_REWRITER_DISPATCH(RefCreateNode);
RELAY_EXPR_REWRITER_DISPATCH(RefReadNode);
RELAY_EXPR_REWRITER_DISPATCH(RefWriteNode);
RELAY_EXPR_REWRITER_DISPATCH(ConstructorNode);
RELAY_EXPR_REWRITER_DISPATCH(MatchNode);
return vtable;
}
};
/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
*/
Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit * \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once. * Each node is guaranteed to be visited only once.
* \param node The ir to be visited. * \param node The ir to be visited.
......
...@@ -330,7 +330,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars") ...@@ -330,7 +330,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
*/ */
std::unordered_map<const Object*, size_t> std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body) { GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor { class ExprRefCounter : private MixedModeVisitor {
public: public:
std::unordered_map<const Object*, size_t> std::unordered_map<const Object*, size_t>
Get(const Expr& body) { Get(const Expr& body) {
......
...@@ -29,8 +29,162 @@ ...@@ -29,8 +29,162 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <stack>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
*/
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
}
};
fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
}
}
}
MixedModeVisitor::MixedModeVisitor(int visit_limit) {
CHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0";
CHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10";
visit_limit_ = visit_limit;
}
void MixedModeVisitor::VisitLeaf(const Expr& expr) {
if (visit_counter_[expr.get()] < visit_limit_) {
ExprFunctor::VisitExpr(expr);
}
visit_counter_[expr.get()]++;
}
bool MixedModeVisitor::CheckVisited(const Expr& expr) {
if (visit_counter_[expr.get()] < visit_limit_) {
return false;
} else {
visit_counter_[expr.get()]++;
return true;
}
}
void MixedModeVisitor::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (visit_counter_[expr.get()] < visit_limit_) {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
}
}
// Overwrite the VisitExpr so we don't recurse for dataflow nodes
void MixedModeVisitor::VisitExpr_(const CallNode* op) {}
// Overwrite the VisitExpr so we don't recurse for dataflow nodes
void MixedModeVisitor::VisitExpr_(const TupleNode* op) {}
// Overwrite the VisitExpr so we don't recurse for dataflow nodes
void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {}
void MixedModeMutator::VisitLeaf(const Expr& expr) {
if (!memo_.count(expr)) {
this->DispatchVisitExpr(expr);
}
}
bool MixedModeMutator::CheckVisited(const Expr& expr) {
if (memo_.count(expr)) {
return true;
} else {
return false;
}
}
Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) {
return ExprMutator::VisitExpr(expr);
}
Expr MixedModeMutator::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (memo_.count(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
Expr ret = this->DispatchVisitExpr(expr);
memo_[expr] = ret;
return ret;
}
}
class PostOrderRewriter : public MixedModeMutator {
public:
explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
Expr DispatchVisitExpr(const Expr& expr) final {
auto post = ExprFunctor::VisitExpr(expr);
return rewriter_->Rewrite(expr, post);
}
protected:
ExprRewriter* rewriter_;
};
Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter) {
return PostOrderRewriter(rewriter).VisitExpr(expr);
}
Expr ExprMutator::VisitExpr(const Expr& expr) { Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr); auto it = this->memo_.find(expr);
...@@ -211,12 +365,12 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { ...@@ -211,12 +365,12 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
for (const Clause& p : m->clauses) { for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p)); clauses.push_back(VisitClause(p));
} }
return Match(VisitExpr(m->data), clauses, m->complete); return Match(Mutate(m->data), clauses, m->complete);
} }
Clause ExprMutator::VisitClause(const Clause& c) { Clause ExprMutator::VisitClause(const Clause& c) {
Pattern p = VisitPattern(c->lhs); Pattern p = VisitPattern(c->lhs);
return Clause(p, VisitExpr(c->rhs)); return Clause(p, Mutate(c->rhs));
} }
Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
......
...@@ -32,12 +32,12 @@ ...@@ -32,12 +32,12 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
class BiasAddSimplifier : public ExprMutator { class BiasAddSimplifier : public ExprRewriter {
public: public:
BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {} BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}
Expr VisitExpr_(const CallNode* n) { Expr Rewrite_(const CallNode* n, const Expr& post) override {
auto new_n = ExprMutator::VisitExpr_(n); auto new_n = post;
if (n->op == bias_add_op_) { if (n->op == bias_add_op_) {
Call call = Downcast<Call>(new_n); Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2); CHECK_EQ(call->args.size(), 2);
...@@ -63,7 +63,8 @@ class BiasAddSimplifier : public ExprMutator { ...@@ -63,7 +63,8 @@ class BiasAddSimplifier : public ExprMutator {
}; };
Expr CanonicalizeOps(const Expr& e) { Expr CanonicalizeOps(const Expr& e) {
return BiasAddSimplifier().Mutate(e); auto rewriter = BiasAddSimplifier();
return PostOrderRewrite(e, &rewriter);
} }
namespace transform { namespace transform {
......
...@@ -92,7 +92,7 @@ class Eliminator : private ExprMutator { ...@@ -92,7 +92,7 @@ class Eliminator : private ExprMutator {
}; };
// calculate the dependency graph from expression // calculate the dependency graph from expression
class CalcDep : private ExprVisitor { class CalcDep : protected MixedModeVisitor {
public: public:
static Expr Eliminate(const Expr& e, bool inline_once) { static Expr Eliminate(const Expr& e, bool inline_once) {
FindDef fd; FindDef fd;
...@@ -104,11 +104,14 @@ class CalcDep : private ExprVisitor { ...@@ -104,11 +104,14 @@ class CalcDep : private ExprVisitor {
} }
private: private:
explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { } explicit CalcDep(const VarMap<Expr>& expr_map)
: MixedModeVisitor(2), expr_map_(expr_map) {}
VarMap<Expr> expr_map_; VarMap<Expr> expr_map_;
VarMap<size_t> use_map_; VarMap<size_t> use_map_;
void VisitExpr(const Expr& e) final { using MixedModeVisitor::VisitExpr_;
void VisitLeaf(const Expr& e) final {
visit_counter_[e.get()]++; visit_counter_[e.get()]++;
// The dce code seprate variable into three parts: // The dce code seprate variable into three parts:
// used 0 times (remove) // used 0 times (remove)
......
...@@ -31,20 +31,19 @@ ...@@ -31,20 +31,19 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
class FastMathMutator : public ExprMutator { class FastMathMutator : public ExprRewriter {
public: public:
FastMathMutator() FastMathMutator()
: exp_op_(Op::Get("exp")), : exp_op_(Op::Get("exp")),
tanh_op_(Op::Get("tanh")) {} tanh_op_(Op::Get("tanh")) {}
Expr VisitExpr_(const CallNode* n) { Expr Rewrite_(const CallNode* pre, const Expr& post) override {
auto new_n = ExprMutator::VisitExpr_(n); if (pre->op == exp_op_) {
if (n->op == exp_op_) { return FastExp(post.as<CallNode>()->args[0]);
return FastExp(new_n.as<CallNode>()->args[0]); } else if (pre->op == tanh_op_) {
} else if (n->op == tanh_op_) { return FastTanh(post.as<CallNode>()->args[0]);
return FastTanh(new_n.as<CallNode>()->args[0]);
} }
return new_n; return post;
} }
private: private:
...@@ -56,7 +55,8 @@ class FastMathMutator : public ExprMutator { ...@@ -56,7 +55,8 @@ class FastMathMutator : public ExprMutator {
}; };
Expr FastMath(const Expr& e) { Expr FastMath(const Expr& e) {
return FastMathMutator().Mutate(e); auto rewriter = FastMathMutator();
return PostOrderRewrite(e, &rewriter);
} }
namespace transform { namespace transform {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* \file forward_rewrite.cc * \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion. * \brief Apply rewriting rules in a forward fashion.
*/ */
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
...@@ -33,32 +34,25 @@ namespace relay { ...@@ -33,32 +34,25 @@ namespace relay {
// Realizer class that realizes the expression // Realizer class that realizes the expression
// Note that we can take benefit of its internal memo // Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf. // so that calling realize repeatively won't hurt perf.
class TempRealizer : private ExprMutator { class TempRealizer : private MixedModeMutator {
public: public:
Expr Realize(Expr expr) { Expr Realize(Expr expr) {
return VisitExpr(expr); return Mutate(expr);
} }
private: private:
Expr VisitExpr(const Expr& expr) final { Expr DispatchVisitExpr(const Expr& expr) final {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Expr res; Expr res;
if (const auto* temp = expr.as<TempExprNode>()) { if (const auto* temp = expr.as<TempExprNode>()) {
res = temp->Realize(); res = temp->Realize();
} else { } else {
res = ExprFunctor::VisitExpr(expr); res = MixedModeMutator::DispatchVisitExpr(expr);
} }
memo_[res] = res;
return res; return res;
} }
}
}; };
class ForwardRewriter : private ExprMutator { class ForwardRewriter : private MixedModeMutator {
public: public:
ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map, ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
std::function<ObjectRef(const Call&)> fcontext, std::function<ObjectRef(const Call&)> fcontext,
...@@ -76,11 +70,11 @@ class ForwardRewriter : private ExprMutator { ...@@ -76,11 +70,11 @@ class ForwardRewriter : private ExprMutator {
// Transform expression. // Transform expression.
Expr Rewrite(Expr expr) { Expr Rewrite(const Expr& expr) {
if (fmulti_ref_trigger_ != nullptr) { if (fmulti_ref_trigger_ != nullptr) {
ref_counter_ = GetExprRefCount(expr); ref_counter_ = GetExprRefCount(expr);
} }
return this->VisitExpr(expr); return realizer_.Realize(this->VisitExpr(expr));
} }
private: private:
...@@ -96,15 +90,10 @@ class ForwardRewriter : private ExprMutator { ...@@ -96,15 +90,10 @@ class ForwardRewriter : private ExprMutator {
// internal realizer // internal realizer
TempRealizer realizer_; TempRealizer realizer_;
Expr VisitExpr(const Expr& expr) final {
// by default always realize.
return realizer_.Realize(ExprMutator::VisitExpr(expr));
}
// Visit and allow non-realized version. // Visit and allow non-realized version.
Expr GetTempExpr(const Expr& expr) { Expr GetTempExpr(const Expr& expr, const Expr& post) {
if (fmulti_ref_trigger_ != nullptr) { if (fmulti_ref_trigger_ != nullptr) {
Expr ret = ExprMutator::VisitExpr(expr); Expr ret = post;
auto it = ref_counter_.find(expr.get()); auto it = ref_counter_.find(expr.get());
CHECK(it != ref_counter_.end()); CHECK(it != ref_counter_.end());
if (it->second > 1) { if (it->second > 1) {
...@@ -112,13 +101,13 @@ class ForwardRewriter : private ExprMutator { ...@@ -112,13 +101,13 @@ class ForwardRewriter : private ExprMutator {
} }
return ret; return ret;
} else { } else {
return ExprMutator::VisitExpr(expr); return post;
} }
} }
// Automatic fold TupleGetItem. // Automatic fold TupleGetItem.
Expr VisitExpr_(const TupleGetItemNode* op) final { Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
Expr tuple = this->GetTempExpr(op->tuple); Expr tuple = this->GetTempExpr(op->tuple, post.as<TupleGetItemNode>()->tuple);
if (const auto* ptuple = tuple.as<TupleNode>()) { if (const auto* ptuple = tuple.as<TupleNode>()) {
return ptuple->fields[op->index]; return ptuple->fields[op->index];
} else { } else {
...@@ -130,13 +119,14 @@ class ForwardRewriter : private ExprMutator { ...@@ -130,13 +119,14 @@ class ForwardRewriter : private ExprMutator {
} }
} }
Expr VisitExpr_(const TupleNode* op) final { Expr Rewrite_(const TupleNode* op, const Expr& post) final {
tvm::Array<Expr> fields; tvm::Array<Expr> fields;
bool all_fields_unchanged = true; bool all_fields_unchanged = true;
for (auto field : op->fields) { const auto* post_node = post.as<TupleNode>();
auto new_field = this->GetTempExpr(field); for (size_t i = 0; i < op->fields.size(); ++i) {
auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
fields.push_back(new_field); fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field); all_fields_unchanged &= new_field.same_as(op->fields[i]);
} }
if (all_fields_unchanged) { if (all_fields_unchanged) {
...@@ -146,7 +136,7 @@ class ForwardRewriter : private ExprMutator { ...@@ -146,7 +136,7 @@ class ForwardRewriter : private ExprMutator {
} }
} }
Expr VisitExpr_(const CallNode* call_node) final { Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
const Call& ref_call = GetRef<Call>(call_node); const Call& ref_call = GetRef<Call>(call_node);
PackedFunc frewrite; PackedFunc frewrite;
if (rewrite_func_) { if (rewrite_func_) {
...@@ -155,17 +145,17 @@ class ForwardRewriter : private ExprMutator { ...@@ -155,17 +145,17 @@ class ForwardRewriter : private ExprMutator {
CHECK(rewrite_map_); CHECK(rewrite_map_);
frewrite = rewrite_map_->get(call_node->op, nullptr); frewrite = rewrite_map_->get(call_node->op, nullptr);
} }
const auto* post_node = post.as<CallNode>();
auto new_op = this->Mutate(call_node->op); auto new_op = post_node->op;
bool unchanged = call_node->op.same_as(new_op); bool unchanged = call_node->op.same_as(new_op);
Array<Expr> call_args; Array<Expr> call_args;
for (auto arg : call_node->args) { for (size_t i = 0; i < call_node->args.size(); ++i) {
Expr new_arg = this->GetTempExpr(arg); Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]);
if (frewrite == nullptr) { if (frewrite == nullptr) {
new_arg = realizer_.Realize(new_arg); new_arg = realizer_.Realize(new_arg);
} }
unchanged &= new_arg.same_as(arg); unchanged &= new_arg.same_as(call_node->args[i]);
call_args.push_back(new_arg); call_args.push_back(new_arg);
} }
// try to rewrite. // try to rewrite.
......
...@@ -35,14 +35,6 @@ namespace tvm { ...@@ -35,14 +35,6 @@ namespace tvm {
namespace relay { namespace relay {
/*! /*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Object*, size_t>
GetExprRefCount(const Expr& body);
/*!
* \brief Check if expr is positive constant. * \brief Check if expr is positive constant.
* \param expr The expression to be checked. * \param expr The expression to be checked.
* \return Whether all elements of expr is positive constant. * \return Whether all elements of expr is positive constant.
......
...@@ -161,6 +161,23 @@ TEST(Relay, BuildModule) { ...@@ -161,6 +161,23 @@ TEST(Relay, BuildModule) {
} }
} }
TEST(Relay, GetExprRefCount) {
auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
auto a = relay::Var("a", tensor_type);
auto add_op = relay::Op::Get("add");
auto relu_op = relay::Op::Get("nn.relu");
auto x = relay::Call(relu_op, {a}, tvm::Attrs(), {});
auto y = relay::Call(relu_op, {x}, tvm::Attrs(), {});
auto z = relay::Call(add_op, {y, x}, tvm::Attrs(), {});
auto ref_count = GetExprRefCount(z);
CHECK(ref_count[a.get()] == 1);
CHECK(ref_count[relu_op.get()] == 2);
CHECK(ref_count[add_op.get()] == 1);
CHECK(ref_count[x.get()] == 2);
CHECK(ref_count[y.get()] == 1);
CHECK(ref_count[z.get()] == 1);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment