Unverified Commit 7d670b04 by Animesh Jain Committed by GitHub

Legalize - Use Non-recursive Rewriter. (#5296)

* Legalize - Use Non-recursive Rewriter.

* Cleanup.
parent 2b968204
...@@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator { ...@@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
* *
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order. * ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
* *
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will * The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original * non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the * node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex * ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
...@@ -408,7 +408,7 @@ class ExprRewriter { ...@@ -408,7 +408,7 @@ class ExprRewriter {
/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes /*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
* *
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call, * ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the * PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter. * ExprRewriter.
......
...@@ -35,19 +35,18 @@ namespace legalize { ...@@ -35,19 +35,18 @@ namespace legalize {
// Call registered FTVMLegalize of an op // Call registered FTVMLegalize of an op
// Returns the legalized expression // Returns the legalized expression
class Legalizer : public ExprMutator { class Legalizer : public ExprRewriter {
public: public:
explicit Legalizer(const std::string& legalize_map_attr_name) explicit Legalizer(const std::string& legalize_map_attr_name)
: legalize_map_attr_name_{legalize_map_attr_name} {} : legalize_map_attr_name_{legalize_map_attr_name} {}
Expr VisitExpr_(const CallNode* call_node) { Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
// Get the new_call node without any changes to current call node. // Get the new_call node without any changes to current call node.
Expr new_e = ExprMutator::VisitExpr_(call_node); Call new_call = Downcast<Call>(post);
Call new_call = Downcast<Call>(new_e);
// Check if the string is registered in the OpRegistry. // Check if the string is registered in the OpRegistry.
if (!Op::HasAttr(legalize_map_attr_name_)) { if (!Op::HasAttr(legalize_map_attr_name_)) {
return new_e; return post;
} }
// Collect the registered legalize function. // Collect the registered legalize function.
...@@ -70,19 +69,18 @@ class Legalizer : public ExprMutator { ...@@ -70,19 +69,18 @@ class Legalizer : public ExprMutator {
// Transform the op by calling the registered legalize function. // Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types); Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
// Reassign new_e if the transformation succeeded. // Return the new expr if the transformation succeeded.
if (legalized_value.defined()) { if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode. // Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>(); const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node) CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node"; << "Can only replace the original operator with another call node";
return legalized_value;
new_e = legalized_value;
} }
} }
} }
return new_e; return post;
} }
private: private:
...@@ -90,7 +88,8 @@ class Legalizer : public ExprMutator { ...@@ -90,7 +88,8 @@ class Legalizer : public ExprMutator {
}; };
Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
return Legalizer(legalize_map_attr_name).Mutate(expr); auto rewriter = Legalizer(legalize_map_attr_name);
return PostOrderRewrite(expr, &rewriter);
} }
} // namespace legalize } // namespace legalize
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment