Unverified Commit f3ae3f20 by Tianqi Chen Committed by GitHub

[TOPI] Fix atlest1d for reduce and squeeze (#2147)

parent bac22073
...@@ -28,6 +28,17 @@ inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) { ...@@ -28,6 +28,17 @@ inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) {
return result; return result;
} }
/*
* \brief Helper function to convert TShape to TVM array. Useful for
* passing data from NNVM param structures to TOPI ops.
*
* \param shape The shape to convert
*
* \return An Array of Expr, where each element is a constant int32
*/
inline tvm::Array<tvm::Integer> ShapeToIntArray(TShape shape) {
return tvm::Array<tvm::Integer>(ShapeToArray(shape).node_);
}
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
#endif // NNVM_COMPILER_UTIL_H_ #endif // NNVM_COMPILER_UTIL_H_
...@@ -3,9 +3,6 @@ ...@@ -3,9 +3,6 @@
* \file reduce.cc * \file reduce.cc
* \brief reduce operator. * \brief reduce operator.
*/ */
// Enforce TOPI to use old behavior that reduces to at least 1d
#define TOPI_REDUCE_ATLEAST1D 1
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/node.h> #include <nnvm/node.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
...@@ -20,13 +17,12 @@ ...@@ -20,13 +17,12 @@
#include "topi/reduction.h" #include "topi/reduction.h"
#include "topi/transform.h" #include "topi/transform.h"
static_assert(TOPI_REDUCE_ATLEAST1D, "need to use legacy reduce behavior");
namespace nnvm { namespace nnvm {
namespace top { namespace top {
using namespace tvm; using namespace tvm;
using namespace nnvm::compiler; using namespace nnvm::compiler;
// reduce // reduce
DMLC_REGISTER_PARAMETER(ReduceParam); DMLC_REGISTER_PARAMETER(ReduceParam);
...@@ -168,9 +164,9 @@ Example:: ...@@ -168,9 +164,9 @@ Example::
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) }; if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes); auto axis = ShapeToIntArray(r_axes);
return Array<Tensor>{ return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) }; topi::sum(inputs[0], axis, param.keepdims, true) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -202,9 +198,9 @@ NNVM_REGISTER_REDUCE_OP(max) ...@@ -202,9 +198,9 @@ NNVM_REGISTER_REDUCE_OP(max)
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
auto axis = ShapeToArray(r_axes); auto axis = ShapeToIntArray(r_axes);
return Array<Tensor>{ return Array<Tensor>{
topi::max(inputs[0], axis, param.keepdims) }; topi::max(inputs[0], axis, param.keepdims, true) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -235,9 +231,9 @@ NNVM_REGISTER_REDUCE_OP(min) ...@@ -235,9 +231,9 @@ NNVM_REGISTER_REDUCE_OP(min)
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
auto axis = ShapeToArray(r_axes); auto axis = ShapeToIntArray(r_axes);
return Array<Tensor>{ return Array<Tensor>{
topi::min(inputs[0], axis, param.keepdims) }; topi::min(inputs[0], axis, param.keepdims, true) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -299,8 +295,8 @@ values over a given axis. ...@@ -299,8 +295,8 @@ values over a given axis.
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
auto axis = ShapeToArray(r_axes); auto axis = ShapeToIntArray(r_axes);
Tensor out = topi::argmax(inputs[0], axis, param.keepdims); Tensor out = topi::argmax(inputs[0], axis, param.keepdims, true);
if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype); if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
return Array<Tensor>{out}; return Array<Tensor>{out};
}); });
...@@ -322,8 +318,8 @@ values over a given axis. ...@@ -322,8 +318,8 @@ values over a given axis.
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
auto axis = ShapeToArray(r_axes); auto axis = ShapeToIntArray(r_axes);
Tensor out = topi::argmin(inputs[0], axis, param.keepdims); Tensor out = topi::argmin(inputs[0], axis, param.keepdims, true);
if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype); if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
return Array<Tensor>{out}; return Array<Tensor>{out};
}); });
...@@ -352,7 +348,7 @@ Example:: ...@@ -352,7 +348,7 @@ Example::
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) }; if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes); auto axis = ShapeToIntArray(r_axes);
Expr count = make_const(inputs[0]->dtype, 1); Expr count = make_const(inputs[0]->dtype, 1);
for (auto& i : r_axes) { for (auto& i : r_axes) {
...@@ -360,7 +356,7 @@ Example:: ...@@ -360,7 +356,7 @@ Example::
} }
return Array<Tensor>{ return Array<Tensor>{
topi::divide(topi::sum(inputs[0], axis, param.keepdims), count) }; topi::divide(topi::sum(inputs[0], axis, param.keepdims, true), count) };
}); });
NNVM_REGISTER_REDUCE_OP(prod) NNVM_REGISTER_REDUCE_OP(prod)
...@@ -387,9 +383,9 @@ Example:: ...@@ -387,9 +383,9 @@ Example::
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude); param.axis, param.exclude);
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) }; if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes); auto axis = ShapeToIntArray(r_axes);
return Array<Tensor>{ return Array<Tensor>{
topi::prod(inputs[0], axis, param.keepdims) }; topi::prod(inputs[0], axis, param.keepdims, true) };
}); });
......
...@@ -756,8 +756,8 @@ Examples:: ...@@ -756,8 +756,8 @@ Examples::
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Array<Tensor>& out_info) { const Array<Tensor>& out_info) {
const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed); const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis); auto axis = ShapeToIntArray(param.axis);
return Array<Tensor>{ topi::squeeze(inputs[0], axis) }; return Array<Tensor>{ topi::squeeze(inputs[0], axis, true) };
}) })
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
......
...@@ -21,15 +21,9 @@ using namespace tvm; ...@@ -21,15 +21,9 @@ using namespace tvm;
* \return The fused iteration variable * \return The fused iteration variable
*/ */
inline IterVar Fuse(Stage stage, const Array<IterVar>& args) { inline IterVar Fuse(Stage stage, const Array<IterVar>& args) {
CHECK_GE(args.size(), 1) << "Fuse requires at least 1 arg"; IterVar res;
stage.fuse(args, &res);
auto fused = args[0]; return res;
for (size_t i = 1; i < args.size(); ++i) {
IterVar out;
stage.fuse(fused, args[i], &out);
fused = out;
}
return fused;
} }
} // namespace detail } // namespace detail
......
...@@ -27,7 +27,7 @@ using namespace tvm; ...@@ -27,7 +27,7 @@ using namespace tvm;
*/ */
inline Tensor l2_normalize(const Tensor& data, inline Tensor l2_normalize(const Tensor& data,
float eps, float eps,
const Array<Expr>& axis, const Array<Integer>& axis,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "l2_normalize") { std::string tag = "l2_normalize") {
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input"; CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input";
......
...@@ -40,7 +40,7 @@ inline Tensor softmax(const Tensor &x, ...@@ -40,7 +40,7 @@ inline Tensor softmax(const Tensor &x,
auto k1 = tvm::reduce_axis(Range(0, input_shape[axis]), "k1"); auto k1 = tvm::reduce_axis(Range(0, input_shape[axis]), "k1");
auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2"); auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2");
auto reduced_shape = MakeReduceTargetShape({axis}, x, false); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
auto insert_reduce_index = [axis, ndim](const Array<Var> &indices, auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
const IterVar &reduce_index) { const IterVar &reduce_index) {
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <set>
#include <vector> #include <vector>
#include <iterator> #include <iterator>
...@@ -20,13 +19,6 @@ ...@@ -20,13 +19,6 @@
#include "topi/detail/constant_utils.h" #include "topi/detail/constant_utils.h"
#include "tvm/tvm.h" #include "tvm/tvm.h"
/*!
* \brief macro flag to enable some legacy behavior which requires
* reduction result to be at least 1d.
*/
#ifndef TOPI_REDUCE_ATLEAST1D
#define TOPI_REDUCE_ATLEAST1D 0
#endif
namespace topi { namespace topi {
using namespace tvm; using namespace tvm;
...@@ -42,30 +34,34 @@ using FCommReduce = std::function< ...@@ -42,30 +34,34 @@ using FCommReduce = std::function<
* \brief Convert a reduction axis which could be empty or have negative * \brief Convert a reduction axis which could be empty or have negative
* elements into a real axis with valid dimension indices. * elements into a real axis with valid dimension indices.
* *
* \param ndim Number of dimensions in the target.
* \param axis The axis parameter.
*
* \return A non-empty sorted array of valid dimension indices, with no duplicates. * \return A non-empty sorted array of valid dimension indices, with no duplicates.
* If the input axis is empty, the result will be an axis including all dimensions. * If the input axis is empty, the result will be an axis including all dimensions.
* If any input element is negative, it will be treated as an offset from the * If any input element is negative, it will be treated as an offset from the
* last dimension (same as python indexing rules). * last dimension (same as python indexing rules).
*/ */
inline std::vector<int> GetRealAxis(int ndim, const std::vector<int>& axis) { inline std::vector<int> GetRealAxis(int ndim, const Array<Integer>& axis) {
std::vector<int> real_axis; std::vector<int> real_axis;
if (axis.size() == 0) { if (!axis.defined() || axis.size() == 0) {
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
real_axis.push_back(i); real_axis.push_back(i);
} }
} else { } else {
// Use a set so duplicates are removed and the dims are sorted // Use a set so duplicates are removed and the dims are sorted
std::set<int> dims; for (auto elem : axis) {
for (auto ele : axis) { int64_t val = elem->value;
if (ele < 0) { if (val < 0) {
ele += ndim; val += ndim;
}
if (ele >= ndim) {
LOG(ERROR) << ele << " exceeds the maximum dimension " << ndim;
} }
dims.emplace(ele); CHECK_LE(val, ndim) << " exceeds the maximum dimension " << ndim;
CHECK_GE(val, 0);
real_axis.push_back(static_cast<int>(val));
} }
std::copy(dims.begin(), dims.end(), std::back_inserter(real_axis)); std::sort(real_axis.begin(), real_axis.end());
real_axis.resize(
std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin());
} }
return real_axis; return real_axis;
} }
...@@ -84,7 +80,8 @@ inline Array<IterVar> MakeReduceAxes(const std::vector<int>& real_axis, const Te ...@@ -84,7 +80,8 @@ inline Array<IterVar> MakeReduceAxes(const std::vector<int>& real_axis, const Te
/*! \brief Calculate the target shape for a reduce op */ /*! \brief Calculate the target shape for a reduce op */
inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis, inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
const Tensor& data, const Tensor& data,
bool keepdims) { bool keepdims,
bool atleast1d) {
auto ndim = data->shape.size(); auto ndim = data->shape.size();
Array<Expr> target_shape; Array<Expr> target_shape;
if (keepdims) { if (keepdims) {
...@@ -104,7 +101,7 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis, ...@@ -104,7 +101,7 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
} }
} }
} }
if (target_shape.size() == 0 && TOPI_REDUCE_ATLEAST1D) { if (target_shape.size() == 0 && atleast1d) {
target_shape.push_back(1); target_shape.push_back(1);
} }
return target_shape; return target_shape;
...@@ -163,18 +160,19 @@ inline Tensor DoCommReduce(const Tensor& data, ...@@ -163,18 +160,19 @@ inline Tensor DoCommReduce(const Tensor& data,
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return The result tensor. * \return The result tensor.
*/ */
inline Tensor CommReduce(const Tensor& data, inline Tensor CommReduce(const Tensor& data,
const Array<Expr>& axis, const Array<Integer>& axis,
FReduce func, FReduce func,
bool keepdims = false) { bool keepdims,
bool atleast1d) {
auto ndim = data->shape.size(); auto ndim = data->shape.size();
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto axis_val = detail::GetConstIntValues(axis, "axis"); auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis_val); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims);
return DoCommReduce(data, func, target_shape, real_axis, return DoCommReduce(data, func, target_shape, real_axis,
keepdims ? std::vector<int>() : real_axis); keepdims ? std::vector<int>() : real_axis);
} }
...@@ -188,19 +186,20 @@ inline Tensor CommReduce(const Tensor& data, ...@@ -188,19 +186,20 @@ inline Tensor CommReduce(const Tensor& data,
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return The result tensor. * \return The result tensor.
*/ */
inline Tensor CommReduceIdx(const Tensor& data, inline Tensor CommReduceIdx(const Tensor& data,
const Array<Expr>& axis, const Array<Integer>& axis,
FCommReduce func, FCommReduce func,
bool keepdims = false) { bool keepdims,
bool atleast1d) {
auto ndim = data->shape.size(); auto ndim = data->shape.size();
CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto axis_val = detail::GetConstIntValues(axis, "axis"); auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis_val);
auto reduce_axes = MakeReduceAxes(real_axis, data); auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d);
auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data] auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data]
(const Array<Var>& indices) { (const Array<Var>& indices) {
...@@ -311,11 +310,15 @@ inline Expr ProdOp(Expr source, Array<IterVar> axis) { ...@@ -311,11 +310,15 @@ inline Expr ProdOp(Expr source, Array<IterVar> axis) {
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return A Tensor whose op member is the sum operation * \return A Tensor whose op member is the sum operation
*/ */
inline Tensor sum(const Tensor& data, Array<Expr> axis, bool keepdims = false) { inline Tensor sum(const Tensor& data,
return CommReduce(data, axis, tvm::sum, keepdims); const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::sum, keepdims, atleast1d);
} }
inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) { inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
...@@ -356,11 +359,15 @@ inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) { ...@@ -356,11 +359,15 @@ inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return A Tensor whose op member is the min operation * \return A Tensor whose op member is the min operation
*/ */
inline Tensor min(const Tensor& data, Array<Expr> axis, bool keepdims = false) { inline Tensor min(const Tensor& data,
return CommReduce(data, axis, MinOp, keepdims); const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, MinOp, keepdims, atleast1d);
} }
/*! /*!
...@@ -373,11 +380,15 @@ inline Tensor min(const Tensor& data, Array<Expr> axis, bool keepdims = false) { ...@@ -373,11 +380,15 @@ inline Tensor min(const Tensor& data, Array<Expr> axis, bool keepdims = false) {
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return A Tensor whose op member is the max operation * \return A Tensor whose op member is the max operation
*/ */
inline Tensor max(const Tensor& data, Array<Expr> axis, bool keepdims = false) { // NOLINT(*) inline Tensor max(const Tensor& data,
return CommReduce(data, axis, MaxOp, keepdims); const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
} }
/*! /*!
...@@ -390,10 +401,14 @@ inline Tensor max(const Tensor& data, Array<Expr> axis, bool keepdims = false) { ...@@ -390,10 +401,14 @@ inline Tensor max(const Tensor& data, Array<Expr> axis, bool keepdims = false) {
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return A Tensor whose op member is the argmin operation * \return A Tensor whose op member is the argmin operation
*/ */
inline Tensor argmin(const Tensor& data, Array<Expr> axis, bool keepdims = false) { inline Tensor argmin(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) { auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result; Array<Expr> result;
result.push_back(tvm::select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx result.push_back(tvm::select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
...@@ -407,7 +422,7 @@ inline Tensor argmin(const Tensor& data, Array<Expr> axis, bool keepdims = false ...@@ -407,7 +422,7 @@ inline Tensor argmin(const Tensor& data, Array<Expr> axis, bool keepdims = false
return result; return result;
}; };
auto func = MakeCommReducer(fcombine, fidentity, "argmin"); auto func = MakeCommReducer(fcombine, fidentity, "argmin");
return CommReduceIdx(data, axis, func, keepdims); return CommReduceIdx(data, axis, func, keepdims, atleast1d);
} }
/*! /*!
...@@ -420,10 +435,14 @@ inline Tensor argmin(const Tensor& data, Array<Expr> axis, bool keepdims = false ...@@ -420,10 +435,14 @@ inline Tensor argmin(const Tensor& data, Array<Expr> axis, bool keepdims = false
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return A Tensor whose op member is the argmax operation * \return A Tensor whose op member is the argmax operation
*/ */
inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false) { inline Tensor argmax(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) { auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result; Array<Expr> result;
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
...@@ -437,7 +456,7 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false ...@@ -437,7 +456,7 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false
return result; return result;
}; };
auto func = MakeCommReducer(fcombine, fidentity, "argmax"); auto func = MakeCommReducer(fcombine, fidentity, "argmax");
return CommReduceIdx(data, axis, func, keepdims); return CommReduceIdx(data, axis, func, keepdims, atleast1d);
} }
/*! /*!
...@@ -449,11 +468,15 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false ...@@ -449,11 +468,15 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false
* \param keepdims If this is set to true, the axes which are reduced are * \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result * left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array. * to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* *
* \return A Tensor whose op member is the prod operation * \return A Tensor whose op member is the prod operation
*/ */
inline Tensor prod(const Tensor& data, Array<Expr> axis, bool keepdims = false) { // NOLINT(*) inline Tensor prod(const Tensor& data,
return CommReduce(data, axis, ProdOp, keepdims); const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, ProdOp, keepdims, atleast1d);
} }
} // namespace topi } // namespace topi
......
...@@ -196,30 +196,34 @@ inline Tensor reshape(const Tensor& x, ...@@ -196,30 +196,34 @@ inline Tensor reshape(const Tensor& x,
* \param x The input tensor * \param x The input tensor
* \param axis Indices of the dimensions to remove. If this is empty, * \param axis Indices of the dimensions to remove. If this is empty,
* all entries with a constant size of 1 will be removed. * all entries with a constant size of 1 will be removed.
* \param atleast1d Whether the output need to be atleast1d.
* \param name The name of the operation * \param name The name of the operation
* \param tag The tag to mark the operation * \param tag The tag to mark the operation
* *
* \return A Tensor whose op member is the squeeze operation * \return A Tensor whose op member is the squeeze operation
*/ */
inline Tensor squeeze(const Tensor& x, inline Tensor squeeze(const Tensor& x,
Array<Expr> axis, Array<Integer> axis,
bool atleast1d = false,
std::string name = "tensor", std::string name = "tensor",
std::string tag = kInjective) { std::string tag = kInjective) {
auto axis_val = GetConstIntValues(axis, "axis");
auto ndim = x->shape.size(); auto ndim = x->shape.size();
if (axis_val.size() == 0) { std::vector<int> axis_val;
if (!axis.defined() || axis.size() == 0) {
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) { if (IsConstInt(x->shape[i]) && GetConstInt(x->shape[i]) == 1) {
axis_val.push_back(static_cast<int>(i)); axis_val.push_back(static_cast<int>(i));
} }
} }
} else { } else {
for (size_t i = 0; i < axis_val.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
if (axis_val[i] < 0) { int64_t val = axis[i]->value;
axis_val[i] += static_cast<int>(x->shape.size()); if (val < 0) {
val += static_cast<int>(x->shape.size());
} }
CHECK_EQ(GetConstInt(x->shape[axis_val[i]]), 1) << CHECK_EQ(GetConstInt(x->shape[val]), 1) <<
"Dimension " << axis[i] << " must have size 1"; "Dimension " << val << " must have size 1";
axis_val.push_back(val);
} }
} }
...@@ -231,7 +235,7 @@ inline Tensor squeeze(const Tensor& x, ...@@ -231,7 +235,7 @@ inline Tensor squeeze(const Tensor& x,
out_shape.push_back(x->shape[i]); out_shape.push_back(x->shape[i]);
} }
} }
if (out_shape.size() == 0) { if (out_shape.size() == 0 && atleast1d) {
out_shape.push_back(1); out_shape.push_back(1);
} }
......
...@@ -63,10 +63,12 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): ...@@ -63,10 +63,12 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
sch[temp_val_input].compute_at(sch[real_output], outer_in) sch[temp_val_input].compute_at(sch[real_output], outer_in)
else: else:
if is_idx_reduce: if is_idx_reduce:
spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis))
sch[real_output].bind(spatial_axis, tvm.thread_axis("blockIdx.x"))
sch[temp_idx_input].compute_at(sch[real_output], sch[temp_idx_input].compute_at(sch[real_output],
sch[real_output].op.axis[0]) spatial_axis)
sch[temp_val_input].compute_at(sch[real_output], sch[temp_val_input].compute_at(sch[real_output],
sch[real_output].op.axis[0]) spatial_axis)
sch[real_output].set_store_predicate(thread_x.equal(0)) sch[real_output].set_store_predicate(thread_x.equal(0))
return sch return sch
......
...@@ -59,9 +59,9 @@ using namespace tvm; ...@@ -59,9 +59,9 @@ using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
/*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */ /*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */
Array<Expr> ArrayOrInt(TVMArgValue arg) { Array<Integer> ArrayOrInt(TVMArgValue arg) {
if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) { if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) {
Array<Expr> result; Array<Integer> result;
result.push_back(arg.operator int()); result.push_back(arg.operator int());
return result; return result;
} else { } else {
......
...@@ -97,6 +97,10 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") ...@@ -97,6 +97,10 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
def test_reduce_map(): def test_reduce_map():
verify_reduce_map_ele(in_shape=(32,),
axis=0,
keepdims=False,
type="argmax")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24), verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(1, 2, 3), axis=(1, 2, 3),
keepdims=True, keepdims=True,
......
...@@ -91,9 +91,6 @@ def verify_squeeze(src_shape, axis): ...@@ -91,9 +91,6 @@ def verify_squeeze(src_shape, axis):
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis) out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
if out_npy.shape == ():
out_nd_shape = (1,)
else:
out_nd_shape = out_npy.shape out_nd_shape = out_npy.shape
out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype) out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd) foo(data_nd, out_nd)
......
...@@ -100,9 +100,6 @@ def verify_squeeze(src_shape, axis): ...@@ -100,9 +100,6 @@ def verify_squeeze(src_shape, axis):
data_npy = np.random.normal(size=src_shape).astype(A.dtype) data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis) out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx) data_nd = tvm.nd.array(data_npy, ctx)
if out_npy.shape == ():
out_nd_shape = (1,)
else:
out_nd_shape = out_npy.shape out_nd_shape = out_npy.shape
out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype) out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd) foo(data_nd, out_nd)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment