nms.cc 4.79 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

Yao Wang committed
20 21 22 23 24 25 26 27 28 29
/*!
 * \file nms.cc
 * \brief Non-maximum suppression operators
 */
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>

namespace tvm {
namespace relay {

30 31 32 33 34 35 36 37 38 39 40 41 42
TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs);

bool GetValidCountRel(const Array<Type>& types,
                      int num_inputs,
                      const Attrs& attrs,
                      const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 2);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto& dshape = data->shape;
  CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";

  std::vector<IndexExpr> oshape({data->shape[0]});
  std::vector<Type> fields;
43 44
  fields.push_back(TensorType(oshape, DataType::Int(32)));
  fields.push_back(TensorType(data->shape, data->dtype));
45 46

  // assign output type
47
  reporter->Assign(types[1], TupleType(Array<Type>(fields)));
48 49 50 51
  return true;
}

Expr MakeGetValidCounts(Expr data,
52 53 54
                        double score_threshold,
                        int id_index,
                        int score_index) {
55
  auto attrs = make_object<GetValidCountsAttrs>();
56
  attrs->score_threshold = score_threshold;
57 58
  attrs->id_index = id_index;
  attrs->score_index = score_index;
59
  static const Op& op = Op::Get("vision.get_valid_counts");
60
  return Call(op, {data}, Attrs(attrs), {});
61 62 63
}


64
TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts")
65
.set_body_typed(MakeGetValidCounts);
66 67 68 69 70 71 72 73 74 75 76 77 78 79


RELAY_REGISTER_OP("vision.get_valid_counts")
.describe(R"doc(Get valid count of bounding boxes given
a score threshold. Also moves valid boxes to the top of
input data.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(5)
.add_type_rel("GetValidCount", GetValidCountRel);


TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
Yao Wang committed
80 81 82 83 84 85 86 87

bool NMSRel(const Array<Type>& types,
            int num_inputs,
            const Attrs& attrs,
            const TypeReporter& reporter) {
  CHECK_EQ(types.size(), 3);
  const auto* data = types[0].as<TensorTypeNode>();
  const auto* valid_count = types[1].as<TensorTypeNode>();
88 89
  const NonMaximumSuppressionAttrs* param =
    attrs.as<NonMaximumSuppressionAttrs>();
Yao Wang committed
90 91 92 93 94 95
  const auto& dshape = data->shape;
  const auto& vshape = valid_count->shape;
  CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
  CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D.";

  // assign output type
96 97
  if (param->return_indices) {
    std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
98
    reporter->Assign(types[2], TensorType(oshape, DataType::Int(32)));
99
  } else {
100
    reporter->Assign(types[2], TensorType(dshape, data->dtype));
101
  }
Yao Wang committed
102 103 104 105 106 107
  return true;
}


Expr MakeNMS(Expr data,
             Expr valid_count,
108 109
             int max_output_size,
             double iou_threshold,
Yao Wang committed
110
             bool force_suppress,
111
             int top_k,
112 113
             int coord_start,
             int score_index,
114 115 116
             int id_index,
             bool return_indices,
             bool invalid_to_bottom) {
117
  auto attrs = make_object<NonMaximumSuppressionAttrs>();
118 119
  attrs->max_output_size = max_output_size;
  attrs->iou_threshold = iou_threshold;
Yao Wang committed
120
  attrs->force_suppress = force_suppress;
121
  attrs->top_k = top_k;
122 123
  attrs->coord_start = coord_start;
  attrs->score_index = score_index;
124 125 126 127
  attrs->id_index = id_index;
  attrs->return_indices = return_indices;
  attrs->invalid_to_bottom = invalid_to_bottom;
  static const Op& op = Op::Get("vision.non_max_suppression");
128
  return Call(op, {data, valid_count}, Attrs(attrs), {});
Yao Wang committed
129 130 131
}


132
TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression")
133
.set_body_typed(MakeNMS);
Yao Wang committed
134 135


136 137 138 139
RELAY_REGISTER_OP("vision.non_max_suppression")
.describe(R"doc(Non-maximum suppression. The input boxes should
be in the format of [class_id, score, left, top, right, bottom].
Set id_index to be -1 to ignore class_id axis.
Yao Wang committed
140 141 142 143 144 145 146 147 148
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.set_support_level(5)
.add_type_rel("NMS", NMSRel);

}  // namespace relay
}  // namespace tvm