Commit 9049d669 by Alexander Pivovarov Committed by Yao Wang

[Relay][Legalize] Legalize conv2d_transpose for NHWC (#4399)

parent 87bd799e
...@@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target): ...@@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target):
return topi.generic.schedule_conv2d_transpose_nchw(outs) return topi.generic.schedule_conv2d_transpose_nchw(outs)
@reg.register_legalize("nn.conv2d_transpose")
def legalize_conv2d_transpose(attrs, inputs, types):
"""Legalize conv2d_transpose op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current Transposed convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)
reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# bias_add # bias_add
......
...@@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs): ...@@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs):
@register_relay_attr_node @register_relay_attr_node
class BinaryDenseAttrs(Attrs): class BinaryDenseAttrs(Attrs):
"""Attributes used in bitserial dense operators""" """Attributes used in bitserial dense operators"""
@register_relay_attr_node
class Conv2DTransposeAttrs(Attrs):
"""Attributes used in Transposed Conv2D operators"""
...@@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type(): ...@@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type():
(10, 15, 3, 3), "float32") (10, 15, 3, 3), "float32")
# infer by shape of w, mixed precision # infer by shape of w, mixed precision
n, c, h, w = tvm.var("n"), 10, 10, 12 n, h, w, c = tvm.var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32")) w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
y = relay.nn.conv2d_transpose(x, w, y = relay.nn.conv2d_transpose(x, w,
output_padding=(1, 1), output_padding=(1, 1),
...@@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type(): ...@@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type():
(n, 15, 15, 11), "float32") (n, 15, 15, 11), "float32")
def test_conv2d_transpose_run(): def test_conv2d_transpose_nchw_run():
dshape = (1, 3, 18, 18) dshape = (1, 3, 18, 18)
kshape = (3, 10, 3, 3) kshape = (3, 10, 3, 3)
oshape = (1, 10, 37, 37) oshape = (1, 10, 37, 37)
...@@ -348,6 +348,33 @@ def test_conv2d_transpose_run(): ...@@ -348,6 +348,33 @@ def test_conv2d_transpose_run():
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
def test_conv2d_transpose_nhwc_run():
dshape_nhwc = (1, 18, 18, 3)
kshape_hwoi = (3, 3, 10, 3)
oshape_nhwc = (1, 37, 37, 10)
x = relay.var("x", shape=dshape_nhwc)
w = relay.var("w")
# kshape and kernel_layout should have swapped IO.
# kshape is HWOI and kernel_layout is HWIO
y = relay.nn.conv2d_transpose(x, w,
channels=10, kernel_size=(3, 3), strides=(2, 2),
padding=(1, 1), output_padding=(2, 2),
data_layout="NHWC", kernel_layout="HWIO")
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape_nhwc).astype(dtype)
kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
# use true kshape layout here - HWOI
c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1)
d_np = np.zeros(shape=oshape_nhwc)
d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np
ref_res = d_np
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
def test_upsampling_infer_type(): def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
...@@ -819,7 +846,8 @@ if __name__ == "__main__": ...@@ -819,7 +846,8 @@ if __name__ == "__main__":
test_pad_infer_type() test_pad_infer_type()
test_pad_run() test_pad_run()
test_conv2d_transpose_infer_type() test_conv2d_transpose_infer_type()
test_conv2d_transpose_run() test_conv2d_transpose_nchw_run()
test_conv2d_transpose_nhwc_run()
test_conv2d_run() test_conv2d_run()
test_conv2d_winograd() test_conv2d_winograd()
test_bitserial_conv2d_infer_type() test_bitserial_conv2d_infer_type()
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""Transposed 2D convolution operators (sometimes called Deconvolution).""" """Transposed 2D convolution operators (sometimes called Deconvolution)."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import relay
from .dilate import dilate from .dilate import dilate
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
...@@ -102,3 +103,62 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype) ...@@ -102,3 +103,62 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw") axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output return Output
@tvm.target.generic_func
def conv2d_transpose_legalize(attrs, inputs, types):
"""Legalizes Transposed 2D convolution op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current Transposed 2D convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
kernel_layout = attrs['kernel_layout']
# Convert Kernel layout to IOHW
# kernel_layout is different from input kernel layout - IO is swapped
if kernel_layout == 'HWIO':
# input kernel layout is swapped to HWOI
# output kernel layout will be IOHW
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif kernel_layout == 'HWOI':
# input kernel layout is swapped to HWIO
# output kernel layout will be IOHW
kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif kernel_layout == 'IOHW':
# input kernel layout is swapped to OIHW
# output kernel layout will be IOHW
kernel = relay.transpose(kernel, axes=(1, 0, 2, 3))
elif kernel_layout == 'OIHW':
# input kernel layout is swapped to IOHW
# output kernel layout will be IOHW
pass
else:
# Skip legalize. Let relay.nn.conv2d_transpose to handle the case
return None
# Set new attrs for conv2d_transpose.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
# layout of kernel should be IOHW, but kernel_layout should be swapped - OIHW
new_attrs['kernel_layout'] = 'OIHW'
# Convert data to NCHW.
data = relay.transpose(data, axes=(0, 3, 1, 2))
deconv = relay.nn.conv2d_transpose(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = relay.transpose(deconv, axes=(0, 2, 3, 1))
return out
return None
...@@ -24,7 +24,7 @@ from __future__ import absolute_import as _abs ...@@ -24,7 +24,7 @@ from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python from .conv2d_nhwc_python import conv2d_nhwc_python
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python from .dilate_python import dilate_python
......
...@@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding): ...@@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
padded_a_np[n, c], w_np[c, f], mode='valid') padded_a_np[n, c], w_np[c, f], mode='valid')
b_np[n, f] += out b_np[n, f] += out
return b_np return b_np
def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding):
"""Transposed convolution operator in NHWC layout.
Parameters
----------
a_nhwc : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]
weight : numpy.ndarray
4-D in formats HWIO, HWOI, OIHW or IOHW
weight_format : str
['HWIO', 'HWOI', 'OIHW', 'IOHW']
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert a_nhwc.ndim == 4, "a_nhwc number of dimensions should be 4"
assert weight.ndim == 4, "weight number of dimensions should be 4"
a_nchw = np.transpose(a_nhwc, (0, 3, 1, 2))
# conv2d_transpose_nchw_python needs kernel layout to be IOHW
if weight_format == 'HWIO':
w_iohw = np.transpose(weight, (2, 3, 0, 1))
elif weight_format == 'HWOI':
w_iohw = np.transpose(weight, (3, 2, 0, 1))
elif weight_format == 'OIHW':
w_iohw = np.transpose(weight, (1, 0, 2, 3))
elif weight_format == 'IOHW':
w_iohw = weight
else:
raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW')
res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding)
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
return res_nhwc
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment