Commit da1ea262 by Leyuan Wang Committed by Wuwei Lin

Non_maximum_suppression and get_valid_counts add new parameters (#3335)

parent 124f9b7f
......@@ -27,7 +27,7 @@ from .sort import argsort
from .. import tag
def get_valid_counts_pre(data, flag, idx, score_threshold):
def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index):
"""Low level IR to Prepare get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the
top of input data.
......@@ -46,6 +46,12 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
score_threshold : float32
Lower limit of score for valid bounding boxes.
id_index : optional, int
index of the class categories, -1 to disable.
score_index: optional, int
Index of the scores/confidence of boxes.
Returns
-------
stmt : Stmt
......@@ -61,6 +67,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
flag = ib.buffer_ptr(flag)
idx = ib.buffer_ptr(idx)
score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
......@@ -72,7 +80,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
with ib.if_scope(data[tid * box_data_length + 1] > score_threshold):
with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \
tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))):
flag[tid] = 1
idx[tid] = 1
with ib.else_scope():
......@@ -356,7 +365,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
temp_flag, temp_idx = \
tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
lambda ins, outs: get_valid_counts_pre(
ins[0], outs[0], outs[1], score_threshold),
ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
dtype=["int32", "int32"],
out_buffers=[temp_flag_buf, temp_idx_buf],
name="get_valid_counts_phase_one")
......@@ -395,7 +404,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
def nms_ir(data, sorted_index, valid_count, out, box_indices,
max_output_size, iou_threshold, force_suppress,
top_k, coord_start, id_index):
top_k, coord_start, id_index, score_index):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
......@@ -431,6 +440,9 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
id_index : int
index of the class categories, -1 to disable.
score_index : optional, int
Index of the scores/confidence of boxes.
Returns
-------
stmt : Stmt
......@@ -477,6 +489,7 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="unroll") as i:
......@@ -498,20 +511,26 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
box_indices[i * num_anchors + (j + nkeep)] = -1
# Apply nms
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0):
with ib.for_range(0, valid_count[i]) as k:
offset_k = k * box_data_length
with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \
with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(tvm.all(j > k, \
out[base_idx + offset_j + score_index] > 0, \
tvm.any(id_index < 0, \
out[base_idx + offset_j + id_index] >= 0), \
tvm.any(force_suppress > 0, id_index < 0, \
out[base_idx + offset_j] == \
out[base_idx + offset_k]))):
iou = calculate_overlap(out, base_idx + offset_k + coord_start,
base_idx + offset_j + coord_start)
out[base_idx + offset_k + id_index] == \
out[base_idx + offset_j + id_index]))):
iou = calculate_overlap(out, base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start)
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_k] = -1.0
box_indices[i * num_anchors + k] = -1
out[base_idx + offset_j + score_index] = -1.0
with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.else_scope():
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
......@@ -749,7 +768,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], outs[1],
max_output_size, iou_threshold, force_suppress,
top_k, coord_start, id_index),
top_k, coord_start, id_index, score_index),
dtype=[data.dtype, "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
name="nms",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment