Commit e6863990 by Leyuan Wang Committed by Tianqi Chen

[TOPI] Add GPU SSD (#1397)

parent 3a0b757c
......@@ -15,6 +15,8 @@ from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern
from .vision import schedule_region
from .vision import schedule_reorg
from .nn import schedule_lrn, schedule_l2_normalize
from .vision import *
from . import ssd
from .ssd import *
from .nms import *
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison
"""Non-maximum suppression operator"""
import math
import tvm
from tvm import api
from topi.vision import nms
def sort_ir(data, index, output, axis, is_descend):
"""Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Parameters
----------
data: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : Buffer
Buffer of number of valid number of boxes.
output : Buffer
Output buffer of indicies of sorted tensor.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_index = ib.buffer_ptr(index)
p_out = ib.buffer_ptr(output)
ndim = len(data.shape)
assert data.dtype == "float32", "Currently only supports input dtype to be float32"
assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
axis_mul_before = 1
axis_mul_after = 1
if axis < 0:
axis = ndim + axis
for i in range(0, ndim):
if i < axis:
axis_mul_before *= data.shape[i]
elif i > axis:
axis_mul_after *= data.shape[i]
dshape = 0
for i in range(0, len(index.shape)):
dshape += index.shape[i]
dshape = tvm.select(dshape > axis_mul_before*axis_mul_after, dshape,
axis_mul_before*axis_mul_after)
sizes_temp = ib.allocate(
"int32", dshape, name="sizes_temp", scope="global")
sizes = ib.allocate("int32", dshape, name="sizes", scope="global")
temp_index = ib.allocate("int32", dshape, name="temp_index", scope="local")
temp_data = ib.allocate("float32", dshape, name="temp_data", scope="local")
data_new = ib.allocate("float32", dshape, name="data_new", scope="global")
index_new = ib.allocate("int32", dshape, name="index_new", scope="global")
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
sizes[tid] = p_index[tid]
sizes_temp[tid] = p_index[tid]
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
with ib.for_range(0, tvm.floor(tvm.sqrt((axis_mul_before * axis_mul_after) \
.astype("float32"))) + 1, name="k") as k:
with ib.if_scope(tid - (tvm.const(1, "int32") << k) >= 0):
with ib.if_scope(k % 2 == 0):
sizes[tid] += sizes_temp[tid - (
tvm.const(1, "int32") << k)]
sizes_temp[tid] = sizes[tid]
with ib.else_scope():
sizes_temp[tid] += sizes[tid - (
tvm.const(1, "int32") << k)]
sizes[tid] = sizes_temp[tid]
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
current_sort_num = p_index[tid]
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, current_sort_num, name="k") as k:
full_idx = base_idx + k * axis_mul_after
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
index_new[start + k] = k
data_new[start + k] = p_data[full_idx]
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
# OddEvenTransposeSort
with ib.for_range(0, p_index[tid], name="k") as k:
with ib.for_range(0, p_index[tid] - 1, name="i") as i:
with ib.if_scope(i % 2 == (k & 1)):
with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) ^
is_descend) == False):
temp_data[tid] = data_new[i+start]
data_new[i+start] = data_new[i+start+1]
data_new[i+start+1] = temp_data[tid]
temp_index[tid] = index_new[i+start]
index_new[i+start] = index_new[i+start+1]
index_new[i+start+1] = temp_index[tid]
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
current_sort_num = p_index[tid]
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, data.shape[axis], name="k") as k:
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.select(
k < current_sort_num,
index_new[k+start], k)
body = ib.get()
return body
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
sort_result : Buffer
Buffer of output box indexes sorted by score.
valid_count : Buffer
Buffer of number of valid output boxes.
out : Buffer
Output buffer.
nms_threshold : float
Non-maximum suppression threshold.
force_suppress : boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
Keep maximum top k detections before nms, -1 for no limit.
Returns
-------
stmt : Stmt
The result IR statement.
"""
def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes.
"""
w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
- tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
- tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
i = w * h
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.select(u <= 0.0, 0.0, i / u)
max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x")
by = tvm.thread_axis("blockIdx.y")
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
batch_size = out.shape[0]
num_anchors = out.shape[1]
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_ty = max_threads
nthread_by = 6 // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(ty, "thread_extent", nthread_ty)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
i = bx * max_threads + tx
j = by * max_threads + ty
nms_threshold_node = tvm.make.node(
"FloatImm", dtype="float32", value=nms_threshold)
nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
force_suppress_node = tvm.make.node(
"IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="unroll", name="n") as n:
with ib.if_scope(
tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
# Reorder output
nkeep = tvm.select(
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n])
with ib.if_scope(i < nkeep):
with ib.if_scope(j < 6):
p_out[(n * num_anchors * 6
+ i * 6 + j)] = p_data[(n * num_anchors * 6
+ p_sort_result[n * num_anchors + i] * 6 + j)]
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])):
with ib.if_scope(i < p_valid_count[n] - nkeep):
with ib.if_scope(j < 6):
p_out[(n * num_anchors * 6
+ (i + nkeep) * 6 + j)] = p_data[(n * num_anchors * 6
+ (i + nkeep) * 6 + j)]
# Apply nms
with ib.if_scope(i < p_valid_count[n]):
offset_i = i * 6
with ib.if_scope(p_out[n * num_anchors * 6 + offset_i] >= 0):
with ib.if_scope(j < p_valid_count[n]):
offset_j = j * 6
with ib.if_scope(tvm.all(j > i, p_out[n * num_anchors * 6
+ offset_j] >= 0)):
with ib.if_scope(tvm.any(force_suppress_node > 0,
p_out[n * num_anchors * 6 + offset_i] ==
p_out[n * num_anchors * 6 + offset_j])):
# When force_suppress == True or class_id equals
iou = calculate_overlap(
p_out, n * num_anchors * 6 + offset_i + 2,
n * num_anchors * 6 + offset_j + 2)
with ib.if_scope(iou >= nms_threshold):
p_out[
n * num_anchors * 6 + offset_j] = -1.0
with ib.else_scope():
with ib.if_scope(i < p_valid_count[n]):
with ib.if_scope(j < 6):
p_out[(n * num_anchors * 6
+ i * 6 + j)] = p_data[n * num_anchors * 6 + i * 6 + j]
# Set invalid entry to be -1
with ib.if_scope(i < num_anchors - p_valid_count[n]):
with ib.if_scope(j < 6):
p_out[n * num_anchors * 6 + (i +
p_valid_count[n]) * 6 + j] = -1.0
body = ib.get()
return body
@nms.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
nms_threshold : float
Non-maximum suppression threshold.
force_suppress : boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
Keep maximum top k detections before nms, -1 for no limit.
Returns
-------
out : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
Example
--------
.. code-block:: python
# An example to use nms
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder(
(dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
"valid_count_buf", data_alignment=4)
data_buf = api.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8)
score_axis = 1
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(
score_shape, lambda i, j: data[i, j, score_axis], name="score_tensor")
score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
"score_tensor_buf", data_alignment=8)
sort_tensor_dtype = "int32"
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8)
sort_tensor = \
tvm.extern(score_shape,
[score_tensor, valid_count],
lambda ins, outs: sort_ir(
ins[0], ins[1], outs[0], score_axis, True),
dtype=sort_tensor_dtype,
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
out = \
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], nms_threshold,
force_suppress, nms_topk),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
return out
# pylint: disable=wildcard-import
"""VISION network operators"""
from __future__ import absolute_import as _abs
from .multibox import *
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements
"""SSD multibox operators"""
from __future__ import absolute_import as _abs
import math
import tvm
from tvm import api
import topi
from topi.vision.ssd import multibox_prior
from topi.vision.ssd import multibox_detection
from topi.vision.ssd import multibox_transform_loc
from ..nms import nms
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
"""Low level IR routing for multibox_prior operator.
Parameters
----------
data : Buffer
Input data buffer.
out : Buffer
Output buffer.
sizes : tuple of float
Tuple of sizes for anchor boxes.
ratios : tuple of float
Tuple of ratios for anchor boxes.
steps : Tuple of float
Priorbox step across y and x, -1 for auto calculation.
offsets : tuple of int
Priorbox center offsets, y and x respectively.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x")
by = tvm.thread_axis("blockIdx.y")
ib = tvm.ir_builder.create()
p_out = ib.buffer_ptr(out)
in_height = data.shape[2]
in_width = data.shape[3]
nthread_tx = max_threads
nthread_bx = in_height // max_threads + 1
nthread_ty = max_threads
nthread_by = in_width // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(ty, "thread_extent", nthread_ty)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
num_sizes = len(sizes)
num_ratios = len(ratios)
size_ratio_concat = sizes + ratios
steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
offset_h = offsets[0]
offset_w = offsets[1]
i = bx * max_threads + tx
j = by * max_threads + ty
with ib.if_scope((i < in_height)):
with ib.if_scope((j < in_width)):
center_h = (i + offset_h) * steps_h
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
w = tvm.select(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
count = (i * in_width * (num_sizes + num_ratios - 1) +
j * (num_sizes + num_ratios - 1) + k) * 4
p_out[count] = center_w - w
p_out[count + 1] = center_h - h
p_out[count + 2] = center_w + w
p_out[count + 3] = center_h + h
body = ib.get()
return body
@multibox_prior.register(["cuda", "gpu"])
def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), \
offsets=(0.5, 0.5), clip=False):
"""Generate prior(anchor) boxes from data, sizes and ratios.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]]
sizes : tuple of float
Tuple of sizes for anchor boxes.
ratios : tuple of float
Tuple of ratios for anchor boxes.
steps : Tuple of float
Priorbox step across y and x, -1 for auto calculation.
offsets : tuple of int
Priorbox center offsets, y and x respectively.
clip : boolean
Whether to clip out-of-boundary boxes.
Returns
-------
out : tvm.Tensor
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
num_sizes = len(sizes)
num_ratios = len(ratios)
oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
out = tvm.extern(oshape, [data], lambda ins, outs:
multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets),
tag="multibox_prior")
if clip:
out = topi.clip(out, 0, 1)
return out
def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
----------
cls_prob : Buffer
Buffer of class probabilities.
loc_pred : Buffer
Buffer of location regression predictions.
anchor : Buffer
Buffer of prior anchor boxes.
valid_count : Buffer
Buffer of number of valid output boxes.
out : Buffer
Output buffer.
clip : boolean
Whether to clip out-of-boundary boxes.
threshold : float
Threshold to be a positive prediction.
variances : tuple of float
Variances to be decoded from box regression output.
Returns
-------
stmt : Stmt
The result IR statement.
"""
def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh):
"""Transform prior anchor box to output box through location predictions.
"""
al = anchor[anchor_base_idx]
at = anchor[anchor_base_idx + 1]
ar = anchor[anchor_base_idx + 2]
ab = anchor[anchor_base_idx + 3]
aw = ar - al
ah = ab - at
ax = (al + ar) / 2.0
ay = (at + ab) / 2.0
px = loc[loc_base_idx]
py = loc[loc_base_idx + 1]
pw = loc[loc_base_idx + 2]
ph = loc[loc_base_idx + 3]
ox = px * vx * aw + ax
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
ib = tvm.ir_builder.create()
temp_score = ib.allocate('float32', (batch_size * (num_classes -1) * num_anchors, \
), name="temp_score", scope="global")
score = ib.allocate('float32', (batch_size * num_anchors, ), name="score", scope="local")
cls_id = ib.allocate('int32', (batch_size * num_anchors, ), name="id", scope="local")
flag = ib.allocate('int32', (batch_size * num_anchors, ), name="flag", scope="global")
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
nthread_tx = max_threads
nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
p_cls_prob = ib.buffer_ptr(cls_prob)
p_loc_pred = ib.buffer_ptr(loc_pred)
p_anchor = ib.buffer_ptr(anchor)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
with ib.if_scope(tid < batch_size * num_anchors * num_classes):
n = tid / (num_anchors * num_classes)
j = (tid % (num_anchors * num_classes)) / num_anchors
i = tid % num_anchors
with ib.if_scope(j > 0):
temp_score[n * num_anchors * num_classes + i * (num_classes - 1) + j-1] = \
p_cls_prob[tid]
p_valid_count[n] = 0
with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors
i = tid % num_anchors
score[tid] = -1.0
cls_id[tid] = 0
with ib.for_range(0, num_classes-1, name="k") as k:
temp = temp_score[tid * (num_classes-1) + k]
cls_id[tid] = tvm.select(temp > score[tid], k + 1, cls_id[tid])
score[tid] = tvm.make.Max(temp, score[tid])
with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
cls_id[tid] = 0
with ib.if_scope(cls_id[tid] > 0):
flag[tid] = 1
with ib.else_scope():
flag[tid] = 0
with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors, name="k") as k:
with ib.if_scope(k > 0):
flag[tid * num_anchors + k] += flag[tid * num_anchors + k - 1]
p_valid_count[tid] = flag[tid * num_anchors + num_anchors - 1]
with ib.if_scope(tid < batch_size * num_anchors):
n = tid / num_anchors
i = tid % num_anchors
with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0):
out_base_idx = n * num_anchors * 6
with ib.else_scope():
out_base_idx = n * num_anchors * 6 + flag[tid - 1] * 6
p_out[out_base_idx] = cls_id[tid] - 1.0
p_out[out_base_idx + 1] = score[tid]
p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, p_anchor, i*4,
clip, variances[0], variances[1],
variances[2], variances[3])
body = ib.get()
return body
@multibox_transform_loc.register(["cuda", "gpu"])
def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
variances=(0.1, 0.1, 0.2, 0.2)):
"""Location transformation for multibox detection
Parameters
----------
cls_prob : tvm.Tensor
Class probabilities.
loc_pred : tvm.Tensor
Location regression predictions.
anchor : tvm.Tensor
Prior anchor boxes.
clip : boolean
Whether to clip out-of-boundary boxes.
threshold : float
Threshold to be a positive prediction.
variances : tuple of float
Variances to be decoded from box regression output.
Returns
-------
ret : tuple of tvm.Tensor composed of
out : tvm.Tensor
3-D tensor with shape (batch_size, num_anchors, 6)
valid_count : tvm.Tensor
1-D tensor with shape (batch_size,), number of valid anchor boxes.
"""
batch_size = cls_prob.shape[0]
num_anchors = anchor.shape[1]
oshape = (batch_size, num_anchors, 6)
# Define data alignment for intermediate buffer
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8)
valid_count, out = \
tvm.extern([(batch_size,), oshape],
[cls_prob, loc_pred, anchor],
lambda ins, outs: transform_loc_ir(
ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances),
dtype=[valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf, out_buf],
tag="multibox_transform_loc")
return [out, valid_count]
@multibox_detection.register(["cuda", "gpu"])
def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
force_suppress=False, variances=(0.1, 0.1, 0.2, 0.2), nms_topk=-1):
"""Convert multibox detection predictions.
Parameters
----------
cls_prob : tvm.Tensor
Class probabilities.
loc_pred : tvm.Tensor
Location regression predictions.
anchor : tvm.Tensor
Prior anchor boxes.
clip : boolean
Whether to clip out-of-boundary boxes.
nms_threshold : float
Non-maximum suppression threshold.
force_suppress : boolean
Whether to suppress all detections regardless of class_id.
threshold : float
Threshold to be a positive prediction.
variances : tuple of float
Variances to be decoded from box regression output.
nms_topk : int
Keep maximum top k detections before nms, -1 for no limit.
Returns
-------
out : tvm.Tensor
3-D tensor with shape (batch_size, num_anchors, 6)
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out
# pylint: disable=invalid-name, unused-variable, unused-argument
# pylint: disable=invalid-name, unused-variable, unused-argument, no-member
"""Schedule for vision operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from .. import cpp
from .. import tag
def _default_schedule(outs):
"""Default schedule for gpu."""
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if "nms" in op.tag:
sort = op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[score].split(fused, factor=num_thread)
s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else:
x = op.output(0)
fused = s[x].fuse(*s[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[x].split(fused, factor=num_thread)
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
traverse(outs[0].op)
return s
@generic.schedule_reorg.register(["cuda", "gpu"])
def schedule_reorg(outs):
......@@ -41,8 +74,25 @@ def schedule_region(outs):
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_region(cpp_target, outs)
@generic.schedule_nms.register(["cuda", "gpu"])
def schedule_nms(outs):
"""Schedule for non-maximum suppression
Parameters
----------
outs: Array of Tensor
The computation graph description of nms
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs)
@generic.schedule_multibox_prior.register(["cuda", "gpu"])
def schedule_multibox_prior(out):
def schedule_multibox_prior(outs):
"""Schedule for multibox_prior operator.
Parameters
......@@ -56,10 +106,28 @@ def schedule_multibox_prior(out):
s: Schedule
The computation schedule for multibox_prior.
"""
raise RuntimeError("Currently multibox_prior only supports CPU.")
return _default_schedule(outs)
@generic.schedule_multibox_transform_loc.register(["cuda", "gpu"])
def schedule_multibox_transform_loc(outs):
"""Schedule for multibox_transform_loc
Parameters
----------
outs: Array of Tensor
The computation graph description of
multibox_transform_loc in the format
of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs)
@generic.schedule_multibox_detection.register(["cuda", "gpu"])
def schedule_multibox_detection(out):
def schedule_multibox_detection(outs):
"""Schedule for multibox_detection operator.
Parameters
......@@ -73,4 +141,4 @@ def schedule_multibox_detection(out):
s: Schedule
The computation schedule for multibox_detection.
"""
raise RuntimeError("Currently multibox_detection only supports CPU.")
return _default_schedule(outs)
......@@ -20,23 +20,27 @@ def verify_clip(N, a_min, a_max, dtype):
a_np, b_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
ctx = tvm.cpu(0) if device == "llvm" else tvm.gpu(0)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device, name="clip")
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
def test_clip():
verify_clip(1024, -127, 127, 'int8')
verify_clip(1024, -127, 127, 'int16')
verify_clip(1024, -127, 127, 'float32')
verify_clip(1024, -127, 127, 'int16')
verify_clip(1024, -127, 127, 'int8')
if __name__ == "__main__":
......
......@@ -8,6 +8,8 @@ from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
......@@ -59,7 +61,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
def test_conv2d_nchw():
# ResNet18 worklaods
# ResNet18 workloads
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
......@@ -72,6 +74,21 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# ResNet 50 workloads
verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0)
verify_conv2d_nchw(1, 128, 28, 512, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 28, 128, 1, 1, 0)
verify_conv2d_nchw(1, 512, 28, 256, 1, 2, 0)
verify_conv2d_nchw(1, 256, 14, 1024, 1, 1, 0)
verify_conv2d_nchw(1, 512, 28, 1024, 1, 2, 0)
verify_conv2d_nchw(1, 1024, 14, 256, 1, 1, 0)
verify_conv2d_nchw(1, 1024, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 2048, 1, 2, 0)
verify_conv2d_nchw(1, 1024, 14, 2048, 1, 2, 0)
verify_conv2d_nchw(1, 2048, 7, 512, 1, 1, 0)
# Vgg16 workloads
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
# Super resolution workloads
......
......@@ -14,7 +14,6 @@ def test_nms():
nms_threshold = 0.7
force_suppress = True
nms_topk = 2
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
......@@ -31,6 +30,10 @@ def test_nms():
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
else:
out = topi.cuda.nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
s = topi.generic.schedule_nms(out)
tvm_data = tvm.nd.array(np_data, ctx)
......@@ -40,13 +43,12 @@ def test_nms():
f(tvm_data, tvm_valid_count, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
data = tvm.placeholder(dshape, name="data")
out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
dtype = data.dtype
input_data = np.random.uniform(size=dshape).astype(dtype)
......@@ -88,15 +90,19 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
else:
out = topi.cuda.ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
s = topi.generic.schedule_multibox_prior(out)
tvm_input_data = tvm.nd.array(input_data, ctx)
tvm_out = tvm.nd.array(np.zeros(oshape, dtype=dtype), ctx)
f = tvm.build(s, [data, out], device)
f(tvm_input_data, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-3)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
......@@ -113,7 +119,6 @@ def test_multibox_detection():
cls_prob = tvm.placeholder((batch_size, num_anchors, num_classes), name="cls_prob")
loc_preds = tvm.placeholder((batch_size, num_anchors * 4), name="loc_preds")
anchors = tvm.placeholder((1, num_anchors, 4), name="anchors")
out = ssd.multibox_detection(cls_prob, loc_preds, anchors)
# Manually create test case
np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]])
......@@ -131,6 +136,10 @@ def test_multibox_detection():
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = ssd.multibox_detection(cls_prob, loc_preds, anchors)
else:
out = topi.cuda.ssd.multibox_detection(cls_prob, loc_preds, anchors)
s = topi.generic.schedule_multibox_detection(out)
tvm_cls_prob = tvm.nd.array(np_cls_prob.astype(cls_prob.dtype), ctx)
......@@ -141,7 +150,7 @@ def test_multibox_detection():
f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment