Commit ee3c1b09 by Yao Wang Committed by Tianqi Chen

[TOPI]Add where operator (#1416)

parent 6ea74d41
...@@ -72,3 +72,7 @@ reg.register_schedule("strided_slice", _fschedule_injective) ...@@ -72,3 +72,7 @@ reg.register_schedule("strided_slice", _fschedule_injective)
# slice_like # slice_like
reg.register_pattern("slice_like", OpPattern.INJECTIVE) reg.register_pattern("slice_like", OpPattern.INJECTIVE)
reg.register_schedule("slice_like", _fschedule_injective) reg.register_schedule("slice_like", _fschedule_injective)
# where
reg.register_pattern("where", OpPattern.INJECTIVE)
reg.register_schedule("where", _fschedule_injective)
...@@ -1125,8 +1125,8 @@ Examples:: ...@@ -1125,8 +1125,8 @@ Examples::
DMLC_REGISTER_PARAMETER(SliceLikeParam); DMLC_REGISTER_PARAMETER(SliceLikeParam);
inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs, std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) { std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U);
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed); const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
...@@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like) ...@@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like)
}) })
.set_support_level(4); .set_support_level(4);
// where
inline bool WhereShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& cond_shape = in_attrs->at(0);
const TShape& x_shape = in_attrs->at(1);
const TShape& y_shape = in_attrs->at(2);
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: "
<< x_shape << " vs " << y_shape;
if (cond_shape != x_shape) {
CHECK_EQ(cond_shape.ndim(), 1)
<< "Shape of condition " << cond_shape
<< " must be either equal to x or has dimension of 1.";
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape);
return true;
}
inline bool WhereInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1));
return true;
}
inline bool WhereCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}
return true;
}
NNVM_REGISTER_OP(where)
.describe(R"code(
Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.
If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.
Note that all non-zero values are interpreted as True in condition.
Examples::
x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
cond = [[0, 1], [-1, 0]]
where(cond, x, y) = [[5, 2], [3, 8]]
cond = [1, 0]
where(cond, x, y) = [[1, 2], [7, 8]]
)code" NNVM_ADD_FILELINE)
.add_argument("condition", "Tensor", "Condition array")
.add_argument("x", "Tensor", "First array to be selected")
.add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", WhereShape)
.set_attr<FInferType>("FInferType", WhereInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", WhereCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{
topi::where(inputs[0], inputs[1], inputs[2])
};
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"condition", "x", "y"};
})
.set_support_level(4);
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -645,6 +645,36 @@ def test_slice_like(): ...@@ -645,6 +645,36 @@ def test_slice_like():
axis = (2, 3) axis = (2, 3)
verify_slice_like(np_data, np_shape_like, axis) verify_slice_like(np_data, np_shape_like, axis)
def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
cond_var = sym.Variable("condition")
x_var = sym.Variable("x")
y_var = sym.Variable("y")
net = sym.where(cond_var, x_var, y_var)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, {"condition": condition.shape,
"x": x.shape, "y": y.shape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"condition": condition, "x": x, "y": y})
m.run()
out = m.get_output(0, tvm.nd.empty(x.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
def test_where():
shape = (13, 8, 224, 224, 6)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
if __name__ == "__main__": if __name__ == "__main__":
test_reshape() test_reshape()
...@@ -665,4 +695,5 @@ if __name__ == "__main__": ...@@ -665,4 +695,5 @@ if __name__ == "__main__":
test_multibox_transform_loc() test_multibox_transform_loc()
test_nms() test_nms()
test_slice_like() test_slice_like()
test_where()
print(nnvm.compiler.engine.dump()) print(nnvm.compiler.engine.dump())
...@@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a, ...@@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a,
}, name, tag); }, name, tag);
} }
/*!
* \brief Return the elements, either from x or y, depending on the condition.
*
* \param condition The condition array.
* \param x First array to be selected.
* \param y Second array to be selected.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor selected from x or y depending on condition.
*/
inline Tensor where(const Tensor& condition,
const Tensor& x,
const Tensor& y,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: "
<< x->shape.size() << " vs " << y->shape.size();
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: "
<< x->dtype << " vs " << y->dtype;
Array<Expr> oshape = x->shape;
Tensor out;
if (condition->shape.size() != 1) {
CHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
out = compute(
oshape, [&](const Array<Var>& indices) {
return tvm::select(condition(indices) != 0, x(indices), y(indices));
}, name, tag);
} else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
<< "If condition is 1-D, the first dimension must be the same as x: "
<< condition->shape[0] << " vs " << x->shape[0];
out = compute(
oshape, [&](const Array<Var>& indices) {
Array<Expr> condition_idx{indices[0]};
return tvm::select(condition(condition_idx) != 0,
x(indices), y(indices));
}, name, tag);
}
return out;
}
} // namespace topi } // namespace topi
#endif // TOPI_TRANSFORM_H_ #endif // TOPI_TRANSFORM_H_
...@@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take") ...@@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
} }
}); });
TVM_REGISTER_GLOBAL("topi.where")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = where(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.strided_slice") TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]); *rv = strided_slice(args[0], args[1], args[2], args[3]);
......
...@@ -206,6 +206,35 @@ def verify_take(src_shape, indices_src, axis=None): ...@@ -206,6 +206,35 @@ def verify_take(src_shape, indices_src, axis=None):
for device in ["llvm", "opencl"]: for device in ["llvm", "opencl"]:
check_device(device) check_device(device)
def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
A = tvm.placeholder(shape=condition.shape, dtype=dtype, name="condition")
B = tvm.placeholder(shape=x.shape, dtype=dtype, name="x")
C = tvm.placeholder(shape=y.shape, dtype=dtype, name="y")
out_tensor = topi.cpp.where(A, B, C)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)
foo = tvm.build(s, [A, B, C, out_tensor], device, name="where")
tvm_out = tvm.nd.empty(x.shape, ctx=ctx, dtype=dtype)
foo(tvm.nd.array(condition, ctx), tvm.nd.array(x, ctx),
tvm.nd.array(y, ctx), tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_concatenate_split(shapes, axis, indices_or_sections): def verify_concatenate_split(shapes, axis, indices_or_sections):
tensor_l_concatenate = [] tensor_l_concatenate = []
for i, shape in enumerate(shapes): for i, shape in enumerate(shapes):
...@@ -324,6 +353,18 @@ def test_take(): ...@@ -324,6 +353,18 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2) verify_take((4,3,5,6), [[2,1,0,0]], -2)
def test_where():
shape = (10, 3, 7, 13)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
def test_regression_1(): def test_regression_1():
verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7]) verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7])
verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4]) verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4])
...@@ -340,5 +381,6 @@ if __name__ == "__main__": ...@@ -340,5 +381,6 @@ if __name__ == "__main__":
test_squeeze() test_squeeze()
test_split() test_split()
test_take() test_take()
test_where()
test_regression_1() test_regression_1()
test_regression_2() test_regression_2()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment