Commit 9b4b360f by Wuwei Lin Committed by Tianqi Chen

[TOPI][CUDA] Fix nms block extent and type mismatch in multibox (#2320)

parent 57506af1
...@@ -99,7 +99,7 @@ def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \ ...@@ -99,7 +99,7 @@ def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
tvm.target.current_target(allow_none=False).max_num_threads) tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
dshape = tvm.max(sizes_in.shape[0], p_index[0]) dshape = axis_mul_before * axis_mul_after
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1 nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(tx, "thread_extent", nthread_tx)
...@@ -331,9 +331,7 @@ def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend): ...@@ -331,9 +331,7 @@ def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
Parameters Parameters
---------- ----------
data: tvm.Tensor data: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6]. 2-D tensor of input boxes' score with shape [batch_size, num_anchors].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
data_buf: Buffer data_buf: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors]. 2D Buffer of input boxes' score with shape [batch_size, num_anchors].
...@@ -595,8 +593,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk ...@@ -595,8 +593,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
force_suppress = True force_suppress = True
nms_topk = -1 nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(dshape) np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]) np_valid_count = np.array([4]).astype("int32")
s = topi.generic.schedule_nms(out) s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm") f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu() ctx = tvm.cpu()
......
...@@ -278,10 +278,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ ...@@ -278,10 +278,10 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
oy = py * vy * ah + ay oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0 ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ return tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh) tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh)
max_threads = int( max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads) tvm.target.current_target(allow_none=False).max_num_threads)
......
...@@ -145,8 +145,8 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) ...@@ -145,8 +145,8 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
force_suppress = True force_suppress = True
nms_topk = -1 nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(dshape) np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]) np_valid_count = np.array([4]).astype("int32")
s = topi.generic.schedule_nms(out) s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm") f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu() ctx = tvm.cpu()
......
...@@ -46,7 +46,7 @@ def test_nms(): ...@@ -46,7 +46,7 @@ def test_nms():
f(tvm_data, tvm_valid_count, tvm_out) f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
for device in ['llvm', 'opencl']: for device in ['llvm', 'opencl', 'cuda']:
check_device(device) check_device(device)
...@@ -105,7 +105,7 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse ...@@ -105,7 +105,7 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
f(tvm_input_data, tvm_out) f(tvm_input_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-3) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-3)
for device in ['llvm', 'opencl']: for device in ['llvm', 'opencl', 'cuda']:
check_device(device) check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment