Unverified Commit 2bd2f998 by abergeron Committed by GitHub

[TOPI][Relay] Add bitwise ops (#4815)

* Add bitwise ops to topi

* Add the bitwise ops to relay.
parent 19d0d157
......@@ -44,6 +44,7 @@ register_schedule("sign", schedule_broadcast)
register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast)
register_schedule("bitwise_not", schedule_broadcast)
register_schedule("negative", schedule_broadcast)
register_schedule("copy", schedule_broadcast)
......@@ -57,6 +58,9 @@ register_schedule("mod", schedule_broadcast)
register_schedule("floor_mod", schedule_broadcast)
register_schedule("logical_and", schedule_broadcast)
register_schedule("logical_or", schedule_broadcast)
register_schedule("bitwise_and", schedule_broadcast)
register_schedule("bitwise_or", schedule_broadcast)
register_schedule("bitwise_xor", schedule_broadcast)
register_schedule("equal", schedule_broadcast)
register_schedule("not_equal", schedule_broadcast)
register_schedule("less", schedule_broadcast)
......@@ -194,6 +198,9 @@ register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
register_shape_func("bitwise_xor", False, broadcast_shape_func)
register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func)
......
......@@ -318,6 +318,22 @@ def logical_not(data):
return _make.logical_not(data)
def bitwise_not(data):
"""Compute element-wise bitwise not of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_not(data)
def add(lhs, rhs):
"""Addition with numpy-style broadcasting.
......@@ -506,6 +522,60 @@ def logical_or(lhs, rhs):
return _make.logical_or(lhs, rhs)
def bitwise_and(lhs, rhs):
"""bitwise AND with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_and(lhs, rhs)
def bitwise_or(lhs, rhs):
"""bitwise OR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_or(lhs, rhs)
def bitwise_xor(lhs, rhs):
"""bitwise XOR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_xor(lhs, rhs)
def equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs == rhs).
......
......@@ -124,6 +124,24 @@ RELAY_REGISTER_BINARY_OP("logical_or")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
RELAY_REGISTER_BINARY_OP("bitwise_and")
.describe("Elementwise bitwise AND with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and));
RELAY_REGISTER_BINARY_OP("bitwise_or")
.describe("Elementwise bitwise OR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or));
RELAY_REGISTER_BINARY_OP("bitwise_xor")
.describe("Elementwise bitwise XOR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor));
RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting")
.set_support_level(4)
......
......@@ -266,13 +266,24 @@ RELAY_REGISTER_UNARY_OP("logical_not")
.describe(R"code(Returns the logical inverse of input array, computed element-wise.
.. math::
~(x)
!(x)
)code" TVM_ADD_FILELINE)
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
RELAY_REGISTER_UNARY_OP("bitwise_not")
.describe(R"code(Returns the bitwise inverse of input array, computed element-wise.
.. math::
~(x)
)code" TVM_ADD_FILELINE)
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not));
// shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
......
......@@ -141,6 +141,48 @@ TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);
/*!
* \fn bitwise_and
* \brief Compute A & B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; });
TOPI_DEFINE_OP_OVERLOAD(operator&, bitwise_and);
/*!
* \fn bitwise_or
* \brief Compute A | B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; });
TOPI_DEFINE_OP_OVERLOAD(operator|, bitwise_or);
/*!
* \fn bitwise_xor
* \brief Compute A ^ B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; });
TOPI_DEFINE_OP_OVERLOAD(operator^, bitwise_xor);
/*!
* \fn add
* \brief Compute A + B with auto-broadcasting.
*
......
......@@ -179,6 +179,23 @@ inline Tensor logical_not(const Tensor& x,
}
/*!
* \brief Creates an operation that returns the bitwise NOT of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the bitwise NOT operation
*/
inline Tensor bitwise_not(const Tensor& x,
std::string name = "T_bitwise_not",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return ~x(i);
}, name, tag);
}
/*!
* \brief Returns the sign of the tensor
*
* \param x The input tensor
......
......@@ -420,6 +420,63 @@ def logical_or(lhs, rhs):
return _cpp.logical_or(lhs, rhs)
def bitwise_and(lhs, rhs):
"""Compute element-wise bitwise and of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_and(lhs, rhs)
def bitwise_or(lhs, rhs):
"""Compute element-wise bitwise or of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_or(lhs, rhs)
def bitwise_xor(lhs, rhs):
"""Compute element-wise bitwise xor of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_xor(lhs, rhs)
def logical_not(data):
"""Compute element-wise logical not of data.
......@@ -434,3 +491,19 @@ def logical_not(data):
Otherwise returns Tensor.
"""
return _cpp.logical_not(data)
def bitwise_not(data):
"""Compute element-wise bitwise not of data.
Parameters
----------
data : tvm.Tensor or Expr
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if the operand are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_not(data)
......@@ -133,6 +133,9 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
......@@ -151,6 +154,11 @@ TVM_REGISTER_GLOBAL("topi.logical_not")
*rv = logical_not(args[0]);
});
TVM_REGISTER_GLOBAL("topi.bitwise_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = bitwise_not(args[0]);
});
/* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
......@@ -270,6 +270,47 @@ def test_logical_single_ele():
test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3))
def test_bitwise_not():
def test_apply(
func,
name,
f_numpy,
shape,
dtype="int32",
):
# Build the logic and compile the function
A = tvm.placeholder(shape=shape, name="A", dtype=dtype)
B = func(A)
if isinstance(A, tvm.expr.PrimExpr):
assert (isinstance(B, tvm.expr.PrimExpr))
return
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
foo = tvm.build(s, [A, B], device, name=name)
data_npy = np.random.uniform(size=shape).astype(A.dtype)
data_nd = tvm.nd.array(data_npy, ctx)
out_npy = f_numpy(data_npy)
out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, ())
test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, (2, 1, 2))
def test_logical_binary_ele():
def test_apply(
func,
......@@ -314,6 +355,33 @@ def test_logical_binary_ele():
test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
def test_bitwise_and():
verify_broadcast_binary_ele(
None, None, topi.bitwise_and, np.bitwise_and,
dtype="int32")
verify_broadcast_binary_ele(
(2, 1, 2), (2, 1, 2), topi.bitwise_and, np.bitwise_and,
dtype="int32")
def test_bitwise_or():
verify_broadcast_binary_ele(
None, None, topi.bitwise_or, np.bitwise_or,
dtype="int32")
verify_broadcast_binary_ele(
(2, 1, 2), (2, 1, 2), topi.bitwise_or, np.bitwise_or,
dtype="int32")
def test_bitwise_xor():
verify_broadcast_binary_ele(
None, None, topi.bitwise_xor, np.bitwise_xor,
dtype="int32")
verify_broadcast_binary_ele(
(2, 1, 2), (2, 1, 2), topi.bitwise_xor, np.bitwise_xor,
dtype="int32")
if __name__ == "__main__":
test_add()
test_shift()
......@@ -328,4 +396,8 @@ if __name__ == "__main__":
test_power()
test_broadcast_to()
test_logical_single_ele()
test_bitwise_not()
test_logical_binary_ele()
test_bitwise_and()
test_bitwise_or()
test_bitwise_xor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment