Commit 654192de by lixiaoquan Committed by Tianqi Chen

Fix a tensorflow test bug. (#3165)

Length of input_shape isn't always 4.
parent 95a323aa
......@@ -185,7 +185,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs)
if is_gpu_available():
if is_gpu_available() and (len(input_shape) == 4):
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment