Commit 7880b50c by Wuwei Lin Committed by Tianqi Chen

[Relay][Pass] Fix CombineParallelConv2D (#2167)

parent db74c997
...@@ -13,9 +13,9 @@ from .backend import graph_runtime_codegen as _graph_gen ...@@ -13,9 +13,9 @@ from .backend import graph_runtime_codegen as _graph_gen
# List of optimization pass and level when switch on # List of optimization pass and level when switch on
OPT_PASS_LEVEL = { OPT_PASS_LEVEL = {
"SimplifyInference": 0, "SimplifyInference": 0,
"CombineParallelConv2D": 4,
"OpFusion": 1, "OpFusion": 1,
"FoldConstant": 2, "FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3, "FoldScaleAxis": 3,
} }
......
...@@ -47,17 +47,22 @@ using Group = std::vector<Branch>; ...@@ -47,17 +47,22 @@ using Group = std::vector<Branch>;
class BranchGroupFinder : private ExprVisitor { class BranchGroupFinder : private ExprVisitor {
public: public:
std::vector<Group> Find(const Expr& expr) { std::vector<Group> Find(const Expr& expr) {
static const Op& conv2d = Op::Get("nn.conv2d");
this->VisitExpr(expr); this->VisitExpr(expr);
std::vector<Group> groups; std::vector<Group> groups;
for (const auto& root : conv_roots_) { for (const auto& root : conv_roots_) {
const auto& convs = children_map_.at(root); const auto& children = children_map_.at(root);
for (const CallNode* conv : convs) { size_t ngroups = groups.size();
auto&& branch = CreateBranch(conv); for (const CallNode* child : children) {
if (!child->op.same_as(conv2d)) continue;
auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group // add the branch to a group, or create a new group
auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty()); CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(conv, group[0][0]); return IsCompatibleConv2D(child, group[0][0]);
}); });
if (it != groups.end()) { if (it != groups.end()) {
it->push_back(branch); it->push_back(branch);
...@@ -108,7 +113,7 @@ class BranchGroupFinder : private ExprVisitor { ...@@ -108,7 +113,7 @@ class BranchGroupFinder : private ExprVisitor {
const CallNode* call = it->second[0]; const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)]; auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) { if (pattern <= kBroadcast) {
branch.push_back(it->second[0]); branch.push_back(call);
it = children_map_.find(GetRef<Expr>(branch.back())); it = children_map_.find(GetRef<Expr>(branch.back()));
} else { } else {
break; break;
......
...@@ -11,7 +11,8 @@ def test_combine_parallel_conv2d(): ...@@ -11,7 +11,8 @@ def test_combine_parallel_conv2d():
# y3 cannot be combined # y3 cannot be combined
y3 = relay.nn.conv2d(x, w3) y3 = relay.nn.conv2d(x, w3)
y4 = relay.nn.conv2d(x, w4) y4 = relay.nn.conv2d(x, w4)
y = relay.Tuple((y1, y2, y3, y4)) y5 = relay.nn.max_pool2d(x)
y = relay.Tuple((y1, y2, y3, y4, y5))
return relay.Function(args, y) return relay.Function(args, y)
def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
...@@ -24,7 +25,8 @@ def test_combine_parallel_conv2d(): ...@@ -24,7 +25,8 @@ def test_combine_parallel_conv2d():
y3 = relay.nn.conv2d(x, w3) y3 = relay.nn.conv2d(x, w3)
y4 = relay.strided_slice(y, [0, channels1 + channels2], y4 = relay.strided_slice(y, [0, channels1 + channels2],
[None, channels1 + channels2 + channels4]) [None, channels1 + channels2 + channels4])
y = relay.Tuple((y1, y2, y3, y4)) y5 = relay.nn.max_pool2d(x)
y = relay.Tuple((y1, y2, y3, y4, y5))
return relay.Function(args, y) return relay.Function(args, y)
def check(x_shape, channels1, channels2, channels3, channels4): def check(x_shape, channels1, channels2, channels3, channels4):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment