Commit d20646c7 by Wuwei Lin Committed by masahi

[TOPI][CUDA] Add faster-rcnn proposal op (#2420)

* [TOPI][CUDA] Add faster-rcnn proposal op

* Fix doc

* Add global barrier

* Use vthread in argsort

* Update sort and nms ir

* Fix lint

* Update sort ir in ssd nms
parent 6b0157bf
......@@ -18,3 +18,4 @@ from .vision import *
from . import ssd
from .ssd import *
from .nms import *
from .rcnn import *
......@@ -35,7 +35,7 @@ def sort_ir(data, index, output):
p_index = ib.buffer_ptr(index)
p_out = ib.buffer_ptr(output)
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_bx = (num_anchors + 1) // 2 // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("vthread")
ib.scope_attr(tx, "thread_extent", nthread_tx)
......@@ -46,8 +46,10 @@ def sort_ir(data, index, output):
with ib.for_range(0, batch, for_type="unroll") as b:
start = b * num_anchors
with ib.if_scope(tid < num_anchors):
p_out[start + tid] = tid
for i in range(2):
bbox_id = tid * 2 + i
with ib.if_scope(bbox_id < num_anchors):
p_out[start + bbox_id] = bbox_id
# OddEvenTransposeSort
with ib.for_range(0, p_index[b]) as k:
with ib.if_scope(tid < (p_index[b] + 1) // 2):
......
# pylint: disable=wildcard-import
"""Faster R-CNN and Mask R-CNN operators"""
from .proposal import *
......@@ -151,3 +151,32 @@ def schedule_multibox_detection(outs):
@generic.schedule_roi_align.register(["cuda", "gpu"])
def schedule_roi_align(outs):
return schedule_pool(outs, 'NCHW')
@generic.schedule_proposal.register(["cuda", "gpu"])
def schedule_proposal(outs):
"""Schedule for proposal operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of proposal
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if op.tag in ['bbox_score', 'sorted_bbox']:
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -157,3 +157,20 @@ def schedule_roi_align(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_proposal(outs):
"""Schedule for proposal operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of proposal
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
# pylint: disable=wildcard-import
"""Faster R-CNN and Mask R-CNN operators"""
from .roi_align import *
from .proposal import *
# pylint: disable=invalid-name
"""Proposal operator"""
import math
import tvm
def generate_anchor(ratio, scale, base_size):
"""Generate anchor"""
w = h = float(base_size)
x_ctr = 0.5 * (w - 1.)
y_ctr = 0.5 * (h - 1.)
size = w * h
size_ratios = math.floor(size / ratio)
new_w = math.floor(math.sqrt(size_ratios) + 0.5) * scale
new_h = math.floor((new_w / scale * ratio) + 0.5) * scale
return (x_ctr - 0.5 * (new_w - 1.0), y_ctr - 0.5 * (new_h - 1.0),
x_ctr + 0.5 * (new_w - 1.0), y_ctr + 0.5 * (new_h - 1.0))
def reg_bbox(x1, y1, x2, y2, dx, dy, dw, dh):
"""Bounding box regression function"""
bbox_w = x2 - x1 + 1.0
bbox_h = y2 - y1 + 1.0
ctr_x = x1 + 0.5 * (bbox_w - 1.0)
ctr_y = y1 + 0.5 * (bbox_h - 1.0)
pred_ctr_x = dx * bbox_w + ctr_x
pred_ctr_y = dy * bbox_h + ctr_y
pred_w = tvm.exp(dw) * bbox_w
pred_h = tvm.exp(dh) * bbox_h
pred_x1 = pred_ctr_x - 0.5 * (pred_w - 1.0)
pred_y1 = pred_ctr_y - 0.5 * (pred_h - 1.0)
pred_x2 = pred_ctr_x + 0.5 * (pred_w - 1.0)
pred_y2 = pred_ctr_y + 0.5 * (pred_h - 1.0)
return pred_x1, pred_y1, pred_x2, pred_y2
def reg_iou(x1, y1, x2, y2, dx1, dy1, dx2, dy2):
"""Bounding box regression function"""
pred_x1 = x1 + dx1
pred_y1 = y1 + dy1
pred_x2 = x2 + dx2
pred_y2 = y2 + dy2
return pred_x1, pred_y1, pred_x2, pred_y2
@tvm.target.generic_func
def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
"""Proposal operator.
Parameters
----------
cls_prob : tvm.Tensor
4-D with shape [batch, 2 * num_anchors, height, width]
bbox_pred : tvm.Tensor
4-D with shape [batch, 4 * num_anchors, height, width]
im_info : tvm.Tensor
2-D with shape [batch, 3]
scales : list/tuple of float
Scales of anchor windoes.
ratios : list/tuple of float
Ratios of anchor windoes.
feature_stride : int
The size of the receptive field each unit in the convolution layer of the rpn, for example
the product of all stride's prior to this layer.
threshold : float
Non-maximum suppression threshold.
rpn_pre_nms_top_n : int
Number of top scoring boxes to apply NMS. -1 to use all boxes.
rpn_post_nms_top_n : int
Number of top scoring boxes to keep after applying NMS to RPN proposals.
rpn_min_size : int
Minimum height or width in proposal.
iou_loss : bool
Usage of IoU loss.
Returns
-------
out : tvm.Tensor
2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
[batch_index, w_start, h_start, w_end, h_end].
"""
# pylint: disable=unused-argument
raise ValueError("missing register for topi.vision.rcnn.proposal")
"""Test code for vision package"""
from __future__ import print_function
import math
import numpy as np
import tvm
......@@ -206,8 +207,75 @@ def test_roi_align():
verify_roi_align(4, 16, 32, 64, 7, 0.5, 2)
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
cls_prob = tvm.placeholder(np_cls_prob.shape)
bbox_pred = tvm.placeholder(np_bbox_pred.shape)
im_info = tvm.placeholder(np_im_info.shape, dtype='int32')
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
out = topi.vision.proposal(cls_prob, bbox_pred, im_info, **attrs)
s = topi.generic.schedule_proposal(out)
f = tvm.build(s, [cls_prob, bbox_pred, im_info, out], device)
tvm_cls_prob = tvm.nd.array(np_cls_prob, ctx=ctx)
tvm_bbox_pred = tvm.nd.array(np_bbox_pred, ctx=ctx)
tvm_im_info = tvm.nd.array(np_im_info, ctx=ctx)
tvm_out = tvm.nd.empty(ctx=ctx, shape=out.shape, dtype=out.dtype)
f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
for device in ['cuda']:
check_device(device)
def test_proposal():
attrs = {'scales': (0.5,),'ratios': (0.5,),
'feature_stride': 16,
'iou_loss': False,
'rpn_min_size': 16,
'threshold': 0.7,
'rpn_pre_nms_top_n': 200,
'rpn_post_nms_top_n': 4,
}
np_cls_prob = np.array([[
[[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
[[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]]
]], dtype='float32')
np_bbox_pred = np.array([[
[[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]],
[[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]],
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
]], dtype='float32')
np_im_info = np.array([[48, 48, 1]], dtype='int32')
np_out = np.array([
[0., 0., 2.8451548,28.38012, 18.154846],
[0., 0., 15.354933, 41.96971, 41.245064],
[0., 18.019852, 1.0538368, 51.98015, 25.946163],
[0., 27.320923, -1.266357, 55., 24.666357]
], dtype='float32')
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
np_out = np.array([
[ 0., -5.25, -2.5, 21.75, 19.],
[ 0., 11.25, -2., 37.25, 18.5],
[ 0., 26.849998, -2.3000002, 53.45, 18.6],
[ 0., -4.95, 13.799999, 22.25, 35.5]
], dtype='float32')
attrs['iou_loss'] = True
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
if __name__ == "__main__":
test_nms()
test_multibox_prior()
test_multibox_detection()
test_roi_align()
test_proposal()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment