Unverified Commit 13140916 by Alex Gladkov Committed by GitHub

Fast exponent (#4790)

parent a43e326f
......@@ -377,5 +377,85 @@ inline Tensor full_like(const Tensor& x,
}, name, tag);
}
/*!
* \brief Fast exponential function implementation
*
* \param _x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is exponent operation
*
* \note Function computes:
* log2(e^x) = x * log2(e) * log2(2) =>
* log2(e^x) = log2(2^(x*log2(e))) =>
* e^x = 2^(x*log2(e))
* Splitting power x*log2(e) into integer and fractional parts:
* e^(n+f) = e^n * e^f
* n = floor(x*log2(e) + 1/2)
* f = x - n * ln(2)
* exp(x) = 2^n * exp(y)
* Approximation for fractional part:
* y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2))
*/
inline Tensor fast_exp_float32(const Tensor& _x,
std::string name,
std::string tag) {
auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
make_const(DataType::Float(32), 1.3981999507E-3f),
make_const(DataType::Float(32), 8.3334519073E-3f),
make_const(DataType::Float(32), 4.1665795894E-2f),
make_const(DataType::Float(32), 1.6666665459E-1f),
make_const(DataType::Float(32), 5.0000001201E-1f)};
auto one = make_const(DataType::Float(32), 1.0f);
auto one_half = make_const(DataType::Float(32), 0.5f);
auto b = make_const(DataType::Float(32), 127.0f);
return compute(_x->shape,
[&](const Array<Var>& i) {
// clamp x
auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
// integer part
auto n = ::tvm::floor(x * log2e + one_half);
// fractional part
auto f = x - n * ln2;
auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f
+ p[5]) * f * f + f + one;
// Return 2^m * exp(r).
auto ef = tvm::reinterpret(DataType::Float(32),
::tvm::cast(DataType::Int(32), n + b) << 23);
return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
},
name, tag);
}
/*!
* \brief Fast exponential function implementation
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is exponent operation
*
*/
inline Tensor fast_exp(const Tensor& x,
std::string name = "T_fast_exp",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
auto ret = fast_exp_float32(x, name, tag);
return ret;
} else {
return compute(x->shape, [&](const Array<Var>& i) {
return ::tvm::exp(x(i));
}, name, tag);
}
}
} // namespace topi
#endif // TOPI_ELEMWISE_H_
......@@ -451,3 +451,19 @@ def reinterpret(x, dtype):
The result.
"""
return cpp.reinterpret(x, dtype)
def fast_exp(x):
"""Take exponential of input x using fast_exp implementation
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
......@@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]);
});
TVM_REGISTER_GLOBAL("topi.fast_exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_exp(args[0]);
});
TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
......
......@@ -185,7 +185,45 @@ def test_cast():
verify("bool", "int32")
def test_fastmath():
def test_apply(
func,
name,
f_numpy,
low,
high,
step,
dtype=tvm.float32
):
a_np = np.arange(low, high, step).astype(dtype)
b_np = f_numpy(a_np)
A = tvm.placeholder(a_np.shape, dtype=dtype, name="A")
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
func = tvm.build(s, [A, B], device, name=name)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
check_device('llvm')
check_device('llvm -device=arm-cpu')
test_apply(topi.fast_exp, "fast_exp", np.exp,
low=-88, high=88,
step = 0.01)
if __name__ == "__main__":
test_util()
test_ewise()
test_cast()
test_fastmath()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment