Commit 5d70b008 by Wuwei Lin Committed by Yizhi Liu

[Relay] InferCorrectLayout for strided_slice & min_num_branches option in…

[Relay] InferCorrectLayout for strided_slice & min_num_branches option in CombineParallelConv2D (#2961)

* [Relay] InferCorrectLayout for strided_slice

* Add min_num_branches option to CombineParallelConv2D

* Return undef if original layout contains splitted axes
parent 552d4aa3
......@@ -722,20 +722,23 @@ def fuse_ops(expr, opt_level=1):
return _ir_pass.FuseOps(expr, opt_level)
def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
def combine_parallel_conv2d(expr, min_num_branches=3):
"""Combine multiple conv2d into one.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
min_num_branches : int
The minimum number of parallel branches when the transformation should be applied.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
return _ir_pass.CombineParallelConv2D(expr, min_num_branches)
def alter_op_layout(expr):
......
......@@ -1722,6 +1722,64 @@ bool StridedSliceRel(const Array<Type>& types,
}
Array<Array<Layout> > StridedSliceInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
CHECK(old_in_layouts.defined());
CHECK_EQ(old_in_layouts.size(), 1);
CHECK(old_in_shapes.defined());
CHECK_EQ(old_in_shapes.size(), 1);
auto layout = old_in_layouts[0];
if (layout.defined() && new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);
auto new_layout = new_in_layouts[0];
auto shape = old_in_shapes[0];
// NOTE: Discard "const" qualifier here.
auto *params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
Array<Integer> new_begin, new_end;
for (size_t i = 0; i < params->begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(params->begin[i]);
new_end.push_back(params->end[i]);
} else {
if (params->strides.defined() && i < params->strides.size()) {
auto stride = params->strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}
}
int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
int64_t end = params->end[i].defined() ? params->end[i]->value :
shape[i].as<IntImm>()->value;
if (begin % factor || end % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
}
new_begin.push_back(tvm::Integer(begin / factor));
new_end.push_back(tvm::Integer(end / factor));
}
}
layout = new_layout;
params->begin = new_begin;
params->end = new_end;
}
return {{layout}, {layout}};
}
// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data,
Array<Integer> begin,
......@@ -1783,7 +1841,8 @@ Examples::
.set_attrs_type_key("relay.attrs.StridedSliceAttrs")
.add_type_rel("StridedSlice", StridedSliceRel)
.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
// relay.split
......
......@@ -159,10 +159,15 @@ class BranchGroupFinder : private ExprVisitor {
class ParallelConv2DCombiner {
public:
explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) {
}
Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() < 2) continue;
if (group.size() < min_num_branches_) {
continue;
}
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
......@@ -170,6 +175,7 @@ class ParallelConv2DCombiner {
private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
uint64_t min_num_branches_;
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
......@@ -343,11 +349,14 @@ class ParallelConv2DCombiner {
}
};
Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }
/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */
Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
return ParallelConv2DCombiner(min_num_branches).Combine(expr);
}
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CombineParallelConv2D(args[0]);
*ret = CombineParallelConv2D(args[0], args[1]);
});
} // namespace relay
......
......@@ -472,6 +472,48 @@ def test_alter_layout_nchw_upsamping_op():
assert(alpha_equal(a, b))
def test_alter_layout_strided_slice():
"""Test rewriting strided_slice during alter_iop_layout"""
def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.strided_slice(y, begin=[0, 16], end=[None, None])
y = relay.Function(free_vars(y), y)
return y
@register_alter_op_layout("nn.conv2d", level=109)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW4c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var("weight")
x = relay.layout_transform(x, "NCHW", "NCHW4c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW4c")
y = relay.strided_slice(y, begin=[0, 4], end=[None, 8])
y = relay.layout_transform(y, "NCHW4c", "NCHW")
y = relay.Function(free_vars(y), y)
return y
a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)
a = alter_op_layout(a)
a = infer_type(a)
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
......@@ -482,3 +524,4 @@ if __name__ == "__main__":
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice()
......@@ -55,7 +55,7 @@ def test_combine_parallel_conv2d():
y_before = before(x, w1, w2, w3, w4)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = relay.ir_pass.infer_type(y_expected)
......@@ -102,7 +102,7 @@ def test_combine_parallel_conv2d_scale_relu():
bias = relay.var("bias", shape=(channels2, 1, 1))
y_before = before(x, w1, w2, scale1, scale2, bias)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
......@@ -142,7 +142,7 @@ def test_combine_parallel_conv2d_scale():
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, scale1, scale2)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
......@@ -179,7 +179,7 @@ def test_combine_parallel_conv2d_multiple_blocks():
w = relay.var("w", shape=(out_c, in_c, 1, 1))
y_before = before(x, w, repeat)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w, out_c, repeat)
y_expected = relay.ir_pass.infer_type(y_expected)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment