/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file multibox_op.cc * \brief Property def of SSD multibox related operators. */ #include <tvm/expr.h> #include <tvm/packed_func_ext.h> #include <nnvm/op.h> #include <nnvm/top/nn.h> #include <nnvm/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h> #include "../../op_common.h" #include "../../elemwise_op_common.h" namespace nnvm { namespace top { using compiler::FTVMCompute; using tvm::Tensor; using tvm::Array; DMLC_REGISTER_PARAMETER(MultiBoxPriorParam); bool MultiBoxPriorShape(const NodeAttrs& attrs, std::vector<TShape> *in_attrs, std::vector<TShape> *out_attrs) { const MultiBoxPriorParam& param = nnvm::get<MultiBoxPriorParam>(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U) << "Inputs: [data]" << in_attrs->size(); TShape dshape = in_attrs->at(0); CHECK_GE(dshape.ndim(), 4U) << "Input data should be 4D: " "[batch, channel, height, width]"; int in_height = dshape[2]; CHECK_GT(in_height, 0) << "Input height should > 0"; int in_width = dshape[3]; CHECK_GT(in_width, 0) << "Input width should > 0"; // since input sizes are same in each batch, we could share MultiBoxPrior TShape oshape = TShape(3); int num_sizes = param.sizes.ndim(); int num_ratios = param.ratios.ndim(); oshape[0] = 1; oshape[1] = in_height * in_width * (num_sizes + num_ratios - 1); oshape[2] = 4; CHECK_EQ(param.steps.ndim(), 2) << "Step ndim must be 2: (step_y, step_x)"; CHECK_GE(param.steps[0] * param.steps[1], 0) << "Must specify both " "step_y and step_x"; out_attrs->clear(); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape); return true; } inline bool MultiBoxPriorLayout(const NodeAttrs& attrs, std::vector<Layout> *ilayouts, const std::vector<Layout> *last_ilayouts, std::vector<Layout> *olayouts) { static const Layout kNCHW("NCHW"); CHECK_EQ(ilayouts->size(), 1U); CHECK_EQ(olayouts->size(), 1U); NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW); return true; } NNVM_REGISTER_OP(multibox_prior) .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." )doc" NNVM_ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser<MultiBoxPriorParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MultiBoxPriorParam>) .add_arguments(MultiBoxPriorParam::__FIELDS__()) .add_argument("data", "Tensor", "Input data") .set_attr<FInferShape>("FInferShape", MultiBoxPriorShape) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FCorrectLayout>("FCorrectLayout", MultiBoxPriorLayout) .set_attr<FGradient>( "FGradient", [](const NodePtr& n, const std::vector<NodeEntry>& ograds) { return std::vector<NodeEntry>{ MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[0]}), ograds[0] }; }) .set_support_level(4); DMLC_REGISTER_PARAMETER(MultiBoxTransformLocParam); bool MultiBoxTransformLocShape(const NodeAttrs& attrs, std::vector<TShape> *in_attrs, std::vector<TShape> *out_attrs) { CHECK_EQ(in_attrs->size(), 3U) << "Inputs: [cls_prob, loc_pred, anchor]"; TShape cshape = in_attrs->at(0); TShape lshape = in_attrs->at(1); TShape ashape = in_attrs->at(2); CHECK_EQ(cshape.ndim(), 3U) << "Class probability should be 3-D."; CHECK_EQ(lshape.ndim(), 2U) << "Location prediction should be 2-D."; CHECK_EQ(ashape.ndim(), 3U) << "Anchor should be 3-D."; CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch."; CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc."; CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0."; CHECK_EQ(ashape[2], 4U); TShape oshape0 = TShape(3); oshape0[0] = cshape[0]; oshape0[1] = ashape[1]; oshape0[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] TShape oshape1 = TShape(1); oshape1[0] = cshape[0]; out_attrs->clear(); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape0); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 1, oshape1); return true; } inline bool MultiBoxTransformLocLayout(const NodeAttrs& attrs, std::vector<Layout> *ilayouts, const std::vector<Layout> *last_ilayouts, std::vector<Layout> *olayouts) { CHECK_EQ(ilayouts->size(), 3U); CHECK_EQ(last_ilayouts->size(), 3U); CHECK_EQ(olayouts->size(), 2U); for (size_t i = 0; i < last_ilayouts->size(); ++i) { const Layout& last_layout = last_ilayouts->at(i); if (last_layout.defined()) { NNVM_ASSIGN_LAYOUT(*ilayouts, i, last_layout); } } return true; } inline bool MultiBoxTransformLocInferType(const NodeAttrs &attrs, std::vector<int> *in_attrs, std::vector<int> *out_attrs) { DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0)); DTYPE_ASSIGN(out_attrs->at(1), 4U); return true; } NNVM_REGISTER_OP(multibox_transform_loc) .describe(R"doc("Location transformation for multibox detection." )doc" NNVM_ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(2) .set_attr_parser(ParamParser<MultiBoxTransformLocParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MultiBoxTransformLocParam>) .add_arguments(MultiBoxTransformLocParam::__FIELDS__()) .add_argument("cls_prob", "Tensor", "Class probabilities.") .add_argument("loc_pred", "Tensor", "Location regression predictions.") .add_argument("anchor", "Tensor", "Multibox prior anchor boxes") .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { return std::vector<std::string>{"cls_prob", "loc_pred", "anchor"}; }) .set_attr<FInferShape>("FInferShape", MultiBoxTransformLocShape) .set_attr<FInferType>("FInferType", MultiBoxTransformLocInferType) .set_attr<FCorrectLayout>("FCorrectLayout", MultiBoxTransformLocLayout) .set_support_level(4); } // namespace top } // namespace nnvm