Unverified Commit a94f69fa by Andrew Liu Committed by GitHub

[Relay][TF] Support for Atan/Atan2 in Relay Tensorflow frontend converter. (#5104)

* add Atan/Atan2 op

* fix bug and testing
parent 2b661231
...@@ -1535,6 +1535,11 @@ def _batch_to_space_nd(): ...@@ -1535,6 +1535,11 @@ def _batch_to_space_nd():
return _impl return _impl
def _atan2():
def _impl(inputs, attr, params):
divide = _elemwise("divide")(inputs, attr, params)
return get_relay_op("atan")(divide)
return _impl
def _prod(): def _prod():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
...@@ -1615,6 +1620,8 @@ _convert_map = { ...@@ -1615,6 +1620,8 @@ _convert_map = {
'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'), 'ArgMin' : _argx(_op.argmin, 'argmin'),
'Assert' : _assert(), 'Assert' : _assert(),
'Atan' : AttrCvt('atan'),
'Atan2' : _atan2(),
'AvgPool' : _pooling('avg_pool'), 'AvgPool' : _pooling('avg_pool'),
'AvgPool3D' : _pool3d('avg_pool3d'), 'AvgPool3D' : _pool3d('avg_pool3d'),
'BatchMatMul' : _batch_matmul(), 'BatchMatMul' : _batch_matmul(),
......
...@@ -2669,6 +2669,26 @@ def test_forward_tan(): ...@@ -2669,6 +2669,26 @@ def test_forward_tan():
tf.tan(in_data, name="tan") tf.tan(in_data, name="tan")
compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0')
def test_forward_atan():
"""test operator tan """
tf.disable_eager_execution()
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.atan(in_data, name="atan")
compare_tf_with_tvm([np_data], ['in_data:0'], 'atan:0')
def test_forward_atan2():
"""test operator tan """
tf.disable_eager_execution()
np_data_1 = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
np_data_2 = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data_1 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_1")
in_data_2 = tf.placeholder(tf.float32, (2, 3, 5), name="in_data_2")
tf.atan2(in_data_1, in_data_2, name="atan2")
compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'atan2:0')
def test_forward_sin(): def test_forward_sin():
"""test operator sin """ """test operator sin """
...@@ -3116,6 +3136,8 @@ if __name__ == '__main__': ...@@ -3116,6 +3136,8 @@ if __name__ == '__main__':
test_forward_left_shift() test_forward_left_shift()
test_forward_truncatemod() test_forward_truncatemod()
test_forward_one_hot() test_forward_one_hot()
test_forward_atan()
test_forward_atan2()
# Activations # Activations
test_forward_sigmoid() test_forward_sigmoid()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment