Commit d7bc4fdd by Yao Wang Committed by Yizhi Liu

Fix x86 depthwise conv2d alter_op_layout (#3264)

* Fix x86 depthwise conv2d alter_op_layout

* Small fix

* Add test case

* Fix test

* Assert kernel layout

* Minor fix

* Add get_shape function

* Minor change
parent 770ac84e
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
import tvm
from tvm import relay
from tvm.relay.op import register_alter_op_layout
......@@ -513,6 +514,45 @@ def test_alter_layout_strided_slice():
assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_depthwise_conv2d():
"""Test depthwise_conv2d operator"""
def before():
x = relay.var("x", shape=(1, 32, 56, 56))
w = relay.var("w", shape=(32, 1, 3, 3))
y = relay.nn.conv2d(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), groups=32)
y = relay.Function(free_vars(y), y)
return y
import topi
@register_alter_op_layout("nn.conv2d", level=110)
def alter_conv2d(attrs, inputs, tinfos):
with tvm.target.create("llvm"):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay)
def expected():
x = relay.var("x", shape=(1, 32, 56, 56))
w = relay.var("w", shape=(32, 1, 3, 3))
x = relay.layout_transform(x, "NCHW", "NCHW8c")
w = relay.layout_transform(w, "OIHW", "OIHW1i8o")
y = relay.nn.contrib_depthwise_conv2d_nchwc(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3),
groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o",
out_layout="NCHW8c")
y = relay.layout_transform(y, "NCHW8c", "NCHW")
y = relay.Function(free_vars(y), y)
return y
a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)
a = alter_op_layout(a)
a = infer_type(a)
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
def test_alter_layout_prelu():
"""Test PRelu operator"""
def before():
......@@ -524,7 +564,7 @@ def test_alter_layout_prelu():
y = relay.Function(free_vars(y), y)
return y
@register_alter_op_layout("nn.conv2d", level=110)
@register_alter_op_layout("nn.conv2d", level=111)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
......@@ -571,4 +611,5 @@ if __name__ == "__main__":
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice()
test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
......@@ -26,11 +26,11 @@ from ..util import traverse_inline, get_const_tuple, get_const_int
from ..nn.util import get_pad_tuple
# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
depthwise_conv2d_nchw.fdefault)
# register customized schedule for arm cpu.
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'],
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu',
['direct', 'contrib_spatial_pack'])
def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
"""Schedule depthwise conv2d
......@@ -151,7 +151,7 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack'])
@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack'])
def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""TOPI compute callback for depthwise_conv2d nchw
......
......@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
from numbers import Integral
import tvm
from tvm.api import layout, bijective_layout
from . import tag
def traverse_inline(s, final_op, callback):
......@@ -289,3 +290,41 @@ def get_max_power2_factor(n, max_value=None):
x *= 2
n /= 2
return x
def get_shape(src_shape, src_layout, dst_layout):
"""Given a source shape, a source layout and a destination layout, infer
the destination shape.
Parameter
---------
src_shape : tuple of int or IntImm
Source shape
src_layout : str or Layout
Source layout
dst_layout : str or Layout
Destination layout
Returns
-------
dst_shape : tuple of int
Destination shape
"""
if src_layout == dst_layout:
return get_const_tuple(src_shape)
if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)
assert len(src_layout) == len(dst_layout), \
"Incompatible layout %s vs %s" % (src_layout, dst_layout)
layout_mapping = bijective_layout(src_layout, dst_layout)
dst_indices = layout_mapping.forward_index(
tvm.convert([i for i in range(len(src_layout))]))
return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
......@@ -26,7 +26,7 @@ from tvm.autotvm.task.topi_integration import deserialize_args
from tvm.autotvm.task import get_config
from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
from ..util import get_const_tuple, get_shape
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
......@@ -415,11 +415,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
dtype = data.dtype
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
is_depthwise = groups == in_channel and groups == out_channel
kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW")
is_depthwise = groups == kshape[0] and kshape[1] == 1
# only optimize for NCHW
if layout != 'NCHW':
if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW":
return None
if groups != 1 and not is_depthwise:
return None
......
......@@ -22,11 +22,12 @@ from tvm.autotvm.task import get_config
from tvm.autotvm.task.space import SplitEntity
from tvm.autotvm.task.topi_integration import deserialize_args
from .. import generic, tag
from ..generic import schedule_depthwise_conv2d_nchw
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
depthwise_conv2d_infer_layout
from ..nn.depthwise_conv2d import depthwise_conv2d_nchw, depthwise_conv2d_NCHWc, \
_get_workload, depthwise_conv2d_infer_layout
from .util import get_fp32_len
......@@ -70,6 +71,12 @@ def _fallback_schedule(cfg, wkl):
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
autotvm.register_topi_compute(depthwise_conv2d_nchw, 'cpu', 'direct',
depthwise_conv2d_nchw.fdefault)
autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'cpu', 'direct',
schedule_depthwise_conv2d_nchw.fdefault)
@autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct')
def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
layout, out_layout, out_dtype=None):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for util"""
import topi
def verify_get_shape(src_shape, src_layout, dst_layout, expect_shape):
dst_shape = topi.util.get_shape(src_shape, src_layout, dst_layout)
assert dst_shape == expect_shape, \
"Shape mismatch: expecting %s but got %s" % (expect_shape, dst_shape)
def test_get_shape():
verify_get_shape((1, 3, 224, 224), "NCHW", "NCHW", (1, 3, 224, 224))
verify_get_shape((1, 3, 224, 224), "NCHW", "NHWC", (1, 224, 224, 3))
verify_get_shape((3, 2, 32, 48, 16), "NCHW16c", "NC16cWH", (3, 2, 16, 48, 32))
verify_get_shape((2, 3, 32, 32, 16, 8), "OIHW16i8o", "HWO8oI16i", (32, 32, 2, 8, 3, 16))
if __name__ == "__main__":
test_get_shape()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment