Commit 48c16a17 by Leyuan Wang Committed by masahi

[BugFix] SSD fully supported on GPUs, updated deploy_ssd tutorial (#2510)

* nms fixed for gpu, tested on cuda and opencl devices, ssd now can run fully on the gpu

* sort updated to use virtual thread

* typo fixed

* fix lint

* fix lint

* add support when batch_size > 1

* intel graphics conv2d bugs fixed for inception_v3

* intel conv2d api updated, nn input size 4 condition added

* review addressed

* move conv_tags to attributes

* opencl ctx fixed

* nms_ir index simplified
parent 881a78b3
......@@ -5,427 +5,66 @@ import tvm
from tvm import api
from topi.vision import nms
from ..util import get_const_tuple
def sort_pre_ir(index, sizes_out, axis_mul_before, axis_mul_after):
"""Low level IR routing subfunction 1/4 for computing segments' staring locatons.
Parameters
----------
index : Buffer
Buffer of number of valid output boxes.
sizes_out : Buffer
Output buffer of start locations of each sorting segment.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_index = ib.buffer_ptr(index)
dshape = sizes_out.shape
sizes = ib.buffer_ptr(sizes_out)
nthread_tx = max_threads
nthread_bx = dshape[0] // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
sizes[tid] = p_index[tid]
# scan
with ib.if_scope(tid < 1):
with ib.for_range(0, axis_mul_before * axis_mul_after - 1, name="k") as k:
sizes[k + 1] += sizes[k]
body = ib.get()
return body
def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
axis, axis_mul_before, axis_mul_after):
"""Low level IR routing subfunction 2/4 for flattening data and indices into segmented format.
def sort_ir(data, index, output):
"""Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : Buffer
Buffer of number of valid output boxes.
sizes_in : Buffer
Buffer of start locations of each sorting segment.
data_out : Buffer
Buffer of flattened segmented data.
index_out : Buffer
Buffer of flattened segmented indices.
axis : int
The axis used for sorting.
1D Buffer of number of valid number of boxes.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
output : Buffer
2D Output buffer of indicies of sorted tensor with shape [batch_size, num_anchors].
Returns
-------
stmt : Stmt
The result IR statement.
"""
ib = tvm.ir_builder.create()
sizes = ib.buffer_ptr(sizes_in)
p_index = ib.buffer_ptr(index)
p_data = ib.buffer_ptr(data)
data_new = ib.buffer_ptr(data_out)
index_new = ib.buffer_ptr(index_out)
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
dshape = axis_mul_before * axis_mul_after
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
current_sort_num = p_index[tid]
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, current_sort_num, name="k") as k:
full_idx = base_idx + k * axis_mul_after
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
index_new[start + k] = k
data_new[start + k] = p_data[full_idx]
with ib.else_scope():
with ib.if_scope(tid == 0):
with ib.for_range(0, p_index[0], name="k") as k:
index_new[k] = k
body = ib.get()
return body
def sort_oet_ir(data, index, new_data, new_index, loc, out_index, axis_mul_before, \
axis_mul_after, axis, is_descend):
"""Low level IR routing subfunction 3/4 for Odd-Even-Transposition sorting.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_data : Buffer
Buffer of flattened segmented data.
new_index : Buffer
Buffer of flattened segmented indices.
loc : Buffer
Buffer of start locations of each sorting segment.
out_index : Buffer
Output buffer of output box indexes sorted by score in a flattened segmented format.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
assert data.dtype == "float32", "Currently only supports input dtype to be float32"
batch, num_anchors = get_const_tuple(data.shape)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
ib = tvm.ir_builder.create()
dshape = loc.shape
fshape = data.shape[axis] * dshape[0]
temp_data = ib.allocate(
"float32", dshape, name="temp_data", scope="local")
p_data = ib.buffer_ptr(data)
p_index = ib.buffer_ptr(index)
data_new = ib.buffer_ptr(new_data)
index_new = ib.buffer_ptr(new_index)
index_out = ib.buffer_ptr(out_index)
sizes = ib.buffer_ptr(loc)
nthread_tx = max_threads
nthread_bx = fshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
# OddEvenTransposeSort
with ib.for_range(0, p_index[tid], name="k") as k:
with ib.for_range(0, p_index[tid] - 1, name="i") as i:
with ib.if_scope(i % 2 == k % 2):
with ib.if_scope(((data_new[i+start] < data_new[i+start+1]) == is_descend)):
temp_data[tid] = data_new[i+start]
data_new[i+start] = data_new[i+start+1]
data_new[i+start+1] = temp_data[tid]
index_out[tid] = index_new[i+start]
index_new[i+start] = index_new[i+start+1]
index_new[i+start+1] = index_out[tid]
with ib.if_scope(tid < 1):
with ib.for_range(0, sizes[dshape[0] - 1], name="i") as i:
index_out[i] = index_new[i]
with ib.else_scope():
with ib.for_range(0, fshape, name="k", for_type="unroll") as k:
with ib.if_scope(tvm.all(k % 2 == tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid + 1 < fshape, (p_data[tid] < p_data[tid+1]) \
== is_descend)):
data_new[tid] = p_data[tid+1]
index_out[tid] = index_new[tid+1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid + 1 < fshape, (data_new[tid] < data_new[tid+1]) \
== is_descend)):
p_data[tid] = data_new[tid+1]
index_new[tid] = index_out[tid+1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(tvm.all(k % 2 != tid % 2, tid < fshape)):
with ib.if_scope(k % 2 == 0):
with ib.if_scope(tvm.all(tid > 0, (p_data[tid-1] < p_data[tid]) == is_descend)):
data_new[tid] = p_data[tid-1]
index_out[tid] = index_new[tid-1]
with ib.else_scope():
data_new[tid] = p_data[tid]
index_out[tid] = index_new[tid]
with ib.else_scope():
with ib.if_scope(tvm.all(tid > 0, (data_new[tid-1] < data_new[tid]) \
== is_descend)):
p_data[tid] = data_new[tid-1]
index_new[tid] = index_out[tid-1]
with ib.else_scope():
p_data[tid] = data_new[tid]
index_new[tid] = index_out[tid]
with ib.if_scope(fshape % 2 == 1):
with ib.if_scope(tid < 1):
with ib.for_range(0, fshape, name="k") as k:
index_out[tid] = index_new[tid]
body = ib.get()
return body
def sort_ir_out(data, index, new_index, loc, output, axis_mul_before, axis_mul_after, axis):
"""Low level IR routing subfunction 4/4 for writing sorted indices to output format.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
index : Buffer
Buffer of number of valid output boxes.
new_index : Buffer
Buffer of sorted indices in a flatten format.
loc : Buffer
Buffer of start locations of each sorting segment.
output : Buffer
Output buffer of output box indexes sorted by score.
axis_mul_before : int
The multiplication result of axis dimensions before axis.
axis_mul_after : int
The multiplication result of axis dimensions after axis.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
dshape = tvm.max(loc.shape[0], data.shape[axis])
p_index = ib.buffer_ptr(index)
index_new = ib.buffer_ptr(new_index)
sizes = ib.buffer_ptr(loc)
p_out = ib.buffer_ptr(output)
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("vthread")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(axis_mul_before * axis_mul_after > 1):
with ib.if_scope(tid < axis_mul_before * axis_mul_after):
i = tid / axis_mul_after
j = tid % axis_mul_after
base_idx = i * data.shape[axis] * axis_mul_after + j
with ib.for_range(0, data.shape[axis], name="k") as k:
with ib.if_scope(tid == 0):
start = 0
with ib.else_scope():
start = sizes[tid-1]
p_out[base_idx + k * axis_mul_after] = tvm.if_then_else(
k < p_index[tid], index_new[k+start], k)
with ib.else_scope():
with ib.if_scope(tid < data.shape[axis]):
p_out[tid] = tvm.if_then_else(tid < p_index[0], index_new[tid], tid)
body = ib.get()
return body
def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
"""Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu.
Parameters
----------
data: tvm.Tensor
2-D tensor of input boxes' score with shape [batch_size, num_anchors].
data_buf: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
index : tvm.Tensor
1-D tensor for valid number of boxes.
index_buf : Buffer
Buffer of number of valid number of boxes.
output_buf : Buffer
Output buffer of indicies of sorted tensor.
axis : int
The axis used for sorting.
is_descend : bool
If the sorted data is in descending order.
Returns
-------
out : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors].
"""
ndim = len(data.shape)
assert data.dtype == "float32", "Currently only supports input dtype to be float32"
assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim
axis_mul_before = 1
axis_mul_after = 1
if axis < 0:
axis = ndim + axis
for i in range(0, ndim):
if i < axis:
axis_mul_before *= data.shape[i]
elif i > axis:
axis_mul_after *= data.shape[i]
dshape = axis_mul_before*axis_mul_after
fshape = data.shape[axis] * dshape
loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8)
new_index_buf = api.decl_buffer(
fshape, index.dtype, "index_new", data_alignment=8)
out_index_buf = api.decl_buffer(
fshape, index.dtype, "index_out", data_alignment=8)
new_data_buf = api.decl_buffer(
dshape, data.dtype, "data_new", data_alignment=8)
loc = \
tvm.extern([(dshape,)],
[index],
lambda ins, outs: sort_pre_ir(
ins[0], outs[0], axis_mul_before, axis_mul_after),
dtype=[index.dtype],
in_buffers=index_buf,
out_buffers=[loc_buf],
tag="sorting_prepare")
data_new, index_new = \
tvm.extern([(dshape,), (fshape,)],
[data, index, loc],
lambda ins, outs: sort_pre_ir_data(
ins[0], ins[1], ins[2], outs[0], outs[1], axis,
axis_mul_before, axis_mul_after),
dtype=[data.dtype, index.dtype],
in_buffers=[data_buf, index_buf, loc_buf],
out_buffers=[new_data_buf, new_index_buf],
tag="sorting_data")
index_out = \
tvm.extern([(fshape,)],
[data, index, data_new, index_new, loc],
lambda ins, outs: sort_oet_ir(
ins[0], ins[1], ins[2], ins[3], ins[4], outs[0],
axis_mul_before, axis_mul_after, axis, is_descend),
dtype=[index.dtype],
in_buffers=[data_buf, index_buf,
new_data_buf, new_index_buf, loc_buf],
out_buffers=[out_index_buf],
tag="sorting_oet")
out = \
tvm.extern([data.shape],
[data, index, index_out, loc],
lambda ins, outs: sort_ir_out(
ins[0], ins[1], ins[2], ins[3], outs[0],
axis_mul_before, axis_mul_after, axis),
dtype=[index.dtype],
in_buffers=[data_buf, index_buf, out_index_buf, loc_buf],
out_buffers=output_buf,
tag="sorting_output")
return out
ib.scope_attr(bx, "virtual_thread", nthread_bx)
tid = bx * nthread_tx + tx
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
with ib.for_range(0, batch, for_type="unroll") as b:
start = b * num_anchors
with ib.if_scope(tid < num_anchors):
p_out[start + tid] = tid
# OddEvenTransposeSort
with ib.for_range(0, p_index[b]) as k:
with ib.if_scope(tid < (p_index[b] + 1) // 2):
offset = start + 2 * tid + (k % 2)
with ib.if_scope( \
tvm.all(offset + 1 < p_index[0], p_data[offset] < p_data[offset + 1])):
temp_data[0] = p_data[offset]
p_data[offset] = p_data[offset + 1]
p_data[offset + 1] = temp_data[0]
temp_index[0] = p_out[offset]
p_out[offset] = p_out[offset + 1]
p_out[offset + 1] = temp_index[0]
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
return ib.get()
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
"""Low level IR routing for transform location in multibox_detection operator.
......@@ -461,10 +100,10 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes.
"""
w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
- tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
- tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
- tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
- tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
i = w * h
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
......@@ -475,9 +114,7 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
ty = tvm.thread_axis("threadIdx.y")
bx = tvm.thread_axis("blockIdx.x")
by = tvm.thread_axis("blockIdx.y")
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result)
......@@ -487,67 +124,57 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
num_anchors = out.shape[1]
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_ty = max_threads
nthread_by = 6 // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(ty, "thread_extent", nthread_ty)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
i = bx * max_threads + tx
j = by * max_threads + ty
nms_threshold_node = tvm.make.node(
"FloatImm", dtype="float32", value=nms_threshold)
nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
force_suppress_node = tvm.make.node(
"IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="unroll", name="n") as n:
with ib.if_scope(
tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
with ib.for_range(0, batch_size, for_type="unroll") as b:
base_idx = b * num_anchors * 6
with ib.if_scope( \
tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
# Reorder output
nkeep = tvm.if_then_else(
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n])
with ib.if_scope(i < nkeep):
with ib.if_scope(j < 6):
p_out[(n * num_anchors * 6
+ i * 6 + j)] = p_data[(n * num_anchors * 6
+ p_sort_result[n * num_anchors + i] * 6 + j)]
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])):
with ib.if_scope(i < p_valid_count[n] - nkeep):
with ib.if_scope(j < 6):
p_out[(n * num_anchors * 6
+ (i + nkeep) * 6 + j)] = p_data[(n * num_anchors * 6
+ (i + nkeep) * 6 + j)]
nkeep = tvm.if_then_else( \
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b]),
nms_topk, p_valid_count[b])
with ib.for_range(0, nkeep) as l:
with ib.if_scope(i < 6):
p_out[(base_idx + l * 6 + i)] = \
p_data[(base_idx + p_sort_result[b * num_anchors + l] * 6 + i)]
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])):
with ib.for_range(0, p_valid_count[b] - nkeep) as l:
with ib.if_scope(i < 6):
p_out[(base_idx + (l + nkeep) * 6 + i)] = \
p_data[(base_idx + (l + nkeep) * 6 + i)]
# Apply nms
with ib.if_scope(i < p_valid_count[n]):
offset_i = i * 6
with ib.if_scope(p_out[n * num_anchors * 6 + offset_i] >= 0):
with ib.if_scope(j < p_valid_count[n]):
offset_j = j * 6
with ib.if_scope(tvm.all(j > i, p_out[n * num_anchors * 6
+ offset_j] >= 0)):
with ib.for_range(0, p_valid_count[b]) as l:
offset_l = l * 6
with ib.if_scope(p_out[base_idx + offset_l] >= 0):
with ib.if_scope(i < p_valid_count[b]):
offset_i = i * 6
with ib.if_scope(tvm.all(i > l, p_out[base_idx
+ offset_i] >= 0)):
with ib.if_scope(tvm.any(force_suppress_node > 0,
p_out[n * num_anchors * 6 + offset_i] ==
p_out[n * num_anchors * 6 + offset_j])):
p_out[base_idx + offset_l] ==
p_out[base_idx + offset_i])):
# When force_suppress == True or class_id equals
iou = calculate_overlap(
p_out, n * num_anchors * 6 + offset_i + 2,
n * num_anchors * 6 + offset_j + 2)
iou = calculate_overlap(p_out, base_idx + offset_l + 2,
base_idx + offset_i + 2)
with ib.if_scope(iou >= nms_threshold):
p_out[
n * num_anchors * 6 + offset_j] = -1.0
p_out[base_idx + offset_i] = -1.0
with ib.else_scope():
with ib.if_scope(i < p_valid_count[n]):
with ib.if_scope(j < 6):
p_out[(n * num_anchors * 6
+ i * 6 + j)] = p_data[n * num_anchors * 6 + i * 6 + j]
with ib.for_range(0, p_valid_count[b]) as c:
with ib.if_scope(i < 6):
p_out[(base_idx + c * 6 + i)] = p_data[base_idx + c * 6 + i]
# Set invalid entry to be -1
with ib.if_scope(i < num_anchors - p_valid_count[n]):
with ib.if_scope(j < 6):
p_out[n * num_anchors * 6 + (i +
p_valid_count[n]) * 6 + j] = -1.0
with ib.for_range(0, num_anchors - p_valid_count[b]) as c:
with ib.if_scope(i < 6):
p_out[base_idx + (c + p_valid_count[b]) * 6 + i] = -1.0
body = ib.get()
return body
......@@ -610,18 +237,26 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
"valid_count_buf", data_alignment=4)
data_buf = api.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8)
score_axis = 1
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(
score_shape, lambda i, j: data[i, j, score_axis], name="score_tensor")
score_shape, lambda i, j: data[i, j, 1], name="score_tensor")
score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
"score_tensor_buf", data_alignment=8)
sort_tensor_dtype = "int32"
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8)
sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count,
valid_count_buf, sort_tensor_buf, score_axis, True)
sort_tensor = \
tvm.extern(score_shape,
[score_tensor, valid_count],
lambda ins, outs: sort_ir(
ins[0], ins[1], outs[0]),
dtype=sort_tensor_dtype,
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
out = \
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
......
"""
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
**Author**: `Yao Wang <https://github.com/kevinthesun>`_, \
`Leyuan Wang <https://github.com/Laurawly>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use mxnet pretrained SSD model with Resnet50 as body network and
......@@ -32,17 +33,20 @@ from mxnet.model import load_checkpoint
# echo "set(USE_SORT ON)" > config.mk
# make -j8
#
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
#
model_name = "ssd_resnet50_512"
model_file = "%s.zip" % model_name
test_image = "dog.jpg"
dshape = (1, 3, 512, 512)
dtype = "float32"
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#ctx = tvm.gpu(0)
# Use these commented settings to build for opencl.
#target = 'opencl'
#ctx = tvm.opencl(0)
target = "llvm"
ctx = tvm.cpu()
......@@ -56,7 +60,8 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
"resnet50_ssd_512_voc0712_trainval.zip"
image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_folder = \
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
......@@ -92,7 +97,8 @@ parser.add_argument(
default="nnvm")
args = parser.parse_args()
if args.frontend == "relay":
net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, aux_params=aux_params)
net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, \
aux_params=aux_params)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
elif args.frontend == "nnvm":
......@@ -134,7 +140,7 @@ def display(img, out, thresh=0.5):
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (10,10)
mpl.rcParams['figure.figsize'] = (10, 10)
pens = dict()
plt.clf()
plt.imshow(img)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment