Commit 9bfdc55c by Hiroyuki Makino Committed by Tianqi Chen

[Relay][TOPI] Add rsqrt operator (#2949)

parent fed1c08e
...@@ -36,6 +36,7 @@ List of operators ...@@ -36,6 +36,7 @@ List of operators
topi.tanh topi.tanh
topi.log topi.log
topi.sqrt topi.sqrt
topi.rsqrt
topi.sigmoid topi.sigmoid
topi.clip topi.clip
topi.cast topi.cast
...@@ -122,6 +123,7 @@ topi ...@@ -122,6 +123,7 @@ topi
.. autofunction:: topi.tanh .. autofunction:: topi.tanh
.. autofunction:: topi.log .. autofunction:: topi.log
.. autofunction:: topi.sqrt .. autofunction:: topi.sqrt
.. autofunction:: topi.rsqrt
.. autofunction:: topi.sigmoid .. autofunction:: topi.sigmoid
.. autofunction:: topi.clip .. autofunction:: topi.clip
.. autofunction:: topi.cast .. autofunction:: topi.cast
......
...@@ -41,6 +41,7 @@ This level enables fully connected multi-layer perceptron. ...@@ -41,6 +41,7 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.log tvm.relay.log
tvm.relay.sqrt tvm.relay.sqrt
tvm.relay.rsqrt
tvm.relay.exp tvm.relay.exp
tvm.relay.sigmoid tvm.relay.sigmoid
tvm.relay.add tvm.relay.add
...@@ -186,6 +187,7 @@ Level 1 Definitions ...@@ -186,6 +187,7 @@ Level 1 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.log .. autofunction:: tvm.relay.log
.. autofunction:: tvm.relay.sqrt .. autofunction:: tvm.relay.sqrt
.. autofunction:: tvm.relay.rsqrt
.. autofunction:: tvm.relay.exp .. autofunction:: tvm.relay.exp
.. autofunction:: tvm.relay.sigmoid .. autofunction:: tvm.relay.sigmoid
.. autofunction:: tvm.relay.add .. autofunction:: tvm.relay.add
......
...@@ -486,6 +486,7 @@ TVM_DECLARE_INTRIN_UNARY(exp); ...@@ -486,6 +486,7 @@ TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(popcount);
......
...@@ -52,6 +52,22 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar ...@@ -52,6 +52,22 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
return numpy.zeros(shape).astype(dtype) return numpy.zeros(shape).astype(dtype)
def rsqrt(x):
"""
Computes reciprocal of square root of x element-wise
Parameters
----------
x: Tensor
Returns
-------
res: Tensor
The result of reciprocal of square root of x
"""
return numpy.ones_like(x) / numpy.sqrt(x)
def popcount(x): def popcount(x):
""" """
Count ones in the binary representation of number x Count ones in the binary representation of number x
...@@ -103,6 +119,7 @@ HYBRID_GLOBALS = { ...@@ -103,6 +119,7 @@ HYBRID_GLOBALS = {
'allocate' : allocate, 'allocate' : allocate,
'output_tensor' : allocate, 'output_tensor' : allocate,
'sqrt' : numpy.sqrt, 'sqrt' : numpy.sqrt,
'rsqrt' : rsqrt,
'log' : numpy.log, 'log' : numpy.log,
'tanh' : numpy.tanh, 'tanh' : numpy.tanh,
'power' : numpy.power, 'power' : numpy.power,
......
...@@ -260,7 +260,7 @@ def log(x): ...@@ -260,7 +260,7 @@ def log(x):
def sqrt(x): def sqrt(x):
"""Take log of input x. """Take square root of input x.
Parameters Parameters
---------- ----------
...@@ -275,6 +275,22 @@ def sqrt(x): ...@@ -275,6 +275,22 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x) return call_pure_intrin(x.dtype, "sqrt", x)
def rsqrt(x):
"""Take reciprocal of square root of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "rsqrt", x)
def floor(x): def floor(x):
"""Take floor of float input x. """Take floor of float input x.
......
...@@ -27,6 +27,7 @@ schedule_elemwise = schedule_injective ...@@ -27,6 +27,7 @@ schedule_elemwise = schedule_injective
register_schedule("log", schedule_broadcast) register_schedule("log", schedule_broadcast)
register_schedule("exp", schedule_broadcast) register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast) register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast) register_schedule("sigmoid", schedule_broadcast)
register_schedule("floor", schedule_broadcast) register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast) register_schedule("ceil", schedule_broadcast)
......
...@@ -79,6 +79,26 @@ def sqrt(data): ...@@ -79,6 +79,26 @@ def sqrt(data):
return _make.sqrt(data) return _make.sqrt(data)
def rsqrt(data):
"""Compute elementwise rsqrt of data.
.. math::
1/sqrt(x)
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.rsqrt(data)
def sigmoid(data): def sigmoid(data):
"""Compute elementwise sigmoid of data. """Compute elementwise sigmoid of data.
......
...@@ -40,6 +40,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") ...@@ -40,6 +40,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].type(), 1);
*rv = one / sqrt(call->args[0]);
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>); .set_body(DispatchExtern<FloatSuffix>);
......
...@@ -64,7 +64,7 @@ RELAY_REGISTER_UNARY_OP("exp") ...@@ -64,7 +64,7 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
RELAY_REGISTER_UNARY_OP("sqrt") RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise. .describe(R"code(Returns the sqrt input array, computed element-wise.
.. math:: .. math::
sqrt(x) sqrt(x)
...@@ -73,6 +73,15 @@ RELAY_REGISTER_UNARY_OP("sqrt") ...@@ -73,6 +73,15 @@ RELAY_REGISTER_UNARY_OP("sqrt")
.set_support_level(1) .set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
RELAY_REGISTER_UNARY_OP("rsqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise.
.. math::
1/sqrt(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt));
RELAY_REGISTER_UNARY_OP("zeros_like") RELAY_REGISTER_UNARY_OP("zeros_like")
.describe(R"code(Returns an array of zeros, with same type and shape as the input. .describe(R"code(Returns an array of zeros, with same type and shape as the input.
......
...@@ -30,7 +30,7 @@ def test_op_attr(): ...@@ -30,7 +30,7 @@ def test_op_attr():
def test_op_level1(): def test_op_level1():
x = relay.Var("x") x = relay.Var("x")
for op_name in ["log", "exp", "sqrt", "tanh"]: for op_name in ["log", "exp", "sqrt", "rsqrt","tanh"]:
y = getattr(relay, op_name)(x) y = getattr(relay, op_name)(x)
assert y.op.name == op_name assert y.op.name == op_name
assert y.op.support_level == 1 assert y.op.support_level == 1
......
...@@ -30,6 +30,10 @@ def relu(x): ...@@ -30,6 +30,10 @@ def relu(x):
np.maximum(x_copy, 0, x_copy) np.maximum(x_copy, 0, x_copy)
return x_copy return x_copy
def rsqrt(x):
one = np.ones_like(x)
return one / np.sqrt(x)
def test_unary_op(): def test_unary_op():
def check_single_op(opfunc, ref): def check_single_op(opfunc, ref):
shape = (10, 4) shape = (10, 4)
...@@ -57,6 +61,7 @@ def test_unary_op(): ...@@ -57,6 +61,7 @@ def test_unary_op():
for opfunc, ref in [(tvm.relay.log, np.log), for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp), (tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt), (tvm.relay.sqrt, np.sqrt),
(tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid), (tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh), (tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]: (relay.nn.relu, relu)]:
......
...@@ -130,6 +130,24 @@ inline Tensor sign(const Tensor& x, ...@@ -130,6 +130,24 @@ inline Tensor sign(const Tensor& x,
} }
/*! /*!
* \brief Creates an operation that returns rsqrt of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the rsqrt operation
*/
inline Tensor rsqrt(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr one = make_const(x->dtype, 1);
return one/tvm::sqrt(x(i));
}, name, tag);
}
/*!
* \brief Creates an operation that clips each element of a tensor to * \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max] * the interval [a_min, a_max]
* *
......
...@@ -225,6 +225,23 @@ def sqrt(x): ...@@ -225,6 +225,23 @@ def sqrt(x):
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def rsqrt(x):
"""Take inverse square root of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.rsqrt(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def sigmoid(x): def sigmoid(x):
"""Take sigmoid tanh of input x. """Take sigmoid tanh of input x.
......
...@@ -163,6 +163,11 @@ TVM_REGISTER_GLOBAL("topi.sqrt") ...@@ -163,6 +163,11 @@ TVM_REGISTER_GLOBAL("topi.sqrt")
*rv = sqrt(args[0]); *rv = sqrt(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.rsqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rsqrt(args[0]);
});
TVM_REGISTER_GLOBAL("topi.log") TVM_REGISTER_GLOBAL("topi.log")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log(args[0]); *rv = log(args[0]);
......
...@@ -40,6 +40,7 @@ def test_ewise(): ...@@ -40,6 +40,7 @@ def test_ewise():
test_apply(topi.sigmoid, "sigmoid") test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log") test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt") test_apply(topi.sqrt, "sqrt")
test_apply(topi.rsqrt, "rsqrt")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -75,6 +75,7 @@ def test_ewise(): ...@@ -75,6 +75,7 @@ def test_ewise():
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1) test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100) test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)
if __name__ == "__main__": if __name__ == "__main__":
test_util() test_util()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment