Unverified Commit 6e36da35 by Samuel Committed by GitHub

[TOPI][PYTORCH]Logical & Bitwise operator support (#5341)

parent cc8cacb1
......@@ -99,6 +99,7 @@ List of operators
topi.logical_and
topi.logical_or
topi.logical_not
topi.logical_xor
topi.arange
topi.stack
topi.repeat
......@@ -193,6 +194,7 @@ topi
.. autofunction:: topi.logical_and
.. autofunction:: topi.logical_or
.. autofunction:: topi.logical_not
.. autofunction:: topi.logical_xor
topi.nn
~~~~~~~
......
......@@ -150,6 +150,7 @@ This level enables additional math and transform operators.
tvm.relay.logical_and
tvm.relay.logical_or
tvm.relay.logical_not
tvm.relay.logical_xor
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.power
......
......@@ -1168,7 +1168,6 @@ def _ceil():
def _clamp():
def _impl(inputs, input_types):
print(inputs, input_types)
data = inputs[0]
amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
......@@ -1298,6 +1297,67 @@ def _mm():
return _impl
def _bitwise_not():
def _impl(inputs, input_types):
data = inputs[0]
# The input tensor must be of integral or Boolean types.
# For bool tensors, it computes the logical NOT
if input_types[0] == "bool":
out = _op.logical_not(_op.cast(data, "bool"))
else:
out = _op.bitwise_not(_op.cast(data, "int"))
return out
return _impl
def _bitwise_xor():
def _impl(inputs, input_types):
lhs = inputs[0]
import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1]))
raise AssertionError(msg)
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")
return _op.bitwise_xor(lhs, rhs)
return _impl
def _logical_not():
def _impl(inputs, input_types):
data = inputs[0]
return _op.logical_not(_op.cast(data, "bool"))
return _impl
def _logical_xor():
def _impl(inputs, input_types):
lhs = _op.cast(inputs[0], "bool")
import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1]))
raise AssertionError(msg)
rhs = _op.cast(rhs, "bool")
return _op.logical_xor(lhs, rhs)
return _impl
def _isfinite():
def _impl(inputs, input_types):
return _op.isfinite(inputs[0])
......@@ -1524,6 +1584,10 @@ def _get_convert_map(prelude):
"aten::ge" : _elemwise("greater_equal"),
"aten::ne" : _elemwise("not_equal"),
"aten::eq" : _elemwise("equal"),
"aten::logical_not" : _logical_not(),
"aten::logical_xor" : _logical_xor(),
"aten::bitwise_not" : _bitwise_not(),
"aten::bitwise_xor" : _bitwise_xor(),
"aten::isfinite" : _isfinite(),
"aten::isnan" : _isnan(),
"aten::Bool" : _Bool(),
......
......@@ -53,6 +53,7 @@ register_broadcast_schedule("copy")
register_broadcast_schedule("logical_not")
register_broadcast_schedule("logical_and")
register_broadcast_schedule("logical_or")
register_broadcast_schedule("logical_xor")
register_broadcast_schedule("bitwise_not")
register_broadcast_schedule("bitwise_and")
register_broadcast_schedule("bitwise_or")
......@@ -205,6 +206,7 @@ register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("logical_xor", False, broadcast_shape_func)
register_shape_func("bitwise_not", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
......
......@@ -537,6 +537,23 @@ def logical_or(lhs, rhs):
return _make.logical_or(lhs, rhs)
def logical_xor(lhs, rhs):
"""logical XOR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.logical_xor(lhs, rhs)
def bitwise_and(lhs, rhs):
"""bitwise AND with numpy-style broadcasting.
......
......@@ -123,6 +123,12 @@ RELAY_REGISTER_BINARY_OP("logical_or")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
RELAY_REGISTER_BINARY_OP("logical_xor")
.describe("Elementwise logical XOR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor));
RELAY_REGISTER_BINARY_OP("bitwise_and")
.describe("Elementwise bitwise AND with broadcasting")
.set_support_level(4)
......
......@@ -159,7 +159,7 @@ def verify_model(model_name, input_data=[],
if isinstance(baseline_outputs, tuple):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
baseline_outputs = (baseline_outputs.cpu().numpy(),)
trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
......@@ -1600,6 +1600,95 @@ def test_forward_topk():
verify_model(Topk6().float().eval(), input_data=input_data)
def test_forward_logical_not():
torch.set_grad_enabled(False)
class LogicalNot1(Module):
def forward(self, *args):
return torch.logical_not(args[0])
input_data = torch.tensor([True, False])
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
def test_forward_bitwise_not():
torch.set_grad_enabled(False)
class BitwiseNot1(Module):
def forward(self, *args):
return torch.bitwise_not(args[0])
input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([True, False])
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
def test_forward_bitwise_xor():
torch.set_grad_enabled(False)
class BitwiseXor1(Module):
def forward(self, *args):
return torch.bitwise_xor(args[0], args[1])
class BitwiseXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.bitwise_xor(args[0], rhs)
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(BitwiseXor2().float().eval(), input_data=[lhs])
def test_forward_logical_xor():
torch.set_grad_enabled(False)
class LogicalXor1(Module):
def forward(self, *args):
return torch.logical_xor(args[0], args[1])
class LogicalXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.logical_xor(args[0], rhs)
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(LogicalXor2().float().eval(), input_data=[lhs])
if __name__ == "__main__":
# Single operator tests
test_forward_add()
......@@ -1663,6 +1752,10 @@ if __name__ == "__main__":
test_forward_clamp()
test_forward_floor()
test_forward_round()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
test_forward_logical_xor()
test_forward_isfinite()
test_forward_isnan()
test_forward_isinf()
......
......@@ -141,6 +141,19 @@ TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);
/*!
* \fn logical_xor
* \brief Compute A ^ B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; });
/*!
* \fn bitwise_and
* \brief Compute A & B with auto-broadcasting.
*
......
......@@ -420,6 +420,25 @@ def logical_or(lhs, rhs):
return _cpp.logical_or(lhs, rhs)
def logical_xor(lhs, rhs):
"""Compute element-wise logical xor of data.
Parameters
----------
lhs : tvm.te.Tensor or Expr
The left operand
rhs : tvm.te.Tensor or Expr
The right operand
Returns
-------
ret : tvm.te.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_xor(lhs, rhs)
def bitwise_and(lhs, rhs):
"""Compute element-wise bitwise and of data.
......
......@@ -65,6 +65,7 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
......
......@@ -355,6 +355,8 @@ def test_logical_binary_ele():
test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
test_apply(topi.logical_xor, "logical_xor", np.logical_xor, True, False)
test_apply(topi.logical_xor, "logical_xor", np.logical_xor, [True, False], [False, False])
def test_bitwise_and():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment