Unverified Commit a6cb4b8d by Bing Xu Committed by GitHub

[Intrin] Adding a few missing math intrin (#5011)

* [intrin] exp2

* [intrin] exp10

* [intrin] log2/10

* [intrins] exp10

* [test] math intrin
parent 07469675
...@@ -508,16 +508,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); ...@@ -508,16 +508,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
} \ } \
TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(exp2);
TVM_DECLARE_INTRIN_UNARY(exp10);
TVM_DECLARE_INTRIN_UNARY(erf); TVM_DECLARE_INTRIN_UNARY(erf);
TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(log2);
TVM_DECLARE_INTRIN_UNARY(log10);
TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(tan); TVM_DECLARE_INTRIN_UNARY(tan);
TVM_DECLARE_INTRIN_UNARY(cos); TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(cosh);
TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(sinh);
TVM_DECLARE_INTRIN_UNARY(atan); TVM_DECLARE_INTRIN_UNARY(atan);
namespace tir { namespace tir {
......
...@@ -33,7 +33,9 @@ from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_li ...@@ -33,7 +33,9 @@ from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_li
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum from .op import comm_reducer, min, max, sum
......
...@@ -330,6 +330,38 @@ def exp(x): ...@@ -330,6 +330,38 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x) return call_pure_intrin(x.dtype, "exp", x)
def exp2(x):
"""Calculate 2**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp2", x)
def exp10(x):
"""Calculate 10**x
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp10", x)
def erf(x): def erf(x):
"""Take gauss error function of the input x. """Take gauss error function of the input x.
...@@ -393,6 +425,38 @@ def log(x): ...@@ -393,6 +425,38 @@ def log(x):
""" """
return call_pure_intrin(x.dtype, "log", x) return call_pure_intrin(x.dtype, "log", x)
def log2(x):
"""Take log2 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log2", x)
def log10(x):
"""Take log10 of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log10", x)
def tan(x): def tan(x):
"""Take tan of input x. """Take tan of input x.
...@@ -424,6 +488,23 @@ def cos(x): ...@@ -424,6 +488,23 @@ def cos(x):
""" """
return call_pure_intrin(x.dtype, "cos", x) return call_pure_intrin(x.dtype, "cos", x)
def cosh(x):
"""Take cosh of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "cosh", x)
def sin(x): def sin(x):
"""Take sin of input x. """Take sin of input x.
...@@ -439,6 +520,23 @@ def sin(x): ...@@ -439,6 +520,23 @@ def sin(x):
""" """
return call_pure_intrin(x.dtype, "sin", x) return call_pure_intrin(x.dtype, "sin", x)
def sinh(x):
"""Take sin of input x.
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "sinh", x)
def atan(x): def atan(x):
"""Take atan of input x. """Take atan of input x.
......
...@@ -35,12 +35,35 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") ...@@ -35,12 +35,35 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
PrimExpr ret = tir::CallNode::make(
x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic);
*rv = ret;
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
...@@ -108,9 +131,45 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") ...@@ -108,9 +131,45 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
PrimExpr exp_negx = tir::CallNode::make(
x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
PrimExpr exp_posx = tir::CallNode::make(
x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
PrimExpr ret = (exp_posx + exp_negx) / two;
*rv = ret;
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
PrimExpr exp_negx = tir::CallNode::make(
x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
PrimExpr exp_posx = tir::CallNode::make(
x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
PrimExpr ret = (exp_posx - exp_negx) / two;
*rv = ret;
});
} // namespace llvm } // namespace llvm
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs") ...@@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
...@@ -72,6 +78,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma") ...@@ -72,6 +78,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
...@@ -87,9 +99,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan") ...@@ -87,9 +99,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
......
...@@ -62,6 +62,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs") ...@@ -62,6 +62,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
...@@ -71,6 +77,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma") ...@@ -71,6 +77,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
...@@ -86,9 +98,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan") ...@@ -86,9 +98,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh")
.set_body(DispatchExternOCML);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
......
...@@ -107,21 +107,39 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") ...@@ -107,21 +107,39 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
.set_body(DispatchExtern<CUDAFastMathTan>); .set_body(DispatchExtern<CUDAFastMathTan>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
......
...@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round") ...@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
...@@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") ...@@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh")
.set_body(DispatchExtern<Direct>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round") ...@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
...@@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") ...@@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh")
.set_body(DispatchExtern<Direct>);
// There is no warp shuffle instruction in standard OpenCL // There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension // When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle { struct IntelShuffle {
......
...@@ -36,9 +36,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil") ...@@ -36,9 +36,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
...@@ -51,6 +63,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow") ...@@ -51,6 +63,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh")
.set_body(DispatchExtern<Direct>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round") ...@@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
...@@ -60,6 +72,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow") ...@@ -60,6 +72,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh")
.set_body(DispatchExtern<Direct>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
......
...@@ -45,5 +45,33 @@ def test_nearbyint(): ...@@ -45,5 +45,33 @@ def test_nearbyint():
a_rounded.asnumpy(), np.rint(a.asnumpy())) a_rounded.asnumpy(), np.rint(a.asnumpy()))
def test_unary_intrin():
test_funcs = [
(tvm.tir.exp10, lambda x : np.power(10, x)),
(tvm.tir.log2, lambda x : np.log2(x)),
(tvm.tir.log10, lambda x : np.log10(x)),
(tvm.tir.sinh, lambda x : np.sinh(x)),
(tvm.tir.cosh, lambda x : np.cosh(x)),
]
def run_test(tvm_intrin, np_func):
m = te.var("m",)
A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name='B')
s = te.create_schedule(B.op)
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
b = tvm.nd.array( \
np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b)
tvm.testing.assert_allclose(
b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5)
for func in test_funcs:
run_test(*func);
if __name__ == "__main__": if __name__ == "__main__":
test_nearbyint() test_nearbyint()
test_unary_intrin()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment