Commit 17c2c0a1 by Kimish Patel Committed by Tianqi Chen

Expose llvm.nearbyint intrinsic. This is a faster alternate to rounding. (#4001)

* Expose llvm.nearbyint intrinsic. This is a faster alternate to rounding.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Added python binding. Added test.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
parent 9756b067
...@@ -35,6 +35,7 @@ tvm.intrin ...@@ -35,6 +35,7 @@ tvm.intrin
tvm.ceil tvm.ceil
tvm.trunc tvm.trunc
tvm.round tvm.round
tvm.nearbyint
tvm.abs tvm.abs
tvm.isnan tvm.isnan
...@@ -52,5 +53,6 @@ tvm.intrin ...@@ -52,5 +53,6 @@ tvm.intrin
.. autofunction:: tvm.ceil .. autofunction:: tvm.ceil
.. autofunction:: tvm.trunc .. autofunction:: tvm.trunc
.. autofunction:: tvm.round .. autofunction:: tvm.round
.. autofunction:: tvm.nearbyint
.. autofunction:: tvm.abs .. autofunction:: tvm.abs
.. autofunction:: tvm.isnan .. autofunction:: tvm.isnan
...@@ -543,6 +543,14 @@ TVM_DLL Expr ceil(Expr x); ...@@ -543,6 +543,14 @@ TVM_DLL Expr ceil(Expr x);
TVM_DLL Expr round(Expr x); TVM_DLL Expr round(Expr x);
/*! /*!
* \brief Calculates std::nearbyint(x)
* \param x The input expression.
* \return The result expression.
* This is a faster alternate to round.
*/
TVM_DLL Expr nearbyint(Expr x);
/*!
* \brief Calculate trunc(x) * \brief Calculate trunc(x)
* \param x The input expression. * \param x The input expression.
* \return The result expression. * \return The result expression.
......
...@@ -434,6 +434,29 @@ def round(x): ...@@ -434,6 +434,29 @@ def round(x):
return _make.round(x) return _make.round(x)
def nearbyint(x):
"""Round elements of the array to the nearest integer.
This intrinsic uses llvm.nearbyint instead of llvm.round
which is faster but will results different from tvm.round.
Notably nearbyint rounds according to the rounding mode,
whereas tvm.round (llvm.round) ignores that.
For differences between the two see:
https://en.cppreference.com/w/cpp/numeric/math/round
https://en.cppreference.com/w/cpp/numeric/math/nearbyint
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return _make.nearbyint(x)
def isnan(x): def isnan(x):
"""Check if input value is Nan. """Check if input value is Nan.
......
...@@ -50,6 +50,9 @@ TVM_REGISTER_API("make.ceil") ...@@ -50,6 +50,9 @@ TVM_REGISTER_API("make.ceil")
TVM_REGISTER_API("make.round") TVM_REGISTER_API("make.round")
.set_body_typed(tvm::round); .set_body_typed(tvm::round);
TVM_REGISTER_API("make.nearbyint")
.set_body_typed(tvm::nearbyint);
TVM_REGISTER_API("make.trunc") TVM_REGISTER_API("make.trunc")
.set_body_typed(tvm::trunc); .set_body_typed(tvm::trunc);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -59,6 +59,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs") ...@@ -59,6 +59,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) { .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0]; Expr e = targs[0];
......
...@@ -527,6 +527,13 @@ Expr round(Expr x) { ...@@ -527,6 +527,13 @@ Expr round(Expr x) {
return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic);
} }
Expr nearbyint(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
return ir::Call::make(x.type(), "nearbyint", {x}, ir::Call::PureIntrinsic);
}
Expr trunc(Expr x) { Expr trunc(Expr x) {
using ir::FloatImm; using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>(); const FloatImm* fx = x.as<FloatImm>();
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import topi
from tvm.contrib import util, clang
import numpy as np
import ctypes
import math
def test_nearbyint():
m = tvm.var("m",)
A = tvm.placeholder((m,), name='A')
A_rounded = tvm.compute((m,), lambda *i: tvm.nearbyint(A(*i)), name='A')
s = tvm.create_schedule(A_rounded.op)
f = tvm.build(s, [A, A_rounded], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(high=100, size=n).astype(A.dtype), ctx)
a_rounded = tvm.nd.array( \
np.random.uniform(size=n).astype(A_rounded.dtype), ctx)
f(a, a_rounded)
# Note that numpys rint rounds to nearest integer with
# ties to halfway is broken by rounding to even.
# So that 1.5 and 2.5 will round 2.
# This is the default rounding mode with libc as well.
# However one can set a different rounding mode and in that
# case numpy result might differ.
tvm.testing.assert_allclose(
a_rounded.asnumpy(), np.rint(a.asnumpy()))
if __name__ == "__main__":
test_nearbyint()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment