Commit b67afcd6 by Yong Wu Committed by Tianqi Chen

[Relay] add ClipByValue and Neg in tf frontend converter (#3211)

parent 29ee8a23
...@@ -941,6 +941,13 @@ def _where(): ...@@ -941,6 +941,13 @@ def _where():
return AttrCvt(op_name="where")(inputs, attr) return AttrCvt(op_name="where")(inputs, attr)
return _impl return _impl
def _clip_by_value():
def _impl(inputs, attr, params):
a_min = params.pop(inputs[1].name_hint).asnumpy()[0]
a_max = params.pop(inputs[2].name_hint).asnumpy()[0]
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
return _impl
def _reverse_v2(): def _reverse_v2():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = _get_num_param(params, inputs[1]) axis = _get_num_param(params, inputs[1])
...@@ -1212,6 +1219,7 @@ _convert_map = { ...@@ -1212,6 +1219,7 @@ _convert_map = {
'Cast' : _cast(), 'Cast' : _cast(),
'Ceil' : AttrCvt('ceil'), 'Ceil' : AttrCvt('ceil'),
'CheckNumerics' : _check_numerics(), 'CheckNumerics' : _check_numerics(),
'ClipByValue' : _clip_by_value(),
'Concat' : _concat(), 'Concat' : _concat(),
'ConcatV2' : _concatV2(), 'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'), 'Conv2D' : _conv('conv'),
...@@ -1245,6 +1253,7 @@ _convert_map = { ...@@ -1245,6 +1253,7 @@ _convert_map = {
'Mean' : _mean(), 'Mean' : _mean(),
'Minimum' : _elemwise('minimum'), 'Minimum' : _elemwise('minimum'),
'Mul' : _elemwise('multiply'), 'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'), 'NotEqual' : _broadcast('not_equal'),
'Pack' : _pack(), 'Pack' : _pack(),
'Pad' : _pad('Pad'), 'Pad' : _pad('Pad'),
......
...@@ -834,6 +834,23 @@ def test_forward_tile(): ...@@ -834,6 +834,23 @@ def test_forward_tile():
####################################################################### #######################################################################
# ClipByValue
# -----------
def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype):
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.clip_by_value(in_data, clip_value_min, clip_value_max, name="ClipByValue")
np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0')
def test_forward_clip_by_value():
'''test ClipByValue op'''
if tf.__version__ < LooseVersion('1.9'):
_test_forward_clip_by_value((4,), .1, 5., 'float32')
_test_forward_clip_by_value((4, 4), 1, 5, 'int32')
#######################################################################
# Multi Input to graph # Multi Input to graph
# -------------------- # --------------------
...@@ -1591,6 +1608,14 @@ def test_forward_log(): ...@@ -1591,6 +1608,14 @@ def test_forward_log():
tf.log(in_data, name="log") tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
def test_forward_negative():
"""test tf operator Neg """
np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
tf.negative(in_data, name="negative")
compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
def test_forward_softplus(): def test_forward_softplus():
"""test operator Softplus""" """test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
...@@ -1738,6 +1763,7 @@ if __name__ == '__main__': ...@@ -1738,6 +1763,7 @@ if __name__ == '__main__':
test_forward_unstack() test_forward_unstack()
test_forward_tile() test_forward_tile()
test_forward_top_k_v2() test_forward_top_k_v2()
test_forward_clip_by_value()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
...@@ -1753,6 +1779,7 @@ if __name__ == '__main__': ...@@ -1753,6 +1779,7 @@ if __name__ == '__main__':
test_forward_pow_exp() test_forward_pow_exp()
test_forward_sign() test_forward_sign()
test_forward_log() test_forward_log()
test_forward_negative()
test_forward_softplus() test_forward_softplus()
test_forward_sqrt() test_forward_sqrt()
test_forward_rsqrt() test_forward_rsqrt()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment