Commit a3f3dc75 by Alexander Pivovarov Committed by Tianqi Chen

Make topi cuda nms_gpu method signature similar to non_max_suppression (#2780)

parent d8abc733
...@@ -182,8 +182,15 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -182,8 +182,15 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
@non_max_suppression.register(["cuda", "gpu"]) @non_max_suppression.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False, def nms_gpu(data,
topk=-1, id_index=0, invalid_to_bottom=False): valid_count,
max_output_size=-1,
iou_threshold=0.5,
force_suppress=False,
top_k=-1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection. """Non-maximum suppression operator for object detection.
Parameters Parameters
...@@ -205,7 +212,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress ...@@ -205,7 +212,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
force_suppress : optional, boolean force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id. Whether to suppress all detections regardless of class_id.
topk : optional, int top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit. Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int id_index : optional, int
...@@ -229,7 +236,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress ...@@ -229,7 +236,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7 iou_threshold = 0.7
force_suppress = True force_suppress = True
topk = -1 top_k = -1
out = nms(data, valid_count, iou_threshold, force_suppress, topk) out = nms(data, valid_count, iou_threshold, force_suppress, topk)
np_data = np.random.uniform(dshape) np_data = np.random.uniform(dshape)
np_valid_count = np.array([4]) np_valid_count = np.array([4])
...@@ -273,7 +280,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress ...@@ -273,7 +280,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
[data, sort_tensor, valid_count], [data, sort_tensor, valid_count],
lambda ins, outs: nms_ir( lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], iou_threshold, ins[0], ins[1], ins[2], outs[0], iou_threshold,
force_suppress, topk), force_suppress, top_k),
dtype="float32", dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms") tag="nms")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment