Commit d184d2f8 by Neo Chien Committed by Zhi

[Relay][Frontend][Tensorflow] Fix type assignment for operator 'tf.range' (#4294)

parent 62521453
......@@ -1075,6 +1075,7 @@ def _rank():
return _impl
def _range():
def _impl(inputs, attr, params):
start = _get_param(params, inputs[0])[0]
......@@ -1082,7 +1083,7 @@ def _range():
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
else params.pop('Rank').asnumpy()[0]
delta = _get_param(params, inputs[2])[0]
dtype = attr['dtype'].name if 'dtype' in attr else "int32"
dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype)
return AttrCvt(
op_name="arange",
ignores=['Tidx'],
......@@ -1092,6 +1093,7 @@ def _range():
'dtype': dtype})([], attr)
return _impl
def _elu():
def _impl(inputs, attr, params):
dtype = attr['T'].name
......@@ -1202,7 +1204,7 @@ def _topk():
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented(
raise tvm.error.OpAttributeUnImplemented(
'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk',
ignores=['sorted'],
......
......@@ -1638,6 +1638,11 @@ def test_forward_range():
tf.range(1, 18, 3, name="range")
compare_tf_with_tvm([], [], 'range:0')
"""test type assignment for operator Range"""
tf.reset_default_graph()
tf.range(1, 256 + 1, 1, dtype=tf.float32)
compare_tf_with_tvm([], [], 'range:0')
#######################################################################
# Pad
# ---
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment