Commit 3d74b48f by Yao Wang Committed by Tianqi Chen

Add ssd op with ir builder (#1095)

parent e8e84fa0
......@@ -126,6 +126,7 @@ stage('Build') {
echo USE_SORT=1 >> config.mk
echo USE_OPENGL=1 >> config.mk
echo LLVM_CONFIG=llvm-config-4.0 >> config.mk
echo USE_SORT=1 >> config.mk
"""
make('cpu', '-j2')
pack_lib('cpu', tvm_lib)
......
......@@ -201,6 +201,7 @@ ifeq ($(USE_GRAPH_RUNTIME_DEBUG), 1)
endif
include make/contrib/cblas.mk
include make/contrib/sort.mk
include make/contrib/random.mk
include make/contrib/nnpack.mk
include make/contrib/cudnn.mk
......
......@@ -90,3 +90,6 @@ USE_CUBLAS = 0
# Whether use rocBlas
USE_ROCBLAS = 0
# Whether use contrib sort
USE_SORT = 0
SORT_CONTRIB_SRC = $(wildcard src/contrib/sort/*.cc)
SORT_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(SORT_CONTRIB_SRC))
ifeq ($(USE_SORT), 1)
RUNTIME_DEP += $(SORT_CONTRIB_OBJ)
endif
......@@ -300,7 +300,14 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag=""):
return res[0] if len(res) == 1 else res
def extern(shape, inputs, fcompute, name="extern", dtype=None, tag=""):
def extern(shape,
inputs,
fcompute,
name="extern",
dtype=None,
in_buffers=None,
out_buffers=None,
tag=""):
"""Compute several tensor via extern function.
Parameters
......@@ -332,6 +339,12 @@ def extern(shape, inputs, fcompute, name="extern", dtype=None, tag=""):
The data types of outputs,
by default dtype will be same as inputs.
in_buffers: Buffer or list of Buffer, optional
Input buffers.
out_buffers: Buffer or list of Buffers, optional
Output buffers.
Returns
-------
tensor: Tensor or list of Tensors
......@@ -357,14 +370,25 @@ def extern(shape, inputs, fcompute, name="extern", dtype=None, tag=""):
tag = _tag.TagScope.current.tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
input_placeholders = []
output_placeholders = []
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
% (len(inputs), len(in_buffers)))
if out_buffers is not None:
out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
if len(shape) != len(out_buffers):
raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
% (len(shape), len(out_buffers)))
input_placeholders = in_buffers or []
output_placeholders = out_buffers or []
types = set()
for t in inputs:
if not isinstance(t, _tensor.Tensor):
raise ValueError("expect inputs to be tensor")
input_placeholders.append(
decl_buffer(t.shape, t.dtype, t.op.name))
if in_buffers is None:
input_placeholders.append(
decl_buffer(t.shape, t.dtype, t.op.name))
types.add(t.dtype)
if dtype is None:
......@@ -375,8 +399,9 @@ def extern(shape, inputs, fcompute, name="extern", dtype=None, tag=""):
if isinstance(dtype, str):
dtype = [dtype]
for shp, dt in zip(shape, dtype):
output_placeholders.append(decl_buffer(shp, dt, name))
if out_buffers is None:
for shp, dt in zip(shape, dtype):
output_placeholders.append(decl_buffer(shp, dt, name))
body = fcompute(input_placeholders, output_placeholders)
if isinstance(body, _expr.Expr):
body = _make.Evaluate(body)
......
/*!
* Copyright (c) 2017 by Contributors
* \file Use standard C library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dlpack/dlpack.h>
#include <algorithm>
#include <vector>
namespace tvm {
namespace contrib {
using namespace runtime;
template<typename DType>
bool CompareAscend(const std::pair<int32_t, DType>& lhs,
const std::pair<int32_t, DType>& rhs) {
return lhs.second < rhs.second;
}
template<typename DType>
bool CompareDescend(const std::pair<int32_t, DType>& lhs,
const std::pair<int32_t, DType>& rhs) {
return lhs.second > rhs.second;
}
// Argsort implemented C library sort.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// sort_num specify the number of elements to be sorted.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *input = args[0];
DLTensor *sort_num = args[1];
DLTensor *output = args[2];
int32_t axis = args[3];
bool is_descend = args[4];
auto dtype = input->dtype;
auto data_ptr = static_cast<float *>(input->data);
auto sort_num_ptr = static_cast<int32_t *>(sort_num->data);
std::vector<std::pair<int32_t, float>> sorter;
int64_t axis_mul_before = 1;
int64_t axis_mul_after = 1;
if (axis < 0) {
axis = input->ndim + axis;
}
// Currently only supports input dtype to be float32.
CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
"to be float32.";
CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
"to be float32.";
CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
"input ndim " << input->ndim;
for (int i = 0; i < input->ndim; ++i) {
if (i < axis) {
axis_mul_before *= input->shape[i];
} else if (i > axis) {
axis_mul_after *= input->shape[i];
}
}
for (int64_t i = 0 ; i < axis_mul_before; ++i) {
for (int64_t j = 0 ; j < axis_mul_after; ++j) {
sorter.clear();
int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j);
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
for (int64_t k = 0; k < current_sort_num; ++k) {
int64_t full_idx = base_idx + k * axis_mul_after;
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
}
std::stable_sort(sorter.begin(), sorter.end(),
is_descend ? CompareDescend<float>
: CompareAscend<float>);
for (int32_t k = 0; k < input->shape[axis]; ++k) {
*(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
= k < sorter.size() ? sorter[k].first : k;
}
}
}
});
} // namespace contrib
} // namespace tvm
import tvm
import numpy as np
def test_sort():
n = 2
l = 5
m = 3
data = tvm.placeholder((n, l, m), name='data')
sort_num = tvm.placeholder((n, m), name="sort_num", dtype="int32")
axis = 1
is_descend = True
out = tvm.extern(data.shape, [data, sort_num],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort", ins[0],
ins[1], outs[0], axis, is_descend),
dtype='int32', name="sort_tensor")
input = [[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]]
sort_num_input = [[1, 2, 3], [4, 5, 5]]
sorted_index = [[[0, 1, 1], [1, 0, 0], [2, 2, 2], [3, 3, 3], [4, 4, 4]],
[[3, 4, 4], [2, 3, 3], [1, 2, 2], [0, 1, 1], [4, 0, 0]]]
ctx = tvm.cpu(0)
target = "llvm"
s = tvm.create_schedule(out.op)
f = tvm.build(s, [data, sort_num, out], target)
a = tvm.nd.array(np.array(input).astype(data.dtype), ctx)
b = tvm.nd.array(np.array(sort_num_input).astype(sort_num.dtype), ctx)
c = tvm.nd.array(np.zeros(a.shape, dtype=out.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(c.asnumpy(), np.array(sorted_index).astype(out.dtype), rtol=1e-5)
def test_sort_np():
dshape = (1, 2, 3, 4, 5, 6)
axis = 4
reduced_shape = (1, 2, 3, 4, 6)
is_descend = False
data = tvm.placeholder(dshape, name='data')
sort_num = tvm.placeholder(reduced_shape, name="sort_num", dtype="int32")
out = tvm.extern(data.shape, [data, sort_num],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort", ins[0],
ins[1], outs[0], axis, is_descend),
dtype='int32', name="sort_tensor")
ctx = tvm.cpu(0)
target = "llvm"
s = tvm.create_schedule(out.op)
f = tvm.build(s, [data, sort_num, out], target)
np_data = np.random.uniform(size=dshape)
np_out = np.argsort(np_data, axis=axis)
sort_num_input = np.full(reduced_shape, dshape[axis])
a = tvm.nd.array(np.array(np_data).astype(data.dtype), ctx)
b = tvm.nd.array(np.array(sort_num_input).astype(sort_num.dtype), ctx)
c = tvm.nd.array(np.zeros(a.shape, dtype=out.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5)
if __name__ == "__main__":
test_sort()
test_sort_np()
# pylint: disable=invalid-name, unused-variable
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for vision operators"""
from __future__ import absolute_import as _abs
import tvm
......@@ -40,3 +40,37 @@ def schedule_region(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_region(cpp_target, outs)
@generic.schedule_multibox_prior.register(["cuda", "gpu"])
def schedule_multibox_prior(out):
"""Schedule for multibox_prior operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of multibox_prior
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for multibox_prior.
"""
raise RuntimeError("Currently multibox_prior only supports CPU.")
@generic.schedule_multibox_detection.register(["cuda", "gpu"])
def schedule_multibox_detection(out):
"""Schedule for multibox_detection operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of multibox_detection
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for multibox_detection.
"""
raise RuntimeError("Currently multibox_detection only supports CPU.")
# pylint: disable=invalid-name, no-member
"""Generic vision operators"""
from __future__ import absolute_import as _abs
import tvm
......@@ -70,3 +71,72 @@ def schedule_region(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_nms(outs):
"""Schedule for non-maximum suppression
Parameters
----------
outs: Array of Tensor
The computation graph description of nms
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_multibox_prior(outs):
"""Schedule for multibox_prior
Parameters
----------
outs: Array of Tensor
The computation graph description of multibox_prior
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_multibox_transform_loc(outs):
"""Schedule for multibox_transform_loc
Parameters
----------
outs: Array of Tensor
The computation graph description of
multibox_transform_loc in the format
of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_multibox_detection(outs):
"""Schedule for multibox_detection
Parameters
----------
outs: Array of Tensor
The computation graph description of multibox_detection
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
# pylint: disable=invalid-name, unused-variable
"""NN operator common utilities"""
from __future__ import absolute_import
import tvm
from ..util import get_const_int
def infer_pad(data, data_pad):
......@@ -53,8 +55,8 @@ def infer_stride(data, kernel, out):
_, _, IH, IW = data.shape
_, _, KH, KW = kernel.shape
_, _, OH, OW = out.shape
hstride = (IH - KH) // (OH - 1)
wstride = (IW - KW) // (OW - 1)
hstride = (IH - KH) // tvm.make.Max(OH - 1, 1) + tvm.select(OH == 1, 1, 0)
wstride = (IW - KW) // tvm.make.Max(OW - 1, 1) + tvm.select(OW == 1, 1, 0)
return get_const_int(hstride), get_const_int(wstride)
......
......@@ -2,6 +2,7 @@
"""VISION network operators"""
from __future__ import absolute_import as _abs
from . import yolo2
from . import yolo2, ssd
from .shortcut import *
from .reorg import *
from .nms import *
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
"""Non-maximum suppression operator"""
import tvm
from tvm import api
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
"""Low level IR routing for transform location in multibox_detection operator.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
sort_result : Buffer
Buffer of output box indexes sorted by score.
valid_count : Buffer
Buffer of number of valid output boxes.
out : Buffer
Output buffer.
nms_threshold : float
Non-maximum suppression threshold.
force_suppress : boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
Keep maximum top k detections before nms, -1 for no limit.
Returns
-------
stmt : Stmt
The result IR statement.
"""
def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes.
"""
w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
- tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
- tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
i = w * h
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.select(u <= 0.0, 0.0, i / u)
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
batch_size = out.shape[0]
num_anchors = out.shape[1]
nms_threshold_node = tvm.make.node("FloatImm", dtype="float32", value=nms_threshold)
nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
force_suppress_node = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="parallel", name="n") as n:
with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
# Reorder output
nkeep = tvm.select(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n])
with ib.for_range(0, nkeep, name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[(n * num_anchors * 6
+ l * 6 + m)] = p_data[(n * num_anchors * 6
+ p_sort_result[n * num_anchors + l] * 6 + m)]
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])):
with ib.for_range(0, p_valid_count[n] - nkeep, name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[(n * num_anchors * 6
+ (l + nkeep) * 6 + m)] = p_data[(n * num_anchors * 6
+ (l + nkeep) * 6 + m)]
# Apply nms
with ib.for_range(0, p_valid_count[n], name="l") as l:
offset_l = l * 6
with ib.if_scope(p_out[n * num_anchors * 6 + offset_l] >= 0):
with ib.for_range(0, p_valid_count[n], name="m") as m:
offset_m = m * 6
with ib.if_scope(tvm.all(m > l, p_out[n * num_anchors * 6
+ offset_m] >= 0)):
with ib.if_scope(tvm.any(force_suppress_node > 0,
p_out[n * num_anchors * 6 + offset_l] ==
p_out[n * num_anchors * 6 + offset_m])):
# When force_suppress == True or class_id equals
iou = calculate_overlap(p_out, n * num_anchors * 6 + offset_l + 2,
n * num_anchors * 6 + offset_m + 2)
with ib.if_scope(iou >= nms_threshold):
p_out[n * num_anchors * 6 + offset_m] = -1.0
with ib.else_scope():
with ib.for_range(0, p_valid_count[n], name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[(n * num_anchors * 6
+ l * 6 + m)] = p_data[n * num_anchors * 6 + l * 6 + m]
# Set invalid entry to be -1
with ib.for_range(0, num_anchors - p_valid_count[n], name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[n * num_anchors * 6 + (l + p_valid_count[n]) * 6 + m] = -1.0
return ib.get()
@tvm.target.generic_func
def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
nms_threshold : float
Non-maximum suppression threshold.
force_suppress : boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
Keep maximum top k detections before nms, -1 for no limit.
Returns
-------
out : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
Example
--------
.. code-block:: python
# An example to use nms
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
"valid_count_buf", data_alignment=4)
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
score_axis = 1
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
"score_tensor_buf", data_alignment=8)
sort_tensor_dtype = "int32"
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8)
sort_tensor = \
tvm.extern(score_shape,
[score_tensor, valid_count],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort", ins[0], ins[1],
outs[0], score_axis, True),
dtype=sort_tensor_dtype,
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
out = \
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], nms_threshold,
force_suppress, nms_topk),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
return out
# pylint: disable=wildcard-import
"""VISION network operators"""
from __future__ import absolute_import as _abs
from .multibox import *
"""Test code for vision package"""
import numpy as np
import tvm
import topi
import math
from topi.vision import ssd, nms
def test_nms():
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
force_suppress = True
nms_topk = 2
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1]]])
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_nms(out)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f = tvm.build(s, [data, valid_count, out], device)
f(tvm_data, tvm_valid_count, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
for device in ['llvm']:
check_device(device)
def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
data = tvm.placeholder(dshape, name="data")
out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
dtype = data.dtype
input_data = np.random.uniform(size=dshape).astype(dtype)
in_height = data.shape[2].value
in_width = data.shape[3].value
num_sizes = len(sizes)
num_ratios = len(ratios)
size_ratio_concat = sizes + ratios
steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
offset_h = offsets[0]
offset_w = offsets[1]
oshape = (1, in_height * in_width * (num_sizes + num_ratios - 1), 4)
np_out = np.zeros(oshape).astype(dtype)
for i in range(in_height):
center_h = (i + offset_h) * steps_h
for j in range(in_width):
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
w = size_ratio_concat[k] * in_height / in_width / 2.0 if k < num_sizes else \
size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0
h = size_ratio_concat[k] / 2.0 if k < num_sizes else \
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0
count = i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k
np_out[0][count][0] = center_w - w
np_out[0][count][1] = center_h - h
np_out[0][count][2] = center_w + w
np_out[0][count][3] = center_h + h
if clip:
np_out = np.clip(np_out, 0, 1)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_multibox_prior(out)
tvm_input_data = tvm.nd.array(input_data, ctx)
tvm_out = tvm.nd.array(np.zeros(oshape, dtype=dtype), ctx)
f = tvm.build(s, [data, out], device)
f(tvm_input_data, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
for device in ['llvm']:
check_device(device)
def test_multibox_prior():
verify_multibox_prior((1, 3, 50, 50))
verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5))
verify_multibox_prior((1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True)
def test_multibox_detection():
batch_size = 1
num_anchors = 3
num_classes = 3
cls_prob = tvm.placeholder((batch_size, num_anchors, num_classes), name="cls_prob")
loc_preds = tvm.placeholder((batch_size, num_anchors * 4), name="loc_preds")
anchors = tvm.placeholder((1, num_anchors, 4), name="anchors")
out = ssd.multibox_detection(cls_prob, loc_preds, anchors)
# Manually create test case
np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]])
np_loc_preds = np.array([[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4, -0.8]])
np_anchors = np.array([[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2], [1.2, 1.2, 1.5, 1.5]]])
expected_np_out = np.array([[[1, 0.69999999, 0, 0, 0.10818365, 0.10008108],
[0, 0.44999999, 1, 1, 1, 1],
[0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_multibox_detection(out)
tvm_cls_prob = tvm.nd.array(np_cls_prob.astype(cls_prob.dtype), ctx)
tvm_loc_preds = tvm.nd.array(np_loc_preds.astype(loc_preds.dtype), ctx)
tvm_anchors = tvm.nd.array(np_anchors.astype(anchors.dtype), ctx)
tvm_out = tvm.nd.array(np.zeros((batch_size, num_anchors, 6)).astype(out.dtype), ctx)
f = tvm.build(s, [cls_prob, loc_preds, anchors, out], device)
f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4)
for device in ['llvm']:
check_device(device)
if __name__ == "__main__":
test_nms()
test_multibox_prior()
test_multibox_detection()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment