Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d9a9a8b6
Commit
d9a9a8b6
authored
Dec 20, 2018
by
Zhi
Committed by
Haichen Shen
Dec 20, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][op] multibox_transform_loc (#2315)
parent
10b6e7e0
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
183 additions
and
0 deletions
+183
-0
include/tvm/relay/attrs/vision.h
+18
-0
python/tvm/relay/op/vision/multibox.py
+36
-0
src/relay/op/vision/multibox_op.cc
+73
-0
tests/python/relay/test_op_level5.py
+56
-0
No files found.
include/tvm/relay/attrs/vision.h
View file @
d9a9a8b6
...
...
@@ -40,6 +40,24 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
}
};
struct
MultiBoxTransformLocAttrs
:
public
tvm
::
AttrsNode
<
MultiBoxTransformLocAttrs
>
{
bool
clip
;
double
threshold
;
Array
<
IndexExpr
>
variances
;
TVM_DECLARE_ATTRS
(
MultiBoxTransformLocAttrs
,
"relay.attrs.MultiBoxTransformLocAttrs"
)
{
TVM_ATTR_FIELD
(
clip
).
set_default
(
true
)
.
describe
(
"Clip out-of-boundary boxes."
);
TVM_ATTR_FIELD
(
threshold
).
set_default
(
0
.
01
)
.
describe
(
"Threshold to be a positive prediction."
);
TVM_ATTR_FIELD
(
variances
)
.
set_default
(
Array
<
IndexExpr
>
({
0
.
1
f
,
0
.
1
f
,
0
.
2
f
,
0
.
2
f
}))
.
describe
(
"Variances to be decoded from box regression output."
);
}
};
/*! \brief Attributes used in non_maximum_suppression operators */
struct
NMSAttrs
:
public
tvm
::
AttrsNode
<
NMSAttrs
>
{
double
overlap_threshold
;
...
...
python/tvm/relay/op/vision/multibox.py
View file @
d9a9a8b6
...
...
@@ -36,3 +36,39 @@ def multibox_prior(data,
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
return
_make
.
multibox_prior
(
data
,
sizes
,
ratios
,
steps
,
offsets
,
clip
)
def
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
clip
=
True
,
threshold
=
0.01
,
variance
=
(
0.1
,
0.1
,
0.2
,
0.2
)):
"""Location transformation for multibox detection
Parameters
----------
cls_prob : tvm.relay.Expr
Class probabilities.
loc_pred : tvm.relay.Expr
Location regression predictions.
anchor : tvm.relay.Expr
Prior anchor boxes.
clip : boolean, optional
Whether to clip out-of-boundary boxes.
threshold : double, optional
Threshold to be a positive prediction.
variance : Tuple of float, optional
Variances to be decoded from box regression output.
Returns
-------
ret : tuple of tvm.relay.Expr
"""
return
_make
.
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
clip
,
threshold
,
variance
)
src/relay/op/vision/multibox_op.cc
View file @
d9a9a8b6
...
...
@@ -68,5 +68,78 @@ RELAY_REGISTER_OP("vision.multibox_prior")
.
set_support_level
(
5
)
.
add_type_rel
(
"MultiBoxPrior"
,
MultiboxPriorRel
);
TVM_REGISTER_NODE_TYPE
(
MultiBoxTransformLocAttrs
);
bool
MultiBoxTransformLocRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
4
);
const
auto
*
cls_prob
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
loc_pred
=
types
[
1
].
as
<
TensorTypeNode
>
();
const
auto
*
anchor
=
types
[
2
].
as
<
TensorTypeNode
>
();
CHECK
(
cls_prob
!=
nullptr
&&
loc_pred
!=
nullptr
&&
anchor
!=
nullptr
);
const
auto
&
cls_shape
=
cls_prob
->
shape
;
const
auto
&
loc_shape
=
loc_pred
->
shape
;
const
auto
&
anchor_shape
=
anchor
->
shape
;
CHECK_EQ
(
cls_shape
.
size
(),
3U
)
<<
"The dimension of class probability should be 3, but received "
<<
cls_shape
.
size
();
CHECK_EQ
(
loc_shape
.
size
(),
2U
)
<<
"The dimension of location prediction should be 2, but received "
<<
loc_shape
.
size
();
CHECK_EQ
(
anchor_shape
.
size
(),
3U
)
<<
"The dimension of anchor should be 3, but received "
<<
anchor_shape
.
size
();
CHECK
(
reporter
->
AssertEQ
(
cls_shape
[
2
],
anchor_shape
[
1
]))
<<
"Number of anchors mismatch found"
;
CHECK
(
reporter
->
AssertEQ
(
cls_shape
[
2
]
*
4
,
loc_shape
[
1
]))
<<
"# anchors mismatch with # loc."
;
CHECK
(
reporter
->
Assert
(
anchor_shape
[
1
]
>
0
))
<<
"Number of anchors must > 0."
;
CHECK
(
reporter
->
AssertEQ
(
anchor_shape
[
2
],
4
));
std
::
vector
<
IndexExpr
>
oshape0
({
cls_shape
[
0
],
anchor_shape
[
1
],
6
});
std
::
vector
<
IndexExpr
>
oshape1
({
cls_shape
[
0
]});
std
::
vector
<
Type
>
fields
;
fields
.
push_back
(
TensorTypeNode
::
make
(
oshape0
,
cls_prob
->
dtype
));
fields
.
push_back
(
TensorTypeNode
::
make
(
oshape1
,
Int
(
32
)));
// assign output type
reporter
->
Assign
(
types
[
3
],
TupleTypeNode
::
make
(
Array
<
Type
>
(
fields
)));
return
true
;
}
Expr
MakeMultiBoxTransformLoc
(
Expr
cls_prob
,
Expr
loc_pred
,
Expr
anchor
,
bool
clip
,
double
threshold
,
Array
<
IndexExpr
>
variances
)
{
auto
attrs
=
make_node
<
MultiBoxTransformLocAttrs
>
();
attrs
->
clip
=
std
::
move
(
clip
);
attrs
->
threshold
=
std
::
move
(
threshold
);
attrs
->
variances
=
std
::
move
(
variances
);
static
const
Op
&
op
=
Op
::
Get
(
"vision.multibox_transform_loc"
);
return
CallNode
::
make
(
op
,
{
cls_prob
,
loc_pred
,
anchor
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.vision._make.multibox_transform_loc"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
6
>
(
MakeMultiBoxTransformLoc
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"vision.multibox_transform_loc"
)
.
describe
(
R"doc("Location transformation for multibox detection."
)doc"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.MultiBoxTransformLocAttrs"
)
.
set_num_inputs
(
3
)
.
add_argument
(
"cls_prob"
,
"Tensor"
,
"Class probabilities."
)
.
add_argument
(
"loc_pred"
,
"Tensor"
,
"Location regression predictions."
)
.
add_argument
(
"anchor"
,
"Tensor"
,
"Multibox prior anchor boxes"
)
.
add_type_rel
(
"MultiBoxTransformLoc"
,
MultiBoxTransformLocRel
)
.
set_support_level
(
5
);
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_level5.py
View file @
d9a9a8b6
...
...
@@ -102,8 +102,64 @@ def test_nms():
(
n
,
num_anchors
,
6
),
"float32"
)
def
test_multibox_transform_loc
():
def
test_default_value
():
num_anchors
=
5
num_classes
=
5
cls_prob
=
relay
.
var
(
"cls_prob"
,
relay
.
ty
.
TensorType
((
1
,
num_anchors
,
num_classes
),
"float32"
))
loc_pred
=
relay
.
var
(
"loc_pred"
,
relay
.
ty
.
TensorType
((
1
,
num_anchors
*
4
),
"float32"
))
anchors
=
relay
.
var
(
"anchors"
,
relay
.
ty
.
TensorType
((
1
,
num_anchors
,
4
),
"float32"
))
ret
=
relay
.
vision
.
multibox_transform_loc
(
cls_prob
=
cls_prob
,
loc_pred
=
loc_pred
,
anchor
=
anchors
)
ret
=
relay
.
ir_pass
.
infer_type
(
ret
)
ref_type
=
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
1
,
num_anchors
,
6
),
"float32"
),
relay
.
ty
.
TensorType
((
1
,
),
"int"
)
]))
assert
ret
.
checked_type
==
ref_type
def
test_threshold
():
num_anchors
=
5
num_classes
=
5
n
=
tvm
.
var
(
"n"
)
cls_prob
=
relay
.
var
(
"cls_prob"
,
relay
.
ty
.
TensorType
((
n
,
num_anchors
,
num_classes
),
"float32"
))
loc_pred
=
relay
.
var
(
"loc_pred"
,
relay
.
ty
.
TensorType
((
n
,
num_anchors
*
4
),
"float32"
))
anchors
=
relay
.
var
(
"anchors"
,
relay
.
ty
.
TensorType
((
1
,
num_anchors
,
4
),
"float32"
))
threshold
=
0.02
variance
=
(
0.2
,
0.2
,
0.3
,
0.3
)
ret
=
relay
.
vision
.
multibox_transform_loc
(
cls_prob
=
cls_prob
,
loc_pred
=
loc_pred
,
anchor
=
anchors
,
threshold
=
threshold
,
variance
=
variance
)
ret
=
relay
.
ir_pass
.
infer_type
(
ret
)
ref_type
=
relay
.
ty
.
TupleType
(
tvm
.
convert
([
relay
.
ty
.
TensorType
((
n
,
num_anchors
,
6
),
"float32"
),
relay
.
ty
.
TensorType
((
n
,
),
"int"
)
]))
assert
ret
.
checked_type
==
ref_type
test_default_value
()
test_threshold
()
if
__name__
==
"__main__"
:
test_resize_infer_type
()
test_resize
()
test_multibox_prior
()
test_multibox_transform_loc
()
test_nms
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment