Unverified Commit 0d1babce by MORITA Kazutaka Committed by GitHub

[FRONTEND][TENSORFLOW] Fix gather_nd indices (#5279)

* [FRONTEND][TENSORFLOW] Fix gather_nd indices

* retrigger CI
parent 00014e20
...@@ -1127,9 +1127,11 @@ def _gather(): ...@@ -1127,9 +1127,11 @@ def _gather():
def _gather_nd(): def _gather_nd():
"""GatherNd""" """GatherNd"""
def _impl(inputs, attr, params, mod): def _impl(inputs, attr, params, mod):
indices_dims = len(_infer_shape(inputs[1], mod))
indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims-1)))
return AttrCvt(op_name="gather_nd", return AttrCvt(op_name="gather_nd",
ignores=['Tindices', 'Tparams',\ ignores=['Tindices', 'Tparams',\
'Taxis', '_class'])(inputs, attr) 'Taxis', '_class'])([inputs[0], indices], attr)
return _impl return _impl
def _stridedSlice(): def _stridedSlice():
......
...@@ -1365,11 +1365,11 @@ def test_forward_gather(): ...@@ -1365,11 +1365,11 @@ def test_forward_gather():
def test_forward_gather_nd(): def test_forward_gather_nd():
"""test operator GatherNd""" """test operator GatherNd"""
np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32)
tf.reset_default_graph() tf.reset_default_graph()
with tf.Graph().as_default(): with tf.Graph().as_default():
in_data = tf.placeholder(tf.float32, (2, 2), name="in_data") in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data")
tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd") tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd")
compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment