Commit 1fb2d7e2 by Pariksheet Pinjari Committed by Tianqi Chen

Add support for absolute opeartion (#1406)

parent 6bda4e33
...@@ -14,7 +14,7 @@ tvm.intrin ...@@ -14,7 +14,7 @@ tvm.intrin
tvm.ceil tvm.ceil
tvm.trunc tvm.trunc
tvm.round tvm.round
tvm.abs
.. autofunction:: tvm.call_packed .. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_pure_intrin .. autofunction:: tvm.call_pure_intrin
...@@ -26,3 +26,4 @@ tvm.intrin ...@@ -26,3 +26,4 @@ tvm.intrin
.. autofunction:: tvm.ceil .. autofunction:: tvm.ceil
.. autofunction:: tvm.trunc .. autofunction:: tvm.trunc
.. autofunction:: tvm.round .. autofunction:: tvm.round
.. autofunction:: tvm.abs
...@@ -13,6 +13,7 @@ List of operators ...@@ -13,6 +13,7 @@ List of operators
topi.ceil topi.ceil
topi.trunc topi.trunc
topi.round topi.round
topi.abs
topi.exp topi.exp
topi.tanh topi.tanh
topi.log topi.log
...@@ -84,6 +85,7 @@ topi ...@@ -84,6 +85,7 @@ topi
.. autofunction:: topi.ceil .. autofunction:: topi.ceil
.. autofunction:: topi.trunc .. autofunction:: topi.trunc
.. autofunction:: topi.round .. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.exp .. autofunction:: topi.exp
.. autofunction:: topi.tanh .. autofunction:: topi.tanh
.. autofunction:: topi.log .. autofunction:: topi.log
......
...@@ -79,6 +79,7 @@ This level enables typical convnet models. ...@@ -79,6 +79,7 @@ This level enables typical convnet models.
nnvm.symbol.ceil nnvm.symbol.ceil
nnvm.symbol.round nnvm.symbol.round
nnvm.symbol.trunc nnvm.symbol.trunc
nnvm.symbol.abs
nnvm.symbol.leaky_relu nnvm.symbol.leaky_relu
nnvm.symbol.__add_scalar__ nnvm.symbol.__add_scalar__
nnvm.symbol.__sub_scalar__ nnvm.symbol.__sub_scalar__
...@@ -157,6 +158,7 @@ Detailed Definitions ...@@ -157,6 +158,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.ceil .. autofunction:: nnvm.symbol.ceil
.. autofunction:: nnvm.symbol.round .. autofunction:: nnvm.symbol.round
.. autofunction:: nnvm.symbol.trunc .. autofunction:: nnvm.symbol.trunc
.. autofunction:: nnvm.symbol.abs
.. autofunction:: nnvm.symbol.leaky_relu .. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.__add_scalar__ .. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__ .. autofunction:: nnvm.symbol.__sub_scalar__
......
...@@ -18,7 +18,6 @@ using HalideIR::likely_if_innermost; ...@@ -18,7 +18,6 @@ using HalideIR::likely_if_innermost;
using HalideIR::cast; using HalideIR::cast;
using HalideIR::min; using HalideIR::min;
using HalideIR::max; using HalideIR::max;
using HalideIR::abs;
using HalideIR::select; using HalideIR::select;
/*! /*!
...@@ -71,6 +70,26 @@ inline Expr pow(Expr x, Expr y) { ...@@ -71,6 +70,26 @@ inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
} }
/*!
* \brief Calculate absolute value of x, elementwise
* \param x The input data
*
* \return The aboslute value of input data x
*/
inline Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
} else {
LOG(WARNING) << "Warning: Data type " << x.type()
<<" not supported for absolute op. Skipping absolute op...";
return x;
}
}
} // namespace tvm } // namespace tvm
#endif // TVM_IR_OPERATOR_H_ #endif // TVM_IR_OPERATOR_H_
...@@ -68,6 +68,10 @@ reg.register_schedule("ceil", _fschedule_broadcast) ...@@ -68,6 +68,10 @@ reg.register_schedule("ceil", _fschedule_broadcast)
reg.register_pattern("round", OpPattern.ELEMWISE) reg.register_pattern("round", OpPattern.ELEMWISE)
reg.register_schedule("round", _fschedule_broadcast) reg.register_schedule("round", _fschedule_broadcast)
# abs
reg.register_pattern("abs", OpPattern.ELEMWISE)
reg.register_schedule("abs", _fschedule_broadcast)
# trunc # trunc
reg.register_pattern("trunc", OpPattern.ELEMWISE) reg.register_pattern("trunc", OpPattern.ELEMWISE)
reg.register_schedule("trunc", _fschedule_broadcast) reg.register_schedule("trunc", _fschedule_broadcast)
......
...@@ -81,6 +81,18 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(round) ...@@ -81,6 +81,18 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(round)
return Array<Tensor>{ topi::round(inputs[0]) }; return Array<Tensor>{ topi::round(inputs[0]) };
}); });
// abs
NNVM_REGISTER_ELEMWISE_UNARY_OP(abs)
.describe(R"code(Take absolute value of elements of the input.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::abs(inputs[0]) };
});
// sigmoid // sigmoid
NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid) NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
.describe(R"code(Computes sigmoid. .describe(R"code(Computes sigmoid.
......
...@@ -28,6 +28,10 @@ def test_trunc(): ...@@ -28,6 +28,10 @@ def test_trunc():
def test_round(): def test_round():
check_map(sym.round, np.round) check_map(sym.round, np.round)
def test_abs():
check_map(sym.abs, np.abs)
check_map(sym.abs, np.abs, dtype = "int32")
check_map(sym.abs, np.abs, dtype = "int8")
def test_shift(): def test_shift():
n = 3 n = 3
...@@ -40,4 +44,5 @@ if __name__ == "__main__": ...@@ -40,4 +44,5 @@ if __name__ == "__main__":
test_floor() test_floor()
test_ceil() test_ceil()
test_round() test_round()
test_abs()
test_trunc() test_trunc()
...@@ -285,6 +285,22 @@ def trunc(x): ...@@ -285,6 +285,22 @@ def trunc(x):
return call_pure_intrin(x.dtype, "trunc", x) return call_pure_intrin(x.dtype, "trunc", x)
def abs(x):
"""Get absolute value of the input element-wise.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return _make.abs(x)
def round(x): def round(x):
"""Round elements of the array to the nearest integer. """Round elements of the array to the nearest integer.
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <ir/IROperator.h> #include <ir/IROperator.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/ir_operator.h>
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -16,6 +17,11 @@ TVM_REGISTER_API("_Var") ...@@ -16,6 +17,11 @@ TVM_REGISTER_API("_Var")
*ret = Variable::make(args[1], args[0]); *ret = Variable::make(args[1], args[0]);
}); });
TVM_REGISTER_API("make.abs")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::abs(args[0]);
});
TVM_REGISTER_API("make._range_by_min_extent") TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]); *ret = Range::make_by_min_extent(args[0], args[1]);
......
...@@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil") ...@@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
.set_body(DispatchExtern<CUDAMath>); .set_body(DispatchExtern<CUDAMath>);
......
...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil") ...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil") ...@@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -34,6 +34,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") ...@@ -34,6 +34,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
......
...@@ -39,6 +39,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round") ...@@ -39,6 +39,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice); .set_body(DispatchExternLibDevice);
......
...@@ -38,6 +38,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round") ...@@ -38,6 +38,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML); .set_body(DispatchExternOCML);
......
...@@ -41,6 +41,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round") ...@@ -41,6 +41,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>); .set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs")
.set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>); .set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
......
...@@ -35,6 +35,7 @@ TOPI_DECLARE_UNARY_OP(floor); ...@@ -35,6 +35,7 @@ TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil); TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round); TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
/*! /*!
* \brief Creates an operation that returns identity of a given tensor * \brief Creates an operation that returns identity of a given tensor
......
...@@ -126,6 +126,23 @@ def trunc(x): ...@@ -126,6 +126,23 @@ def trunc(x):
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def abs(x):
"""Take absolute value of the input of x, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.abs(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x): def round(x):
"""Round elements of x to nearest integer. """Round elements of x to nearest integer.
......
...@@ -46,6 +46,7 @@ def test_ewise(): ...@@ -46,6 +46,7 @@ def test_ewise():
test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100) test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100) test_apply(topi.round, "round", np.round, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1) test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10) test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment