Unverified Commit 33260318 by Samuel Committed by GitHub

[TFLITE]TOP_K op parser support (#5051)

* [TFLITE]TOP_K op parser support

* Testcase updated
parent ae482a32
......@@ -129,6 +129,7 @@ class OperatorConverter(object):
'TAN': self.convert_tan,
'TANH':self.convert_tanh,
'TILE': self.convert_tile,
'TOPK_V2': self.convert_topk_v2,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
......@@ -1550,6 +1551,24 @@ class OperatorConverter(object):
return out
def convert_topk_v2(self, op):
""" Convert TFLite TOPK_v2 """
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)
k = self.get_tensor_value(input_tensors[1])
out = _op.topk(in_expr, int(k))
return out
def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
......
......@@ -273,6 +273,24 @@ def test_forward_slice():
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])
#######################################################################
# Topk
# ----
def _test_topk(in_shape, k=1):
""" One iteration of TOPK """
data = np.random.uniform(size=in_shape).astype('float32')
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_ops.top_k(in_data, k, name='TopK')
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]])
def test_forward_topk():
""" TOPK """
_test_topk((3,), 1)
_test_topk((3,), 3)
_test_topk((3, 5, 7), 3)
_test_topk((3, 5, 7), 3)
#######################################################################
# transpose
# ---------
......@@ -1775,6 +1793,7 @@ if __name__ == '__main__':
test_all_resize()
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
test_forward_depthtospace()
test_forward_spacetodepth()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment