Commit 4f92cfe5 by Alex Gladkov Committed by Wuwei Lin

Improve CUDA conv2d_transpose_nchw (#4762)

- combine pad and dilate;
- fix for the issue https://discuss.tvm.ai/t/compile-error-for-cuda-target/4164
- fix for the issue https://github.com/apache/incubator-tvm/pull/4472
parent cf3e7865
...@@ -21,11 +21,11 @@ import tvm ...@@ -21,11 +21,11 @@ import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic from .. import nn, generic
from ..util import equal_const_int, get_const_tuple, traverse_inline from ..util import get_const_tuple, traverse_inline
@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct") @autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
"""Transposed 2D convolution nchw forward operator. """Transposed 2D convolution nchw forward operator.
Parameters Parameters
...@@ -48,67 +48,58 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype): ...@@ -48,67 +48,58 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
Output : tvm.Tensor Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
batch, in_c, in_h, in_w = get_const_tuple(Input.shape) batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
_, out_c, filter_h, filter_w = get_const_tuple(Filter.shape) _, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape)
stride_h, stride_w = strides stride_height, stride_width = stride
cfg.stride = stride
# attach stride info to config, this is used in schedule space definition pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
cfg.stride = strides padding, (kernel_height, kernel_width))
# padding stage out_width = (inp_width - 1) * stride_width + \
fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w)) kernel_width - pad_left - pad_right
bpad_top = filter_h - 1 - fpad_top pad_left = kernel_width - 1 - pad_left
bpad_bottom = filter_h - 1 - fpad_bottom pad_right = kernel_width - 1 - pad_right
bpad_left = filter_w - 1 - fpad_left dilated_width = stride_width * (inp_width - 1) + 1
bpad_right = filter_w - 1 - fpad_right
out_height = (inp_height - 1) * stride_height + \
# padding stage kernel_height - pad_top - pad_bottom
FirstPad = nn.pad(Input, pad_top = kernel_height - 1 - pad_top
[0, 0, (bpad_top + stride_h - 1) // stride_h, pad_bottom = kernel_height - 1 - pad_bottom
(bpad_left + stride_w - 1) // stride_w], dilated_height = stride_height * (inp_height - 1) + 1
[0, 0, (bpad_bottom + stride_h - 1) // stride_h,
(bpad_right + stride_w - 1) // stride_w], name='FirstPad') # compute pad
data = tvm.compute(
idxdiv = tvm.indexdiv (batch, inp_channels,
idxmod = tvm.indexmod pad_top + dilated_height + pad_bottom,
# remove extra padding introduced by dilatation pad_left + dilated_width + pad_right),
border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h) lambda n, c, y, x: tvm.if_then_else(
border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w) tvm.all(x >= pad_left,
x < pad_left + dilated_width,
# dilation stage tvm.indexmod(x - pad_left, stride_width).equal(0),
data = FirstPad y >= pad_top,
strides = [1, 1, stride_h, stride_w] y < pad_top + dilated_height,
n = len(data.shape) tvm.indexmod(y - pad_top, stride_height).equal(0)),
data[n, c,
def _dilate(*indices): tvm.indexdiv(y - pad_top, stride_height),
not_zero = [] tvm.indexdiv(x - pad_left, stride_width)],
index_tuple = [] tvm.const(0., "float32")),
for i in range(n): name='data_pad')
if not equal_const_int(strides[i], 1):
index_tuple.append(idxdiv(indices[i], strides[i])) # compute transposed conv
not_zero.append(idxmod(indices[i], strides[i]).equal(0)) dc = tvm.reduce_axis((0, inp_channels), name='dc')
else: dh = tvm.reduce_axis((0, kernel_height), name='dh')
index_tuple.append(indices[i]) dw = tvm.reduce_axis((0, kernel_width), name='dw')
if not_zero: data_out = tvm.compute(
not_zero = tvm.all(*not_zero) (batch, out_channels, out_height, out_width),
return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)
# convolution stage
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
dc = tvm.reduce_axis((0, in_c), name='dc')
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')
Output = tvm.compute(
(batch, out_c, out_h, out_w),
lambda b, c, h, w: tvm.sum( lambda b, c, h, w: tvm.sum(
_dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) * data[b, dc, h + dh, w + dw].astype(out_dtype) *
Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype), kernel[dc,
c,
kernel_height - 1 - dh,
kernel_width - 1 - dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw") axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output return data_out
@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw, @autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
['cuda', 'gpu'], 'direct') ['cuda', 'gpu'], 'direct')
...@@ -140,6 +131,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -140,6 +131,7 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
else: else:
cfg["tile_n"] = SplitEntity([1, 1, 1, 1]) cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
# split F (output channel dimension) # split F (output channel dimension)
if F > 1:
cfg["tile_f"] = SplitEntity([-1, 1, 64, 1]) cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
# split Y (height dimension) # split Y (height dimension)
y_split_factor = 1 y_split_factor = 1
...@@ -185,24 +177,6 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs): ...@@ -185,24 +177,6 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
cfg.define_knob("unroll_explicit", [0, 1]) cfg.define_knob("unroll_explicit", [0, 1])
if cfg.is_fallback: if cfg.is_fallback:
ko = int(kernel.shape[1])
kh = int(kernel.shape[2])
kw = int(kernel.shape[3])
stride_h, stride_w = cfg.stride
# Workaround to make CUDA compilation work. Issue #4470
# TODO make _fallback_schedule work for all kernel/strides combinations
# after issue #4470 is resolved
do_fallback = True
if ko == 1:
do_fallback = False
elif (kh, kw) == (1, 1):
do_fallback = True
elif (stride_h, stride_w) == (2, 2):
do_fallback = False
elif (kh, kw) == (stride_h, stride_w):
do_fallback = False
if do_fallback:
N, F, Y, X = get_const_tuple(conv.shape) N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X) _fallback_schedule(N, F, Y, X)
......
...@@ -25,10 +25,13 @@ from topi.util import get_const_tuple ...@@ -25,10 +25,13 @@ from topi.util import get_const_tuple
from common import get_all_backend from common import get_all_backend
def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size in_height, in_width = in_size
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
pad_top, pad_left, pad_bottom, pad_right = padding
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W') W = tvm.placeholder((in_channel, num_filter, kernel_height, kernel_width), name='W')
a_shape = get_const_tuple(A.shape) a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape) w_shape = get_const_tuple(W.shape)
...@@ -51,7 +54,10 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, ...@@ -51,7 +54,10 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype) B = topi.nn.conv2d_transpose_nchw(A, W,
[stride_height, stride_width],
[pad_top, pad_left, pad_bottom, pad_right],
A.dtype)
C = topi.nn.relu(B) C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_transpose_nchw([B]) s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C]) s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
...@@ -66,18 +72,21 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, ...@@ -66,18 +72,21 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
func2(a, w, c) func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
def test_conv2d_transpose_nchw(): def test_conv2d_transpose_nchw():
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0) verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1) verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0) verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0) verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1) verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1))
verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0))
verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0))
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_transpose_nchw() test_conv2d_transpose_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment