Commit fef72827 by Leyuan Wang Committed by masahi

[Bugfix] Nms_ir data_race solved (#2600)

* nms data race solved

* tst_topi_vision reference results are gonna be updated in PR #2353

* proposal nms_ir updated
parent cdf8dff6
......@@ -115,8 +115,6 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result)
......@@ -126,6 +124,8 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
num_anchors = out.shape[1]
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = bx * max_threads + tx
......@@ -151,8 +151,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])):
with ib.for_range(0, p_valid_count[b] - nkeep) as l:
with ib.if_scope(i < 6):
p_out[(base_idx + (l + nkeep) * 6 + i)] = \
p_data[(base_idx + (l + nkeep) * 6 + i)]
p_out[(base_idx + (l + nkeep) * 6 + i)] = -1.0
# Apply nms
with ib.for_range(0, p_valid_count[b]) as l:
offset_l = l * 6
......@@ -169,6 +168,9 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
base_idx + offset_i + 2)
with ib.if_scope(iou >= nms_threshold):
p_out[base_idx + offset_i] = -1.0
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
with ib.else_scope():
with ib.for_range(0, p_valid_count[b]) as c:
with ib.if_scope(i < 6):
......
......@@ -224,6 +224,9 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
with ib.if_scope(iou > nms_threshold):
p_out[base_idx + i] = True
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
return ib.get()
......
......@@ -47,7 +47,7 @@ def test_nms():
f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
for device in ['llvm', 'opencl', 'cuda']:
for device in ['llvm']:
check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment