Unverified Commit afb8bf06 by Samuel Committed by GitHub

[BUGFIX]bugfix in tensorflow space_to_batch_nd (#5175)

* [BUGFIX]bugfix in tensorflow space_to_batch_nd

* Test case added
parent b2a32ddf
......@@ -1516,7 +1516,7 @@ def _space_to_batch_nd():
paddings = _infer_value(inputs[2], params).asnumpy()
paddings = np.squeeze(paddings)
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, exis=0)
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()
N = len(input_shape)
M = len(block_shape)
......
......@@ -593,6 +593,17 @@ def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
compare_tf_with_tvm(data, in_data.name, out.name)
def _test_space_to_batch_nd_infer_paddings(input_shape, block_shape, dtype='int32'):
data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
padding_np = np.array([0, 1]).astype(np.int32).reshape((1, 2))
with tf.Graph().as_default():
in_data = tf.placeholder(shape=input_shape, dtype=dtype)
const1 = tf.constant(padding_np, dtype=tf.int32)
# make paddings an input to tf.transpose, but not an input to the graph,
# so it can be extracted with infer_value_simulated
paddings = tf.reverse(const1, axis=[-1])
out = tf.space_to_batch_nd(in_data, block_shape, paddings)
compare_tf_with_tvm(data, in_data.name, out.name)
def test_forward_space_to_batch_nd():
# test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d
......@@ -637,6 +648,11 @@ def test_forward_space_to_batch_nd():
dtype='float64'
)
_test_space_to_batch_nd_infer_paddings(
input_shape=[2, 3, 2],
block_shape=[2]
)
#######################################################################
# BatchToSpaceND
# --------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment