Commit c846d17c by Yuwei Hu Committed by Wuwei Lin

[TOPI] Improve conv2d_transpose schedule on X86 and CUDA (#3948)

* improve conv2d_transpose x86 performance by reusing conv2d schedule

* parallelize across batches to make large-batch conv2d and conv2d_transpose faster

* improve doc for autotvm.task.space.FallbackConfigEntity.fallback_with_reference_log

* add fallback schedule for schedule_conv2d_transpose_nchw_cuda

* fix pylint

* fix pylint

* unify conv2d_transpose declaration in topi.nn and topi.x86
parent b577171d
......@@ -1003,6 +1003,9 @@ class FallbackConfigEntity(ConfigSpace):
We use tuned parameters from TopHub as reference data.
For an unseen shape, we find the most similar tuned one from TopHub and
mimic its parameters.
Note that we are not matching by workload (e.g., input size, kernel size),
but instead matching by configuration space. The idea is that if two workloads have
similar configuration space, their optimal configurations are also likely to be similar.
Parameters
----------
......
......@@ -19,10 +19,11 @@
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic
from ..util import equal_const_int, get_const_tuple, traverse_inline
@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
"""Transposed 2D convolution nchw forward operator.
......@@ -129,6 +130,36 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _fallback_schedule(N, F, Y, X):
# pylint: disable=unused-argument
# split N (batch dimension)
if N > 1:
cfg["tile_n"] = SplitEntity([-1, 1, 1, 4])
else:
cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
# split F (output channel dimension)
cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
# split Y (height dimension)
y_split_factor = 1
for candidate in range(5, 17):
if Y % candidate == 0:
y_split_factor = candidate
break
cfg["tile_y"] = SplitEntity([-1, 1, 1, y_split_factor])
# split X (width dimension)
x_split_factor = 1
for candidate in range(5, 17):
if X % candidate == 0:
x_split_factor = candidate
break
cfg["tile_x"] = SplitEntity([-1, x_split_factor, 1, 1])
# split RC (input channel dimension, which is a reduction axis)
cfg["tile_rc"] = SplitEntity([-1, 1, 16])
# other configurations
cfg["fuse_yx"] = OtherOptionEntity(False)
cfg["unroll_explicit"] = OtherOptionEntity(True)
cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
def _callback(op):
if op.tag == 'conv2d_transpose_nchw':
pad_data = op.input_tensors[0]
......@@ -150,6 +181,11 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])
if cfg.is_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)
##### space definition end #####
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
......
......@@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
from __future__ import absolute_import as _abs
import tvm
from .dilate import dilate
from .pad import pad
from .util import get_pad_tuple
......@@ -53,27 +52,44 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
"""
return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype)
def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype):
"""Implementation of conv2d transpose"""
def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype):
"""Preprocess data and kernel to make the compute pattern
of conv2d_transpose the same as conv2d"""
batch, in_c, in_h, in_w = data.shape
_, out_c, filter_h, filter_w = kernel.shape
stride_h, stride_w = strides
# dilate stage
DilatedInput = dilate(data, [1, 1, stride_h, stride_w], name='DilatedInput')
# padding stage
# dilate data
data_dilate = dilate(data, [1, 1, stride_h, stride_w], name='data_dilate')
# pad data
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right
PaddedInput = pad(DilatedInput, \
[0, 0, bpad_top, bpad_left], \
[0, 0, bpad_bottom, bpad_right], \
name='PaddedInput')
data_pad = pad(data_dilate, \
[0, 0, bpad_top, bpad_left], \
[0, 0, bpad_bottom, bpad_right], \
name='data_pad')
# transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees
kernel_transform = tvm.compute((out_c, in_c, filter_h, filter_w), \
lambda o, i, h, w: kernel[i][o][filter_h-1-h][filter_w-1-w], \
name='kernel_transform')
return data_pad, kernel_transform
def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype):
"""Implementation of conv2d transpose"""
data_pad, kernel_transform = \
conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype)
batch, in_c, in_h, in_w = data_pad.shape
out_c, _, filter_h, filter_w = kernel_transform.shape
stride_h, stride_w = strides
# convolution stage
out_c = simplify(out_c)
out_h = simplify((in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h)
out_w = simplify((in_w - 1) * stride_w - fpad_left - fpad_right + filter_w)
out_h = simplify(in_h - filter_h + 1)
out_w = simplify(in_w - filter_w + 1)
dc = tvm.reduce_axis((0, in_c), name='dc')
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')
......@@ -81,8 +97,8 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
Output = tvm.compute(
(batch, out_c, out_h, out_w),
lambda b, c, h, w: tvm.sum(
PaddedInput[b, dc, h+dh, w+dw].astype(out_dtype) *
kernel[dc, c, filter_h-1-dh, filter_w-1-dw].astype(out_dtype),
data_pad[b, dc, h+dh, w+dw].astype(out_dtype) *
kernel_transform[c, dc, dh, dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output
......@@ -15,5 +15,5 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_NCHWc
from .dense import _schedule_dense, _schedule_dense_pack, _schedule_dense_nopack
from .batch_matmul import schedule_batch_matmul
from .roi_align import roi_align_nchw
from .conv2d_transpose import schedule_conv2d_transpose
from .conv2d_transpose import _schedule_conv2d_transpose_nchw
from .sparse import *
......@@ -73,7 +73,7 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
parallel_axis = s[A1].fuse(batch, ic_chunk, ih)
s[A1].parallel(parallel_axis)
# schedule kernel pack
......@@ -115,7 +115,7 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh_outer)
parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
......
......@@ -72,7 +72,7 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
parallel_axis = s[A1].fuse(batch, ic_chunk, ih)
s[A1].parallel(parallel_axis)
# schedule kernel pack
......@@ -117,7 +117,7 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(oc_chunk, oh)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
......@@ -135,7 +135,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
# schedule 5-D NCHW[x]c conv
......@@ -194,7 +194,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
A = data
if isinstance(s[A].op, tvm.tensor.ComputeOp):
batch, ic_chunk, ih, iw, _ = s[A].op.axis
parallel_axis = s[A].fuse(ic_chunk, ih)
parallel_axis = s[A].fuse(batch, ic_chunk, ih)
s[A].parallel(parallel_axis)
# schedule 5-D NCHW[x]c conv
......
......@@ -16,50 +16,60 @@
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D Transpose schedule on x86"""
import tvm
from tvm import autotvm
from .. import generic, tag
from ..nn.conv2d_transpose import conv2d_transpose_nchw, declaration_conv2d_transpose_impl
from .. import generic
from ..util import get_const_tuple, traverse_inline
from ..nn import conv2d_transpose_nchw_preprocess, conv2d_transpose_nchw
from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d import _declaration_conv_impl, \
_create_tuning_space as _create_tuning_space_conv2d, \
_get_default_config as _get_default_config_conv2d
@autotvm.register_topi_compute(conv2d_transpose_nchw, 'cpu', ['direct'])
def _declaration_conv2d_transpose(cfg, data, kernel, strides, padding, out_dtype):
# TODO cfg is not used for now
return declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
def _conv2d_transpose_nchw(cfg, data, kernel, strides, padding, out_dtype):
data_pad, kernel_transform = \
conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype)
# reuse conv2d implementation
_create_tuning_space_conv2d(cfg, data_pad, kernel_transform, strides=(1, 1), \
padding=(0, 0), dilation=(1, 1), layout="NCHW")
if cfg.is_fallback:
_get_default_config_conv2d(cfg, data_pad, kernel_transform, strides=(1, 1), \
padding=(0, 0), out_dtype=out_dtype, layout='NCHW')
return _declaration_conv_impl(cfg, data_pad, kernel_transform, strides=(1, 1), \
padding=(0, 0), dilation=(1, 1), layout="NCHW", \
out_dtype=out_dtype)
@autotvm.register_topi_schedule(generic.schedule_conv2d_transpose_nchw, 'cpu', ['direct'])
def schedule_conv2d_transpose(cfg, outs):
def _schedule_conv2d_transpose_nchw(cfg, outs):
"""Create schedule for tensors"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d_transpose_nchw' in op.tag:
C = op.output(0)
N, OC, OH, OW = C.op.axis
rc, ry, rx = C.op.reduce_axis
OH, oh = s[C].split(OH, factor=2)
OC, oc = s[C].split(OC, factor=32)
IC, ic = s[C].split(rc, factor=32)
s[C].reorder(N, OC, OH, OW, oc, IC, ry, rx, ic)
N = s[C].fuse(N, OC)
s[C].vectorize(oc)
s[C].parallel(N)
scheduled_ops.append(op)
def _callback(op):
# reuse conv2d schedule
if 'conv2d_nchw' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
# retrieve data
data_vec = conv_out.op.input_tensors[0]
data_pad = data_vec.op.input_tensors[0]
data_dilate = data_pad.op.input_tensors[0]
s[data_dilate].compute_inline()
# retrieve kernel
kernel_vec = conv_out.op.input_tensors[1]
kernel_transform = kernel_vec.op.input_tensors[0]
s[kernel_transform].compute_inline()
# call conv2d schedule
_, _, kh, kw = get_const_tuple(kernel_transform.shape)
is_kernel_1x1 = kh == 1 and kw == 1
args = [s, cfg, data_dilate, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
if is_kernel_1x1:
conv2d_avx_1x1._schedule_conv(*args)
else:
conv2d_avx_common._schedule_conv(*args)
traverse(outs[0].op)
traverse_inline(s, outs[0].op, _callback)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment