Commit 53ac89ed by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] CombineParallelConv2D (#2089)

parent ff5dffa4
......@@ -13,6 +13,7 @@ from .backend import graph_runtime_codegen as _graph_gen
# List of optimization pass and level when switch on
"SimplifyInference": 0,
"CombineParallelConv2D": 1,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
......@@ -144,6 +145,10 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
......@@ -292,3 +292,19 @@ def fuse_ops(expr, opt_level=1):
Transformed expression, containing fused result.
return _ir_pass.FuseOps(expr, opt_level)
def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
expr : tvm.relay.Expr
The input expression.
transformed_expr : tvm.relay.Expr
Transformed expression
return _ir_pass.CombineParallelConv2D(expr)
* Copyright (c) 2018 by Contributors
* \file
* \brief Combine parallel 2d convolutions into a single convolution.
* This pass replaces convolutions that share the same input node and the same
* arguments (except that the number of output channels can be different) with a
* single convolution. The weight of the new 2d convolution is the concatenation
* of the original weights. Elemwise and broadcast ops following conv2d are also
* combined if possible.
* This prevents launching multiple kernels in networks with multiple
* convolution branches, such as Inception block.
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
namespace tvm {
namespace relay {
using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;
Find parallel branches starting with conv2d as shown below and then group branches by kernel
shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops.
Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
which should be handled in ParallelConv2DCombiner.
/ \
conv2d conv2d
| |
op op
| |
class BranchGroupFinder : private ExprVisitor {
std::vector<Group> Find(const Expr& expr) {
std::vector<Group> groups;
for (const auto& root : conv_roots_) {
const auto& convs =;
for (const CallNode* conv : convs) {
auto&& branch = CreateBranch(conv);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(conv, group[0][0]);
if (it != groups.end()) {
} else {
// each group has at least one branch
return groups;
std::unordered_set<Expr, NodeHash, NodeEqual> conv_roots_;
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_;
// Two 2d convolutions can be combined if they have the same attributes or
// only have different output channels.
bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
static const Layout kOIHW("OIHW");
const auto* attrs_a = a-><Conv2DAttrs>();
const auto* attrs_b = b-><Conv2DAttrs>();
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW);
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
eq(attrs_a->data_layout, attrs_b->data_layout) &&
eq(attrs_a->weight_layout, attrs_b->weight_layout) &&
eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
eq(shape_a[3], shape_b[3]);
// Create a branch starting from conv2d.
Branch CreateBranch(const CallNode* conv) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always conv2d
Branch branch{conv};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
while (it != children_map_.end() && it->second.size() == 1) {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
return branch;
void VisitExpr_(const CallNode* n) final {
static const Op& conv2d = Op::Get("nn.conv2d");
if (n->op.same_as(conv2d) && n-><Conv2DAttrs>()->groups == 1) {
} else {
for (size_t i = 0; i < n->args.size(); i++) {
class ParallelConv2DCombiner {
Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() < 2) continue;
return ExprSubst(expr, std::move(subst_map_));
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Array<Expr> weights;
for (const auto& branch : branches) {
auto conv2d = branch[0];
auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels;
auto index = branches[0][0]-><Conv2DAttrs>()->weight_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
Call MakeCombinedConv2D(const Group& branches) {
static const Op& conv2d = Op::Get("nn.conv2d");
Expr data = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_channels;
std::tie(new_weight, new_channels) = TransformWeight(branches);
const CallNode* group_root = branches[0][0];
const auto* attrs = group_root-><Conv2DAttrs>();
const auto new_attrs = make_node<Conv2DAttrs>();
new_attrs->strides = attrs->strides;
new_attrs->padding = attrs->padding;
new_attrs->dilation = attrs->dilation;
new_attrs->groups = attrs->groups;
new_attrs->kernel_size = attrs->kernel_size;
new_attrs->data_layout = attrs->data_layout;
new_attrs->weight_layout = attrs->weight_layout;
new_attrs->out_layout = attrs->out_layout;
new_attrs->out_dtype = attrs->out_dtype;
new_attrs->channels = new_channels;
return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) {
AttrsEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();
if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
return false;
// Position of the 'C' dimension in the argument
size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size();
// Channel super-dimension shoule be present and not broadcasted
if ((arg_channel_pos > channel_pos) || // size_t overflow
!eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) ||
!eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos]))
return false;
for (size_t i = 0; i < ta->shape.size(); i++) {
if (i == arg_channel_pos) continue;
if (!eq(ta->shape[i], tb->shape[i]))
return false;
return true;
// Check if ops in depth-th level can be combined
bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) {
const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
if (!branch[depth]->op.same_as(call->op) ||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
branch[depth]->args.size() != call->args.size()) {
return false;
if (branch[depth]->args[parent_index].get() != branch[depth - 1])
return false;
// Check args
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) continue;
if (!IsArgCompatible(call, branch[depth], i, channel_pos) ||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
return false;
return true;
// Combine args and make the combined CallNode
Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
size_t ndim = call->type_as<TensorTypeNode>()->shape.size();
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t arg_channel_pos = channel_pos - ndim + arg_ndim;
Array<Expr> tuple;
for (const auto& branch : branches) {
auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
return CallNode::make(call->op, new_args, call->attrs, {});
// Replace output of each branch with slices of the combined output
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
size_t channel_pos) {
int64_t index = 0;
for (const auto& branch : branches) {
const CallNode* conv2d = branch[0];
int64_t channels = GetConv2DSuperChannelsDim(conv2d);
Array<Integer> begin;
Array<Integer> end;
for (size_t i = 0; i < channel_pos; i++) {
index += channels;
auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{});
subst_map_[GetRef<Expr>(branch[depth])] = slice;
// Combine branches in a group. Conv2d in different branches in the same group are safe to
// combine. Subsequent ops may or may not be combined. We start from conv2d and try to
// combine ops from all branches in the same depth.
void CombineBranches(const Group& branches) {
Call combined = MakeCombinedConv2D(branches);
auto conv_param = combined-><Conv2DAttrs>();
const std::string& layout =
conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout;
size_t channel_pos = layout.find('C');
CHECK_NE(channel_pos, std::string::npos);
auto it = std::min_element(branches.begin(), branches.end(),
[](const Branch& branch_a,
const Branch& branch_b) {
return branch_a.size() < branch_b.size();
size_t depth = it->size();
size_t i;
// starting from 1 to skip the conv2d
for (i = 1; i < depth; i++) {
size_t parent_index;
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
CHECK_NE(parent_index, branches[0][i]->args.size());
if (!CheckLevel(branches, i, channel_pos, parent_index)) break;
combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index);
UpdateGroupOutput(combined, branches, i - 1, channel_pos);
Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CombineParallelConv2D(args[0]);
} // namespace relay
} // namespace tvm
* Copyright (c) 2018 by Contributors
* \file expr_subst.h
* \brief Utility functions for substituting expressions.
#include <tvm/relay/expr_functor.h>
#include "./expr_subst.h"
namespace tvm {
namespace relay {
class ExprSubstituter : public ExprMutator {
explicit ExprSubstituter(std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map)
: subst_map_(subst_map) {}
Expr VisitExpr(const Expr& expr) final {
auto it = subst_map_.find(expr);
if (it != subst_map_.end()) {
return (*it).second;
return ExprMutator::VisitExpr(expr);
tvm::Map<Expr, Expr> subst_map_;
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map) {
return ExprSubstituter(std::move(subst_map)).Mutate(expr);
} // namespace relay
} // namespace tvm
* Copyright (c) 2018 by Contributors
* \file expr_subst.h
* \brief Utility functions for substituting expressions.
#include <tvm/relay/expr.h>
#include <unordered_map>
namespace tvm {
namespace relay {
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map);
} // namespace relay
} // namespace tvm
......@@ -11,6 +11,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
#include <string>
#include "../op/layout.h"
......@@ -120,6 +121,19 @@ inline bool IsDepthwiseConv2D(const Call& call,
is_const_int(wshape[1], 1);
* \brief Get super-dimension of output channels of conv2d
* \param call The conv2d call.
* \return Super-dimension size of output channels of conv2d.
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
auto param = call-><Conv2DAttrs>();
auto tweight = call->args[1]->type_as<TensorTypeNode>();
auto index = param->weight_layout.find('O');
CHECK_NE(index, std::string::npos);
auto channels = as_const_int(tweight->shape[index]);
return *channels;
* \brief Create a Constant with a scalar
......@@ -172,6 +186,10 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
} // namespace relay
} // namespace tvm
from tvm import relay
import numpy as np
def test_combine_parallel_conv2d():
"""Simple testcase."""
def before(x, w1, w2, w3, w4):
args = [x, w1, w2, w3, w4]
y1 = relay.nn.conv2d(x, w1)
y2 = relay.nn.conv2d(x, w2)
# y3 cannot be combined
y3 = relay.nn.conv2d(x, w3)
y4 = relay.nn.conv2d(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
# use a fixed order of args so alpha equal check can pass
args = [x, w1, w2, w3, w4]
w = relay.concatenate((w1, w2, w4), axis=0)
y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4)
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
y3 = relay.nn.conv2d(x, w3)
y4 = relay.strided_slice(y, [0, channels1 + channels2],
[None, channels1 + channels2 + channels4])
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def check(x_shape, channels1, channels2, channels3, channels4):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
w3 = relay.var("w3", shape=(channels3, in_c, 3, 3))
w4 = relay.var("w4", shape=(channels4, in_c, 1, 1))
y_before = before(x, w1, w2, w3, w4)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)
check((1, 4, 16, 16), 4, 4, 4, 4)
check((1, 4, 16, 16), 4, 8, 4, 7)
def test_combine_parallel_conv2d_scale_relu():
"""Testcase of combining conv2d + scale + relu"""
def before(x, w1, w2, scale1, scale2, bias):
args = [x, w1, w2, scale1, scale2, bias]
y1 = relay.nn.conv2d(x, w1)
y1 = relay.multiply(y1, scale1)
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, w2)
y2 = relay.multiply(y2, scale2)
y2 = relay.nn.relu(y2)
y2 = relay.add(y2, bias)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2):
args = [x, w1, w2, scale1, scale2, bias]
w = relay.concatenate((w1, w2), axis=0)
scale = relay.concatenate((scale1, scale2), axis=0)
y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
y = relay.multiply(y, scale)
y = relay.nn.relu(y)
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
y2 = relay.add(y2, bias)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(x_shape, channels1, channels2):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
scale1 = relay.var("scale1", shape=(channels1, 1, 1))
scale2 = relay.var("scale2", shape=(channels2, 1, 1))
bias = relay.var("bias", shape=(channels2, 1, 1))
y_before = before(x, w1, w2, scale1, scale2, bias)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)
check((1, 4, 16, 16), 4, 8)
def test_combine_parallel_conv2d_scale():
"""Testcase of un-combinable scale"""
def before(x, w1, w2, scale1, scale2):
args = [x, w1, w2, scale1, scale2]
y1 = relay.nn.conv2d(x, w1)
y1 = relay.multiply(y1, scale1)
y2 = relay.nn.conv2d(x, w2)
y2 = relay.multiply(y2, scale2)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, scale1, scale2, channels1, channels2):
args = [x, w1, w2, scale1, scale2]
w = relay.concatenate((w1, w2), axis=0)
y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
y1 = relay.multiply(y1, scale1)
y2 = relay.multiply(y2, scale2)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(x_shape, channels1, channels2):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
scale1 = relay.var("scale1", shape=(1,))
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, scale1, scale2)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)
check((1, 4, 16, 16), 4, 8)
if __name__ == "__main__":
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment