reduce.cc 13.6 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 * \file reduce.cc
 * \brief reduce operator.
 */
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
9 10
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
11
#include <nnvm/top/tensor.h>
12
#include <numeric>
13 14
#include "../op_common.h"
#include "../elemwise_op_common.h"
15 16
#include "topi/detail/constant_utils.h"
#include "topi/elemwise.h"
17
#include "topi/reduction.h"
18
#include "topi/transform.h"
19 20 21

namespace nnvm {
namespace top {
22 23
using namespace tvm;
using namespace nnvm::compiler;
24

25

26 27 28
// reduce
DMLC_REGISTER_PARAMETER(ReduceParam);

29 30 31
inline TShape GetReduceAxes(const uint32_t indim,
                            const TShape& axis,
                            bool exclude) {
32
  if (axis.ndim() == 0) {
33 34 35
    TShape r_axes(indim);
    std::iota(r_axes.begin(), r_axes.end(), 0);
    return r_axes;
36
  }
37 38

  CHECK_LT(axis[axis.ndim() - 1], indim)
39
    << "Reduction axis " << axis[axis.ndim() - 1]
40
    << " exceeds input dimensions " << indim;
41

42 43
  TShape in_axis = axis;
  for (auto& i : in_axis) {
44
    i = i < 0 ? i + indim : i;
45
    CHECK_GE(i, 0) << "axis out of bounds in reduce operator";
46
    CHECK_LT(i, indim) << "axis out of bounds in reduce operator";
47 48
  }
  std::sort(in_axis.begin(), in_axis.end());
49 50 51
  if (!exclude) return in_axis;
  TShape r_axis(indim - in_axis.ndim());
  for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
52
    if (j < in_axis.ndim() && i == in_axis[j]) {
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
        ++j;
        continue;
    }
    r_axis[k++] = i;
  }
  return r_axis;
}

inline TShape ReduceShapeImpl(const TShape& ishape,
                              const TShape& axis,
                              bool keepdims,
                              bool exclude) {
  uint32_t indim = ishape.ndim();
  TShape r_axes = GetReduceAxes(indim, axis, exclude);
  if (!r_axes.ndim()) return ishape;
  if (r_axes.ndim() == indim)
    return TShape(keepdims ? indim : 1);
70

71
  CHECK(r_axes.ndim() < indim);
72 73
  if (keepdims) {
    TShape oshape(ishape);
74
    for (unsigned i = 0, j = 0; i < indim; ++i) {
75
      if (j >= r_axes.ndim() || i != r_axes[j]) continue;
76 77
      oshape[i] = 1;
      ++j;
78 79 80 81
    }
    return oshape;
  }

82 83
  TShape oshape(indim - r_axes.ndim());
  for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
84
    if (j < r_axes.ndim() && i == r_axes[j]) {
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
      ++j;
      continue;
    }
    oshape[k++] = ishape[i];
  }
  return oshape;
}

inline bool ReduceShape(const nnvm::NodeAttrs& attrs,
                        std::vector<TShape>* in_attrs,
                        std::vector<TShape>* out_attrs) {
  CHECK_EQ(in_attrs->size(), 1U);
  CHECK_EQ(out_attrs->size(), 1U);
  if ((*in_attrs)[0].ndim() == 0) return false;
  const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
100
  NNVM_ASSIGN_OUTPUT_SHAPE(
101 102 103 104 105 106
      attrs, *out_attrs, 0,
      ReduceShapeImpl((*in_attrs)[0], param.axis,
                      param.keepdims, param.exclude));
  return true;
}

107 108 109 110 111 112 113 114 115 116
inline bool CollapseShape(const nnvm::NodeAttrs& attrs,
                          std::vector<TShape>* in_attrs,
                          std::vector<TShape>* out_attrs) {
  CHECK_EQ(in_attrs->size(), 2U);
  CHECK_EQ(out_attrs->size(), 1U);
  if ((*in_attrs)[0].ndim() == 1) return false;
  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, (*in_attrs)[1]);
  return true;
}

117 118 119 120 121 122 123 124
template<typename PType>
inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
  PType param;
  param.Init(attrs->dict);
  std::sort(&param.axis[0], &param.axis[param.axis.ndim()]);
  attrs->parsed = std::move(param);
}

125 126 127 128 129 130 131
#define NNVM_REGISTER_BASE_REDUCE_OP(op)                                 \
  NNVM_REGISTER_OP(op)                                                   \
  .add_arguments(ReduceParam::__FIELDS__())                              \
  .set_attr_parser(AxesParamParser<ReduceParam>)                         \
  .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
  .set_num_outputs(1)

132
#define NNVM_REGISTER_REDUCE_OP(op)                                     \
133
  NNVM_REGISTER_BASE_REDUCE_OP(op)                                      \
134 135 136
  .add_argument("data", "Tensor", "The input")                          \
  .set_attr<FInferShape>("FInferShape", ReduceShape)                    \
  .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)               \
137
  .set_attr<FCorrectLayout>("FCorrectLayout",                           \
138
    ElemwiseFixedLayoutUnknownOut<1, 1>)                                \
139
  .set_num_inputs(1)
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157

NNVM_REGISTER_REDUCE_OP(sum)
.describe(R"code(Computes the sum of array elements over given axes.

Example::

  data = [[[1,2],[2,3],[1,3]],
          [[1,4],[4,3],[5,2]],
          [[7,1],[7,2],[7,3]]]

  sum(data, axis=1)
  [[  4.   8.]
   [ 10.   9.]
   [ 21.   6.]]

  sum(data, axis=[1,2])
  [ 12.  19.  27.]

158
)code" NNVM_ADD_FILELINE)
159 160 161 162 163
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
164 165 166
    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
                                  param.axis, param.exclude);
    if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
167
    auto axis = ShapeToIntArray(r_axes);
168
    return Array<Tensor>{
169
      topi::sum(inputs[0], axis, param.keepdims, true) };
170
})
171 172 173 174
.set_attr<FGradient>(
  "FGradient", [](const NodePtr& n,
                  const std::vector<NodeEntry>& ograds){
    const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed);
175 176 177 178 179 180 181
    bool exclude = param.exclude;
    TShape p_axis = param.axis;
    if (!param.exclude && param.axis.ndim() == 0) {
      exclude = true;
      p_axis = TShape();
    }
    std::ostringstream axis; axis << p_axis;
182 183 184 185
    return std::vector<NodeEntry>{
      MakeNode("expand_like", n->attrs.name + "_grad",
               {ograds[0], n->inputs[0]},
               {{"axis", axis.str()},
186
                {"exclude", std::to_string(exclude)}})
187 188
  };
});
189 190 191 192

NNVM_REGISTER_REDUCE_OP(max)
.describe(R"code(Computes the max of array elements over given axes.

193
)code" NNVM_ADD_FILELINE)
194 195 196 197 198
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
199 200
    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
                                  param.axis, param.exclude);
201
    auto axis = ShapeToIntArray(r_axes);
202
    return Array<Tensor>{
203
      topi::max(inputs[0], axis, param.keepdims, true) };
204
})
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
.set_attr<FGradient>(
  "FGradient", [](const NodePtr& n,
                  const std::vector<NodeEntry>& ograds){
    const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed);
    std::ostringstream axis; axis << param.axis;
    NodeEntry sub0 = MakeNode("expand_like", n->attrs.name + "_grad_sub0",
                             {ograds[0], n->inputs[0]},
                             {{"axis", axis.str()},
                              {"exclude", std::to_string(param.exclude)}});
    NodeEntry sub1 = MakeNode("_max_mask", n->attrs.name + "_grad_sub1",
                              {ograds[0]},
                              {{"axis", axis.str()},
                               {"exclude", std::to_string(param.exclude)}});
    return std::vector<NodeEntry>{
      MakeNode("elemwise_mul", n->attrs.name + "_grad", {sub0, sub1})
    };
});
222 223 224 225

NNVM_REGISTER_REDUCE_OP(min)
.describe(R"code(Computes the min of array elements over given axes.

226
)code" NNVM_ADD_FILELINE)
227 228 229 230 231
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
232 233
    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
                                  param.axis, param.exclude);
234
    auto axis = ShapeToIntArray(r_axes);
235
    return Array<Tensor>{
236
      topi::min(inputs[0], axis, param.keepdims, true) };
237
})
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
.set_attr<FGradient>(
  "FGradient", [](const NodePtr& n,
                  const std::vector<NodeEntry>& ograds){
    const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed);
    std::ostringstream axis; axis << param.axis;
    NodeEntry sub0 = MakeNode("expand_like", n->attrs.name + "_grad_sub0",
                              {ograds[0], n->inputs[0]},
                              {{"axis", axis.str()},
                               {"exclude", std::to_string(param.exclude)}});
    NodeEntry sub1 = MakeNode("_min_mask", n->attrs.name + "_grad_sub1",
                              {ograds[0]},
                              {{"axis", axis.str()},
                               {"exclude", std::to_string(param.exclude)}});
    return std::vector<NodeEntry>{
      MakeNode("elemwise_mul", n->attrs.name + "_grad", {sub0, sub1})
    };
});
255

256 257 258 259 260 261 262 263 264 265 266 267 268 269
NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum)
.add_argument("data", "Tensor", "The input")
.add_argument("as", "Tensor", "The reference")
.set_attr<FInferShape>("FInferShape", CollapseShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<2, 1>)
.set_num_inputs(2)
.describe(R"code(Reduces lhs to the shape of rhs via sum)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
});
270

271 272 273 274 275
inline bool InferFixedType(const NodeAttrs& attrs,
                          std::vector<int>* in_attrs,
                          std::vector<int>* out_attrs) {
  CHECK_EQ(in_attrs->size(), 1U);
  CHECK_EQ(out_attrs->size(), 1U);
276 277
  const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
  NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, param.dtype);
278 279 280 281 282 283 284 285 286 287
  return true;
}

NNVM_REGISTER_BASE_REDUCE_OP(argmax)
.describe(R"code(Creates an operation that finds the indices of the maximum
values over a given axis.

)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input")
.set_attr<FInferShape>("FInferShape", ReduceShape)
288
.set_attr<FInferType>("FInferType", InferFixedType)
289 290 291 292 293 294 295 296 297
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1)
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
                                  param.axis, param.exclude);
298 299
    auto axis = ShapeToIntArray(r_axes);
    Tensor out = topi::argmax(inputs[0], axis, param.keepdims, true);
300 301
    if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
    return Array<Tensor>{out};
302 303 304 305 306 307 308 309 310
});

NNVM_REGISTER_BASE_REDUCE_OP(argmin)
.describe(R"code(Creates an operation that finds the indices of the minimum
values over a given axis.

)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "The input")
.set_attr<FInferShape>("FInferShape", ReduceShape)
311
.set_attr<FInferType>("FInferType", InferFixedType)
312 313 314 315 316 317 318 319 320
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1)
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
                                  param.axis, param.exclude);
321 322
    auto axis = ShapeToIntArray(r_axes);
    Tensor out = topi::argmin(inputs[0], axis, param.keepdims, true);
323 324
    if (param.dtype == kFloat32) out = topi::cast(out, out_info[0]->dtype);
    return Array<Tensor>{out};
325 326
});

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
NNVM_REGISTER_REDUCE_OP(mean)
  .describe(R"code(Computes the mean of array elements over given axes.

Example::

  data = [[[1,2],[2,3],[1,3]],
          [[1,4],[4,3],[5,2]],
          [[7,1],[7,2],[7,3]]]

  mean(data)
  [3.22]

  mean(data, axis=[1,2])
  [ 2.  3.16666667  4.5]

)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
                                  param.axis, param.exclude);
    if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
351
    auto axis = ShapeToIntArray(r_axes);
352

353
    Expr count = make_const(inputs[0]->dtype, 1);
354
    for (auto& i : r_axes) {
355
      count *= cast(inputs[0]->dtype, inputs[0]->shape[i]);
356 357 358
    }

    return Array<Tensor>{
359
      topi::divide(topi::sum(inputs[0], axis, param.keepdims, true), count) };
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
});

NNVM_REGISTER_REDUCE_OP(prod)
  .describe(R"code(Computes the products of array elements over given axes.

Example::

  data = [[[1,2],[2,3],[1,3]],
          [[1,4],[4,3],[5,2]],
          [[7,1],[7,2],[7,3]]]

  mean(data, axis=1)
  [35562240]

  mean(data, axis=[1,2])
  [ 36  480  2058]

)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
                                  param.axis, param.exclude);
    if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
386
    auto axis = ShapeToIntArray(r_axes);
387
    return Array<Tensor>{
388
      topi::prod(inputs[0], axis, param.keepdims, true) };
389 390
});

391

392 393
}  // namespace top
}  // namespace nnvm