Commit d72cdfa6 by alexgl-github Committed by Tianqi Chen

Add support for Tensorflow operators log1p, cos, sin (#3614)

The patch adds support for Tensorflow operators log1p and cos
Tensorflow log1p is described at https://www.tensorflow.org/api_docs/python/tf/math/log1p
Tensorflow cos is described at https://www.tensorflow.org/api_docs/python/tf/math/cos
Tensorflow sin is described at https://www.tensorflow.org/api_docs/python/tf/math/sin
parent 331585f4
...@@ -517,7 +517,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt); ...@@ -517,7 +517,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(sin);
// Implementation details after this // Implementation details after this
inline bool is_const(const Expr& x) { inline bool is_const(const Expr& x) {
......
...@@ -258,6 +258,35 @@ def log(x): ...@@ -258,6 +258,35 @@ def log(x):
""" """
return call_pure_intrin(x.dtype, "log", x) return call_pure_intrin(x.dtype, "log", x)
def cos(x):
"""Take cos of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "cos", x)
def sin(x):
"""Take sin of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "sin", x)
def sqrt(x): def sqrt(x):
"""Take square root of input x. """Take square root of input x.
......
...@@ -1301,6 +1301,13 @@ def _prod(): ...@@ -1301,6 +1301,13 @@ def _prod():
return _op.prod(inputs[0], int(axis), keepdims=keepdims) return _op.prod(inputs[0], int(axis), keepdims=keepdims)
return _impl return _impl
def _log1p():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
def _impl(inputs, attr, params):
one = tvm.relay.const(1, attr['T'].name)
add_out = _get_relay_op('add')(inputs[0], one)
return _get_relay_op('log')(add_out)
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -1354,6 +1361,9 @@ _convert_map = { ...@@ -1354,6 +1361,9 @@ _convert_map = {
'Less' : _broadcast('less'), 'Less' : _broadcast('less'),
'LessEqual' : _broadcast('less_equal'), 'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'), 'Log' : AttrCvt('log'),
'Log1p' : _log1p(),
'Cos' : AttrCvt('cos'),
'Sin' : AttrCvt('sin'),
'LogicalAnd' : _logical('logical_and'), 'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'), 'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'), 'LogicalNot' : _logical('logical_not'),
......
...@@ -25,6 +25,9 @@ schedule_broadcast = schedule_injective ...@@ -25,6 +25,9 @@ schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective schedule_elemwise = schedule_injective
register_schedule("log", schedule_broadcast) register_schedule("log", schedule_broadcast)
register_schedule("log1p", schedule_broadcast)
register_schedule("cos", schedule_broadcast)
register_schedule("sin", schedule_broadcast)
register_schedule("exp", schedule_broadcast) register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast) register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast) register_schedule("rsqrt", schedule_broadcast)
......
...@@ -20,7 +20,7 @@ from __future__ import absolute_import ...@@ -20,7 +20,7 @@ from __future__ import absolute_import
from ..expr import const from ..expr import const
from .op import register_gradient from .op import register_gradient
from .transform import collapse_sum_like, broadcast_to_like, where from .transform import collapse_sum_like, broadcast_to_like, where
from .tensor import exp, negative, power, less from .tensor import exp, negative, power, less, cos, sin
from .tensor import zeros_like, ones_like from .tensor import zeros_like, ones_like
from . import nn as _nn from . import nn as _nn
...@@ -31,6 +31,18 @@ def log_grad(orig, grad): ...@@ -31,6 +31,18 @@ def log_grad(orig, grad):
x = orig.args[0] x = orig.args[0]
return [grad * ones_like(x) / x] return [grad * ones_like(x) / x]
@register_gradient("cos")
def cos_grad(orig, grad):
"""Returns [grad * (-sin(x))]"""
x = orig.args[0]
ones = ones_like(x)
return [grad * (-ones * sin(x))]
@register_gradient("sin")
def sin_grad(orig, grad):
"""Returns [grad * cos(x)]"""
x = orig.args[0]
return [grad * cos(x)]
@register_gradient("exp") @register_gradient("exp")
def exp_grad(orig, grad): def exp_grad(orig, grad):
......
...@@ -46,6 +46,35 @@ def log(data): ...@@ -46,6 +46,35 @@ def log(data):
""" """
return _make.log(data) return _make.log(data)
def cos(data):
"""Compute elementwise cos of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.cos(data)
def sin(data):
"""Compute elementwise sin of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sin(data)
def exp(data): def exp(data):
"""Compute elementwise exp of data. """Compute elementwise exp of data.
......
...@@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") ...@@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
......
...@@ -95,6 +95,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") ...@@ -95,6 +95,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
......
...@@ -86,6 +86,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") ...@@ -86,6 +86,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
} // namespace llvm } // namespace llvm
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -79,6 +79,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow") ...@@ -79,6 +79,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin")
.set_body(DispatchExternLibDevice);
} // namespace llvm } // namespace llvm
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -53,6 +53,28 @@ RELAY_REGISTER_UNARY_OP("log") ...@@ -53,6 +53,28 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("cos")
.describe(R"code(Returns the cos of input array, computed element-wise.
.. math::
Y = cos(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));
RELAY_REGISTER_UNARY_OP("sin")
.describe(R"code(Returns the sin of input array, computed element-wise.
.. math::
Y = sin(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));
RELAY_REGISTER_UNARY_OP("exp") RELAY_REGISTER_UNARY_OP("exp")
.describe(R"code(Returns the exp input array, computed element-wise. .describe(R"code(Returns the exp input array, computed element-wise.
......
...@@ -1869,6 +1869,30 @@ def test_forward_log(): ...@@ -1869,6 +1869,30 @@ def test_forward_log():
tf.log(in_data, name="log") tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0') compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
def test_forward_log1p():
"""test operator Log1p """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.log1p(in_data, name="log1p")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0')
def test_forward_cos():
"""test operator cos """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.cos(in_data, name="cos")
compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')
def test_forward_sin():
"""test operator sin """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.sin(in_data, name="sin")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0')
def test_forward_negative(): def test_forward_negative():
"""test tf operator Neg """ """test tf operator Neg """
np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32) np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)
...@@ -2159,6 +2183,9 @@ if __name__ == '__main__': ...@@ -2159,6 +2183,9 @@ if __name__ == '__main__':
test_forward_pow_exp() test_forward_pow_exp()
test_forward_sign() test_forward_sign()
test_forward_log() test_forward_log()
test_forward_log1p()
test_forward_cos()
test_forward_sin()
test_forward_negative() test_forward_negative()
test_forward_divide() test_forward_divide()
test_forward_abs() test_forward_abs()
......
...@@ -56,7 +56,9 @@ def test_unary_op(): ...@@ -56,7 +56,9 @@ def test_unary_op():
(tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)), (tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)), (tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))), (tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x)))]: (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x))]:
check_single_op(opfunc, ref) check_single_op(opfunc, ref)
......
...@@ -71,7 +71,9 @@ def test_unary_op(): ...@@ -71,7 +71,9 @@ def test_unary_op():
(tvm.relay.rsqrt, rsqrt), (tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid), (tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh), (tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]: (relay.nn.relu, relu),
(tvm.relay.cos, np.cos),
(tvm.relay.sin, np.sin)]:
check_single_op(opfunc, ref) check_single_op(opfunc, ref)
......
...@@ -54,6 +54,8 @@ TOPI_DECLARE_UNARY_OP(ceil); ...@@ -54,6 +54,8 @@ TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round); TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs); TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
/* /*
* \brief Fast_tanh_float implementation from Eigen * \brief Fast_tanh_float implementation from Eigen
......
...@@ -90,6 +90,37 @@ def tanh(x): ...@@ -90,6 +90,37 @@ def tanh(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def cos(x):
"""Take cos of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.cos(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def sin(x):
"""Take sin of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.sin(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x): def floor(x):
...@@ -206,7 +237,6 @@ def log(x): ...@@ -206,7 +237,6 @@ def log(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def sqrt(x): def sqrt(x):
"""Take square root of input x. """Take square root of input x.
......
...@@ -148,6 +148,16 @@ TVM_REGISTER_GLOBAL("topi.exp") ...@@ -148,6 +148,16 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]); *rv = exp(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tanh") TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]); *rv = tanh(args[0]);
......
...@@ -41,6 +41,8 @@ def test_ewise(): ...@@ -41,6 +41,8 @@ def test_ewise():
test_apply(topi.log, "log") test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt") test_apply(topi.sqrt, "sqrt")
test_apply(topi.rsqrt, "rsqrt") test_apply(topi.rsqrt, "rsqrt")
test_apply(topi.sin, "sin")
test_apply(topi.cos, "cos")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -84,6 +84,8 @@ def test_ewise(): ...@@ -84,6 +84,8 @@ def test_ewise():
test_apply(topi.log, "log", np.log, 0, 100) test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True) test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
def test_cast(): def test_cast():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment