Commit 3d010ed5 by Liangfu Chen Committed by Tianqi Chen

support equal and not_equal in topi (#1373)

parent d7c600b8
......@@ -46,6 +46,8 @@ List of operators
topi.max
topi.sum
topi.min
topi.argmax
topi.argmin
topi.broadcast_to
topi.add
topi.subtract
......@@ -57,6 +59,10 @@ List of operators
topi.power
topi.greater
topi.less
topi.equal
topi.not_equal
topi.greater_equal
topi.less_equal
topi.image.resize
......
......@@ -257,6 +257,58 @@ TOPI_DEFINE_BCAST_OP(greater, { return (a > b); });
*/
TOPI_DEFINE_BCAST_OP(less, { return (a < b); });
/*!
* \fn equal
* \brief Compute (A == B) with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(equal, { return (a == b); });
/*!
* \fn not_equal
* \brief Compute (A != B) with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(not_equal, { return (a != b); });
/*!
* \fn greater_equal
* \brief Compute (A >= B) with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(greater_equal, { return (a >= b); });
/*!
* \fn less_equal
* \brief Compute (A <= B) with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(less_equal, { return (a <= b); });
} // namespace topi
#endif // TOPI_BROADCAST_H_
......@@ -249,3 +249,79 @@ def less(lhs, rhs):
Otherwise returns Tensor.
"""
return _cpp.less(lhs, rhs)
def equal(lhs, rhs):
"""Compute (lhs==rhs) with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.equal(lhs, rhs)
def not_equal(lhs, rhs):
"""Compute (lhs!=rhs) with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.not_equal(lhs, rhs)
def greater_equal(lhs, rhs):
"""Compute (lhs>=rhs) with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.greater_equal(lhs, rhs)
def less_equal(lhs, rhs):
"""Compute (lhs<=rhs) with auto-broadcasting
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.less_equal(lhs, rhs)
......@@ -116,6 +116,10 @@ TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal);
TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
/* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp")
......
......@@ -142,10 +142,30 @@ def test_cmp():
return topi.greater(x, y).astype("int8")
def less(x, y):
return topi.less(x, y).astype("int8")
def equal(x, y):
return topi.equal(x, y).astype("int8")
def not_equal(x, y):
return topi.not_equal(x, y).astype("int8")
def greater_equal(x, y):
return topi.greater_equal(x, y).astype("int8")
def less_equal(x, y):
return topi.less_equal(x, y).astype("int8")
verify_broadcast_binary_ele(
(1, 2, 2), (2,), greater, np.greater)
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), less, np.less)
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), equal, np.equal,
lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), not_equal, np.not_equal,
lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
verify_broadcast_binary_ele(
(7, 1, 5), (7, 3, 1), greater_equal, np.greater_equal,
lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
verify_broadcast_binary_ele(
(7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
def test_shift():
# explicit specify the output type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment