Unverified Commit 52bf1b35 by Samuel Committed by GitHub

[RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support (#5395)

* [RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support

* Review comment fixed

* Gradient testcase added
parent 3cc49719
...@@ -142,6 +142,14 @@ def _unary(name): ...@@ -142,6 +142,14 @@ def _unary(name):
return _impl return _impl
def _log1p():
def _impl(inputs, input_types):
# 1_plus_log x = log(x + 1)
one = _expr.const(1, dtype="float32")
return _op.log(inputs[0] + one)
return _impl
def _arange(): def _arange():
def _impl(inputs, input_types): def _impl(inputs, input_types):
if len(inputs) == 5: if len(inputs) == 5:
...@@ -1642,11 +1650,16 @@ def _get_convert_map(prelude): ...@@ -1642,11 +1650,16 @@ def _get_convert_map(prelude):
"aten::abs" : _unary("abs"), "aten::abs" : _unary("abs"),
"aten::neg" : _unary("negative"), "aten::neg" : _unary("negative"),
"aten::cos" : _unary("cos"), "aten::cos" : _unary("cos"),
"aten::cosh" : _unary("cosh"),
"aten::sin" : _unary("sin"), "aten::sin" : _unary("sin"),
"aten::sinh" : _unary("sinh"),
"aten::tan" : _unary("tan"), "aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"), "aten::tanh" : _unary("tanh"),
"aten::atan" : _unary("atan"), "aten::atan" : _unary("atan"),
"aten::log" : _unary("log"), "aten::log" : _unary("log"),
"aten::log2" : _unary("log2"),
"aten::log10" : _unary("log10"),
"aten::log1p" : _log1p(),
"aten::exp" : _unary("exp"), "aten::exp" : _unary("exp"),
"aten::erf" : _unary("erf"), "aten::erf" : _unary("erf"),
"aten::trunc" : _unary("trunc"), "aten::trunc" : _unary("trunc"),
......
...@@ -27,9 +27,13 @@ from .op import register_pattern, OpPattern ...@@ -27,9 +27,13 @@ from .op import register_pattern, OpPattern
register_broadcast_schedule("log") register_broadcast_schedule("log")
register_broadcast_schedule("log2")
register_broadcast_schedule("log10")
register_broadcast_schedule("tan") register_broadcast_schedule("tan")
register_broadcast_schedule("cos") register_broadcast_schedule("cos")
register_broadcast_schedule("cosh")
register_broadcast_schedule("sin") register_broadcast_schedule("sin")
register_broadcast_schedule("sinh")
register_broadcast_schedule("atan") register_broadcast_schedule("atan")
register_broadcast_schedule("exp") register_broadcast_schedule("exp")
register_broadcast_schedule("erf") register_broadcast_schedule("erf")
......
...@@ -27,12 +27,14 @@ from .op import register_gradient ...@@ -27,12 +27,14 @@ from .op import register_gradient
from .reduce import sum as _sum from .reduce import sum as _sum
from .tensor import ( from .tensor import (
cos, cos,
cosh,
exp, exp,
less, less,
negative, negative,
ones_like, ones_like,
power, power,
sin, sin,
sinh,
zeros_like, zeros_like,
equal, equal,
shape_of, shape_of,
...@@ -61,6 +63,24 @@ def log_grad(orig, grad): ...@@ -61,6 +63,24 @@ def log_grad(orig, grad):
return [grad * ones_like(x) / x] return [grad * ones_like(x) / x]
@register_gradient("log2")
def log2_grad(orig, grad):
"""Returns [grad * 1 / (log(2) * x)]"""
x = orig.args[0]
ones = ones_like(x)
two = const(2.0)
return [grad * ones / (log(two) * x)]
@register_gradient("log10")
def log10_grad(orig, grad):
"""Returns [grad * 1 / (log(10) * x)]"""
x = orig.args[0]
ones = ones_like(x)
ten = const(10.0)
return [grad * ones / (log(ten) * x)]
@register_gradient("tan") @register_gradient("tan")
def tan_grad(orig, grad): def tan_grad(orig, grad):
"""Returns [grad / (cos^2(x))]""" """Returns [grad / (cos^2(x))]"""
...@@ -76,12 +96,26 @@ def cos_grad(orig, grad): ...@@ -76,12 +96,26 @@ def cos_grad(orig, grad):
return [grad * (-ones * sin(x))] return [grad * (-ones * sin(x))]
@register_gradient("cosh")
def cosh_grad(orig, grad):
"""Returns [grad * (-sinh(x))]"""
x = orig.args[0]
ones = ones_like(x)
return [grad * (-ones * sinh(x))]
@register_gradient("sin") @register_gradient("sin")
def sin_grad(orig, grad): def sin_grad(orig, grad):
"""Returns [grad * cos(x)]""" """Returns [grad * cos(x)]"""
x = orig.args[0] x = orig.args[0]
return [grad * cos(x)] return [grad * cos(x)]
@register_gradient("sinh")
def sinh_grad(orig, grad):
"""Returns [grad * cosh(x)]"""
x = orig.args[0]
return [grad * cosh(x)]
@register_gradient("atan") @register_gradient("atan")
def atan_grad(orig, grad): def atan_grad(orig, grad):
"""Returns [grad * 1 / (1 + x ^ 2)]""" """Returns [grad * 1 / (1 + x ^ 2)]"""
......
...@@ -47,6 +47,36 @@ def log(data): ...@@ -47,6 +47,36 @@ def log(data):
""" """
return _make.log(data) return _make.log(data)
def log2(data):
"""Compute elementwise log to the base 2 of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log2(data)
def log10(data):
"""Compute elementwise log to the base 10 of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log10(data)
def tan(data): def tan(data):
"""Compute elementwise tan of data. """Compute elementwise tan of data.
...@@ -77,6 +107,21 @@ def cos(data): ...@@ -77,6 +107,21 @@ def cos(data):
""" """
return _make.cos(data) return _make.cos(data)
def cosh(data):
"""Compute elementwise cosh of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.cosh(data)
def sin(data): def sin(data):
"""Compute elementwise sin of data. """Compute elementwise sin of data.
...@@ -92,6 +137,21 @@ def sin(data): ...@@ -92,6 +137,21 @@ def sin(data):
""" """
return _make.sin(data) return _make.sin(data)
def sinh(data):
"""Compute elementwise sinh of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sinh(data)
def atan(data): def atan(data):
"""Compute elementwise atan of data. """Compute elementwise atan of data.
......
...@@ -500,7 +500,7 @@ def where(condition, x, y): ...@@ -500,7 +500,7 @@ def where(condition, x, y):
Returns Returns
------- -------
result : relay.Expr result : relay.Expr
The selected array. The selected array.
Examples Examples
-------- --------
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
# expose all operators in tvm tir.op # expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import sinh, cosh, log2, log10
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from tvm.tir import isnan, isfinite, isinf from tvm.tir import isnan, isfinite, isinf
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
......
...@@ -539,7 +539,7 @@ def sin(x): ...@@ -539,7 +539,7 @@ def sin(x):
def sinh(x): def sinh(x):
"""Take sin of input x. """Take sinh of input x.
Parameters Parameters
---------- ----------
......
...@@ -51,6 +51,28 @@ RELAY_REGISTER_UNARY_OP("log") ...@@ -51,6 +51,28 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("log2")
.describe(R"code(Returns the log to base 2 of input array, computed element-wise.
.. math::
log2(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2));
RELAY_REGISTER_UNARY_OP("log10")
.describe(R"code(Returns the log to base 10 of input array, computed element-wise.
.. math::
log10(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10));
RELAY_REGISTER_UNARY_OP("tan") RELAY_REGISTER_UNARY_OP("tan")
.describe(R"code(Returns the tan of input array, computed element-wise. .describe(R"code(Returns the tan of input array, computed element-wise.
...@@ -73,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("cos") ...@@ -73,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("cos")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));
RELAY_REGISTER_UNARY_OP("cosh")
.describe(R"code(Returns the cosh of input array, computed element-wise.
.. math::
Y = cosh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh));
RELAY_REGISTER_UNARY_OP("sin") RELAY_REGISTER_UNARY_OP("sin")
.describe(R"code(Returns the sin of input array, computed element-wise. .describe(R"code(Returns the sin of input array, computed element-wise.
...@@ -84,6 +117,17 @@ RELAY_REGISTER_UNARY_OP("sin") ...@@ -84,6 +117,17 @@ RELAY_REGISTER_UNARY_OP("sin")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));
RELAY_REGISTER_UNARY_OP("sinh")
.describe(R"code(Returns the sinh of input array, computed element-wise.
.. math::
Y = sinh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh));
RELAY_REGISTER_UNARY_OP("atan") RELAY_REGISTER_UNARY_OP("atan")
.describe(R"code(Returns the atan of input array, computed element-wise. .describe(R"code(Returns the atan of input array, computed element-wise.
......
...@@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf") ...@@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
...@@ -49,9 +55,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan") ...@@ -49,9 +55,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
......
...@@ -1868,6 +1868,26 @@ def test_forward_unary(): ...@@ -1868,6 +1868,26 @@ def test_forward_unary():
def forward(self, *args): def forward(self, *args):
return torch.neg(args[0]) return torch.neg(args[0])
class Sinh1(Module):
def forward(self, *args):
return torch.sinh(args[0])
class Cosh1(Module):
def forward(self, *args):
return torch.cosh(args[0])
class Log2_1(Module):
def forward(self, *args):
return torch.log2(args[0])
class Log10_1(Module):
def forward(self, *args):
return torch.log10(args[0])
class Log1p_1(Module):
def forward(self, *args):
return torch.log1p(args[0])
input_shape = [1, 3, 10, 10] input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float() input_data = torch.rand(input_shape).float()
verify_model(Sqrt1().float().eval(), input_data=input_data) verify_model(Sqrt1().float().eval(), input_data=input_data)
...@@ -1876,11 +1896,16 @@ def test_forward_unary(): ...@@ -1876,11 +1896,16 @@ def test_forward_unary():
verify_model(Floor1().float().eval(), input_data=input_data) verify_model(Floor1().float().eval(), input_data=input_data)
verify_model(Round1().float().eval(), input_data=input_data) verify_model(Round1().float().eval(), input_data=input_data)
verify_model(Cos1().float().eval(), input_data=input_data) verify_model(Cos1().float().eval(), input_data=input_data)
verify_model(Cosh1().float().eval(), input_data=input_data)
verify_model(Sin1().float().eval(), input_data=input_data) verify_model(Sin1().float().eval(), input_data=input_data)
verify_model(Sinh1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data) verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data) verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data) verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data) verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Log2_1().float().eval(), input_data=input_data)
verify_model(Log10_1().float().eval(), input_data=input_data)
verify_model(Log1p_1().float().eval(), input_data=input_data)
verify_model(Exp1().float().eval(), input_data=input_data) verify_model(Exp1().float().eval(), input_data=input_data)
verify_model(Erf1().float().eval(), input_data=input_data) verify_model(Erf1().float().eval(), input_data=input_data)
verify_model(Trunc1().float().eval(), input_data=input_data) verify_model(Trunc1().float().eval(), input_data=input_data)
......
...@@ -65,7 +65,11 @@ def test_unary_op(): ...@@ -65,7 +65,11 @@ def test_unary_op():
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)), (tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)), (tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)), (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]: (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))),
(tvm.relay.log2, lambda x: 1 / (np.log(2) * x)),
(tvm.relay.log10, lambda x: 1 / (np.log(10) * x)),
(tvm.relay.cosh, lambda x: -1.0 * np.sinh(x)),
(tvm.relay.sinh, lambda x: np.cosh(x))]:
check_single_op(opfunc, ref) check_single_op(opfunc, ref)
......
...@@ -49,14 +49,18 @@ TOPI_DECLARE_UNARY_OP(erf); ...@@ -49,14 +49,18 @@ TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt); TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log); TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(log2);
TOPI_DECLARE_UNARY_OP(log10);
TOPI_DECLARE_UNARY_OP(floor); TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil); TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round); TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs); TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos); TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(cosh);
TOPI_DECLARE_UNARY_OP(tan); TOPI_DECLARE_UNARY_OP(tan);
TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(sinh);
TOPI_DECLARE_UNARY_OP(atan); TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan); TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(tanh);
......
...@@ -144,6 +144,23 @@ def cos(x): ...@@ -144,6 +144,23 @@ def cos(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE) @tvm.te.tag_scope(tag=tag.ELEMWISE)
def cosh(x):
"""Take cosh of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.cosh(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def sin(x): def sin(x):
"""Take sin of input x. """Take sin of input x.
...@@ -161,6 +178,23 @@ def sin(x): ...@@ -161,6 +178,23 @@ def sin(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE) @tvm.te.tag_scope(tag=tag.ELEMWISE)
def sinh(x):
"""Take sinh of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.sinh(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def atan(x): def atan(x):
"""Take atan of input x. """Take atan of input x.
...@@ -346,6 +380,40 @@ def log(x): ...@@ -346,6 +380,40 @@ def log(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE) @tvm.te.tag_scope(tag=tag.ELEMWISE)
def log2(x):
"""Take logarithm to the base 2 of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.log2(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def log10(x):
"""Take logarithm to the base 10 of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.log10(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def sqrt(x): def sqrt(x):
"""Take square root of input x. """Take square root of input x.
......
...@@ -61,11 +61,21 @@ TVM_REGISTER_GLOBAL("topi.cos") ...@@ -61,11 +61,21 @@ TVM_REGISTER_GLOBAL("topi.cos")
*rv = cos(args[0]); *rv = cos(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.cosh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cosh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sin") TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]); *rv = sin(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.sinh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sinh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tanh") TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]); *rv = tanh(args[0]);
...@@ -101,6 +111,16 @@ TVM_REGISTER_GLOBAL("topi.log") ...@@ -101,6 +111,16 @@ TVM_REGISTER_GLOBAL("topi.log")
*rv = log(args[0]); *rv = log(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.log2")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log2(args[0]);
});
TVM_REGISTER_GLOBAL("topi.log10")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log10(args[0]);
});
TVM_REGISTER_GLOBAL("topi.identity") TVM_REGISTER_GLOBAL("topi.identity")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = identity(args[0]); *rv = identity(args[0]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment