Unverified Commit c4c61cb7 by Leyuan Wang Committed by GitHub

[Fix] Fix get_valid_count flaky test for cuda (#4901)

* get_valid_count accuracy issue fixed for individual tests but not for all tests running together

* minor fix

* initialize valid_count and PrefixSum buffers

* test updated

* udpate relay test as well

* update document

* fix lint

* address comment

* fix lint

* correct atomicAdd identifier name
parent 8290eaba
......@@ -221,8 +221,6 @@ def test_get_valid_counts():
func = relay.Function([x], z.astuple())
func = run_infer_type(func)
for target, ctx in ctx_list():
if target == 'cuda':
return
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
......
......@@ -21,29 +21,46 @@ import math
import tvm
from tvm import api
from tvm.generic import cast
from tvm.intrin import if_then_else, log, power
from tvm.intrin import if_then_else
from topi.vision import non_max_suppression, get_valid_counts
from .sort import argsort
from .. import tag
def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index):
"""Low level IR to Prepare get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the
def cuda_atomic_add_rule(op):
if op.dtype == "float32":
return tvm.call_pure_extern("float32", "atomicAdd", op.args[0], op.args[1])
if op.dtype == "float64":
return tvm.call_pure_extern("float64", "atomicAdd", op.args[0], op.args[1])
if op.dtype == "int32":
return tvm.call_pure_extern("int32", "atomicAdd", op.args[0], op.args[1])
raise RuntimeError("only support int32, float32 and float64")
tvm.target.intrin.register_intrin_rule(
"cuda", "atomic_add", cuda_atomic_add_rule, override=True)
def atomic_add(x, y):
return tvm.call_pure_intrin(y.dtype, "atomic_add", x, y)
def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
top of input data.
Parameters
----------
data: Buffer
3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
data : Buffer
Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].
valid_count : Buffer
1D buffer for valid number of boxes with shape [batch_size, ].
flag : Buffer
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
idx : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
score_threshold : float32
Lower limit of score for valid bounding boxes.
......@@ -60,18 +77,24 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
box_data_length = data.shape[2]
elem_length = data.shape[2]
ib = tvm.ir_builder.create()
data = ib.buffer_ptr(data)
valid_count = ib.buffer_ptr(valid_count)
flag = ib.buffer_ptr(flag)
idx = ib.buffer_ptr(idx)
score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
atomic_add_return = ib.allocate(
valid_count.dtype, (1,), name='atomic_add_return', scope='local')
one_count = tvm.const(1, dtype=valid_count.dtype)
score_threshold = tvm.make.node(
"FloatImm", dtype="float32", value=score_threshold)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
max_threads = int(tvm.target.Target.current(
allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
......@@ -79,163 +102,52 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
idxd = tvm.indexdiv
# initialize valid_count
with ib.if_scope(tid < batch_size):
valid_count[tid] = 0
# initialize flag
with ib.if_scope(tid < batch_size * num_anchors):
with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \
tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))):
flag[tid] = 0
with ib.if_scope(tid < batch_size * num_anchors):
i = idxd(tid, num_anchors)
with ib.if_scope(tvm.all(data[tid * elem_length + score_index] > score_threshold,
tvm.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
flag[tid] = 1
idx[tid] = 1
with ib.else_scope():
flag[tid] = 0
idx[tid] = 0
atomic_add_return[0] = atomic_add(tvm.call_pure_intrin("handle", "tvm_address_of",
valid_count[i]), one_count)
return ib.get()
def get_valid_counts_upsweep(data, idx_in, idx, partial):
"""Low level IR of first step of scan: unsweep.
Parameters
----------
data: Buffer
3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
idx_in : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
idx : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
partial : Buffer
2D Buffer of valid data indices with shape [batch_size, new_range].
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
ib = tvm.ir_builder.create()
data = ib.buffer_ptr(data)
idx_in = ib.buffer_ptr(idx_in)
idx = ib.buffer_ptr(idx)
partial = ib.buffer_ptr(partial)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1
nthread_tx = max_threads
nthread_bx = batch_size
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
new_range = num_anchors // elem_per_thread + 1
# Scan: Upsweep:
with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)):
with ib.for_range(0, elem_per_thread) as i:
with ib.if_scope(bx * num_anchors + \
tx * elem_per_thread + i < batch_size * num_anchors):
with ib.if_scope(i == 0):
partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread]
idx[bx * num_anchors + tx * elem_per_thread] = \
idx_in[bx * num_anchors + tx * elem_per_thread]
with ib.else_scope():
partial[bx * new_range + tx] += \
idx_in[bx * num_anchors + tx * elem_per_thread + i]
idx[bx * num_anchors + tx * elem_per_thread + i] = \
idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \
idx_in[bx * num_anchors + tx * elem_per_thread + i]
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
return ib.get()
def get_valid_counts_scan(data, partial_in, partial):
"""Low level IR to do scan.
def flag_scan(flag, prefix_sum):
"""Low level IR to calculate correct positions for valid boxes.
Parameters
----------
data: Buffer
3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
idx_in : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
idx : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
flag : Buffer
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
partial : Buffer
2D Buffer of valid data indices with shape [batch_size, new_range].
prefix_sum : Buffer
2D Buffer of prefix sum of flags indicating new locations of valid boxes
with same shape as flag.
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
ib = tvm.ir_builder.create()
partial_in = ib.buffer_ptr(partial_in)
partial = ib.buffer_ptr(partial)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1
nthread_tx = max_threads
nthread_bx = batch_size
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
var = tvm.make.node("FloatImm", dtype="float32", value=2)
new_range = num_anchors // elem_per_thread + 1
iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
# Scan: Kogge-Stone adder
with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
with ib.for_range(0, iteration) as k:
with ib.if_scope(k == 0):
with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))):
partial[bx * new_range + tx] = \
partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1]
with ib.else_scope():
partial[bx * new_range] = partial_in[bx * new_range]
with ib.else_scope():
with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \
tx < tvm.min(new_range, num_anchors))):
partial[bx * new_range + tx] += \
partial[bx * new_range + tx - cast(power(var, k), "int32")]
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
return ib.get()
def get_valid_counts_downsweep(data, idx_in, partial, idx):
"""Low level IR to do downsweep of scan.
Parameters
----------
data: Buffer
3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
idx_in : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
batch_size = flag.shape[0]
num_anchors = flag.shape[1]
partial : Buffer
2D Buffer of valid data indices with shape [batch_size, new_range].
ib = tvm.ir_builder.create()
idx : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
flag = ib.buffer_ptr(flag)
prefix_sum = ib.buffer_ptr(prefix_sum)
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
ib = tvm.ir_builder.create()
idx_in = ib.buffer_ptr(idx_in)
idx = ib.buffer_ptr(idx)
partial = ib.buffer_ptr(partial)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1
max_threads = int(tvm.target.Target.current(
allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
......@@ -243,23 +155,23 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
new_range = num_anchors // elem_per_thread + 1
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Scan: Downsweep:
with ib. if_scope(tid < batch_size * num_anchors):
i = idxd(tid, num_anchors) # number of batches
j = idxm(tid, num_anchors) # number of anchors
with ib.if_scope(j < elem_per_thread):
idx[tid] = idx_in[tid]
with ib.else_scope():
idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]
# initialize prefix_sum
with ib.if_scope(tid < batch_size * num_anchors):
prefix_sum[tid] = 0
with ib.if_scope(tid < batch_size * num_anchors):
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
with ib.for_range(0, j) as r:
prefix_sum[tid] += flag[i * num_anchors + r]
return ib.get()
def get_valid_counts_ir(data, flag, idx, valid_count, out):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the
def out_rewrite(data, flag, prefix_sum, valid_count, out):
"""Low level IR to move valid boxes to the
top of input data.
Parameters
......@@ -270,11 +182,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
flag : Buffer
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
idx : Buffer
2D Buffer of valid data indices with shape [batch_size, num_anchors].
prefix_sum : Buffer
2D Buffer of prefix sum of flags indicating new locations of valid boxes
with same shape as flag.
valid_count : Buffer
1-D buffer for valid number of boxes.
1D buffer for valid number of boxes with shape [batch_size, ].
out : Buffer
Rearranged data buffer.
......@@ -284,28 +197,28 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
stmt : Stmt
The result IR statement.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]
size = batch_size * num_anchors * elem_length
batch_size = out.shape[0]
num_anchors = out.shape[1]
elem_length = out.shape[2]
ib = tvm.ir_builder.create()
one = tvm.const(1, dtype=out.dtype)
data = ib.buffer_ptr(data)
flag = ib.buffer_ptr(flag)
idx = ib.buffer_ptr(idx)
valid_count = ib.buffer_ptr(valid_count)
prefix_sum = ib.buffer_ptr(prefix_sum)
out = ib.buffer_ptr(out)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
max_threads = int(tvm.target.Target.current(
allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
idxd = tvm.indexdiv
idxm = tvm.indexmod
......@@ -313,17 +226,15 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
base_idx = i * num_anchors * elem_length
with ib.if_scope(flag[tid] > 0):
with ib.if_scope(tvm.all(flag[tid] > 0, prefix_sum[tid] >= 0,
prefix_sum[tid] < num_anchors)):
with ib.for_range(0, elem_length) as k:
out[base_idx + prefix_sum[tid] * elem_length +
k] = data[tid * elem_length + k]
with ib.if_scope(j >= valid_count[i]):
with ib.for_range(0, elem_length) as k:
with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size):
out[base_idx + (idx[tid] - 1) * elem_length + k] =\
data[base_idx + j * elem_length + k]
with ib.if_scope(j == 0):
valid_count[i] = idx[tid + num_anchors - 1]
with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]):
with ib.for_range(0, elem_length) as l:
with ib.if_scope(tid * elem_length + l < size):
out[tid * elem_length + l] = -1.0
out[tid * elem_length + k] = -one
return ib.get()
......@@ -356,56 +267,47 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
elem_per_thread = num_anchors // max_threads + 1
new_range = num_anchors // elem_per_thread + 1
data_buf = api.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8)
valid_count_buf = api.decl_buffer(
(batch_size,), "int32", "valid_count_buf", data_alignment=8)
temp_flag_buf = api.decl_buffer(
(batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8)
temp_idx_buf = api.decl_buffer(
(batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8)
temp_partial_buf = api.decl_buffer(
(batch_size, new_range), "int32", "temp_partial", data_alignment=8)
data_buf = api.decl_buffer(
data.shape, data.dtype, "data_buf", data_alignment=8)
(batch_size, num_anchors), "int32", "temp_partial", data_alignment=8)
out_buf = api.decl_buffer(
data.shape, data.dtype, "out_buf", data_alignment=8)
temp_flag, temp_idx = \
tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
lambda ins, outs: get_valid_counts_pre(
ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
dtype=["int32", "int32"],
out_buffers=[temp_flag_buf, temp_idx_buf],
name="get_valid_counts_phase_one")
temp_idx_new, temp_partial = \
tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx],
lambda ins, outs: get_valid_counts_upsweep(
ins[0], ins[1], outs[0], outs[1]),
dtype=["int32", "int32"],
out_buffers=[temp_idx_buf, temp_partial_buf],
name="get_valid_counts_phase_two")
temp_partial_new = \
tvm.extern([(batch_size, new_range)], [data, temp_partial],
lambda ins, outs: get_valid_counts_scan(
ins[0], ins[1], outs[0]),
dtype=["int32"],
out_buffers=[temp_partial_buf],
name="get_valid_counts_phase_three")
temp_idx_final = \
tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new],
lambda ins, outs: get_valid_counts_downsweep(
ins[0], ins[1], ins[2], outs[0]),
dtype=["int32"],
out_buffers=[temp_idx_buf],
name="get_valid_counts_phase_four")
valid_count, out_tensor = \
tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final],
lambda ins, outs: get_valid_counts_ir(
ins[0], ins[1], ins[2], outs[0], outs[1]),
dtype=["int32", data.dtype],
in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
name="get_valid_counts_phase_five",
valid_count, temp_flag = \
tvm.extern([(batch_size,), (batch_size, num_anchors)], [data],
lambda ins, outs: get_valid_counts_ir(
ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
dtype=["int32", "int32"],
in_buffers=[data_buf],
out_buffers=[valid_count_buf, temp_flag_buf],
name="get_valid_counts",
tag="get_valid_counts_gpu")
return [valid_count, out_tensor]
temp_partial = \
tvm.extern([(batch_size, num_anchors)], [temp_flag],
lambda ins, outs: flag_scan(
ins[0], outs[0]),
dtype=["int32"],
in_buffers=[temp_flag_buf],
out_buffers=[temp_partial_buf],
name="flag_scan")
out = \
tvm.extern([data.shape], [data, temp_flag, temp_partial, valid_count],
lambda ins, outs: out_rewrite(
ins[0], ins[1], ins[2], ins[3], outs[0]),
dtype=[data.dtype],
in_buffers=[data_buf, temp_flag_buf,
temp_partial_buf, valid_count_buf],
out_buffers=[out_buf],
name="out_rewrite")
return [valid_count, out]
def nms_ir(data, sorted_index, valid_count, out, box_indices,
......@@ -479,7 +381,8 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out)
box_indices = ib.buffer_ptr(box_indices)
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
num_valid_boxes = ib.allocate(
"int32", (1,), name="num_valid_boxes", scope="local")
max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads)
......@@ -491,26 +394,29 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
ib.scope_attr(bx, "thread_extent", nthread_bx)
j = bx * max_threads + tx
iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
iou_threshold = tvm.make.node(
"FloatImm", dtype="float32", value=iou_threshold)
top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
force_suppress = tvm.make.node(
"IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="unroll") as i:
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
# Reorder output
nkeep = if_then_else( \
tvm.all(top_k > 0, top_k < valid_count[i]),
top_k, valid_count[i])
nkeep = if_then_else(
tvm.all(top_k > 0, top_k < valid_count[i]),
top_k, valid_count[i])
with ib.if_scope(j < nkeep):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = \
data[(base_idx + sorted_index[i * num_anchors + j] \
* box_data_length + k)]
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
data[(base_idx + sorted_index[i * num_anchors + j]
* box_data_length + k)]
box_indices[i * num_anchors +
j] = sorted_index[i * num_anchors + j]
with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
with ib.if_scope(j < valid_count[i] - nkeep):
with ib.for_range(0, box_data_length) as k:
......@@ -519,16 +425,18 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
# Apply nms
with ib.for_range(0, valid_count[i]) as k:
offset_k = k * box_data_length
with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0,
tvm.any(id_index < 0, out[base_idx +
offset_k + id_index] >= 0))):
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(tvm.all(j > k, \
out[base_idx + offset_j + score_index] > 0, \
tvm.any(id_index < 0, \
out[base_idx + offset_j + id_index] >= 0), \
tvm.any(force_suppress > 0, id_index < 0, \
out[base_idx + offset_k + id_index] == \
with ib.if_scope(tvm.all(j > k,
out[base_idx + offset_j +
score_index] > 0,
tvm.any(id_index < 0,
out[base_idx + offset_j + id_index] >= 0),
tvm.any(force_suppress > 0, id_index < 0,
out[base_idx + offset_k + id_index] ==
out[base_idx + offset_j + id_index]))):
iou = calculate_overlap(out, base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start)
......@@ -541,12 +449,14 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.for_range(0, box_data_length) as k:
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
out[(base_idx + offset_j + k)
] = data[base_idx + offset_j + k]
box_indices[i * num_anchors + j] = j
# Set invalid entry to be -1
with ib.if_scope(j < num_anchors - valid_count[i]):
with ib.for_range(0, box_data_length) as k:
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
out[base_idx + (j + valid_count[i]) *
box_data_length + k] = -1.0
box_indices[i * num_anchors + j + valid_count[i]] = -1
# Only return max_output_size number of valid boxes
num_valid_boxes[0] = 0
......@@ -671,7 +581,7 @@ def invalid_to_bottom_ir(data, flag, idx, out):
with ib.if_scope(flag[i * num_anchors + j] > 0):
with ib.for_range(0, elem_length) as k:
out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
= data[base_idx + j * elem_length + k]
= data[base_idx + j * elem_length + k]
return ib.get()
......@@ -756,8 +666,10 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
"valid_count_buf", data_alignment=4)
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
score_tensor = tvm.compute(
score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(
score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
......@@ -795,7 +707,8 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
ins[0], outs[0], outs[1]),
dtype=["int32", "int32"],
in_buffers=[out_buf],
out_buffers=[temp_flag_buf, temp_idx_buf],
out_buffers=[
temp_flag_buf, temp_idx_buf],
name="invalid_to_bottom_phase_one")
output = tvm.extern([data.shape], [out, temp_flag, temp_idx],
......
......@@ -67,8 +67,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
for device in ['llvm', 'cuda', 'opencl']:
# Disable gpu test for now
if device != "llvm":
# Disable opencl test for now
if device != "llvm" and device != "cuda":
continue
check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment