Commit 165aa0db by hlu1 Committed by Tianqi Chen

fast tanh (#3255)

parent 29b0b4c1
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "tvm/tvm.h" #include "tvm/tvm.h"
#include "tvm/ir.h" #include "tvm/ir.h"
#include "tvm/ir_pass.h" #include "tvm/ir_pass.h"
#include "broadcast.h"
namespace topi { namespace topi {
using namespace tvm; using namespace tvm;
...@@ -46,7 +47,6 @@ using namespace tvm; ...@@ -46,7 +47,6 @@ using namespace tvm;
} }
TOPI_DECLARE_UNARY_OP(exp); TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt); TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log); TOPI_DECLARE_UNARY_OP(log);
...@@ -56,6 +56,74 @@ TOPI_DECLARE_UNARY_OP(round); ...@@ -56,6 +56,74 @@ TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs); TOPI_DECLARE_UNARY_OP(abs);
/*
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
*/
inline Tensor fast_tanh_float(const Tensor& in,
std::string name,
std::string tag) {
// Clamp the inputs to the range [-9, 9] since anything outside
// this range is +/-1.0f in single-precision.
auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0));
// The monomial coefficients of the numerator polynomial (odd).
auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03);
auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04);
auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05);
auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08);
auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11);
auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13);
auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16);
// The monomial coefficients of the denominator polynomial (even).
auto beta_0 = make_const(in->dtype, 4.89352518554385e-03);
auto beta_2 = make_const(in->dtype, 2.26843463243900e-03);
auto beta_4 = make_const(in->dtype, 1.18534705686654e-04);
auto beta_6 = make_const(in->dtype, 1.19825839466702e-06);
return compute(x->shape,
[&](const Array<Var>& i) {
auto x2 = x(i) * x(i);
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x(i) * p;
auto q = x2 * beta_6 + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;
return p / q;
},
name, tag);
}
/*!
* \brief Creates an operation that returns hyperbolic tanh of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is tanh
*/
inline Tensor tanh(const Tensor& x,
std::string name = "T_tanh",
std::string tag = kElementWise) {
if (x->dtype == Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
} else {
// fallback to default implementation
return compute(x->shape, [&](const Array<Var>& i) {
return ::tvm::tanh(x(i));
}, name, tag);
}
}
/*! /*!
* \brief Creates an operation that returns identity of a given tensor * \brief Creates an operation that returns identity of a given tensor
* *
......
...@@ -29,13 +29,21 @@ def test_util(): ...@@ -29,13 +29,21 @@ def test_util():
def test_ewise(): def test_ewise():
m = tvm.var('m') def test_apply(
l = tvm.var('l') func,
A = tvm.placeholder((m, l), name='A') name,
f_numpy,
low,
high,
shape=(20, 3),
dtype=tvm.float32,
check_round=False,
skip_name_check=False,
):
m = tvm.var("m")
l = tvm.var("l")
A = tvm.placeholder((m, l), dtype=dtype, name="A")
shape = (20, 3)
def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False):
B = func(A) B = func(A)
assert tuple(B.shape) == tuple(A.shape) assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check: if not skip_name_check:
...@@ -63,7 +71,6 @@ def test_ewise(): ...@@ -63,7 +71,6 @@ def test_ewise():
for device in get_all_backend(): for device in get_all_backend():
check_device(device) check_device(device)
test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True) test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
...@@ -71,11 +78,12 @@ def test_ewise(): ...@@ -71,11 +78,12 @@ def test_ewise():
test_apply(topi.abs, "fabs", np.abs, -100, 100) test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100, check_round=True) test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
test_apply(topi.exp, "exp", np.exp, -1, 1) test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10) test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128))
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1) test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128), dtype="float64")
test_apply(topi.sigmoid, "sigmoid", lambda x: 1 / (1 + np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100) test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True) test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
def test_cast(): def test_cast():
...@@ -93,7 +101,7 @@ def test_cast(): ...@@ -93,7 +101,7 @@ def test_cast():
b_np = a_np.astype(to_dtype) b_np = a_np.astype(to_dtype)
for device in get_all_backend(): for device in get_all_backend():
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
print("Skip because %s is not enabled" % device) print("Skip because %s is not enabled" % device)
continue continue
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment