Unverified Commit 90b08f5e by Bing Xu Committed by GitHub

[intrin] a few more math functions (#5468)

parent 684f2d7b
...@@ -570,7 +570,13 @@ TVM_DECLARE_INTRIN_UNARY(cos); ...@@ -570,7 +570,13 @@ TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(cosh); TVM_DECLARE_INTRIN_UNARY(cosh);
TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(sinh); TVM_DECLARE_INTRIN_UNARY(sinh);
TVM_DECLARE_INTRIN_UNARY(asin);
TVM_DECLARE_INTRIN_UNARY(acos);
TVM_DECLARE_INTRIN_UNARY(atan); TVM_DECLARE_INTRIN_UNARY(atan);
TVM_DECLARE_INTRIN_UNARY(acosh);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);
namespace tir { namespace tir {
/*! /*!
......
...@@ -37,7 +37,9 @@ from .function import PrimFunc ...@@ -37,7 +37,9 @@ from .function import PrimFunc
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2 from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign from .op import isnan, isfinite, isinf, copysign
......
...@@ -522,6 +522,38 @@ def cosh(x): ...@@ -522,6 +522,38 @@ def cosh(x):
return call_pure_intrin(x.dtype, "cosh", x) return call_pure_intrin(x.dtype, "cosh", x)
def acos(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "acos", x)
def acosh(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "acosh", x)
def sin(x): def sin(x):
"""Take sin of input x. """Take sin of input x.
...@@ -554,6 +586,38 @@ def sinh(x): ...@@ -554,6 +586,38 @@ def sinh(x):
return call_pure_intrin(x.dtype, "sinh", x) return call_pure_intrin(x.dtype, "sinh", x)
def asin(x):
"""Take asin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "asin", x)
def asinh(x):
"""Take asinh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "asinh", x)
def atan(x): def atan(x):
"""Take atan of input x. """Take atan of input x.
...@@ -570,6 +634,22 @@ def atan(x): ...@@ -570,6 +634,22 @@ def atan(x):
return call_pure_intrin(x.dtype, "atan", x) return call_pure_intrin(x.dtype, "atan", x)
def atanh(x):
"""Take atanh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "atanh", x)
def atan2(x1, x2): def atan2(x1, x2):
"""Take arctan2(x1, x2). """Take arctan2(x1, x2).
......
...@@ -52,22 +52,37 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") ...@@ -52,22 +52,37 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot")
......
...@@ -63,6 +63,12 @@ def test_unary_intrin(): ...@@ -63,6 +63,12 @@ def test_unary_intrin():
(tvm.tir.sinh, lambda x : np.sinh(x)), (tvm.tir.sinh, lambda x : np.sinh(x)),
(tvm.tir.cosh, lambda x : np.cosh(x)), (tvm.tir.cosh, lambda x : np.cosh(x)),
(tvm.tir.log1p, lambda x : np.log1p(x)), (tvm.tir.log1p, lambda x : np.log1p(x)),
(tvm.tir.asin, lambda x : np.arcsin(x)),
(tvm.tir.acos, lambda x : np.arccos(x)),
(tvm.tir.atan, lambda x : np.arctan(x)),
(tvm.tir.asinh, lambda x : np.arcsinh(x)),
(tvm.tir.acosh, lambda x : np.arccosh(x)),
(tvm.tir.atanh, lambda x : np.arctanh(x)),
] ]
def run_test(tvm_intrin, np_func): def run_test(tvm_intrin, np_func):
m = te.var("m",) m = te.var("m",)
...@@ -72,7 +78,7 @@ def test_unary_intrin(): ...@@ -72,7 +78,7 @@ def test_unary_intrin():
f = tvm.build(s, [A, B], "llvm") f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
n = 10 n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), ctx)
b = tvm.nd.array( \ b = tvm.nd.array( \
np.random.uniform(size=n).astype(A.dtype), ctx) np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b) f(a, b)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment