Commit 93c80170 by Yong Wu Committed by Yizhi Liu

[Relay][TensorFlow Frontend] SoftPlus Sqrt (#3187)

parent 20ddd2b5
......@@ -990,6 +990,16 @@ def _softmax():
transforms={'axis': ('axis', 1)})([inputs[0]], attr)
return _impl
def _softplus():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus
def _impl(inputs, attr, params):
exp_out = AttrCvt('exp')(inputs, attr)
inputs.append(tvm.relay.const(1, attr['T'].name))
rh = tvm.relay.const(1, attr['T'].name)
add_out = _get_relay_op('add')(exp_out, rh)
return _get_relay_op('log')(add_out)
return _impl
def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
......@@ -1163,9 +1173,11 @@ _convert_map = {
'Sign' : AttrCvt('sign'),
'Slice' : _slice(),
'Softmax' : _softmax(),
'Softplus' : _softplus(),
'SpaceToBatchND' : _space_to_batch_nd(),
'Split' : _split(False),
'SplitV' : _split(True),
'Sqrt' : AttrCvt('sqrt'),
'Square' : _square(),
'Squeeze' : _squeeze(),
'StridedSlice' : _stridedSlice(),
......
......@@ -1151,7 +1151,6 @@ def test_forward_placeholder():
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
......@@ -1440,22 +1439,37 @@ def test_forward_pow_exp():
compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
def test_forward_log():
"""test Log """
"""test operator Log """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
def test_forward_softplus():
"""test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.nn.softplus(in_data, name="softplus")
compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
def test_forward_rsqrt():
"""test Rsqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
tf.rsqrt(in_data, name="rsqrt")
print(tf.get_default_graph().as_graph_def())
compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
def test_forward_sqrt():
"""test Sqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
tf.sqrt(in_data, name="sqrt")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
#######################################################################
# Mean
# ----
......@@ -1561,6 +1575,8 @@ if __name__ == '__main__':
test_forward_pow_exp()
test_forward_sign()
test_forward_log()
test_forward_softplus()
test_forward_sqrt()
test_forward_rsqrt()
test_forward_expand_dims()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment