Unverified Commit a6cb4b8d by Bing Xu Committed by GitHub

[Intrin] Adding a few missing math intrin (#5011)

* [intrin] exp2

* [intrin] exp10

* [intrin] log2/10

* [intrins] exp10

* [test] math intrin
parent 07469675
......@@ -508,16 +508,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
} \
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(exp2);
TVM_DECLARE_INTRIN_UNARY(exp10);
TVM_DECLARE_INTRIN_UNARY(erf);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(log2);
TVM_DECLARE_INTRIN_UNARY(log10);
TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(tan);
TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(cosh);
TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(sinh);
TVM_DECLARE_INTRIN_UNARY(atan);
namespace tir {
......
......@@ -33,7 +33,9 @@ from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_li
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
......
......@@ -330,6 +330,38 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x)
def exp2(x):
"""Calculate 2**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp2", x)
def exp10(x):
"""Calculate 10**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp10", x)
def erf(x):
"""Take gauss error function of the input x.
......@@ -393,6 +425,38 @@ def log(x):
"""
return call_pure_intrin(x.dtype, "log", x)
def log2(x):
"""Take log2 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log2", x)
def log10(x):
"""Take log10 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log10", x)
def tan(x):
"""Take tan of input x.
......@@ -424,6 +488,23 @@ def cos(x):
"""
return call_pure_intrin(x.dtype, "cos", x)
def cosh(x):
"""Take cosh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "cosh", x)
def sin(x):
"""Take sin of input x.
......@@ -439,6 +520,23 @@ def sin(x):
"""
return call_pure_intrin(x.dtype, "sin", x)
def sinh(x):
"""Take sin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "sinh", x)
def atan(x):
"""Take atan of input x.
......
......@@ -35,12 +35,35 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
PrimExpr ret = tir::CallNode::make(
x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic);
*rv = ret;
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
......@@ -108,9 +131,45 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
PrimExpr exp_negx = tir::CallNode::make(
x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
PrimExpr exp_posx = tir::CallNode::make(
x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
PrimExpr ret = (exp_posx + exp_negx) / two;
*rv = ret;
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
PrimExpr exp_negx = tir::CallNode::make(
x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
PrimExpr exp_posx = tir::CallNode::make(
x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
PrimExpr ret = (exp_posx - exp_negx) / two;
*rv = ret;
});
} // namespace llvm
} // namespace codegen
} // namespace tvm
......
......@@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
.set_body(DispatchExternLibDevice);
......@@ -72,6 +78,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt")
.set_body(DispatchExternLibDevice);
......@@ -87,9 +99,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan")
.set_body(DispatchExternLibDevice);
......
......@@ -62,6 +62,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf")
.set_body(DispatchExternOCML);
......@@ -71,6 +77,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt")
.set_body(DispatchExternOCML);
......@@ -86,9 +98,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan")
.set_body(DispatchExternOCML);
......
......@@ -107,21 +107,39 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
.set_body(DispatchExtern<CUDAFastMathTan>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan")
.set_body(DispatchExtern<CUDAMath>);
......
......@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh")
.set_body(DispatchExtern<Direct>);
......@@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
.set_body(DispatchExtern<Direct>);
......@@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh")
.set_body(DispatchExtern<Direct>);
// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
......
......@@ -36,9 +36,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh")
.set_body(DispatchExtern<Direct>);
......@@ -51,6 +63,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh")
.set_body(DispatchExtern<Direct>);
......@@ -60,6 +72,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
......
......@@ -45,5 +45,33 @@ def test_nearbyint():
a_rounded.asnumpy(), np.rint(a.asnumpy()))
def test_unary_intrin():
test_funcs = [
(tvm.tir.exp10, lambda x : np.power(10, x)),
(tvm.tir.log2, lambda x : np.log2(x)),
(tvm.tir.log10, lambda x : np.log10(x)),
(tvm.tir.sinh, lambda x : np.sinh(x)),
(tvm.tir.cosh, lambda x : np.cosh(x)),
]
def run_test(tvm_intrin, np_func):
m = te.var("m",)
A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name='B')
s = te.create_schedule(B.op)
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
b = tvm.nd.array( \
np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b)
tvm.testing.assert_allclose(
b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5)
for func in test_funcs:
run_test(*func);
if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment