Commit da1ea262 by Leyuan Wang Committed by Wuwei Lin

Non_maximum_suppression and get_valid_counts add new parameters (#3335)

parent 124f9b7f
...@@ -27,7 +27,7 @@ from .sort import argsort ...@@ -27,7 +27,7 @@ from .sort import argsort
from .. import tag from .. import tag
def get_valid_counts_pre(data, flag, idx, score_threshold): def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index):
"""Low level IR to Prepare get valid count of bounding boxes """Low level IR to Prepare get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the given a score threshold. Also moves valid boxes to the
top of input data. top of input data.
...@@ -46,6 +46,12 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): ...@@ -46,6 +46,12 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
score_threshold : float32 score_threshold : float32
Lower limit of score for valid bounding boxes. Lower limit of score for valid bounding boxes.
id_index : optional, int
index of the class categories, -1 to disable.
score_index: optional, int
Index of the scores/confidence of boxes.
Returns Returns
------- -------
stmt : Stmt stmt : Stmt
...@@ -61,6 +67,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): ...@@ -61,6 +67,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
flag = ib.buffer_ptr(flag) flag = ib.buffer_ptr(flag)
idx = ib.buffer_ptr(idx) idx = ib.buffer_ptr(idx)
score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads nthread_tx = max_threads
...@@ -72,7 +80,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): ...@@ -72,7 +80,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
tid = bx * max_threads + tx tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors): with ib.if_scope(tid < batch_size * num_anchors):
with ib.if_scope(data[tid * box_data_length + 1] > score_threshold): with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \
tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))):
flag[tid] = 1 flag[tid] = 1
idx[tid] = 1 idx[tid] = 1
with ib.else_scope(): with ib.else_scope():
...@@ -356,7 +365,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): ...@@ -356,7 +365,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
temp_flag, temp_idx = \ temp_flag, temp_idx = \
tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
lambda ins, outs: get_valid_counts_pre( lambda ins, outs: get_valid_counts_pre(
ins[0], outs[0], outs[1], score_threshold), ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
dtype=["int32", "int32"], dtype=["int32", "int32"],
out_buffers=[temp_flag_buf, temp_idx_buf], out_buffers=[temp_flag_buf, temp_idx_buf],
name="get_valid_counts_phase_one") name="get_valid_counts_phase_one")
...@@ -395,7 +404,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): ...@@ -395,7 +404,7 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
def nms_ir(data, sorted_index, valid_count, out, box_indices, def nms_ir(data, sorted_index, valid_count, out, box_indices,
max_output_size, iou_threshold, force_suppress, max_output_size, iou_threshold, force_suppress,
top_k, coord_start, id_index): top_k, coord_start, id_index, score_index):
"""Low level IR routing for transform location in multibox_detection operator. """Low level IR routing for transform location in multibox_detection operator.
Parameters Parameters
...@@ -431,6 +440,9 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices, ...@@ -431,6 +440,9 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
id_index : int id_index : int
index of the class categories, -1 to disable. index of the class categories, -1 to disable.
score_index : optional, int
Index of the scores/confidence of boxes.
Returns Returns
------- -------
stmt : Stmt stmt : Stmt
...@@ -477,6 +489,7 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices, ...@@ -477,6 +489,7 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
top_k = tvm.make.node("IntImm", dtype="int32", value=top_k) top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start) coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="unroll") as i: with ib.for_range(0, batch_size, for_type="unroll") as i:
...@@ -498,20 +511,26 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices, ...@@ -498,20 +511,26 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
box_indices[i * num_anchors + (j + nkeep)] = -1 box_indices[i * num_anchors + (j + nkeep)] = -1
# Apply nms # Apply nms
with ib.if_scope(j < valid_count[i]): with ib.for_range(0, valid_count[i]) as k:
offset_j = j * box_data_length offset_k = k * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0): with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
with ib.for_range(0, valid_count[i]) as k: tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
offset_k = k * box_data_length with ib.if_scope(j < valid_count[i]):
with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \ offset_j = j * box_data_length
with ib.if_scope(tvm.all(j > k, \
out[base_idx + offset_j + score_index] > 0, \
tvm.any(id_index < 0, \
out[base_idx + offset_j + id_index] >= 0), \
tvm.any(force_suppress > 0, id_index < 0, \ tvm.any(force_suppress > 0, id_index < 0, \
out[base_idx + offset_j] == \ out[base_idx + offset_k + id_index] == \
out[base_idx + offset_k]))): out[base_idx + offset_j + id_index]))):
iou = calculate_overlap(out, base_idx + offset_k + coord_start, iou = calculate_overlap(out, base_idx + offset_j + coord_start,
base_idx + offset_j + coord_start) base_idx + offset_k + coord_start)
with ib.if_scope(iou >= iou_threshold): with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_k] = -1.0 out[base_idx + offset_j + score_index] = -1.0
box_indices[i * num_anchors + k] = -1 with ib.if_scope(id_index >= 0):
out[base_idx + offset_j + id_index] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.else_scope(): with ib.else_scope():
with ib.if_scope(j < valid_count[i]): with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length offset_j = j * box_data_length
...@@ -749,7 +768,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, ...@@ -749,7 +768,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
lambda ins, outs: nms_ir( lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], outs[1], ins[0], ins[1], ins[2], outs[0], outs[1],
max_output_size, iou_threshold, force_suppress, max_output_size, iou_threshold, force_suppress,
top_k, coord_start, id_index), top_k, coord_start, id_index, score_index),
dtype=[data.dtype, "int32"], dtype=[data.dtype, "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
name="nms", name="nms",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment