/*! * Copyright (c) 2018 by Contributors * \file reorg.cc */ #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> #include <nnvm/top/nn.h> #include "../../op_common.h" #include "../../elemwise_op_common.h" #include "reorg.h" namespace nnvm { namespace top { // reorg DMLC_REGISTER_PARAMETER(ReorgParam); inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs, std::vector<TShape> *in_shape, std::vector<TShape> *out_shape) { const ReorgParam ¶m = nnvm::get<ReorgParam>(attrs.parsed); TShape dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape); CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D"; CHECK_GT(param.stride, 0U) << "Stride value cannot be 0"; TShape oshape({dshape[0], 0, 0, 0}); oshape[1] = dshape[1] * param.stride * param.stride; oshape[2] = dshape[2] / param.stride; oshape[3] = dshape[3] / param.stride; NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); return true; } NNVM_REGISTER_OP(yolo2_reorg) .describe(R"(Perform reorg operation on input array based on the stride value. - **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width). - **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride). )" NNVM_ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) .set_support_level(5) .add_argument("data", "Tensor", "Data input to reorganize") .set_attr_parser(ParamParser<ReorgParam>) .add_arguments(ReorgParam::__FIELDS__()) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReorgParam>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferShape>("FInferShape", ReorgInferShape); } // namespace top } // namespace nnvm