Commit 4f92cfe5 by Alex Gladkov Committed by Wuwei Lin

Improve CUDA conv2d_transpose_nchw (#4762)

- combine pad and dilate;
- fix for the issue https://discuss.tvm.ai/t/compile-error-for-cuda-target/4164
- fix for the issue https://github.com/apache/incubator-tvm/pull/4472
parent cf3e7865
......@@ -21,11 +21,11 @@ import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic
from ..util import equal_const_int, get_const_tuple, traverse_inline
from ..util import get_const_tuple, traverse_inline
@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
"""Transposed 2D convolution nchw forward operator.
Parameters
......@@ -48,67 +48,58 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_c, in_h, in_w = get_const_tuple(Input.shape)
_, out_c, filter_h, filter_w = get_const_tuple(Filter.shape)
stride_h, stride_w = strides
# attach stride info to config, this is used in schedule space definition
cfg.stride = strides
# padding stage
fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right
# padding stage
FirstPad = nn.pad(Input,
[0, 0, (bpad_top + stride_h - 1) // stride_h,
(bpad_left + stride_w - 1) // stride_w],
[0, 0, (bpad_bottom + stride_h - 1) // stride_h,
(bpad_right + stride_w - 1) // stride_w], name='FirstPad')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# remove extra padding introduced by dilatation
border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
# dilation stage
data = FirstPad
strides = [1, 1, stride_h, stride_w]
n = len(data.shape)
def _dilate(*indices):
not_zero = []
index_tuple = []
for i in range(n):
if not equal_const_int(strides[i], 1):
index_tuple.append(idxdiv(indices[i], strides[i]))
not_zero.append(idxmod(indices[i], strides[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)
# convolution stage
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
dc = tvm.reduce_axis((0, in_c), name='dc')
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')
Output = tvm.compute(
(batch, out_c, out_h, out_w),
batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
_, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape)
stride_height, stride_width = stride
cfg.stride = stride
pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
padding, (kernel_height, kernel_width))
out_width = (inp_width - 1) * stride_width + \
kernel_width - pad_left - pad_right
pad_left = kernel_width - 1 - pad_left
pad_right = kernel_width - 1 - pad_right
dilated_width = stride_width * (inp_width - 1) + 1
out_height = (inp_height - 1) * stride_height + \
kernel_height - pad_top - pad_bottom
pad_top = kernel_height - 1 - pad_top
pad_bottom = kernel_height - 1 - pad_bottom
dilated_height = stride_height * (inp_height - 1) + 1
# compute pad
data = tvm.compute(
(batch, inp_channels,
pad_top + dilated_height + pad_bottom,
pad_left + dilated_width + pad_right),
lambda n, c, y, x: tvm.if_then_else(
tvm.all(x >= pad_left,
x < pad_left + dilated_width,
tvm.indexmod(x - pad_left, stride_width).equal(0),
y >= pad_top,
y < pad_top + dilated_height,
tvm.indexmod(y - pad_top, stride_height).equal(0)),
data[n, c,
tvm.indexdiv(y - pad_top, stride_height),
tvm.indexdiv(x - pad_left, stride_width)],
tvm.const(0., "float32")),
name='data_pad')
# compute transposed conv
dc = tvm.reduce_axis((0, inp_channels), name='dc')
dh = tvm.reduce_axis((0, kernel_height), name='dh')
dw = tvm.reduce_axis((0, kernel_width), name='dw')
data_out = tvm.compute(
(batch, out_channels, out_height, out_width),
lambda b, c, h, w: tvm.sum(
_dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) *
Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype),
data[b, dc, h + dh, w + dw].astype(out_dtype) *
kernel[dc,
c,
kernel_height - 1 - dh,
kernel_width - 1 - dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output
return data_out
@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
['cuda', 'gpu'], 'direct')
......@@ -140,6 +131,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
else:
cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
# split F (output channel dimension)
if F > 1:
cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
# split Y (height dimension)
y_split_factor = 1
......@@ -185,24 +177,6 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
cfg.define_knob("unroll_explicit", [0, 1])
if cfg.is_fallback:
ko = int(kernel.shape[1])
kh = int(kernel.shape[2])
kw = int(kernel.shape[3])
stride_h, stride_w = cfg.stride
# Workaround to make CUDA compilation work. Issue #4470
# TODO make _fallback_schedule work for all kernel/strides combinations
# after issue #4470 is resolved
do_fallback = True
if ko == 1:
do_fallback = False
elif (kh, kw) == (1, 1):
do_fallback = True
elif (stride_h, stride_w) == (2, 2):
do_fallback = False
elif (kh, kw) == (stride_h, stride_w):
do_fallback = False
if do_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)
......
......@@ -25,10 +25,13 @@ from topi.util import get_const_tuple
from common import get_all_backend
def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
in_height, in_width = in_size
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
pad_top, pad_left, pad_bottom, pad_right = padding
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
W = tvm.placeholder((in_channel, num_filter, kernel_height, kernel_width), name='W')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
......@@ -51,7 +54,10 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
return
print("Running on target: %s" % device)
with tvm.target.create(device):
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype)
B = topi.nn.conv2d_transpose_nchw(A, W,
[stride_height, stride_width],
[pad_top, pad_left, pad_bottom, pad_right],
A.dtype)
C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
......@@ -66,18 +72,21 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in get_all_backend():
check_device(device)
def test_conv2d_transpose_nchw():
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)
verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1))
verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0))
verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0))
if __name__ == "__main__":
test_conv2d_transpose_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment