/*! * Copyright (c) 2018 by Contributors * \file nms.cc * \brief Non-maximum suppression operators */ #include <tvm/relay/op.h> #include <tvm/relay/attrs/vision.h> namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(NMSAttrs); bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as<TensorTypeNode>(); const auto* valid_count = types[1].as<TensorTypeNode>(); const auto& dshape = data->shape; const auto& vshape = valid_count->shape; CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D."; // assign output type reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype)); return true; } Expr MakeNMS(Expr data, Expr valid_count, double overlap_threshold, bool force_suppress, int topk) { auto attrs = make_node<NMSAttrs>(); attrs->overlap_threshold = overlap_threshold; attrs->force_suppress = force_suppress; attrs->topk = topk; static const Op& op = Op::Get("vision.nms"); return CallNode::make(op, {data, valid_count}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op.vision._make.nms") .set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call<Expr, 5>(MakeNMS, args, rv); }); RELAY_REGISTER_OP("vision.nms") .describe(R"doc("Non-maximum suppression." )doc" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("data", "Tensor", "Input data.") .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") .set_support_level(5) .add_type_rel("NMS", NMSRel); } // namespace relay } // namespace tvm