Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d2f29ba5
Commit
d2f29ba5
authored
Mar 11, 2019
by
Yao Wang
Committed by
Tianqi Chen
Mar 11, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Object Detection] Gluoncv SSD support on CPU (#2353)
parent
f7eff095
Show whitespace changes
Inline
Side-by-side
Showing
30 changed files
with
1048 additions
and
374 deletions
+1048
-374
include/tvm/relay/attrs/vision.h
+30
-7
nnvm/include/nnvm/top/nn.h
+19
-6
nnvm/python/nnvm/frontend/mxnet.py
+3
-3
nnvm/python/nnvm/top/vision.py
+14
-9
nnvm/src/top/vision/nms.cc
+14
-5
nnvm/tests/python/compiler/test_top_level4.py
+8
-8
nnvm/tests/python/frontend/mxnet/test_forward.py
+0
-1
python/tvm/relay/frontend/mxnet.py
+51
-3
python/tvm/relay/op/transform.py
+1
-1
python/tvm/relay/op/vision/__init__.py
+1
-1
python/tvm/relay/op/vision/_vision.py
+29
-7
python/tvm/relay/op/vision/nms.py
+50
-6
src/relay/op/tensor/transform.cc
+10
-9
src/relay/op/vision/multibox_op.cc
+4
-2
src/relay/op/vision/nms.cc
+75
-11
tests/python/frontend/mxnet/test_forward.py
+7
-1
tests/python/relay/test_op_level10.py
+1
-0
tests/python/relay/test_op_level5.py
+69
-17
topi/include/topi/nn/l2_normalize.h
+6
-1
topi/python/topi/cuda/nms.py
+26
-17
topi/python/topi/cuda/ssd/multibox.py
+2
-2
topi/python/topi/cuda/vision.py
+17
-0
topi/python/topi/generic/vision.py
+17
-0
topi/python/topi/testing/__init__.py
+1
-0
topi/python/topi/testing/slice_axis_python.py
+34
-0
topi/python/topi/vision/nms.py
+248
-106
topi/python/topi/vision/ssd/multibox.py
+138
-144
topi/tests/python/test_topi_vision.py
+68
-6
tutorials/frontend/deploy_ssd_gluoncv.py
+104
-0
tutorials/nnvm/deploy_ssd_mxnet.py
+1
-1
No files found.
include/tvm/relay/attrs/vision.h
View file @
d2f29ba5
...
@@ -58,19 +58,42 @@ struct MultiBoxTransformLocAttrs
...
@@ -58,19 +58,42 @@ struct MultiBoxTransformLocAttrs
}
}
};
};
/*! \brief Attributes used in non_maximum_suppression operators */
/*! \brief Attributes used in get_valid_counts operator */
struct
NMSAttrs
:
public
tvm
::
AttrsNode
<
NMSAttrs
>
{
struct
GetValidCountsAttrs
:
public
tvm
::
AttrsNode
<
GetValidCountsAttrs
>
{
double
overlap_threshold
;
double
score_threshold
;
TVM_DECLARE_ATTRS
(
GetValidCountsAttrs
,
"relay.attrs.GetValidCountsAttrs"
)
{
TVM_ATTR_FIELD
(
score_threshold
).
set_default
(
0
.
0
)
.
describe
(
"Lower limit of score for valid bounding boxes."
);
}
};
/*! \brief Attributes used in non_maximum_suppression operator */
struct
NonMaximumSuppressionAttrs
:
public
tvm
::
AttrsNode
<
NonMaximumSuppressionAttrs
>
{
int
max_output_size
;
double
iou_threshold
;
bool
force_suppress
;
bool
force_suppress
;
int
topk
;
int
top_k
;
int
id_index
;
bool
return_indices
;
bool
invalid_to_bottom
;
TVM_DECLARE_ATTRS
(
NMSAttrs
,
"relay.attrs.NMSAttrs"
)
{
TVM_DECLARE_ATTRS
(
NonMaximumSuppressionAttrs
,
"relay.attrs.NonMaximumSuppressionAttrs"
)
{
TVM_ATTR_FIELD
(
overlap_threshold
).
set_default
(
0
.
5
)
TVM_ATTR_FIELD
(
max_output_size
).
set_default
(
-
1
)
.
describe
(
"Max number of output valid boxes for each instance."
"By default all valid boxes are returned."
);
TVM_ATTR_FIELD
(
iou_threshold
).
set_default
(
0
.
5
)
.
describe
(
"Non-maximum suppression threshold."
);
.
describe
(
"Non-maximum suppression threshold."
);
TVM_ATTR_FIELD
(
force_suppress
).
set_default
(
false
)
TVM_ATTR_FIELD
(
force_suppress
).
set_default
(
false
)
.
describe
(
"Suppress all detections regardless of class_id."
);
.
describe
(
"Suppress all detections regardless of class_id."
);
TVM_ATTR_FIELD
(
top
k
).
set_default
(
-
1
)
TVM_ATTR_FIELD
(
top_
k
).
set_default
(
-
1
)
.
describe
(
"Keep maximum top k detections before nms, -1 for no limit."
);
.
describe
(
"Keep maximum top k detections before nms, -1 for no limit."
);
TVM_ATTR_FIELD
(
id_index
).
set_default
(
0
)
.
describe
(
"Axis index of id."
);
TVM_ATTR_FIELD
(
return_indices
).
set_default
(
true
)
.
describe
(
"Whether to return box indices in input data."
);
TVM_ATTR_FIELD
(
invalid_to_bottom
).
set_default
(
false
)
.
describe
(
"Whether to move all invalid bounding boxes to the bottom."
);
}
}
};
};
...
...
nnvm/include/nnvm/top/nn.h
View file @
d2f29ba5
...
@@ -443,17 +443,30 @@ struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocPa
...
@@ -443,17 +443,30 @@ struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocPa
}
}
};
};
struct
NMSParam
:
public
dmlc
::
Parameter
<
NMSParam
>
{
struct
NonMaximumSuppressionParam
:
public
dmlc
::
Parameter
<
NonMaximumSuppressionParam
>
{
float
nms_threshold
;
bool
return_indices
;
float
iou_threshold
;
bool
force_suppress
;
bool
force_suppress
;
int
nms_topk
;
int
top_k
;
DMLC_DECLARE_PARAMETER
(
NMSParam
)
{
int
id_index
;
DMLC_DECLARE_FIELD
(
nms_threshold
).
set_default
(
0
.
5
)
int
max_output_size
;
bool
invalid_to_bottom
;
DMLC_DECLARE_PARAMETER
(
NonMaximumSuppressionParam
)
{
DMLC_DECLARE_FIELD
(
max_output_size
).
set_default
(
-
1
)
.
describe
(
"Max number of output valid boxes for each instance."
"By default all valid boxes are returned."
);
DMLC_DECLARE_FIELD
(
iou_threshold
).
set_default
(
0
.
5
)
.
describe
(
"Non-maximum suppression threshold."
);
.
describe
(
"Non-maximum suppression threshold."
);
DMLC_DECLARE_FIELD
(
force_suppress
).
set_default
(
false
)
DMLC_DECLARE_FIELD
(
force_suppress
).
set_default
(
false
)
.
describe
(
"Suppress all detections regardless of class_id."
);
.
describe
(
"Suppress all detections regardless of class_id."
);
DMLC_DECLARE_FIELD
(
nms_top
k
).
set_default
(
-
1
)
DMLC_DECLARE_FIELD
(
top_
k
).
set_default
(
-
1
)
.
describe
(
"Keep maximum top k detections before nms, -1 for no limit."
);
.
describe
(
"Keep maximum top k detections before nms, -1 for no limit."
);
DMLC_DECLARE_FIELD
(
id_index
).
set_default
(
0
)
.
describe
(
"Axis index of id."
);
DMLC_DECLARE_FIELD
(
return_indices
).
set_default
(
true
)
.
describe
(
"Whether to return box indices in input data."
);
DMLC_DECLARE_FIELD
(
invalid_to_bottom
).
set_default
(
false
)
.
describe
(
"Whether to move all invalid bounding boxes to the bottom."
);
}
}
};
};
...
...
nnvm/python/nnvm/frontend/mxnet.py
View file @
d2f29ba5
...
@@ -245,11 +245,11 @@ def _contrib_multibox_detection(inputs, attrs):
...
@@ -245,11 +245,11 @@ def _contrib_multibox_detection(inputs, attrs):
if
attrs
.
get
(
'variances'
)
is
not
None
else
(
0.1
,
0.1
,
0.2
,
0.2
)
if
attrs
.
get
(
'variances'
)
is
not
None
else
(
0.1
,
0.1
,
0.2
,
0.2
)
nms_topk
=
attrs
.
get
(
'nms_topk'
)
or
-
1
nms_topk
=
attrs
.
get
(
'nms_topk'
)
or
-
1
new_attrs0
=
{
'clip'
:
clip
,
'threshold'
:
float
(
threshold
),
'variances'
:
variances
}
new_attrs0
=
{
'clip'
:
clip
,
'threshold'
:
float
(
threshold
),
'variances'
:
variances
}
new_attrs1
=
{
'
nms_threshold'
:
float
(
nms_threshold
),
'force_suppress'
:
force_suppress
,
new_attrs1
=
{
'
return_indices'
:
False
,
'iou_threshold'
:
float
(
nms_threshold
)
,
'
nms_top
k'
:
int
(
nms_topk
)}
'
force_suppress'
:
force_suppress
,
'top_
k'
:
int
(
nms_topk
)}
data
,
valid_count
=
_get_nnvm_op
(
'multibox_transform_loc'
)(
inputs
[
0
],
inputs
[
1
],
data
,
valid_count
=
_get_nnvm_op
(
'multibox_transform_loc'
)(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
**
new_attrs0
)
inputs
[
2
],
**
new_attrs0
)
return
_get_nnvm_op
(
'n
ms
'
)(
data
,
valid_count
,
**
new_attrs1
)
return
_get_nnvm_op
(
'n
on_max_suppression
'
)(
data
,
valid_count
,
**
new_attrs1
)
def
_elemwise_sum
(
inputs
,
_
):
def
_elemwise_sum
(
inputs
,
_
):
new_attrs
=
{
'num_args'
:
len
(
inputs
)}
new_attrs
=
{
'num_args'
:
len
(
inputs
)}
...
...
nnvm/python/nnvm/top/vision.py
View file @
d2f29ba5
...
@@ -61,20 +61,25 @@ def compute_multibox_transform_loc(attrs, inputs, _):
...
@@ -61,20 +61,25 @@ def compute_multibox_transform_loc(attrs, inputs, _):
reg
.
register_pattern
(
"multibox_detection"
,
OpPattern
.
OPAQUE
)
reg
.
register_pattern
(
"multibox_detection"
,
OpPattern
.
OPAQUE
)
# non-maximum suppression
# non-maximum suppression
@reg.register_schedule
(
"n
ms
"
)
@reg.register_schedule
(
"n
on_max_suppression
"
)
def
schedule_nms
(
_
,
outs
,
target
):
def
schedule_nms
(
_
,
outs
,
target
):
"""Schedule definition of n
ms
"""
"""Schedule definition of n
on_max_suppression
"""
with
tvm
.
target
.
create
(
target
):
with
tvm
.
target
.
create
(
target
):
return
topi
.
generic
.
schedule_nms
(
outs
)
return
topi
.
generic
.
schedule_nms
(
outs
)
@reg.register_compute
(
"n
ms
"
)
@reg.register_compute
(
"n
on_max_suppression
"
)
def
compute_nms
(
attrs
,
inputs
,
_
):
def
compute_nms
(
attrs
,
inputs
,
_
):
"""Compute definition of nms"""
"""Compute definition of non_max_suppression"""
nms_threshold
=
attrs
.
get_float
(
'nms_threshold'
)
return_indices
=
attrs
.
get_bool
(
'return_indices'
)
max_output_size
=
attrs
.
get_int
(
'max_output_size'
)
iou_threshold
=
attrs
.
get_float
(
'iou_threshold'
)
force_suppress
=
attrs
.
get_bool
(
'force_suppress'
)
force_suppress
=
attrs
.
get_bool
(
'force_suppress'
)
nms_topk
=
attrs
.
get_int
(
'nms_topk'
)
top_k
=
attrs
.
get_int
(
'top_k'
)
id_index
=
attrs
.
get_int
(
'id_index'
)
invalid_to_bottom
=
attrs
.
get_bool
(
'invalid_to_bottom'
)
return
topi
.
vision
.
nms
(
inputs
[
0
],
inputs
[
1
],
nms_threshold
,
return
topi
.
vision
.
non_max_suppression
(
inputs
[
0
],
inputs
[
1
],
max_output_size
,
force_suppress
,
nms_topk
)
iou_threshold
,
force_suppress
,
top_k
,
id_index
,
return_indices
,
invalid_to_bottom
)
reg
.
register_pattern
(
"n
ms
"
,
OpPattern
.
OPAQUE
)
reg
.
register_pattern
(
"n
on_max_suppression
"
,
OpPattern
.
OPAQUE
)
nnvm/src/top/vision/nms.cc
View file @
d2f29ba5
...
@@ -19,11 +19,13 @@ using compiler::FTVMCompute;
...
@@ -19,11 +19,13 @@ using compiler::FTVMCompute;
using
tvm
::
Tensor
;
using
tvm
::
Tensor
;
using
tvm
::
Array
;
using
tvm
::
Array
;
DMLC_REGISTER_PARAMETER
(
N
MS
Param
);
DMLC_REGISTER_PARAMETER
(
N
onMaximumSuppression
Param
);
bool
NMSShape
(
const
NodeAttrs
&
attrs
,
bool
NMSShape
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
TShape
>
*
in_attrs
,
std
::
vector
<
TShape
>
*
in_attrs
,
std
::
vector
<
TShape
>
*
out_attrs
)
{
std
::
vector
<
TShape
>
*
out_attrs
)
{
const
NonMaximumSuppressionParam
&
param
=
nnvm
::
get
<
NonMaximumSuppressionParam
>
(
attrs
.
parsed
);
CHECK_EQ
(
in_attrs
->
size
(),
2U
)
<<
"Inputs: [data, valid_count]"
;
CHECK_EQ
(
in_attrs
->
size
(),
2U
)
<<
"Inputs: [data, valid_count]"
;
TShape
dshape
=
in_attrs
->
at
(
0
);
TShape
dshape
=
in_attrs
->
at
(
0
);
TShape
vshape
=
in_attrs
->
at
(
1
);
TShape
vshape
=
in_attrs
->
at
(
1
);
...
@@ -33,7 +35,14 @@ bool NMSShape(const NodeAttrs& attrs,
...
@@ -33,7 +35,14 @@ bool NMSShape(const NodeAttrs& attrs,
"(batch_size, num_anchors, 6)."
;
"(batch_size, num_anchors, 6)."
;
CHECK_EQ
(
dshape
[
0
],
vshape
[
0
])
<<
"batch_size mismatch."
;
CHECK_EQ
(
dshape
[
0
],
vshape
[
0
])
<<
"batch_size mismatch."
;
out_attrs
->
clear
();
out_attrs
->
clear
();
if
(
param
.
return_indices
)
{
TShape
oshape
=
TShape
(
2
);
oshape
[
0
]
=
dshape
[
0
];
oshape
[
1
]
=
dshape
[
1
];
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_attrs
,
0
,
oshape
);
}
else
{
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_attrs
,
0
,
dshape
);
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_attrs
,
0
,
dshape
);
}
return
true
;
return
true
;
}
}
...
@@ -56,15 +65,15 @@ inline bool NMSInferLayout(const NodeAttrs& attrs,
...
@@ -56,15 +65,15 @@ inline bool NMSInferLayout(const NodeAttrs& attrs,
return
true
;
return
true
;
}
}
NNVM_REGISTER_OP
(
n
ms
)
NNVM_REGISTER_OP
(
n
on_max_suppression
)
.
describe
(
R"doc("Non-maximum suppression."
.
describe
(
R"doc("Non-maximum suppression."
)doc"
NNVM_ADD_FILELINE
)
)doc"
NNVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
set_num_outputs
(
1
)
.
set_attr_parser
(
ParamParser
<
N
MS
Param
>
)
.
set_attr_parser
(
ParamParser
<
N
onMaximumSuppression
Param
>
)
.
set_attr
<
FGetAttrDict
>
(
"FGetAttrDict"
,
.
set_attr
<
FGetAttrDict
>
(
"FGetAttrDict"
,
ParamGetAttrDict
<
N
MS
Param
>
)
ParamGetAttrDict
<
N
onMaximumSuppression
Param
>
)
.
add_arguments
(
N
MS
Param
::
__FIELDS__
())
.
add_arguments
(
N
onMaximumSuppression
Param
::
__FIELDS__
())
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
.
add_argument
(
"valid_count"
,
"Tensor"
,
"Number of valid anchor boxes."
)
.
add_argument
(
"valid_count"
,
"Tensor"
,
"Number of valid anchor boxes."
)
.
set_attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
NodeAttrs
&
attrs
)
{
.
set_attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
NodeAttrs
&
attrs
)
{
...
...
nnvm/tests/python/compiler/test_top_level4.py
View file @
d2f29ba5
...
@@ -550,7 +550,7 @@ def test_multibox_transform_loc():
...
@@ -550,7 +550,7 @@ def test_multibox_transform_loc():
anchors
=
sym
.
Variable
(
"anchors"
)
anchors
=
sym
.
Variable
(
"anchors"
)
transform_loc_data
,
valid_count
=
sym
.
multibox_transform_loc
(
cls_prob
=
cls_prob
,
loc_pred
=
loc_preds
,
transform_loc_data
,
valid_count
=
sym
.
multibox_transform_loc
(
cls_prob
=
cls_prob
,
loc_pred
=
loc_preds
,
anchor
=
anchors
)
anchor
=
anchors
)
out
=
sym
.
n
ms
(
data
=
transform_loc_data
,
valid_count
=
valid_count
)
out
=
sym
.
n
on_max_suppression
(
data
=
transform_loc_data
,
valid_count
=
valid_count
,
return_indices
=
False
)
# Manually create test case
# Manually create test case
np_cls_prob
=
np
.
array
([[[
0.2
,
0.5
,
0.3
],
[
0.25
,
0.3
,
0.45
],
[
0.7
,
0.1
,
0.2
]]])
np_cls_prob
=
np
.
array
([[[
0.2
,
0.5
,
0.3
],
[
0.25
,
0.3
,
0.45
],
[
0.7
,
0.1
,
0.2
]]])
...
@@ -573,22 +573,22 @@ def test_multibox_transform_loc():
...
@@ -573,22 +573,22 @@ def test_multibox_transform_loc():
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
expected_np_out
.
shape
,
dtype
))
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
expected_np_out
.
shape
,
dtype
))
tvm
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
expected_np_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
expected_np_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_n
ms
():
def
test_n
on_max_suppression
():
dshape
=
(
1
,
5
,
6
)
dshape
=
(
1
,
5
,
6
)
data
=
sym
.
Variable
(
"data"
)
data
=
sym
.
Variable
(
"data"
)
valid_count
=
sym
.
Variable
(
"valid_count"
,
dtype
=
"int32"
)
valid_count
=
sym
.
Variable
(
"valid_count"
,
dtype
=
"int32"
)
nms
_threshold
=
0.7
iou
_threshold
=
0.7
force_suppress
=
True
force_suppress
=
True
nms_top
k
=
2
top_
k
=
2
out
=
sym
.
n
ms
(
data
=
data
,
valid_count
=
valid_count
,
nms_threshold
=
nms_threshold
,
out
=
sym
.
n
on_max_suppression
(
data
=
data
,
valid_count
=
valid_count
,
return_indices
=
False
,
force_suppress
=
force_suppress
,
nms_topk
=
nms_top
k
)
iou_threshold
=
iou_threshold
,
force_suppress
=
force_suppress
,
top_k
=
top_
k
)
np_data
=
np
.
array
([[[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
1
,
0.7
,
30
,
60
,
50
,
80
],
np_data
=
np
.
array
([[[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
1
,
0.7
,
30
,
60
,
50
,
80
],
[
0
,
0.4
,
4
,
21
,
19
,
40
],
[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.4
,
4
,
21
,
19
,
40
],
[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
1
,
0.5
,
100
,
60
,
70
,
110
]]])
.
astype
(
"float32"
)
[
1
,
0.5
,
100
,
60
,
70
,
110
]]])
.
astype
(
"float32"
)
np_valid_count
=
np
.
array
([
4
])
.
astype
(
"int32"
)
np_valid_count
=
np
.
array
([
4
])
.
astype
(
"int32"
)
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
0
,
0.4
,
4
,
21
,
19
,
40
],
[
-
1
,
0.9
,
35
,
61
,
52
,
79
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
target
=
"llvm"
target
=
"llvm"
...
@@ -726,7 +726,7 @@ if __name__ == "__main__":
...
@@ -726,7 +726,7 @@ if __name__ == "__main__":
test_flip
()
test_flip
()
test_multibox_prior
()
test_multibox_prior
()
test_multibox_transform_loc
()
test_multibox_transform_loc
()
test_n
ms
()
test_n
on_max_suppression
()
test_slice_like
()
test_slice_like
()
test_where
()
test_where
()
test_argmax
()
test_argmax
()
...
...
nnvm/tests/python/frontend/mxnet/test_forward.py
View file @
d2f29ba5
...
@@ -315,4 +315,3 @@ if __name__ == '__main__':
...
@@ -315,4 +315,3 @@ if __name__ == '__main__':
test_forward_slice
()
test_forward_slice
()
test_forward_maximum
()
test_forward_maximum
()
test_forward_minimum
()
test_forward_minimum
()
python/tvm/relay/frontend/mxnet.py
View file @
d2f29ba5
...
@@ -328,13 +328,14 @@ def _mx_multibox_detection(inputs, attrs):
...
@@ -328,13 +328,14 @@ def _mx_multibox_detection(inputs, attrs):
0.2
,
0.2
))
0.2
,
0.2
))
new_attrs1
=
{}
new_attrs1
=
{}
new_attrs1
[
"overlap_threshold"
]
=
attrs
.
get_float
(
"nms_threshold"
,
0.5
)
new_attrs1
[
"return_indices"
]
=
False
new_attrs1
[
"iou_threshold"
]
=
attrs
.
get_float
(
"nms_threshold"
,
0.5
)
new_attrs1
[
"force_suppress"
]
=
attrs
.
get_bool
(
"force_suppress"
,
False
)
new_attrs1
[
"force_suppress"
]
=
attrs
.
get_bool
(
"force_suppress"
,
False
)
new_attrs1
[
"topk"
]
=
attrs
.
get_int
(
"nms_topk"
,
-
1
)
new_attrs1
[
"top
_
k"
]
=
attrs
.
get_int
(
"nms_topk"
,
-
1
)
ret
=
_op
.
vision
.
multibox_transform_loc
(
inputs
[
0
],
inputs
[
1
],
ret
=
_op
.
vision
.
multibox_transform_loc
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
**
new_attrs0
)
inputs
[
2
],
**
new_attrs0
)
return
_op
.
vision
.
n
ms
(
ret
[
0
],
ret
[
1
],
**
new_attrs1
)
return
_op
.
vision
.
n
on_max_suppression
(
ret
[
0
],
ret
[
1
],
**
new_attrs1
)
def
_mx_batch_dot
(
inputs
,
attrs
):
def
_mx_batch_dot
(
inputs
,
attrs
):
...
@@ -399,6 +400,49 @@ def _mx_proposal(inputs, attrs):
...
@@ -399,6 +400,49 @@ def _mx_proposal(inputs, attrs):
return
_op
.
vision
.
proposal
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
**
new_attrs
)
return
_op
.
vision
.
proposal
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
**
new_attrs
)
def
_mx_box_nms
(
inputs
,
attrs
):
force_suppress
=
attrs
.
get_bool
(
"force_suppress"
,
False
)
iou_thresh
=
attrs
.
get_float
(
'overlap_thresh'
,
0.5
)
top_k
=
attrs
.
get_int
(
'topk'
,
-
1
)
valid_thresh
=
attrs
.
get_float
(
'valid_thresh'
,
0
)
coord_start
=
attrs
.
get_int
(
'coord_start'
,
2
)
score_index
=
attrs
.
get_int
(
'score_index'
,
1
)
id_index
=
attrs
.
get_int
(
'id_index'
,
-
1
)
in_format
=
attrs
.
get_str
(
'in_format'
,
'corner'
)
out_format
=
attrs
.
get_str
(
'out_format'
,
'corner'
)
if
coord_start
!=
2
:
raise
RuntimeError
(
'coord_start
%
s is not supported.'
%
coord_start
)
if
score_index
!=
1
:
raise
RuntimeError
(
'score_index
%
s is not supported.'
%
score_index
)
if
id_index
!=
-
1
and
int
(
id_index
)
!=
0
:
raise
RuntimeError
(
'id_index
%
s is not supported.'
%
id_index
)
if
in_format
!=
'corner'
:
raise
RuntimeError
(
'in_format
%
s is not supported.'
%
in_format
)
if
out_format
!=
'corner'
:
raise
RuntimeError
(
'out_format
%
s is not supported.'
%
out_format
)
ret
=
_op
.
vision
.
get_valid_counts
(
inputs
[
0
],
score_threshold
=
valid_thresh
)
nms_out
=
_op
.
vision
.
non_max_suppression
(
ret
[
1
],
ret
[
0
],
iou_threshold
=
iou_thresh
,
force_suppress
=
force_suppress
,
top_k
=
top_k
,
id_index
=
id_index
,
return_indices
=
False
,
invalid_to_bottom
=
True
)
return
nms_out
def
_mx_l2_normalize
(
inputs
,
attrs
):
new_attrs
=
{}
mode
=
attrs
.
get_str
(
'mode'
,
'instance'
)
if
mode
!=
'channel'
:
raise
RuntimeError
(
'mode
%
s is not supported.'
%
mode
)
new_attrs
[
'eps'
]
=
attrs
.
get_float
(
'eps'
,
1e-10
)
new_attrs
[
'axis'
]
=
[
1
]
return
_op
.
nn
.
l2_normalize
(
inputs
[
0
],
**
new_attrs
)
# Note: due to attribute conversion constraint
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
# ops in the identity set must be attribute free
_identity_list
=
[
_identity_list
=
[
...
@@ -497,6 +541,7 @@ _convert_map = {
...
@@ -497,6 +541,7 @@ _convert_map = {
"BatchNorm"
:
_mx_batch_norm
,
"BatchNorm"
:
_mx_batch_norm
,
"BatchNorm_v1"
:
_mx_batch_norm
,
"BatchNorm_v1"
:
_mx_batch_norm
,
"LRN"
:
_mx_lrn
,
"LRN"
:
_mx_lrn
,
"L2Normalization"
:
_mx_l2_normalize
,
"slice"
:
_mx_slice
,
"slice"
:
_mx_slice
,
"slice_like"
:
_mx_slice_like
,
"slice_like"
:
_mx_slice_like
,
"slice_axis"
:
_mx_slice_axis
,
"slice_axis"
:
_mx_slice_axis
,
...
@@ -520,6 +565,7 @@ _convert_map = {
...
@@ -520,6 +565,7 @@ _convert_map = {
"_contrib_ROIAlign"
:
_mx_roi_align
,
"_contrib_ROIAlign"
:
_mx_roi_align
,
"_contrib_Proposal"
:
_mx_proposal
,
"_contrib_Proposal"
:
_mx_proposal
,
"_contrib_MultiProposal"
:
_mx_proposal
,
"_contrib_MultiProposal"
:
_mx_proposal
,
"_contrib_box_nms"
:
_mx_box_nms
,
# List of missing operators that are present in NNVMv1
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
# TODO(tvm-tvm): support all operators.
#
#
...
@@ -662,6 +708,8 @@ def from_mxnet(symbol,
...
@@ -662,6 +708,8 @@ def from_mxnet(symbol,
params
[
k
]
=
_nd
.
array
(
v
.
data
()
.
asnumpy
())
params
[
k
]
=
_nd
.
array
(
v
.
data
()
.
asnumpy
())
data
=
mx
.
sym
.
Variable
(
"data"
)
data
=
mx
.
sym
.
Variable
(
"data"
)
sym
=
symbol
(
data
)
sym
=
symbol
(
data
)
if
isinstance
(
sym
,
(
list
,
tuple
)):
sym
=
mx
.
sym
.
Group
(
sym
)
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
shape
,
dtype
=
_update_shape_dtype
(
shape
,
dtype
,
params
)
sym
=
_from_mxnet_impl
(
sym
,
shape
,
dtype
)
sym
=
_from_mxnet_impl
(
sym
,
shape
,
dtype
)
elif
isinstance
(
symbol
,
mx
.
gluon
.
Block
):
elif
isinstance
(
symbol
,
mx
.
gluon
.
Block
):
...
...
python/tvm/relay/op/transform.py
View file @
d2f29ba5
...
@@ -525,7 +525,7 @@ def strided_slice(data, begin, end, strides=None):
...
@@ -525,7 +525,7 @@ def strided_slice(data, begin, end, strides=None):
The indices to begin with in the slicing.
The indices to begin with in the slicing.
end: list of int
end: list of int
Indic
i
es indicating end of the slice.
Indices indicating end of the slice.
strides: list of int, optional
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
Specifies the stride values, it can be negative in that case,
...
...
python/tvm/relay/op/vision/__init__.py
View file @
d2f29ba5
...
@@ -6,6 +6,6 @@ from .multibox import *
...
@@ -6,6 +6,6 @@ from .multibox import *
from
.nms
import
*
from
.nms
import
*
from
.rcnn
import
*
from
.rcnn
import
*
from
.yolo
import
*
from
.yolo
import
*
from
.
import
_multibox
from
.
import
_rcnn
from
.
import
_rcnn
from
.
import
_yolo
from
.
import
_yolo
from
.
import
_vision
python/tvm/relay/op/vision/_
multibox
.py
→
python/tvm/relay/op/vision/_
vision
.py
View file @
d2f29ba5
...
@@ -54,24 +54,46 @@ reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
...
@@ -54,24 +54,46 @@ reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg
.
register_pattern
(
"vision.multibox_detection"
,
OpPattern
.
OPAQUE
)
reg
.
register_pattern
(
"vision.multibox_detection"
,
OpPattern
.
OPAQUE
)
# Get counts of valid boxes
@reg.register_schedule
(
"vision.get_valid_counts"
)
def
schedule_get_valid_counts
(
_
,
outs
,
target
):
"""Schedule definition of get_valid_counts"""
with
target
:
return
topi
.
generic
.
schedule_get_valid_counts
(
outs
)
@reg.register_compute
(
"vision.get_valid_counts"
)
def
compute_get_valid_counts
(
attrs
,
inputs
,
_
,
target
):
"""Compute definition of get_valid_counts"""
score_threshold
=
get_const_float
(
attrs
.
score_threshold
)
return
topi
.
vision
.
get_valid_counts
(
inputs
[
0
],
score_threshold
)
reg
.
register_pattern
(
"vision.get_valid_counts"
,
OpPattern
.
OPAQUE
)
# non-maximum suppression
# non-maximum suppression
@reg.register_schedule
(
"vision.n
ms
"
)
@reg.register_schedule
(
"vision.n
on_max_suppression
"
)
def
schedule_nms
(
_
,
outs
,
target
):
def
schedule_nms
(
_
,
outs
,
target
):
"""Schedule definition of nms"""
"""Schedule definition of nms"""
with
target
:
with
target
:
return
topi
.
generic
.
schedule_nms
(
outs
)
return
topi
.
generic
.
schedule_nms
(
outs
)
@reg.register_compute
(
"vision.n
ms
"
)
@reg.register_compute
(
"vision.n
on_max_suppression
"
)
def
compute_nms
(
attrs
,
inputs
,
_
,
target
):
def
compute_nms
(
attrs
,
inputs
,
_
,
target
):
"""Compute definition of nms"""
"""Compute definition of nms"""
overlap_threshold
=
get_const_float
(
attrs
.
overlap_threshold
)
return_indices
=
bool
(
get_const_int
(
attrs
.
return_indices
))
max_output_size
=
get_const_int
(
attrs
.
max_output_size
)
iou_threshold
=
get_const_float
(
attrs
.
iou_threshold
)
force_suppress
=
bool
(
get_const_int
(
attrs
.
force_suppress
))
force_suppress
=
bool
(
get_const_int
(
attrs
.
force_suppress
))
topk
=
get_const_int
(
attrs
.
topk
)
top_k
=
get_const_int
(
attrs
.
top_k
)
id_index
=
get_const_int
(
attrs
.
id_index
)
invalid_to_bottom
=
bool
(
get_const_int
(
attrs
.
invalid_to_bottom
))
return
[
return
[
topi
.
vision
.
nms
(
inputs
[
0
],
inputs
[
1
],
overlap_threshold
,
topi
.
vision
.
non_max_suppression
(
inputs
[
0
],
inputs
[
1
],
max_output_size
,
force_suppress
,
topk
)
iou_threshold
,
force_suppress
,
top_k
,
id_index
,
return_indices
,
invalid_to_bottom
)
]
]
reg
.
register_pattern
(
"vision.n
ms
"
,
OpPattern
.
OPAQUE
)
reg
.
register_pattern
(
"vision.n
on_max_suppression
"
,
OpPattern
.
OPAQUE
)
python/tvm/relay/op/vision/nms.py
View file @
d2f29ba5
"""Non-maximum suppression operations."""
"""Non-maximum suppression operations."""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
.
import
_make
from
.
import
_make
from
...expr
import
TupleWrapper
def
nms
(
data
,
def
get_valid_counts
(
data
,
score_threshold
):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : relay.Expr
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
valid_count : relay.Expr
1-D tensor for valid number of boxes.
out_tensor : relay.Expr
Rearranged data tensor.
"""
return
TupleWrapper
(
_make
.
get_valid_counts
(
data
,
score_threshold
),
2
)
def
non_max_suppression
(
data
,
valid_count
,
valid_count
,
overlap_threshold
=
0.5
,
max_output_size
=-
1
,
iou_threshold
=
0.5
,
force_suppress
=
False
,
force_suppress
=
False
,
topk
=-
1
):
top_k
=-
1
,
id_index
=
0
,
return_indices
=
True
,
invalid_to_bottom
=
False
):
"""Non-maximum suppression operator for object detection.
"""Non-maximum suppression operator for object detection.
Parameters
Parameters
...
@@ -19,18 +48,33 @@ def nms(data,
...
@@ -19,18 +48,33 @@ def nms(data,
valid_count : relay.Expr
valid_count : relay.Expr
1-D tensor for valid number of boxes.
1-D tensor for valid number of boxes.
overlap_threshold : float, optional
max_output_size : int, optional
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : float, optional
Non-maximum suppression threshold.
Non-maximum suppression threshold.
force_suppress : bool, optional
force_suppress : bool, optional
Suppress all detections regardless of class_id.
Suppress all detections regardless of class_id.
topk : int, optional
top
_
k : int, optional
Keep maximum top k detections before nms, -1 for no limit.
Keep maximum top k detections before nms, -1 for no limit.
id_index : int, optional
index of the class categories, -1 to disable.
return_indices : bool, optional
Whether to return box indices in input data.
invalid_to_bottom : bool, optional
Whether to move all valid bounding boxes to the top.
Returns
Returns
-------
-------
out : relay.Expr
out : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
3-D tensor with shape [batch_size, num_anchors, 6].
"""
"""
return
_make
.
nms
(
data
,
valid_count
,
overlap_threshold
,
force_suppress
,
topk
)
return
_make
.
non_max_suppression
(
data
,
valid_count
,
max_output_size
,
iou_threshold
,
force_suppress
,
top_k
,
id_index
,
return_indices
,
invalid_to_bottom
)
src/relay/op/tensor/transform.cc
View file @
d2f29ba5
...
@@ -1516,6 +1516,16 @@ RELAY_REGISTER_OP("broadcast_to_like")
...
@@ -1516,6 +1516,16 @@ RELAY_REGISTER_OP("broadcast_to_like")
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kBroadcast
);
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kBroadcast
);
// Adapter function to make int array.
Array
<
Integer
>
GetIntArray
(
Array
<
IndexExpr
>
arr
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
CHECK
(
!
arr
[
i
].
defined
()
||
arr
[
i
].
as
<
IntImm
>
())
<<
"Expect an int array"
;
}
return
Array
<
Integer
>
(
arr
.
node_
);
}
// strided_slice
// strided_slice
TVM_REGISTER_NODE_TYPE
(
StridedSliceAttrs
);
TVM_REGISTER_NODE_TYPE
(
StridedSliceAttrs
);
bool
StridedSliceRel
(
const
Array
<
Type
>&
types
,
bool
StridedSliceRel
(
const
Array
<
Type
>&
types
,
...
@@ -1870,15 +1880,6 @@ Expr MakeSliceLike(Expr data,
...
@@ -1870,15 +1880,6 @@ Expr MakeSliceLike(Expr data,
return
CallNode
::
make
(
op
,
{
data
,
shape_like
},
Attrs
(
attrs
),
{});
return
CallNode
::
make
(
op
,
{
data
,
shape_like
},
Attrs
(
attrs
),
{});
}
}
// Adapter function to make int array.
Array
<
Integer
>
GetIntArray
(
Array
<
IndexExpr
>
arr
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
CHECK
(
!
arr
[
i
].
defined
()
||
arr
[
i
].
as
<
IntImm
>
())
<<
"Expect an int array"
;
}
return
Array
<
Integer
>
(
arr
.
node_
);
}
Array
<
Tensor
>
SliceLikeCompute
(
const
Attrs
&
attrs
,
Array
<
Tensor
>
SliceLikeCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Type
&
out_type
,
...
...
src/relay/op/vision/multibox_op.cc
View file @
d2f29ba5
...
@@ -70,8 +70,10 @@ RELAY_REGISTER_OP("vision.multibox_prior")
...
@@ -70,8 +70,10 @@ RELAY_REGISTER_OP("vision.multibox_prior")
TVM_REGISTER_NODE_TYPE
(
MultiBoxTransformLocAttrs
);
TVM_REGISTER_NODE_TYPE
(
MultiBoxTransformLocAttrs
);
bool
MultiBoxTransformLocRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
bool
MultiBoxTransformLocRel
(
const
Array
<
Type
>&
types
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
4
);
CHECK_EQ
(
types
.
size
(),
4
);
const
auto
*
cls_prob
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
cls_prob
=
types
[
0
].
as
<
TensorTypeNode
>
();
...
...
src/relay/op/vision/nms.cc
View file @
d2f29ba5
...
@@ -9,7 +9,54 @@
...
@@ -9,7 +9,54 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
TVM_REGISTER_NODE_TYPE
(
NMSAttrs
);
TVM_REGISTER_NODE_TYPE
(
GetValidCountsAttrs
);
bool
GetValidCountRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
2
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
&
dshape
=
data
->
shape
;
CHECK_EQ
(
dshape
.
size
(),
3
)
<<
"Input data should be 3-D."
;
std
::
vector
<
IndexExpr
>
oshape
({
data
->
shape
[
0
]});
std
::
vector
<
Type
>
fields
;
fields
.
push_back
(
TensorTypeNode
::
make
(
oshape
,
Int
(
32
)));
fields
.
push_back
(
TensorTypeNode
::
make
(
data
->
shape
,
data
->
dtype
));
// assign output type
reporter
->
Assign
(
types
[
1
],
TupleTypeNode
::
make
(
Array
<
Type
>
(
fields
)));
return
true
;
}
Expr
MakeGetValidCounts
(
Expr
data
,
double
score_threshold
)
{
auto
attrs
=
make_node
<
GetValidCountsAttrs
>
();
attrs
->
score_threshold
=
score_threshold
;
static
const
Op
&
op
=
Op
::
Get
(
"vision.get_valid_counts"
);
return
CallNode
::
make
(
op
,
{
data
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.vision._make.get_valid_counts"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
2
>
(
MakeGetValidCounts
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"vision.get_valid_counts"
)
.
describe
(
R"doc(Get valid count of bounding boxes given
a score threshold. Also moves valid boxes to the top of
input data.
)doc"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
.
set_support_level
(
5
)
.
add_type_rel
(
"GetValidCount"
,
GetValidCountRel
);
TVM_REGISTER_NODE_TYPE
(
NonMaximumSuppressionAttrs
);
bool
NMSRel
(
const
Array
<
Type
>&
types
,
bool
NMSRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
int
num_inputs
,
...
@@ -18,39 +65,56 @@ bool NMSRel(const Array<Type>& types,
...
@@ -18,39 +65,56 @@ bool NMSRel(const Array<Type>& types,
CHECK_EQ
(
types
.
size
(),
3
);
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
valid_count
=
types
[
1
].
as
<
TensorTypeNode
>
();
const
auto
*
valid_count
=
types
[
1
].
as
<
TensorTypeNode
>
();
const
NonMaximumSuppressionAttrs
*
param
=
attrs
.
as
<
NonMaximumSuppressionAttrs
>
();
const
auto
&
dshape
=
data
->
shape
;
const
auto
&
dshape
=
data
->
shape
;
const
auto
&
vshape
=
valid_count
->
shape
;
const
auto
&
vshape
=
valid_count
->
shape
;
CHECK_EQ
(
dshape
.
size
(),
3
)
<<
"Input data should be 3-D."
;
CHECK_EQ
(
dshape
.
size
(),
3
)
<<
"Input data should be 3-D."
;
CHECK_EQ
(
vshape
.
size
(),
1
)
<<
"Input valid count should be 1-D."
;
CHECK_EQ
(
vshape
.
size
(),
1
)
<<
"Input valid count should be 1-D."
;
// assign output type
// assign output type
if
(
param
->
return_indices
)
{
std
::
vector
<
IndexExpr
>
oshape
({
dshape
[
0
],
dshape
[
1
]});
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
oshape
,
Int
(
32
)));
}
else
{
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
dshape
,
data
->
dtype
));
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
dshape
,
data
->
dtype
));
}
return
true
;
return
true
;
}
}
Expr
MakeNMS
(
Expr
data
,
Expr
MakeNMS
(
Expr
data
,
Expr
valid_count
,
Expr
valid_count
,
double
overlap_threshold
,
int
max_output_size
,
double
iou_threshold
,
bool
force_suppress
,
bool
force_suppress
,
int
topk
)
{
int
top_k
,
auto
attrs
=
make_node
<
NMSAttrs
>
();
int
id_index
,
attrs
->
overlap_threshold
=
overlap_threshold
;
bool
return_indices
,
bool
invalid_to_bottom
)
{
auto
attrs
=
make_node
<
NonMaximumSuppressionAttrs
>
();
attrs
->
max_output_size
=
max_output_size
;
attrs
->
iou_threshold
=
iou_threshold
;
attrs
->
force_suppress
=
force_suppress
;
attrs
->
force_suppress
=
force_suppress
;
attrs
->
topk
=
topk
;
attrs
->
top_k
=
top_k
;
static
const
Op
&
op
=
Op
::
Get
(
"vision.nms"
);
attrs
->
id_index
=
id_index
;
attrs
->
return_indices
=
return_indices
;
attrs
->
invalid_to_bottom
=
invalid_to_bottom
;
static
const
Op
&
op
=
Op
::
Get
(
"vision.non_max_suppression"
);
return
CallNode
::
make
(
op
,
{
data
,
valid_count
},
Attrs
(
attrs
),
{});
return
CallNode
::
make
(
op
,
{
data
,
valid_count
},
Attrs
(
attrs
),
{});
}
}
TVM_REGISTER_API
(
"relay.op.vision._make.n
ms
"
)
TVM_REGISTER_API
(
"relay.op.vision._make.n
on_max_suppression
"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
5
>
(
MakeNMS
,
args
,
rv
);
runtime
::
detail
::
unpack_call
<
Expr
,
9
>
(
MakeNMS
,
args
,
rv
);
});
});
RELAY_REGISTER_OP
(
"vision.nms"
)
RELAY_REGISTER_OP
(
"vision.non_max_suppression"
)
.
describe
(
R"doc("Non-maximum suppression."
.
describe
(
R"doc(Non-maximum suppression. The input boxes should
be in the format of [class_id, score, left, top, right, bottom].
Set id_index to be -1 to ignore class_id axis.
)doc"
TVM_ADD_FILELINE
)
)doc"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
...
...
tests/python/frontend/mxnet/test_forward.py
View file @
d2f29ba5
...
@@ -374,6 +374,11 @@ def test_forward_slice_like():
...
@@ -374,6 +374,11 @@ def test_forward_slice_like():
verify
((
3
,
4
),
(
2
,
3
),
(
0
))
verify
((
3
,
4
),
(
2
,
3
),
(
0
))
verify
((
3
,
4
),
(
2
,
3
),
(
-
1
))
verify
((
3
,
4
),
(
2
,
3
),
(
-
1
))
def
test_forward_l2_normalize
():
data
=
mx
.
sym
.
var
(
'data'
)
mx_sym
=
mx
.
sym
.
L2Normalization
(
data
,
mode
=
"channel"
)
verify_mxnet_frontend_impl
(
mx_sym
,
(
2
,
3
,
4
,
5
),
(
2
,
3
,
4
,
5
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_forward_mlp
()
test_forward_mlp
()
...
@@ -401,5 +406,6 @@ if __name__ == '__main__':
...
@@ -401,5 +406,6 @@ if __name__ == '__main__':
test_forward_broadcast_ops
()
test_forward_broadcast_ops
()
test_forward_elemwise_ops
()
test_forward_elemwise_ops
()
test_forward_scalar_ops
()
test_forward_scalar_ops
()
test_forward_slice_axis
()
test_forward_slice_like
()
test_forward_slice_like
()
test_forward_slice_axis
()
test_forward_l2_normalize
()
tests/python/relay/test_op_level10.py
View file @
d2f29ba5
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
"""
"""
import
numpy
as
np
import
numpy
as
np
import
tvm
import
tvm
import
topi.testing
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.testing
import
ctx_list
from
tvm.relay.testing
import
ctx_list
import
topi
import
topi
...
...
tests/python/relay/test_op_level5.py
View file @
d2f29ba5
...
@@ -135,56 +135,107 @@ def test_multibox_prior():
...
@@ -135,56 +135,107 @@ def test_multibox_prior():
verify_multibox_prior
(
x
,
dshape
,
ref_res
,
clip
=
False
,
check_type_only
=
True
)
verify_multibox_prior
(
x
,
dshape
,
ref_res
,
clip
=
False
,
check_type_only
=
True
)
def
test_nms
():
def
test_get_valid_counts
():
def
verify_nms
(
x0_data
,
x1_data
,
dshape
,
ref_res
,
valid_count
,
def
verify_get_valid_counts
(
dshape
,
score_threshold
):
overlap_threshold
=
0.5
,
force_suppress
=
False
,
topk
=-
1
,
dtype
=
"float32"
batch_size
,
num_anchor
,
elem_length
=
dshape
np_data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
)
np_out1
=
np
.
zeros
(
shape
=
(
batch_size
,))
np_out2
=
np
.
zeros
(
shape
=
dshape
)
.
astype
(
dtype
)
for
i
in
range
(
batch_size
):
np_out1
[
i
]
=
0
inter_idx
=
0
for
j
in
range
(
num_anchor
):
score
=
np_data
[
i
,
j
,
1
]
if
score
>=
score_threshold
:
for
k
in
range
(
elem_length
):
np_out2
[
i
,
inter_idx
,
k
]
=
np_data
[
i
,
j
,
k
]
np_out1
[
i
]
+=
1
inter_idx
+=
1
if
j
>=
np_out1
[
i
]:
for
k
in
range
(
elem_length
):
np_out2
[
i
,
j
,
k
]
=
-
1
x
=
relay
.
var
(
"x"
,
relay
.
ty
.
TensorType
(
dshape
,
dtype
))
z
=
relay
.
vision
.
get_valid_counts
(
x
,
score_threshold
)
assert
"score_threshold"
in
z
.
astext
()
func
=
relay
.
Function
([
x
],
z
.
astuple
())
func
=
relay
.
ir_pass
.
infer_type
(
func
)
ctx_list
=
[(
"llvm"
,
tvm
.
cpu
(
0
))]
for
target
,
ctx
in
ctx_list
:
intrp
=
relay
.
create_executor
(
"debug"
,
ctx
=
ctx
,
target
=
target
)
out
=
intrp
.
evaluate
(
func
)(
np_data
)
tvm
.
testing
.
assert_allclose
(
out
[
0
]
.
asnumpy
(),
np_out1
,
rtol
=
1e-3
)
tvm
.
testing
.
assert_allclose
(
out
[
1
]
.
asnumpy
(),
np_out2
,
rtol
=
1e-3
)
verify_get_valid_counts
((
1
,
2500
,
6
),
0
)
verify_get_valid_counts
((
1
,
2500
,
6
),
-
1
)
verify_get_valid_counts
((
3
,
1000
,
6
),
0.55
)
verify_get_valid_counts
((
16
,
500
,
6
),
0.95
)
def
test_non_max_suppression
():
def
verify_nms
(
x0_data
,
x1_data
,
dshape
,
ref_res
,
ref_indices_res
,
iou_threshold
=
0.5
,
force_suppress
=
False
,
top_k
=-
1
,
check_type_only
=
False
):
check_type_only
=
False
):
x0
=
relay
.
var
(
"x0"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
x0
=
relay
.
var
(
"x0"
,
relay
.
ty
.
TensorType
(
dshape
,
"float32"
))
x1
=
relay
.
var
(
"x1"
,
relay
.
ty
.
TensorType
((
dshape
[
0
],),
"int"
))
x1
=
relay
.
var
(
"x1"
,
relay
.
ty
.
TensorType
((
dshape
[
0
],),
"int"
))
z
=
relay
.
vision
.
nms
(
x0
,
x1
,
overlap_threshold
,
force_suppress
,
topk
)
z
=
relay
.
vision
.
non_max_suppression
(
x0
,
x1
,
-
1
,
iou_threshold
,
force_suppress
,
top_k
,
return_indices
=
False
)
assert
"overlap_threshold"
in
z
.
astext
()
z_indices
=
relay
.
vision
.
non_max_suppression
(
x0
,
x1
,
-
1
,
iou_threshold
,
force_suppress
,
top_k
)
assert
"iou_threshold"
in
z
.
astext
()
assert
"iou_threshold"
in
z_indices
.
astext
()
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
zz_indices
=
relay
.
ir_pass
.
infer_type
(
z_indices
)
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
(
dshape
,
"float32"
)
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
(
dshape
,
"float32"
)
assert
zz_indices
.
checked_type
==
relay
.
ty
.
TensorType
((
dshape
[
0
],
dshape
[
1
]),
"int32"
)
if
check_type_only
:
if
check_type_only
:
return
return
func
=
relay
.
Function
([
x0
,
x1
],
z
)
func
=
relay
.
Function
([
x0
,
x1
],
z
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func_indices
=
relay
.
Function
([
x0
,
x1
],
z_indices
)
func_indices
=
relay
.
ir_pass
.
infer_type
(
func_indices
)
ctx_list
=
[(
"llvm"
,
tvm
.
cpu
(
0
))]
ctx_list
=
[(
"llvm"
,
tvm
.
cpu
(
0
))]
for
target
,
ctx
in
ctx_list
:
for
target
,
ctx
in
ctx_list
:
intrp1
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
intrp1
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
op_res1
=
intrp1
.
evaluate
(
func
)(
x0_data
,
x1_data
)
op_res1
=
intrp1
.
evaluate
(
func
)(
x0_data
,
x1_data
)
op_indices_res1
=
intrp1
.
evaluate
(
func_indices
)(
x0_data
,
x1_data
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
op_indices_res1
.
asnumpy
(),
ref_indices_res
,
rtol
=
1e-5
)
intrp2
=
relay
.
create_executor
(
"debug"
,
ctx
=
ctx
,
target
=
target
)
intrp2
=
relay
.
create_executor
(
"debug"
,
ctx
=
ctx
,
target
=
target
)
op_res2
=
intrp2
.
evaluate
(
func
)(
x0_data
,
x1_data
)
op_res2
=
intrp2
.
evaluate
(
func
)(
x0_data
,
x1_data
)
op_indices_res2
=
intrp2
.
evaluate
(
func_indices
)(
x0_data
,
x1_data
)
tvm
.
testing
.
assert_allclose
(
op_res2
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
op_res2
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
)
tvm
.
testing
.
assert_allclose
(
op_indices_res2
.
asnumpy
(),
ref_indices_res
,
rtol
=
1e-5
)
np_data
=
np
.
array
([[[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
1
,
0.7
,
30
,
60
,
50
,
80
],
np_data
=
np
.
array
([[[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
1
,
0.7
,
30
,
60
,
50
,
80
],
[
0
,
0.4
,
4
,
21
,
19
,
40
],
[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.4
,
4
,
21
,
19
,
40
],
[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
1
,
0.5
,
100
,
60
,
70
,
110
]]])
.
astype
(
"float32"
)
[
1
,
0.5
,
100
,
60
,
70
,
110
]]])
.
astype
(
"float32"
)
np_valid_count
=
np
.
array
([
4
])
.
astype
(
"int32"
)
np_valid_count
=
np
.
array
([
4
])
.
astype
(
"int32"
)
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
0
,
0.4
,
4
,
21
,
19
,
40
],
[
-
1
,
0.9
,
35
,
61
,
52
,
79
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
np_indices_result
=
np
.
array
([[
3
,
0
,
-
1
,
-
1
,
-
1
]])
num_anchors
=
5
num_anchors
=
5
dshape
=
(
tvm
.
var
(
"n"
),
num_anchors
,
6
)
dshape
=
(
tvm
.
var
(
"n"
),
num_anchors
,
6
)
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
dshape
[
0
]
,
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
np_indices_result
,
force_suppress
=
True
,
topk
=
2
,
check_type_only
=
True
)
force_suppress
=
True
,
top
_
k
=
2
,
check_type_only
=
True
)
dshape
=
(
1
,
num_anchors
,
6
)
dshape
=
(
1
,
num_anchors
,
6
)
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
dshape
[
0
]
,
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
np_indices_result
,
force_suppress
=
True
,
topk
=
2
,
check_type_only
=
False
)
force_suppress
=
True
,
top
_
k
=
2
,
check_type_only
=
False
)
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
1
,
0.7
,
30
,
60
,
50
,
80
],
[
-
1
,
0.9
,
35
,
61
,
52
,
79
],
[
1
,
0.7
,
30
,
60
,
50
,
80
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
np_indices_result
=
np
.
array
([[
3
,
0
,
1
,
-
1
,
-
1
]])
dshape
=
(
tvm
.
var
(
"n"
),
num_anchors
,
6
)
dshape
=
(
tvm
.
var
(
"n"
),
num_anchors
,
6
)
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
dshape
[
0
],
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
check_type_only
=
True
)
np_indices_result
,
check_type_only
=
True
)
dshape
=
(
1
,
num_anchors
,
6
)
dshape
=
(
1
,
num_anchors
,
6
)
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
dshape
[
0
],
verify_nms
(
np_data
,
np_valid_count
,
dshape
,
np_result
,
top
k
=
3
)
np_indices_result
,
top_
k
=
3
)
def
test_multibox_transform_loc
():
def
test_multibox_transform_loc
():
...
@@ -226,7 +277,7 @@ def test_multibox_transform_loc():
...
@@ -226,7 +277,7 @@ def test_multibox_transform_loc():
assert
ret
.
checked_type
==
ref_type
assert
ret
.
checked_type
==
ref_type
nms
=
relay
.
vision
.
n
ms
(
mtl
[
0
],
mtl
[
1
]
)
nms
=
relay
.
vision
.
n
on_max_suppression
(
mtl
[
0
],
mtl
[
1
],
return_indices
=
False
)
func
=
relay
.
Function
([
cls_prob
,
loc_pred
,
anchors
],
nms
)
func
=
relay
.
Function
([
cls_prob
,
loc_pred
,
anchors
],
nms
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
ctx_list
=
[(
"llvm"
,
tvm
.
cpu
(
0
))]
ctx_list
=
[(
"llvm"
,
tvm
.
cpu
(
0
))]
...
@@ -411,8 +462,9 @@ if __name__ == "__main__":
...
@@ -411,8 +462,9 @@ if __name__ == "__main__":
test_resize
()
test_resize
()
test_multibox_prior
()
test_multibox_prior
()
test_multibox_transform_loc
()
test_multibox_transform_loc
()
test_
nm
s
()
test_
get_valid_count
s
()
test_roi_align
()
test_roi_align
()
test_proposal
()
test_proposal
()
test_yolo_reorg_infer_shape
()
test_yolo_reorg_infer_shape
()
test_yolo_reorg
()
test_yolo_reorg
()
test_non_max_suppression
()
topi/include/topi/nn/l2_normalize.h
View file @
d2f29ba5
...
@@ -30,7 +30,12 @@ inline Tensor l2_normalize(const Tensor& data,
...
@@ -30,7 +30,12 @@ inline Tensor l2_normalize(const Tensor& data,
const
Array
<
Integer
>&
axis
,
const
Array
<
Integer
>&
axis
,
std
::
string
name
=
"tensor"
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
"l2_normalize"
)
{
std
::
string
tag
=
"l2_normalize"
)
{
CHECK_EQ
(
data
->
shape
.
size
(),
4
)
<<
"L2 normalization requires 4-D input"
;
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
int
ax
=
topi
::
detail
::
GetConstInt
(
axis
[
i
]);
CHECK_LT
(
ax
,
data
->
shape
.
size
())
<<
"Axis "
<<
ax
<<
" exceeds input data dim "
<<
data
->
shape
.
size
();
}
auto
input_shape
=
data
->
shape
;
auto
input_shape
=
data
->
shape
;
Tensor
dot_value
=
topi
::
power
(
data
,
static_cast
<
float
>
(
2
.
0
));
Tensor
dot_value
=
topi
::
power
(
data
,
static_cast
<
float
>
(
2
.
0
));
Tensor
sum_value
=
topi
::
sum
(
dot_value
,
axis
,
true
);
Tensor
sum_value
=
topi
::
sum
(
dot_value
,
axis
,
true
);
...
...
topi/python/topi/cuda/nms.py
View file @
d2f29ba5
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison
, unused-argument
"""Non-maximum suppression operator"""
"""Non-maximum suppression operator"""
import
math
import
math
import
tvm
import
tvm
from
tvm
import
api
from
tvm
import
api
from
topi.vision
import
n
ms
from
topi.vision
import
n
on_max_suppression
from
..util
import
get_const_tuple
from
..util
import
get_const_tuple
def
sort_ir
(
data
,
index
,
output
):
def
sort_ir
(
data
,
index
,
output
):
...
@@ -181,13 +181,14 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
...
@@ -181,13 +181,14 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
return
body
return
body
@nms.register
([
"cuda"
,
"gpu"
])
@non_max_suppression.register
([
"cuda"
,
"gpu"
])
def
nms_gpu
(
data
,
valid_count
,
nms_threshold
=
0.5
,
force_suppress
=
False
,
nms_topk
=-
1
):
def
nms_gpu
(
data
,
valid_count
,
return_indices
,
iou_threshold
=
0.5
,
force_suppress
=
False
,
topk
=-
1
,
id_index
=
0
,
invalid_to_bottom
=
False
):
"""Non-maximum suppression operator for object detection.
"""Non-maximum suppression operator for object detection.
Parameters
Parameters
----------
----------
data: tvm.Tensor
data
: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
[class_id, score, box_left, box_top, box_right, box_bottom].
...
@@ -195,15 +196,24 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
...
@@ -195,15 +196,24 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
valid_count : tvm.Tensor
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
1-D tensor for valid number of boxes.
nms_threshold : float
return_indices : boolean
Whether to return box indices in input data.
iou_threshold : optional, float
Non-maximum suppression threshold.
Non-maximum suppression threshold.
force_suppress : boolean
force_suppress :
optional,
boolean
Whether to suppress all detections regardless of class_id.
Whether to suppress all detections regardless of class_id.
nms_topk :
int
topk : optional,
int
Keep maximum top k detections before nms, -1 for no limit.
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
Returns
-------
-------
out : tvm.Tensor
out : tvm.Tensor
...
@@ -216,14 +226,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
...
@@ -216,14 +226,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
# An example to use nms
# An example to use nms
dshape = (1, 5, 6)
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder(
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
(dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7
nms_threshold = 0.7
force_suppress = True
force_suppress = True
nms_
topk = -1
topk = -1
out = nms(data, valid_count,
nms_threshold, force_suppress, nms_
topk)
out = nms(data, valid_count,
iou_threshold, force_suppress,
topk)
np_data = np.random.uniform(
size=dshape).astype("float32"
)
np_data = np.random.uniform(
dshape
)
np_valid_count = np.array([4])
.astype("int32")
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
ctx = tvm.cpu()
...
@@ -263,8 +272,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
...
@@ -263,8 +272,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
tvm
.
extern
(
data
.
shape
,
tvm
.
extern
(
data
.
shape
,
[
data
,
sort_tensor
,
valid_count
],
[
data
,
sort_tensor
,
valid_count
],
lambda
ins
,
outs
:
nms_ir
(
lambda
ins
,
outs
:
nms_ir
(
ins
[
0
],
ins
[
1
],
ins
[
2
],
outs
[
0
],
nms
_threshold
,
ins
[
0
],
ins
[
1
],
ins
[
2
],
outs
[
0
],
iou
_threshold
,
force_suppress
,
nms_
topk
),
force_suppress
,
topk
),
dtype
=
"float32"
,
dtype
=
"float32"
,
in_buffers
=
[
data_buf
,
sort_tensor_buf
,
valid_count_buf
],
in_buffers
=
[
data_buf
,
sort_tensor_buf
,
valid_count_buf
],
tag
=
"nms"
)
tag
=
"nms"
)
...
...
topi/python/topi/cuda/ssd/multibox.py
View file @
d2f29ba5
...
@@ -11,7 +11,7 @@ import topi
...
@@ -11,7 +11,7 @@ import topi
from
topi.vision.ssd
import
multibox_prior
from
topi.vision.ssd
import
multibox_prior
from
topi.vision.ssd
import
multibox_detection
from
topi.vision.ssd
import
multibox_detection
from
topi.vision.ssd
import
multibox_transform_loc
from
topi.vision.ssd
import
multibox_transform_loc
from
..nms
import
n
ms
from
..nms
import
n
on_max_suppression
def
multibox_prior_ir
(
data
,
out
,
sizes
,
ratios
,
steps
,
offsets
):
def
multibox_prior_ir
(
data
,
out
,
sizes
,
ratios
,
steps
,
offsets
):
...
@@ -437,6 +437,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
...
@@ -437,6 +437,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
"""
"""
inter_out
=
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
inter_out
=
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
clip
,
threshold
,
variances
)
clip
,
threshold
,
variances
)
out
=
n
ms
(
out
=
n
on_max_suppression
(
inter_out
[
0
],
inter_out
[
1
],
nms_threshold
,
force_suppress
,
nms_topk
)
inter_out
[
0
],
inter_out
[
1
],
nms_threshold
,
force_suppress
,
nms_topk
)
return
out
return
out
topi/python/topi/cuda/vision.py
View file @
d2f29ba5
...
@@ -162,3 +162,20 @@ def schedule_proposal(outs):
...
@@ -162,3 +162,20 @@ def schedule_proposal(outs):
scheduled_ops
.
append
(
op
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
return
s
@generic.schedule_get_valid_counts.register
([
"cuda"
,
"gpu"
])
def
schedule_get_valid_counts
(
outs
):
"""Schedule for get_valid_counts operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of get_valid_counts
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
)
topi/python/topi/generic/vision.py
View file @
d2f29ba5
...
@@ -37,6 +37,23 @@ def schedule_reorg(outs):
...
@@ -37,6 +37,23 @@ def schedule_reorg(outs):
return
cpp
.
generic
.
default_schedule
(
cpp_target
,
outs
,
False
)
return
cpp
.
generic
.
default_schedule
(
cpp_target
,
outs
,
False
)
@tvm.target.generic_func
@tvm.target.generic_func
def
schedule_get_valid_counts
(
outs
):
"""Schedule for get_valid_counts
Parameters
----------
outs: Array of Tensor
The computation graph description of nms
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_nms
(
outs
):
def
schedule_nms
(
outs
):
"""Schedule for non-maximum suppression
"""Schedule for non-maximum suppression
...
...
topi/python/topi/testing/__init__.py
View file @
d2f29ba5
...
@@ -20,3 +20,4 @@ from .l2_normalize_python import l2_normalize_python
...
@@ -20,3 +20,4 @@ from .l2_normalize_python import l2_normalize_python
from
.gather_nd_python
import
gather_nd_python
from
.gather_nd_python
import
gather_nd_python
from
.strided_slice_python
import
strided_slice_python
from
.strided_slice_python
import
strided_slice_python
from
.batch_matmul
import
batch_matmul
from
.batch_matmul
import
batch_matmul
from
.slice_axis_python
import
slice_axis_python
topi/python/topi/testing/slice_axis_python.py
0 → 100644
View file @
d2f29ba5
"""Slice axis in python"""
def
slice_axis_python
(
data
,
axis
,
begin
,
end
=
None
):
"""Slice input array along specific axis.
Parameters
----------
data : numpy.ndarray
The source array to be sliced.
axis : int
Axis to be sliced.
begin: int
The index to begin with in the slicing.
end: int, optional
The index indicating end of the slice.
Returns
-------
ret : numpy.ndarray
The computed result.
"""
dshape
=
data
.
shape
if
axis
<
0
:
axis
+=
len
(
dshape
)
if
begin
<
0
:
begin
+=
dshape
[
axis
]
if
end
<=
0
:
end
+=
dshape
[
axis
]
slc
=
[
slice
(
None
)]
*
len
(
dshape
)
slc
[
axis
]
=
slice
(
begin
,
end
)
return
data
[
tuple
(
slc
)]
topi/python/topi/vision/nms.py
View file @
d2f29ba5
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements
"""Non-maximum suppression operator"""
"""Non-maximum suppression operator"""
import
tvm
import
tvm
from
tvm
import
api
from
tvm
import
api
,
hybrid
def
nms_ir
(
data
,
sort_result
,
valid_count
,
out
,
nms_threshold
,
force_suppress
,
nms_topk
):
@hybrid.script
"""Low level IR routing for transform location in multibox_detection operator.
def
hybrid_rearrange_out
(
data
):
"""Hybrid routine to rearrange nms output to
move all valid entries to top.
Parameters
Parameters
----------
----------
data: Buffer
data : tvm.Tensor or numpy NDArray
Buffer of output boxes with class and score.
NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6].
sort_result : Buffer
Returns
Buffer of output box indexes sorted by score.
-------
output : tvm.Tensor or numpy NDArray
Transformed NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6].
"""
batch_size
=
data
.
shape
[
0
]
num_anchors
=
data
.
shape
[
1
]
elem_length
=
data
.
shape
[
2
]
output
=
output_tensor
((
batch_size
,
num_anchors
,
elem_length
),
data
.
dtype
)
valid_count : Buffer
for
i
in
parallel
(
batch_size
):
Buffer of number of valid output boxes.
valid_idx
=
0
for
j
in
range
(
num_anchors
):
if
data
[
i
,
j
,
0
]
>=
0
:
for
k
in
range
(
elem_length
):
output
[
i
,
valid_idx
,
k
]
=
data
[
i
,
j
,
k
]
valid_idx
+=
1
if
j
>=
valid_idx
:
for
k
in
range
(
elem_length
):
output
[
i
,
j
,
k
]
=
-
1.0
return
output
out : Buffer
Output buffer.
nms_threshold : float
@hybrid.script
Non-maximum suppression threshold.
def
hybrid_get_valid_counts
(
data
,
score_threshold
):
"""Hybrid routine to get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the
top of input data.
Parameters
----------
data : tvm.Tensor or numpy NDArray
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : tvm.const
Lower limit of score for valid bounding boxes.
Returns
-------
out_tensor : tvm.Tensor or numpy NDArray
Rearranged data tensor.
valid_count : tvm.Tensor or numpy NDArray
1-D tensor for valid number of boxes.
"""
batch_size
=
data
.
shape
[
0
]
num_anchors
=
data
.
shape
[
1
]
box_data_length
=
data
.
shape
[
2
]
valid_count
=
output_tensor
((
batch_size
,),
"int32"
)
out_tensor
=
output_tensor
((
batch_size
,
num_anchors
,
box_data_length
),
data
.
dtype
)
for
i
in
parallel
(
batch_size
):
valid_count
[
i
]
=
0
for
j
in
range
(
num_anchors
):
score
=
data
[
i
,
j
,
1
]
if
score
>
score_threshold
:
for
k
in
range
(
box_data_length
):
out_tensor
[
i
,
valid_count
[
i
],
k
]
=
data
[
i
,
j
,
k
]
valid_count
[
i
]
+=
1
if
j
>=
valid_count
[
i
]:
for
k
in
range
(
box_data_length
):
out_tensor
[
i
,
j
,
k
]
=
-
1.0
return
valid_count
,
out_tensor
@tvm.target.generic_func
def
get_valid_counts
(
data
,
score_threshold
=
0
):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : tvm.Tensor
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
out_tensor : tvm.Tensor
Rearranged data tensor.
force_suppress : boolean
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
"""
score_threshold_const
=
tvm
.
const
(
score_threshold
,
"float"
)
return
hybrid_get_valid_counts
(
data
,
score_threshold_const
)
@hybrid.script
def
hybrid_nms
(
data
,
sorted_index
,
valid_count
,
max_output_size
,
iou_threshold
,
force_suppress
,
top_k
,
id_index
):
"""Hybrid routing for non-maximum suppression.
Parameters
----------
data: tvm.Tensor or numpy NDArray
Bounding boxes with class and score. 3-D tensor with shape
[batch_size, num_anchors, 6].
sorted_index : tvm.Tensor or numpy NDArray
Bounding box indexes sorted by score, with shape
[batch_size, num_anchors].
valid_count : tvm.Tensor or numpy NDArray
1-D tensor for valid number of boxes.
max_output_size : tvm.const
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : tvm.const
Overlapping(IoU) threshold to suppress object with smaller score.
force_suppress : tvm.const
Whether to suppress all detections regardless of class_id.
Whether to suppress all detections regardless of class_id.
nms_topk : in
t
top_k : tvm.cons
t
Keep maximum top k detections before nms, -1 for no limit.
Keep maximum top k detections before nms, -1 for no limit.
id_index : tvm.const
index of the class categories, -1 to disable.
Returns
Returns
-------
-------
stmt : Stmt
output : tvm.Tensor
The result IR statement
.
3-D tensor with shape [batch_size, num_anchors, 6]
.
"""
def
calculate_overlap
(
out_tensor
,
box_a_idx
,
box_b_idx
):
box_indices: tvm.Tensor
"""Calculate overlap of two boxes
.
2-D tensor with shape [batch_size, num_anchors]
.
"""
"""
w
=
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
out_tensor
[
box_a_idx
+
2
],
out_tensor
[
box_b_idx
+
2
])
batch_size
=
data
.
shape
[
0
]
-
tvm
.
make
.
Max
(
out_tensor
[
box_a_idx
],
out_tensor
[
box_b_idx
]))
num_anchors
=
data
.
shape
[
1
]
h
=
tvm
.
make
.
Max
(
0.0
,
tvm
.
make
.
Min
(
out_tensor
[
box_a_idx
+
3
],
out_tensor
[
box_b_idx
+
3
])
box_data_length
=
data
.
shape
[
2
]
-
tvm
.
make
.
Max
(
out_tensor
[
box_a_idx
+
1
],
out_tensor
[
box_b_idx
+
1
]))
box_indices
=
output_tensor
((
batch_size
,
num_anchors
),
"int32"
)
i
=
w
*
h
output
=
output_tensor
((
batch_size
,
u
=
(
out_tensor
[
box_a_idx
+
2
]
-
out_tensor
[
box_a_idx
])
*
\
num_anchors
,
(
out_tensor
[
box_a_idx
+
3
]
-
out_tensor
[
box_a_idx
+
1
])
+
\
box_data_length
,),
(
out_tensor
[
box_b_idx
+
2
]
-
out_tensor
[
box_b_idx
])
*
\
data
.
dtype
)
(
out_tensor
[
box_b_idx
+
3
]
-
out_tensor
[
box_b_idx
+
1
])
-
i
return
tvm
.
expr
.
Select
(
u
<=
0.0
,
0.0
,
i
/
u
)
for
i
in
parallel
(
batch_size
):
if
iou_threshold
>
0
:
ib
=
tvm
.
ir_builder
.
create
()
if
valid_count
[
i
]
>
0
:
p_data
=
ib
.
buffer_ptr
(
data
)
p_sort_result
=
ib
.
buffer_ptr
(
sort_result
)
p_valid_count
=
ib
.
buffer_ptr
(
valid_count
)
p_out
=
ib
.
buffer_ptr
(
out
)
batch_size
=
out
.
shape
[
0
]
num_anchors
=
out
.
shape
[
1
]
nms_threshold_node
=
tvm
.
make
.
node
(
"FloatImm"
,
dtype
=
"float32"
,
value
=
nms_threshold
)
nms_topk_node
=
tvm
.
make
.
node
(
"IntImm"
,
dtype
=
"int32"
,
value
=
nms_topk
)
force_suppress_node
=
tvm
.
make
.
node
(
"IntImm"
,
dtype
=
"int32"
,
value
=
1
if
force_suppress
else
0
)
with
ib
.
for_range
(
0
,
batch_size
,
for_type
=
"parallel"
,
name
=
"n"
)
as
n
:
with
ib
.
if_scope
(
tvm
.
all
(
nms_threshold_node
>
0
,
nms_threshold_node
<
1
,
p_valid_count
[
0
]
>
0
)):
# Reorder output
# Reorder output
nkeep
=
tvm
.
if_then_else
(
nkeep
=
valid_count
[
i
]
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
]),
if
0
<
top_k
<
nkeep
:
nms_topk
,
p_valid_count
[
n
])
nkeep
=
top_k
with
ib
.
for_range
(
0
,
nkeep
,
name
=
"l"
)
as
l
:
for
j
in
range
(
nkeep
):
with
ib
.
for_range
(
0
,
6
,
name
=
"m"
)
as
m
:
for
k
in
range
(
box_data_length
):
p_out
[(
n
*
num_anchors
*
6
output
[
i
,
j
,
k
]
=
data
[
i
,
sorted_index
[
i
,
j
],
k
]
+
l
*
6
+
m
)]
=
p_data
[(
n
*
num_anchors
*
6
box_indices
[
i
,
j
]
=
sorted_index
[
i
,
j
]
+
p_sort_result
[
n
*
num_anchors
+
l
]
*
6
+
m
)]
if
0
<
top_k
<
valid_count
[
i
]:
with
ib
.
if_scope
(
tvm
.
all
(
nms_topk_node
>
0
,
nms_topk
<
p_valid_count
[
n
])):
for
j
in
range
(
valid_count
[
i
]
-
nkeep
):
with
ib
.
for_range
(
0
,
p_valid_count
[
n
]
-
nkeep
,
name
=
"l"
)
as
l
:
for
k
in
range
(
box_data_length
):
with
ib
.
for_range
(
0
,
6
,
name
=
"m"
)
as
m
:
output
[
i
,
j
+
nkeep
,
k
]
=
-
1.0
p_out
[(
n
*
num_anchors
*
6
box_indices
[
i
,
j
+
nkeep
]
=
-
1
+
(
l
+
nkeep
)
*
6
+
m
)]
=
p_data
[(
n
*
num_anchors
*
6
+
(
l
+
nkeep
)
*
6
+
m
)]
# Apply nms
# Apply nms
with
ib
.
for_range
(
0
,
p_valid_count
[
n
],
name
=
"l"
)
as
l
:
for
j
in
range
(
valid_count
[
i
]):
offset_l
=
l
*
6
if
output
[
i
,
j
,
0
]
>=
0
:
with
ib
.
if_scope
(
p_out
[
n
*
num_anchors
*
6
+
offset_l
]
>=
0
):
for
k
in
range
(
valid_count
[
i
]):
with
ib
.
for_range
(
0
,
p_valid_count
[
n
],
name
=
"m"
)
as
m
:
check_iou
=
0
offset_m
=
m
*
6
if
k
>
j
and
output
[
i
,
k
,
0
]
>=
0
:
with
ib
.
if_scope
(
tvm
.
all
(
m
>
l
,
p_out
[
n
*
num_anchors
*
6
if
force_suppress
:
+
offset_m
]
>=
0
)):
check_iou
=
1
with
ib
.
if_scope
(
tvm
.
any
(
force_suppress_node
>
0
,
elif
id_index
<
0
or
output
[
i
,
j
,
0
]
==
output
[
i
,
k
,
0
]:
p_out
[
n
*
num_anchors
*
6
+
offset_l
]
==
check_iou
=
1
p_out
[
n
*
num_anchors
*
6
+
offset_m
])):
if
check_iou
>
0
:
# When force_suppress == True or class_id equals
batch_idx
=
i
iou
=
calculate_overlap
(
p_out
,
n
*
num_anchors
*
6
+
offset_l
+
2
,
box_a_idx
=
j
n
*
num_anchors
*
6
+
offset_m
+
2
)
box_b_idx
=
k
with
ib
.
if_scope
(
iou
>=
nms_threshold
):
box_start_idx
=
2
p_out
[
n
*
num_anchors
*
6
+
offset_m
]
=
-
1.0
a_t
=
output
[
batch_idx
,
box_a_idx
,
box_start_idx
+
1
]
with
ib
.
else_scope
():
a_b
=
output
[
batch_idx
,
box_a_idx
,
box_start_idx
+
3
]
with
ib
.
for_range
(
0
,
p_valid_count
[
n
],
name
=
"l"
)
as
l
:
a_l
=
output
[
batch_idx
,
box_a_idx
,
box_start_idx
]
with
ib
.
for_range
(
0
,
6
,
name
=
"m"
)
as
m
:
a_r
=
output
[
batch_idx
,
box_a_idx
,
box_start_idx
+
2
]
p_out
[(
n
*
num_anchors
*
6
b_t
=
output
[
batch_idx
,
box_b_idx
,
box_start_idx
+
1
]
+
l
*
6
+
m
)]
=
p_data
[
n
*
num_anchors
*
6
+
l
*
6
+
m
]
b_b
=
output
[
batch_idx
,
box_b_idx
,
box_start_idx
+
3
]
b_l
=
output
[
batch_idx
,
box_b_idx
,
box_start_idx
]
b_r
=
output
[
batch_idx
,
box_b_idx
,
box_start_idx
+
2
]
w
=
max
(
0.0
,
min
(
a_r
,
b_r
)
-
max
(
a_l
,
b_l
))
h
=
max
(
0.0
,
min
(
a_b
,
b_b
)
-
max
(
a_t
,
b_t
))
area
=
h
*
w
u
=
(
a_r
-
a_l
)
*
(
a_b
-
a_t
)
+
(
b_r
-
b_l
)
*
(
b_b
-
b_t
)
-
area
iou
=
0.0
if
u
<=
0.0
else
area
/
u
if
iou
>=
iou_threshold
:
output
[
i
,
k
,
0
]
=
-
1.0
box_indices
[
i
,
k
]
=
-
1
else
:
for
j
in
range
(
valid_count
[
i
]):
for
k
in
range
(
box_data_length
):
output
[
i
,
j
,
k
]
=
data
[
i
,
j
,
k
]
box_indices
[
i
,
j
]
=
j
# Set invalid entry to be -1
# Set invalid entry to be -1
with
ib
.
for_range
(
0
,
num_anchors
-
p_valid_count
[
n
],
name
=
"l"
)
as
l
:
for
j
in
range
(
num_anchors
-
valid_count
[
i
]):
with
ib
.
for_range
(
0
,
6
,
name
=
"m"
)
as
m
:
for
k
in
range
(
box_data_length
):
p_out
[
n
*
num_anchors
*
6
+
(
l
+
p_valid_count
[
n
])
*
6
+
m
]
=
-
1.0
output
[
i
,
j
+
valid_count
[
i
],
k
]
=
-
1.0
return
ib
.
get
()
box_indices
[
i
,
j
+
valid_count
[
i
]]
=
-
1
# Only return max_output_size valid boxes
num_valid_boxes
=
0
if
max_output_size
>
0
:
for
j
in
range
(
valid_count
[
i
]):
if
output
[
i
,
j
,
0
]
>=
0
:
if
num_valid_boxes
==
max_output_size
:
for
k
in
range
(
box_data_length
):
output
[
i
,
j
,
k
]
=
-
1.0
box_indices
[
i
,
j
]
=
-
1
else
:
num_valid_boxes
+=
1
return
output
,
box_indices
@tvm.target.generic_func
@tvm.target.generic_func
def
nms
(
data
,
valid_count
,
nms_threshold
=
0.5
,
force_suppress
=
False
,
nms_topk
=-
1
):
def
non_max_suppression
(
data
,
valid_count
,
max_output_size
=-
1
,
iou_threshold
=
0.5
,
force_suppress
=
False
,
top_k
=-
1
,
id_index
=
0
,
return_indices
=
True
,
invalid_to_bottom
=
False
):
"""Non-maximum suppression operator for object detection.
"""Non-maximum suppression operator for object detection.
Parameters
Parameters
----------
----------
data: tvm.Tensor
data
: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
[class_id, score, box_left, box_top, box_right, box_bottom].
...
@@ -120,15 +249,28 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
...
@@ -120,15 +249,28 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
valid_count : tvm.Tensor
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
1-D tensor for valid number of boxes.
nms_threshold : float
max_output_size : optional, int
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : optional, float
Non-maximum suppression threshold.
Non-maximum suppression threshold.
force_suppress : boolean
force_suppress :
optional,
boolean
Whether to suppress all detections regardless of class_id.
Whether to suppress all detections regardless of class_id.
nms_topk :
int
top_k : optional,
int
Keep maximum top k detections before nms, -1 for no limit.
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
return_indices : optional, boolean
Whether to return box indices in input data.
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
Returns
-------
-------
out : tvm.Tensor
out : tvm.Tensor
...
@@ -138,16 +280,17 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
...
@@ -138,16 +280,17 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
--------
--------
.. code-block:: python
.. code-block:: python
# An example to use n
ms
# An example to use n
on_max_suppression
dshape = (1, 5, 6)
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms
_threshold = 0.7
iou
_threshold = 0.7
force_suppress = True
force_suppress = True
nms_topk = -1
top_k = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
out = non_max_suppression(data, valid_count, iou_threshold=iou_threshold,
np_data = np.random.uniform(size=dshape).astype("float32")
force_suppress=force_suppress, top_k=top_k)
np_valid_count = np.array([4]).astype("int32")
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
ctx = tvm.cpu()
...
@@ -161,7 +304,6 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
...
@@ -161,7 +304,6 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
valid_count_dtype
=
"int32"
valid_count_dtype
=
"int32"
valid_count_buf
=
api
.
decl_buffer
(
valid_count
.
shape
,
valid_count_dtype
,
valid_count_buf
=
api
.
decl_buffer
(
valid_count
.
shape
,
valid_count_dtype
,
"valid_count_buf"
,
data_alignment
=
4
)
"valid_count_buf"
,
data_alignment
=
4
)
data_buf
=
api
.
decl_buffer
(
data
.
shape
,
data
.
dtype
,
"data_buf"
,
data_alignment
=
8
)
score_axis
=
1
score_axis
=
1
score_shape
=
(
batch_size
,
num_anchors
)
score_shape
=
(
batch_size
,
num_anchors
)
score_tensor
=
tvm
.
compute
(
score_shape
,
lambda
i
,
j
:
data
[
i
,
j
,
score_axis
])
score_tensor
=
tvm
.
compute
(
score_shape
,
lambda
i
,
j
:
data
[
i
,
j
,
score_axis
])
...
@@ -180,13 +322,13 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
...
@@ -180,13 +322,13 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
in_buffers
=
[
score_tensor_buf
,
valid_count_buf
],
in_buffers
=
[
score_tensor_buf
,
valid_count_buf
],
out_buffers
=
sort_tensor_buf
,
out_buffers
=
sort_tensor_buf
,
name
=
"nms_sort"
)
name
=
"nms_sort"
)
out
=
\
out
,
box_indices
=
hybrid_nms
(
data
,
sort_tensor
,
valid_count
,
tvm
.
extern
(
data
.
shape
,
tvm
.
const
(
max_output_size
,
dtype
=
"int32"
)
,
[
data
,
sort_tensor
,
valid_count
]
,
tvm
.
const
(
iou_threshold
,
dtype
=
"float32"
)
,
lambda
ins
,
outs
:
nms_ir
(
tvm
.
const
(
force_suppress
,
dtype
=
"bool"
),
ins
[
0
],
ins
[
1
],
ins
[
2
],
outs
[
0
],
nms_threshold
,
tvm
.
const
(
top_k
,
dtype
=
"int32"
)
,
force_suppress
,
nms_topk
),
tvm
.
const
(
id_index
,
dtype
=
"int32"
))
dtype
=
"float32"
,
if
not
return_indices
and
invalid_to_bottom
:
in_buffers
=
[
data_buf
,
sort_tensor_buf
,
valid_count_buf
],
out
=
hybrid_rearrange_out
(
out
)
tag
=
"nms"
)
return
out
return
box_indices
if
return_indices
else
out
topi/python/topi/vision/ssd/multibox.py
View file @
d2f29ba5
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
, undefined-variable
"""SSD multibox operators"""
"""SSD multibox operators"""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
math
import
tvm
import
tvm
from
tvm
import
api
from
tvm
import
hybrid
from
tvm.intrin
import
exp
,
sqrt
import
topi
import
topi
from
..nms
import
n
ms
from
..nms
import
n
on_max_suppression
def
multibox_prior_ir
(
data
,
out
,
sizes
,
ratios
,
steps
,
offsets
):
@hybrid.script
"""Low level IR routing for multibox_prior operator.
def
hybrid_multibox_prior
(
data
,
sizes
,
ratios
,
steps
,
offsets
):
"""Hybrid routing for multibox_prior operator.
Parameters
Parameters
----------
----------
data :
Buffer
data :
tvm.Tensor or numpy NDArray
Input data buffer.
4-D tensor with shape [batch, channel, height, width]]
out : Buffe
r
sizes : tvm ConsExp
r
Output buffer
.
Sizes for anchor boxes
.
sizes : tuple of float
ratios : tvm ConsExpr
Tuple of sizes for anchor boxes.
Ratios for anchor boxes.
ratios : tuple of float
Tuple of ratios for anchor boxes.
steps :
Tuple of float
steps :
tvm ConsExpr
Priorbox step across y and x, -1 for auto calculation.
Priorbox step across y and x, -1 for auto calculation.
offsets : t
uple of int
offsets : t
vm ConsExpr
Priorbox center offsets, y and x respectively.
Priorbox center offsets, y and x respectively.
Returns
Returns
-------
-------
stmt : Stmt
output : tvm.Tensor or numpy NDArray
The result IR statement.
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
"""
ib
=
tvm
.
ir_builder
.
create
()
p_out
=
ib
.
buffer_ptr
(
out
)
in_height
=
data
.
shape
[
2
]
in_height
=
data
.
shape
[
2
]
in_width
=
data
.
shape
[
3
]
in_width
=
data
.
shape
[
3
]
num_sizes
=
len
(
sizes
)
num_sizes
=
len
(
sizes
)
num_ratios
=
len
(
ratios
)
num_ratios
=
len
(
ratios
)
size_ratio_concat
=
sizes
+
ratios
num_boxes
=
in_height
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
steps_h
=
steps
[
0
]
if
steps
[
0
]
>
0
else
1.0
/
in_height
output
=
output_tensor
((
1
,
num_boxes
,
4
),
"float32"
)
steps_w
=
steps
[
1
]
if
steps
[
1
]
>
0
else
1.0
/
in_width
steps_h
=
steps
[
0
]
*
1.0
if
steps
[
0
]
>
0
else
1.0
/
in_height
steps_w
=
steps
[
1
]
*
1.0
if
steps
[
1
]
>
0
else
1.0
/
in_width
offset_h
=
offsets
[
0
]
offset_h
=
offsets
[
0
]
offset_w
=
offsets
[
1
]
offset_w
=
offsets
[
1
]
with
ib
.
for_range
(
0
,
in_height
,
for_type
=
"parallel"
,
name
=
"i"
)
as
i
:
# Need to define var out of const_range + if
w
=
0.0
h
=
0.0
for
i
in
parallel
(
in_height
):
center_h
=
(
i
+
offset_h
)
*
steps_h
center_h
=
(
i
+
offset_h
)
*
steps_h
with
ib
.
for_range
(
0
,
in_width
,
name
=
"j"
)
as
j
:
for
j
in
range
(
in_width
)
:
center_w
=
(
j
+
offset_w
)
*
steps_w
center_w
=
(
j
+
offset_w
)
*
steps_w
for
k
in
range
(
num_sizes
+
num_ratios
-
1
):
for
k
in
const_
range
(
num_sizes
+
num_ratios
-
1
):
w
=
tvm
.
if_then_else
(
k
<
num_sizes
,
if
k
<
num_sizes
:
size_ratio_concat
[
k
]
*
in_height
/
in_width
/
2.0
,
w
=
sizes
[
k
]
*
in_height
/
in_width
/
2.0
size_ratio_concat
[
0
]
*
in_height
/
in_width
*
h
=
sizes
[
k
]
/
2.0
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
else
:
h
=
tvm
.
if_then_else
(
w
=
sizes
[
0
]
*
in_height
/
in_width
\
k
<
num_sizes
,
size_ratio_concat
[
k
]
/
2.0
,
*
sqrt
(
ratios
[
k
-
num_sizes
+
1
]
*
1.0
)
/
2.0
size_ratio_concat
[
0
]
/
math
.
sqrt
(
size_ratio_concat
[
k
+
1
])
/
2.0
)
h
=
sizes
[
0
]
/
sqrt
(
ratios
[
k
-
num_sizes
+
1
]
*
1.0
)
/
2.0
count
=
(
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
+
count
=
i
*
in_width
*
(
num_sizes
+
num_ratios
-
1
)
\
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
)
*
4
+
j
*
(
num_sizes
+
num_ratios
-
1
)
+
k
p_out
[
count
]
=
center_w
-
w
output
[
0
,
count
,
0
]
=
center_w
-
w
p_out
[
count
+
1
]
=
center_h
-
h
output
[
0
,
count
,
1
]
=
center_h
-
h
p_out
[
count
+
2
]
=
center_w
+
w
output
[
0
,
count
,
2
]
=
center_w
+
w
p_out
[
count
+
3
]
=
center_h
+
h
output
[
0
,
count
,
3
]
=
center_h
+
h
return
ib
.
get
()
return
output
@tvm.target.generic_func
@tvm.target.generic_func
...
@@ -101,115 +102,120 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
...
@@ -101,115 +102,120 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
out : tvm.Tensor
out : tvm.Tensor
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
"""
num_sizes
=
len
(
sizes
)
out
=
hybrid_multibox_prior
(
data
,
tvm
.
convert
(
sizes
),
tvm
.
convert
(
ratios
),
num_ratios
=
len
(
ratios
)
tvm
.
convert
(
steps
),
tvm
.
convert
(
offsets
))
oshape
=
(
1
,
data
.
shape
[
2
]
*
data
.
shape
[
3
]
*
(
num_sizes
+
num_ratios
-
1
),
4
)
out
=
tvm
.
extern
(
oshape
,
[
data
],
lambda
ins
,
outs
:
multibox_prior_ir
(
ins
[
0
],
outs
[
0
],
sizes
,
ratios
,
steps
,
offsets
),
tag
=
"multibox_prior"
)
if
clip
:
if
clip
:
out
=
topi
.
clip
(
out
,
0
,
1
)
out
=
topi
.
clip
(
out
,
0
,
1
)
return
out
return
out
def
transform_loc_ir
(
cls_prob
,
loc_pred
,
anchor
,
valid_count
,
out
,
clip
,
threshold
,
variances
):
@hybrid.script
"""Low level IR routing for transform location in multibox_detection operator.
def
_hybridy_transform_loc
(
box
,
pred_loc
,
variance
,
clip
):
"""Transform prior anchor box to output box through location predictions.
"""
al
=
box
[
0
]
at
=
box
[
1
]
ar
=
box
[
2
]
ab
=
box
[
3
]
Parameters
px
=
pred_loc
[
0
]
----------
py
=
pred_loc
[
1
]
cls_prob : Buffer
pw
=
pred_loc
[
2
]
Buffer of class probabilities.
ph
=
pred_loc
[
3
]
loc_pred : Buffer
vx
=
variance
[
0
]
Buffer of location regression predictions.
vy
=
variance
[
1
]
vw
=
variance
[
2
]
vh
=
variance
[
3
]
anchor : Buffer
output
=
output_tensor
((
4
,),
pred_loc
.
dtype
)
Buffer of prior anchor boxes.
valid_count : Buffer
aw
=
ar
-
al
Buffer of number of valid output boxes.
ah
=
ab
-
at
ax
=
(
al
+
ar
)
/
2.0
ay
=
(
at
+
ab
)
/
2.0
ox
=
px
*
vx
*
aw
+
ax
oy
=
py
*
vy
*
ah
+
ay
ow
=
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
exp
(
ph
*
vh
)
*
ah
/
2.0
output
[
0
]
=
max
(
0.0
,
min
(
1.0
,
ox
-
ow
))
if
clip
else
ox
-
ow
output
[
1
]
=
max
(
0.0
,
min
(
1.0
,
oy
-
oh
))
if
clip
else
oy
-
oh
output
[
2
]
=
max
(
0.0
,
min
(
1.0
,
ox
+
ow
))
if
clip
else
ox
+
ow
output
[
3
]
=
max
(
0.0
,
min
(
1.0
,
oy
+
oh
))
if
clip
else
oy
+
oh
return
output
@hybrid.script
def
hybrid_multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
clip
,
threshold
,
variances
):
"""Hybrid routing for transform location in multibox_detection operator.
out : Buffer
Parameters
Output buffer.
----------
cls_prob : tvm.Tensor or numpy NDArray
3-D tensor of class probabilities.
clip : boolean
loc_pred : tvm.Tensor or numpy NDArray
2-D tensor of location regression predictions.
anchor : tvm.Tensor or numpy NDArray
3-D tensor of prior anchor boxes.
clip : tvm.const
Whether to clip out-of-boundary boxes.
Whether to clip out-of-boundary boxes.
threshold :
floa
t
threshold :
tvm.cons
t
Threshold to be a positive prediction.
Threshold to be a positive prediction.
variances : t
uple of float
variances : t
vm.ndarray
Variances to be decoded from box regression output.
Variances to be decoded from box regression output.
Returns
Returns
-------
-------
stmt : Stmt
out_loc : tvm.Tensor or numpy NDArray
The result IR statement.
3-D tensor of transformed location.
"""
def
transform_loc
(
loc
,
loc_base_idx
,
anchor
,
anchor_base_idx
,
clip
,
vx
,
vy
,
vw
,
vh
):
"""Transform prior anchor box to output box through location predictions.
"""
al
=
anchor
[
anchor_base_idx
]
at
=
anchor
[
anchor_base_idx
+
1
]
ar
=
anchor
[
anchor_base_idx
+
2
]
ab
=
anchor
[
anchor_base_idx
+
3
]
aw
=
ar
-
al
ah
=
ab
-
at
ax
=
(
al
+
ar
)
/
2.0
ay
=
(
at
+
ab
)
/
2.0
px
=
loc
[
loc_base_idx
]
py
=
loc
[
loc_base_idx
+
1
]
pw
=
loc
[
loc_base_idx
+
2
]
ph
=
loc
[
loc_base_idx
+
3
]
ox
=
px
*
vx
*
aw
+
ax
oy
=
py
*
vy
*
ah
+
ay
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
return
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
if_then_else
(
clip
,
tvm
.
max
(
0
,
tvm
.
min
(
1
,
oy
+
oh
)),
oy
+
oh
)
valid_count : tvm.Tensor or numpy NDArray
1_d tensor of valid counts for boxes.
"""
batch_size
=
cls_prob
.
shape
[
0
]
batch_size
=
cls_prob
.
shape
[
0
]
num_classes
=
cls_prob
.
shape
[
1
]
num_classes
=
cls_prob
.
shape
[
1
]
num_anchors
=
cls_prob
.
shape
[
2
]
num_anchors
=
cls_prob
.
shape
[
2
]
box_coord
=
allocate
((
4
,),
loc_pred
.
dtype
)
ib
=
tvm
.
ir_builder
.
create
()
pred_coord
=
allocate
((
4
,),
loc_pred
.
dtype
)
p_cls_prob
=
ib
.
buffer_ptr
(
cls_prob
)
out_loc
=
output_tensor
((
batch_size
,
num_anchors
,
6
),
p_loc_pred
=
ib
.
buffer_ptr
(
loc_pred
)
loc_pred
.
dtype
)
p_anchor
=
ib
.
buffer_ptr
(
anchor
)
valid_count
=
output_tensor
((
batch_size
,),
"int32"
)
p_valid_count
=
ib
.
buffer_ptr
(
valid_count
)
p_out
=
ib
.
buffer_ptr
(
out
)
for
i
in
parallel
(
batch_size
):
with
ib
.
for_range
(
0
,
batch_size
,
for_type
=
"parallel"
,
name
=
"n"
)
as
n
:
valid_count
[
i
]
=
0
p_valid_count
[
n
]
=
0
for
j
in
range
(
num_anchors
):
with
ib
.
for_range
(
0
,
num_anchors
,
name
=
"i"
)
as
i
:
# Find the predicted class id and probability
# Find the predicted class id and probability
score
=
ib
.
allocate
(
'float32'
,
(
1
,),
name
=
"score"
,
scope
=
"local"
)
score
=
-
1.0
cls_id
=
ib
.
allocate
(
'int32'
,
(
1
,),
name
=
"id"
,
scope
=
"local"
)
cls_id
=
0
score
[
0
]
=
-
1.0
for
k
in
range
(
num_classes
):
cls_id
[
0
]
=
0
if
k
>
0
:
with
ib
.
for_range
(
0
,
num_classes
,
name
=
"j"
)
as
j
:
temp
=
cls_prob
[
i
,
k
,
j
]
with
ib
.
if_scope
(
j
>
0
):
cls_id
=
k
if
temp
>
score
else
cls_id
temp
=
p_cls_prob
[
n
*
num_anchors
*
num_classes
+
j
*
num_anchors
+
i
]
score
=
max
(
temp
,
score
)
cls_id
[
0
]
=
tvm
.
if_then_else
(
temp
>
score
[
0
],
j
,
cls_id
[
0
])
if
cls_id
>
0
and
score
<
threshold
:
score
[
0
]
=
tvm
.
max
(
temp
,
score
[
0
])
cls_id
=
0
with
ib
.
if_scope
(
tvm
.
all
(
cls_id
[
0
]
>
0
,
score
[
0
]
<
threshold
)):
cls_id
[
0
]
=
0
# [id, prob, xmin, ymin, xmax, ymax]
# [id, prob, xmin, ymin, xmax, ymax]
# Remove background, restore original id
# Remove background, restore original id
with
ib
.
if_scope
(
cls_id
[
0
]
>
0
):
if
cls_id
>
0
:
out_base_idx
=
n
*
num_anchors
*
6
+
p_valid_count
[
n
]
*
6
out_loc
[
i
,
valid_count
[
i
],
0
]
=
cls_id
-
1.0
p_out
[
out_base_idx
]
=
cls_id
[
0
]
-
1.0
out_loc
[
i
,
valid_count
[
i
],
1
]
=
score
p_out
[
out_base_idx
+
1
]
=
score
[
0
]
for
l
in
range
(
4
):
offset
=
i
*
4
box_coord
[
l
]
=
anchor
[
0
,
j
,
l
]
p_out
[
out_base_idx
+
2
],
p_out
[
out_base_idx
+
3
],
p_out
[
out_base_idx
+
4
],
\
pred_coord
[
l
]
=
loc_pred
[
i
,
j
*
4
+
l
]
p_out
[
out_base_idx
+
5
]
=
transform_loc
(
p_loc_pred
,
n
*
num_anchors
*
4
+
offset
,
out_coord
=
_hybridy_transform_loc
(
box_coord
,
pred_coord
,
p_anchor
,
offset
,
clip
,
variances
[
0
],
variances
,
clip
)
variances
[
1
],
variances
[
2
],
variances
[
3
])
out_loc
[
i
,
valid_count
[
i
],
2
]
=
out_coord
[
0
]
p_valid_count
[
n
]
+=
1
out_loc
[
i
,
valid_count
[
i
],
3
]
=
out_coord
[
1
]
out_loc
[
i
,
valid_count
[
i
],
4
]
=
out_coord
[
2
]
return
ib
.
get
()
out_loc
[
i
,
valid_count
[
i
],
5
]
=
out_coord
[
3
]
valid_count
[
i
]
+=
1
return
out_loc
,
valid_count
@tvm.target.generic_func
@tvm.target.generic_func
def
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
clip
=
True
,
threshold
=
0.01
,
def
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
clip
=
True
,
threshold
=
0.01
,
...
@@ -240,24 +246,10 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
...
@@ -240,24 +246,10 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
-------
-------
ret : tuple of tvm.Tensor
ret : tuple of tvm.Tensor
"""
"""
batch_size
=
cls_prob
.
shape
[
0
]
return
hybrid_multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
num_anchors
=
anchor
.
shape
[
1
]
tvm
.
const
(
clip
,
"bool"
),
oshape
=
(
batch_size
,
num_anchors
,
6
)
tvm
.
const
(
threshold
,
"float32"
),
# Define data alignment for intermediate buffer
tvm
.
convert
(
variances
))
valid_count_dtype
=
"int32"
valid_count_buf
=
api
.
decl_buffer
((
batch_size
,),
valid_count_dtype
,
"valid_count_buf"
,
data_alignment
=
4
)
out_buf
=
api
.
decl_buffer
(
oshape
,
cls_prob
.
dtype
,
"out_buf"
,
data_alignment
=
8
)
valid_count
,
out
=
\
tvm
.
extern
([(
batch_size
,),
oshape
],
[
cls_prob
,
loc_pred
,
anchor
],
lambda
ins
,
outs
:
transform_loc_ir
(
ins
[
0
],
ins
[
1
],
ins
[
2
],
outs
[
0
],
outs
[
1
],
clip
,
threshold
,
variances
),
dtype
=
[
valid_count_dtype
,
cls_prob
.
dtype
],
out_buffers
=
[
valid_count_buf
,
out_buf
],
tag
=
"multibox_transform_loc"
)
return
[
out
,
valid_count
]
@tvm.target.generic_func
@tvm.target.generic_func
def
multibox_detection
(
cls_prob
,
loc_pred
,
anchor
,
clip
=
True
,
threshold
=
0.01
,
nms_threshold
=
0.5
,
def
multibox_detection
(
cls_prob
,
loc_pred
,
anchor
,
clip
=
True
,
threshold
=
0.01
,
nms_threshold
=
0.5
,
...
@@ -300,5 +292,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
...
@@ -300,5 +292,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
"""
"""
inter_out
=
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
inter_out
=
multibox_transform_loc
(
cls_prob
,
loc_pred
,
anchor
,
clip
,
threshold
,
variances
)
clip
,
threshold
,
variances
)
out
=
nms
(
inter_out
[
0
],
inter_out
[
1
],
nms_threshold
,
force_suppress
,
nms_topk
)
out
=
non_max_suppression
(
inter_out
[
0
],
inter_out
[
1
],
-
1
,
nms_threshold
,
force_suppress
,
nms_topk
,
return_indices
=
False
)
return
out
return
out
topi/tests/python/test_topi_vision.py
View file @
d2f29ba5
...
@@ -8,11 +8,62 @@ import topi.testing
...
@@ -8,11 +8,62 @@ import topi.testing
from
tvm.contrib.pickle_memoize
import
memoize
from
tvm.contrib.pickle_memoize
import
memoize
from
topi.util
import
get_const_tuple
from
topi.util
import
get_const_tuple
from
topi.vision
import
ssd
,
nms
from
topi.vision
import
ssd
,
non_max_suppression
,
get_valid_counts
def
verify_get_valid_counts
(
dshape
,
score_threshold
):
dtype
=
"float32"
batch_size
,
num_anchor
,
elem_length
=
dshape
np_data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
)
np_out1
=
np
.
zeros
(
shape
=
(
batch_size
,))
np_out2
=
np
.
zeros
(
shape
=
dshape
)
.
astype
(
dtype
)
for
i
in
range
(
batch_size
):
np_out1
[
i
]
=
0
inter_idx
=
0
for
j
in
range
(
num_anchor
):
score
=
np_data
[
i
,
j
,
1
]
if
score
>
score_threshold
:
for
k
in
range
(
elem_length
):
np_out2
[
i
,
inter_idx
,
k
]
=
np_data
[
i
,
j
,
k
]
np_out1
[
i
]
+=
1
inter_idx
+=
1
if
j
>=
np_out1
[
i
]:
for
k
in
range
(
elem_length
):
np_out2
[
i
,
j
,
k
]
=
-
1.0
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
data
=
tvm
.
placeholder
(
dshape
,
name
=
"data"
,
dtype
=
dtype
)
outs
=
get_valid_counts
(
data
,
score_threshold
)
s
=
topi
.
generic
.
schedule_multibox_prior
(
outs
)
tvm_input_data
=
tvm
.
nd
.
array
(
np_data
,
ctx
)
tvm_out1
=
tvm
.
nd
.
array
(
np
.
zeros
(
np_out1
.
shape
,
dtype
=
"int32"
),
ctx
)
tvm_out2
=
tvm
.
nd
.
array
(
np
.
zeros
(
np_out2
.
shape
,
dtype
=
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
data
,
outs
[
0
],
outs
[
1
]],
device
)
f
(
tvm_input_data
,
tvm_out1
,
tvm_out2
)
tvm
.
testing
.
assert_allclose
(
tvm_out1
.
asnumpy
(),
np_out1
,
rtol
=
1e-3
)
tvm
.
testing
.
assert_allclose
(
tvm_out2
.
asnumpy
(),
np_out2
,
rtol
=
1e-3
)
def
test_nms
():
for
device
in
[
'llvm'
]:
check_device
(
device
)
def
test_get_valid_counts
():
verify_get_valid_counts
((
1
,
2500
,
6
),
0
)
verify_get_valid_counts
((
1
,
2500
,
6
),
-
1
)
verify_get_valid_counts
((
3
,
1000
,
6
),
0.55
)
verify_get_valid_counts
((
16
,
500
,
6
),
0.95
)
def
test_non_max_suppression
():
dshape
=
(
1
,
5
,
6
)
dshape
=
(
1
,
5
,
6
)
indices_dshape
=
(
1
,
5
)
data
=
tvm
.
placeholder
(
dshape
,
name
=
"data"
)
data
=
tvm
.
placeholder
(
dshape
,
name
=
"data"
)
valid_count
=
tvm
.
placeholder
((
dshape
[
0
],),
dtype
=
"int32"
,
name
=
"valid_count"
)
valid_count
=
tvm
.
placeholder
((
dshape
[
0
],),
dtype
=
"int32"
,
name
=
"valid_count"
)
nms_threshold
=
0.7
nms_threshold
=
0.7
...
@@ -24,8 +75,9 @@ def test_nms():
...
@@ -24,8 +75,9 @@ def test_nms():
[
1
,
0.5
,
100
,
60
,
70
,
110
]]])
.
astype
(
data
.
dtype
)
[
1
,
0.5
,
100
,
60
,
70
,
110
]]])
.
astype
(
data
.
dtype
)
np_valid_count
=
np
.
array
([
4
])
.
astype
(
valid_count
.
dtype
)
np_valid_count
=
np
.
array
([
4
])
.
astype
(
valid_count
.
dtype
)
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
np_result
=
np
.
array
([[[
2
,
0.9
,
35
,
61
,
52
,
79
],
[
0
,
0.8
,
1
,
20
,
25
,
45
],
[
0
,
0.4
,
4
,
21
,
19
,
40
],
[
-
1
,
0.9
,
35
,
61
,
52
,
79
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]]])
np_indices_result
=
np
.
array
([[
3
,
0
,
-
1
,
-
1
,
-
1
]])
def
check_device
(
device
):
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
ctx
=
tvm
.
context
(
device
,
0
)
...
@@ -35,18 +87,27 @@ def test_nms():
...
@@ -35,18 +87,27 @@ def test_nms():
print
(
"Running on target:
%
s"
%
device
)
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
with
tvm
.
target
.
create
(
device
):
if
device
==
'llvm'
:
if
device
==
'llvm'
:
out
=
nms
(
data
,
valid_count
,
nms_threshold
,
force_suppress
,
nms_topk
)
out
=
non_max_suppression
(
data
,
valid_count
,
-
1
,
nms_threshold
,
force_suppress
,
nms_topk
,
return_indices
=
False
)
indices_out
=
non_max_suppression
(
data
,
valid_count
,
-
1
,
nms_threshold
,
force_suppress
,
nms_topk
)
else
:
else
:
out
=
topi
.
cuda
.
nms
(
data
,
valid_count
,
nms_threshold
,
force_suppress
,
nms_topk
)
out
=
topi
.
cuda
.
non_max_suppression
(
data
,
valid_count
,
-
1
,
nms_threshold
,
force_suppress
,
nms_topk
,
return_indices
=
False
)
indices_out
=
topi
.
cuda
.
non_max_suppression
(
data
,
valid_count
,
-
1
,
nms_threshold
,
force_suppress
,
nms_topk
)
s
=
topi
.
generic
.
schedule_nms
(
out
)
s
=
topi
.
generic
.
schedule_nms
(
out
)
indices_s
=
topi
.
generic
.
schedule_nms
(
indices_out
)
tvm_data
=
tvm
.
nd
.
array
(
np_data
,
ctx
)
tvm_data
=
tvm
.
nd
.
array
(
np_data
,
ctx
)
tvm_valid_count
=
tvm
.
nd
.
array
(
np_valid_count
,
ctx
)
tvm_valid_count
=
tvm
.
nd
.
array
(
np_valid_count
,
ctx
)
tvm_out
=
tvm
.
nd
.
array
(
np
.
zeros
(
dshape
,
dtype
=
data
.
dtype
),
ctx
)
tvm_out
=
tvm
.
nd
.
array
(
np
.
zeros
(
dshape
,
dtype
=
data
.
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
data
,
valid_count
,
out
],
device
)
f
=
tvm
.
build
(
s
,
[
data
,
valid_count
,
out
],
device
)
f
(
tvm_data
,
tvm_valid_count
,
tvm_out
)
f
(
tvm_data
,
tvm_valid_count
,
tvm_out
)
tvm
.
testing
.
assert_allclose
(
tvm_out
.
asnumpy
(),
np_result
,
rtol
=
1e-4
)
tvm
.
testing
.
assert_allclose
(
tvm_out
.
asnumpy
(),
np_result
,
rtol
=
1e-4
)
tvm_indices_out
=
tvm
.
nd
.
array
(
np
.
zeros
(
indices_dshape
,
dtype
=
"int32"
),
ctx
)
f
=
tvm
.
build
(
indices_s
,
[
data
,
valid_count
,
indices_out
],
device
)
f
(
tvm_data
,
tvm_valid_count
,
tvm_indices_out
)
tvm
.
testing
.
assert_allclose
(
tvm_indices_out
.
asnumpy
(),
np_indices_result
,
rtol
=
1e-4
)
for
device
in
[
'llvm'
]:
for
device
in
[
'llvm'
]:
check_device
(
device
)
check_device
(
device
)
...
@@ -274,7 +335,8 @@ def test_proposal():
...
@@ -274,7 +335,8 @@ def test_proposal():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_nms
()
test_get_valid_counts
()
test_non_max_suppression
()
test_multibox_prior
()
test_multibox_prior
()
test_multibox_detection
()
test_multibox_detection
()
test_roi_align
()
test_roi_align
()
...
...
tutorials/frontend/deploy_ssd_gluoncv.py
0 → 100644
View file @
d2f29ba5
"""
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use GluonCV pre-trained SSD model and convert it to Relay IR
"""
import
tvm
from
matplotlib
import
pyplot
as
plt
from
nnvm
import
compiler
from
nnvm.frontend
import
from_mxnet
from
nnvm.testing.config
import
ctx_list
from
tvm
import
relay
from
tvm.contrib
import
graph_runtime
from
gluoncv
import
model_zoo
,
data
,
utils
######################################################################
# Preliminary and Set parameters
# ------------------------------
# We should build TVM with sort support, in TVM root directory
#
# .. code-block:: bash
#
# echo "set(USE_SORT ON)" > config.mk
# make -j8
#
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
#
# To get best inference performance on CPU, change
# target argument according to your device and
# follow the :ref:`tune_relay_x86` to tune x86 CPU and
# :ref:`tune_relay_arm` for arm cpu.
#
# SSD with VGG as body network is not supported yet since
# x86 conv2d schedule doesn't support dilation.
supported_model
=
[
'ssd_512_resnet18_v1_voc'
,
'ssd_512_resnet18_v1_coco'
,
'ssd_512_resnet50_v1_voc'
,
'ssd_512_resnet50_v1_coco'
,
'ssd_512_resnet101_v2_voc'
,
'ssd_512_mobilenet1_0_voc'
,
'ssd_512_mobilenet1_0_coco'
,
]
model_name
=
"ssd_512_resnet50_v1_voc"
dshape
=
(
1
,
3
,
512
,
512
)
dtype
=
"float32"
target_list
=
ctx_list
()
######################################################################
# Download and pre-process demo image
im_fname
=
utils
.
download
(
'https://github.com/dmlc/web-data/blob/master/'
+
'gluoncv/detection/street_small.jpg?raw=true'
,
path
=
'street_small.jpg'
)
x
,
img
=
data
.
transforms
.
presets
.
ssd
.
load_test
(
im_fname
,
short
=
512
)
######################################################################
# Convert and compile model for CPU.
block
=
model_zoo
.
get_model
(
model_name
,
pretrained
=
True
)
def
compile
(
target
):
net
,
params
=
relay
.
frontend
.
from_mxnet
(
block
,
{
"data"
:
dshape
})
with
relay
.
build_config
(
opt_level
=
3
):
graph
,
lib
,
params
=
relay
.
build
(
net
,
target
,
params
=
params
)
return
graph
,
lib
,
params
######################################################################
# Create TVM runtime and do inference
def
run
(
graph
,
lib
,
params
,
ctx
):
# Build TVM runtime
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
tvm_input
=
tvm
.
nd
.
array
(
x
.
asnumpy
(),
ctx
=
ctx
)
m
.
set_input
(
'data'
,
tvm_input
)
m
.
set_input
(
**
params
)
# execute
m
.
run
()
# get outputs
class_IDs
,
scores
,
bounding_boxs
=
m
.
get_output
(
0
),
m
.
get_output
(
1
),
m
.
get_output
(
2
)
return
class_IDs
,
scores
,
bounding_boxs
for
target
,
ctx
in
target_list
:
if
target
==
"cuda"
:
print
(
"GPU not supported yet, skip."
)
continue
graph
,
lib
,
params
=
compile
(
target
)
class_IDs
,
scores
,
bounding_boxs
=
run
(
graph
,
lib
,
params
,
ctx
)
######################################################################
# Display result
ax
=
utils
.
viz
.
plot_bbox
(
img
,
bounding_boxs
.
asnumpy
()[
0
],
scores
.
asnumpy
()[
0
],
class_IDs
.
asnumpy
()[
0
],
class_names
=
block
.
classes
)
plt
.
show
()
tutorials/nnvm/deploy_ssd.py
→
tutorials/nnvm/deploy_ssd
_mxnet
.py
View file @
d2f29ba5
...
@@ -61,7 +61,7 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
...
@@ -61,7 +61,7 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
image_url
=
"https://cloud.githubusercontent.com/assets/3307514/20012567/"
\
image_url
=
"https://cloud.githubusercontent.com/assets/3307514/20012567/"
\
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder
=
\
inference_symbol_folder
=
\
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url
=
"https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/"
\
inference_symbol_url
=
"https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/"
\
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment