Commit fa4d3ec6 by Wei Chen Committed by Haichen Shen

[TOPI]Add op argwhere (#3994)

* Add op argwhere

* Move shape func to _algorithm.py

* Add lint rule

* Raise exception if rank is not supportted

* move argwhere to transform

* Add argwhere example

* Fix lint

* Add 1-d support

* cleanup

* Add more dtype support

* CR comment

* Improve error message

* Docs

* raise exception
parent 5cc17649
......@@ -314,6 +314,12 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
}
}; // struct OneHotAttrs
/*! \brief Attributes for ArgWhere operator */
struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
}
}; // struct ArgWhereAttrs
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
......@@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
from __future__ import absolute_import
import tvm
import topi
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ._reduce import _schedule_reduce
......@@ -204,3 +206,100 @@ def take_shape_func(attrs, inputs, out_ndims):
axis += data_ndim
assert 0 <= axis < data_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
@script
def _argwhere_shape_func_1d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(1)
for i1 in range(condition.shape[0]):
if condition[i1] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_2d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(2)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
if condition[i1, i2] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_3d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(3)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
if condition[i1, i2, i3] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_4d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(4)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
for i4 in range(condition.shape[3]):
if condition[i1, i2, i3, i4] != 0:
out[0] += int64(1)
return out
@script
def _argwhere_shape_func_5d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(5)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
for i4 in range(condition.shape[3]):
for i5 in range(condition.shape[4]):
if condition[i1, i2, i3, i4, i5] != 0:
out[0] += int64(1)
return out
@_reg.register_shape_func("argwhere", True)
def argwhere_shape_func(attrs, inputs, out_ndims):
"""
Shape function for argwhere.
"""
if len(inputs[0].shape) == 1:
return [_argwhere_shape_func_1d(inputs[0])]
elif len(inputs[0].shape) == 2:
return [_argwhere_shape_func_2d(inputs[0])]
elif len(inputs[0].shape) == 3:
return [_argwhere_shape_func_3d(inputs[0])]
elif len(inputs[0].shape) == 4:
return [_argwhere_shape_func_4d(inputs[0])]
elif len(inputs[0].shape) == 5:
return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere")
@_reg.register_schedule("argwhere")
def schedule_argwhere(_, outs, target):
"""Schedule definition of argwhere"""
with target:
return topi.generic.schedule_argwhere(outs)
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type, _):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
......@@ -144,7 +144,6 @@ def squeeze(data, axis=None):
"""
return _make.squeeze(data, axis)
def reshape(data, newshape):
"""Reshapes the input array.
......@@ -214,6 +213,28 @@ def reshape(data, newshape):
newshape = [newshape]
return _make.reshape(data, list(newshape))
def argwhere(condition):
"""Find the indices of elements of a tensor that are
non-zero.
Parameters
----------
condition : relay.Expr
The input condition tensor.
Returns
-------
out : relay.Expr
Tensor with the indices of elements that are non-zero.
Examples
--------
.. code-block:: python
condition = [[True, False], [False, True]]
relay.argwhere(condition) = [[0, 0], [1, 1]]
"""
return _make.argwhere(condition)
def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
......
......@@ -817,6 +817,40 @@ the input array into an output array with the same shape as the second input arr
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// ArgWhere
bool ArgWhereRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt != nullptr);
const auto& input_shape = tt->shape;
const auto& input_rank = input_shape.size();
std::vector<IndexExpr> result_shape;
result_shape.push_back(Any::make());
result_shape.push_back(IntImm::make(Int(32), input_rank));
reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32)));
return true;
}
TVM_REGISTER_API("relay.op._make.argwhere")
.set_body_typed<Expr(Expr)>([](Expr data) {
static const Op& op = Op::Get("argwhere");
auto attrs = make_node<ArgWhereAttrs>();
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("argwhere")
.describe(R"doc(Find the indices of elements of a tensor that are
non-zero)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ArgWhereAttrs")
.add_argument("condition", "Tensor", "The input condition tensor.")
.add_type_rel("ArgWhere", ArgWhereRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);
// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
......
......@@ -92,6 +92,36 @@ def test_any_reshape():
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
x = relay.var('x', shape=x_shape, dtype=dtype)
y = relay.argwhere(x)
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data).asnumpy()
expected = np.argwhere(data)
assert result.shape == expected.shape
tvm.testing.assert_allclose(result.flatten(), expected.flatten())
def test_any_argwhere():
verify_any_argwhere(any_dims(1), (5,))
verify_any_argwhere(any_dims(2), (5, 5))
verify_any_argwhere(any_dims(3), (5, 5, 5))
verify_any_argwhere(any_dims(4), (5, 5, 5, 5))
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5))
verify_any_argwhere(any_dims(1), (5,), "int32")
verify_any_argwhere(any_dims(2), (5, 5), "int32")
verify_any_argwhere(any_dims(3), (5, 5, 5), "int32")
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32")
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32")
verify_any_argwhere(any_dims(1), (5,), "int8")
verify_any_argwhere(any_dims(2), (5, 5), "int8")
verify_any_argwhere(any_dims(3), (5, 5, 5), "int8")
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")
def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
mod = relay.Module()
data = relay.var('data', shape=data_shape, dtype='float32')
......
......@@ -22,6 +22,7 @@ from .reduction import *
from .transform import *
from .broadcast import *
from .sort import *
from .argwhere import *
from . import nn
from . import x86
from . import cuda
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Argwhere operator"""
import tvm
from tvm import hybrid
@hybrid.script
def hybrid_argwhere_1d(output_shape, condition):
"""Find the indices of elements of a 1-D tensor that are non-zero.
Parameters
----------
condition : tvm.Tensor
1-D tensor with boolean values.
Returns
-------
out : tvm.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
valid_index = 0
for i1 in range(a1):
if condition[i1] != 0:
a[valid_index, 0] = i1
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_2d(output_shape, condition):
"""Find the indices of elements of a 2-D tensor that are non-zero.
Parameters
----------
condition : tvm.Tensor
2-D tensor with boolean values.
Returns
-------
out : tvm.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
if condition[i1, i2] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_3d(output_shape, condition):
"""Find the indices of elements of a 3-D tensor that are non-zero.
Parameters
----------
condition : tvm.Tensor
3-D tensor with boolean values.
Returns
-------
out : tvm.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
a3 = condition.shape[2]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
for i3 in range(a3):
if condition[i1, i2, i3] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
a[valid_index, 2] = i3
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_4d(output_shape, condition):
"""Find the indices of elements of a 4-D tensor that are non-zero.
Parameters
----------
condition : tvm.Tensor
4-D tensor with boolean values.
Returns
-------
out : tvm.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
a3 = condition.shape[2]
a4 = condition.shape[3]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
for i3 in range(a3):
for i4 in range(a4):
if condition[i1, i2, i3, i4] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
a[valid_index, 2] = i3
a[valid_index, 3] = i4
valid_index += 1
return a
@hybrid.script
def hybrid_argwhere_5d(output_shape, condition):
"""Find the indices of elements of a 5-D tensor that are non-zero.
Parameters
----------
condition : tvm.Tensor
5-D tensor with boolean values.
Returns
-------
out : tvm.Tensor
Indices of non-zero elements.
"""
a = output_tensor(output_shape, "int32")
a1 = condition.shape[0]
a2 = condition.shape[1]
a3 = condition.shape[2]
a4 = condition.shape[3]
a5 = condition.shape[4]
valid_index = 0
for i1 in range(a1):
for i2 in range(a2):
for i3 in range(a3):
for i4 in range(a4):
for i5 in range(a5):
if condition[i1, i2, i3, i4, i5] != 0:
a[valid_index, 0] = i1
a[valid_index, 1] = i2
a[valid_index, 2] = i3
a[valid_index, 3] = i4
a[valid_index, 4] = i5
valid_index += 1
return a
@tvm.target.generic_func
def argwhere(output_shape, condition):
"""Find the indices of elements of a tensor that are non-zero.
Parameters
----------
condition : tvm.Tensor
Tensor with boolean values.
Returns
-------
out : tvm.Tensor
Indices of non-zero elements.
"""
if len(condition.shape) == 1:
return hybrid_argwhere_1d(output_shape.shape, condition)
if len(condition.shape) == 2:
return hybrid_argwhere_2d(output_shape.shape, condition)
if len(condition.shape) == 3:
return hybrid_argwhere_3d(output_shape.shape, condition)
if len(condition.shape) == 4:
return hybrid_argwhere_4d(output_shape.shape, condition)
if len(condition.shape) == 5:
return hybrid_argwhere_5d(output_shape.shape, condition)
raise ValueError("Does not support rank higher than 5 in argwhere")
......@@ -20,3 +20,4 @@ from .injective import *
from .extern import *
from .vision import *
from .sort import *
from .search import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member
"""Generic search operators"""
from __future__ import absolute_import as _abs
import tvm
from .vision import _default_schedule
@tvm.target.generic_func
def schedule_argwhere(outs):
"""Schedule for argwhere operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argwhere.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment