Unverified Commit 1265983c by Animesh Jain Committed by GitHub

[TOPI] Using x86 schedules for ARM conv2d. (#5334)

parent b1364ebb
......@@ -54,10 +54,16 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
if groups == 1:
if layout == "NCHW":
if kernel_layout == "OIHW":
# ARM conv2d spatial pack schedule.
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu")
# Intel x86 conv2d schedule.
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
name="conv2d_nchw.x86")
# check if winograd algorithm is applicable
_, _, kh, kw = get_const_tuple(kernel.shape)
pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw))
......@@ -100,11 +106,13 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout)
# ARM conv2d depthwise schedule
if kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.arm_cpu")
# TODO:
# This schedule has incorrect result on some hardware platforms (like NV Jetson TX2)
# Let us comment it out but not remove.
......@@ -115,6 +123,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
# wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw_spatial_pack),
# name="depthwise_conv2d_nchw_spatial_pack.arm_cpu",
# plevel=15)
# Intel x86 depthwise conv2d schedule.
channel_multiplier = get_const_tuple(inputs[1].shape)[1]
if channel_multiplier == 1 and dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.x86")
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.")
......@@ -138,6 +154,26 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
format(layout))
return strategy
@conv2d_NCHWc_strategy.register("arm_cpu")
def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d_NCHWc adopted from x86"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86")
return strategy
@depthwise_conv2d_NCHWc_strategy.register("arm_cpu")
def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
"""depthwise_conv2d_NCHWc adopted from x86"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.x86")
return strategy
def wrap_compute_conv2d_winograd_nnpack(topi_compute):
"""wrap topi compute for conv2d_winograd NNPack"""
def _compute_conv2d_nnpack(attrs, inputs, out_type):
......
......@@ -26,7 +26,7 @@ from tvm import autotvm
from ..nn import conv2d_alter_layout
from ..util import get_const_tuple
from ..x86.conv2d import _get_default_config as _get_x86_default_config
logger = logging.getLogger('topi')
......@@ -59,6 +59,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
data, kernel = tinfos
out_dtype = out_type.dtype
# Extract data types
data_tensor, kernel_tensor = tinfos
data_dtype = data_tensor.dtype
kernel_dtype = kernel_tensor.dtype
idxd = tvm.tir.indexdiv
if topi_tmpl == "conv2d_nchw_spatial_pack.arm_cpu":
......@@ -169,4 +174,60 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.conv2d(*inputs, **new_attrs)
if topi_tmpl == "conv2d_NCHWc.x86":
# Converting NCHW to NCHWc.
assert data_layout == "NCHW" and kernel_layout == "OIHW"
if cfg.is_fallback:
_get_x86_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
out_dtype, False, data_layout)
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
# update new attrs
new_attrs['channels'] = out_channel
new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
# Store altered operator's config
new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data_dtype)
new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn,
kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
new_attrs["out_layout"], out_dtype], topi_tmpl)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
if topi_tmpl == "depthwise_conv2d_NCHWc.x86":
# Converting NCHW to NCHWc.
assert data_layout == "NCHW" and kernel_layout == "OIHW"
if cfg.is_fallback:
_get_x86_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
out_dtype, True, data_layout)
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
assert channel_multiplier == 1
# update new attrs
new_attrs['channels'] = out_channel
new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
# Store altered operator's config.
new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data_dtype)
new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
new_attrs['out_layout'], out_dtype], topi_tmpl)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
return None
......@@ -169,7 +169,8 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo
cfg.define_split("tile_ic", in_channel, num_outputs=2)
cfg.define_split("tile_oc", num_filter, num_outputs=2)
cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64,
policy="verbose")
if is_kernel_1x1:
cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment