Commit 7880b50c by Wuwei Lin Committed by Tianqi Chen

[Relay][Pass] Fix CombineParallelConv2D (#2167)

parent db74c997
......@@ -13,9 +13,9 @@ from .backend import graph_runtime_codegen as _graph_gen
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 4,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
}
......
......@@ -47,17 +47,22 @@ using Group = std::vector<Branch>;
class BranchGroupFinder : private ExprVisitor {
public:
std::vector<Group> Find(const Expr& expr) {
static const Op& conv2d = Op::Get("nn.conv2d");
this->VisitExpr(expr);
std::vector<Group> groups;
for (const auto& root : conv_roots_) {
const auto& convs = children_map_.at(root);
for (const CallNode* conv : convs) {
auto&& branch = CreateBranch(conv);
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(conv2d)) continue;
auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) {
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(conv, group[0][0]);
return IsCompatibleConv2D(child, group[0][0]);
});
if (it != groups.end()) {
it->push_back(branch);
......@@ -108,7 +113,7 @@ class BranchGroupFinder : private ExprVisitor {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
branch.push_back(it->second[0]);
branch.push_back(call);
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
break;
......
......@@ -11,7 +11,8 @@ def test_combine_parallel_conv2d():
# y3 cannot be combined
y3 = relay.nn.conv2d(x, w3)
y4 = relay.nn.conv2d(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
y5 = relay.nn.max_pool2d(x)
y = relay.Tuple((y1, y2, y3, y4, y5))
return relay.Function(args, y)
def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
......@@ -24,7 +25,8 @@ def test_combine_parallel_conv2d():
y3 = relay.nn.conv2d(x, w3)
y4 = relay.strided_slice(y, [0, channels1 + channels2],
[None, channels1 + channels2 + channels4])
y = relay.Tuple((y1, y2, y3, y4))
y5 = relay.nn.max_pool2d(x)
y = relay.Tuple((y1, y2, y3, y4, y5))
return relay.Function(args, y)
def check(x_shape, channels1, channels2, channels3, channels4):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment