Commit ab1853c2 by Neo Chien Committed by Yao Wang

[TOPI] operator support: logical_and, logical_or, logical_not (#3929)

* [TOPI] operator support: logical_and, logical_or, logical_not

* [TOPI] operator support: logical_and, logical_or, logical_not

* [TOPI] fix test cases for operator support: logical_and, logical_or, logical_not

* [TOPI] fix test cases for operator support: logical_not
parent 26eaea4a
...@@ -175,6 +175,9 @@ topi ...@@ -175,6 +175,9 @@ topi
.. autofunction:: topi.topk .. autofunction:: topi.topk
.. autofunction:: topi.sequence_mask .. autofunction:: topi.sequence_mask
.. autofunction:: topi.one_hot .. autofunction:: topi.one_hot
.. autofunction:: topi.logical_and
.. autofunction:: topi.logical_or
.. autofunction:: topi.logical_not
topi.nn topi.nn
~~~~~~~ ~~~~~~~
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .import cpp as _cpp from .import cpp as _cpp
def broadcast_to(data, shape): def broadcast_to(data, shape):
"""Broadcast the src to the target shape """Broadcast the src to the target shape
...@@ -341,3 +342,57 @@ def less_equal(lhs, rhs): ...@@ -341,3 +342,57 @@ def less_equal(lhs, rhs):
Otherwise returns Tensor. Otherwise returns Tensor.
""" """
return _cpp.less_equal(lhs, rhs) return _cpp.less_equal(lhs, rhs)
def logical_and(lhs, rhs):
"""Compute element-wise logical and of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_and(lhs, rhs)
def logical_or(lhs, rhs):
"""Compute element-wise logical or of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_or(lhs, rhs)
def logical_not(data):
"""Compute element-wise logical not of data.
Parameters
----------
data : tvm.Tensor or Expr
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if the operand are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_not(data)
...@@ -21,6 +21,7 @@ import tvm ...@@ -21,6 +21,7 @@ import tvm
from . import tag from . import tag
from . import cpp from . import cpp
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def identity(x): def identity(x):
"""Take identity of input x. """Take identity of input x.
...@@ -107,6 +108,7 @@ def tanh(x): ...@@ -107,6 +108,7 @@ def tanh(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def cos(x): def cos(x):
"""Take cos of input x. """Take cos of input x.
...@@ -123,6 +125,7 @@ def cos(x): ...@@ -123,6 +125,7 @@ def cos(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.cos(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.cos(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def sin(x): def sin(x):
"""Take sin of input x. """Take sin of input x.
...@@ -139,6 +142,7 @@ def sin(x): ...@@ -139,6 +142,7 @@ def sin(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.sin(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.sin(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x): def floor(x):
"""Take floor of input x. """Take floor of input x.
...@@ -172,6 +176,7 @@ def ceil(x): ...@@ -172,6 +176,7 @@ def ceil(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))
def sign(x): def sign(x):
"""Returns -1, 0, 1 based on sign of x. """Returns -1, 0, 1 based on sign of x.
...@@ -187,6 +192,7 @@ def sign(x): ...@@ -187,6 +192,7 @@ def sign(x):
""" """
return cpp.sign(x) return cpp.sign(x)
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x): def trunc(x):
"""Take truncated value of the input of x, element-wise. """Take truncated value of the input of x, element-wise.
...@@ -254,6 +260,7 @@ def log(x): ...@@ -254,6 +260,7 @@ def log(x):
""" """
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def sqrt(x): def sqrt(x):
"""Take square root of input x. """Take square root of input x.
...@@ -391,6 +398,7 @@ def cast(x, dtype): ...@@ -391,6 +398,7 @@ def cast(x, dtype):
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE) x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
return tvm.make._cast(dtype, x) return tvm.make._cast(dtype, x)
def reinterpret(x, dtype): def reinterpret(x, dtype):
"""Reinterpret input to specified data type. """Reinterpret input to specified data type.
......
...@@ -118,11 +118,6 @@ TVM_REGISTER_GLOBAL("topi.TEST_create_target") ...@@ -118,11 +118,6 @@ TVM_REGISTER_GLOBAL("topi.TEST_create_target")
} \ } \
}); \ }); \
TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]);
});
TOPI_REGISTER_BCAST_OP("topi.add", topi::add); TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply); TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
...@@ -142,6 +137,16 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); ...@@ -142,6 +137,16 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.logical_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = logical_not(args[0]);
});
/* Ops from elemwise.h */ /* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp") TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import tvm import tvm
import topi import topi
def verify_broadcast_to_ele(in_shape, out_shape, fbcast): def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
# Build the logic and compile the function # Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A") A = tvm.placeholder(shape=in_shape, name="A")
...@@ -99,18 +100,21 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, ...@@ -99,18 +100,21 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
check_device(target) check_device(target)
check_device("sdaccel") check_device("sdaccel")
def test_broadcast_to(): def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,), topi.broadcast_to) verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
verify_broadcast_to_ele((), (10,), topi.broadcast_to) verify_broadcast_to_ele((), (10,), topi.broadcast_to)
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to) verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to) verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)
def test_add(): def test_add():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(), (), topi.add, np.add) (), (), topi.add, np.add)
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add) (5, 2, 3), (2, 1), topi.add, np.add)
def test_subtract(): def test_subtract():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(5, 2, 3), (), topi.subtract, np.subtract) (5, 2, 3), (), topi.subtract, np.subtract)
...@@ -121,10 +125,12 @@ def test_subtract(): ...@@ -121,10 +125,12 @@ def test_subtract():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(1, 32), (64, 32), topi.subtract, np.subtract) (1, 32), (64, 32), topi.subtract, np.subtract)
def test_multiply(): def test_multiply():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply) (5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
def test_divide(): def test_divide():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001) None, (10,), topi.divide, np.divide, rhs_min=0.0001)
...@@ -133,32 +139,41 @@ def test_divide(): ...@@ -133,32 +139,41 @@ def test_divide():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001) (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
def test_maximum_minmum(): def test_maximum_minmum():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(32,), (64, 32), topi.maximum, np.maximum) (32,), (64, 32), topi.maximum, np.maximum)
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum) (1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
def test_power(): def test_power():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2) (1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
def test_mod(): def test_mod():
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32") (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
def test_cmp(): def test_cmp():
# explicit specify the output type # explicit specify the output type
def greater(x, y): def greater(x, y):
return topi.greater(x, y).astype("int8") return topi.greater(x, y).astype("int8")
def less(x, y): def less(x, y):
return topi.less(x, y).astype("int8") return topi.less(x, y).astype("int8")
def equal(x, y): def equal(x, y):
return topi.equal(x, y).astype("int8") return topi.equal(x, y).astype("int8")
def not_equal(x, y): def not_equal(x, y):
return topi.not_equal(x, y).astype("int8") return topi.not_equal(x, y).astype("int8")
def greater_equal(x, y): def greater_equal(x, y):
return topi.greater_equal(x, y).astype("int8") return topi.greater_equal(x, y).astype("int8")
def less_equal(x, y): def less_equal(x, y):
return topi.less_equal(x, y).astype("int8") return topi.less_equal(x, y).astype("int8")
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
...@@ -178,6 +193,7 @@ def test_cmp(): ...@@ -178,6 +193,7 @@ def test_cmp():
(7, 1, 5), (7, 3, 1), less_equal, np.less_equal, (7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32') lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
def test_shift(): def test_shift():
# explicit specify the output type # explicit specify the output type
verify_broadcast_binary_ele( verify_broadcast_binary_ele(
...@@ -193,6 +209,90 @@ def test_shift(): ...@@ -193,6 +209,90 @@ def test_shift():
dtype="int8", rhs_min=0, rhs_max=32) dtype="int8", rhs_min=0, rhs_max=32)
def test_logical_single_ele():
def test_apply(
func,
name,
f_numpy,
indata,
dtype="bool",
):
# Build the logic and compile the function
A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype)
B = func(A)
if isinstance(A, tvm.expr.Expr):
assert (isinstance(B, tvm.expr.Expr))
return
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
foo = tvm.build(s, [A, B], device, name=name)
data_npy = indata.astype(A.dtype)
data_nd = tvm.nd.array(data_npy, ctx)
out_npy = f_numpy(indata)
out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
test_apply(topi.logical_not, "logical_not", np.logical_not, np.array([True, False, 0, 1]))
test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3))
def test_logical_binary_ele():
def test_apply(
func,
name,
f_numpy,
lhs,
rhs,
dtype="bool",
):
# Build the logic and compile the function
A = (tvm.var("A", dtype=dtype))
B = (tvm.var("B", dtype=dtype))
C = func(A, B)
if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
assert (isinstance(C, tvm.expr.Expr))
return
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
foo = tvm.build(s, [A, B, C], device, name=name)
lhs_nd = tvm.nd.array(lhs, ctx)
rhs_nd = tvm.nd.array(rhs, ctx)
out_npy = f_numpy(lhs, rhs)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
for device in get_all_backend():
check_device(device)
test_apply(topi.logical_and, "logical_and", np.logical_and, True, False)
test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
if __name__ == "__main__": if __name__ == "__main__":
test_add() test_add()
test_shift() test_shift()
...@@ -204,3 +304,5 @@ if __name__ == "__main__": ...@@ -204,3 +304,5 @@ if __name__ == "__main__":
test_maximum_minmum() test_maximum_minmum()
test_power() test_power()
test_broadcast_to() test_broadcast_to()
test_logical_single_ele()
test_logical_binary_ele()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment