Commit ca352770 by 雾雨魔理沙 Committed by masahi

[Relay] Fix operator fusion for multiple output (#3871)

* save

* add test

* refactor

* fix indent

* save

* refactor
parent 57cd27f1
...@@ -304,14 +304,16 @@ class PrettyPrinter : ...@@ -304,14 +304,16 @@ class PrettyPrinter :
* \return The corresponding name. * \return The corresponding name.
*/ */
Doc AllocTypeVar(const TypeVar& var) { Doc AllocTypeVar(const TypeVar& var) {
if (memo_type_.count(var)) {
Doc val = memo_type_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->var->name_hint; std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) { if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name; name = "t" + name;
} }
Doc val = GetUniqueName("%" + name); Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val; memo_type_[var] = val;
if (var->kind != kType) { if (var->kind != kType) {
val << ": " << Print(var->kind); val << ": " << Print(var->kind);
...@@ -325,16 +327,18 @@ class PrettyPrinter : ...@@ -325,16 +327,18 @@ class PrettyPrinter :
* \return The corresponding name. * \return The corresponding name.
*/ */
Doc AllocVar(const Var& var) { Doc AllocVar(const Var& var) {
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
Doc val = memo_[var];
val << "-malformed-ir";
return val;
}
std::string name = var->name_hint(); std::string name = var->name_hint();
// always make sure first name is alpha // always make sure first name is alpha
if (name.length() == 0 || !std::isalpha(name[0])) { if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name; name = "v" + name;
} }
Doc val = GetUniqueName("%" + name); Doc val = GetUniqueName("%" + name);
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
val << "-malformed-ir";
}
memo_[var] = val; memo_[var] = val;
if (var->type_annotation.defined()) { if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation); val << ": " << Print(var->type_annotation);
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* *
* \file src/tvm/relay/pass/fuse_ops.cc * \file src/tvm/relay/pass/fuse_ops.cc
* *
...@@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -247,11 +247,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
node->pattern = op_pattern; node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque); this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>(); const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references. // pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type = const auto* arg_type =
call->args[i]->checked_type().as<TensorTypeNode>(); call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type // specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern; OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast && if (edge_pattern == kBroadcast &&
arg_type != nullptr && arg_type != nullptr &&
...@@ -403,12 +403,12 @@ class DominatorTree { ...@@ -403,12 +403,12 @@ class DominatorTree {
return rhs; return rhs;
} }
/*! /*!
* \brief Find the least common acenstor of the two nodes. * \brief Find the least common ancestor of the two nodes.
* \param lhs The left node. * \param lhs The left node.
* \param rhs The right node. * \param rhs The right node.
* \param edge_pattern * \param edge_pattern
* The combined edge pattern across all the parents. * The combined edge pattern across all the parents.
* \return The least common ancestor of thw two. * \return The least common ancestor of the two.
*/ */
static Node* LeastCommonAncestor( static Node* LeastCommonAncestor(
Node* lhs, Node* lhs,
...@@ -436,17 +436,43 @@ class DominatorTree { ...@@ -436,17 +436,43 @@ class DominatorTree {
} }
return lhs; return lhs;
} }
}; /*!
* \brief Find the least common ancestor of a list of nodes.
DominatorTree DominatorTree::PostDom(common::Arena* arena, * \param nodes the nodes.
const IndexedForwardGraph& graph) { * \param edge_pattern
DominatorTree tree; * The combined edge pattern across all the parents.
tree.nodes.resize(graph.post_dfs_order.size(), nullptr); * \return The least common ancestor of all nodes.
// reverse topo order */
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { Node* LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& input_nodes,
size_t index = i - 1; OpPatternKind* edge_pattern) {
auto link = input_nodes.head;
if (link == nullptr) {
return nullptr;
}
auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
size_t oindex = edge.node->index;
CHECK_LT(oindex, nodes.size());
Node* onode = nodes[oindex];
CHECK(onode != nullptr);
return onode;
};
Node* parent = get_node(link->value);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
link = link->next;
for (; link != nullptr; link = link->next) {
parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
}
return parent;
}
/*!
* \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node.
* \param arena The Arena.
* \param gnode An IndexedForwardGraph Node.
* \return The DominatorTree Node.
*/
Node* GetNode(common::Arena* arena, IndexedForwardGraph::Node* gnode) {
Node* tnode = arena->make<Node>(); Node* tnode = arena->make<Node>();
auto* gnode = graph.post_dfs_order[index];
tnode->gnode = gnode; tnode->gnode = gnode;
if (gnode->extern_ref) { if (gnode->extern_ref) {
tnode->depth = 1; tnode->depth = 1;
...@@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena, ...@@ -455,24 +481,24 @@ DominatorTree DominatorTree::PostDom(common::Arena* arena,
} else { } else {
// find the LCAs of all outputs. // find the LCAs of all outputs.
OpPatternKind pattern = kElemWise; OpPatternKind pattern = kElemWise;
Node* parent = nullptr; Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
size_t oindex = link->value.node->index;
CHECK_LT(oindex, tree.nodes.size());
Node* onode = tree.nodes[oindex];
CHECK(onode != nullptr);
if (parent != nullptr) {
parent = LeastCommonAncestor(parent, onode, &pattern);
} else {
parent = onode;
}
pattern = CombinePattern(pattern, link->value.pattern);
}
tnode->depth = parent ? parent->depth + 1 : 1; tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent; tnode->parent = parent;
tnode->pattern = pattern; tnode->pattern = pattern;
} }
tree.nodes[index] = tnode; return tnode;
}
};
DominatorTree DominatorTree::PostDom(common::Arena* arena,
const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
} }
return tree; return tree;
} }
...@@ -614,7 +640,7 @@ class GraphPartitioner { ...@@ -614,7 +640,7 @@ class GraphPartitioner {
// merge the current group to the parent if possible. // merge the current group to the parent if possible.
MergeFromTo(gnode, target); MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) { for (auto link = src->outputs.head; link != nullptr; link = link->next) {
CommitFuse_(link->value.node, sink, target);; CommitFuse_(link->value.node, sink, target);
} }
} }
/*! /*!
...@@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator { ...@@ -851,7 +877,7 @@ class FuseMutator : private ExprMutator {
Expr VisitExpr_(const TupleNode* tuple) { Expr VisitExpr_(const TupleNode* tuple) {
auto* ret_group = gmap_.at(tuple)->FindRoot(); auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group == gmap_.at(tuple)) { if (ret_group->root_ref == tuple) {
return ExprMutator::VisitExpr_(tuple); return ExprMutator::VisitExpr_(tuple);
} }
// This tuple is an intermediate node in the group // This tuple is an intermediate node in the group
...@@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator { ...@@ -863,7 +889,7 @@ class FuseMutator : private ExprMutator {
auto* ret_group = gmap_.at(tuple_get)->FindRoot(); auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index); auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
if (ret_group == gmap_.at(tuple_get)) { if (ret_group->root_ref == tuple_get) {
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) { if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op // Isolated. This case occurs when tuple is created by an Opaque op
// e.g. multibox_transform_loc // e.g. multibox_transform_loc
...@@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator { ...@@ -922,45 +948,8 @@ class FuseMutator : private ExprMutator {
} }
}; };
// Temporary solution, should be handled by implementing a "FunctionPass"
// which applies fusion to each function.
struct GlobalVarLiveness : ExprVisitor {
Module module;
std::set<GlobalVar> visited;
explicit GlobalVarLiveness(const Module& mod) : module(mod), visited() {}
void VisitExpr_(const GlobalVarNode* gvar_node) {
auto gvar = GetRef<GlobalVar>(gvar_node);
if (visited.find(gvar) == visited.end()) {
visited.insert(gvar);
this->VisitExpr(this->module->Lookup(gvar));
}
}
};
std::set<GlobalVar> LiveGlobals(const Module& mod, const Expr& expr) {
auto gvl = GlobalVarLiveness(mod);
gvl.VisitExpr(expr);
return gvl.visited;
}
Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
// First we convert all chains of fusable ops into return FuseMutator().Transform(expr, fuse_opt_level);
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
if (!module.defined()) {
return FuseMutator().Transform(expr, fuse_opt_level);
} else {
auto lgvs = LiveGlobals(module, expr);
for (auto lv : lgvs) {
auto body = module->Lookup(lv);
auto e = FuseMutator().Transform(body, fuse_opt_level);
module->Add(lv, Downcast<Function>(e), true);
}
return FuseMutator().Transform(expr, fuse_opt_level);
}
} }
namespace transform { namespace transform {
......
...@@ -541,6 +541,18 @@ def test_immutable(): ...@@ -541,6 +541,18 @@ def test_immutable():
assert relay.analysis.alpha_equal(new_mod, expected()) assert relay.analysis.alpha_equal(new_mod, expected())
def test_split():
"""Test that the result is well formed."""
x = relay.var("x", shape=(6, 9))
y = relay.split(x, 3).astuple()
a = relay.TupleGetItem(y, 0)
b = relay.TupleGetItem(y, 1)
c = relay.TupleGetItem(y, 2)
mod = relay.module.Module()
mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
mod = transform.FuseOps()(mod)
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse() test_conv2d_fuse()
...@@ -555,3 +567,4 @@ if __name__ == "__main__": ...@@ -555,3 +567,4 @@ if __name__ == "__main__":
test_inception_like() test_inception_like()
test_fuse_parallel_injective() test_fuse_parallel_injective()
test_immutable() test_immutable()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment