Commit 9151d435 by Alex Gladkov Committed by Yao Wang

Additional MXNet Convolution and Deconvolution tests (#4026)

Add different batch sizes and channel numbers to
MXNet Convolution and Deconvolution tests.
parent 18188f4b
......@@ -833,7 +833,7 @@ def test_forward_slice():
def test_forward_convolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):
weight_shape=(num_filter,1,) + kernel_size
weight_shape=(num_filter, data_shape[1],) + kernel_size
x = np.random.uniform(size=data_shape).astype("float32")
weight = np.random.uniform(size=weight_shape).astype("float32")
bias = np.random.uniform(size=num_filter).astype("float32")
......@@ -852,11 +852,17 @@ def test_forward_convolution():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
def test_forward_deconvolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):
weight_shape=(1, num_filter) + kernel_size
weight_shape=(data_shape[1], num_filter) + kernel_size
x = np.random.uniform(size=data_shape).astype("float32")
weight = np.random.uniform(size=weight_shape).astype("float32")
bias = np.random.uniform(size=num_filter).astype("float32")
......@@ -875,7 +881,13 @@ def test_forward_deconvolution():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(1,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(20,8,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment