Unverified Commit f5b02fdb by Haichen Shen Committed by GitHub

[Relay][OP] Add fast_erf implementation (#5241)

* add fast erf

* doc

* lint

* fix

* fix indent
parent 869b718a
...@@ -72,7 +72,7 @@ class GenericFunc : public ObjectRef { ...@@ -72,7 +72,7 @@ class GenericFunc : public ObjectRef {
* *
* \code * \code
* // Example code on how to call generic function * // Example code on how to call generic function
* void CallGeneirc(GenericFunc f) { * void CallGeneric(GenericFunc f) {
* // call like normal functions by pass in arguments * // call like normal functions by pass in arguments
* // return value is automatically converted back * // return value is automatically converted back
* int rvalue = f(1, 2.0); * int rvalue = f(1, 2.0);
......
...@@ -76,6 +76,7 @@ register_injective_schedule("shape_of") ...@@ -76,6 +76,7 @@ register_injective_schedule("shape_of")
register_injective_schedule("ndarray_size") register_injective_schedule("ndarray_size")
register_broadcast_schedule("fast_exp") register_broadcast_schedule("fast_exp")
register_broadcast_schedule("fast_tanh") register_broadcast_schedule("fast_tanh")
register_broadcast_schedule("fast_erf")
# zeros # zeros
...@@ -222,3 +223,4 @@ register_shape_func("exp", False, elemwise_shape_func) ...@@ -222,3 +223,4 @@ register_shape_func("exp", False, elemwise_shape_func)
register_shape_func("tan", False, elemwise_shape_func) register_shape_func("tan", False, elemwise_shape_func)
register_shape_func("fast_exp", False, elemwise_shape_func) register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func) register_shape_func("fast_tanh", False, elemwise_shape_func)
register_shape_func("fast_erf", False, elemwise_shape_func)
...@@ -128,6 +128,17 @@ RELAY_REGISTER_UNARY_OP("erf") ...@@ -128,6 +128,17 @@ RELAY_REGISTER_UNARY_OP("erf")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));
RELAY_REGISTER_UNARY_OP("fast_erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
.. math::
\fast_erf(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf));
RELAY_REGISTER_UNARY_OP("sqrt") RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the sqrt input array, computed element-wise. .describe(R"code(Returns the sqrt input array, computed element-wise.
......
...@@ -35,11 +35,14 @@ class FastMathMutator : public ExprRewriter { ...@@ -35,11 +35,14 @@ class FastMathMutator : public ExprRewriter {
public: public:
FastMathMutator() FastMathMutator()
: exp_op_(Op::Get("exp")), : exp_op_(Op::Get("exp")),
erf_op_(Op::Get("erf")),
tanh_op_(Op::Get("tanh")) {} tanh_op_(Op::Get("tanh")) {}
Expr Rewrite_(const CallNode* pre, const Expr& post) override { Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (pre->op == exp_op_) { if (pre->op == exp_op_) {
return FastExp(post.as<CallNode>()->args[0]); return FastExp(post.as<CallNode>()->args[0]);
} else if (pre->op == erf_op_) {
return FastErf(post.as<CallNode>()->args[0]);
} else if (pre->op == tanh_op_) { } else if (pre->op == tanh_op_) {
return FastTanh(post.as<CallNode>()->args[0]); return FastTanh(post.as<CallNode>()->args[0]);
} }
...@@ -51,6 +54,7 @@ class FastMathMutator : public ExprRewriter { ...@@ -51,6 +54,7 @@ class FastMathMutator : public ExprRewriter {
// operator equivalence checking so that the registry lookup overhead can be // operator equivalence checking so that the registry lookup overhead can be
// reduced. // reduced.
const Op& exp_op_; const Op& exp_op_;
const Op& erf_op_;
const Op& tanh_op_; const Op& tanh_op_;
}; };
......
...@@ -322,6 +322,11 @@ inline Expr FastExp(Expr e) { ...@@ -322,6 +322,11 @@ inline Expr FastExp(Expr e) {
return Call(op, {e}); return Call(op, {e});
} }
inline Expr FastErf(Expr e) {
static const Op& op = Op::Get("fast_erf");
return Call(op, {e});
}
inline Expr FastTanh(Expr e) { inline Expr FastTanh(Expr e) {
static const Op& op = Op::Get("fast_tanh"); static const Op& op = Op::Get("fast_tanh");
return Call(op, {e}); return Call(op, {e});
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import numpy as np import numpy as np
import scipy
from scipy import special
import tvm import tvm
import tvm.relay as relay import tvm.relay as relay
import topi import topi
...@@ -52,6 +54,7 @@ def test_fastmath(): ...@@ -52,6 +54,7 @@ def test_fastmath():
rtol=1e-5, atol=1e-5) rtol=1e-5, atol=1e-5)
test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01) test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01) test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <topi/tags.h> #include <topi/tags.h>
#include <algorithm>
#include <string> #include <string>
#include "broadcast.h" #include "broadcast.h"
...@@ -63,7 +64,7 @@ TOPI_DECLARE_UNARY_OP(tanh); ...@@ -63,7 +64,7 @@ TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(isfinite); TOPI_DECLARE_UNARY_OP(isfinite);
TOPI_DECLARE_UNARY_OP(isinf); TOPI_DECLARE_UNARY_OP(isinf);
/* /*!
* \brief Fast_tanh_float implementation from Eigen * \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26 * https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
*/ */
...@@ -461,5 +462,75 @@ inline Tensor fast_exp(const Tensor& x, ...@@ -461,5 +462,75 @@ inline Tensor fast_exp(const Tensor& x,
} }
} }
/*!
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
*/
inline Tensor fast_erf_float32(const Tensor& data,
std::string name,
std::string tag) {
auto plus_4 = make_const(DataType::Float(32), 4.f);
auto minus_4 = make_const(DataType::Float(32), -4.f);
// The monomial coefficients of the numerator polynomial (odd).
auto alpha_1 = make_const(DataType::Float(32), -1.60960333262415e-02f);
auto alpha_3 = make_const(DataType::Float(32), -2.95459980854025e-03f);
auto alpha_5 = make_const(DataType::Float(32), -7.34990630326855e-04f);
auto alpha_7 = make_const(DataType::Float(32), -5.69250639462346e-05f);
auto alpha_9 = make_const(DataType::Float(32), -2.10102402082508e-06f);
auto alpha_11 = make_const(DataType::Float(32), 2.77068142495902e-08f);
auto alpha_13 = make_const(DataType::Float(32), -2.72614225801306e-10f);
// The monomial coefficients of the denominator polynomial (even).
auto beta_0 = make_const(DataType::Float(32), -1.42647390514189e-02f);
auto beta_2 = make_const(DataType::Float(32), -7.37332916720468e-03f);
auto beta_4 = make_const(DataType::Float(32), -1.68282697438203e-03f);
auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f);
return compute(data->shape, [&](const Array<Var> &i) {
// clamp x
auto x = tvm::max(tvm::min(data(i), plus_4), minus_4);
auto x2 = x * x;
// Evaluate the numerator polynomial p.
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;
// Evaluate the denominator polynomial p.
auto q = x2 * beta_8 + beta_6;
q = x2 * q + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;
return p / q;
}, name, tag);
}
/*!
* \brief Fast erf implementation
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is erf operation
*/
inline Tensor fast_erf(const Tensor& x,
std::string name = "T_fast_erf",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
auto ret = fast_erf_float32(x, name, tag);
return ret;
} else {
return topi::erf(x);
}
}
} // namespace topi } // namespace topi
#endif // TOPI_ELEMWISE_H_ #endif // TOPI_ELEMWISE_H_
...@@ -534,3 +534,19 @@ def fast_tanh(x): ...@@ -534,3 +534,19 @@ def fast_tanh(x):
The result. The result.
""" """
return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
def fast_erf(x):
"""Take gauss error function of input x using fast_erf implementation.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return cpp.fast_erf(x, x.dtype, tag.ELEMWISE)
...@@ -46,6 +46,11 @@ TVM_REGISTER_GLOBAL("topi.erf") ...@@ -46,6 +46,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
*rv = erf(args[0]); *rv = erf(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.fast_erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_erf(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tan") TVM_REGISTER_GLOBAL("topi.tan")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tan(args[0]); *rv = tan(args[0]);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
import numpy as np import numpy as np
import scipy import scipy
from scipy import special
import tvm import tvm
from tvm import te from tvm import te
import topi import topi
...@@ -238,11 +239,11 @@ def test_fastmath(): ...@@ -238,11 +239,11 @@ def test_fastmath():
test_apply(topi.fast_exp, "fast_exp", np.exp, test_apply(topi.fast_exp, "fast_exp", np.exp,
low=-88, high=88, low=-88, high=88, step=0.01)
step = 0.01) test_apply(topi.fast_erf, "fast_erf", scipy.special.erf,
low=-10, high=10, step=0.01)
test_apply(topi.fast_tanh, "fast_tanh", np.tanh, test_apply(topi.fast_tanh, "fast_tanh", np.tanh,
low=-10, high=10, low=-10, high=10, step=0.01)
step = 0.01)
if __name__ == "__main__": if __name__ == "__main__":
test_util() test_util()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment