Unverified Commit 2bd2f998 by abergeron Committed by GitHub

[TOPI][Relay] Add bitwise ops (#4815)

* Add bitwise ops to topi

* Add the bitwise ops to relay.
parent 19d0d157
...@@ -44,6 +44,7 @@ register_schedule("sign", schedule_broadcast) ...@@ -44,6 +44,7 @@ register_schedule("sign", schedule_broadcast)
register_schedule("abs", schedule_broadcast) register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast) register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast) register_schedule("logical_not", schedule_broadcast)
register_schedule("bitwise_not", schedule_broadcast)
register_schedule("negative", schedule_broadcast) register_schedule("negative", schedule_broadcast)
register_schedule("copy", schedule_broadcast) register_schedule("copy", schedule_broadcast)
...@@ -57,6 +58,9 @@ register_schedule("mod", schedule_broadcast) ...@@ -57,6 +58,9 @@ register_schedule("mod", schedule_broadcast)
register_schedule("floor_mod", schedule_broadcast) register_schedule("floor_mod", schedule_broadcast)
register_schedule("logical_and", schedule_broadcast) register_schedule("logical_and", schedule_broadcast)
register_schedule("logical_or", schedule_broadcast) register_schedule("logical_or", schedule_broadcast)
register_schedule("bitwise_and", schedule_broadcast)
register_schedule("bitwise_or", schedule_broadcast)
register_schedule("bitwise_xor", schedule_broadcast)
register_schedule("equal", schedule_broadcast) register_schedule("equal", schedule_broadcast)
register_schedule("not_equal", schedule_broadcast) register_schedule("not_equal", schedule_broadcast)
register_schedule("less", schedule_broadcast) register_schedule("less", schedule_broadcast)
...@@ -194,6 +198,9 @@ register_shape_func("mod", False, broadcast_shape_func) ...@@ -194,6 +198,9 @@ register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func) register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func) register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func) register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
register_shape_func("bitwise_xor", False, broadcast_shape_func)
register_shape_func("equal", False, broadcast_shape_func) register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func) register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func) register_shape_func("less", False, broadcast_shape_func)
......
...@@ -318,6 +318,22 @@ def logical_not(data): ...@@ -318,6 +318,22 @@ def logical_not(data):
return _make.logical_not(data) return _make.logical_not(data)
def bitwise_not(data):
"""Compute element-wise bitwise not of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_not(data)
def add(lhs, rhs): def add(lhs, rhs):
"""Addition with numpy-style broadcasting. """Addition with numpy-style broadcasting.
...@@ -506,6 +522,60 @@ def logical_or(lhs, rhs): ...@@ -506,6 +522,60 @@ def logical_or(lhs, rhs):
return _make.logical_or(lhs, rhs) return _make.logical_or(lhs, rhs)
def bitwise_and(lhs, rhs):
"""bitwise AND with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_and(lhs, rhs)
def bitwise_or(lhs, rhs):
"""bitwise OR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_or(lhs, rhs)
def bitwise_xor(lhs, rhs):
"""bitwise XOR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_xor(lhs, rhs)
def equal(lhs, rhs): def equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs == rhs). """Broadcasted elementwise test for (lhs == rhs).
......
...@@ -124,6 +124,24 @@ RELAY_REGISTER_BINARY_OP("logical_or") ...@@ -124,6 +124,24 @@ RELAY_REGISTER_BINARY_OP("logical_or")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
RELAY_REGISTER_BINARY_OP("bitwise_and")
.describe("Elementwise bitwise AND with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and));
RELAY_REGISTER_BINARY_OP("bitwise_or")
.describe("Elementwise bitwise OR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or));
RELAY_REGISTER_BINARY_OP("bitwise_xor")
.describe("Elementwise bitwise XOR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor));
RELAY_REGISTER_CMP_OP("equal") RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting") .describe("Elementwise equal compare with broadcasting")
.set_support_level(4) .set_support_level(4)
......
...@@ -266,13 +266,24 @@ RELAY_REGISTER_UNARY_OP("logical_not") ...@@ -266,13 +266,24 @@ RELAY_REGISTER_UNARY_OP("logical_not")
.describe(R"code(Returns the logical inverse of input array, computed element-wise. .describe(R"code(Returns the logical inverse of input array, computed element-wise.
.. math:: .. math::
~(x) !(x)
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_support_level(4) .set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
RELAY_REGISTER_UNARY_OP("bitwise_not")
.describe(R"code(Returns the bitwise inverse of input array, computed element-wise.
.. math::
~(x)
)code" TVM_ADD_FILELINE)
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not));
// shape_of // shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
......
...@@ -141,6 +141,48 @@ TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; }); ...@@ -141,6 +141,48 @@ TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or); TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);
/*! /*!
* \fn bitwise_and
* \brief Compute A & B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; });
TOPI_DEFINE_OP_OVERLOAD(operator&, bitwise_and);
/*!
* \fn bitwise_or
* \brief Compute A | B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; });
TOPI_DEFINE_OP_OVERLOAD(operator|, bitwise_or);
/*!
* \fn bitwise_xor
* \brief Compute A ^ B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; });
TOPI_DEFINE_OP_OVERLOAD(operator^, bitwise_xor);
/*!
* \fn add * \fn add
* \brief Compute A + B with auto-broadcasting. * \brief Compute A + B with auto-broadcasting.
* *
......
...@@ -179,6 +179,23 @@ inline Tensor logical_not(const Tensor& x, ...@@ -179,6 +179,23 @@ inline Tensor logical_not(const Tensor& x,
} }
/*! /*!
* \brief Creates an operation that returns the bitwise NOT of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the bitwise NOT operation
*/
inline Tensor bitwise_not(const Tensor& x,
std::string name = "T_bitwise_not",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return ~x(i);
}, name, tag);
}
/*!
* \brief Returns the sign of the tensor * \brief Returns the sign of the tensor
* *
* \param x The input tensor * \param x The input tensor
......
...@@ -420,6 +420,63 @@ def logical_or(lhs, rhs): ...@@ -420,6 +420,63 @@ def logical_or(lhs, rhs):
return _cpp.logical_or(lhs, rhs) return _cpp.logical_or(lhs, rhs)
def bitwise_and(lhs, rhs):
"""Compute element-wise bitwise and of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_and(lhs, rhs)
def bitwise_or(lhs, rhs):
"""Compute element-wise bitwise or of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_or(lhs, rhs)
def bitwise_xor(lhs, rhs):
"""Compute element-wise bitwise xor of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_xor(lhs, rhs)
def logical_not(data): def logical_not(data):
"""Compute element-wise logical not of data. """Compute element-wise logical not of data.
...@@ -434,3 +491,19 @@ def logical_not(data): ...@@ -434,3 +491,19 @@ def logical_not(data):
Otherwise returns Tensor. Otherwise returns Tensor.
""" """
return _cpp.logical_not(data) return _cpp.logical_not(data)
def bitwise_not(data):
"""Compute element-wise bitwise not of data.
Parameters
----------
data : tvm.Tensor or Expr
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if the operand are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_not(data)
...@@ -133,6 +133,9 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power); ...@@ -133,6 +133,9 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less); TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
...@@ -151,6 +154,11 @@ TVM_REGISTER_GLOBAL("topi.logical_not") ...@@ -151,6 +154,11 @@ TVM_REGISTER_GLOBAL("topi.logical_not")
*rv = logical_not(args[0]); *rv = logical_not(args[0]);
}); });
TVM_REGISTER_GLOBAL("topi.bitwise_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = bitwise_not(args[0]);
});
/* Ops from elemwise.h */ /* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp") TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -270,6 +270,47 @@ def test_logical_single_ele(): ...@@ -270,6 +270,47 @@ def test_logical_single_ele():
test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3)) test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3))
def test_bitwise_not():
def test_apply(
func,
name,
f_numpy,
shape,
dtype="int32",
):
# Build the logic and compile the function
A = tvm.placeholder(shape=shape, name="A", dtype=dtype)
B = func(A)
if isinstance(A, tvm.expr.PrimExpr):
assert (isinstance(B, tvm.expr.PrimExpr))
return
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
foo = tvm.build(s, [A, B], device, name=name)
data_npy = np.random.uniform(size=shape).astype(A.dtype)
data_nd = tvm.nd.array(data_npy, ctx)
out_npy = f_numpy(data_npy)
out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, ())
test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, (2, 1, 2))
def test_logical_binary_ele(): def test_logical_binary_ele():
def test_apply( def test_apply(
func, func,
...@@ -314,6 +355,33 @@ def test_logical_binary_ele(): ...@@ -314,6 +355,33 @@ def test_logical_binary_ele():
test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False]) test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
def test_bitwise_and():
verify_broadcast_binary_ele(
None, None, topi.bitwise_and, np.bitwise_and,
dtype="int32")
verify_broadcast_binary_ele(
(2, 1, 2), (2, 1, 2), topi.bitwise_and, np.bitwise_and,
dtype="int32")
def test_bitwise_or():
verify_broadcast_binary_ele(
None, None, topi.bitwise_or, np.bitwise_or,
dtype="int32")
verify_broadcast_binary_ele(
(2, 1, 2), (2, 1, 2), topi.bitwise_or, np.bitwise_or,
dtype="int32")
def test_bitwise_xor():
verify_broadcast_binary_ele(
None, None, topi.bitwise_xor, np.bitwise_xor,
dtype="int32")
verify_broadcast_binary_ele(
(2, 1, 2), (2, 1, 2), topi.bitwise_xor, np.bitwise_xor,
dtype="int32")
if __name__ == "__main__": if __name__ == "__main__":
test_add() test_add()
test_shift() test_shift()
...@@ -328,4 +396,8 @@ if __name__ == "__main__": ...@@ -328,4 +396,8 @@ if __name__ == "__main__":
test_power() test_power()
test_broadcast_to() test_broadcast_to()
test_logical_single_ele() test_logical_single_ele()
test_bitwise_not()
test_logical_binary_ele() test_logical_binary_ele()
test_bitwise_and()
test_bitwise_or()
test_bitwise_xor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment