Commit 38151abd by Yong Wu Committed by Siva

[Relay][Frontend] Support TF Gather (#2935)

* [Relay][Frontend] Support TF Gather

* fix comments
parent 4968279f
...@@ -673,10 +673,13 @@ def _square(): ...@@ -673,10 +673,13 @@ def _square():
return _op.multiply(inputs[0], inputs[0]) return _op.multiply(inputs[0], inputs[0])
return _impl return _impl
def _gather_v2(): def _gather():
"Tensorflow now support only gatherv2" "GatherV2, Gather"
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
axis = 0
if len(inputs) > 2:
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
new_input = [] new_input = []
new_input.append(inputs.pop(0)) new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0)) new_input.append(inputs.pop(0))
...@@ -1013,7 +1016,8 @@ _convert_map = { ...@@ -1013,7 +1016,8 @@ _convert_map = {
'Shape' : _shape(), 'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'), 'Sigmoid' : AttrCvt('sigmoid'),
'Fill' : _fill(), 'Fill' : _fill(),
'GatherV2' : _gather_v2(), 'GatherV2' : _gather(),
'Gather' : _gather(),
'StridedSlice' : _stridedSlice(), 'StridedSlice' : _stridedSlice(),
'LRN' : _lrn(), 'LRN' : _lrn(),
'Pad' : _pad('Pad'), 'Pad' : _pad('Pad'),
......
...@@ -19,8 +19,8 @@ from tensorflow.python.ops import math_ops ...@@ -19,8 +19,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2
from distutils.version import LooseVersion
import tvm.relay.testing.tf as tf_testing import tvm.relay.testing.tf as tf_testing
####################################################################### #######################################################################
...@@ -473,11 +473,11 @@ def test_forward_stridedslice(): ...@@ -473,11 +473,11 @@ def test_forward_stridedslice():
####################################################################### #######################################################################
# Gather # Gather, GatherV2
# ------ # ----------------
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
""" One iteration of a Gather """ """ One iteration of a GatherV2 """
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data") in_data = tf.placeholder(dtype, ip_shape, name="in_data")
...@@ -497,7 +497,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): ...@@ -497,7 +497,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0') compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0')
def test_forward_gather(): def test_forward_gather():
'''test gather layer''' '''test GatherV2 layer'''
_test_gather((4,), (1,), 1, 0, 'int32') _test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32') _test_gather((4,), (1,), 1, 0, 'float32')
_test_gather((1,4), (1,), [0], 0, 'int32') _test_gather((1,4), (1,), [0], 0, 'int32')
...@@ -509,6 +509,44 @@ def test_forward_gather(): ...@@ -509,6 +509,44 @@ def test_forward_gather():
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32') _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
""" One iteration of a Gather"""
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices)
np_data = np.random.uniform(size=ip_shape).astype(dtype)
def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
if isinstance(indice_value, int):
indices = np.array([indice_value], dtype='int32')
else:
indices = np.asarray(indice_value, dtype='int32')
return indices
np_indices = _fill_indices(indice_value)
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'Gather:0')
def test_forward_gather_v1():
'''test gather layer'''
if tf.__version__ < LooseVersion('1.7'):
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((4,), (1,), 1, 'int32')
_test_gather_v1((4,), (1,), 1, 'float32')
_test_gather_v1((1, 4), (1,), [0], 'int32')
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'int32')
_test_gather_v1((2, 2), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
_test_gather_v1((3, 3, 3), (1, 1, 2), [[[1, 0]]], 'int32')
_test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32')
####################################################################### #######################################################################
# Split # Split
# ----- # -----
...@@ -1213,6 +1251,7 @@ if __name__ == '__main__': ...@@ -1213,6 +1251,7 @@ if __name__ == '__main__':
test_forward_crop() test_forward_crop()
test_forward_pad() test_forward_pad()
test_forward_gather() test_forward_gather()
test_forward_gather_v1()
test_forward_stridedslice() test_forward_stridedslice()
test_forward_split() test_forward_split()
test_forward_unstack() test_forward_unstack()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment