Commit 493fc040 by Zhi Committed by Tianqi Chen

Add relay.where (#1869)

parent 65016b65
......@@ -98,6 +98,7 @@ This level enables additional math and transform operators.
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.pow
tvm.relay.where
**Level 5: Vision/Image Operators**
......@@ -173,6 +174,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.maximum
.. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.where
Level 5 Definitions
-------------------
......
......@@ -180,3 +180,42 @@ def full_like(data, fill_value):
The resulting tensor.
"""
return _make.full_like(data, fill_value)
def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
Parameters
----------
condition : relay.Expr
The condition array. The n-th element in `y` is selected when the n-th
value in the `condition` array is zero. Otherwise, the corresponding
element from `x` will be picked.
x : relay.Expr
The first array to be selected.
y : relay.Expr
The second array to be selected.
Returns
-------
result : relay.Expr
The selected array.
Examples
--------
.. code-block:: python
x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
condition = [[0, 1], [-1, 0]]
relay.where(conditon, x, y) = [[5, 2], [3, 8]]
condition = [1, 0]
relay.where(conditon, x, y) = [[1, 2], [7, 8]]
Note that the shape of condition, x, and y needs to be the same.
"""
return _make.where(condition, x, y)
......@@ -498,5 +498,85 @@ and type as the input array.
.set_support_level(3)
.add_type_rel("FullLike", FullLikeRel);
// where operator
bool WhereRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4U);
const auto* condition = types[0].as<TensorTypeNode>();
const auto* x = types[1].as<TensorTypeNode>();
const auto* y = types[2].as<TensorTypeNode>();
CHECK(condition != nullptr && x != nullptr && y != nullptr);
const auto& cond_shape = condition->shape;
const auto& x_shape = x->shape;
const auto& y_shape = y->shape;
CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size";
if (cond_shape.size() != x_shape.size()) {
CHECK_EQ(cond_shape.size(), 1)
<< "Shape of condition " << condition->shape
<< " must be either equal to x or has dimension of 1.";
}
for (size_t i = 0; i < x_shape.size(); i++) {
CHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
<< "Shape of condition " << condition->shape
<< " must be either equal to x or has dimension of 1.";
}
reporter->Assign(types[3], TensorTypeNode::make(x_shape, x->dtype));
return true;
}
// Positional relay function to create where operator.
Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) {
static const Op& op = Op::Get("where");
return CallNode::make(op, {condition, x, y});
}
TVM_REGISTER_API("relay.op._make.where")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv);
});
RELAY_REGISTER_OP("where")
.describe(R"code(
Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.
If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.
Note that all non-zero values are interpreted as True in condition.
Examples::
x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
cond = [[0, 1], [-1, 0]]
where(cond, x, y) = [[5, 2], [3, 8]]
cond = [1, 0]
where(cond, x, y) = [[1, 2], [7, 8]]
)code" TVM_ADD_FILELINE)
.add_argument("condition", "Tensor", "Condition array")
.add_argument("x", "Tensor", "First array to be selected")
.add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3)
.set_support_level(4)
.add_type_rel("Where", WhereRel);
} // namespace relay
} // namespace tvm
......@@ -125,8 +125,22 @@ def test_binary_broadcast():
assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32")
def test_where():
ib = relay.ir_builder.IRBuilder()
cond = ib.param("cond", relay.TensorType((3, 4), "float32"))
x = ib.param("x", relay.TensorType((3, 4), "float32"))
y = ib.param("y", relay.TensorType((3, 4), "float32"))
with ib.function(cond, x, y) as func:
ib.ret(relay.where(cond.var, x.var, y.var))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((3, 4), "float32")
if __name__ == "__main__":
test_cmp_type()
test_binary_broadcast()
test_binary_op()
test_binary_broadcast_op()
test_where()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment