Commit cc112c10 by Ashutosh Parkhi Committed by Tianqi Chen

Support for sign (#2775)

parent d441d445
...@@ -11,6 +11,7 @@ List of operators ...@@ -11,6 +11,7 @@ List of operators
topi.negative topi.negative
topi.floor topi.floor
topi.ceil topi.ceil
topi.sign
topi.trunc topi.trunc
topi.round topi.round
topi.abs topi.abs
...@@ -96,6 +97,7 @@ topi ...@@ -96,6 +97,7 @@ topi
.. autofunction:: topi.identity .. autofunction:: topi.identity
.. autofunction:: topi.floor .. autofunction:: topi.floor
.. autofunction:: topi.ceil .. autofunction:: topi.ceil
.. autofunction:: topi.sign
.. autofunction:: topi.trunc .. autofunction:: topi.trunc
.. autofunction:: topi.round .. autofunction:: topi.round
.. autofunction:: topi.abs .. autofunction:: topi.abs
......
...@@ -81,6 +81,7 @@ This level enables additional math and transform operators. ...@@ -81,6 +81,7 @@ This level enables additional math and transform operators.
tvm.relay.squeeze tvm.relay.squeeze
tvm.relay.floor tvm.relay.floor
tvm.relay.ceil tvm.relay.ceil
tvm.relay.sign
tvm.relay.trunc tvm.relay.trunc
tvm.relay.clip tvm.relay.clip
tvm.relay.round tvm.relay.round
...@@ -213,6 +214,7 @@ Level 3 Definitions ...@@ -213,6 +214,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.squeeze .. autofunction:: tvm.relay.squeeze
.. autofunction:: tvm.relay.floor .. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil .. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.sign
.. autofunction:: tvm.relay.trunc .. autofunction:: tvm.relay.trunc
.. autofunction:: tvm.relay.clip .. autofunction:: tvm.relay.clip
.. autofunction:: tvm.relay.round .. autofunction:: tvm.relay.round
......
...@@ -16,6 +16,7 @@ register_schedule("floor", schedule_broadcast) ...@@ -16,6 +16,7 @@ register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast) register_schedule("ceil", schedule_broadcast)
register_schedule("trunc", schedule_broadcast) register_schedule("trunc", schedule_broadcast)
register_schedule("round", schedule_broadcast) register_schedule("round", schedule_broadcast)
register_schedule("sign", schedule_broadcast)
register_schedule("abs", schedule_broadcast) register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast) register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast) register_schedule("logical_not", schedule_broadcast)
......
...@@ -158,6 +158,20 @@ def abs(data): ...@@ -158,6 +158,20 @@ def abs(data):
""" """
return _make.abs(data) return _make.abs(data)
def sign(data):
"""Compute element-wise absolute of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sign(data)
def tanh(data): def tanh(data):
"""Compute element-wise tanh of data. """Compute element-wise tanh of data.
......
...@@ -146,6 +146,16 @@ RELAY_REGISTER_UNARY_OP("round") ...@@ -146,6 +146,16 @@ RELAY_REGISTER_UNARY_OP("round")
.set_support_level(3) .set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
RELAY_REGISTER_UNARY_OP("sign")
.describe(R"code(Returns the sign of input array, computed element-wise.
.. numpy::
sign(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign));
RELAY_REGISTER_UNARY_OP("abs") RELAY_REGISTER_UNARY_OP("abs")
.describe(R"code(Returns the abs of input array, computed element-wise. .describe(R"code(Returns the abs of input array, computed element-wise.
......
...@@ -25,7 +25,8 @@ def test_unary_identity(): ...@@ -25,7 +25,8 @@ def test_unary_identity():
(relay.round, np.round), (relay.round, np.round),
(relay.abs, np.abs), (relay.abs, np.abs),
(relay.copy, None), # np.copy (relay.copy, None), # np.copy
(relay.negative, np.negative)]: (relay.negative, np.negative),
(relay.sign, np.sign)]:
shape = (8, 9, 4) shape = (8, 9, 4)
x = relay.var("x", relay.TensorType(shape, "float32")) x = relay.var("x", relay.TensorType(shape, "float32"))
y = op(x) y = op(x)
......
...@@ -89,6 +89,28 @@ inline Tensor logical_not(const Tensor& x, ...@@ -89,6 +89,28 @@ inline Tensor logical_not(const Tensor& x,
} }
/*! /*!
* \brief Returns the sign of the tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sign
*/
inline Tensor sign(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr zero = make_zero(x->dtype);
Expr one = make_const(x->dtype, 1);
Expr minus_one = make_const(x->dtype, -1);
auto s1 = tvm::ir::Select::make((x(i) < zero), minus_one, zero);
auto s2 = tvm::ir::Select::make((x(i) > zero), one, s1);
return s2;
}, name, tag);
}
/*!
* \brief Creates an operation that clips each element of a tensor to * \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max] * the interval [a_min, a_max]
* *
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import tag from . import tag
from . import cpp
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def identity(x): def identity(x):
...@@ -107,6 +108,20 @@ def ceil(x): ...@@ -107,6 +108,20 @@ def ceil(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))
def sign(x):
"""Returns -1, 0, 1 based on sign of x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.sign(x)
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x): def trunc(x):
......
...@@ -173,6 +173,11 @@ TVM_REGISTER_GLOBAL("topi.elemwise_sum") ...@@ -173,6 +173,11 @@ TVM_REGISTER_GLOBAL("topi.elemwise_sum")
*rv = elemwise_sum(args[0]); *rv = elemwise_sum(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.sign")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sign(args[0]);
});
TVM_REGISTER_GLOBAL("topi.full") TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]); *rv = full(args[0], args[1], args[2]);
......
...@@ -18,10 +18,11 @@ def test_ewise(): ...@@ -18,10 +18,11 @@ def test_ewise():
shape = (20, 3) shape = (20, 3)
def test_apply(func, name, f_numpy, low, high, check_round=False): def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False):
B = func(A) B = func(A)
assert tuple(B.shape) == tuple(A.shape) assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name if not skip_name_check:
assert B.op.body[0].name == name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary # avoid round check too close to boundary
if check_round: if check_round:
...@@ -49,6 +50,7 @@ def test_ewise(): ...@@ -49,6 +50,7 @@ def test_ewise():
test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100) test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100) test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100, check_round=True) test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment