Unverified Commit f7beea4b by Tianqi Chen Committed by GitHub

[OP] Fix reduce op problem when axis=None (#2436)

parent ac3b5bb3
...@@ -101,6 +101,7 @@ inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, ...@@ -101,6 +101,7 @@ inline std::vector<int64_t> GetReduceAxes(const uint32_t indim,
// Get axis under exclude condition. // Get axis under exclude condition.
Array<Integer> GetExcludeAxes(size_t indim, Array<Integer> GetExcludeAxes(size_t indim,
const Array<Integer>& inaxis) { const Array<Integer>& inaxis) {
CHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
std::vector<bool> axis_flag(indim, true); std::vector<bool> axis_flag(indim, true);
for (auto i : inaxis) { for (auto i : inaxis) {
int64_t axis = i->value; int64_t axis = i->value;
...@@ -137,9 +138,9 @@ Array<Tensor> ReduceCompute(const Attrs& attrs, ...@@ -137,9 +138,9 @@ Array<Tensor> ReduceCompute(const Attrs& attrs,
auto axes = param->axis; auto axes = param->axis;
if (param->exclude) { if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
} if (axes.size() == 0) {
if (axes.size() == 0) { return { topi::identity(inputs[0]) };
return { topi::identity(inputs[0]) }; }
} }
return { f(inputs[0], axes, param->keepdims, false) }; return { f(inputs[0], axes, param->keepdims, false) };
} }
......
...@@ -135,7 +135,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") ...@@ -135,7 +135,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
assert zz.checked_type == relay.ty.TensorType(output, out_type) assert zz.checked_type == relay.ty.TensorType(output, out_type)
if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0: if all(isinstance(v, tvm.expr.Var) == 1 for v in data):
return return
func = relay.Function([x], z) func = relay.Function([x], z)
...@@ -187,7 +187,7 @@ def test_reduce_functions(): ...@@ -187,7 +187,7 @@ def test_reduce_functions():
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, True, ()) verify_reduce(func, (4, 4, 3), None, False, False, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,)) verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment