Commit 83da72f2 by Wang Yucheng Committed by Tianqi Chen

[Relay][Frontend][TFLite] Add constant input support for elemwise ops (#4666)

* [Relay][Frontend][TFLite] Add constant input support for elemwise ops

* modify in tflite.py
parent 02c6767a
......@@ -611,7 +611,16 @@ class OperatorConverter(object):
assert len(input_tensors) == 2, "input tensors length should be 2"
lhs_tensor = input_tensors[0]
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
if self.has_expr(lhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type())
lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor),
dtype=lhs_type_str)
rhs_tensor = input_tensors[1]
if self.has_expr(rhs_tensor.tensor_idx):
......
......@@ -787,6 +787,24 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data[0], ['in_0:0'], in_data, [out])
# Test with constant and tensor
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]
if quantized:
inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor")
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_1")]
# the 1st tensor is treated as constant and directly added as part of the operation
out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data)
out = with_fused_activation_function(out, fused_activation_function)
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True)
else:
out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data[1], ['in_1:0'], in_data, [out])
#######################################################################
# Add
# ---
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment