pooling.cc 17 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20

21 22 23 24 25 26 27 28
/*!
 *  Copyright (c) 2017 by Contributors
 * \file pooling.cc
 * \brief Property def of pooling operators.
 */
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
29 30
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
31
#include <nnvm/top/nn.h>
32
#include "nn_common.h"
33 34
#include "../op_common.h"
#include "../elemwise_op_common.h"
35
#include "topi/nn/pooling.h"
36 37 38

namespace nnvm {
namespace top {
39 40
using namespace tvm;
using namespace nnvm::compiler;
41

42
DMLC_REGISTER_PARAMETER(MaxPool2DParam);
43

44
template <typename T>
45 46 47
inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs,
                             std::vector<TShape>* in_shape,
                             std::vector<TShape>* out_shape) {
48
  const T& param = nnvm::get<T>(attrs.parsed);
49 50 51 52 53
  CHECK_EQ(in_shape->size(), 1U);
  CHECK_EQ(out_shape->size(), 1U);

  TShape dshape = (*in_shape)[0];
  if (dshape.ndim() ==  0) return false;
54 55 56 57 58 59 60 61 62 63 64 65

  CHECK_GE(dshape.ndim(), 2U)
    << "Pool2D only support input >= 2-D: input must have height and width";

  Layout layout(param.layout);
  CHECK(layout.contains('H') && layout.contains('W') &&
        !layout.contains('h') && !layout.contains('w'))
    << "Invalid layout " << layout
    << ". Pool2D layout must have H and W, which cannot be split";

  const auto hidx = layout.indexof('H');
  const auto widx = layout.indexof('W');
66

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
  dim_t pad_h, pad_w;
  if (param.padding.ndim() == 1) {
    pad_h = param.padding[0] * 2;
    pad_w = param.padding[0] * 2;
  } else if (param.padding.ndim() == 2) {
    // (top, left)
    pad_h = param.padding[0] * 2;
    pad_w = param.padding[1] * 2;
  } else if (param.padding.ndim() == 4) {
    // (top, left, bottom, right)
    pad_h = param.padding[0] + param.padding[2];
    pad_w = param.padding[1] + param.padding[3];
  } else {
    return false;
  }

83
  TShape oshape = dshape;
84
  CHECK(param.pool_size[0] <= dshape[hidx] + pad_h)
85
      << "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[hidx]
86 87
      << " padded to " << (dshape[hidx] + pad_h) << ")";
  CHECK(param.pool_size[1] <= dshape[widx] + pad_w)
88
      << "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[widx]
89
      << " padded to " << (dshape[widx] + pad_w) << ")";
90 91

  if (!param.ceil_mode) {
92
    oshape[hidx] = ((dshape[hidx] + pad_h - param.pool_size[0]) /
93
                    param.strides[0]) + 1;
94
    oshape[widx] = ((dshape[widx] + pad_w - param.pool_size[1]) /
95
                    param.strides[1]) + 1;
96
  } else {
97
    oshape[hidx] = ((dshape[hidx] + pad_h - param.pool_size[0] +
98
                    param.strides[0] - 1) / param.strides[0]) + 1;
99
    oshape[widx] = ((dshape[widx] + pad_w - param.pool_size[1] +
100
                    param.strides[1] - 1) / param.strides[1]) + 1;
101 102 103 104 105
  }
  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
  return true;
}

106
template <typename T>
107 108 109 110
inline bool Pool2DCorrectLayout(const NodeAttrs& attrs,
                                std::vector<Layout> *ilayouts,
                                const std::vector<Layout> *last_ilayouts,
                                std::vector<Layout> *olayouts) {
111
  const T &param = nnvm::get<T>(attrs.parsed);
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  CHECK_EQ(ilayouts->size(), 1);
  CHECK_EQ(last_ilayouts->size(), 1);
  CHECK_EQ(olayouts->size(), 1);

  Layout input = (*ilayouts)[0];
  const Layout layout(param.layout);

  if (input.defined()) {
    CHECK(input.convertible(layout)) << "Invalid input layout " << input;
    if (input.indexof('W') != layout.indexof('W') ||
        input.indexof('H') != layout.indexof('H') ||
        input.contains('w') || input.contains('h')) {
      // as long as the index doesn't change for width and height
      // pool2d can keep the input layout.
      input = layout;
    }
  } else {
    input = layout;
  }

  NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
  NNVM_ASSIGN_LAYOUT(*olayouts, 0, input);

  return true;
}

138 139 140 141 142 143 144 145
NNVM_REGISTER_OP(max_pool2d)
.describe(R"code(Max pooling operation for one dimensional data.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
           (batch_size, channels, out_height, out_width)  if `layout` is `NCHW`.
           out_height and out_width are calculated as::
146

147 148 149 150 151 152 153
               out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
               out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1

           where padding will be an expanded array based on number of values passed as::
               one int : all sides same padding used.
               two int : bottom, right use same as top and left.
               four int: padding width in the order of (top, left, bottom, right).
154

155 156 157 158 159
           When `ceil_mode` is `True`, ceil will be used instead of floor in this
           equation.

)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
160 161 162
.add_arguments(MaxPool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<MaxPool2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MaxPool2DParam>)
163 164
.set_num_outputs(1)
.set_num_inputs(1)
165
.set_attr<FInferShape>("FInferShape", Pool2DInferShape<MaxPool2DParam>)
166
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
167
.set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout<MaxPool2DParam>)
168 169 170
.set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
                                         const Array<Tensor>& inputs,
                                         const Array<Tensor>& out_info) {
171
  const MaxPool2DParam& param = nnvm::get<MaxPool2DParam>(attrs.parsed);
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  auto pool_size = ShapeToArray(param.pool_size);
  auto strides = ShapeToArray(param.strides);
  auto padding = ShapeToArray(param.padding);
  auto ceil_mode = param.ceil_mode;

  Layout layout(param.layout);
  CHECK(layout.convertible(Layout("NCHW")))
    << "max_pool2d currently only supports layouts that are convertible from NCHW";
  CHECK_EQ(layout.indexof('h'), -1) << "max_pool2d does not support input split on height";
  CHECK_EQ(layout.indexof('w'), -1) << "max_pool2d does not support input split on width";

  CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
    << "Pool2D only support 4-D input (e.g., NCHW)"
    << " or 5-D input (last dimension is a split of channel)";

187 188 189 190 191 192 193 194 195
  if (param.padding.ndim() == 1) {
    padding.push_back(padding[0]);
    padding.push_back(padding[0]);
    padding.push_back(padding[0]);
  } else if (param.padding.ndim() == 2) {
    padding.push_back(padding[0]);
    padding.push_back(padding[1]);
  }

196 197 198
  return Array<Tensor>{
    topi::nn::pool(inputs[0], pool_size, strides, padding,
                   topi::nn::kMaxPool, ceil_mode, layout.name())};
199
})
200 201 202 203 204 205 206
.set_attr<FGradient>(
  "FGradient", [](const NodePtr& n,
                  const std::vector<NodeEntry>& ograds) {
    return MakeGradNode("_max_pool2d_grad", n,
                        {ograds[0], n->inputs[0], NodeEntry{n, 0, 0}},
                        n->attrs.dict);
})
207 208
.set_support_level(2);

209 210 211 212 213 214 215 216 217
NNVM_REGISTER_OP(_max_pool2d_grad)
  .describe(R"code(Max pooling 2D grad.

)code" NNVM_ADD_FILELINE)
.add_argument("ograd", "4D Tensor", "Output grad.")
.add_argument("input", "4D Tensor", "Input data of max_pool2d grad.")
.add_argument("output", "4D Tensor", "Output data of max_pool2d grad.")
.set_num_inputs(3)
.set_num_outputs(1)
218 219
.set_attr_parser(ParamParser<MaxPool2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MaxPool2DParam>)
220 221 222 223
.set_attr<FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
.set_attr<FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<TIsBackward>("TIsBackward", true);

224
DMLC_REGISTER_PARAMETER(AvgPool2DParam);
225 226 227 228 229 230 231 232 233

NNVM_REGISTER_OP(avg_pool2d)
.describe(R"code(Average pooling operation for one dimensional data.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
           (batch_size, channels, out_height, out_width)  if `layout` is `NCHW`.
           out_height and out_width are calculated as::
234

235 236 237 238 239 240 241
               out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
               out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1

           where padding will be an expanded array based on number of values passed as::
               one int : all sides same padding used.
               two int : bottom, right use same as top and left.
               four int: padding width in the order of (top, left, bottom, right).
242

243 244 245 246 247
           When `ceil_mode` is `True`, ceil will be used instead of floor in this
           equation.

)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
248 249 250 251
.add_arguments(AvgPool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<AvgPool2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<AvgPool2DParam>)
.set_attr<FInferShape>("FInferShape", Pool2DInferShape<AvgPool2DParam>)
252
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
253
.set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout<AvgPool2DParam>)
254 255 256
.set_attr<FTVMCompute>("FTVMCompute", [](const NodeAttrs& attrs,
                                         const Array<Tensor>& inputs,
                                         const Array<Tensor>& out_info) {
257
  const AvgPool2DParam& param = nnvm::get<AvgPool2DParam>(attrs.parsed);
258 259 260 261
  auto pool_size = ShapeToArray(param.pool_size);
  auto strides = ShapeToArray(param.strides);
  auto padding = ShapeToArray(param.padding);
  auto ceil_mode = param.ceil_mode;
262
  auto count_include_pad = param.count_include_pad;
263 264 265 266 267 268 269 270 271 272 273

  Layout layout(param.layout);
  CHECK(layout.convertible(Layout("NCHW")))
    << "avg_pool2d currently only supports layouts that are convertible from NCHW";
  CHECK_EQ(layout.indexof('h'), -1) << "avg_pool2d does not support input split on height";
  CHECK_EQ(layout.indexof('w'), -1) << "avg_pool2d does not support input split on width";

  CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
    << "Pool2D only support 4-D input (e.g., NCHW)"
    << " or 5-D input (last dimension is a split of channel)";

274 275 276 277 278 279 280 281 282
  if (param.padding.ndim() == 1) {
    padding.push_back(padding[0]);
    padding.push_back(padding[0]);
    padding.push_back(padding[0]);
  } else if (param.padding.ndim() == 2) {
    padding.push_back(padding[0]);
    padding.push_back(padding[1]);
  }

283 284
  return Array<Tensor>{
    topi::nn::pool(inputs[0], pool_size, strides, padding,
285
                   topi::nn::kAvgPool, ceil_mode, layout.name(), count_include_pad)};
286
})
287 288
.set_num_outputs(1)
.set_num_inputs(1)
289 290 291
.set_support_level(2);


292
DMLC_REGISTER_PARAMETER(GlobalPool2DParam);
293 294 295 296

inline bool GlobalPool2DInferShape(const nnvm::NodeAttrs& attrs,
                                   std::vector<TShape>* in_shape,
                                   std::vector<TShape>* out_shape) {
297
  static const Layout kNCHW("NCHW");
298
  const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
299 300
  CHECK_EQ(in_shape->size(), 1U);
  CHECK_EQ(out_shape->size(), 1U);
301

302 303
  TShape dshape = (*in_shape)[0];
  if (dshape.ndim() ==  0) return false;
304 305 306 307 308 309 310 311 312 313 314 315 316

  CHECK_GE(dshape.ndim(), 2U)
    << "Pool2D only support input >= 2-D: input must have height and width";

  Layout layout(param.layout);
  CHECK(layout.contains('H') && layout.contains('W') &&
        !layout.contains('h') && !layout.contains('w'))
    << "Invalid layout " << layout
    << ". Pool2D layout must have H and W, which cannot be split";

  const auto hidx = layout.indexof('H');
  const auto widx = layout.indexof('W');

317
  TShape oshape = dshape;
318
  oshape[hidx] = oshape[widx] = 1;
319 320 321 322
  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
  return true;
}

323 324 325 326
inline bool GlobalPool2DCorrectLayout(const NodeAttrs& attrs,
                                      std::vector<Layout> *ilayouts,
                                      const std::vector<Layout> *last_ilayouts,
                                      std::vector<Layout> *olayouts) {
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
  const GlobalPool2DParam &param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
  CHECK_EQ(ilayouts->size(), 1);
  CHECK_EQ(last_ilayouts->size(), 1);
  CHECK_EQ(olayouts->size(), 1);

  Layout input = (*ilayouts)[0];
  const Layout layout(param.layout);

  if (input.defined()) {
    CHECK(input.convertible(layout)) << "Invalid input layout " << input;
    if (input.indexof('W') != layout.indexof('W') ||
        input.indexof('H') != layout.indexof('H') ||
        input.contains('w') || input.contains('h')) {
      // as long as the index doesn't change for width and height
      // pool2d can keep the input layout.
      input = layout;
    }
  } else {
    input = layout;
  }

  NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
  NNVM_ASSIGN_LAYOUT(*olayouts, 0, input);

  return true;
}

354 355 356 357 358 359 360 361 362 363
NNVM_REGISTER_OP(global_max_pool2d)
.describe(R"code(Global max pooling operation for 2D data.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
           (batch_size, channels, 1, 1)  if `layout` is `NCHW`.

)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
364 365
.add_arguments(GlobalPool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<GlobalPool2DParam>)
366
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
367 368
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
369
.set_attr<FCorrectLayout>("FCorrectLayout", GlobalPool2DCorrectLayout)
370 371 372 373
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
  const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
  Layout layout(param.layout);
  CHECK(layout.convertible(Layout("NCHW")))
    << "global_max_pool2d currently only supports layouts that are convertible from NCHW";
  CHECK_EQ(layout.indexof('h'), -1)
    << "global_max_pool2d does not support input split on height";
  CHECK_EQ(layout.indexof('w'), -1)
    << "global_max_pool2d does not support input split on width";

  CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
    << "Pool2D only support 4-D input (e.g., NCHW)"
    << " or 5-D input (last dimension is a split of channel)";

  return Array<Tensor>{
    topi::nn::global_pool(inputs[0], topi::nn::kMaxPool, layout.name()) };
389
})
390 391
.set_num_outputs(1)
.set_num_inputs(1)
392 393 394 395 396 397 398 399 400 401 402 403 404
.set_support_level(2);


NNVM_REGISTER_OP(global_avg_pool2d)
.describe(R"code(Global average pooling operation for 2D data.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
            (batch_size, channels, height, width) if `layout` is `NCHW`.
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
           (batch_size, channels, 1, 1)  if `layout` is `NCHW`.

)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
405 406
.add_arguments(GlobalPool2DParam::__FIELDS__())
.set_attr_parser(ParamParser<GlobalPool2DParam>)
407
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<GlobalPool2DParam>)
408 409
.set_attr<FInferShape>("FInferShape", GlobalPool2DInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
410
.set_attr<FCorrectLayout>("FCorrectLayout", GlobalPool2DCorrectLayout)
411 412 413 414
.set_attr<FTVMCompute>(
  "FTVMCompute", [](const NodeAttrs& attrs,
                    const Array<Tensor>& inputs,
                    const Array<Tensor>& out_info) {
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
  const GlobalPool2DParam& param = nnvm::get<GlobalPool2DParam>(attrs.parsed);
  Layout layout(param.layout);
  CHECK(layout.convertible(Layout("NCHW")))
    << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
  CHECK_EQ(layout.indexof('h'), -1)
    << "global_avg_pool2d does not support input split on height";
  CHECK_EQ(layout.indexof('w'), -1)
    << "global_avg_pool2d does not support input split on width";

  CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
    << "Pool2D only support 4-D input (e.g., NCHW)"
    << " or 5-D input (last dimension is a split of channel)";

  return Array<Tensor>{
    topi::nn::global_pool(inputs[0], topi::nn::kAvgPool, layout.name()) };
430
})
431 432
.set_num_outputs(1)
.set_num_inputs(1)
433 434 435 436
.set_support_level(2);

}  // namespace top
}  // namespace nnvm