Unverified Commit f7beea4b by Tianqi Chen Committed by GitHub

[OP] Fix reduce op problem when axis=None (#2436)

parent ac3b5bb3
......@@ -101,6 +101,7 @@ inline std::vector<int64_t> GetReduceAxes(const uint32_t indim,
// Get axis under exclude condition.
Array<Integer> GetExcludeAxes(size_t indim,
const Array<Integer>& inaxis) {
CHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
std::vector<bool> axis_flag(indim, true);
for (auto i : inaxis) {
int64_t axis = i->value;
......@@ -137,9 +138,9 @@ Array<Tensor> ReduceCompute(const Attrs& attrs,
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
}
if (axes.size() == 0) {
return { topi::identity(inputs[0]) };
if (axes.size() == 0) {
return { topi::identity(inputs[0]) };
}
}
return { f(inputs[0], axes, param->keepdims, false) };
}
......
......@@ -135,7 +135,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
assert zz.checked_type == relay.ty.TensorType(output, out_type)
if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0:
if all(isinstance(v, tvm.expr.Var) == 1 for v in data):
return
func = relay.Function([x], z)
......@@ -187,7 +187,7 @@ def test_reduce_functions():
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, True, ())
verify_reduce(func, (4, 4, 3), None, False, False, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment