Commit 60769b77 by Leyuan Wang Committed by Tianqi Chen

Fixed bugs for SSD sorting and multbox detection (#1578)

parent 19cf5c66
......@@ -7,90 +7,105 @@ from tvm import api
from topi.vision import nms
def sort_ir(data, index, output, axis, is_descend):
"""Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
def sort_pre_ir(index, sizes_out, axis_mul_before, axis_mul_after):
"""Low level IR routing subfunction 1/4 for computing segments' staring locatons.
Parameters
----------
data: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : Buffer
Buffer of number of valid number of boxes.
Buffer of number of valid output boxes.
output : Buffer
Output buffer of indicies of sorted tensor.
sizes_out : Buffer
Output buffer of start locations of each sorting segment.
axis : int
The axis used for sorting.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
is_descend : bool
If the sorted data is in descending order.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_index = ib.buffer_ptr(index)
p_out = ib.buffer_ptr(output)
ndim = len(data.shape)
assert data.dtype == "float32", "Currently only supports input dtype to be float32"
assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
axis_mul_before = 1
axis_mul_after = 1
if axis < 0:
axis = ndim + axis
for i in range(0, ndim):
if i < axis:
axis_mul_before *= data.shape[i]
elif i > axis:
axis_mul_after *= data.shape[i]
dshape = 0
for i in range(0, len(index.shape)):
dshape += index.shape[i]
dshape = tvm.select(dshape > axis_mul_before*axis_mul_after, dshape,
axis_mul_before*axis_mul_after)
sizes_temp = ib.allocate(
"int32", dshape, name="sizes_temp", scope="global")
sizes = ib.allocate("int32", dshape, name="sizes", scope="global")
temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local")
temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local")
data_new = ib.allocate("float32", dshape, name="data_new", scope="global")
index_new = ib.allocate("int32", dshape, name="index_new", scope="global")
dshape = sizes_out.shape
sizes = ib.buffer_ptr(sizes_out)
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
nthread_bx = dshape[0] // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
sizes[tid] = p_index[tid]
sizes_temp[tid] = p_index[tid]
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \
.astype("float32"))) + 1, name="k") as k:
with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0):
with ib.if_scope(k % 2 == 0):
sizes[tid] += sizes_temp[tid - (
tvm.const(1, "int32") << k)]
sizes_temp[tid] = sizes[tid]
with ib.else_scope():
sizes_temp[tid] += sizes[tid - (
tvm.const(1, "int32") << k)]
sizes[tid] = sizes_temp[tid]
# scan
with ib.if_scope(tid < 1):
with ib.for_range(0, axis_mul_before * axis_mul_after - 1, name="k") as k:
sizes[k + 1] += sizes[k]
body = ib.get()
return body
def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
axis, axis_mul_before, axis_mul_after):
"""Low level IR routing subfunction 2/4 for flattening data and indices into segmented format.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
sizes_in : Buffer
Buffer of start locations of each sorting segment.
data_out : Buffer
Buffer of flattened segmented data.
index_out : Buffer
Buffer of flattened segmented indices.
axis : int
The axis used for sorting.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
Returns
-------
stmt : Stmt
The result IR statement.
"""
ib = tvm.ir_builder.create()
sizes = ib.buffer_ptr(sizes_in)
p_index = ib.buffer_ptr(index)
p_data = ib.buffer_ptr(data)
data_new = ib.buffer_ptr(data_out)
index_new = ib.buffer_ptr(index_out)
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
dshape = tvm.max(sizes_in.shape[0], p_index[0])
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
......@@ -104,7 +119,77 @@ def sort_ir(data, index, output, axis, is_descend):
start = sizes[tid-1]
index_new[start + k] = k
data_new[start + k] = p_data[full_idx]
with ib.else_scope():
with ib.if_scope(tid == 0):
with ib.for_range(0, p_index[0], name="k") as k:
index_new[k] = k
body = ib.get()
return body
def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \
axis_mul_after, axis, is_descend):
"""Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_data : Buffer
Buffer of flattened segmented data.
new_index : Buffer
Buffer of flattened segmented indices.
loc : Buffer
Buffer of start locations of each sorting segment.
out_index : Buffer
Output buffer of output box indexes sorted by score in a flattened segmented format.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
dshape = loc.shape
fshape = data.shape[axis] * dshape[0]
temp_data = ib.allocate(
"float32", dshape, name="temp_data", scope="local")
p_data = ib.buffer_ptr(data)
p_index = ib.buffer_ptr(index)
data_new = ib.buffer_ptr(new_data)
index_new = ib.buffer_ptr(new_index)
index_out = ib.buffer_ptr(out_index)
sizes = ib.buffer_ptr(loc)
nthread_tx = max_threads
nthread_bx = fshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
with ib.if_scope(tid == 0):
start = 0
......@@ -113,20 +198,117 @@ def sort_ir(data, index, output, axis, is_descend):
# OddEvenTransposeSort
with ib.for_range(0, p_index[tid], name="k") as k:
with ib.for_range(0, p_index[tid] - 1, name="i") as i:
with ib.if_scope(i % 2 == (k & 1)):
with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) ^
is_descend) == False):
with ib.if_scope(i % 2 == k % 2):
with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) == is_descend)):
temp_data[tid] = data_new[i+start]
data_new[i+start] = data_new[i+start+1]
data_new[i+start+1] = temp_data[tid]
temp_index[tid] = index_new[i+start]
index_out[tid] = index_new[i+start]
index_new[i+start] = index_new[i+start+1]
index_new[i+start+1] = temp_index[tid]
index_new[i+start+1] = index_out[tid]
with ib.if_scope(tid < 1):
with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i:
index_out[i] = index_new[i]
with ib.else_scope():
with ib.for_range(0, fshape, name="k", for_type="unroll") as k:
with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \
== is_descend)):
data_new[tid] = p_data[tid+1]
index_out[tid] = index_new[tid+1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \
== is_descend)):
p_data[tid] = data_new[tid+1]
index_new[tid] = index_out[tid+1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid > 0, (p_data[tid-1] < p_data[tid]) == is_descend)):
data_new[tid] = p_data[tid-1]
index_out[tid] = index_new[tid-1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \
== is_descend)):
p_data[tid] = data_new[tid-1]
index_new[tid] = index_out[tid-1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(fshape % 2 == 1):
with ib.if_scope(tid < 1):
with ib.for_range(0, fshape, name="k") as k:
index_out[tid] = index_new[tid]
body = ib.get()
return body
def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis):
"""Low level IR routing subfunction 4/4 for writing sorted indices to output format.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_index : Buffer
Buffer of sorted indices in a flatten format.
loc : Buffer
Buffer of start locations of each sorting segment.
output : Buffer
Output buffer of output box indexes sorted by score.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
dshape = tvm.max(loc.shape[0], data.shape[axis])
p_index = ib.buffer_ptr(index)
index_new = ib.buffer_ptr(new_index)
sizes = ib.buffer_ptr(loc)
p_out = ib.buffer_ptr(output)
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
current_sort_num = p_index[tid]
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, data.shape[axis], name="k") as k:
with ib.if_scope(tid == 0):
......@@ -134,12 +316,119 @@ def sort_ir(data, index, output, axis, is_descend):
with ib.else_scope():
start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.select(
k < current_sort_num,
index_new[k+start], k)
k < p_index[tid], index_new[k+start], k)
with ib.else_scope():
with ib.if_scope(tid < data.shape[axis]):
p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid)
body = ib.get()
return body
def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
"""Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu.
Parameters
----------
data: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
data_buf: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : tvm.Tensor
1-D tensor for valid number of boxes.
index_buf : Buffer
Buffer of number of valid number of boxes.
output_buf : Buffer
Output buffer of indicies of sorted tensor.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
out : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors].
"""
ndim = len(data.shape)
assert data.dtype == "float32", "Currently only supports input dtype to be float32"
assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
axis_mul_before = 1
axis_mul_after = 1
if axis < 0:
axis = ndim + axis
for i in range(0, ndim):
if i < axis:
axis_mul_before *= data.shape[i]
elif i > axis:
axis_mul_after *= data.shape[i]
dshape = axis_mul_before*axis_mul_after
fshape = data.shape[axis] * dshape
loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8)
new_index_buf = api.decl_buffer(
fshape, index.dtype, "index_new", data_alignment=8)
out_index_buf = api.decl_buffer(
fshape, index.dtype, "index_out", data_alignment=8)
new_data_buf = api.decl_buffer(
dshape, data.dtype, "data_new", data_alignment=8)
loc = \
tvm.extern([(dshape,)],
[index],
lambda ins, outs: sort_pre_ir(
ins[0], outs[0], axis_mul_before, axis_mul_after),
dtype=[index.dtype],
in_buffers=index_buf,
out_buffers=[loc_buf],
tag="sorting_prepare")
data_new, index_new = \
tvm.extern([(dshape,), (fshape,)],
[data, index, loc],
lambda ins, outs: sort_pre_ir_data(
ins[0], ins[1], ins[2], outs[0], outs[1], axis,
axis_mul_before, axis_mul_after),
dtype=[data.dtype, index.dtype],
in_buffers=[data_buf, index_buf, loc_buf],
out_buffers=[new_data_buf, new_index_buf],
tag="sorting_data")
index_out = \
tvm.extern([(fshape,)],
[data, index, data_new, index_new, loc],
lambda ins, outs: sort_oet_ir(
ins[0], ins[1], ins[2], ins[3], ins[4], outs[0],
axis_mul_before, axis_mul_after, axis, is_descend),
dtype=[index.dtype],
in_buffers=[data_buf, index_buf,
new_data_buf, new_index_buf, loc_buf],
out_buffers=[out_index_buf],
tag="sorting_oet")
out = \
tvm.extern([data.shape],
[data, index, index_out, loc],
lambda ins, outs: sort_ir_out(
ins[0], ins[1], ins[2], ins[3], outs[0],
axis_mul_before, axis_mul_after, axis),
dtype=[index.dtype],
in_buffers=[data_buf, index_buf, out_index_buf, loc_buf],
out_buffers=output_buf,
tag="sorting_output")
return out
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
"""Low level IR routing for transform location in multibox_detection operator.
......@@ -333,15 +622,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8)
sort_tensor = \
tvm.extern(score_shape,
[score_tensor, valid_count],
lambda ins, outs: sort_ir(
ins[0], ins[1], outs[0], score_axis, True),
dtype=sort_tensor_dtype,
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count,
valid_count_buf, sort_tensor_buf, score_axis, True)
out = \
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, too-many-function-args
"""SSD multibox operators"""
from __future__ import absolute_import as _abs
import math
......@@ -13,6 +13,7 @@ from topi.vision.ssd import multibox_detection
from topi.vision.ssd import multibox_transform_loc
from ..nms import nms
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
"""Low level IR routing for multibox_prior operator.
......@@ -41,7 +42,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
stmt : Stmt
The result IR statement.
"""
max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x")
......@@ -76,7 +78,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0,
size_ratio_concat[
k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
......@@ -93,7 +96,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
@multibox_prior.register(["cuda", "gpu"])
def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \
def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
offsets=(0.5, 0.5), clip=False):
"""Generate prior(anchor) boxes from data, sizes and ratios.
......@@ -124,31 +127,114 @@ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \
"""
num_sizes = len(sizes)
num_ratios = len(ratios)
oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
oshape = (
1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
out = tvm.extern(oshape, [data], lambda ins, outs:
multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets),
multibox_prior_ir(
ins[0], outs[0], sizes, ratios, steps, offsets),
tag="multibox_prior")
if clip:
out = topi.clip(out, 0, 1)
return out
def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances):
"""Low level IR routing for transform location in multibox_detection operator.
def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold):
"""Low level IR routing for transform location data preparation.
Parameters
----------
cls_prob : Buffer
Buffer of class probabilities.
valid_count : Buffer
Buffer of number of valid output boxes.
temp_flag : Buffer
Output intermediate result buffer
temp_id : Buffer
Output intermediate result buffer
temp_score_out : Buffer
Output buffer
threshold : float
Threshold to be a positive prediction.
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create()
score = ib.buffer_ptr(temp_score_out)
cls_id = ib.buffer_ptr(temp_id)
flag = ib.buffer_ptr(temp_flag)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
nthread_tx = max_threads
nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
p_cls_prob = ib.buffer_ptr(cls_prob)
p_valid_count = ib.buffer_ptr(valid_count)
with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors # number of batches
i = tid % num_anchors # number of anchors
score[i] = -1.0
cls_id[i] = 0
p_valid_count[n] = 0
with ib.for_range(0, num_classes-1, name="k") as k:
temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i]
with ib.if_scope(temp > score[i]):
cls_id[i] = k + 1
score[i] = temp
with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)):
cls_id[i] = 0
with ib.if_scope(cls_id[i] > 0):
flag[i] = 1
with ib.else_scope():
flag[i] = 0
with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors, name="k") as k:
with ib.if_scope(k > 0):
flag[tid * num_anchors +
k] += flag[tid * num_anchors + k - 1]
p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1]
body = ib.get()
return body
def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
out, clip, variances, batch_size, num_classes, num_anchors):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
----------
loc_pred : Buffer
Buffer of location regression predictions.
anchor : Buffer
Buffer of prior anchor boxes.
valid_count : Buffer
Buffer of number of valid output boxes.
temp_flag : Buffer
Intermediate result buffer.
temp_id : Buffer
Intermediate result buffer.
temp_score_in : Buffer
Input buffer which stores intermediate results.
out : Buffer
Output buffer.
......@@ -156,12 +242,18 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
clip : boolean
Whether to clip out-of-boundary boxes.
threshold : float
Threshold to be a positive prediction.
variances : tuple of float
Variances to be decoded from box regression output.
batch_size : int
Batch size
num_classes : int
Number of classes
num_anchors : int
Number of anchors
Returns
-------
stmt : Stmt
......@@ -191,17 +283,12 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create()
temp_score = ib.allocate('float32', (batch_size * (num_classes -1) * num_anchors, \
), name="temp_score", scope="global")
score = ib.allocate('float32', (batch_size * num_anchors, ), name="score", scope="local")
cls_id = ib.allocate('int32', (batch_size * num_anchors, ), name="id", scope="local")
flag = ib.allocate('int32', (batch_size * num_anchors, ), name="flag", scope="global")
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
score = ib.buffer_ptr(temp_score_in)
cls_id = ib.buffer_ptr(temp_id)
flag = ib.buffer_ptr(temp_flag)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
nthread_tx = max_threads
......@@ -209,42 +296,13 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
p_cls_prob = ib.buffer_ptr(cls_prob)
p_loc_pred = ib.buffer_ptr(loc_pred)
p_anchor = ib.buffer_ptr(anchor)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
with ib.if_scope(tid < batch_size * num_anchors * num_classes):
n = tid / (num_anchors * num_classes)
j = (tid % (num_anchors * num_classes)) / num_anchors
i = tid % num_anchors
with ib.if_scope(j > 0):
temp_score[n * num_anchors * num_classes + i * (num_classes - 1) + j-1] = \
p_cls_prob[tid]
p_valid_count[n] = 0
with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors
i = tid % num_anchors
score[tid] = -1.0
cls_id[tid] = 0
with ib.for_range(0, num_classes-1, name="k") as k:
temp = temp_score[tid * (num_classes-1) + k]
cls_id[tid] = tvm.select(temp > score[tid], k + 1, cls_id[tid])
score[tid] = tvm.make.Max(temp, score[tid])
with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
cls_id[tid] = 0
with ib.if_scope(cls_id[tid] > 0):
flag[tid] = 1
with ib.else_scope():
flag[tid] = 0
with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors, name="k") as k:
with ib.if_scope(k > 0):
flag[tid * num_anchors + k] += flag[tid * num_anchors + k - 1]
p_valid_count[tid] = flag[tid * num_anchors + num_anchors - 1]
with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors
i = tid % num_anchors
n = tid / num_anchors # number of batches
i = tid % num_anchors # number of anchors
with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0):
out_base_idx = n * num_anchors * 6
......@@ -253,17 +311,17 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
p_out[out_base_idx] = cls_id[tid] - 1.0
p_out[out_base_idx + 1] = score[tid]
p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, p_anchor, i*4,
clip, variances[0], variances[1],
variances[2], variances[3])
p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4,
p_anchor, i*4, clip, variances[0],
variances[1], variances[2], variances[3])
body = ib.get()
return body
@multibox_transform_loc.register(["cuda", "gpu"])
def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
variances=(0.1, 0.1, 0.2, 0.2)):
def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)):
"""Location transformation for multibox detection
Parameters
......@@ -297,20 +355,42 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=
1-D tensor with shape (batch_size,), number of valid anchor boxes.
"""
batch_size = cls_prob.shape[0]
num_anchors = anchor.shape[1]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
oshape = (batch_size, num_anchors, 6)
# Define data alignment for intermediate buffer
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8)
valid_count, out = \
tvm.extern([(batch_size,), oshape],
[cls_prob, loc_pred, anchor],
out_buf = api.decl_buffer(
oshape, cls_prob.dtype, "out_buf", data_alignment=8)
size = num_anchors
temp_flag_buf = api.decl_buffer(
(size,), valid_count_dtype, "flag", data_alignment=8)
temp_id_buf = api.decl_buffer(
(size,), valid_count_dtype, "cls_id", data_alignment=8)
temp_score_buf = api.decl_buffer(
(size,), cls_prob.dtype, "score", data_alignment=8)
valid_count, temp_flag, temp_id, temp_score = \
tvm.extern([(batch_size,), (size,), (size,), (size,)],
[cls_prob],
lambda ins, outs: transform_loc_pre(
ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
dtype=[valid_count_dtype,
valid_count_dtype, valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf,
temp_flag_buf, temp_id_buf, temp_score_buf],
tag="multibox_transform_loc_first_step")
out = \
tvm.extern([oshape],
[loc_pred, anchor, temp_flag, temp_id, temp_score],
lambda ins, outs: transform_loc_ir(
ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances),
dtype=[valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf, out_buf],
ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \
variances, batch_size, num_classes, num_anchors),
dtype=[cls_prob.dtype],
out_buffers=[out_buf],
tag="multibox_transform_loc")
return [out, valid_count]
......@@ -356,5 +436,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
out = nms(
inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment