Commit f1d2d46b by 黎明灰烬 Committed by Lianmin Zheng

[TOPI] Move conv2d spatial pack schedule to dedicated file (#3972)

More schedules are making the conv2d.py file too large, so
we'd like to move the spatial pack schedule to dedicated file
before introducing NHWC schedule. No logic change in this patch.
parent 4a3abb94
......@@ -35,6 +35,8 @@ from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \
schedule_conv2d_spatial_pack_nchw
logger = logging.getLogger('topi')
......@@ -75,8 +77,11 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
num_tile=2)
if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=2)
else:
raise ValueError("Unsupported layout {}".format(layout))
@autotvm.register_topi_schedule(
......@@ -119,7 +124,8 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
_schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
conv, output, outs[0])
if 'winograd_conv2d_output' in op.tag:
output = op.output(0)
......@@ -132,179 +138,6 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
traverse_inline(s, outs[0].op, _callback)
return s
def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, num_tile):
assert layout == "NCHW", "Only support NCHW"
# create workload according to raw arguments
out_dtype = out_dtype or data.dtype
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
if len(kernel.shape) == 4:
pre_packed = False
CO, _, KH, KW = get_const_tuple(kernel.shape)
else: # kernel tensor is pre packed
pre_packed = True
CO, _, KH, KW, VC = get_const_tuple(kernel.shape)
CO = CO * VC
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
data_pad = pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])
# ==================== define configuration space ====================
n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
if num_tile == 2: # for arm cpu
co, vc = cfg.define_split('tile_co', co, num_outputs=2)
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
elif num_tile == 3: # for mali gpu
co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
else:
raise RuntimeError("Invalid num_tile")
cfg.define_reorder("reorder_0",
[n, co, oh, ow, ci, kh, kw, vh, vw, vc],
policy='candidate', candidate=[
[n, co, oh, ow, ci, kh, kw, vh, vw, vc],
[n, co, oh, ow, ci, kh, kw, vc, vh, vw]])
cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
# fallback support
if cfg.is_fallback:
if num_tile == 2: # arm cpu
ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
cfg.fallback_with_reference_log(ref_log)
elif num_tile == 3: # mali gpu
ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
cfg.fallback_with_reference_log(ref_log)
# ====================================================================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
kvshape = (CO // VC, CI, KH, KW, VC)
ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
oshape = (N, CO, OH, OW)
if dilation_h != 1 or dilation_w != 1:
# undilate input data
dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
[(w*VW+vw)*WSTR+kw*dilation_w],
name='data_vec_undilated')
else:
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
name='data_vec')
if pre_packed:
kernel_vec = kernel
else:
kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc:
kernel[co*VC+vc][ci][kh][kw],
name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')
if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_conv2d_output')
return output
def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
conv, output, last):
"""schedule implementation"""
n, co, oh, ow, vh, vw, vc = s[conv].op.axis
ci, kh, kw = s[conv].op.reduce_axis
# schedule conv
cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
cfg["ann_reduce"].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kh.dom.extent),
get_const_int(kw.dom.extent)],
max_unroll=16,
cfg=cfg)
cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
# schedule fusion
n, co, h, w = s[last].op.axis
co, vc = cfg['tile_co'].apply(s, last, co)
oh, vh = cfg['tile_oh'].apply(s, last, h)
ow, vw = cfg['tile_ow'].apply(s, last, w)
s[last].reorder(n, co, oh, ow, vh, vw, vc)
if last != output:
s[output].compute_inline()
cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
s[conv].compute_at(s[last], ow)
# mark parallel
p = s[last].fuse(n, co)
s[last].parallel(p)
if data_vec.op.name == 'data_vec_undilated':
n, h, _, _, _, _, _, _ = s[data_vec].op.axis
else:
n, h, _, _, _, _ = s[data_vec].op.axis
p = s[data_vec].fuse(n, h)
s[data_vec].parallel(p)
if kernel_vec.op.name == 'kernel_vec':
co, _, _, _, _ = s[kernel_vec].op.axis
if autotvm.GLOBAL_SCOPE.in_tuning:
# kernel packing will be pre-computed during compilation, so we skip
# this part to make tuning records correct
s[kernel_vec].pragma(co, 'debug_skip_region')
else:
s[kernel_vec].parallel(co)
elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose': # for conv2d transpose
co, _, _, _, _ = s[kernel_vec].op.axis
s[kernel_vec].parallel(co)
return s
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd'])
def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" TOPI compute callback. Use winograd template """
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,no-else-return
"""Conv2D spatial pack implementation for ARM CPU"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from .. import nn
from ..util import get_const_tuple
from ..nn.util import get_const_int, get_pad_tuple
def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation,
out_dtype, num_tile):
"""compute define for Conv2d Spatial Pack with NCHW layout"""
out_dtype = out_dtype or data.dtype
N, CI, IH, IW = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
if len(kernel.shape) == 4:
pre_packed = False
CO, _, KH, KW = get_const_tuple(kernel.shape)
else: # kernel tensor is pre packed
pre_packed = True
CO, _, KH, KW, VC = get_const_tuple(kernel.shape)
CO = CO * VC
dilated_kernel_h = (KH - 1) * dilation_h + 1
dilated_kernel_w = (KW - 1) * dilation_w + 1
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
OH = (IH + pad_top + pad_bottom - dilated_kernel_h) // HSTR + 1
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_bottom, pad_right])
# ==================== define configuration space ====================
n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW)
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
if num_tile == 2: # for arm cpu
co, vc = cfg.define_split('tile_co', co, num_outputs=2)
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
elif num_tile == 3: # for mali gpu
co, _, vc = cfg.define_split('tile_co', co, num_outputs=3)
oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3)
ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3)
else:
raise RuntimeError("Invalid num_tile")
cfg.define_reorder("reorder_0",
[n, co, oh, ow, ci, kh, kw, vh, vw, vc],
policy='candidate', candidate=[
[n, co, oh, ow, ci, kh, kw, vh, vw, vc],
[n, co, oh, ow, ci, kh, kw, vc, vh, vw]])
cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
# fallback support
if cfg.is_fallback:
if num_tile == 2: # arm cpu
ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
cfg.fallback_with_reference_log(ref_log)
elif num_tile == 3: # mali gpu
ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
cfg.fallback_with_reference_log(ref_log)
# ====================================================================
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]
kvshape = (CO // VC, CI, KH, KW, VC)
ovshape = (N, CO // VC, OH // VH, OW // VW, VH, VW, VC)
oshape = (N, CO, OH, OW)
if dilation_h != 1 or dilation_w != 1:
# undilate input data
dvshape = (N, OH // VH, OW // VW, CI, KH, KW, VH, VW)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, kh, kw, vh, vw:
data_pad[n][ci][(h*VH+vh)*HSTR+kh*dilation_h]
[(w*VW+vw)*WSTR+kw*dilation_w],
name='data_vec_undilated')
else:
dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1)
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw:
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw],
name='data_vec')
if pre_packed:
kernel_vec = kernel
else:
kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc:
kernel[co*VC+vc][ci][kh][kw],
name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
kh = tvm.reduce_axis((0, KH), name='kh')
kw = tvm.reduce_axis((0, KW), name='kw')
if dilation_h != 1 or dilation_w != 1:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, kh, kw, vh, vw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
else:
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) *
kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
axis=[ci, kh, kw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_conv2d_output')
return output
def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
conv, output, last):
"""schedule implementation"""
n, co, oh, ow, vh, vw, vc = s[conv].op.axis
ci, kh, kw = s[conv].op.reduce_axis
# schedule conv
cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc])
cfg["ann_reduce"].apply(s, conv, [kh, kw],
axis_lens=[get_const_int(kh.dom.extent),
get_const_int(kw.dom.extent)],
max_unroll=16,
cfg=cfg)
cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
# schedule fusion
n, co, h, w = s[last].op.axis
co, vc = cfg['tile_co'].apply(s, last, co)
oh, vh = cfg['tile_oh'].apply(s, last, h)
ow, vw = cfg['tile_ow'].apply(s, last, w)
s[last].reorder(n, co, oh, ow, vh, vw, vc)
if last != output:
s[output].compute_inline()
cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
axis_lens=[cfg['tile_oh'].size[-1],
cfg['tile_ow'].size[-1],
cfg['tile_co'].size[-1]],
max_unroll=16,
cfg=cfg)
s[conv].compute_at(s[last], ow)
# mark parallel
s[last].parallel(co)
if data_vec.op.name == 'data_vec_undilated':
_, h, _, _, _, _, _, _ = s[data_vec].op.axis
else:
_, h, _, _, _, _ = s[data_vec].op.axis
s[data_vec].parallel(h)
if kernel_vec.op.name == 'kernel_vec':
co, _, _, _, _ = s[kernel_vec].op.axis
if autotvm.GLOBAL_SCOPE.in_tuning:
# kernel packing will be pre-computed during compilation, so we skip
# this part to make tuning records correct
s[kernel_vec].pragma(co, 'debug_skip_region')
else:
s[kernel_vec].parallel(co)
elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose': # for conv2d transpose
co, _, _, _, _ = s[kernel_vec].op.axis
s[kernel_vec].parallel(co)
return s
......@@ -24,7 +24,7 @@ from tvm import autotvm
from ..generic import schedule_conv2d_transpose_nchw
from ..nn import conv2d_transpose_nchw, dilate, pad, get_pad_tuple
from ..util import get_const_tuple, traverse_inline
from .conv2d import _schedule_spatial_pack
from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw
@autotvm.task.register_topi_compute(conv2d_transpose_nchw, "arm_cpu", "direct")
def conv2d_transpose_nchw_arm(cfg, Input, Filter, strides, padding, out_dtype):
......@@ -154,7 +154,8 @@ def schedule_conv2d_transpose_arm(cfg, outs):
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
_schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec,
conv, output, outs[0])
traverse_inline(s, outs[0].op, _callback)
return s
......@@ -27,7 +27,8 @@ from ..nn import conv2d, conv2d_winograd_without_weight_transform, \
from ..nn.winograd_util import winograd_transform_matrices
# reuse some compute declarations from ARM CPU
from ..arm_cpu.conv2d import _decl_spatial_pack, _alter_conv2d_layout_arm
from ..arm_cpu.conv2d import _alter_conv2d_layout_arm
from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw
@autotvm.register_topi_compute(conv2d, 'mali', ['direct'])
......@@ -67,8 +68,11 @@ def conv2d_mali(cfg, data, kernel, strides, padding, dilation, layout, out_dtype
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
num_tile=3)
if layout == 'NCHW':
return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding,
dilation, out_dtype, num_tile=3)
else:
raise ValueError("Unsupported layout {}".format(layout))
@autotvm.register_topi_schedule(schedule_conv2d_nchw, 'mali', ['direct', 'winograd'])
def schedule_conv2d_nchw_mali(cfg, outs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment