Commit 672b0909 by 黎明灰烬 Committed by Tianqi Chen

[TOPI][AutoTVM] NHWC conv2d templates for ARM (#3859)

* [AutoTVM][TOPI] NHWC conv2d templates (spatial pack) for ARM

As some frontends (tflite for example) are using NHWC as the default
layout, we are enabling NHWC schedule templates in TOPI and AutoTVM.

* some comments fix
parent 373ab314
...@@ -182,12 +182,15 @@ class TaskExtractEnv: ...@@ -182,12 +182,15 @@ class TaskExtractEnv:
args = deserialize_args(args) args = deserialize_args(args)
A, W = args[:2] A, W = args[:2]
layout = args[-2] layout = args[-2]
assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs) C = topi.nn.conv2d(*args, **kwargs)
if layout == 'NCHW': if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C]) s = topi.generic.schedule_conv2d_nchw([C])
else: elif layout == 'HWCN':
s = topi.generic.schedule_conv2d_hwcn([C]) s = topi.generic.schedule_conv2d_hwcn([C])
elif layout == 'NHWC':
s = topi.generic.schedule_conv2d_nhwc([C])
else:
raise ValueError("Unsupported layout {}".format(layout))
return s, [A, W, C] return s, [A, W, C]
@register("topi_nn_depthwise_conv2d_nchw") @register("topi_nn_depthwise_conv2d_nchw")
......
...@@ -24,7 +24,8 @@ import tvm ...@@ -24,7 +24,8 @@ import tvm
from tvm import autotvm from tvm import autotvm
import tvm.contrib.nnpack import tvm.contrib.nnpack
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \ from ..generic import schedule_conv2d_nchw, schedule_conv2d_nhwc, \
schedule_conv2d_winograd_without_weight_transform, \
schedule_conv2d_winograd_nnpack_without_weight_transform schedule_conv2d_winograd_nnpack_without_weight_transform
from ..util import traverse_inline, get_const_tuple from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
...@@ -34,7 +35,9 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ ...@@ -34,7 +35,9 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
from ..nn.util import get_const_int, get_pad_tuple from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices from ..nn.winograd_util import winograd_transform_matrices
from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \ from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
schedule_conv2d_spatial_pack_nchw conv2d_spatial_pack_nhwc, \
schedule_conv2d_spatial_pack_nchw, \
schedule_conv2d_spatial_pack_nhwc
logger = logging.getLogger('topi') logger = logging.getLogger('topi')
...@@ -78,6 +81,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -78,6 +81,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if layout == 'NCHW': if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=2) dilation, out_dtype, num_tile=2)
elif layout == 'NHWC':
return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding,
dilation, out_dtype)
else: else:
raise ValueError("Unsupported layout {}".format(layout)) raise ValueError("Unsupported layout {}".format(layout))
...@@ -136,6 +142,34 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs): ...@@ -136,6 +142,34 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
traverse_inline(s, outs[0].op, _callback) traverse_inline(s, outs[0].op, _callback)
return s return s
@autotvm.register_topi_schedule(schedule_conv2d_nhwc, 'arm_cpu', ['direct'])
def schedule_conv2d_nhwc_arm_cpu(cfg, outs):
"""TOPI schedule callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if 'spatial_conv_output_NHWC' in op.tag:
schedule_conv2d_spatial_pack_nhwc(cfg, s, op, outs[0])
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd']) @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" TOPI compute callback. Use winograd template """ """ TOPI compute callback. Use winograd template """
......
...@@ -196,3 +196,163 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, ...@@ -196,3 +196,163 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
s[kernel_vec].parallel(co) s[kernel_vec].parallel(co)
return s return s
def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Spatial pack compute for Conv2d NHWC"""
out_dtype = out_dtype or data.dtype
N, IH, IW, IC = get_const_tuple(data.shape)
assert len(kernel.shape) == 4, "AlterOpLayout not enabled for NHWC yet"
KH, KW, _, OC = get_const_tuple(kernel.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = \
get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0])
# ==================== define configuration space ====================
n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW)
ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
oco, oci = cfg.define_split('tile_co', oc, num_outputs=2)
oho, ohi = cfg.define_split('tile_oh', oh, num_outputs=2)
owo, owi = cfg.define_split('tile_ow', ow, num_outputs=2)
cfg.define_reorder('reorder_conv',
[n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
policy='candidate', candidate=[
[n, oho, owo, oco, kh, kw, ic, ohi, owi, oci],
[n, oho, owo, oco, ohi, kh, kw, ic, owi, oci],
[n, oho, owo, oco, ohi, kh, kw, owi, ic, oci],
[n, oho, owo, ohi, oco, kh, kw, owi, ic, oci]])
cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy='try_unroll_vec')
# ====================================================================
OCI = cfg['tile_co'].size[-1]
OHI = cfg['tile_oh'].size[-1]
OWI = cfg['tile_ow'].size[-1]
OCO = OC // OCI
OHO = OH // OHI
OWO = OW // OWI
kvshape = (OCO, KH, KW, IC, OCI)
ovshape = (N, OHO, OWO, OCO, OHI, OWI, OCI)
oshape = (N, OH, OW, OC)
if dilation_h != 1 or dilation_w != 1:
# undilate input data
dvshape = (N, OHO, OWO, KH, KW, IC, OHI, OWI)
data_vec = tvm.compute(dvshape, lambda n, oho, owo, kh, kw, ic, ohi, owi:
data_pad[n][(oho*OHI+ohi)*HSTR+kh*dilation_h]
[(owo*OWI+owi)*WSTR+kw*dilation_w][ic],
name='data_vec_undilated')
else:
dvshape = (N, OHO, OWO, KH + (OHI-1)*HSTR, KW + (OWI-1)*WSTR, IC)
data_vec = tvm.compute(dvshape, lambda n, oho, owo, ohi, owi, ic:
data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic],
name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda oco, kh, kw, ic, oci: \
kernel[kh][kw][ic][oco*OCI+oci],
name='kernel_vec')
ic = tvm.reduce_axis((0, IC), name='ic')
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')
if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
tvm.sum(data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) *
kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \
tvm.sum(data_vec[n, oho, owo, ohi*HSTR+kh, owi*WSTR+kw, ic].astype(out_dtype) *
kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')
idiv = tvm.indexdiv
imod = tvm.indexmod
output = tvm.compute(oshape, lambda n, oho, owo, oc:
conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)]\
[imod(oho, OHI)][imod(owo, OWI)][imod(oc, OCI)],
name='output_unpack', tag='spatial_conv_output_NHWC')
return output
def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output):
"""Spatial Pack schedule for Conv2d NHWC"""
unpack = op.output(0)
conv = unpack.op.input_tensors[0]
data_vec = conv.op.input_tensors[0]
kernel_vec = conv.op.input_tensors[1]
data_pad = data_vec.op.input_tensors[0]
OHI = cfg['tile_oh'].size[-1]
OWI = cfg['tile_ow'].size[-1]
OCI = cfg['tile_co'].size[-1]
# schedule unpack/output
if output != unpack:
s[unpack].compute_inline()
n, oh, ow, oc = s[output].op.axis
oco, oci = cfg['tile_co'].apply(s, output, oc)
oho, ohi = cfg['tile_oh'].apply(s, output, oh)
owo, owi = cfg['tile_ow'].apply(s, output, ow)
s[output].reorder(n, oho, owo, oco, ohi, owi, oci)
cfg['ann_spatial'].apply(s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
max_unroll=16, cfg=cfg)
cfg.define_knob('compat', [0, 1, 2])
if cfg['compat'].val < 2:
compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
s[conv].compute_at(s[output], compat_axis)
paxis = s[output].fuse(n, oho)
s[output].parallel(paxis)
# schedule conv
n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis
ic, kh, kw = s[conv].op.reduce_axis
cfg['reorder_conv'].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci])
cfg['ann_reduce'].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kh.dom.extent),
get_const_int(kw.dom.extent)],
max_unroll=16,
cfg=cfg)
cfg['ann_spatial'].apply(s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI],
max_unroll=16, cfg=cfg)
if cfg['compat'].val < 2:
compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706
s[kernel_vec].compute_at(s[conv], compat_axis)
s[data_vec].compute_at(s[conv], compat_axis)
# schedule kernel pack
oco, kh, kw, ic, oci = kernel_vec.op.axis
s[kernel_vec].vectorize(oci)
s[kernel_vec].unroll(ic)
if cfg['compat'].val == 2:
s[kernel_vec].parallel(oco)
# schedule data pack
if data_vec.op.name == 'data_vec_undilated':
n, oho, owo, kh, kw, ic, ohi, owi = s[data_vec].op.axis
s[data_vec].vectorize(owi)
s[data_vec].unroll(ohi)
else:
n, oho, owo, ohi, owi, ic = s[data_vec].op.axis
s[data_vec].vectorize(ic)
s[data_vec].unroll(owi)
if cfg['compat'].val == 2:
paxis = s[data_vec].fuse(n, oho)
s[data_vec].parallel(paxis)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment