Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
baf7a729
Commit
baf7a729
authored
Mar 15, 2019
by
Wuwei Lin
Committed by
masahi
Mar 15, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI, Relay] ROI Pool operator (#2811)
parent
c0a5a9be
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
351 additions
and
0 deletions
+351
-0
include/tvm/relay/attrs/vision.h
+20
-0
python/tvm/relay/frontend/mxnet.py
+9
-0
python/tvm/relay/op/vision/_rcnn.py
+16
-0
python/tvm/relay/op/vision/rcnn.py
+27
-0
src/relay/op/vision/rcnn_op.cc
+52
-0
tests/python/relay/test_op_level5.py
+33
-0
topi/python/topi/cuda/vision.py
+4
-0
topi/python/topi/generic/vision.py
+17
-0
topi/python/topi/testing/__init__.py
+1
-0
topi/python/topi/testing/roi_pool_python.py
+47
-0
topi/python/topi/vision/rcnn/__init__.py
+1
-0
topi/python/topi/vision/rcnn/roi_pool.py
+77
-0
topi/tests/python/test_topi_vision.py
+47
-0
No files found.
include/tvm/relay/attrs/vision.h
View file @
baf7a729
...
...
@@ -121,6 +121,26 @@ struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
}
};
/*! \brief Attributes used in roi_pool operators */
struct
ROIPoolAttrs
:
public
tvm
::
AttrsNode
<
ROIPoolAttrs
>
{
Array
<
IndexExpr
>
pooled_size
;
double
spatial_scale
;
std
::
string
layout
;
TVM_DECLARE_ATTRS
(
ROIPoolAttrs
,
"relay.attrs.ROIPoolAttrs"
)
{
TVM_ATTR_FIELD
(
pooled_size
).
describe
(
"Output size of roi pool."
);
TVM_ATTR_FIELD
(
spatial_scale
)
.
describe
(
"Ratio of input feature map height (or w) to raw image height (or w). "
"Equals the reciprocal of total stride in convolutional layers, which should be "
"in range (0.0, 1.0]"
);
TVM_ATTR_FIELD
(
layout
).
set_default
(
"NCHW"
).
describe
(
"Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions."
);
}
};
/*! \brief Attributes used in yolo reorg operators */
struct
YoloReorgAttrs
:
public
tvm
::
AttrsNode
<
YoloReorgAttrs
>
{
Integer
stride
;
...
...
python/tvm/relay/frontend/mxnet.py
View file @
baf7a729
...
...
@@ -438,6 +438,14 @@ def _mx_roi_align(inputs, attrs):
return
_op
.
vision
.
roi_align
(
inputs
[
0
],
inputs
[
1
],
**
new_attrs
)
def
_mx_roi_pooling
(
inputs
,
attrs
):
new_attrs
=
{}
new_attrs
[
"pooled_size"
]
=
attrs
.
get_int_tuple
(
"pooled_size"
)
new_attrs
[
"spatial_scale"
]
=
attrs
.
get_float
(
"spatial_scale"
)
new_attrs
[
"layout"
]
=
"NCHW"
return
_op
.
vision
.
roi_pool
(
inputs
[
0
],
inputs
[
1
],
**
new_attrs
)
def
_mx_proposal
(
inputs
,
attrs
):
new_attrs
=
{}
new_attrs
[
"scales"
]
=
attrs
.
get_float_tuple
(
"scales"
,
(
4.0
,
8.0
,
16.0
,
32.0
))
...
...
@@ -641,6 +649,7 @@ _convert_map = {
"_contrib_MultiBoxPrior"
:
_mx_multibox_prior
,
"_contrib_MultiBoxDetection"
:
_mx_multibox_detection
,
"_contrib_ROIAlign"
:
_mx_roi_align
,
"ROIPooling"
:
_mx_roi_pooling
,
"_contrib_Proposal"
:
_mx_proposal
,
"_contrib_MultiProposal"
:
_mx_proposal
,
"_contrib_box_nms"
:
_mx_box_nms
,
...
...
python/tvm/relay/op/vision/_rcnn.py
View file @
baf7a729
...
...
@@ -22,6 +22,22 @@ def schedule_roi_align(_, outs, target):
reg
.
register_pattern
(
"vision.roi_align"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
@reg.register_compute
(
"vision.roi_pool"
)
def
compute_roi_pool
(
attrs
,
inputs
,
_
,
target
):
"""Compute definition of roi_pool"""
assert
attrs
.
layout
==
"NCHW"
return
[
topi
.
vision
.
rcnn
.
roi_pool_nchw
(
inputs
[
0
],
inputs
[
1
],
pooled_size
=
get_const_tuple
(
attrs
.
pooled_size
),
spatial_scale
=
attrs
.
spatial_scale
)]
@reg.register_schedule
(
"vision.roi_pool"
)
def
schedule_roi_pool
(
_
,
outs
,
target
):
"""Schedule definition of roi_pool"""
with
target
:
return
topi
.
generic
.
vision
.
schedule_roi_pool
(
outs
)
reg
.
register_pattern
(
"vision.roi_pool"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
@reg.register_compute
(
"vision.proposal"
)
def
compute_proposal
(
attrs
,
inputs
,
_
,
target
):
"""Compute definition of proposal"""
...
...
python/tvm/relay/op/vision/rcnn.py
View file @
baf7a729
...
...
@@ -32,6 +32,33 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N
return
_make
.
roi_align
(
data
,
rois
,
pooled_size
,
spatial_scale
,
sample_ratio
,
layout
)
def
roi_pool
(
data
,
rois
,
pooled_size
,
spatial_scale
,
layout
=
'NCHW'
):
"""ROI pool operator.
Parameters
----------
data : relay.Expr
4-D tensor with shape [batch, channel, height, width]
rois : relay.Expr
2-D tensor with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]
pooled_size : list/tuple of two ints
output size
spatial_scale : float
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]
Returns
-------
output : relay.Expr
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
"""
return
_make
.
roi_pool
(
data
,
rois
,
pooled_size
,
spatial_scale
,
layout
)
def
proposal
(
cls_prob
,
bbox_pred
,
im_info
,
...
...
src/relay/op/vision/rcnn_op.cc
View file @
baf7a729
...
...
@@ -63,6 +63,58 @@ RELAY_REGISTER_OP("vision.roi_align")
.
set_support_level
(
5
)
.
add_type_rel
(
"ROIAlign"
,
ROIAlignRel
);
TVM_REGISTER_NODE_TYPE
(
ROIPoolAttrs
);
bool
ROIPoolRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
auto
roi_pool_attrs
=
attrs
.
as
<
ROIPoolAttrs
>
();
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
rois
=
types
[
1
].
as
<
TensorTypeNode
>
();
const
auto
&
dshape
=
data
->
shape
;
const
auto
&
rshape
=
rois
->
shape
;
CHECK
(
roi_pool_attrs
);
CHECK_EQ
(
dshape
.
size
(),
4
)
<<
"Input data should be 4-D."
;
CHECK_EQ
(
rshape
.
size
(),
2
)
<<
"Input rois should be 2-D."
;
CHECK_EQ
(
roi_pool_attrs
->
layout
,
"NCHW"
)
<<
"ROI Pool only supports NCHW layout"
;
// assign output type
std
::
vector
<
IndexExpr
>
oshape
(
{
rshape
[
0
],
dshape
[
1
],
roi_pool_attrs
->
pooled_size
[
0
],
roi_pool_attrs
->
pooled_size
[
1
]});
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
));
return
true
;
}
Expr
MakeROIPool
(
Expr
data
,
Expr
rois
,
Array
<
IndexExpr
>
pooled_size
,
double
spatial_scale
,
std
::
string
layout
)
{
auto
attrs
=
make_node
<
ROIPoolAttrs
>
();
attrs
->
pooled_size
=
pooled_size
;
attrs
->
spatial_scale
=
spatial_scale
;
attrs
->
layout
=
layout
;
static
const
Op
&
op
=
Op
::
Get
(
"vision.roi_pool"
);
return
CallNode
::
make
(
op
,
{
data
,
rois
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.vision._make.roi_pool"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
5
>
(
MakeROIPool
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"vision.roi_pool"
)
.
describe
(
R"doc(ROI Pool operator.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **rois**: 2D array of shape (num_roi, 5). The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end].
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`.
)doc"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"rois"
,
"Tensor"
,
"The input rois"
)
.
set_support_level
(
5
)
.
add_type_rel
(
"ROIPool"
,
ROIPoolRel
);
TVM_REGISTER_NODE_TYPE
(
ProposalAttrs
);
bool
ProposalRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
...
...
tests/python/relay/test_op_level5.py
View file @
baf7a729
...
...
@@ -357,6 +357,38 @@ def test_roi_align():
verify_roi_align
((
4
,
4
,
16
,
16
),
(
32
,
5
),
pooled_size
=
7
,
spatial_scale
=
0.5
,
sample_ratio
=
2
)
def
test_roi_pool
():
def
verify_roi_pool
(
data_shape
,
rois_shape
,
pooled_size
,
spatial_scale
):
data
=
relay
.
var
(
"data"
,
relay
.
ty
.
TensorType
(
data_shape
,
"float32"
))
rois
=
relay
.
var
(
"rois"
,
relay
.
ty
.
TensorType
(
rois_shape
,
"float32"
))
z
=
relay
.
vision
.
roi_pool
(
data
,
rois
,
pooled_size
=
(
pooled_size
,
pooled_size
),
spatial_scale
=
spatial_scale
,
layout
=
"NCHW"
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
batch
,
channel
,
in_size
,
_
=
data_shape
num_roi
=
rois_shape
[
0
]
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
(
(
num_roi
,
channel
,
pooled_size
,
pooled_size
),
"float32"
)
func
=
relay
.
Function
([
data
,
rois
],
z
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
np_data
=
np
.
random
.
uniform
(
size
=
data_shape
)
.
astype
(
"float32"
)
np_rois
=
np
.
random
.
uniform
(
size
=
rois_shape
)
.
astype
(
'float32'
)
*
in_size
np_rois
[:,
0
]
=
np
.
random
.
randint
(
low
=
0
,
high
=
batch
,
size
=
num_roi
)
.
astype
(
'float32'
)
ref_res
=
topi
.
testing
.
roi_pool_nchw_python
(
np_data
,
np_rois
,
pooled_size
=
pooled_size
,
spatial_scale
=
spatial_scale
)
for
target
,
ctx
in
ctx_list
():
intrp1
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
op_res1
=
intrp1
.
evaluate
(
func
)(
np_data
,
np_rois
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
1e-4
)
intrp2
=
relay
.
create_executor
(
"debug"
,
ctx
=
ctx
,
target
=
target
)
op_res2
=
intrp2
.
evaluate
(
func
)(
np_data
,
np_rois
)
tvm
.
testing
.
assert_allclose
(
op_res2
.
asnumpy
(),
ref_res
,
rtol
=
1e-4
)
verify_roi_pool
((
1
,
4
,
16
,
16
),
(
32
,
5
),
pooled_size
=
7
,
spatial_scale
=
1.0
)
verify_roi_pool
((
4
,
4
,
16
,
16
),
(
32
,
5
),
pooled_size
=
7
,
spatial_scale
=
0.5
)
def
test_proposal
():
def
verify_proposal
(
np_cls_prob
,
np_bbox_pred
,
np_im_info
,
np_out
,
attrs
):
cls_prob
=
relay
.
var
(
"cls_prob"
,
relay
.
ty
.
TensorType
(
np_cls_prob
.
shape
,
"float32"
))
...
...
@@ -464,6 +496,7 @@ if __name__ == "__main__":
test_multibox_transform_loc
()
test_get_valid_counts
()
test_roi_align
()
test_roi_pool
()
test_proposal
()
test_yolo_reorg_infer_shape
()
test_yolo_reorg
()
...
...
topi/python/topi/cuda/vision.py
View file @
baf7a729
...
...
@@ -134,6 +134,10 @@ def schedule_multibox_detection(outs):
def
schedule_roi_align
(
outs
):
return
schedule_pool
(
outs
,
'NCHW'
)
@generic.schedule_roi_pool.register
([
"cuda"
,
"gpu"
])
def
schedule_roi_pool
(
outs
):
return
schedule_pool
(
outs
,
'NCHW'
)
@generic.schedule_proposal.register
([
"cuda"
,
"gpu"
])
def
schedule_proposal
(
outs
):
"""Schedule for proposal operator.
...
...
topi/python/topi/generic/vision.py
View file @
baf7a729
...
...
@@ -140,6 +140,23 @@ def schedule_roi_align(outs):
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_roi_pool
(
outs
):
"""Schedule for roi_align
Parameters
----------
outs: Array of Tensor
The computation graph description of roi_pool
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
,
False
)
@tvm.target.generic_func
def
schedule_proposal
(
outs
):
"""Schedule for proposal operator.
...
...
topi/python/topi/testing/__init__.py
View file @
baf7a729
...
...
@@ -15,6 +15,7 @@ from .upsampling_python import upsampling_python
from
.bilinear_resize_python
import
bilinear_resize_python
from
.reorg_python
import
reorg_python
from
.roi_align_python
import
roi_align_nchw_python
from
.roi_pool_python
import
roi_pool_nchw_python
from
.lrn_python
import
lrn_python
from
.l2_normalize_python
import
l2_normalize_python
from
.gather_nd_python
import
gather_nd_python
...
...
topi/python/topi/testing/roi_pool_python.py
0 → 100644
View file @
baf7a729
# pylint: disable=invalid-name, too-many-nested-blocks
"Roi pool in python"
import
math
import
numpy
as
np
def
roi_pool_nchw_python
(
a_np
,
rois_np
,
pooled_size
,
spatial_scale
):
"""Roi pool in python"""
_
,
channel
,
height
,
width
=
a_np
.
shape
num_roi
=
rois_np
.
shape
[
0
]
b_np
=
np
.
zeros
((
num_roi
,
channel
,
pooled_size
,
pooled_size
),
dtype
=
a_np
.
dtype
)
if
isinstance
(
pooled_size
,
int
):
pooled_size_h
=
pooled_size_w
=
pooled_size
else
:
pooled_size_h
,
pooled_size_w
=
pooled_size
for
i
in
range
(
num_roi
):
roi
=
rois_np
[
i
]
batch_index
=
int
(
roi
[
0
])
roi_start_w
=
int
(
round
(
roi
[
1
]
*
spatial_scale
))
roi_start_h
=
int
(
round
(
roi
[
2
]
*
spatial_scale
))
roi_end_w
=
int
(
round
(
roi
[
3
]
*
spatial_scale
))
roi_end_h
=
int
(
round
(
roi
[
4
]
*
spatial_scale
))
roi_h
=
max
(
roi_end_h
-
roi_start_h
+
1
,
1
)
roi_w
=
max
(
roi_end_w
-
roi_start_w
+
1
,
1
)
bin_h
=
float
(
roi_h
)
/
pooled_size_h
bin_w
=
float
(
roi_w
)
/
pooled_size_w
for
ph
in
range
(
pooled_size_h
):
for
pw
in
range
(
pooled_size_w
):
hstart
=
int
(
math
.
floor
(
ph
*
bin_h
))
wstart
=
int
(
math
.
floor
(
pw
*
bin_w
))
hend
=
int
(
math
.
ceil
((
ph
+
1
)
*
bin_h
))
wend
=
int
(
math
.
ceil
((
pw
+
1
)
*
bin_w
))
hstart
=
min
(
max
(
hstart
+
roi_start_h
,
0
),
height
)
hend
=
min
(
max
(
hend
+
roi_start_h
,
0
),
height
)
wstart
=
min
(
max
(
wstart
+
roi_start_w
,
0
),
width
)
wend
=
min
(
max
(
wend
+
roi_start_w
,
0
),
width
)
is_empty
=
(
hend
<=
hstart
)
or
(
wend
<=
wstart
)
for
c
in
range
(
channel
):
if
is_empty
:
b_np
[
i
,
c
,
ph
,
pw
]
=
0.
else
:
b_np
[
i
,
c
,
ph
,
pw
]
=
np
.
max
(
a_np
[
batch_index
,
c
,
hstart
:
hend
,
wstart
:
wend
])
return
b_np
topi/python/topi/vision/rcnn/__init__.py
View file @
baf7a729
# pylint: disable=wildcard-import
"""Faster R-CNN and Mask R-CNN operators"""
from
.roi_align
import
*
from
.roi_pool
import
*
from
.proposal
import
*
topi/python/topi/vision/rcnn/roi_pool.py
0 → 100644
View file @
baf7a729
# pylint: disable=invalid-name
"""ROI pool operator"""
import
tvm
from
...util
import
get_const_tuple
@tvm.target.generic_func
def
roi_pool_nchw
(
data
,
rois
,
pooled_size
,
spatial_scale
):
"""ROI pool operator in NCHW layout.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, height, width]
rois : tvm.Tensor
2-D with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]
pooled_size : int or list/tuple of two ints
output size, or [out_height, out_width]
spatial_scale : float
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]
Returns
-------
output : tvm.Tensor
4-D with shape [num_roi, channel, pooled_size, pooled_size]
"""
dtype
=
rois
.
dtype
_
,
channel
,
height
,
width
=
get_const_tuple
(
data
.
shape
)
num_roi
,
_
=
get_const_tuple
(
rois
.
shape
)
if
isinstance
(
pooled_size
,
int
):
pooled_size_h
=
pooled_size_w
=
pooled_size
else
:
pooled_size_h
,
pooled_size_w
=
pooled_size
def
_pool
(
i
,
c
,
ph
,
pw
):
roi
=
rois
[
i
]
batch_index
=
roi
[
0
]
.
astype
(
'int32'
)
roi_start_w
,
roi_start_h
,
roi_end_w
,
roi_end_h
=
roi
[
1
],
roi
[
2
],
roi
[
3
],
roi
[
4
]
roi_start_h
=
tvm
.
round
(
roi_start_h
*
spatial_scale
)
.
astype
(
'int32'
)
roi_start_w
=
tvm
.
round
(
roi_start_w
*
spatial_scale
)
.
astype
(
'int32'
)
roi_end_h
=
tvm
.
round
(
roi_end_h
*
spatial_scale
)
.
astype
(
'int32'
)
roi_end_w
=
tvm
.
round
(
roi_end_w
*
spatial_scale
)
.
astype
(
'int32'
)
# force malformed ROIs to be 1x1
roi_h
=
tvm
.
max
(
roi_end_h
-
roi_start_h
+
1
,
tvm
.
const
(
1
,
'int32'
))
roi_w
=
tvm
.
max
(
roi_end_w
-
roi_start_w
+
1
,
tvm
.
const
(
1
,
'int32'
))
bin_h
=
roi_h
.
astype
(
dtype
)
/
pooled_size_h
bin_w
=
roi_w
.
astype
(
dtype
)
/
pooled_size_w
# use epsilon to prevent floating point precision loss in floor/ceil
epsilon
=
tvm
.
const
(
0.00001
,
dtype
)
hstart
=
tvm
.
floor
(
ph
*
bin_h
+
epsilon
)
.
astype
(
'int32'
)
wstart
=
tvm
.
floor
(
pw
*
bin_w
+
epsilon
)
.
astype
(
'int32'
)
hend
=
tvm
.
ceil
((
ph
+
1
)
*
bin_h
-
epsilon
)
.
astype
(
'int32'
)
wend
=
tvm
.
ceil
((
pw
+
1
)
*
bin_w
-
epsilon
)
.
astype
(
'int32'
)
hstart
=
tvm
.
min
(
tvm
.
max
(
hstart
+
roi_start_h
,
0
),
height
)
wstart
=
tvm
.
min
(
tvm
.
max
(
wstart
+
roi_start_w
,
0
),
width
)
hend
=
tvm
.
min
(
tvm
.
max
(
hend
+
roi_start_h
,
0
),
height
)
wend
=
tvm
.
min
(
tvm
.
max
(
wend
+
roi_start_w
,
0
),
width
)
non_empty
=
tvm
.
all
(
hstart
<
hend
,
wstart
<
wend
)
min_value
=
lambda
dtype
:
tvm
.
if_then_else
(
non_empty
,
tvm
.
min_value
(
dtype
),
tvm
.
const
(
0.0
,
dtype
))
# pylint: disable=unnecessary-lambda
_max
=
tvm
.
comm_reducer
(
lambda
x
,
y
:
tvm
.
make
.
_OpMax
(
x
,
y
),
min_value
,
name
=
'max'
)
rh
=
tvm
.
reduce_axis
((
0
,
hend
-
hstart
),
'rh'
)
rw
=
tvm
.
reduce_axis
((
0
,
wend
-
wstart
),
'rw'
)
return
_max
(
data
[
batch_index
,
c
,
hstart
+
rh
,
wstart
+
rw
],
axis
=
[
rh
,
rw
])
return
tvm
.
compute
((
num_roi
,
channel
,
pooled_size_h
,
pooled_size_w
),
_pool
,
tag
=
"pool,roi_pool"
)
topi/tests/python/test_topi_vision.py
View file @
baf7a729
...
...
@@ -268,6 +268,53 @@ def test_roi_align():
verify_roi_align
(
4
,
16
,
32
,
64
,
7
,
0.5
,
2
)
def
verify_roi_pool
(
batch
,
in_channel
,
in_size
,
num_roi
,
pooled_size
,
spatial_scale
):
a_shape
=
(
batch
,
in_channel
,
in_size
,
in_size
)
rois_shape
=
(
num_roi
,
5
)
a
=
tvm
.
placeholder
(
a_shape
)
rois
=
tvm
.
placeholder
(
rois_shape
)
@memoize
(
"topi.tests.test_topi_vision.verify_roi_pool"
)
def
get_ref_data
():
a_np
=
np
.
random
.
uniform
(
size
=
a_shape
)
.
astype
(
'float32'
)
rois_np
=
np
.
random
.
uniform
(
size
=
rois_shape
)
.
astype
(
'float32'
)
*
in_size
rois_np
[:,
0
]
=
np
.
random
.
randint
(
low
=
0
,
high
=
batch
,
size
=
num_roi
)
.
astype
(
'float32'
)
b_np
=
topi
.
testing
.
roi_pool_nchw_python
(
a_np
,
rois_np
,
pooled_size
=
pooled_size
,
spatial_scale
=
spatial_scale
)
return
a_np
,
rois_np
,
b_np
a_np
,
rois_np
,
b_np
=
get_ref_data
()
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
b
=
topi
.
vision
.
rcnn
.
roi_pool_nchw
(
a
,
rois
,
pooled_size
=
pooled_size
,
spatial_scale
=
spatial_scale
)
s
=
topi
.
generic
.
schedule_roi_pool
(
b
)
tvm_a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
tvm_rois
=
tvm
.
nd
.
array
(
rois_np
,
ctx
)
tvm_b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
b
.
shape
),
dtype
=
b
.
dtype
),
ctx
=
ctx
)
f
=
tvm
.
build
(
s
,
[
a
,
rois
,
b
],
device
)
f
(
tvm_a
,
tvm_rois
,
tvm_b
)
tvm
.
testing
.
assert_allclose
(
tvm_b
.
asnumpy
(),
b_np
,
rtol
=
1e-4
)
for
device
in
[
'cuda'
,
'llvm'
]:
check_device
(
device
)
def
test_roi_pool
():
verify_roi_pool
(
1
,
4
,
16
,
32
,
7
,
1.0
)
verify_roi_pool
(
4
,
4
,
16
,
32
,
7
,
0.5
)
def
verify_proposal
(
np_cls_prob
,
np_bbox_pred
,
np_im_info
,
np_out
,
attrs
):
cls_prob
=
tvm
.
placeholder
(
np_cls_prob
.
shape
)
bbox_pred
=
tvm
.
placeholder
(
np_bbox_pred
.
shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment