Commit 9bfdc55c by Hiroyuki Makino Committed by Tianqi Chen

[Relay][TOPI] Add rsqrt operator (#2949)

parent fed1c08e
......@@ -36,6 +36,7 @@ List of operators
topi.tanh
topi.log
topi.sqrt
topi.rsqrt
topi.sigmoid
topi.clip
topi.cast
......@@ -122,6 +123,7 @@ topi
.. autofunction:: topi.tanh
.. autofunction:: topi.log
.. autofunction:: topi.sqrt
.. autofunction:: topi.rsqrt
.. autofunction:: topi.sigmoid
.. autofunction:: topi.clip
.. autofunction:: topi.cast
......
......@@ -41,6 +41,7 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.log
tvm.relay.sqrt
tvm.relay.rsqrt
tvm.relay.exp
tvm.relay.sigmoid
tvm.relay.add
......@@ -186,6 +187,7 @@ Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
.. autofunction:: tvm.relay.sqrt
.. autofunction:: tvm.relay.rsqrt
.. autofunction:: tvm.relay.exp
.. autofunction:: tvm.relay.sigmoid
.. autofunction:: tvm.relay.add
......
......@@ -486,6 +486,7 @@ TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount);
......
......@@ -52,6 +52,22 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
return numpy.zeros(shape).astype(dtype)
def rsqrt(x):
"""
Computes reciprocal of square root of x element-wise
Parameters
----------
x: Tensor
Returns
-------
res: Tensor
The result of reciprocal of square root of x
"""
return numpy.ones_like(x) / numpy.sqrt(x)
def popcount(x):
"""
Count ones in the binary representation of number x
......@@ -103,6 +119,7 @@ HYBRID_GLOBALS = {
'allocate' : allocate,
'output_tensor' : allocate,
'sqrt' : numpy.sqrt,
'rsqrt' : rsqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
......
......@@ -260,7 +260,7 @@ def log(x):
def sqrt(x):
"""Take log of input x.
"""Take square root of input x.
Parameters
----------
......@@ -275,6 +275,22 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x)
def rsqrt(x):
"""Take reciprocal of square root of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "rsqrt", x)
def floor(x):
"""Take floor of float input x.
......
......@@ -27,6 +27,7 @@ schedule_elemwise = schedule_injective
register_schedule("log", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast)
register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast)
......
......@@ -79,6 +79,26 @@ def sqrt(data):
return _make.sqrt(data)
def rsqrt(data):
"""Compute elementwise rsqrt of data.
.. math::
1/sqrt(x)
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.rsqrt(data)
def sigmoid(data):
"""Compute elementwise sigmoid of data.
......
......@@ -40,6 +40,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].type(), 1);
*rv = one / sqrt(call->args[0]);
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>);
......
......@@ -64,7 +64,7 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise.
.describe(R"code(Returns the sqrt input array, computed element-wise.
.. math::
sqrt(x)
......@@ -73,6 +73,15 @@ RELAY_REGISTER_UNARY_OP("sqrt")
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
RELAY_REGISTER_UNARY_OP("rsqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise.
.. math::
1/sqrt(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt));
RELAY_REGISTER_UNARY_OP("zeros_like")
.describe(R"code(Returns an array of zeros, with same type and shape as the input.
......
......@@ -30,7 +30,7 @@ def test_op_attr():
def test_op_level1():
x = relay.Var("x")
for op_name in ["log", "exp", "sqrt", "tanh"]:
for op_name in ["log", "exp", "sqrt", "rsqrt","tanh"]:
y = getattr(relay, op_name)(x)
assert y.op.name == op_name
assert y.op.support_level == 1
......
......@@ -30,6 +30,10 @@ def relu(x):
np.maximum(x_copy, 0, x_copy)
return x_copy
def rsqrt(x):
one = np.ones_like(x)
return one / np.sqrt(x)
def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
......@@ -57,6 +61,7 @@ def test_unary_op():
for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]:
......
......@@ -130,6 +130,24 @@ inline Tensor sign(const Tensor& x,
}
/*!
* \brief Creates an operation that returns rsqrt of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the rsqrt operation
*/
inline Tensor rsqrt(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr one = make_const(x->dtype, 1);
return one/tvm::sqrt(x(i));
}, name, tag);
}
/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
*
......
......@@ -225,6 +225,23 @@ def sqrt(x):
@tvm.tag_scope(tag=tag.ELEMWISE)
def rsqrt(x):
"""Take inverse square root of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.rsqrt(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def sigmoid(x):
"""Take sigmoid tanh of input x.
......
......@@ -163,6 +163,11 @@ TVM_REGISTER_GLOBAL("topi.sqrt")
*rv = sqrt(args[0]);
});
TVM_REGISTER_GLOBAL("topi.rsqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rsqrt(args[0]);
});
TVM_REGISTER_GLOBAL("topi.log")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log(args[0]);
......
......@@ -40,6 +40,7 @@ def test_ewise():
test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt")
test_apply(topi.rsqrt, "rsqrt")
if __name__ == "__main__":
......
......@@ -75,6 +75,7 @@ def test_ewise():
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)
if __name__ == "__main__":
test_util()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment