Commit 313e1d99 by Yao Wang Committed by Tianqi Chen

[Relay][OP]NMS (#1929)

parent 1f2c8156
...@@ -40,6 +40,22 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> { ...@@ -40,6 +40,22 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
} }
}; };
/*! \brief Attributes used in non_maximum_suppression operators */
struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
double overlap_threshold;
bool force_suppress;
int topk;
TVM_DECLARE_ATTRS(NMSAttrs, "relay.attrs.NMSAttrs") {
TVM_ATTR_FIELD(overlap_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_ #endif // TVM_RELAY_ATTRS_VISION_H_
...@@ -3,3 +3,4 @@ ...@@ -3,3 +3,4 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .multibox import * from .multibox import *
from .nms import *
"""Non-maximum suppression operations."""
from __future__ import absolute_import as _abs
from . import _make
def nms(data,
valid_count,
overlap_threshold=0.5,
force_suppress=False,
topk=-1):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
valid_count : relay.Expr
1-D tensor for valid number of boxes.
overlap_threshold : float, optional
Non-maximum suppression threshold.
force_suppress : bool, optional
Suppress all detections regardless of class_id.
topk : int, optional
Keep maximum top k detections before nms, -1 for no limit.
Returns
-------
out : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
"""
return _make.nms(data, valid_count, overlap_threshold, force_suppress, topk)
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h> #include <tvm/relay/attrs/vision.h>
#include <vector>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -66,7 +65,7 @@ RELAY_REGISTER_OP("vision.multibox_prior") ...@@ -66,7 +65,7 @@ RELAY_REGISTER_OP("vision.multibox_prior")
.set_attrs_type_key("relay.attrs.MultiBoxPriorAttrs") .set_attrs_type_key("relay.attrs.MultiBoxPriorAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(4) .set_support_level(5)
.add_type_rel("MultiBoxPrior", MultiboxPriorRel); .add_type_rel("MultiBoxPrior", MultiboxPriorRel);
} // namespace relay } // namespace relay
......
/*!
* Copyright (c) 2018 by Contributors
* \file nms.cc
* \brief Non-maximum suppression operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(NMSAttrs);
bool NMSRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* valid_count = types[1].as<TensorTypeNode>();
const auto& dshape = data->shape;
const auto& vshape = valid_count->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D.";
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
return true;
}
Expr MakeNMS(Expr data,
Expr valid_count,
double overlap_threshold,
bool force_suppress,
int topk) {
auto attrs = make_node<NMSAttrs>();
attrs->overlap_threshold = overlap_threshold;
attrs->force_suppress = force_suppress;
attrs->topk = topk;
static const Op& op = Op::Get("vision.nms");
return CallNode::make(op, {data, valid_count}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.nms")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeNMS, args, rv);
});
RELAY_REGISTER_OP("vision.nms")
.describe(R"doc("Non-maximum suppression."
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.set_support_level(5)
.add_type_rel("NMS", NMSRel);
} // namespace relay
} // namespace tvm
...@@ -18,7 +18,6 @@ def test_resize_infer_type(): ...@@ -18,7 +18,6 @@ def test_resize_infer_type():
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
def test_multibox_prior(): def test_multibox_prior():
sizes = (0.3, 1.5, 0.7) sizes = (0.3, 1.5, 0.7)
ratios = (1.3, 2.4) ratios = (1.3, 2.4)
...@@ -44,6 +43,36 @@ def test_multibox_prior(): ...@@ -44,6 +43,36 @@ def test_multibox_prior():
(1, h * w, 4), "float32") (1, h * w, 4), "float32")
def test_nms():
num_anchors = 60
overlap_threshold = 0.5
force_suppress = True
nms_topk = 10
n = tvm.var("n")
x0 = relay.var("x0", relay.ty.TensorType((n, num_anchors, 6), "float32"))
x1 = relay.var("x1", relay.ty.TensorType((n,), "int"))
z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, nms_topk)
assert "overlap_threshold" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(
(n, num_anchors, 6), "float32")
n = tvm.var("n")
x0 = relay.var("x0", relay.ty.TensorType((n, num_anchors, 6), "float32"))
x1 = relay.var("x1", relay.ty.TensorType((n,), "int"))
z = relay.vision.nms(x0, x1)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(
(n, num_anchors, 6), "float32")
if __name__ == "__main__": if __name__ == "__main__":
test_resize_infer_type() test_resize_infer_type()
test_multibox_prior() test_multibox_prior()
test_nms()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment