Commit a3f3dc75 by Alexander Pivovarov Committed by Tianqi Chen

Make topi cuda nms_gpu method signature similar to non_max_suppression (#2780)

parent d8abc733
......@@ -182,8 +182,15 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
@non_max_suppression.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False,
topk=-1, id_index=0, invalid_to_bottom=False):
def nms_gpu(data,
valid_count,
max_output_size=-1,
iou_threshold=0.5,
force_suppress=False,
top_k=-1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
......@@ -205,7 +212,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id.
topk : optional, int
top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
......@@ -229,7 +236,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7
force_suppress = True
topk = -1
top_k = -1
out = nms(data, valid_count, iou_threshold, force_suppress, topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
......@@ -273,7 +280,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], iou_threshold,
force_suppress, topk),
force_suppress, top_k),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment