Commit 8e2f229a by Yao Wang Committed by Zhi

[Topi]Allow empty tensor for reshape, tile and strided_slice (#4618)

* Support empty tensor

* Fix schedule

* Refactor

* Minor fix

* Fix pylint

* Merge cpp and python is_empty_shape
parent d5d63a44
...@@ -888,6 +888,7 @@ bool TakeRel(const Array<Type>& types, ...@@ -888,6 +888,7 @@ bool TakeRel(const Array<Type>& types,
CHECK(data != nullptr); CHECK(data != nullptr);
const auto* indices = types[1].as<TensorTypeNode>(); const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr); CHECK(indices != nullptr);
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<TakeAttrs>(); const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
...@@ -1648,6 +1649,9 @@ bool SqueezeRel(const Array<Type>& types, ...@@ -1648,6 +1649,9 @@ bool SqueezeRel(const Array<Type>& types,
// if axes is None, squeeze all axes of dimension 1 // if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) { if (!param->axis.defined()) {
for (const auto& e : data->shape) { for (const auto& e : data->shape) {
if (!e.as<IntImm>()) {
LOG(FATAL) << "axis needs to be defined for dynamic input.";
}
const int64_t* axis_ptr = as_const_int(e); const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete"; CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
if (*axis_ptr != 1) { if (*axis_ptr != 1) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tensor_utils.h
* \brief Utility functions for handling tensor
*/
#ifndef TOPI_DETAIL_TENSOR_UTILS_H_
#define TOPI_DETAIL_TENSOR_UTILS_H_
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Check whether input shape has dimension of size 0;
*
* \param x Input shape
*
* \return True if the input shape is empty.
*/
inline bool is_empty_shape(const Array<Expr>& x) {
bool is_empty = false;
for (const auto& dim : x) {
if (auto int_dim = dim.as<IntImm>()) {
if (int_dim->value == 0) {
is_empty = true;
break;
}
}
}
return is_empty;
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_TENSOR_UTILS_H_
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "topi/tags.h" #include "topi/tags.h"
#include "topi/detail/ravel_unravel.h" #include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "topi/detail/tensor_utils.h"
#include "tvm/operation.h" #include "tvm/operation.h"
#include "tvm/expr_operator.h" #include "tvm/expr_operator.h"
#include "tvm/data_layout.h" #include "tvm/data_layout.h"
...@@ -207,16 +208,28 @@ inline Tensor reshape(const Tensor& x, ...@@ -207,16 +208,28 @@ inline Tensor reshape(const Tensor& x,
std::string name = "T_reshape", std::string name = "T_reshape",
std::string tag = kInjective) { std::string tag = kInjective) {
auto x_shape = x->shape; auto x_shape = x->shape;
Array<Expr> newshape_int32; Array<Expr> target_shape;
for (const auto &ele : newshape) { for (const auto &ele : newshape) {
newshape_int32.push_back(cast(DataType::Int(32), ele)); if (ele.as<IntImm>()) {
target_shape.push_back(cast(DataType::Int(32), ele));
} else {
target_shape.push_back(ele);
}
} }
if (is_empty_shape(target_shape)) {
return compute(target_shape,
[&](const Array<Var> &indices) { return tvm::cast(x->dtype, 0); },
name, tag);
} else {
return compute( return compute(
newshape_int32, [&](const Array<Var>& indices) { target_shape, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32), return x(UnravelIndex(
RavelIndex(Array<Expr>{indices.begin(), indices.end()}, target_shape),
x_shape)); x_shape));
}, name, tag); }, name, tag);
}
} }
/*! /*!
...@@ -556,7 +569,7 @@ inline Tensor strided_slice(const Tensor& x, ...@@ -556,7 +569,7 @@ inline Tensor strided_slice(const Tensor& x,
int interval = std::abs(end_i - begin_i); int interval = std::abs(end_i - begin_i);
int slice_size = static_cast<int>((interval int slice_size = static_cast<int>((interval
+ std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
CHECK(stride_vec[i] < 0 ? (end_i < begin_i) : (begin_i < end_i)) CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i; << "] is invalid for axis=" << i;
...@@ -938,6 +951,11 @@ inline Tensor tile(const Tensor& x, ...@@ -938,6 +951,11 @@ inline Tensor tile(const Tensor& x,
for (size_t i = 0; i < tdim; ++i) for (size_t i = 0; i < tdim; ++i)
new_shape.push_back(data_shape[i] * reps_shape[i]); new_shape.push_back(data_shape[i] * reps_shape[i]);
if (is_empty_shape(new_shape)) {
return compute(new_shape,
[&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0);},
name, tag);
} else {
return compute( return compute(
new_shape, [&](const Array<Var>& indices) { new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx; Array<Expr> idx;
...@@ -950,6 +968,7 @@ inline Tensor tile(const Tensor& x, ...@@ -950,6 +968,7 @@ inline Tensor tile(const Tensor& x,
} }
return x(idx); return x(idx);
}, name, tag); }, name, tag);
}
} }
/*! /*!
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""Schedule for pooling operators""" """Schedule for pooling operators"""
import tvm import tvm
from .. import generic from .. import generic
from ..util import is_empty_shape
@generic.schedule_injective_from_existing.register(["arm_cpu"]) @generic.schedule_injective_from_existing.register(["arm_cpu"])
def schedule_injective_from_existing(sch, out): def schedule_injective_from_existing(sch, out):
...@@ -68,6 +69,8 @@ def schedule_injective(outs): ...@@ -68,6 +69,8 @@ def schedule_injective(outs):
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8) (io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii) s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x) schedule_injective_from_existing(s, x)
return s return s
......
...@@ -24,3 +24,4 @@ from . import x86 ...@@ -24,3 +24,4 @@ from . import x86
from . import generic from . import generic
from . import rocm from . import rocm
from . import image from . import image
from . import util
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI for TOPI utility functions"""
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.util", "topi.util")
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""Schedule for composition of injective operator""" """Schedule for composition of injective operator"""
import tvm import tvm
from .. import generic, util from .. import generic, util
from ..util import is_empty_shape
@generic.schedule_injective_from_existing.register(["cuda", "gpu"]) @generic.schedule_injective_from_existing.register(["cuda", "gpu"])
def schedule_injective_from_existing(sch, out): def schedule_injective_from_existing(sch, out):
...@@ -79,6 +80,7 @@ def schedule_injective(outs): ...@@ -79,6 +80,7 @@ def schedule_injective(outs):
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
for out in outs: for out in outs:
if not is_empty_shape(out.shape):
schedule_injective_from_existing(s, out) schedule_injective_from_existing(s, out)
return s return s
......
...@@ -21,7 +21,7 @@ from numbers import Integral ...@@ -21,7 +21,7 @@ from numbers import Integral
import tvm import tvm
from tvm.api import layout, bijective_layout from tvm.api import layout, bijective_layout
from . import tag from . import tag, cpp
class InvalidShapeError(ValueError): class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
...@@ -417,3 +417,19 @@ def make_idx(b, e, s, z, i): ...@@ -417,3 +417,19 @@ def make_idx(b, e, s, z, i):
(b - i) // tvm.abs(s), (b - i) // tvm.abs(s),
(i - b) // s) (i - b) // s)
return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss) return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss)
def is_empty_shape(shape):
"""Check whether an input shape has dimesion with size 0.
Parameter
---------
shape : list of Expr
Input shape
Returns
-------
is_empty: bool
Whether input shape is empty or has dimesion with size 0.
"""
return cpp.util.is_empty_shape(shape)
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from .. import generic from .. import generic
from ..util import is_empty_shape
@generic.schedule_injective_from_existing.register(["cpu"]) @generic.schedule_injective_from_existing.register(["cpu"])
def schedule_injective_from_existing(sch, out): def schedule_injective_from_existing(sch, out):
...@@ -65,6 +66,8 @@ def schedule_injective(outs): ...@@ -65,6 +66,8 @@ def schedule_injective(outs):
x = outs[0] x = outs[0]
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x) schedule_injective_from_existing(s, x)
return s return s
......
...@@ -72,6 +72,8 @@ ...@@ -72,6 +72,8 @@
#include <topi/rocm/softmax.h> #include <topi/rocm/softmax.h>
#include <topi/rocm/normalization.h> #include <topi/rocm/normalization.h>
#include <topi/detail/tensor_utils.h>
namespace topi { namespace topi {
using namespace tvm; using namespace tvm;
...@@ -740,6 +742,12 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize") ...@@ -740,6 +742,12 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize")
*rv = topi::cuda::schedule_l2_normalize(args[0], args[1]); *rv = topi::cuda::schedule_l2_normalize(args[0], args[1]);
}); });
/* Utility functions */
TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::detail::is_empty_shape(args[0]);
});
/*! \brief Builder function for instantiating schedules. */ /*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function< using FTVMScheduleBuilder = std::function<
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>; tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
......
...@@ -555,6 +555,7 @@ def test_strided_slice(): ...@@ -555,6 +555,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2]) verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
def test_strided_set(): def test_strided_set():
verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2]) verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])
...@@ -596,6 +597,7 @@ def test_reshape(): ...@@ -596,6 +597,7 @@ def test_reshape():
verify_reshape((4, 2, 3, 4), (2, 4, 12)) verify_reshape((4, 2, 3, 4), (2, 4, 12))
verify_reshape((4, 2, 3, 4), (2, 48)) verify_reshape((4, 2, 3, 4), (2, 48))
verify_reshape((16, ), (2, 2, 2, 2)) verify_reshape((16, ), (2, 2, 2, 2))
verify_reshape((4, 0), (2, 0, 2))
def test_where(): def test_where():
...@@ -718,6 +720,7 @@ def test_tile(): ...@@ -718,6 +720,7 @@ def test_tile():
verify_tile((3, 2), (2, 3)) verify_tile((3, 2), (2, 3))
verify_tile((3, 2, 5), (2,)) verify_tile((3, 2, 5), (2,))
verify_tile((3, ), (2, 3, 3)) verify_tile((3, ), (2, 3, 3))
verify_tile((4, 0), (5,))
def test_layout_transform(): def test_layout_transform():
in_shape = (1, 32, 8, 8) in_shape = (1, 32, 8, 8)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment