Commit cc112c10 by Ashutosh Parkhi Committed by Tianqi Chen

Support for sign (#2775)

parent d441d445
......@@ -11,6 +11,7 @@ List of operators
topi.negative
topi.floor
topi.ceil
topi.sign
topi.trunc
topi.round
topi.abs
......@@ -96,6 +97,7 @@ topi
.. autofunction:: topi.identity
.. autofunction:: topi.floor
.. autofunction:: topi.ceil
.. autofunction:: topi.sign
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.abs
......
......@@ -81,6 +81,7 @@ This level enables additional math and transform operators.
tvm.relay.squeeze
tvm.relay.floor
tvm.relay.ceil
tvm.relay.sign
tvm.relay.trunc
tvm.relay.clip
tvm.relay.round
......@@ -213,6 +214,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.squeeze
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.sign
.. autofunction:: tvm.relay.trunc
.. autofunction:: tvm.relay.clip
.. autofunction:: tvm.relay.round
......
......@@ -16,6 +16,7 @@ register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast)
register_schedule("trunc", schedule_broadcast)
register_schedule("round", schedule_broadcast)
register_schedule("sign", schedule_broadcast)
register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast)
......
......@@ -158,6 +158,20 @@ def abs(data):
"""
return _make.abs(data)
def sign(data):
"""Compute element-wise absolute of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sign(data)
def tanh(data):
"""Compute element-wise tanh of data.
......
......@@ -146,6 +146,16 @@ RELAY_REGISTER_UNARY_OP("round")
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
RELAY_REGISTER_UNARY_OP("sign")
.describe(R"code(Returns the sign of input array, computed element-wise.
.. numpy::
sign(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign));
RELAY_REGISTER_UNARY_OP("abs")
.describe(R"code(Returns the abs of input array, computed element-wise.
......
......@@ -25,7 +25,8 @@ def test_unary_identity():
(relay.round, np.round),
(relay.abs, np.abs),
(relay.copy, None), # np.copy
(relay.negative, np.negative)]:
(relay.negative, np.negative),
(relay.sign, np.sign)]:
shape = (8, 9, 4)
x = relay.var("x", relay.TensorType(shape, "float32"))
y = op(x)
......
......@@ -89,6 +89,28 @@ inline Tensor logical_not(const Tensor& x,
}
/*!
* \brief Returns the sign of the tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sign
*/
inline Tensor sign(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr zero = make_zero(x->dtype);
Expr one = make_const(x->dtype, 1);
Expr minus_one = make_const(x->dtype, -1);
auto s1 = tvm::ir::Select::make((x(i) < zero), minus_one, zero);
auto s2 = tvm::ir::Select::make((x(i) > zero), one, s1);
return s2;
}, name, tag);
}
/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
*
......
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from . import tag
from . import cpp
@tvm.tag_scope(tag=tag.ELEMWISE)
def identity(x):
......@@ -107,6 +108,20 @@ def ceil(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))
def sign(x):
"""Returns -1, 0, 1 based on sign of x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.sign(x)
@tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x):
......
......@@ -173,6 +173,11 @@ TVM_REGISTER_GLOBAL("topi.elemwise_sum")
*rv = elemwise_sum(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sign")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sign(args[0]);
});
TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]);
......
......@@ -18,9 +18,10 @@ def test_ewise():
shape = (20, 3)
def test_apply(func, name, f_numpy, low, high, check_round=False):
def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
assert B.op.body[0].name == name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary
......@@ -49,6 +50,7 @@ def test_ewise():
test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment