Commit 1ae44cf0 by Alex Gladkov Committed by Wuwei Lin

Fix Tensorflow conv3d pad bug, add non-cubic data and kernel tests (#4772)

parent 4d4346d1
......@@ -498,7 +498,7 @@ def _conv3d(opname):
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_v[0], pad_v[1], pad_h[1]]
attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
......@@ -509,7 +509,7 @@ def _conv3d(opname):
attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW'
use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
channel_axis = 1 if attr['data_format'] == "NCDHW" else 3
channel_axis = 1 if attr['data_format'] == "NCDHW" else 4
# Ignore the new attributes from TF2.0, for now.
out = AttrCvt(
......
......@@ -25,10 +25,17 @@ from topi.util import get_const_tuple
def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
in_depth = in_height = in_width = in_size
if isinstance(in_size, tuple):
in_depth, in_height, in_width = in_size
else:
in_depth = in_height = in_width = in_size
if isinstance(kernel, tuple):
kernel_depth, kernel_height, kernel_width = kernel
else:
kernel_depth = kernel_height = kernel_width = kernel
A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
W = tvm.placeholder((kernel, kernel, kernel, in_channel, num_filter), name='W')
W = tvm.placeholder((kernel_depth, kernel_height, kernel_width, in_channel, num_filter), name='W')
B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation)
a_shape = get_const_tuple(A.shape)
......@@ -74,6 +81,12 @@ def test_conv3d_ndhwc():
# dilation = 2
verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "SAME", dilation=2)
verify_conv3d_ndhwc(1, 1, (20, 256, 256), 32, (1, 3, 3), (1, 2, 2), "SAME")
verify_conv3d_ndhwc(1, 1, (20, 256, 256), 32,
(1, 6, 6), (1, 2, 2), (0, 2, 2))
verify_conv3d_ndhwc(1, 4, (20, 256, 256), 8,
(1, 5, 5), (1, 2, 2), (0, 2, 2))
if __name__ == "__main__":
test_conv3d_ndhwc()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment