Unverified Commit 8efd5460 by Samuel Committed by GitHub

[FRONTEND][TFLITE]Gather, StridedSlice op support added (#4788)

* [FRONTEND][TFLITE]Gather, StridedSlice op added

* Review comments fixed
parent ba382229
......@@ -292,6 +292,79 @@ def test_forward_topk():
_test_topk((3, 5, 7), 3)
#######################################################################
# Gather
# ------
def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
""" One iteration of Gather """
indices = np.asarray(indices).astype('int32')
data = np.random.uniform(1, 10, size=dshape)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
if axis:
out = array_ops.gather(in_data, indices, axis=axis)
else:
out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis
input_range = {'in_data': (-100, 100)} if quantized else None
try:
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
quantized=quantized, input_range=input_range)
except ValueError as e:
if not oob:
raise e
except Exception as e:
raise e
def test_forward_gather():
""" GATHER """
for quantized in [False, True]:
_test_gather((4,), [1], 0, 'float32', quantized)
_test_gather((4,), [1], None, 'int32', quantized)
_test_gather((1, 4), [0], 0, 'int32', quantized)
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized)
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized)
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized)
_test_gather((4,), [16], 0, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)
#######################################################################
# StridedSlice
# ------------
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0, quantized=False):
""" One iteration of a Stridedslice """
data = np.random.uniform(size=ip_shape).astype(dtype)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
out = array_ops.strided_slice(in_data, begin, end, stride,
begin_mask=begin_mask,
end_mask=end_mask,
new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask)
input_range = {'in_data': (-100, 100)} if quantized else None
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized,
input_range=input_range)
def test_forward_stridedslice():
'''test StridedSlice'''
for quantized in [False, True]:
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized)
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized)
_test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized)
#######################################################################
# transpose
# ---------
......@@ -1855,6 +1928,8 @@ if __name__ == '__main__':
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
test_forward_gather()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment