Commit b07b1952 by Jon Soifer Committed by Jared Roesch

[Relay][Topi][TensorFlow][ONNX][Lang] Add support for Any op (#4205)

* Add support for Any op

* Support ONNX frontend

* Add doc

* Add to relay docs

* Dummy change to retrigger CI
parent 156aa590
......@@ -91,6 +91,7 @@ List of operators
topi.greater_equal
topi.less_equal
topi.all
topi.any
topi.logical_and
topi.logical_or
topi.logical_not
......@@ -151,6 +152,7 @@ topi
.. autofunction:: topi.full
.. autofunction:: topi.full_like
.. autofunction:: topi.all
.. autofunction:: topi.any
.. autofunction:: topi.max
.. autofunction:: topi.sum
.. autofunction:: topi.min
......
......@@ -116,6 +116,7 @@ Supported Ops
- Abs
- Add
- All
- Any
- ArgMax
- ArgMin
- AvgPool
......
......@@ -137,6 +137,7 @@ This level enables additional math and transform operators.
tvm.relay.less
tvm.relay.less_equal
tvm.relay.all
tvm.relay.any
tvm.relay.logical_and
tvm.relay.logical_or
tvm.relay.logical_not
......@@ -300,6 +301,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.less
.. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.all
.. autofunction:: tvm.relay.any
.. autofunction:: tvm.relay.logical_and
.. autofunction:: tvm.relay.logical_or
.. autofunction:: tvm.relay.logical_not
......
......@@ -520,6 +520,13 @@ TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
TVM_DLL Expr all(Expr source, Array<IterVar> axis);
/*!
* \brief logical Or of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr any(Expr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
......
......@@ -989,6 +989,12 @@ class Where(OnnxOpConverter):
def _impl_v9(cls, inputs, attr, params):
return _op.where(inputs[0], inputs[1], inputs[2])
class Or(Elemwise):
""" Operator converter for Or.
"""
@classmethod
def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1])
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -1111,7 +1117,8 @@ def _get_convert_map(opset):
'And': And.get_converter(opset),
'Tile': Tile.get_converter(opset),
'Erf': Erf.get_converter(opset),
'Where': Where.get_converter(opset)
'Where': Where.get_converter(opset),
'Or': Or.get_converter(opset)
}
......
......@@ -1330,6 +1330,7 @@ _convert_map = {
'Abs' : AttrCvt('abs'),
'Add' : _elemwise('add'),
'All' : _reduce('all'),
'Any' : _reduce('any'),
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'Assert' : _assert(),
......
......@@ -31,6 +31,7 @@ _reg.register_schedule("argmax", _schedule_reduce)
_reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("all", _schedule_reduce)
_reg.register_schedule("any", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
......
......@@ -166,6 +166,58 @@ def all(data, axis=None, keepdims=False, exclude=False):
return _make.all(data, axis, keepdims, exclude)
def any(data, axis=None, keepdims=False, exclude=False):
"""Computes the logical OR of boolean array elements over given axes.
Parameters
----------
data : relay.Expr
The input boolean tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed. The default, axis=None,
will sum all of the elements of the input array. If axis is
negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one. With this option, the result will broadcast
correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data = relay.Constant(tvm.nd.array([[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]))
relay.any(data, axis=1)
# [[True, True, True],
# [True, True, True]]
relay.any(data, axis=0)
# [[ True, True, True],
# [ True, True, True],
# [False, True, True]]
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.any(data, axis, keepdims, exclude)
def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.
......
......@@ -486,6 +486,16 @@ Expr all(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr any(Expr source, Array<IterVar> rdom) {
CHECK(source.type().is_bool());
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Or::make(x, y);
Expr identity_element = make_const(source.type(), false);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}
Expr max(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Max::make(x, y);
......
......@@ -420,6 +420,43 @@ Example::
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> AnyCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::any);
}
RELAY_REGISTER_REDUCE_OP("any")
.describe(R"code(Computes the logical OR of boolean array elements over given axes.
Example::
data = [[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]
any(data, axis=1)
[[True, True, True],
[True, True, True]]
any(data, axis=0)
[[ True, True, True],
[ True, True, True],
[False, True, True]]
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", AnyCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
Array<Tensor> MaxCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
......
......@@ -1601,6 +1601,53 @@ def test_where():
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
def verify_or(indata, dtype):
x = indata[0].astype(dtype)
y = indata[1].astype(dtype)
outdata = np.logical_or(x, y)
node = helper.make_node('Or', inputs=['in1', 'in2'], outputs=['out'], )
graph = helper.make_graph([node],
'or_test',
inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
model = helper.make_model(graph, producer_name='or_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_or():
# 2d
x = (np.random.randn(3, 4) > 0)
y = (np.random.randn(3, 4) > 0)
verify_or(indata=[x, y], dtype=bool)
# 3d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(3, 4, 5) > 0)
verify_or(indata=[x, y], dtype=bool)
# 4d
x = (np.random.randn(3, 4, 5, 6) > 0)
y = (np.random.randn(3, 4, 5, 6) > 0)
verify_or(indata=[x, y], dtype=bool)
# 3d vs 1d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(5) > 0)
verify_or(indata=[x, y], dtype=bool)
# 3d vs 2d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(4, 5) > 0)
verify_or(indata=[x, y], dtype=bool)
if __name__ == '__main__':
test_flatten()
test_reshape()
......@@ -1651,3 +1698,4 @@ if __name__ == '__main__':
test_tile()
test_erf()
test_where()
test_or()
......@@ -2198,7 +2198,7 @@ def test_forward_size():
check_size((10,))
#######################################################################
# All, Max, Min
# All, Any, Max, Min
# -------------
def test_forward_reduce_all():
"""Test the All operator."""
......@@ -2208,6 +2208,14 @@ def test_forward_reduce_all():
tf.reduce_all(in_data, name="all")
compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
def test_forward_reduce_any():
"""Test the Any operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph()
in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
tf.reduce_any(in_data, name="any")
compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')
def test_forward_reduce_max():
def check_max(ishape, axis, keepdims, dtype):
tf.reset_default_graph()
......@@ -2432,7 +2440,7 @@ if __name__ == '__main__':
test_forward_mean()
test_forward_reduce_prod()
test_forward_reduce_all()
test_forward_reduce_max()
test_forward_reduce_any()
test_forward_reduce_min()
# General
......
......@@ -145,7 +145,7 @@ def test_where():
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0]
ref_func = funcs[1]
dtype = "bool" if ref_func in [np.all] else dtype
dtype = "bool" if ref_func in [np.all, np.any] else dtype
x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
......@@ -207,6 +207,7 @@ def test_reduce_functions():
[relay.std, np.std],
[relay.prod, np.prod],
[relay.all, np.all],
[relay.any, np.any],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
......
......@@ -391,6 +391,27 @@ inline Tensor all(const Tensor& data,
}
/*!
* \brief Creates an operation that computes the logical OR of elements
* over a given axis
*
* \param data The input boolean tensor
* \param axis The axes to reduce. If axis is empty, the operation will
* perform logical OR over all elements of the array.
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
*
* \return A Tensor whose op member is the all operation
*/
inline Tensor any(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
}
/*!
* \brief Creates an operation that finds the minimum of elements over
* a given axis.
*
......
......@@ -90,6 +90,31 @@ def all(data, axis=None, keepdims=False):
return cpp.all(data, axis, keepdims)
def any(data, axis=None, keepdims=False):
"""Logical OR of array elements over a given axis or a list of axes
Parameters
----------
data : tvm.Tensor
The input tvm boolean tensor
axis : None or int or tuple of int
Axis or axes along which a logical OR is performed.
The default, axis=None, will perform logical OR over all elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
return cpp.any(data, axis, keepdims)
def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes
......
......@@ -300,6 +300,11 @@ TVM_REGISTER_GLOBAL("topi.all")
*rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.any")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});
/* Ops from transform.h */
TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
......@@ -52,6 +52,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "all":
B = topi.all(A, axis=axis, keepdims=keepdims)
elif type == "any":
B = topi.any(A, axis=axis, keepdims=keepdims)
elif type == "max":
B = topi.max(A1, axis=axis, keepdims=keepdims)
elif type == "min":
......@@ -86,6 +88,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "all" and dtype == 'bool':
out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
elif type == "any" and dtype == "bool":
out_npy = in_npy_map.any(axis=axis, keepdims=keepdims)
elif type == "max":
out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
elif type == "min":
......@@ -173,6 +177,26 @@ def test_reduce_map():
keepdims=True,
type="sum",
dtype="float64")
verify_reduce_map_ele(in_shape=(2, 3),
axis=None,
keepdims=True,
type="any",
dtype="bool")
verify_reduce_map_ele(in_shape=(32, 128, 24),
axis=None,
keepdims=True,
type="any",
dtype="bool")
verify_reduce_map_ele(in_shape=(1, 4, 7),
axis=1,
keepdims=True,
type="any",
dtype="bool")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=2,
keepdims=False,
type="any",
dtype="bool")
if __name__ == "__main__":
test_reduce_map()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment