Commit 60769b77 by Leyuan Wang Committed by Tianqi Chen

Fixed bugs for SSD sorting and multbox detection (#1578)

parent 19cf5c66
...@@ -7,19 +7,155 @@ from tvm import api ...@@ -7,19 +7,155 @@ from tvm import api
from topi.vision import nms from topi.vision import nms
def sort_ir(data, index, output, axis, is_descend): def sort_pre_ir(index, sizes_out, axis_mul_before, axis_mul_after):
"""Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. """Low level IR routing subfunction 1/4 for computing segments' staring locatons.
Parameters
----------
index : Buffer
Buffer of number of valid output boxes.
sizes_out : Buffer
Output buffer of start locations of each sorting segment.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_index = ib.buffer_ptr(index)
dshape = sizes_out.shape
sizes = ib.buffer_ptr(sizes_out)
nthread_tx = max_threads
nthread_bx = dshape[0] // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
sizes[tid] = p_index[tid]
# scan
with ib.if_scope(tid < 1):
with ib.for_range(0, axis_mul_before * axis_mul_after - 1, name="k") as k:
sizes[k + 1] += sizes[k]
body = ib.get()
return body
def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
axis, axis_mul_before, axis_mul_after):
"""Low level IR routing subfunction 2/4 for flattening data and indices into segmented format.
Parameters Parameters
---------- ----------
data: Buffer data: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors]. Buffer of output boxes with class and score.
index : Buffer index : Buffer
Buffer of number of valid number of boxes. Buffer of number of valid output boxes.
output : Buffer sizes_in : Buffer
Output buffer of indicies of sorted tensor. Buffer of start locations of each sorting segment.
data_out : Buffer
Buffer of flattened segmented data.
index_out : Buffer
Buffer of flattened segmented indices.
axis : int
The axis used for sorting.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
Returns
-------
stmt : Stmt
The result IR statement.
"""
ib = tvm.ir_builder.create()
sizes = ib.buffer_ptr(sizes_in)
p_index = ib.buffer_ptr(index)
p_data = ib.buffer_ptr(data)
data_new = ib.buffer_ptr(data_out)
index_new = ib.buffer_ptr(index_out)
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
dshape = tvm.max(sizes_in.shape[0], p_index[0])
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
current_sort_num = p_index[tid]
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, current_sort_num, name="k") as k:
full_idx = base_idx + k * axis_mul_after
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
index_new[start + k] = k
data_new[start + k] = p_data[full_idx]
with ib.else_scope():
with ib.if_scope(tid == 0):
with ib.for_range(0, p_index[0], name="k") as k:
index_new[k] = k
body = ib.get()
return body
def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \
axis_mul_after, axis, is_descend):
"""Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_data : Buffer
Buffer of flattened segmented data.
new_index : Buffer
Buffer of flattened segmented indices.
loc : Buffer
Buffer of start locations of each sorting segment.
out_index : Buffer
Output buffer of output box indexes sorted by score in a flattened segmented format.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int axis : int
The axis used for sorting. The axis used for sorting.
...@@ -32,15 +168,197 @@ def sort_ir(data, index, output, axis, is_descend): ...@@ -32,15 +168,197 @@ def sort_ir(data, index, output, axis, is_descend):
stmt : Stmt stmt : Stmt
The result IR statement. The result IR statement.
""" """
max_threads = int( max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads) tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
dshape = loc.shape
fshape = data.shape[axis] * dshape[0]
temp_data = ib.allocate(
"float32", dshape, name="temp_data", scope="local")
p_data = ib.buffer_ptr(data) p_data = ib.buffer_ptr(data)
p_index = ib.buffer_ptr(index) p_index = ib.buffer_ptr(index)
data_new = ib.buffer_ptr(new_data)
index_new = ib.buffer_ptr(new_index)
index_out = ib.buffer_ptr(out_index)
sizes = ib.buffer_ptr(loc)
nthread_tx = max_threads
nthread_bx = fshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
# OddEvenTransposeSort
with ib.for_range(0, p_index[tid], name="k") as k:
with ib.for_range(0, p_index[tid] - 1, name="i") as i:
with ib.if_scope(i % 2 == k % 2):
with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) == is_descend)):
temp_data[tid] = data_new[i+start]
data_new[i+start] = data_new[i+start+1]
data_new[i+start+1] = temp_data[tid]
index_out[tid] = index_new[i+start]
index_new[i+start] = index_new[i+start+1]
index_new[i+start+1] = index_out[tid]
with ib.if_scope(tid < 1):
with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i:
index_out[i] = index_new[i]
with ib.else_scope():
with ib.for_range(0, fshape, name="k", for_type="unroll") as k:
with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \
== is_descend)):
data_new[tid] = p_data[tid+1]
index_out[tid] = index_new[tid+1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \
== is_descend)):
p_data[tid] = data_new[tid+1]
index_new[tid] = index_out[tid+1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid > 0, (p_data[tid-1] < p_data[tid]) == is_descend)):
data_new[tid] = p_data[tid-1]
index_out[tid] = index_new[tid-1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \
== is_descend)):
p_data[tid] = data_new[tid-1]
index_new[tid] = index_out[tid-1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(fshape % 2 == 1):
with ib.if_scope(tid < 1):
with ib.for_range(0, fshape, name="k") as k:
index_out[tid] = index_new[tid]
body = ib.get()
return body
def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis):
"""Low level IR routing subfunction 4/4 for writing sorted indices to output format.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_index : Buffer
Buffer of sorted indices in a flatten format.
loc : Buffer
Buffer of start locations of each sorting segment.
output : Buffer
Output buffer of output box indexes sorted by score.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
dshape = tvm.max(loc.shape[0], data.shape[axis])
p_index = ib.buffer_ptr(index)
index_new = ib.buffer_ptr(new_index)
sizes = ib.buffer_ptr(loc)
p_out = ib.buffer_ptr(output) p_out = ib.buffer_ptr(output)
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, data.shape[axis], name="k") as k:
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.select(
k < p_index[tid], index_new[k+start], k)
with ib.else_scope():
with ib.if_scope(tid < data.shape[axis]):
p_out[tid] = tvm.select(tid < p_index[0], index_new[tid], tid)
body = ib.get()
return body
def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
"""Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu.
Parameters
----------
data: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
data_buf: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : tvm.Tensor
1-D tensor for valid number of boxes.
index_buf : Buffer
Buffer of number of valid number of boxes.
output_buf : Buffer
Output buffer of indicies of sorted tensor.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
out : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors].
"""
ndim = len(data.shape) ndim = len(data.shape)
assert data.dtype == "float32", "Currently only supports input dtype to be float32" assert data.dtype == "float32", "Currently only supports input dtype to be float32"
assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
...@@ -55,89 +373,60 @@ def sort_ir(data, index, output, axis, is_descend): ...@@ -55,89 +373,60 @@ def sort_ir(data, index, output, axis, is_descend):
elif i > axis: elif i > axis:
axis_mul_after *= data.shape[i] axis_mul_after *= data.shape[i]
dshape = 0 dshape = axis_mul_before*axis_mul_after
for i in range(0, len(index.shape)): fshape = data.shape[axis] * dshape
dshape += index.shape[i]
dshape = tvm.select(dshape > axis_mul_before*axis_mul_after, dshape, loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8)
axis_mul_before*axis_mul_after) new_index_buf = api.decl_buffer(
fshape, index.dtype, "index_new", data_alignment=8)
sizes_temp = ib.allocate( out_index_buf = api.decl_buffer(
"int32", dshape, name="sizes_temp", scope="global") fshape, index.dtype, "index_out", data_alignment=8)
sizes = ib.allocate("int32", dshape, name="sizes", scope="global") new_data_buf = api.decl_buffer(
temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local") dshape, data.dtype, "data_new", data_alignment=8)
temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local")
data_new = ib.allocate("float32", dshape, name="data_new", scope="global") loc = \
index_new = ib.allocate("int32", dshape, name="index_new", scope="global") tvm.extern([(dshape,)],
nthread_tx = max_threads [index],
nthread_bx = dshape // max_threads + 1 lambda ins, outs: sort_pre_ir(
ib.scope_attr(tx, "thread_extent", nthread_tx) ins[0], outs[0], axis_mul_before, axis_mul_after),
ib.scope_attr(bx, "thread_extent", nthread_bx) dtype=[index.dtype],
tid = bx * max_threads + tx in_buffers=index_buf,
out_buffers=[loc_buf],
with ib.if_scope(tid < axis_mul_before * axis_mul_after): tag="sorting_prepare")
sizes[tid] = p_index[tid]
sizes_temp[tid] = p_index[tid] data_new, index_new = \
tvm.extern([(dshape,), (fshape,)],
with ib.if_scope(tid < axis_mul_before * axis_mul_after): [data, index, loc],
with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \ lambda ins, outs: sort_pre_ir_data(
.astype("float32"))) + 1, name="k") as k: ins[0], ins[1], ins[2], outs[0], outs[1], axis,
with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0): axis_mul_before, axis_mul_after),
with ib.if_scope(k % 2 == 0): dtype=[data.dtype, index.dtype],
sizes[tid] += sizes_temp[tid - ( in_buffers=[data_buf, index_buf, loc_buf],
tvm.const(1, "int32") << k)] out_buffers=[new_data_buf, new_index_buf],
sizes_temp[tid] = sizes[tid] tag="sorting_data")
with ib.else_scope():
sizes_temp[tid] += sizes[tid - ( index_out = \
tvm.const(1, "int32") << k)] tvm.extern([(fshape,)],
sizes[tid] = sizes_temp[tid] [data, index, data_new, index_new, loc],
lambda ins, outs: sort_oet_ir(
with ib.if_scope(tid < axis_mul_before * axis_mul_after): ins[0], ins[1], ins[2], ins[3], ins[4], outs[0],
i = tid / axis_mul_after axis_mul_before, axis_mul_after, axis, is_descend),
j = tid % axis_mul_after dtype=[index.dtype],
current_sort_num = p_index[tid] in_buffers=[data_buf, index_buf,
base_idx = i * data.shape[axis] * axis_mul_after + j new_data_buf, new_index_buf, loc_buf],
with ib.for_range(0, current_sort_num, name="k") as k: out_buffers=[out_index_buf],
full_idx = base_idx + k * axis_mul_after tag="sorting_oet")
with ib.if_scope(tid == 0): out = \
start = 0 tvm.extern([data.shape],
with ib.else_scope(): [data, index, index_out, loc],
start = sizes[tid-1] lambda ins, outs: sort_ir_out(
index_new[start + k] = k ins[0], ins[1], ins[2], ins[3], outs[0],
data_new[start + k] = p_data[full_idx] axis_mul_before, axis_mul_after, axis),
dtype=[index.dtype],
with ib.if_scope(tid < axis_mul_before * axis_mul_after): in_buffers=[data_buf, index_buf, out_index_buf, loc_buf],
with ib.if_scope(tid == 0): out_buffers=output_buf,
start = 0 tag="sorting_output")
with ib.else_scope(): return out
start = sizes[tid-1]
# OddEvenTransposeSort
with ib.for_range(0, p_index[tid], name="k") as k:
with ib.for_range(0, p_index[tid] - 1, name="i") as i:
with ib.if_scope(i % 2 == (k & 1)):
with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) ^
is_descend) == False):
temp_data[tid] = data_new[i+start]
data_new[i+start] = data_new[i+start+1]
data_new[i+start+1] = temp_data[tid]
temp_index[tid] = index_new[i+start]
index_new[i+start] = index_new[i+start+1]
index_new[i+start+1] = temp_index[tid]
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
current_sort_num = p_index[tid]
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, data.shape[axis], name="k") as k:
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.select(
k < current_sort_num,
index_new[k+start], k)
body = ib.get()
return body
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
...@@ -333,15 +622,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk ...@@ -333,15 +622,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8) "sort_tensor_buf", data_alignment=8)
sort_tensor = \ sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count,
tvm.extern(score_shape, valid_count_buf, sort_tensor_buf, score_axis, True)
[score_tensor, valid_count],
lambda ins, outs: sort_ir(
ins[0], ins[1], outs[0], score_axis, True),
dtype=sort_tensor_dtype,
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
out = \ out = \
tvm.extern(data.shape, tvm.extern(data.shape,
[data, sort_tensor, valid_count], [data, sort_tensor, valid_count],
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, too-many-function-args
"""SSD multibox operators""" """SSD multibox operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import math import math
...@@ -13,6 +13,7 @@ from topi.vision.ssd import multibox_detection ...@@ -13,6 +13,7 @@ from topi.vision.ssd import multibox_detection
from topi.vision.ssd import multibox_transform_loc from topi.vision.ssd import multibox_transform_loc
from ..nms import nms from ..nms import nms
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
"""Low level IR routing for multibox_prior operator. """Low level IR routing for multibox_prior operator.
...@@ -41,7 +42,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ...@@ -41,7 +42,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
stmt : Stmt stmt : Stmt
The result IR statement. The result IR statement.
""" """
max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y") ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
...@@ -76,7 +78,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ...@@ -76,7 +78,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
for k in range(num_sizes + num_ratios - 1): for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes, w = tvm.select(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0, size_ratio_concat[
k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width * size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0) math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0, h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
...@@ -93,7 +96,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ...@@ -93,7 +96,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
@multibox_prior.register(["cuda", "gpu"]) @multibox_prior.register(["cuda", "gpu"])
def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1),
offsets=(0.5, 0.5), clip=False): offsets=(0.5, 0.5), clip=False):
"""Generate prior(anchor) boxes from data, sizes and ratios. """Generate prior(anchor) boxes from data, sizes and ratios.
...@@ -124,31 +127,114 @@ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \ ...@@ -124,31 +127,114 @@ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \
""" """
num_sizes = len(sizes) num_sizes = len(sizes)
num_ratios = len(ratios) num_ratios = len(ratios)
oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4) oshape = (
1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
out = tvm.extern(oshape, [data], lambda ins, outs: out = tvm.extern(oshape, [data], lambda ins, outs:
multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets), multibox_prior_ir(
ins[0], outs[0], sizes, ratios, steps, offsets),
tag="multibox_prior") tag="multibox_prior")
if clip: if clip:
out = topi.clip(out, 0, 1) out = topi.clip(out, 0, 1)
return out return out
def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances): def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold):
"""Low level IR routing for transform location in multibox_detection operator. """Low level IR routing for transform location data preparation.
Parameters Parameters
---------- ----------
cls_prob : Buffer cls_prob : Buffer
Buffer of class probabilities. Buffer of class probabilities.
valid_count : Buffer
Buffer of number of valid output boxes.
temp_flag : Buffer
Output intermediate result buffer
temp_id : Buffer
Output intermediate result buffer
temp_score_out : Buffer
Output buffer
threshold : float
Threshold to be a positive prediction.
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create()
score = ib.buffer_ptr(temp_score_out)
cls_id = ib.buffer_ptr(temp_id)
flag = ib.buffer_ptr(temp_flag)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
nthread_tx = max_threads
nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
p_cls_prob = ib.buffer_ptr(cls_prob)
p_valid_count = ib.buffer_ptr(valid_count)
with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors # number of batches
i = tid % num_anchors # number of anchors
score[i] = -1.0
cls_id[i] = 0
p_valid_count[n] = 0
with ib.for_range(0, num_classes-1, name="k") as k:
temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i]
with ib.if_scope(temp > score[i]):
cls_id[i] = k + 1
score[i] = temp
with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)):
cls_id[i] = 0
with ib.if_scope(cls_id[i] > 0):
flag[i] = 1
with ib.else_scope():
flag[i] = 0
with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors, name="k") as k:
with ib.if_scope(k > 0):
flag[tid * num_anchors +
k] += flag[tid * num_anchors + k - 1]
p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1]
body = ib.get()
return body
def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \
out, clip, variances, batch_size, num_classes, num_anchors):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
----------
loc_pred : Buffer loc_pred : Buffer
Buffer of location regression predictions. Buffer of location regression predictions.
anchor : Buffer anchor : Buffer
Buffer of prior anchor boxes. Buffer of prior anchor boxes.
valid_count : Buffer temp_flag : Buffer
Buffer of number of valid output boxes. Intermediate result buffer.
temp_id : Buffer
Intermediate result buffer.
temp_score_in : Buffer
Input buffer which stores intermediate results.
out : Buffer out : Buffer
Output buffer. Output buffer.
...@@ -156,12 +242,18 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -156,12 +242,18 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
clip : boolean clip : boolean
Whether to clip out-of-boundary boxes. Whether to clip out-of-boundary boxes.
threshold : float
Threshold to be a positive prediction.
variances : tuple of float variances : tuple of float
Variances to be decoded from box regression output. Variances to be decoded from box regression output.
batch_size : int
Batch size
num_classes : int
Number of classes
num_anchors : int
Number of anchors
Returns Returns
------- -------
stmt : Stmt stmt : Stmt
...@@ -187,21 +279,16 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -187,21 +279,16 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
ow = tvm.exp(pw * vw) * aw / 2.0 ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh) tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
temp_score = ib.allocate('float32', (batch_size * (num_classes -1) * num_anchors, \ score = ib.buffer_ptr(temp_score_in)
), name="temp_score", scope="global") cls_id = ib.buffer_ptr(temp_id)
score = ib.allocate('float32', (batch_size * num_anchors, ), name="score", scope="local") flag = ib.buffer_ptr(temp_flag)
cls_id = ib.allocate('int32', (batch_size * num_anchors, ), name="id", scope="local")
flag = ib.allocate('int32', (batch_size * num_anchors, ), name="flag", scope="global")
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
nthread_tx = max_threads nthread_tx = max_threads
...@@ -209,42 +296,13 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -209,42 +296,13 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx tid = bx * max_threads + tx
p_cls_prob = ib.buffer_ptr(cls_prob)
p_loc_pred = ib.buffer_ptr(loc_pred) p_loc_pred = ib.buffer_ptr(loc_pred)
p_anchor = ib.buffer_ptr(anchor) p_anchor = ib.buffer_ptr(anchor)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out) p_out = ib.buffer_ptr(out)
with ib.if_scope(tid < batch_size * num_anchors * num_classes):
n = tid / (num_anchors * num_classes)
j = (tid % (num_anchors * num_classes)) / num_anchors
i = tid % num_anchors
with ib.if_scope(j > 0):
temp_score[n * num_anchors * num_classes + i * (num_classes - 1) + j-1] = \
p_cls_prob[tid]
p_valid_count[n] = 0
with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors
i = tid % num_anchors
score[tid] = -1.0
cls_id[tid] = 0
with ib.for_range(0, num_classes-1, name="k") as k:
temp = temp_score[tid * (num_classes-1) + k]
cls_id[tid] = tvm.select(temp > score[tid], k + 1, cls_id[tid])
score[tid] = tvm.make.Max(temp, score[tid])
with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
cls_id[tid] = 0
with ib.if_scope(cls_id[tid] > 0):
flag[tid] = 1
with ib.else_scope():
flag[tid] = 0
with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors, name="k") as k:
with ib.if_scope(k > 0):
flag[tid * num_anchors + k] += flag[tid * num_anchors + k - 1]
p_valid_count[tid] = flag[tid * num_anchors + num_anchors - 1]
with ib.if_scope(tid < batch_size * num_anchors): with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors n = tid / num_anchors # number of batches
i = tid % num_anchors i = tid % num_anchors # number of anchors
with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0): with ib.if_scope(tid == 0):
out_base_idx = n * num_anchors * 6 out_base_idx = n * num_anchors * 6
...@@ -253,17 +311,17 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho ...@@ -253,17 +311,17 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
p_out[out_base_idx] = cls_id[tid] - 1.0 p_out[out_base_idx] = cls_id[tid] - 1.0
p_out[out_base_idx + 1] = score[tid] p_out[out_base_idx + 1] = score[tid]
p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, p_anchor, i*4, p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4,
clip, variances[0], variances[1], p_anchor, i*4, clip, variances[0],
variances[2], variances[3]) variances[1], variances[2], variances[3])
body = ib.get() body = ib.get()
return body return body
@multibox_transform_loc.register(["cuda", "gpu"]) @multibox_transform_loc.register(["cuda", "gpu"])
def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \
variances=(0.1, 0.1, 0.2, 0.2)): threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)):
"""Location transformation for multibox detection """Location transformation for multibox detection
Parameters Parameters
...@@ -297,20 +355,42 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold= ...@@ -297,20 +355,42 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=
1-D tensor with shape (batch_size,), number of valid anchor boxes. 1-D tensor with shape (batch_size,), number of valid anchor boxes.
""" """
batch_size = cls_prob.shape[0] batch_size = cls_prob.shape[0]
num_anchors = anchor.shape[1] num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
oshape = (batch_size, num_anchors, 6) oshape = (batch_size, num_anchors, 6)
# Define data alignment for intermediate buffer # Define data alignment for intermediate buffer
valid_count_dtype = "int32" valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8) out_buf = api.decl_buffer(
valid_count, out = \ oshape, cls_prob.dtype, "out_buf", data_alignment=8)
tvm.extern([(batch_size,), oshape], size = num_anchors
[cls_prob, loc_pred, anchor], temp_flag_buf = api.decl_buffer(
(size,), valid_count_dtype, "flag", data_alignment=8)
temp_id_buf = api.decl_buffer(
(size,), valid_count_dtype, "cls_id", data_alignment=8)
temp_score_buf = api.decl_buffer(
(size,), cls_prob.dtype, "score", data_alignment=8)
valid_count, temp_flag, temp_id, temp_score = \
tvm.extern([(batch_size,), (size,), (size,), (size,)],
[cls_prob],
lambda ins, outs: transform_loc_pre(
ins[0], outs[0], outs[1], outs[2], outs[3], threshold),
dtype=[valid_count_dtype,
valid_count_dtype, valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf,
temp_flag_buf, temp_id_buf, temp_score_buf],
tag="multibox_transform_loc_first_step")
out = \
tvm.extern([oshape],
[loc_pred, anchor, temp_flag, temp_id, temp_score],
lambda ins, outs: transform_loc_ir( lambda ins, outs: transform_loc_ir(
ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances), ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \
dtype=[valid_count_dtype, cls_prob.dtype], variances, batch_size, num_classes, num_anchors),
out_buffers=[valid_count_buf, out_buf], dtype=[cls_prob.dtype],
out_buffers=[out_buf],
tag="multibox_transform_loc") tag="multibox_transform_loc")
return [out, valid_count] return [out, valid_count]
...@@ -356,5 +436,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 ...@@ -356,5 +436,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
""" """
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances) clip, threshold, variances)
out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) out = nms(
inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out return out
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment