Commit 9fd8e3c5 by Yong Wu Committed by Haichen Shen

[Relay][TOPI] operator All (#3124)

* [Relay][TOPI] operator All

* Update tests/python/frontend/tensorflow/test_forward.py

Co-Authored-By: yongwww <55wuyong@163.com>

* fix comments

* change to level 4
parent 3a9de905
...@@ -88,6 +88,7 @@ List of operators ...@@ -88,6 +88,7 @@ List of operators
topi.not_equal topi.not_equal
topi.greater_equal topi.greater_equal
topi.less_equal topi.less_equal
topi.all
topi.logical_and topi.logical_and
topi.logical_or topi.logical_or
topi.logical_not topi.logical_not
...@@ -140,6 +141,7 @@ topi ...@@ -140,6 +141,7 @@ topi
.. autofunction:: topi.gather_nd .. autofunction:: topi.gather_nd
.. autofunction:: topi.full .. autofunction:: topi.full
.. autofunction:: topi.full_like .. autofunction:: topi.full_like
.. autofunction:: topi.all
.. autofunction:: topi.max .. autofunction:: topi.max
.. autofunction:: topi.sum .. autofunction:: topi.sum
.. autofunction:: topi.min .. autofunction:: topi.min
......
...@@ -135,6 +135,7 @@ This level enables additional math and transform operators. ...@@ -135,6 +135,7 @@ This level enables additional math and transform operators.
tvm.relay.greater_equal tvm.relay.greater_equal
tvm.relay.less tvm.relay.less
tvm.relay.less_equal tvm.relay.less_equal
tvm.relay.all
tvm.relay.logical_and tvm.relay.logical_and
tvm.relay.logical_or tvm.relay.logical_or
tvm.relay.logical_not tvm.relay.logical_not
...@@ -277,6 +278,7 @@ Level 4 Definitions ...@@ -277,6 +278,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.greater_equal .. autofunction:: tvm.relay.greater_equal
.. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less
.. autofunction:: tvm.relay.less_equal .. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.all
.. autofunction:: tvm.relay.logical_and .. autofunction:: tvm.relay.logical_and
.. autofunction:: tvm.relay.logical_or .. autofunction:: tvm.relay.logical_or
.. autofunction:: tvm.relay.logical_not .. autofunction:: tvm.relay.logical_not
......
...@@ -429,6 +429,13 @@ TVM_DLL Expr abs(Expr x); ...@@ -429,6 +429,13 @@ TVM_DLL Expr abs(Expr x);
TVM_DLL Expr sum(Expr source, Array<IterVar> axis); TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
/*! /*!
* \brief logical And of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr all(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over axis * \brief max of of source expression over axis
* \param source The source expression. * \param source The source expression.
* \param axis List of iteration variables that will be used for reduction. * \param axis List of iteration variables that will be used for reduction.
......
...@@ -767,6 +767,17 @@ def _sum(): ...@@ -767,6 +767,17 @@ def _sum():
ignores=['name', 'Tidx'])([inputs[0]], attr) ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl return _impl
def _reduce_all():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()
axis = tuple(axis)
return AttrCvt(
op_name='all',
extras={'axis': axis},
transforms={'keep_dims':'keepdims'},
ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl
def _square(): def _square():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return _op.multiply(inputs[0], inputs[0]) return _op.multiply(inputs[0], inputs[0])
...@@ -1180,6 +1191,7 @@ _identity_list = [] ...@@ -1180,6 +1191,7 @@ _identity_list = []
# for N to 1 mapping, currently not supported(?) # for N to 1 mapping, currently not supported(?)
_convert_map = { _convert_map = {
'Add' : _elemwise('add'), 'Add' : _elemwise('add'),
'All' : _reduce_all(),
'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'), 'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'), 'AvgPool' : _pooling('avg_pool'),
......
...@@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target): ...@@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("argmax", _schedule_reduce) _reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce) _reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce) _reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("all", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce) _reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce)
......
...@@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False): ...@@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False):
return _make.sum(data, axis, keepdims, exclude) return _make.sum(data, axis, keepdims, exclude)
def all(data, axis=None, keepdims=False, exclude=False):
"""Computes the logical AND of boolean array elements over given axes.
Parameters
----------
data : relay.Expr
The input boolean tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed. The default, axis=None,
will sum all of the elements of the input array. If axis is
negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one. With this option, the result will broadcast
correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data = relay.Constant(tvm.nd.array([[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]))
relay.all(data, axis=1)
# [[False, True, False],
# [False, False, False]]
relay.all(data, axis=0)
# [[ True, False, False],
# [ True, True, False],
# [False, True, False]]
"""
axis = [axis] if axis and isinstance(axis, int) else axis
return _make.all(data, axis, keepdims, exclude)
def max(data, axis=None, keepdims=False, exclude=False): def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes. """ Computes the max of array elements over given axes.
......
...@@ -393,6 +393,16 @@ Expr sum(Expr source, Array<IterVar> rdom) { ...@@ -393,6 +393,16 @@ Expr sum(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
} }
Expr all(Expr source, Array<IterVar> rdom) {
CHECK(source.type().is_bool());
Var x("x", source.type()), y("y", source.type());
Expr result = ir::And::make(x, y);
Expr identity_element = make_const(source.type(), true);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr max(Expr source, Array<IterVar> rdom) { Expr max(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type()); Var x("x", source.type()), y("y", source.type());
Expr result = ir::Max::make(x, y); Expr result = ir::Max::make(x, y);
......
...@@ -355,6 +355,43 @@ Example:: ...@@ -355,6 +355,43 @@ Example::
.set_attr<TOpPattern>("TOpPattern", kCommReduce); .set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> AllCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::all);
}
RELAY_REGISTER_REDUCE_OP("all")
.describe(R"code(Computes the logical AND of boolean array elements over given axes.
Example::
data = [[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]
all(data, axis=1)
[[False, True, False],
[False, False, False]]
all(data, axis=0)
[[ True, False, False],
[ True, True, False],
[False, True, False]]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> MaxCompute(const Attrs& attrs, Array<Tensor> MaxCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
......
...@@ -1598,6 +1598,17 @@ def test_forward_mean(): ...@@ -1598,6 +1598,17 @@ def test_forward_mean():
check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True) check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True)
####################################################################### #######################################################################
# All
# ---
def test_forward_all():
"""Test the All operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph()
in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
tf.reduce_all(in_data, name="all")
compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
#######################################################################
# Relational operators # Relational operators
# -------------------- # --------------------
def _test_forward_rel_op(data, func): def _test_forward_rel_op(data, func):
...@@ -1718,6 +1729,7 @@ if __name__ == '__main__': ...@@ -1718,6 +1729,7 @@ if __name__ == '__main__':
test_forward_reduce() test_forward_reduce()
test_forward_mean() test_forward_mean()
test_forward_reduce_prod() test_forward_reduce_prod()
test_forward_all()
# General # General
test_forward_multi_input() test_forward_multi_input()
......
...@@ -138,6 +138,7 @@ def test_where(): ...@@ -138,6 +138,7 @@ def test_where():
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0] test_func = funcs[0]
ref_func = funcs[1] ref_func = funcs[1]
dtype = "bool" if ref_func in [np.all] else dtype
x = relay.var("x", relay.TensorType(data, dtype)) x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude) z = test_func(x, axis, keepdims, exclude)
...@@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") ...@@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
return return
func = relay.Function([x], z) func = relay.Function([x], z)
x_data = np.random.uniform(size=data).astype(dtype) x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \
else np.random.uniform(size=data).astype(dtype)
if ref_func in [np.sum]: if ref_func in [np.sum]:
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims) ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
elif ref_func in [np.max, np.min, np.mean, np.prod]: elif ref_func in [np.max, np.min, np.mean, np.prod]:
...@@ -194,6 +197,7 @@ def test_reduce_functions(): ...@@ -194,6 +197,7 @@ def test_reduce_functions():
[relay.min, np.min], [relay.min, np.min],
[relay.mean, np.mean], [relay.mean, np.mean],
[relay.prod, np.prod], [relay.prod, np.prod],
[relay.all, np.all],
[relay.argmin, _with_keepdims(np.argmin)], [relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]: [relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
...@@ -203,6 +207,7 @@ def test_reduce_functions(): ...@@ -203,6 +207,7 @@ def test_reduce_functions():
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, False, ()) verify_reduce(func, (4, 4, 3), None, False, False, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
......
...@@ -369,6 +369,27 @@ inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) { ...@@ -369,6 +369,27 @@ inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
} }
/*! /*!
* \brief Creates an operation that computes the logical AND of elements
* over a given axis
*
* \param data The input boolean tensor
* \param axis The axes to reduce. If axis is empty, the operation will
* perform logical AND over all elements of the array.
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
*
* \return A Tensor whose op member is the all operation
*/
inline Tensor all(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
}
/*!
* \brief Creates an operation that finds the minimum of elements over * \brief Creates an operation that finds the minimum of elements over
* a given axis. * a given axis.
* *
......
...@@ -65,6 +65,31 @@ def sum(data, axis=None, keepdims=False): ...@@ -65,6 +65,31 @@ def sum(data, axis=None, keepdims=False):
return cpp.sum(data, axis, keepdims) return cpp.sum(data, axis, keepdims)
def all(data, axis=None, keepdims=False):
"""Logical AND of array elements over a given axis or a list of axes
Parameters
----------
data : tvm.Tensor
The input tvm boolean tensor
axis : None or int or tuple of int
Axis or axes along which a logical AND is performed.
The default, axis=None, will perform logical AND over all elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
return cpp.all(data, axis, keepdims)
def max(data, axis=None, keepdims=False): def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes """Maximum of array elements over a given axis or a list of axes
......
...@@ -265,6 +265,11 @@ TVM_REGISTER_GLOBAL("topi.prod") ...@@ -265,6 +265,11 @@ TVM_REGISTER_GLOBAL("topi.prod")
*rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
}); });
TVM_REGISTER_GLOBAL("topi.all")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
});
/* Ops from transform.h */ /* Ops from transform.h */
TVM_REGISTER_GLOBAL("topi.expand_dims") TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -50,6 +50,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") ...@@ -50,6 +50,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
out_dtype = dtype out_dtype = dtype
if type == "sum": if type == "sum":
B = topi.sum(A1, axis=axis, keepdims=keepdims) B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "all":
B = topi.all(A, axis=axis, keepdims=keepdims)
elif type == "max": elif type == "max":
B = topi.max(A1, axis=axis, keepdims=keepdims) B = topi.max(A1, axis=axis, keepdims=keepdims)
elif type == "min": elif type == "min":
...@@ -74,10 +76,16 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") ...@@ -74,10 +76,16 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
foo = tvm.build(s, [A, B], device, name=type) foo = tvm.build(s, [A, B], device, name=type)
# Test # Test
in_npy = np.random.uniform(size=in_shape).astype(dtype) if dtype == 'bool':
in_npy_map = in_npy = np.random.choice([True, False], size=in_shape)
else:
in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)
if type == "sum": if type == "sum":
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "all" and dtype == 'bool':
out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
elif type == "max": elif type == "max":
out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
elif type == "min": elif type == "min":
...@@ -113,6 +121,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") ...@@ -113,6 +121,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
def test_reduce_map(): def test_reduce_map():
verify_reduce_map_ele(in_shape=(32,), verify_reduce_map_ele(in_shape=(32,),
axis=0, axis=0,
keepdims=False, keepdims=False,
...@@ -121,6 +130,11 @@ def test_reduce_map(): ...@@ -121,6 +130,11 @@ def test_reduce_map():
axis=(1, 2, 3), axis=(1, 2, 3),
keepdims=True, keepdims=True,
type="sum") type="sum")
verify_reduce_map_ele(in_shape=(2, 3),
axis=None,
keepdims=True,
type="all",
dtype='bool')
verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24), verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
axis=(1,), axis=(1,),
keepdims=False, keepdims=False,
...@@ -129,6 +143,11 @@ def test_reduce_map(): ...@@ -129,6 +143,11 @@ def test_reduce_map():
axis=None, axis=None,
keepdims=True, keepdims=True,
type="sum") type="sum")
verify_reduce_map_ele(in_shape=(32, 128, 24),
axis=None,
keepdims=True,
dtype='bool',
type="all")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24), verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(0, 2), axis=(0, 2),
keepdims=False, keepdims=False,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment