Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
313e1d99
Commit
313e1d99
authored
Nov 13, 2018
by
Yao Wang
Committed by
Tianqi Chen
Nov 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][OP]NMS (#1929)
parent
1f2c8156
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
146 additions
and
3 deletions
+146
-3
include/tvm/relay/attrs/vision.h
+16
-0
python/tvm/relay/op/vision/__init__.py
+1
-0
python/tvm/relay/op/vision/nms.py
+36
-0
src/relay/op/vision/multibox_op.cc
+1
-2
src/relay/op/vision/nms.cc
+62
-0
tests/python/relay/test_op_level5.py
+30
-1
No files found.
include/tvm/relay/attrs/vision.h
View file @
313e1d99
...
...
@@ -40,6 +40,22 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
}
};
/*! \brief Attributes used in non_maximum_suppression operators */
struct
NMSAttrs
:
public
tvm
::
AttrsNode
<
NMSAttrs
>
{
double
overlap_threshold
;
bool
force_suppress
;
int
topk
;
TVM_DECLARE_ATTRS
(
NMSAttrs
,
"relay.attrs.NMSAttrs"
)
{
TVM_ATTR_FIELD
(
overlap_threshold
).
set_default
(
0
.
5
)
.
describe
(
"Non-maximum suppression threshold."
);
TVM_ATTR_FIELD
(
force_suppress
).
set_default
(
false
)
.
describe
(
"Suppress all detections regardless of class_id."
);
TVM_ATTR_FIELD
(
topk
).
set_default
(
-
1
)
.
describe
(
"Keep maximum top k detections before nms, -1 for no limit."
);
}
};
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_
python/tvm/relay/op/vision/__init__.py
View file @
313e1d99
...
...
@@ -3,3 +3,4 @@
from
__future__
import
absolute_import
as
_abs
from
.multibox
import
*
from
.nms
import
*
python/tvm/relay/op/vision/nms.py
0 → 100644
View file @
313e1d99
"""Non-maximum suppression operations."""
from
__future__
import
absolute_import
as
_abs
from
.
import
_make
def
nms
(
data
,
valid_count
,
overlap_threshold
=
0.5
,
force_suppress
=
False
,
topk
=-
1
):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
valid_count : relay.Expr
1-D tensor for valid number of boxes.
overlap_threshold : float, optional
Non-maximum suppression threshold.
force_suppress : bool, optional
Suppress all detections regardless of class_id.
topk : int, optional
Keep maximum top k detections before nms, -1 for no limit.
Returns
-------
out : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
"""
return
_make
.
nms
(
data
,
valid_count
,
overlap_threshold
,
force_suppress
,
topk
)
src/relay/op/vision/multibox_op.cc
View file @
313e1d99
...
...
@@ -5,7 +5,6 @@
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
#include <vector>
namespace
tvm
{
namespace
relay
{
...
...
@@ -66,7 +65,7 @@ RELAY_REGISTER_OP("vision.multibox_prior")
.
set_attrs_type_key
(
"relay.attrs.MultiBoxPriorAttrs"
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
4
)
.
set_support_level
(
5
)
.
add_type_rel
(
"MultiBoxPrior"
,
MultiboxPriorRel
);
}
// namespace relay
...
...
src/relay/op/vision/nms.cc
0 → 100644
View file @
313e1d99
/*!
* Copyright (c) 2018 by Contributors
* \file nms.cc
* \brief Non-maximum suppression operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/vision.h>
namespace
tvm
{
namespace
relay
{
TVM_REGISTER_NODE_TYPE
(
NMSAttrs
);
bool
NMSRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
valid_count
=
types
[
1
].
as
<
TensorTypeNode
>
();
const
auto
&
dshape
=
data
->
shape
;
const
auto
&
vshape
=
valid_count
->
shape
;
CHECK_EQ
(
dshape
.
size
(),
3
)
<<
"Input data should be 3-D."
;
CHECK_EQ
(
vshape
.
size
(),
1
)
<<
"Input valid count should be 1-D."
;
// assign output type
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
dshape
,
data
->
dtype
));
return
true
;
}
Expr
MakeNMS
(
Expr
data
,
Expr
valid_count
,
double
overlap_threshold
,
bool
force_suppress
,
int
topk
)
{
auto
attrs
=
make_node
<
NMSAttrs
>
();
attrs
->
overlap_threshold
=
overlap_threshold
;
attrs
->
force_suppress
=
force_suppress
;
attrs
->
topk
=
topk
;
static
const
Op
&
op
=
Op
::
Get
(
"vision.nms"
);
return
CallNode
::
make
(
op
,
{
data
,
valid_count
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.vision._make.nms"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
5
>
(
MakeNMS
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"vision.nms"
)
.
describe
(
R"doc("Non-maximum suppression."
)doc"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
.
add_argument
(
"valid_count"
,
"Tensor"
,
"Number of valid anchor boxes."
)
.
set_support_level
(
5
)
.
add_type_rel
(
"NMS"
,
NMSRel
);
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_level5.py
View file @
313e1d99
...
...
@@ -18,7 +18,6 @@ def test_resize_infer_type():
assert
zz
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
100
,
200
),
"int8"
)
def
test_multibox_prior
():
sizes
=
(
0.3
,
1.5
,
0.7
)
ratios
=
(
1.3
,
2.4
)
...
...
@@ -44,6 +43,36 @@ def test_multibox_prior():
(
1
,
h
*
w
,
4
),
"float32"
)
def
test_nms
():
num_anchors
=
60
overlap_threshold
=
0.5
force_suppress
=
True
nms_topk
=
10
n
=
tvm
.
var
(
"n"
)
x0
=
relay
.
var
(
"x0"
,
relay
.
ty
.
TensorType
((
n
,
num_anchors
,
6
),
"float32"
))
x1
=
relay
.
var
(
"x1"
,
relay
.
ty
.
TensorType
((
n
,),
"int"
))
z
=
relay
.
vision
.
nms
(
x0
,
x1
,
overlap_threshold
,
force_suppress
,
nms_topk
)
assert
"overlap_threshold"
in
z
.
astext
()
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
(
(
n
,
num_anchors
,
6
),
"float32"
)
n
=
tvm
.
var
(
"n"
)
x0
=
relay
.
var
(
"x0"
,
relay
.
ty
.
TensorType
((
n
,
num_anchors
,
6
),
"float32"
))
x1
=
relay
.
var
(
"x1"
,
relay
.
ty
.
TensorType
((
n
,),
"int"
))
z
=
relay
.
vision
.
nms
(
x0
,
x1
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
(
(
n
,
num_anchors
,
6
),
"float32"
)
if
__name__
==
"__main__"
:
test_resize_infer_type
()
test_multibox_prior
()
test_nms
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment