Commit d69c6fd8 by Animesh Jain Committed by Zhi

[Relay][AlterOp] NHWC to NCHWc support for Pool, pad, concatenate, sum. (#4059)

parent aa424139
...@@ -748,10 +748,12 @@ class OperatorConverter(object): ...@@ -748,10 +748,12 @@ class OperatorConverter(object):
elif padding == Padding.SAME: elif padding == Padding.SAME:
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
(pad_top, pad_bottom), if do_pad:
(pad_left, pad_right), in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
(0, 0))) (pad_top, pad_bottom),
(pad_left, pad_right),
(0, 0)))
else: else:
raise tvm.error.OpAttributeUnImplemented( raise tvm.error.OpAttributeUnImplemented(
'Padding format {} is not supported for operator Conv.'.format(padding)) 'Padding format {} is not supported for operator Conv.'.format(padding))
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -47,15 +47,9 @@ Array<Array<Layout> > Pool2DInferCorrectLayout( ...@@ -47,15 +47,9 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
T *params = const_cast<T*>(attrs.as<T>()); T *params = const_cast<T*>(attrs.as<T>());
if (new_in_layouts.defined()) { if (new_in_layouts.defined()) {
// Set the pool with the new layout.
CHECK_EQ(new_in_layouts.size(), 1); CHECK_EQ(new_in_layouts.size(), 1);
params->layout = new_in_layouts[0].name();
Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
} }
Layout inferred_layout(params->layout); Layout inferred_layout(params->layout);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -119,6 +119,59 @@ Array<Integer> GetExcludeAxes(size_t indim, ...@@ -119,6 +119,59 @@ Array<Integer> GetExcludeAxes(size_t indim,
return r_axes; return r_axes;
} }
// Return the modified layout for AlterOpLayout pass.
Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>>& old_in_shapes) {
// NOTE: Discard "const" qualifier here.
ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());
// Get the reduce axes.
uint32_t indim = old_in_shapes[0].size();
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
Layout ret = Layout::Undef();
if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_EQ(old_in_layouts.size(), 1);
// 1) Collect the original axes
std::unordered_set<std::string> old_r_dims;
for (auto r_axis : r_axes) {
old_r_dims.emplace(old_in_layouts[0][r_axis].name());
}
// 2) Collect the new axes by walking new_layout.
tvm::Array<tvm::Integer> new_r_axes;
std::string new_layout_string = "";
int axis_index = 0;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
if (old_r_dims.count(layout_dim)) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
new_layout_string += layout_dim;
axis_index++;
}
}
// 3) Set the new axis and layout.
ret = Layout(new_layout_string);
params->axis = new_r_axes;
} else if (old_in_layouts.defined()) {
// If the new layout is undefined, set the old layout as the inferred layout.
CHECK_EQ(old_in_layouts.size(), 1);
ret = old_in_layouts[0];
}
return Array<Array<Layout>>{{ret}, {ret}};
}
template<typename F> template<typename F>
Array<Tensor> ReduceCompute(const Attrs& attrs, Array<Tensor> ReduceCompute(const Attrs& attrs,
...@@ -325,6 +378,7 @@ Example:: ...@@ -325,6 +378,7 @@ Example::
.set_attrs_type_key("relay.attrs.ReduceAttrs") .set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4) .set_support_level(4)
.add_type_rel("Reduce", ReduceRel) .add_type_rel("Reduce", ReduceRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
.set_attr<FTVMCompute>("FTVMCompute", SumCompute) .set_attr<FTVMCompute>("FTVMCompute", SumCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce); .set_attr<TOpPattern>("TOpPattern", kCommReduce);
......
...@@ -283,22 +283,34 @@ Array<Array<Layout>> ConcatenateLayout( ...@@ -283,22 +283,34 @@ Array<Array<Layout>> ConcatenateLayout(
const Array<Layout>& new_in_layouts, const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts, const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) { const Array<Array<IndexExpr>> &old_in_shapes) {
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>(); ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis); static_cast<size_t>(param->axis);
Layout ret; Layout ret;
bool is_new_layout_selected = false;
if (new_in_layouts.defined()) { // this function is called after some operators are alternated. if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
// to the new input layout.
const auto& concate_dim = old_in_layouts[0][axis]; const auto& concate_dim = old_in_layouts[0][axis];
for (size_t i = 0; i < new_in_layouts.size(); ++i) { bool all_input_layouts_same = true;
if (new_in_layouts[i].ndim() > axis && for (auto new_layout : new_in_layouts) {
new_in_layouts[i][axis] == concate_dim) { if (!new_layout.Equals(new_in_layouts[0])) {
ret = new_in_layouts[i]; all_input_layouts_same = false;
break;
} }
} }
} else { // this function is called on the original correct relay ir if (all_input_layouts_same) {
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
ret = new_in_layouts[0];
param->axis = new_index;
is_new_layout_selected = true;
}
}
if (!is_new_layout_selected) {
// this function is called on the original correct relay ir
for (size_t i = 0; i < old_in_layouts.size(); ++i) { for (size_t i = 0; i < old_in_layouts.size(); ++i) {
if (old_in_layouts[i].defined()) { if (old_in_layouts[i].defined()) {
ret = old_in_layouts[i]; ret = old_in_layouts[i];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment