Unverified Commit 90b08f5e by Bing Xu Committed by GitHub

[intrin] a few more math functions (#5468)

parent 684f2d7b
......@@ -570,7 +570,13 @@ TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(cosh);
TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(sinh);
TVM_DECLARE_INTRIN_UNARY(asin);
TVM_DECLARE_INTRIN_UNARY(acos);
TVM_DECLARE_INTRIN_UNARY(atan);
TVM_DECLARE_INTRIN_UNARY(acosh);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);
namespace tir {
/*!
......
......@@ -37,7 +37,9 @@ from .function import PrimFunc
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
......
......@@ -522,6 +522,38 @@ def cosh(x):
return call_pure_intrin(x.dtype, "cosh", x)
def acos(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "acos", x)
def acosh(x):
"""Take acos of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "acosh", x)
def sin(x):
"""Take sin of input x.
......@@ -554,6 +586,38 @@ def sinh(x):
return call_pure_intrin(x.dtype, "sinh", x)
def asin(x):
"""Take asin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "asin", x)
def asinh(x):
"""Take asinh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "asinh", x)
def atan(x):
"""Take atan of input x.
......@@ -570,6 +634,22 @@ def atan(x):
return call_pure_intrin(x.dtype, "atan", x)
def atanh(x):
"""Take atanh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "atanh", x)
def atan2(x1, x2):
"""Take arctan2(x1, x2).
......
......@@ -52,22 +52,37 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot")
......
......@@ -63,6 +63,12 @@ def test_unary_intrin():
(tvm.tir.sinh, lambda x : np.sinh(x)),
(tvm.tir.cosh, lambda x : np.cosh(x)),
(tvm.tir.log1p, lambda x : np.log1p(x)),
(tvm.tir.asin, lambda x : np.arcsin(x)),
(tvm.tir.acos, lambda x : np.arccos(x)),
(tvm.tir.atan, lambda x : np.arctan(x)),
(tvm.tir.asinh, lambda x : np.arcsinh(x)),
(tvm.tir.acosh, lambda x : np.arccosh(x)),
(tvm.tir.atanh, lambda x : np.arctanh(x)),
]
def run_test(tvm_intrin, np_func):
m = te.var("m",)
......@@ -72,7 +78,7 @@ def test_unary_intrin():
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), ctx)
b = tvm.nd.array( \
np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment