Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9b4b360f
Commit
9b4b360f
authored
Jan 07, 2019
by
Wuwei Lin
Committed by
Tianqi Chen
Jan 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI][CUDA] Fix nms block extent and type mismatch in multibox (#2320)
parent
57506af1
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
14 deletions
+12
-14
topi/python/topi/cuda/nms.py
+4
-6
topi/python/topi/cuda/ssd/multibox.py
+4
-4
topi/python/topi/vision/nms.py
+2
-2
topi/tests/python/test_topi_vision.py
+2
-2
No files found.
topi/python/topi/cuda/nms.py
View file @
9b4b360f
...
...
@@ -99,7 +99,7 @@ def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
)
tx
=
tvm
.
thread_axis
(
"threadIdx.x"
)
bx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
dshape
=
tvm
.
max
(
sizes_in
.
shape
[
0
],
p_index
[
0
])
dshape
=
axis_mul_before
*
axis_mul_after
nthread_tx
=
max_threads
nthread_bx
=
dshape
//
max_threads
+
1
ib
.
scope_attr
(
tx
,
"thread_extent"
,
nthread_tx
)
...
...
@@ -331,9 +331,7 @@ def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
Parameters
----------
data: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
2-D tensor of input boxes' score with shape [batch_size, num_anchors].
data_buf: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
...
...
@@ -595,8 +593,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(
dshape
)
np_valid_count = np.array([4])
np_data = np.random.uniform(
size=dshape).astype("float32"
)
np_valid_count = np.array([4])
.astype("int32")
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
...
...
topi/python/topi/cuda/ssd/multibox.py
View file @
9b4b360f
...
...
@@ -278,10 +278,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
oy
=
py
*
vy
*
ah
+
ay
ow
=
tvm
.
exp
(
pw
*
vw
)
*
aw
/
2.0
oh
=
tvm
.
exp
(
ph
*
vh
)
*
ah
/
2.0
return
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
,
tvm
.
make
.
Min
(
1
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
,
tvm
.
make
.
Min
(
1
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
,
tvm
.
make
.
Min
(
1
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
,
tvm
.
make
.
Min
(
1
,
oy
+
oh
)),
oy
+
oh
)
return
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
.0
,
tvm
.
make
.
Min
(
1.0
,
ox
-
ow
)),
ox
-
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
.0
,
tvm
.
make
.
Min
(
1.0
,
oy
-
oh
)),
oy
-
oh
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
.0
,
tvm
.
make
.
Min
(
1.0
,
ox
+
ow
)),
ox
+
ow
),
\
tvm
.
select
(
clip
,
tvm
.
make
.
Max
(
0
.0
,
tvm
.
make
.
Min
(
1.0
,
oy
+
oh
)),
oy
+
oh
)
max_threads
=
int
(
tvm
.
target
.
current_target
(
allow_none
=
False
)
.
max_num_threads
)
...
...
topi/python/topi/vision/nms.py
View file @
9b4b360f
...
...
@@ -145,8 +145,8 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(
dshape
)
np_valid_count = np.array([4])
np_data = np.random.uniform(
size=dshape).astype("float32"
)
np_valid_count = np.array([4])
.astype("int32")
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
...
...
topi/tests/python/test_topi_vision.py
View file @
9b4b360f
...
...
@@ -46,7 +46,7 @@ def test_nms():
f
(
tvm_data
,
tvm_valid_count
,
tvm_out
)
tvm
.
testing
.
assert_allclose
(
tvm_out
.
asnumpy
(),
np_result
,
rtol
=
1e-4
)
for
device
in
[
'llvm'
,
'opencl'
]:
for
device
in
[
'llvm'
,
'opencl'
,
'cuda'
]:
check_device
(
device
)
...
...
@@ -105,7 +105,7 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
f
(
tvm_input_data
,
tvm_out
)
tvm
.
testing
.
assert_allclose
(
tvm_out
.
asnumpy
(),
np_out
,
rtol
=
1e-3
)
for
device
in
[
'llvm'
,
'opencl'
]:
for
device
in
[
'llvm'
,
'opencl'
,
'cuda'
]:
check_device
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment