Commit d184d2f8 by Neo Chien Committed by Zhi

[Relay][Frontend][Tensorflow] Fix type assignment for operator 'tf.range' (#4294)

parent 62521453
...@@ -1075,6 +1075,7 @@ def _rank(): ...@@ -1075,6 +1075,7 @@ def _rank():
return _impl return _impl
def _range(): def _range():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
start = _get_param(params, inputs[0])[0] start = _get_param(params, inputs[0])[0]
...@@ -1082,7 +1083,7 @@ def _range(): ...@@ -1082,7 +1083,7 @@ def _range():
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \ if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
else params.pop('Rank').asnumpy()[0] else params.pop('Rank').asnumpy()[0]
delta = _get_param(params, inputs[2])[0] delta = _get_param(params, inputs[2])[0]
dtype = attr['dtype'].name if 'dtype' in attr else "int32" dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype)
return AttrCvt( return AttrCvt(
op_name="arange", op_name="arange",
ignores=['Tidx'], ignores=['Tidx'],
...@@ -1092,6 +1093,7 @@ def _range(): ...@@ -1092,6 +1093,7 @@ def _range():
'dtype': dtype})([], attr) 'dtype': dtype})([], attr)
return _impl return _impl
def _elu(): def _elu():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
dtype = attr['T'].name dtype = attr['T'].name
...@@ -1202,7 +1204,7 @@ def _topk(): ...@@ -1202,7 +1204,7 @@ def _topk():
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2') 'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False: if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented( raise tvm.error.OpAttributeUnImplemented(
'Attribute sorted=False is not supported in operator TopKV2') 'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk', return AttrCvt(op_name='topk',
ignores=['sorted'], ignores=['sorted'],
......
...@@ -1638,6 +1638,11 @@ def test_forward_range(): ...@@ -1638,6 +1638,11 @@ def test_forward_range():
tf.range(1, 18, 3, name="range") tf.range(1, 18, 3, name="range")
compare_tf_with_tvm([], [], 'range:0') compare_tf_with_tvm([], [], 'range:0')
"""test type assignment for operator Range"""
tf.reset_default_graph()
tf.range(1, 256 + 1, 1, dtype=tf.float32)
compare_tf_with_tvm([], [], 'range:0')
####################################################################### #######################################################################
# Pad # Pad
# --- # ---
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment