Commit 2d0010f3 by Zhao Wu Committed by Thierry Moreau

[ARM CPU] Fix infer shape error of depthwise (#4384)

* [ARM CPU] Fix contrib_spatial_pack error

* PyLint error fix

* diable no-else-return as other files

* Change the test case split OC not be 1 to cover 5D weight layout
parent 651bdf2f
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, too-many-arguments
# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments
"""Backend compiler related feature registration"""
from __future__ import absolute_import
......@@ -163,10 +163,17 @@ def compute_conv2d(attrs, inputs, out_type, target):
def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
# NHWC layout
if kernel_layout.startswith("HW"):
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]
# NCHW layout.
# in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
if len(weight_shape) == 4:
return weight_shape[0] * weight_shape[1]
else:
assert len(weight_shape) == 5
C, M, _, _, VC = weight_shape
return C * VC * M
if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
......
......@@ -158,7 +158,7 @@ def test_conv2d_run():
["depthwise_conv2d_nchw", [1, 512, 32, 32, "float32"], \
[512, 1, 3, 3, "float32"], [1, 1], [1, 1], [1, 1], "float32"], \
{"i": 743640, "t": "contrib_spatial_pack", "c": null, \
"e": [["tile_co", "sp", [512, 1]], ["tile_oh", "sp", [8, 1]], \
"e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [8, 1]], \
["tile_ow", "sp", [1, 8]], \
["reorder_0", "re", [0, 1, 2, 3, 4, 5, 8, 6, 7]], \
["reorder_1", "re", [0, 1, 2, 3, 6, 4, 5]], \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment