Unverified Commit 52bf1b35 by Samuel Committed by GitHub

[RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support (#5395)

* [RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support

* Review comment fixed

* Gradient testcase added
parent 3cc49719
......@@ -142,6 +142,14 @@ def _unary(name):
return _impl
def _log1p():
def _impl(inputs, input_types):
# 1_plus_log x = log(x + 1)
one = _expr.const(1, dtype="float32")
return _op.log(inputs[0] + one)
return _impl
def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
......@@ -1642,11 +1650,16 @@ def _get_convert_map(prelude):
"aten::abs" : _unary("abs"),
"aten::neg" : _unary("negative"),
"aten::cos" : _unary("cos"),
"aten::cosh" : _unary("cosh"),
"aten::sin" : _unary("sin"),
"aten::sinh" : _unary("sinh"),
"aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"),
"aten::atan" : _unary("atan"),
"aten::log" : _unary("log"),
"aten::log2" : _unary("log2"),
"aten::log10" : _unary("log10"),
"aten::log1p" : _log1p(),
"aten::exp" : _unary("exp"),
"aten::erf" : _unary("erf"),
"aten::trunc" : _unary("trunc"),
......
......@@ -27,9 +27,13 @@ from .op import register_pattern, OpPattern
register_broadcast_schedule("log")
register_broadcast_schedule("log2")
register_broadcast_schedule("log10")
register_broadcast_schedule("tan")
register_broadcast_schedule("cos")
register_broadcast_schedule("cosh")
register_broadcast_schedule("sin")
register_broadcast_schedule("sinh")
register_broadcast_schedule("atan")
register_broadcast_schedule("exp")
register_broadcast_schedule("erf")
......
......@@ -27,12 +27,14 @@ from .op import register_gradient
from .reduce import sum as _sum
from .tensor import (
cos,
cosh,
exp,
less,
negative,
ones_like,
power,
sin,
sinh,
zeros_like,
equal,
shape_of,
......@@ -61,6 +63,24 @@ def log_grad(orig, grad):
return [grad * ones_like(x) / x]
@register_gradient("log2")
def log2_grad(orig, grad):
"""Returns [grad * 1 / (log(2) * x)]"""
x = orig.args[0]
ones = ones_like(x)
two = const(2.0)
return [grad * ones / (log(two) * x)]
@register_gradient("log10")
def log10_grad(orig, grad):
"""Returns [grad * 1 / (log(10) * x)]"""
x = orig.args[0]
ones = ones_like(x)
ten = const(10.0)
return [grad * ones / (log(ten) * x)]
@register_gradient("tan")
def tan_grad(orig, grad):
"""Returns [grad / (cos^2(x))]"""
......@@ -76,12 +96,26 @@ def cos_grad(orig, grad):
return [grad * (-ones * sin(x))]
@register_gradient("cosh")
def cosh_grad(orig, grad):
"""Returns [grad * (-sinh(x))]"""
x = orig.args[0]
ones = ones_like(x)
return [grad * (-ones * sinh(x))]
@register_gradient("sin")
def sin_grad(orig, grad):
"""Returns [grad * cos(x)]"""
x = orig.args[0]
return [grad * cos(x)]
@register_gradient("sinh")
def sinh_grad(orig, grad):
"""Returns [grad * cosh(x)]"""
x = orig.args[0]
return [grad * cosh(x)]
@register_gradient("atan")
def atan_grad(orig, grad):
"""Returns [grad * 1 / (1 + x ^ 2)]"""
......
......@@ -47,6 +47,36 @@ def log(data):
"""
return _make.log(data)
def log2(data):
"""Compute elementwise log to the base 2 of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log2(data)
def log10(data):
"""Compute elementwise log to the base 10 of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log10(data)
def tan(data):
"""Compute elementwise tan of data.
......@@ -77,6 +107,21 @@ def cos(data):
"""
return _make.cos(data)
def cosh(data):
"""Compute elementwise cosh of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.cosh(data)
def sin(data):
"""Compute elementwise sin of data.
......@@ -92,6 +137,21 @@ def sin(data):
"""
return _make.sin(data)
def sinh(data):
"""Compute elementwise sinh of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sinh(data)
def atan(data):
"""Compute elementwise atan of data.
......
......@@ -20,6 +20,7 @@
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import sinh, cosh, log2, log10
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from tvm.tir import isnan, isfinite, isinf
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
......
......@@ -539,7 +539,7 @@ def sin(x):
def sinh(x):
"""Take sin of input x.
"""Take sinh of input x.
Parameters
----------
......
......@@ -51,6 +51,28 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
RELAY_REGISTER_UNARY_OP("log2")
.describe(R"code(Returns the log to base 2 of input array, computed element-wise.
.. math::
log2(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2));
RELAY_REGISTER_UNARY_OP("log10")
.describe(R"code(Returns the log to base 10 of input array, computed element-wise.
.. math::
log10(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10));
RELAY_REGISTER_UNARY_OP("tan")
.describe(R"code(Returns the tan of input array, computed element-wise.
......@@ -73,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("cos")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));
RELAY_REGISTER_UNARY_OP("cosh")
.describe(R"code(Returns the cosh of input array, computed element-wise.
.. math::
Y = cosh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh));
RELAY_REGISTER_UNARY_OP("sin")
.describe(R"code(Returns the sin of input array, computed element-wise.
......@@ -84,6 +117,17 @@ RELAY_REGISTER_UNARY_OP("sin")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));
RELAY_REGISTER_UNARY_OP("sinh")
.describe(R"code(Returns the sinh of input array, computed element-wise.
.. math::
Y = sinh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh));
RELAY_REGISTER_UNARY_OP("atan")
.describe(R"code(Returns the atan of input array, computed element-wise.
......
......@@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p")
.set_body(DispatchExtern<FloatSuffix>);
......@@ -49,9 +55,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);
......
......@@ -1868,6 +1868,26 @@ def test_forward_unary():
def forward(self, *args):
return torch.neg(args[0])
class Sinh1(Module):
def forward(self, *args):
return torch.sinh(args[0])
class Cosh1(Module):
def forward(self, *args):
return torch.cosh(args[0])
class Log2_1(Module):
def forward(self, *args):
return torch.log2(args[0])
class Log10_1(Module):
def forward(self, *args):
return torch.log10(args[0])
class Log1p_1(Module):
def forward(self, *args):
return torch.log1p(args[0])
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Sqrt1().float().eval(), input_data=input_data)
......@@ -1876,11 +1896,16 @@ def test_forward_unary():
verify_model(Floor1().float().eval(), input_data=input_data)
verify_model(Round1().float().eval(), input_data=input_data)
verify_model(Cos1().float().eval(), input_data=input_data)
verify_model(Cosh1().float().eval(), input_data=input_data)
verify_model(Sin1().float().eval(), input_data=input_data)
verify_model(Sinh1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Log2_1().float().eval(), input_data=input_data)
verify_model(Log10_1().float().eval(), input_data=input_data)
verify_model(Log1p_1().float().eval(), input_data=input_data)
verify_model(Exp1().float().eval(), input_data=input_data)
verify_model(Erf1().float().eval(), input_data=input_data)
verify_model(Trunc1().float().eval(), input_data=input_data)
......
......@@ -65,7 +65,11 @@ def test_unary_op():
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]:
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))),
(tvm.relay.log2, lambda x: 1 / (np.log(2) * x)),
(tvm.relay.log10, lambda x: 1 / (np.log(10) * x)),
(tvm.relay.cosh, lambda x: -1.0 * np.sinh(x)),
(tvm.relay.sinh, lambda x: np.cosh(x))]:
check_single_op(opfunc, ref)
......
......@@ -49,14 +49,18 @@ TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(log2);
TOPI_DECLARE_UNARY_OP(log10);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(cosh);
TOPI_DECLARE_UNARY_OP(tan);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(sinh);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);
......
......@@ -144,6 +144,23 @@ def cos(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def cosh(x):
"""Take cosh of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.cosh(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def sin(x):
"""Take sin of input x.
......@@ -161,6 +178,23 @@ def sin(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def sinh(x):
"""Take sinh of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.sinh(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def atan(x):
"""Take atan of input x.
......@@ -346,6 +380,40 @@ def log(x):
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def log2(x):
"""Take logarithm to the base 2 of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.log2(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def log10(x):
"""Take logarithm to the base 10 of input x.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.log10(x(*i)))
@tvm.te.tag_scope(tag=tag.ELEMWISE)
def sqrt(x):
"""Take square root of input x.
......
......@@ -61,11 +61,21 @@ TVM_REGISTER_GLOBAL("topi.cos")
*rv = cos(args[0]);
});
TVM_REGISTER_GLOBAL("topi.cosh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cosh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sinh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sinh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
......@@ -101,6 +111,16 @@ TVM_REGISTER_GLOBAL("topi.log")
*rv = log(args[0]);
});
TVM_REGISTER_GLOBAL("topi.log2")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log2(args[0]);
});
TVM_REGISTER_GLOBAL("topi.log10")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log10(args[0]);
});
TVM_REGISTER_GLOBAL("topi.identity")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = identity(args[0]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment