Commit 5f79521b by Yuwei Hu Committed by Tianqi Chen

[TOPI] add conv2d_transpose_nchw (#586)

parent 25f95766
......@@ -12,3 +12,4 @@ from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
#pylint: disable=invalid-name
"""Schedule for conv2d_transpose_nchw with auto fusion"""
import tvm
from .. import util
from .. import tag
from .. import generic
from .conv2d_nchw import conv2d_224_3_64, conv2d_56_64_128, conv2d_14_256_256, conv2d_56_64_64
def schedule_conv2d_transpose_small_batch(outs):
"""Create schedule for tensors or return error if batch size is larger than 1"""
s = tvm.create_schedule([x.op for x in outs])
def schedule(temp, Filter, Output):
"""Schedule conv2d_transpose_nchw"""
block_h = util.get_const_int(Output.shape[3])
block_w = util.get_const_int(temp.shape[1])
if block_h % 48 == 0:
block_h = 48
elif block_h % 32 == 0:
block_h = 32
if block_w % 48 == 0:
block_w = 48
elif block_w % 32 == 0:
block_w = 32
flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])
if flag > 768:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
oic, iic = s[temp_G].split(ic, factor=4)
s[temp_G].reorder(i, h, w, oic, iic)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif 128 < flag < 512:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
oic, iic = s[temp_G].split(ic, factor=4)
s[temp_G].reorder(i, oic, h, w, iic)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif util.get_const_int(Filter.shape[3]) == 7:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
s[temp_G].split(w, factor=4)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
else:
s[temp].compute_inline()
temp_S = s.cache_read(temp, "shared", [Output])
temp_R = temp_S
Filter_S = s.cache_read(Filter, "shared", [Output])
if Output.op in s.outputs:
Out = Output
Out_L = s.cache_write(Out, "local")
else:
Out = outs[0].op.output(0)
s[Output].set_scope("local")
Out_L = Output
if util.get_const_int(Filter.shape[3]) == 7:
conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L)
elif 128 < flag < 512:
conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
elif flag >= 512:
conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L)
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule conv2d_transpose_nchw
if 'conv2d_transpose_nchw' in OP.tag:
temp = OP.input_tensors[0]
DilatedInput = temp.op.input_tensors[0]
s[DilatedInput].compute_inline()
Filter = OP.input_tensors[1]
Output = OP.output(0)
schedule(temp, Filter, Output)
traverse(outs[0].op)
return s
@generic.schedule_conv2d_transpose_nchw.register(["cuda", "gpu"])
def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw.
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_transpose_nchw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d_transpose_nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
batch_size = util.get_const_int(outs[0].op.output(0).shape[0])
if batch_size > 1:
raise RuntimeError("Batch size: %d is too large for this schedule" % batch_size)
return schedule_conv2d_transpose_small_batch(outs)
......@@ -8,7 +8,7 @@ def _default_schedule(outs, auto_inline):
target = tvm.target.current_target(allow_none=False)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name != "llvm":
raise RuntimeError("schedule_pool not registered for '%s'" % target)
raise RuntimeError("schedule not registered for '%s'" % target)
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
x = outs[0]
......@@ -19,13 +19,13 @@ def _default_schedule(outs, auto_inline):
@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d nchow
"""Schedule for conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
......@@ -36,14 +36,32 @@ def schedule_conv2d_nchw(outs):
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_transpose_nchw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for conv2d nchow
"""Schedule for depthwise_conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of depthwise_conv2d_nchw
in the format of an array of tensors.
Returns
-------
......@@ -55,12 +73,12 @@ def schedule_depthwise_conv2d_nchw(outs):
@tvm.target.generic_func
def schedule_depthwise_conv2d_nhwc(outs):
"""Schedule for depthwise nhcw conv2
"""Schedule for depthwise_conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of depthwise_conv2d_nhwc
in the format of an array of tensors.
Returns
-------
......@@ -77,8 +95,8 @@ def schedule_reduce(outs):
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of reduce
in the format of an array of tensors.
Returns
-------
......@@ -95,8 +113,8 @@ def schedule_softmax(outs):
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of softmax
in the format of an array of tensors.
Returns
-------
......@@ -113,8 +131,8 @@ def schedule_dense(outs):
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
......@@ -131,8 +149,8 @@ def schedule_pool(outs):
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
......@@ -149,8 +167,8 @@ def schedule_global_pool(outs):
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
The computation graph description of global pool
in the format of an array of tensors.
Returns
-------
......
......@@ -12,3 +12,4 @@ from .dense import *
from .mapping import *
from .pooling import *
from .softmax import *
from .conv2d_transpose import *
# pylint: disable=invalid-name, unused-variable
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
from __future__ import absolute_import as _abs
import tvm
from .dilate import dilate
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
def conv2d_transpose_nchw(Input, Filter, strides, padding):
"""Transposed 2D convolution nchw forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
strides : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_c, in_h, in_w = Input.shape
out_c, _, filter_h, filter_w = Filter.shape
stride_h, stride_w = strides
# dilate stage
DilatedInput = dilate(Input, [1, 1, stride_h, stride_w], name='DilatedInput')
# padding stage
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right
PaddedInput = pad(DilatedInput, \
[0, 0, bpad_top, bpad_left], \
[0, 0, bpad_bottom, bpad_right], \
name='PaddedInput')
# convolution stage
out_c = simplify(out_c)
out_h = simplify((in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h)
out_w = simplify((in_w - 1) * stride_w - fpad_left - fpad_right + filter_w)
dc = tvm.reduce_axis((0, in_c), name='dc')
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')
Output = tvm.compute(
(batch, out_c, out_h, out_w),
lambda b, c, h, w: tvm.sum(
PaddedInput[b, dc, h+dh, w+dw] * Filter[c, dc, filter_h-1-dh, filter_w-1-dw],
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output
......@@ -6,6 +6,7 @@ from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
......@@ -60,5 +60,5 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
apad = a_np[n, c]
out = scipy.signal.convolve2d(
apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
b_np[n, f] += out[::stride, ::stride]
b_np[n, f] += out[::stride_h, ::stride_w]
return b_np
# pylint: disable=unused-variable
"""Transposed convolution in python"""
import numpy as np
import topi
from topi.nn.util import get_pad_tuple
def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
"""Transposed convolution operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_c, in_h, in_w = a_np.shape
out_c, _, filter_h, filter_w = w_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
# dilate stage
dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_h, stride_w])
# padding stage
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right
padded_a_np = np.zeros((batch, in_c, dilated_a_np.shape[2]+bpad_top+bpad_bottom, \
dilated_a_np.shape[3]+bpad_left+bpad_right))
padded_a_np[:, :, bpad_top:dilated_a_np.shape[2]+bpad_top, \
bpad_left:dilated_a_np.shape[3]+bpad_left] = dilated_a_np
# convolution stage
rotated_w_np = np.rot90(w_np, k=2, axes=(2, 3))
b_np = topi.testing.conv2d_nchw_python(padded_a_np, rotated_w_np, stride=1, padding='VALID')
return b_np
"""Test code for transposed convolution."""
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], padding)
C = topi.nn.relu(B)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_transpose.verify_conv2d_transpose_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = topi.testing.conv2d_transpose_nchw_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=128,
unroll_explicit=(device != "cuda")):
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device)
def test_conv2d_transpose_nchw():
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)
if __name__ == "__main__":
test_conv2d_transpose_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment