Commit 8e2f229a by Yao Wang Committed by Zhi

[Topi]Allow empty tensor for reshape, tile and strided_slice (#4618)

* Support empty tensor

* Fix schedule

* Refactor

* Minor fix

* Fix pylint

* Merge cpp and python is_empty_shape
parent d5d63a44
......@@ -888,6 +888,7 @@ bool TakeRel(const Array<Type>& types,
CHECK(data != nullptr);
const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr);
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
......@@ -1648,6 +1649,9 @@ bool SqueezeRel(const Array<Type>& types,
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
if (!e.as<IntImm>()) {
LOG(FATAL) << "axis needs to be defined for dynamic input.";
}
const int64_t* axis_ptr = as_const_int(e);
CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
if (*axis_ptr != 1) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tensor_utils.h
* \brief Utility functions for handling tensor
*/
#ifndef TOPI_DETAIL_TENSOR_UTILS_H_
#define TOPI_DETAIL_TENSOR_UTILS_H_
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Check whether input shape has dimension of size 0;
*
* \param x Input shape
*
* \return True if the input shape is empty.
*/
inline bool is_empty_shape(const Array<Expr>& x) {
bool is_empty = false;
for (const auto& dim : x) {
if (auto int_dim = dim.as<IntImm>()) {
if (int_dim->value == 0) {
is_empty = true;
break;
}
}
}
return is_empty;
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_TENSOR_UTILS_H_
......@@ -34,6 +34,7 @@
#include "topi/tags.h"
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "topi/detail/tensor_utils.h"
#include "tvm/operation.h"
#include "tvm/expr_operator.h"
#include "tvm/data_layout.h"
......@@ -207,16 +208,28 @@ inline Tensor reshape(const Tensor& x,
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
Array<Expr> newshape_int32;
Array<Expr> target_shape;
for (const auto &ele : newshape) {
newshape_int32.push_back(cast(DataType::Int(32), ele));
if (ele.as<IntImm>()) {
target_shape.push_back(cast(DataType::Int(32), ele));
} else {
target_shape.push_back(ele);
}
}
if (is_empty_shape(target_shape)) {
return compute(target_shape,
[&](const Array<Var> &indices) { return tvm::cast(x->dtype, 0); },
name, tag);
} else {
return compute(
newshape_int32, [&](const Array<Var>& indices) {
return x(UnravelIndex(RavelIndex(Array<Expr>{indices.begin(), indices.end()}, newshape_int32),
target_shape, [&](const Array<Var>& indices) {
return x(UnravelIndex(
RavelIndex(Array<Expr>{indices.begin(), indices.end()}, target_shape),
x_shape));
}, name, tag);
}
}
/*!
......@@ -556,7 +569,7 @@ inline Tensor strided_slice(const Tensor& x,
int interval = std::abs(end_i - begin_i);
int slice_size = static_cast<int>((interval
+ std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
CHECK(stride_vec[i] < 0 ? (end_i < begin_i) : (begin_i < end_i))
CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i;
......@@ -938,6 +951,11 @@ inline Tensor tile(const Tensor& x,
for (size_t i = 0; i < tdim; ++i)
new_shape.push_back(data_shape[i] * reps_shape[i]);
if (is_empty_shape(new_shape)) {
return compute(new_shape,
[&](const Array<Var>& indices) { return tvm::cast(x->dtype, 0);},
name, tag);
} else {
return compute(
new_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
......@@ -950,6 +968,7 @@ inline Tensor tile(const Tensor& x,
}
return x(idx);
}, name, tag);
}
}
/*!
......
......@@ -18,6 +18,7 @@
"""Schedule for pooling operators"""
import tvm
from .. import generic
from ..util import is_empty_shape
@generic.schedule_injective_from_existing.register(["arm_cpu"])
def schedule_injective_from_existing(sch, out):
......@@ -68,6 +69,8 @@ def schedule_injective(outs):
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s)
if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x)
return s
......
......@@ -24,3 +24,4 @@ from . import x86
from . import generic
from . import rocm
from . import image
from . import util
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI for TOPI utility functions"""
from tvm._ffi.function import _init_api_prefix
_init_api_prefix("topi.cpp.util", "topi.util")
......@@ -18,6 +18,7 @@
"""Schedule for composition of injective operator"""
import tvm
from .. import generic, util
from ..util import is_empty_shape
@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
def schedule_injective_from_existing(sch, out):
......@@ -79,6 +80,7 @@ def schedule_injective(outs):
tvm.schedule.AutoInlineInjective(s)
for out in outs:
if not is_empty_shape(out.shape):
schedule_injective_from_existing(s, out)
return s
......
......@@ -21,7 +21,7 @@ from numbers import Integral
import tvm
from tvm.api import layout, bijective_layout
from . import tag
from . import tag, cpp
class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
......@@ -417,3 +417,19 @@ def make_idx(b, e, s, z, i):
(b - i) // tvm.abs(s),
(i - b) // s)
return tvm.if_then_else(tvm.expr.Or(bc, ec), 88, ss)
def is_empty_shape(shape):
"""Check whether an input shape has dimesion with size 0.
Parameter
---------
shape : list of Expr
Input shape
Returns
-------
is_empty: bool
Whether input shape is empty or has dimesion with size 0.
"""
return cpp.util.is_empty_shape(shape)
......@@ -19,6 +19,7 @@
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from ..util import is_empty_shape
@generic.schedule_injective_from_existing.register(["cpu"])
def schedule_injective_from_existing(sch, out):
......@@ -65,6 +66,8 @@ def schedule_injective(outs):
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if not is_empty_shape(x.shape):
schedule_injective_from_existing(s, x)
return s
......
......@@ -72,6 +72,8 @@
#include <topi/rocm/softmax.h>
#include <topi/rocm/normalization.h>
#include <topi/detail/tensor_utils.h>
namespace topi {
using namespace tvm;
......@@ -740,6 +742,12 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize")
*rv = topi::cuda::schedule_l2_normalize(args[0], args[1]);
});
/* Utility functions */
TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::detail::is_empty_shape(args[0]);
});
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
......
......@@ -555,6 +555,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
def test_strided_set():
verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])
......@@ -596,6 +597,7 @@ def test_reshape():
verify_reshape((4, 2, 3, 4), (2, 4, 12))
verify_reshape((4, 2, 3, 4), (2, 48))
verify_reshape((16, ), (2, 2, 2, 2))
verify_reshape((4, 0), (2, 0, 2))
def test_where():
......@@ -718,6 +720,7 @@ def test_tile():
verify_tile((3, 2), (2, 3))
verify_tile((3, 2, 5), (2,))
verify_tile((3, ), (2, 3, 3))
verify_tile((4, 0), (5,))
def test_layout_transform():
in_shape = (1, 32, 8, 8)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment