Commit 93c80170 by Yong Wu Committed by Yizhi Liu

[Relay][TensorFlow Frontend] SoftPlus Sqrt (#3187)

parent 20ddd2b5
...@@ -990,6 +990,16 @@ def _softmax(): ...@@ -990,6 +990,16 @@ def _softmax():
transforms={'axis': ('axis', 1)})([inputs[0]], attr) transforms={'axis': ('axis', 1)})([inputs[0]], attr)
return _impl return _impl
def _softplus():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus
def _impl(inputs, attr, params):
exp_out = AttrCvt('exp')(inputs, attr)
inputs.append(tvm.relay.const(1, attr['T'].name))
rh = tvm.relay.const(1, attr['T'].name)
add_out = _get_relay_op('add')(exp_out, rh)
return _get_relay_op('log')(add_out)
return _impl
def _logical(name): def _logical(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr) return AttrCvt(op_name=name)(inputs, attr)
...@@ -1163,9 +1173,11 @@ _convert_map = { ...@@ -1163,9 +1173,11 @@ _convert_map = {
'Sign' : AttrCvt('sign'), 'Sign' : AttrCvt('sign'),
'Slice' : _slice(), 'Slice' : _slice(),
'Softmax' : _softmax(), 'Softmax' : _softmax(),
'Softplus' : _softplus(),
'SpaceToBatchND' : _space_to_batch_nd(), 'SpaceToBatchND' : _space_to_batch_nd(),
'Split' : _split(False), 'Split' : _split(False),
'SplitV' : _split(True), 'SplitV' : _split(True),
'Sqrt' : AttrCvt('sqrt'),
'Square' : _square(), 'Square' : _square(),
'Squeeze' : _squeeze(), 'Squeeze' : _squeeze(),
'StridedSlice' : _stridedSlice(), 'StridedSlice' : _stridedSlice(),
......
...@@ -1151,7 +1151,6 @@ def test_forward_placeholder(): ...@@ -1151,7 +1151,6 @@ def test_forward_placeholder():
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0') tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder') tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
...@@ -1440,22 +1439,37 @@ def test_forward_pow_exp(): ...@@ -1440,22 +1439,37 @@ def test_forward_pow_exp():
compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0') compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
def test_forward_log(): def test_forward_log():
"""test Log """ """test operator Log """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.log(in_data, name="log") tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
def test_forward_softplus():
"""test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.nn.softplus(in_data, name="softplus")
compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
def test_forward_rsqrt(): def test_forward_rsqrt():
"""test Rsqrt """ """test Rsqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32) np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data") in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
tf.rsqrt(in_data, name="rsqrt") tf.rsqrt(in_data, name="rsqrt")
print(tf.get_default_graph().as_graph_def())
compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
def test_forward_sqrt():
"""test Sqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
tf.sqrt(in_data, name="sqrt")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
####################################################################### #######################################################################
# Mean # Mean
# ---- # ----
...@@ -1561,6 +1575,8 @@ if __name__ == '__main__': ...@@ -1561,6 +1575,8 @@ if __name__ == '__main__':
test_forward_pow_exp() test_forward_pow_exp()
test_forward_sign() test_forward_sign()
test_forward_log() test_forward_log()
test_forward_softplus()
test_forward_sqrt()
test_forward_rsqrt() test_forward_rsqrt()
test_forward_expand_dims() test_forward_expand_dims()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment