Commit d69c6fd8 by Animesh Jain Committed by Zhi

[Relay][AlterOp] NHWC to NCHWc support for Pool, pad, concatenate, sum. (#4059)

parent aa424139
......@@ -748,6 +748,8 @@ class OperatorConverter(object):
elif padding == Padding.SAME:
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
if do_pad:
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
(pad_top, pad_bottom),
(pad_left, pad_right),
......
......@@ -47,15 +47,9 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
T *params = const_cast<T*>(attrs.as<T>());
if (new_in_layouts.defined()) {
// Set the pool with the new layout.
CHECK_EQ(new_in_layouts.size(), 1);
Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
params->layout = new_in_layouts[0].name();
}
Layout inferred_layout(params->layout);
......
......@@ -119,6 +119,59 @@ Array<Integer> GetExcludeAxes(size_t indim,
return r_axes;
}
// Return the modified layout for AlterOpLayout pass.
Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
// NOTE: Discard "const" qualifier here.
ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());
// Get the reduce axes.
uint32_t indim = old_in_shapes[0].size();
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
Layout ret = Layout::Undef();
if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_EQ(old_in_layouts.size(), 1);
// 1) Collect the original axes
std::unordered_set<std::string> old_r_dims;
for (auto r_axis : r_axes) {
old_r_dims.emplace(old_in_layouts[0][r_axis].name());
}
// 2) Collect the new axes by walking new_layout.
tvm::Array<tvm::Integer> new_r_axes;
std::string new_layout_string = "";
int axis_index = 0;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
if (old_r_dims.count(layout_dim)) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
new_layout_string += layout_dim;
axis_index++;
}
}
// 3) Set the new axis and layout.
ret = Layout(new_layout_string);
params->axis = new_r_axes;
} else if (old_in_layouts.defined()) {
// If the new layout is undefined, set the old layout as the inferred layout.
CHECK_EQ(old_in_layouts.size(), 1);
ret = old_in_layouts[0];
}
return Array<Array<Layout>>{{ret}, {ret}};
}
template<typename F>
Array<Tensor> ReduceCompute(const Attrs& attrs,
......@@ -325,6 +378,7 @@ Example::
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
......
......@@ -283,22 +283,34 @@ Array<Array<Layout>> ConcatenateLayout(
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);
Layout ret;
bool is_new_layout_selected = false;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
// to the new input layout.
const auto& concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
if (new_in_layouts[i].ndim() > axis &&
new_in_layouts[i][axis] == concate_dim) {
ret = new_in_layouts[i];
break;
bool all_input_layouts_same = true;
for (auto new_layout : new_in_layouts) {
if (!new_layout.Equals(new_in_layouts[0])) {
all_input_layouts_same = false;
}
}
if (all_input_layouts_same) {
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
ret = new_in_layouts[0];
param->axis = new_index;
is_new_layout_selected = true;
}
} else { // this function is called on the original correct relay ir
}
if (!is_new_layout_selected) {
// this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i];
......
......@@ -45,6 +45,7 @@ def test_alter_op():
y = relay.Function([x, weight], y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=100)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -79,6 +80,7 @@ def test_alter_return_none():
called = [False]
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.global_max_pool2d", level=101)
def alter_conv2d(attrs, inputs, tinfos):
called[0] = True
......@@ -112,6 +114,7 @@ def test_alter_layout():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=102)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -180,6 +183,7 @@ def test_alter_layout_dual_path():
y = relay.Function(analysis.free_vars(ret), ret)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=103)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -241,6 +245,7 @@ def test_alter_layout_resnet():
y = relay.nn.global_max_pool2d(y)
return relay.Function(analysis.free_vars(y), y)
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=104)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -291,6 +296,7 @@ def test_alter_layout_broadcast_op():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=105)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -338,6 +344,7 @@ def test_alter_layout_scalar():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=106)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -370,9 +377,19 @@ def test_alter_layout_scalar():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate():
""" """
def before():
""" NCHW, NHWC and corner case concatenate layout transform."""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=107)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
# NCHW layout transformation.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
weight2 = relay.var('weight2')
......@@ -388,14 +405,7 @@ def test_alter_layout_concatenate():
y = relay.Function(analysis.free_vars(ret), ret)
return y
@register_alter_op_layout("nn.conv2d", level=107)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
def expected_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
weight2 = relay.var('weight2')
......@@ -415,10 +425,57 @@ def test_alter_layout_concatenate():
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before()
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected()
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# NHWC layout transformation.
def before_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
weight2 = relay.var('weight2')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC')
y1 = relay.nn.conv2d(y, weight2,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC')
ret = relay.concatenate([y, y1], axis=3)
y = relay.Function(analysis.free_vars(ret), ret)
return y
def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
weight2 = relay.var('weight2')
y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
y1 = relay.nn.conv2d(y, weight2,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NCHW16c')
ret = relay.concatenate([y, y1], axis=1)
ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
......@@ -435,6 +492,7 @@ def test_alter_layout_nchw_upsamping_op():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=108)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -474,6 +532,7 @@ def test_alter_layout_strided_slice():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=109)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -511,6 +570,7 @@ def test_alter_layout_depthwise_conv2d():
return y
import topi
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=110)
def alter_conv2d(attrs, inputs, tinfos):
with tvm.target.create("llvm"):
......@@ -548,6 +608,7 @@ def test_alter_layout_prelu():
y = relay.Function(analysis.free_vars(y), y)
return y
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=111)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
......@@ -580,6 +641,167 @@ def test_alter_layout_prelu():
assert(analysis.alpha_equal(a, b))
def test_alter_layout_pool():
""" Check NCHW, NHWC pool layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=113)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
ret = relay.nn.avg_pool2d(y, pool_size=(1, 1))
y = relay.Function(analysis.free_vars(ret), ret)
return y
def expected_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c')
ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC')
ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NHWC')
y = relay.Function(analysis.free_vars(ret), ret)
return y
def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c')
ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_sum():
""" Check NCHW, NHWC sum layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=114)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
# Check NCHW conversion.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
ret = relay.sum(y, axis=1, keepdims=True)
y = relay.Function(analysis.free_vars(ret), ret)
return y
def expected_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.layout_transform(y, "NCHW16c", "NCHW")
ret = relay.sum(ret, axis=[1], keepdims=True)
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nchw()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC')
ret = relay.sum(y, axis=3, keepdims=True)
y = relay.Function(analysis.free_vars(ret), ret)
return y
def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.layout_transform(y, "NCHW16c", "NCHW")
ret = relay.sum(ret, axis=[1], keepdims=True)
ret = relay.layout_transform(ret, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(ret), ret)
return y
a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())
b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
......@@ -593,3 +815,5 @@ if __name__ == "__main__":
test_alter_layout_strided_slice()
test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
test_alter_layout_pool()
test_alter_layout_sum()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment