Commit ab1853c2 by Neo Chien Committed by Yao Wang

[TOPI] operator support: logical_and, logical_or, logical_not (#3929)

* [TOPI] operator support: logical_and, logical_or, logical_not

* [TOPI] operator support: logical_and, logical_or, logical_not

* [TOPI] fix test cases for operator support: logical_and, logical_or, logical_not

* [TOPI] fix test cases for operator support: logical_not
parent 26eaea4a
......@@ -175,6 +175,9 @@ topi
.. autofunction:: topi.topk
.. autofunction:: topi.sequence_mask
.. autofunction:: topi.one_hot
.. autofunction:: topi.logical_and
.. autofunction:: topi.logical_or
.. autofunction:: topi.logical_not
topi.nn
~~~~~~~
......
......@@ -18,6 +18,7 @@
from __future__ import absolute_import as _abs
from .import cpp as _cpp
def broadcast_to(data, shape):
"""Broadcast the src to the target shape
......@@ -341,3 +342,57 @@ def less_equal(lhs, rhs):
Otherwise returns Tensor.
"""
return _cpp.less_equal(lhs, rhs)
def logical_and(lhs, rhs):
"""Compute element-wise logical and of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_and(lhs, rhs)
def logical_or(lhs, rhs):
"""Compute element-wise logical or of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_or(lhs, rhs)
def logical_not(data):
"""Compute element-wise logical not of data.
Parameters
----------
data : tvm.Tensor or Expr
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if the operand are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_not(data)
......@@ -21,6 +21,7 @@ import tvm
from . import tag
from . import cpp
@tvm.tag_scope(tag=tag.ELEMWISE)
def identity(x):
"""Take identity of input x.
......@@ -107,6 +108,7 @@ def tanh(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def cos(x):
"""Take cos of input x.
......@@ -123,6 +125,7 @@ def cos(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.cos(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def sin(x):
"""Take sin of input x.
......@@ -139,6 +142,7 @@ def sin(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.sin(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
"""Take floor of input x.
......@@ -172,6 +176,7 @@ def ceil(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))
def sign(x):
"""Returns -1, 0, 1 based on sign of x.
......@@ -187,6 +192,7 @@ def sign(x):
"""
return cpp.sign(x)
@tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x):
"""Take truncated value of the input of x, element-wise.
......@@ -254,6 +260,7 @@ def log(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def sqrt(x):
"""Take square root of input x.
......@@ -391,6 +398,7 @@ def cast(x, dtype):
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
return tvm.make._cast(dtype, x)
def reinterpret(x, dtype):
"""Reinterpret input to specified data type.
......
......@@ -118,11 +118,6 @@ TVM_REGISTER_GLOBAL("topi.TEST_create_target")
} \
}); \
TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]);
});
TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
......@@ -142,6 +137,16 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.logical_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = logical_not(args[0]);
});
/* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
......@@ -20,6 +20,7 @@ import numpy as np
import tvm
import topi
def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
......@@ -99,18 +100,21 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
check_device(target)
check_device("sdaccel")
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
verify_broadcast_to_ele((), (10,), topi.broadcast_to)
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4), topi.broadcast_to)
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32), topi.broadcast_to)
def test_add():
verify_broadcast_binary_ele(
(), (), topi.add, np.add)
verify_broadcast_binary_ele(
(5, 2, 3), (2, 1), topi.add, np.add)
def test_subtract():
verify_broadcast_binary_ele(
(5, 2, 3), (), topi.subtract, np.subtract)
......@@ -121,10 +125,12 @@ def test_subtract():
verify_broadcast_binary_ele(
(1, 32), (64, 32), topi.subtract, np.subtract)
def test_multiply():
verify_broadcast_binary_ele(
(5, 64, 128), (2, 5, 64, 1), topi.multiply, np.multiply)
def test_divide():
verify_broadcast_binary_ele(
None, (10,), topi.divide, np.divide, rhs_min=0.0001)
......@@ -133,32 +139,41 @@ def test_divide():
verify_broadcast_binary_ele(
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
def test_maximum_minmum():
verify_broadcast_binary_ele(
(32,), (64, 32), topi.maximum, np.maximum)
verify_broadcast_binary_ele(
(1, 2, 2, 1, 32), (64, 32), topi.minimum, np.minimum)
def test_power():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.power, np.power, lhs_min=0.001, rhs_min=0.001, rhs_max=2)
def test_mod():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
def test_cmp():
# explicit specify the output type
def greater(x, y):
return topi.greater(x, y).astype("int8")
def less(x, y):
return topi.less(x, y).astype("int8")
def equal(x, y):
return topi.equal(x, y).astype("int8")
def not_equal(x, y):
return topi.not_equal(x, y).astype("int8")
def greater_equal(x, y):
return topi.greater_equal(x, y).astype("int8")
def less_equal(x, y):
return topi.less_equal(x, y).astype("int8")
verify_broadcast_binary_ele(
......@@ -178,6 +193,7 @@ def test_cmp():
(7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
def test_shift():
# explicit specify the output type
verify_broadcast_binary_ele(
......@@ -193,6 +209,90 @@ def test_shift():
dtype="int8", rhs_min=0, rhs_max=32)
def test_logical_single_ele():
def test_apply(
func,
name,
f_numpy,
indata,
dtype="bool",
):
# Build the logic and compile the function
A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype)
B = func(A)
if isinstance(A, tvm.expr.Expr):
assert (isinstance(B, tvm.expr.Expr))
return
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(B)
foo = tvm.build(s, [A, B], device, name=name)
data_npy = indata.astype(A.dtype)
data_nd = tvm.nd.array(data_npy, ctx)
out_npy = f_numpy(indata)
out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in get_all_backend():
check_device(device)
test_apply(topi.logical_not, "logical_not", np.logical_not, np.array([True, False, 0, 1]))
test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3))
def test_logical_binary_ele():
def test_apply(
func,
name,
f_numpy,
lhs,
rhs,
dtype="bool",
):
# Build the logic and compile the function
A = (tvm.var("A", dtype=dtype))
B = (tvm.var("B", dtype=dtype))
C = func(A, B)
if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
assert (isinstance(C, tvm.expr.Expr))
return
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
foo = tvm.build(s, [A, B, C], device, name=name)
lhs_nd = tvm.nd.array(lhs, ctx)
rhs_nd = tvm.nd.array(rhs, ctx)
out_npy = f_numpy(lhs, rhs)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
for device in get_all_backend():
check_device(device)
test_apply(topi.logical_and, "logical_and", np.logical_and, True, False)
test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
if __name__ == "__main__":
test_add()
test_shift()
......@@ -204,3 +304,5 @@ if __name__ == "__main__":
test_maximum_minmum()
test_power()
test_broadcast_to()
test_logical_single_ele()
test_logical_binary_ele()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment