Unverified Commit 6e36da35 by Samuel Committed by GitHub

[TOPI][PYTORCH]Logical & Bitwise operator support (#5341)

parent cc8cacb1
...@@ -99,6 +99,7 @@ List of operators ...@@ -99,6 +99,7 @@ List of operators
topi.logical_and topi.logical_and
topi.logical_or topi.logical_or
topi.logical_not topi.logical_not
topi.logical_xor
topi.arange topi.arange
topi.stack topi.stack
topi.repeat topi.repeat
...@@ -193,6 +194,7 @@ topi ...@@ -193,6 +194,7 @@ topi
.. autofunction:: topi.logical_and .. autofunction:: topi.logical_and
.. autofunction:: topi.logical_or .. autofunction:: topi.logical_or
.. autofunction:: topi.logical_not .. autofunction:: topi.logical_not
.. autofunction:: topi.logical_xor
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -150,6 +150,7 @@ This level enables additional math and transform operators. ...@@ -150,6 +150,7 @@ This level enables additional math and transform operators.
tvm.relay.logical_and tvm.relay.logical_and
tvm.relay.logical_or tvm.relay.logical_or
tvm.relay.logical_not tvm.relay.logical_not
tvm.relay.logical_xor
tvm.relay.maximum tvm.relay.maximum
tvm.relay.minimum tvm.relay.minimum
tvm.relay.power tvm.relay.power
......
...@@ -1168,7 +1168,6 @@ def _ceil(): ...@@ -1168,7 +1168,6 @@ def _ceil():
def _clamp(): def _clamp():
def _impl(inputs, input_types): def _impl(inputs, input_types):
print(inputs, input_types)
data = inputs[0] data = inputs[0]
amin = inputs[1] if inputs[1] else np.finfo(np.float32).min amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
amax = inputs[2] if inputs[2] else np.finfo(np.float32).max amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
...@@ -1298,6 +1297,67 @@ def _mm(): ...@@ -1298,6 +1297,67 @@ def _mm():
return _impl return _impl
def _bitwise_not():
def _impl(inputs, input_types):
data = inputs[0]
# The input tensor must be of integral or Boolean types.
# For bool tensors, it computes the logical NOT
if input_types[0] == "bool":
out = _op.logical_not(_op.cast(data, "bool"))
else:
out = _op.bitwise_not(_op.cast(data, "int"))
return out
return _impl
def _bitwise_xor():
def _impl(inputs, input_types):
lhs = inputs[0]
import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1]))
raise AssertionError(msg)
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")
return _op.bitwise_xor(lhs, rhs)
return _impl
def _logical_not():
def _impl(inputs, input_types):
data = inputs[0]
return _op.logical_not(_op.cast(data, "bool"))
return _impl
def _logical_xor():
def _impl(inputs, input_types):
lhs = _op.cast(inputs[0], "bool")
import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1]))
raise AssertionError(msg)
rhs = _op.cast(rhs, "bool")
return _op.logical_xor(lhs, rhs)
return _impl
def _isfinite(): def _isfinite():
def _impl(inputs, input_types): def _impl(inputs, input_types):
return _op.isfinite(inputs[0]) return _op.isfinite(inputs[0])
...@@ -1524,6 +1584,10 @@ def _get_convert_map(prelude): ...@@ -1524,6 +1584,10 @@ def _get_convert_map(prelude):
"aten::ge" : _elemwise("greater_equal"), "aten::ge" : _elemwise("greater_equal"),
"aten::ne" : _elemwise("not_equal"), "aten::ne" : _elemwise("not_equal"),
"aten::eq" : _elemwise("equal"), "aten::eq" : _elemwise("equal"),
"aten::logical_not" : _logical_not(),
"aten::logical_xor" : _logical_xor(),
"aten::bitwise_not" : _bitwise_not(),
"aten::bitwise_xor" : _bitwise_xor(),
"aten::isfinite" : _isfinite(), "aten::isfinite" : _isfinite(),
"aten::isnan" : _isnan(), "aten::isnan" : _isnan(),
"aten::Bool" : _Bool(), "aten::Bool" : _Bool(),
......
...@@ -53,6 +53,7 @@ register_broadcast_schedule("copy") ...@@ -53,6 +53,7 @@ register_broadcast_schedule("copy")
register_broadcast_schedule("logical_not") register_broadcast_schedule("logical_not")
register_broadcast_schedule("logical_and") register_broadcast_schedule("logical_and")
register_broadcast_schedule("logical_or") register_broadcast_schedule("logical_or")
register_broadcast_schedule("logical_xor")
register_broadcast_schedule("bitwise_not") register_broadcast_schedule("bitwise_not")
register_broadcast_schedule("bitwise_and") register_broadcast_schedule("bitwise_and")
register_broadcast_schedule("bitwise_or") register_broadcast_schedule("bitwise_or")
...@@ -205,6 +206,7 @@ register_shape_func("mod", False, broadcast_shape_func) ...@@ -205,6 +206,7 @@ register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func) register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func) register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func) register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("logical_xor", False, broadcast_shape_func)
register_shape_func("bitwise_not", False, broadcast_shape_func) register_shape_func("bitwise_not", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func) register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func) register_shape_func("bitwise_or", False, broadcast_shape_func)
......
...@@ -537,6 +537,23 @@ def logical_or(lhs, rhs): ...@@ -537,6 +537,23 @@ def logical_or(lhs, rhs):
return _make.logical_or(lhs, rhs) return _make.logical_or(lhs, rhs)
def logical_xor(lhs, rhs):
"""logical XOR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.logical_xor(lhs, rhs)
def bitwise_and(lhs, rhs): def bitwise_and(lhs, rhs):
"""bitwise AND with numpy-style broadcasting. """bitwise AND with numpy-style broadcasting.
......
...@@ -123,6 +123,12 @@ RELAY_REGISTER_BINARY_OP("logical_or") ...@@ -123,6 +123,12 @@ RELAY_REGISTER_BINARY_OP("logical_or")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
RELAY_REGISTER_BINARY_OP("logical_xor")
.describe("Elementwise logical XOR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor));
RELAY_REGISTER_BINARY_OP("bitwise_and") RELAY_REGISTER_BINARY_OP("bitwise_and")
.describe("Elementwise bitwise AND with broadcasting") .describe("Elementwise bitwise AND with broadcasting")
.set_support_level(4) .set_support_level(4)
......
...@@ -159,7 +159,7 @@ def verify_model(model_name, input_data=[], ...@@ -159,7 +159,7 @@ def verify_model(model_name, input_data=[],
if isinstance(baseline_outputs, tuple): if isinstance(baseline_outputs, tuple):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else: else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),) baseline_outputs = (baseline_outputs.cpu().numpy(),)
trace = torch.jit.trace(baseline_model, baseline_input).float().eval() trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
...@@ -1600,6 +1600,95 @@ def test_forward_topk(): ...@@ -1600,6 +1600,95 @@ def test_forward_topk():
verify_model(Topk6().float().eval(), input_data=input_data) verify_model(Topk6().float().eval(), input_data=input_data)
def test_forward_logical_not():
torch.set_grad_enabled(False)
class LogicalNot1(Module):
def forward(self, *args):
return torch.logical_not(args[0])
input_data = torch.tensor([True, False])
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
verify_model(LogicalNot1().float().eval(), input_data=input_data)
def test_forward_bitwise_not():
torch.set_grad_enabled(False)
class BitwiseNot1(Module):
def forward(self, *args):
return torch.bitwise_not(args[0])
input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
input_data = torch.tensor([True, False])
verify_model(BitwiseNot1().float().eval(), input_data=input_data)
def test_forward_bitwise_xor():
torch.set_grad_enabled(False)
class BitwiseXor1(Module):
def forward(self, *args):
return torch.bitwise_xor(args[0], args[1])
class BitwiseXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.bitwise_xor(args[0], rhs)
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(BitwiseXor2().float().eval(), input_data=[lhs])
def test_forward_logical_xor():
torch.set_grad_enabled(False)
class LogicalXor1(Module):
def forward(self, *args):
return torch.logical_xor(args[0], args[1])
class LogicalXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.logical_xor(args[0], rhs)
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])
lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(LogicalXor2().float().eval(), input_data=[lhs])
if __name__ == "__main__": if __name__ == "__main__":
# Single operator tests # Single operator tests
test_forward_add() test_forward_add()
...@@ -1663,6 +1752,10 @@ if __name__ == "__main__": ...@@ -1663,6 +1752,10 @@ if __name__ == "__main__":
test_forward_clamp() test_forward_clamp()
test_forward_floor() test_forward_floor()
test_forward_round() test_forward_round()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
test_forward_logical_xor()
test_forward_isfinite() test_forward_isfinite()
test_forward_isnan() test_forward_isnan()
test_forward_isinf() test_forward_isinf()
......
...@@ -141,6 +141,19 @@ TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; }); ...@@ -141,6 +141,19 @@ TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or); TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);
/*! /*!
* \fn logical_xor
* \brief Compute A ^ B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; });
/*!
* \fn bitwise_and * \fn bitwise_and
* \brief Compute A & B with auto-broadcasting. * \brief Compute A & B with auto-broadcasting.
* *
......
...@@ -420,6 +420,25 @@ def logical_or(lhs, rhs): ...@@ -420,6 +420,25 @@ def logical_or(lhs, rhs):
return _cpp.logical_or(lhs, rhs) return _cpp.logical_or(lhs, rhs)
def logical_xor(lhs, rhs):
"""Compute element-wise logical xor of data.
Parameters
----------
lhs : tvm.te.Tensor or Expr
The left operand
rhs : tvm.te.Tensor or Expr
The right operand
Returns
-------
ret : tvm.te.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_xor(lhs, rhs)
def bitwise_and(lhs, rhs): def bitwise_and(lhs, rhs):
"""Compute element-wise bitwise and of data. """Compute element-wise bitwise and of data.
......
...@@ -65,6 +65,7 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power); ...@@ -65,6 +65,7 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and); TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or); TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor); TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
......
...@@ -355,6 +355,8 @@ def test_logical_binary_ele(): ...@@ -355,6 +355,8 @@ def test_logical_binary_ele():
test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False]) test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
test_apply(topi.logical_or, "logical_or", np.logical_or, True, False) test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False]) test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
test_apply(topi.logical_xor, "logical_xor", np.logical_xor, True, False)
test_apply(topi.logical_xor, "logical_xor", np.logical_xor, [True, False], [False, False])
def test_bitwise_and(): def test_bitwise_and():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment