Commit 48c16a17 by Leyuan Wang Committed by masahi

[BugFix] SSD fully supported on GPUs, updated deploy_ssd tutorial (#2510)

* nms fixed for gpu, tested on cuda and opencl devices, ssd now can run fully on the gpu

* sort updated to use virtual thread

* typo fixed

* fix lint

* fix lint

* add support when batch_size > 1

* intel graphics conv2d bugs fixed for inception_v3

* intel conv2d api updated, nn input size 4 condition added

* review addressed

* move conv_tags to attributes

* opencl ctx fixed

* nms_ir index simplified
parent 881a78b3
...@@ -5,427 +5,66 @@ import tvm ...@@ -5,427 +5,66 @@ import tvm
from tvm import api from tvm import api
from topi.vision import nms from topi.vision import nms
from ..util import get_const_tuple
def sort_ir(data, index, output):
def sort_pre_ir(index, sizes_out, axis_mul_before, axis_mul_after): """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
"""Low level IR routing subfunction 1/4 for computing segments' staring locatons.
Parameters
----------
index : Buffer
Buffer of number of valid output boxes.
sizes_out : Buffer
Output buffer of start locations of each sorting segment.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_index = ib.buffer_ptr(index)
dshape = sizes_out.shape
sizes = ib.buffer_ptr(sizes_out)
nthread_tx = max_threads
nthread_bx = dshape[0] // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
sizes[tid] = p_index[tid]
# scan
with ib.if_scope(tid < 1):
with ib.for_range(0, axis_mul_before * axis_mul_after - 1, name="k") as k:
sizes[k + 1] += sizes[k]
body = ib.get()
return body
def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
axis, axis_mul_before, axis_mul_after):
"""Low level IR routing subfunction 2/4 for flattening data and indices into segmented format.
Parameters Parameters
---------- ----------
data: Buffer data: Buffer
Buffer of output boxes with class and score. 2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : Buffer index : Buffer
Buffer of number of valid output boxes. 1D Buffer of number of valid number of boxes.
sizes_in : Buffer
Buffer of start locations of each sorting segment.
data_out : Buffer
Buffer of flattened segmented data.
index_out : Buffer
Buffer of flattened segmented indices.
axis : int
The axis used for sorting.
axis_mul_before : int output : Buffer
The multiplication result of axis dimensions before axis. 2D Output buffer of indicies of sorted tensor with shape [batch_size, num_anchors].
axis_mul_after : int
The multiplication result of axis dimensions after axis.
Returns Returns
------- -------
stmt : Stmt stmt : Stmt
The result IR statement. The result IR statement.
""" """
ib = tvm.ir_builder.create()
sizes = ib.buffer_ptr(sizes_in)
p_index = ib.buffer_ptr(index)
p_data = ib.buffer_ptr(data)
data_new = ib.buffer_ptr(data_out)
index_new = ib.buffer_ptr(index_out)
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
dshape = axis_mul_before * axis_mul_after
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
current_sort_num = p_index[tid]
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, current_sort_num, name="k") as k:
full_idx = base_idx + k * axis_mul_after
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
index_new[start + k] = k
data_new[start + k] = p_data[full_idx]
with ib.else_scope():
with ib.if_scope(tid == 0):
with ib.for_range(0, p_index[0], name="k") as k:
index_new[k] = k
body = ib.get()
return body
def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \
axis_mul_after, axis, is_descend):
"""Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_data : Buffer
Buffer of flattened segmented data.
new_index : Buffer
Buffer of flattened segmented indices.
loc : Buffer
Buffer of start locations of each sorting segment.
out_index : Buffer
Output buffer of output box indexes sorted by score in a flattened segmented format.
axis_mul_before : int assert data.dtype == "float32", "Currently only supports input dtype to be float32"
The multiplication result of axis dimensions before axis. batch, num_anchors = get_const_tuple(data.shape)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
dshape = loc.shape
fshape = data.shape[axis] * dshape[0]
temp_data = ib.allocate(
"float32", dshape, name="temp_data", scope="local")
p_data = ib.buffer_ptr(data) p_data = ib.buffer_ptr(data)
p_index = ib.buffer_ptr(index) p_index = ib.buffer_ptr(index)
data_new = ib.buffer_ptr(new_data)
index_new = ib.buffer_ptr(new_index)
index_out = ib.buffer_ptr(out_index)
sizes = ib.buffer_ptr(loc)
nthread_tx = max_threads
nthread_bx = fshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
# OddEvenTransposeSort
with ib.for_range(0, p_index[tid], name="k") as k:
with ib.for_range(0, p_index[tid] - 1, name="i") as i:
with ib.if_scope(i % 2 == k % 2):
with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) == is_descend)):
temp_data[tid] = data_new[i+start]
data_new[i+start] = data_new[i+start+1]
data_new[i+start+1] = temp_data[tid]
index_out[tid] = index_new[i+start]
index_new[i+start] = index_new[i+start+1]
index_new[i+start+1] = index_out[tid]
with ib.if_scope(tid < 1):
with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i:
index_out[i] = index_new[i]
with ib.else_scope():
with ib.for_range(0, fshape, name="k", for_type="unroll") as k:
with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \
== is_descend)):
data_new[tid] = p_data[tid+1]
index_out[tid] = index_new[tid+1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \
== is_descend)):
p_data[tid] = data_new[tid+1]
index_new[tid] = index_out[tid+1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid > 0, (p_data[tid-1] < p_data[tid]) == is_descend)):
data_new[tid] = p_data[tid-1]
index_out[tid] = index_new[tid-1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \
== is_descend)):
p_data[tid] = data_new[tid-1]
index_new[tid] = index_out[tid-1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(fshape % 2 == 1):
with ib.if_scope(tid < 1):
with ib.for_range(0, fshape, name="k") as k:
index_out[tid] = index_new[tid]
body = ib.get()
return body
def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis):
"""Low level IR routing subfunction 4/4 for writing sorted indices to output format.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_index : Buffer
Buffer of sorted indices in a flatten format.
loc : Buffer
Buffer of start locations of each sorting segment.
output : Buffer
Output buffer of output box indexes sorted by score.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
dshape = tvm.max(loc.shape[0], data.shape[axis])
p_index = ib.buffer_ptr(index)
index_new = ib.buffer_ptr(new_index)
sizes = ib.buffer_ptr(loc)
p_out = ib.buffer_ptr(output) p_out = ib.buffer_ptr(output)
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1 nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("vthread")
ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "virtual_thread", nthread_bx)
tid = bx * max_threads + tx tid = bx * nthread_tx + tx
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
with ib.if_scope(axis_mul_before * axis_mul_after > 1): temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after with ib.for_range(0, batch, for_type="unroll") as b:
j = tid % axis_mul_after start = b * num_anchors
base_idx = i * data.shape[axis] * axis_mul_after + j with ib.if_scope(tid < num_anchors):
with ib.for_range(0, data.shape[axis], name="k") as k: p_out[start + tid] = tid
with ib.if_scope(tid == 0): # OddEvenTransposeSort
start = 0 with ib.for_range(0, p_index[b]) as k:
with ib.else_scope(): with ib.if_scope(tid < (p_index[b] + 1) // 2):
start = sizes[tid-1] offset = start + 2 * tid + (k % 2)
p_out[base_idx + k * axis_mul_after] = tvm.if_then_else( with ib.if_scope( \
k < p_index[tid], index_new[k+start], k) tvm.all(offset + 1 < p_index[0], p_data[offset] < p_data[offset + 1])):
with ib.else_scope(): temp_data[0] = p_data[offset]
with ib.if_scope(tid < data.shape[axis]): p_data[offset] = p_data[offset + 1]
p_out[tid] = tvm.if_then_else(tid < p_index[0], index_new[tid], tid) p_data[offset + 1] = temp_data[0]
temp_index[0] = p_out[offset]
body = ib.get() p_out[offset] = p_out[offset + 1]
return body p_out[offset + 1] = temp_index[0]
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend): tvm.expr.Call.Intrinsic, None, 0))
"""Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu.
return ib.get()
Parameters
----------
data: tvm.Tensor
2-D tensor of input boxes' score with shape [batch_size, num_anchors].
data_buf: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : tvm.Tensor
1-D tensor for valid number of boxes.
index_buf : Buffer
Buffer of number of valid number of boxes.
output_buf : Buffer
Output buffer of indicies of sorted tensor.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
out : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors].
"""
ndim = len(data.shape)
assert data.dtype == "float32", "Currently only supports input dtype to be float32"
assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
axis_mul_before = 1
axis_mul_after = 1
if axis < 0:
axis = ndim + axis
for i in range(0, ndim):
if i < axis:
axis_mul_before *= data.shape[i]
elif i > axis:
axis_mul_after *= data.shape[i]
dshape = axis_mul_before*axis_mul_after
fshape = data.shape[axis] * dshape
loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8)
new_index_buf = api.decl_buffer(
fshape, index.dtype, "index_new", data_alignment=8)
out_index_buf = api.decl_buffer(
fshape, index.dtype, "index_out", data_alignment=8)
new_data_buf = api.decl_buffer(
dshape, data.dtype, "data_new", data_alignment=8)
loc = \
tvm.extern([(dshape,)],
[index],
lambda ins, outs: sort_pre_ir(
ins[0], outs[0], axis_mul_before, axis_mul_after),
dtype=[index.dtype],
in_buffers=index_buf,
out_buffers=[loc_buf],
tag="sorting_prepare")
data_new, index_new = \
tvm.extern([(dshape,), (fshape,)],
[data, index, loc],
lambda ins, outs: sort_pre_ir_data(
ins[0], ins[1], ins[2], outs[0], outs[1], axis,
axis_mul_before, axis_mul_after),
dtype=[data.dtype, index.dtype],
in_buffers=[data_buf, index_buf, loc_buf],
out_buffers=[new_data_buf, new_index_buf],
tag="sorting_data")
index_out = \
tvm.extern([(fshape,)],
[data, index, data_new, index_new, loc],
lambda ins, outs: sort_oet_ir(
ins[0], ins[1], ins[2], ins[3], ins[4], outs[0],
axis_mul_before, axis_mul_after, axis, is_descend),
dtype=[index.dtype],
in_buffers=[data_buf, index_buf,
new_data_buf, new_index_buf, loc_buf],
out_buffers=[out_index_buf],
tag="sorting_oet")
out = \
tvm.extern([data.shape],
[data, index, index_out, loc],
lambda ins, outs: sort_ir_out(
ins[0], ins[1], ins[2], ins[3], outs[0],
axis_mul_before, axis_mul_after, axis),
dtype=[index.dtype],
in_buffers=[data_buf, index_buf, out_index_buf, loc_buf],
out_buffers=output_buf,
tag="sorting_output")
return out
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
"""Low level IR routing for transform location in multibox_detection operator. """Low level IR routing for transform location in multibox_detection operator.
...@@ -461,10 +100,10 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -461,10 +100,10 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes. """Calculate overlap of two boxes.
""" """
w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
- tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
- tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
i = w * h i = w * h
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
...@@ -475,9 +114,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -475,9 +114,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
max_threads = int(math.sqrt( max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads)) tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
by = tvm.thread_axis("blockIdx.y")
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data) p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result) p_sort_result = ib.buffer_ptr(sort_result)
...@@ -487,67 +124,57 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -487,67 +124,57 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
num_anchors = out.shape[1] num_anchors = out.shape[1]
nthread_tx = max_threads nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1 nthread_bx = num_anchors // max_threads + 1
nthread_ty = max_threads
nthread_by = 6 // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(ty, "thread_extent", nthread_ty)
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
i = bx * max_threads + tx i = bx * max_threads + tx
j = by * max_threads + ty
nms_threshold_node = tvm.make.node( nms_threshold_node = tvm.make.node(
"FloatImm", dtype="float32", value=nms_threshold) "FloatImm", dtype="float32", value=nms_threshold)
nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk) nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
force_suppress_node = tvm.make.node( force_suppress_node = tvm.make.node(
"IntImm", dtype="int32", value=1 if force_suppress else 0) "IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="unroll", name="n") as n: with ib.for_range(0, batch_size, for_type="unroll") as b:
with ib.if_scope( base_idx = b * num_anchors * 6
tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, with ib.if_scope( \
p_valid_count[0] > 0)): tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
# Reorder output # Reorder output
nkeep = tvm.if_then_else( nkeep = tvm.if_then_else( \
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b]),
nms_topk, p_valid_count[n]) nms_topk, p_valid_count[b])
with ib.if_scope(i < nkeep): with ib.for_range(0, nkeep) as l:
with ib.if_scope(j < 6): with ib.if_scope(i < 6):
p_out[(n * num_anchors * 6 p_out[(base_idx + l * 6 + i)] = \
+ i * 6 + j)] = p_data[(n * num_anchors * 6 p_data[(base_idx + p_sort_result[b * num_anchors + l] * 6 + i)]
+ p_sort_result[n * num_anchors + i] * 6 + j)] with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])):
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])): with ib.for_range(0, p_valid_count[b] - nkeep) as l:
with ib.if_scope(i < p_valid_count[n] - nkeep): with ib.if_scope(i < 6):
with ib.if_scope(j < 6): p_out[(base_idx + (l + nkeep) * 6 + i)] = \
p_out[(n * num_anchors * 6 p_data[(base_idx + (l + nkeep) * 6 + i)]
+ (i + nkeep) * 6 + j)] = p_data[(n * num_anchors * 6
+ (i + nkeep) * 6 + j)]
# Apply nms # Apply nms
with ib.if_scope(i < p_valid_count[n]): with ib.for_range(0, p_valid_count[b]) as l:
offset_i = i * 6 offset_l = l * 6
with ib.if_scope(p_out[n * num_anchors * 6 + offset_i] >= 0): with ib.if_scope(p_out[base_idx + offset_l] >= 0):
with ib.if_scope(j < p_valid_count[n]): with ib.if_scope(i < p_valid_count[b]):
offset_j = j * 6 offset_i = i * 6
with ib.if_scope(tvm.all(j > i, p_out[n * num_anchors * 6 with ib.if_scope(tvm.all(i > l, p_out[base_idx
+ offset_j] >= 0)): + offset_i] >= 0)):
with ib.if_scope(tvm.any(force_suppress_node > 0, with ib.if_scope(tvm.any(force_suppress_node > 0,
p_out[n * num_anchors * 6 + offset_i] == p_out[base_idx + offset_l] ==
p_out[n * num_anchors * 6 + offset_j])): p_out[base_idx + offset_i])):
# When force_suppress == True or class_id equals # When force_suppress == True or class_id equals
iou = calculate_overlap( iou = calculate_overlap(p_out, base_idx + offset_l + 2,
p_out, n * num_anchors * 6 + offset_i + 2, base_idx + offset_i + 2)
n * num_anchors * 6 + offset_j + 2)
with ib.if_scope(iou >= nms_threshold): with ib.if_scope(iou >= nms_threshold):
p_out[ p_out[base_idx + offset_i] = -1.0
n * num_anchors * 6 + offset_j] = -1.0
with ib.else_scope(): with ib.else_scope():
with ib.if_scope(i < p_valid_count[n]): with ib.for_range(0, p_valid_count[b]) as c:
with ib.if_scope(j < 6): with ib.if_scope(i < 6):
p_out[(n * num_anchors * 6 p_out[(base_idx + c * 6 + i)] = p_data[base_idx + c * 6 + i]
+ i * 6 + j)] = p_data[n * num_anchors * 6 + i * 6 + j]
# Set invalid entry to be -1 # Set invalid entry to be -1
with ib.if_scope(i < num_anchors - p_valid_count[n]): with ib.for_range(0, num_anchors - p_valid_count[b]) as c:
with ib.if_scope(j < 6): with ib.if_scope(i < 6):
p_out[n * num_anchors * 6 + (i + p_out[base_idx + (c + p_valid_count[b]) * 6 + i] = -1.0
p_valid_count[n]) * 6 + j] = -1.0
body = ib.get() body = ib.get()
return body return body
...@@ -610,18 +237,26 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk ...@@ -610,18 +237,26 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
data_buf = api.decl_buffer( data_buf = api.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8) data.shape, data.dtype, "data_buf", data_alignment=8)
score_axis = 1
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute( score_tensor = tvm.compute(
score_shape, lambda i, j: data[i, j, score_axis], name="score_tensor") score_shape, lambda i, j: data[i, j, 1], name="score_tensor")
score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
"score_tensor_buf", data_alignment=8) "score_tensor_buf", data_alignment=8)
sort_tensor_dtype = "int32" sort_tensor_dtype = "int32"
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8) "sort_tensor_buf", data_alignment=8)
sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count, sort_tensor = \
valid_count_buf, sort_tensor_buf, score_axis, True) tvm.extern(score_shape,
[score_tensor, valid_count],
lambda ins, outs: sort_ir(
ins[0], ins[1], outs[0]),
dtype=sort_tensor_dtype,
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
out = \ out = \
tvm.extern(data.shape, tvm.extern(data.shape,
[data, sort_tensor, valid_count], [data, sort_tensor, valid_count],
......
""" """
Deploy Single Shot Multibox Detector(SSD) model Deploy Single Shot Multibox Detector(SSD) model
=============================================== ===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_ **Author**: `Yao Wang <https://github.com/kevinthesun>`_, \
`Leyuan Wang <https://github.com/Laurawly>`_
This article is an introductory tutorial to deploy SSD models with TVM. This article is an introductory tutorial to deploy SSD models with TVM.
We will use mxnet pretrained SSD model with Resnet50 as body network and We will use mxnet pretrained SSD model with Resnet50 as body network and
...@@ -32,17 +33,20 @@ from mxnet.model import load_checkpoint ...@@ -32,17 +33,20 @@ from mxnet.model import load_checkpoint
# echo "set(USE_SORT ON)" > config.mk # echo "set(USE_SORT ON)" > config.mk
# make -j8 # make -j8
# #
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
#
model_name = "ssd_resnet50_512" model_name = "ssd_resnet50_512"
model_file = "%s.zip" % model_name model_file = "%s.zip" % model_name
test_image = "dog.jpg" test_image = "dog.jpg"
dshape = (1, 3, 512, 512) dshape = (1, 3, 512, 512)
dtype = "float32" dtype = "float32"
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#ctx = tvm.gpu(0)
# Use these commented settings to build for opencl.
#target = 'opencl'
#ctx = tvm.opencl(0)
target = "llvm" target = "llvm"
ctx = tvm.cpu() ctx = tvm.cpu()
...@@ -56,7 +60,8 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \ ...@@ -56,7 +60,8 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
"resnet50_ssd_512_voc0712_trainval.zip" "resnet50_ssd_512_voc0712_trainval.zip"
image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg" "cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" inference_symbol_folder = \
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \ inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip" "archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
...@@ -92,7 +97,8 @@ parser.add_argument( ...@@ -92,7 +97,8 @@ parser.add_argument(
default="nnvm") default="nnvm")
args = parser.parse_args() args = parser.parse_args()
if args.frontend == "relay": if args.frontend == "relay":
net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, aux_params=aux_params) net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, \
aux_params=aux_params)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params) graph, lib, params = relay.build(net, target, params=params)
elif args.frontend == "nnvm": elif args.frontend == "nnvm":
...@@ -134,7 +140,7 @@ def display(img, out, thresh=0.5): ...@@ -134,7 +140,7 @@ def display(img, out, thresh=0.5):
import random import random
import matplotlib as mpl import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (10,10) mpl.rcParams['figure.figsize'] = (10, 10)
pens = dict() pens = dict()
plt.clf() plt.clf()
plt.imshow(img) plt.imshow(img)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment