Commit d2f29ba5 by Yao Wang Committed by Tianqi Chen

[Object Detection] Gluoncv SSD support on CPU (#2353)

parent f7eff095
...@@ -58,19 +58,42 @@ struct MultiBoxTransformLocAttrs ...@@ -58,19 +58,42 @@ struct MultiBoxTransformLocAttrs
} }
}; };
/*! \brief Attributes used in non_maximum_suppression operators */ /*! \brief Attributes used in get_valid_counts operator */
struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{ struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
double overlap_threshold; double score_threshold;
TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
TVM_ATTR_FIELD(score_threshold).set_default(0.0)
.describe("Lower limit of score for valid bounding boxes.");
}
};
/*! \brief Attributes used in non_maximum_suppression operator */
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
int max_output_size;
double iou_threshold;
bool force_suppress; bool force_suppress;
int topk; int top_k;
int id_index;
bool return_indices;
bool invalid_to_bottom;
TVM_DECLARE_ATTRS(NMSAttrs, "relay.attrs.NMSAttrs") { TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(overlap_threshold).set_default(0.5) TVM_ATTR_FIELD(max_output_size).set_default(-1)
.describe("Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
TVM_ATTR_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold."); .describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(force_suppress).set_default(false) TVM_ATTR_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id."); .describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1) TVM_ATTR_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit."); .describe("Keep maximum top k detections before nms, -1 for no limit.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(return_indices).set_default(true)
.describe("Whether to return box indices in input data.");
TVM_ATTR_FIELD(invalid_to_bottom).set_default(false)
.describe("Whether to move all invalid bounding boxes to the bottom.");
} }
}; };
......
...@@ -443,17 +443,30 @@ struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocPa ...@@ -443,17 +443,30 @@ struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocPa
} }
}; };
struct NMSParam : public dmlc::Parameter<NMSParam> { struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppressionParam> {
float nms_threshold; bool return_indices;
float iou_threshold;
bool force_suppress; bool force_suppress;
int nms_topk; int top_k;
DMLC_DECLARE_PARAMETER(NMSParam) { int id_index;
DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) int max_output_size;
bool invalid_to_bottom;
DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) {
DMLC_DECLARE_FIELD(max_output_size).set_default(-1)
.describe("Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
DMLC_DECLARE_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold."); .describe("Non-maximum suppression threshold.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false) DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id."); .describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(nms_topk).set_default(-1) DMLC_DECLARE_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit."); .describe("Keep maximum top k detections before nms, -1 for no limit.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
DMLC_DECLARE_FIELD(return_indices).set_default(true)
.describe("Whether to return box indices in input data.");
DMLC_DECLARE_FIELD(invalid_to_bottom).set_default(false)
.describe("Whether to move all invalid bounding boxes to the bottom.");
} }
}; };
......
...@@ -245,11 +245,11 @@ def _contrib_multibox_detection(inputs, attrs): ...@@ -245,11 +245,11 @@ def _contrib_multibox_detection(inputs, attrs):
if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2) if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2)
nms_topk = attrs.get('nms_topk') or -1 nms_topk = attrs.get('nms_topk') or -1
new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances} new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances}
new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress, new_attrs1 = {'return_indices': False, 'iou_threshold': float(nms_threshold),
'nms_topk': int(nms_topk)} 'force_suppress': force_suppress, 'top_k': int(nms_topk)}
data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1], data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1],
inputs[2], **new_attrs0) inputs[2], **new_attrs0)
return _get_nnvm_op('nms')(data, valid_count, **new_attrs1) return _get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1)
def _elemwise_sum(inputs, _): def _elemwise_sum(inputs, _):
new_attrs = {'num_args':len(inputs)} new_attrs = {'num_args':len(inputs)}
......
...@@ -61,20 +61,25 @@ def compute_multibox_transform_loc(attrs, inputs, _): ...@@ -61,20 +61,25 @@ def compute_multibox_transform_loc(attrs, inputs, _):
reg.register_pattern("multibox_detection", OpPattern.OPAQUE) reg.register_pattern("multibox_detection", OpPattern.OPAQUE)
# non-maximum suppression # non-maximum suppression
@reg.register_schedule("nms") @reg.register_schedule("non_max_suppression")
def schedule_nms(_, outs, target): def schedule_nms(_, outs, target):
"""Schedule definition of nms""" """Schedule definition of non_max_suppression"""
with tvm.target.create(target): with tvm.target.create(target):
return topi.generic.schedule_nms(outs) return topi.generic.schedule_nms(outs)
@reg.register_compute("nms") @reg.register_compute("non_max_suppression")
def compute_nms(attrs, inputs, _): def compute_nms(attrs, inputs, _):
"""Compute definition of nms""" """Compute definition of non_max_suppression"""
nms_threshold = attrs.get_float('nms_threshold') return_indices = attrs.get_bool('return_indices')
max_output_size = attrs.get_int('max_output_size')
iou_threshold = attrs.get_float('iou_threshold')
force_suppress = attrs.get_bool('force_suppress') force_suppress = attrs.get_bool('force_suppress')
nms_topk = attrs.get_int('nms_topk') top_k = attrs.get_int('top_k')
id_index = attrs.get_int('id_index')
invalid_to_bottom = attrs.get_bool('invalid_to_bottom')
return topi.vision.nms(inputs[0], inputs[1], nms_threshold, return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
force_suppress, nms_topk) iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
reg.register_pattern("nms", OpPattern.OPAQUE) reg.register_pattern("non_max_suppression", OpPattern.OPAQUE)
...@@ -19,11 +19,13 @@ using compiler::FTVMCompute; ...@@ -19,11 +19,13 @@ using compiler::FTVMCompute;
using tvm::Tensor; using tvm::Tensor;
using tvm::Array; using tvm::Array;
DMLC_REGISTER_PARAMETER(NMSParam); DMLC_REGISTER_PARAMETER(NonMaximumSuppressionParam);
bool NMSShape(const NodeAttrs& attrs, bool NMSShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs, std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) { std::vector<TShape> *out_attrs) {
const NonMaximumSuppressionParam& param =
nnvm::get<NonMaximumSuppressionParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]"; CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]";
TShape dshape = in_attrs->at(0); TShape dshape = in_attrs->at(0);
TShape vshape = in_attrs->at(1); TShape vshape = in_attrs->at(1);
...@@ -33,7 +35,14 @@ bool NMSShape(const NodeAttrs& attrs, ...@@ -33,7 +35,14 @@ bool NMSShape(const NodeAttrs& attrs,
"(batch_size, num_anchors, 6)."; "(batch_size, num_anchors, 6).";
CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch."; CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch.";
out_attrs->clear(); out_attrs->clear();
if (param.return_indices) {
TShape oshape = TShape(2);
oshape[0] = dshape[0];
oshape[1] = dshape[1];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
} else {
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape); NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape);
}
return true; return true;
} }
...@@ -56,15 +65,15 @@ inline bool NMSInferLayout(const NodeAttrs& attrs, ...@@ -56,15 +65,15 @@ inline bool NMSInferLayout(const NodeAttrs& attrs,
return true; return true;
} }
NNVM_REGISTER_OP(nms) NNVM_REGISTER_OP(non_max_suppression)
.describe(R"doc("Non-maximum suppression." .describe(R"doc("Non-maximum suppression."
)doc" NNVM_ADD_FILELINE) )doc" NNVM_ADD_FILELINE)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr_parser(ParamParser<NMSParam>) .set_attr_parser(ParamParser<NonMaximumSuppressionParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", .set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<NMSParam>) ParamGetAttrDict<NonMaximumSuppressionParam>)
.add_arguments(NMSParam::__FIELDS__()) .add_arguments(NonMaximumSuppressionParam::__FIELDS__())
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
......
...@@ -550,7 +550,7 @@ def test_multibox_transform_loc(): ...@@ -550,7 +550,7 @@ def test_multibox_transform_loc():
anchors = sym.Variable("anchors") anchors = sym.Variable("anchors")
transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds, transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds,
anchor=anchors) anchor=anchors)
out = sym.nms(data=transform_loc_data, valid_count=valid_count) out = sym.non_max_suppression(data=transform_loc_data, valid_count=valid_count, return_indices=False)
# Manually create test case # Manually create test case
np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]]) np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]])
...@@ -573,22 +573,22 @@ def test_multibox_transform_loc(): ...@@ -573,22 +573,22 @@ def test_multibox_transform_loc():
out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype)) out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5) tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
def test_nms(): def test_non_max_suppression():
dshape = (1, 5, 6) dshape = (1, 5, 6)
data = sym.Variable("data") data = sym.Variable("data")
valid_count = sym.Variable("valid_count", dtype="int32") valid_count = sym.Variable("valid_count", dtype="int32")
nms_threshold = 0.7 iou_threshold = 0.7
force_suppress = True force_suppress = True
nms_topk = 2 top_k = 2
out = sym.nms(data=data, valid_count=valid_count, nms_threshold=nms_threshold, out = sym.non_max_suppression(data=data, valid_count=valid_count, return_indices=False,
force_suppress=force_suppress, nms_topk=nms_topk) iou_threshold=iou_threshold, force_suppress=force_suppress, top_k=top_k)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32") [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32") np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]]) [-1, -1, -1, -1, -1, -1]]])
target = "llvm" target = "llvm"
...@@ -726,7 +726,7 @@ if __name__ == "__main__": ...@@ -726,7 +726,7 @@ if __name__ == "__main__":
test_flip() test_flip()
test_multibox_prior() test_multibox_prior()
test_multibox_transform_loc() test_multibox_transform_loc()
test_nms() test_non_max_suppression()
test_slice_like() test_slice_like()
test_where() test_where()
test_argmax() test_argmax()
......
...@@ -315,4 +315,3 @@ if __name__ == '__main__': ...@@ -315,4 +315,3 @@ if __name__ == '__main__':
test_forward_slice() test_forward_slice()
test_forward_maximum() test_forward_maximum()
test_forward_minimum() test_forward_minimum()
...@@ -328,13 +328,14 @@ def _mx_multibox_detection(inputs, attrs): ...@@ -328,13 +328,14 @@ def _mx_multibox_detection(inputs, attrs):
0.2, 0.2)) 0.2, 0.2))
new_attrs1 = {} new_attrs1 = {}
new_attrs1["overlap_threshold"] = attrs.get_float("nms_threshold", 0.5) new_attrs1["return_indices"] = False
new_attrs1["iou_threshold"] = attrs.get_float("nms_threshold", 0.5)
new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False) new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False)
new_attrs1["topk"] = attrs.get_int("nms_topk", -1) new_attrs1["top_k"] = attrs.get_int("nms_topk", -1)
ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1], ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1],
inputs[2], **new_attrs0) inputs[2], **new_attrs0)
return _op.vision.nms(ret[0], ret[1], **new_attrs1) return _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1)
def _mx_batch_dot(inputs, attrs): def _mx_batch_dot(inputs, attrs):
...@@ -399,6 +400,49 @@ def _mx_proposal(inputs, attrs): ...@@ -399,6 +400,49 @@ def _mx_proposal(inputs, attrs):
return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs) return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs)
def _mx_box_nms(inputs, attrs):
force_suppress = attrs.get_bool("force_suppress", False)
iou_thresh = attrs.get_float('overlap_thresh', 0.5)
top_k = attrs.get_int('topk', -1)
valid_thresh = attrs.get_float('valid_thresh', 0)
coord_start = attrs.get_int('coord_start', 2)
score_index = attrs.get_int('score_index', 1)
id_index = attrs.get_int('id_index', -1)
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
if coord_start != 2:
raise RuntimeError('coord_start %s is not supported.' % coord_start)
if score_index != 1:
raise RuntimeError('score_index %s is not supported.' % score_index)
if id_index != -1 and int(id_index) != 0:
raise RuntimeError('id_index %s is not supported.' % id_index)
if in_format != 'corner':
raise RuntimeError('in_format %s is not supported.' % in_format)
if out_format != 'corner':
raise RuntimeError('out_format %s is not supported.' % out_format)
ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh)
nms_out = _op.vision.non_max_suppression(ret[1],
ret[0],
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
id_index=id_index,
return_indices=False,
invalid_to_bottom=True)
return nms_out
def _mx_l2_normalize(inputs, attrs):
new_attrs = {}
mode = attrs.get_str('mode', 'instance')
if mode != 'channel':
raise RuntimeError('mode %s is not supported.' % mode)
new_attrs['eps'] = attrs.get_float('eps', 1e-10)
new_attrs['axis'] = [1]
return _op.nn.l2_normalize(inputs[0], **new_attrs)
# Note: due to attribute conversion constraint # Note: due to attribute conversion constraint
# ops in the identity set must be attribute free # ops in the identity set must be attribute free
_identity_list = [ _identity_list = [
...@@ -497,6 +541,7 @@ _convert_map = { ...@@ -497,6 +541,7 @@ _convert_map = {
"BatchNorm" : _mx_batch_norm, "BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn, "LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
"slice" : _mx_slice, "slice" : _mx_slice,
"slice_like" : _mx_slice_like, "slice_like" : _mx_slice_like,
"slice_axis" : _mx_slice_axis, "slice_axis" : _mx_slice_axis,
...@@ -520,6 +565,7 @@ _convert_map = { ...@@ -520,6 +565,7 @@ _convert_map = {
"_contrib_ROIAlign" : _mx_roi_align, "_contrib_ROIAlign" : _mx_roi_align,
"_contrib_Proposal" : _mx_proposal, "_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal, "_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
# List of missing operators that are present in NNVMv1 # List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators. # TODO(tvm-tvm): support all operators.
# #
...@@ -662,6 +708,8 @@ def from_mxnet(symbol, ...@@ -662,6 +708,8 @@ def from_mxnet(symbol,
params[k] = _nd.array(v.data().asnumpy()) params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data") data = mx.sym.Variable("data")
sym = symbol(data) sym = symbol(data)
if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(sym, shape, dtype) sym = _from_mxnet_impl(sym, shape, dtype)
elif isinstance(symbol, mx.gluon.Block): elif isinstance(symbol, mx.gluon.Block):
......
...@@ -525,7 +525,7 @@ def strided_slice(data, begin, end, strides=None): ...@@ -525,7 +525,7 @@ def strided_slice(data, begin, end, strides=None):
The indices to begin with in the slicing. The indices to begin with in the slicing.
end: list of int end: list of int
Indicies indicating end of the slice. Indices indicating end of the slice.
strides: list of int, optional strides: list of int, optional
Specifies the stride values, it can be negative in that case, Specifies the stride values, it can be negative in that case,
......
...@@ -6,6 +6,6 @@ from .multibox import * ...@@ -6,6 +6,6 @@ from .multibox import *
from .nms import * from .nms import *
from .rcnn import * from .rcnn import *
from .yolo import * from .yolo import *
from . import _multibox
from . import _rcnn from . import _rcnn
from . import _yolo from . import _yolo
from . import _vision
...@@ -54,24 +54,46 @@ reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE) ...@@ -54,24 +54,46 @@ reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE) reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)
# Get counts of valid boxes
@reg.register_schedule("vision.get_valid_counts")
def schedule_get_valid_counts(_, outs, target):
"""Schedule definition of get_valid_counts"""
with target:
return topi.generic.schedule_get_valid_counts(outs)
@reg.register_compute("vision.get_valid_counts")
def compute_get_valid_counts(attrs, inputs, _, target):
"""Compute definition of get_valid_counts"""
score_threshold = get_const_float(attrs.score_threshold)
return topi.vision.get_valid_counts(inputs[0], score_threshold)
reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
# non-maximum suppression # non-maximum suppression
@reg.register_schedule("vision.nms") @reg.register_schedule("vision.non_max_suppression")
def schedule_nms(_, outs, target): def schedule_nms(_, outs, target):
"""Schedule definition of nms""" """Schedule definition of nms"""
with target: with target:
return topi.generic.schedule_nms(outs) return topi.generic.schedule_nms(outs)
@reg.register_compute("vision.nms") @reg.register_compute("vision.non_max_suppression")
def compute_nms(attrs, inputs, _, target): def compute_nms(attrs, inputs, _, target):
"""Compute definition of nms""" """Compute definition of nms"""
overlap_threshold = get_const_float(attrs.overlap_threshold) return_indices = bool(get_const_int(attrs.return_indices))
max_output_size = get_const_int(attrs.max_output_size)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress)) force_suppress = bool(get_const_int(attrs.force_suppress))
topk = get_const_int(attrs.topk) top_k = get_const_int(attrs.top_k)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
return [ return [
topi.vision.nms(inputs[0], inputs[1], overlap_threshold, topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
force_suppress, topk) iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
] ]
reg.register_pattern("vision.nms", OpPattern.OPAQUE) reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)
"""Non-maximum suppression operations.""" """Non-maximum suppression operations."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import _make from . import _make
from ...expr import TupleWrapper
def nms(data, def get_valid_counts(data,
score_threshold):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : relay.Expr
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
valid_count : relay.Expr
1-D tensor for valid number of boxes.
out_tensor : relay.Expr
Rearranged data tensor.
"""
return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2)
def non_max_suppression(data,
valid_count, valid_count,
overlap_threshold=0.5, max_output_size=-1,
iou_threshold=0.5,
force_suppress=False, force_suppress=False,
topk=-1): top_k=-1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection. """Non-maximum suppression operator for object detection.
Parameters Parameters
...@@ -19,18 +48,33 @@ def nms(data, ...@@ -19,18 +48,33 @@ def nms(data,
valid_count : relay.Expr valid_count : relay.Expr
1-D tensor for valid number of boxes. 1-D tensor for valid number of boxes.
overlap_threshold : float, optional max_output_size : int, optional
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : float, optional
Non-maximum suppression threshold. Non-maximum suppression threshold.
force_suppress : bool, optional force_suppress : bool, optional
Suppress all detections regardless of class_id. Suppress all detections regardless of class_id.
topk : int, optional top_k : int, optional
Keep maximum top k detections before nms, -1 for no limit. Keep maximum top k detections before nms, -1 for no limit.
id_index : int, optional
index of the class categories, -1 to disable.
return_indices : bool, optional
Whether to return box indices in input data.
invalid_to_bottom : bool, optional
Whether to move all valid bounding boxes to the top.
Returns Returns
------- -------
out : relay.Expr out : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6]. 3-D tensor with shape [batch_size, num_anchors, 6].
""" """
return _make.nms(data, valid_count, overlap_threshold, force_suppress, topk) return _make.non_max_suppression(data, valid_count, max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
...@@ -1516,6 +1516,16 @@ RELAY_REGISTER_OP("broadcast_to_like") ...@@ -1516,6 +1516,16 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_attr<TOpPattern>("TOpPattern", kBroadcast); .set_attr<TOpPattern>("TOpPattern", kBroadcast);
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
// strided_slice // strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
bool StridedSliceRel(const Array<Type>& types, bool StridedSliceRel(const Array<Type>& types,
...@@ -1870,15 +1880,6 @@ Expr MakeSliceLike(Expr data, ...@@ -1870,15 +1880,6 @@ Expr MakeSliceLike(Expr data,
return CallNode::make(op, {data, shape_like}, Attrs(attrs), {}); return CallNode::make(op, {data, shape_like}, Attrs(attrs), {});
} }
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
Array<Tensor> SliceLikeCompute(const Attrs& attrs, Array<Tensor> SliceLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
const Type& out_type, const Type& out_type,
......
...@@ -70,8 +70,10 @@ RELAY_REGISTER_OP("vision.multibox_prior") ...@@ -70,8 +70,10 @@ RELAY_REGISTER_OP("vision.multibox_prior")
TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs); TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs);
bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs, bool MultiBoxTransformLocRel(const Array<Type>& types,
const Attrs& attrs, const TypeReporter& reporter) { int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4); CHECK_EQ(types.size(), 4);
const auto* cls_prob = types[0].as<TensorTypeNode>(); const auto* cls_prob = types[0].as<TensorTypeNode>();
......
...@@ -9,7 +9,54 @@ ...@@ -9,7 +9,54 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(NMSAttrs); TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs);
bool GetValidCountRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const auto& dshape = data->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
std::vector<IndexExpr> oshape({data->shape[0]});
std::vector<Type> fields;
fields.push_back(TensorTypeNode::make(oshape, Int(32)));
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
// assign output type
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
return true;
}
Expr MakeGetValidCounts(Expr data,
double score_threshold) {
auto attrs = make_node<GetValidCountsAttrs>();
attrs->score_threshold = score_threshold;
static const Op& op = Op::Get("vision.get_valid_counts");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.get_valid_counts")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGetValidCounts, args, rv);
});
RELAY_REGISTER_OP("vision.get_valid_counts")
.describe(R"doc(Get valid count of bounding boxes given
a score threshold. Also moves valid boxes to the top of
input data.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(5)
.add_type_rel("GetValidCount", GetValidCountRel);
TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
bool NMSRel(const Array<Type>& types, bool NMSRel(const Array<Type>& types,
int num_inputs, int num_inputs,
...@@ -18,39 +65,56 @@ bool NMSRel(const Array<Type>& types, ...@@ -18,39 +65,56 @@ bool NMSRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
const auto* valid_count = types[1].as<TensorTypeNode>(); const auto* valid_count = types[1].as<TensorTypeNode>();
const NonMaximumSuppressionAttrs* param =
attrs.as<NonMaximumSuppressionAttrs>();
const auto& dshape = data->shape; const auto& dshape = data->shape;
const auto& vshape = valid_count->shape; const auto& vshape = valid_count->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D."; CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D.";
// assign output type // assign output type
if (param->return_indices) {
std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, Int(32)));
} else {
reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype)); reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
}
return true; return true;
} }
Expr MakeNMS(Expr data, Expr MakeNMS(Expr data,
Expr valid_count, Expr valid_count,
double overlap_threshold, int max_output_size,
double iou_threshold,
bool force_suppress, bool force_suppress,
int topk) { int top_k,
auto attrs = make_node<NMSAttrs>(); int id_index,
attrs->overlap_threshold = overlap_threshold; bool return_indices,
bool invalid_to_bottom) {
auto attrs = make_node<NonMaximumSuppressionAttrs>();
attrs->max_output_size = max_output_size;
attrs->iou_threshold = iou_threshold;
attrs->force_suppress = force_suppress; attrs->force_suppress = force_suppress;
attrs->topk = topk; attrs->top_k = top_k;
static const Op& op = Op::Get("vision.nms"); attrs->id_index = id_index;
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
static const Op& op = Op::Get("vision.non_max_suppression");
return CallNode::make(op, {data, valid_count}, Attrs(attrs), {}); return CallNode::make(op, {data, valid_count}, Attrs(attrs), {});
} }
TVM_REGISTER_API("relay.op.vision._make.nms") TVM_REGISTER_API("relay.op.vision._make.non_max_suppression")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeNMS, args, rv); runtime::detail::unpack_call<Expr, 9>(MakeNMS, args, rv);
}); });
RELAY_REGISTER_OP("vision.nms") RELAY_REGISTER_OP("vision.non_max_suppression")
.describe(R"doc("Non-maximum suppression." .describe(R"doc(Non-maximum suppression. The input boxes should
be in the format of [class_id, score, left, top, right, bottom].
Set id_index to be -1 to ignore class_id axis.
)doc" TVM_ADD_FILELINE) )doc" TVM_ADD_FILELINE)
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
......
...@@ -374,6 +374,11 @@ def test_forward_slice_like(): ...@@ -374,6 +374,11 @@ def test_forward_slice_like():
verify((3, 4), (2, 3), (0)) verify((3, 4), (2, 3), (0))
verify((3, 4), (2, 3), (-1)) verify((3, 4), (2, 3), (-1))
def test_forward_l2_normalize():
data = mx.sym.var('data')
mx_sym = mx.sym.L2Normalization(data, mode="channel")
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
...@@ -401,5 +406,6 @@ if __name__ == '__main__': ...@@ -401,5 +406,6 @@ if __name__ == '__main__':
test_forward_broadcast_ops() test_forward_broadcast_ops()
test_forward_elemwise_ops() test_forward_elemwise_ops()
test_forward_scalar_ops() test_forward_scalar_ops()
test_forward_slice_axis()
test_forward_slice_like() test_forward_slice_like()
test_forward_slice_axis()
test_forward_l2_normalize()
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
""" """
import numpy as np import numpy as np
import tvm import tvm
import topi.testing
from tvm import relay from tvm import relay
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
import topi import topi
......
...@@ -135,56 +135,107 @@ def test_multibox_prior(): ...@@ -135,56 +135,107 @@ def test_multibox_prior():
verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True) verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True)
def test_nms(): def test_get_valid_counts():
def verify_nms(x0_data, x1_data, dshape, ref_res, valid_count, def verify_get_valid_counts(dshape, score_threshold):
overlap_threshold=0.5, force_suppress=False, topk=-1, dtype = "float32"
batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor):
score = np_data[i, j, 1]
if score >= score_threshold:
for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
inter_idx += 1
if j >= np_out1[i]:
for k in range(elem_length):
np_out2[i, j, k] = -1
x = relay.var("x", relay.ty.TensorType(dshape, dtype))
z = relay.vision.get_valid_counts(x, score_threshold)
assert "score_threshold" in z.astext()
func = relay.Function([x], z.astuple())
func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3)
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
verify_get_valid_counts((3, 1000, 6), 0.55)
verify_get_valid_counts((16, 500, 6), 0.95)
def test_non_max_suppression():
def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res,
iou_threshold=0.5, force_suppress=False, top_k=-1,
check_type_only=False): check_type_only=False):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int")) x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int"))
z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, topk) z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False)
assert "overlap_threshold" in z.astext() z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k)
assert "iou_threshold" in z.astext()
assert "iou_threshold" in z_indices.astext()
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
zz_indices = relay.ir_pass.infer_type(z_indices)
assert zz.checked_type == relay.ty.TensorType(dshape, "float32") assert zz.checked_type == relay.ty.TensorType(dshape, "float32")
assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32")
if check_type_only: if check_type_only:
return return
func = relay.Function([x0, x1], z) func = relay.Function([x0, x1], z)
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
func_indices = relay.Function([x0, x1], z_indices)
func_indices = relay.ir_pass.infer_type(func_indices)
ctx_list = [("llvm", tvm.cpu(0))] ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list: for target, ctx in ctx_list:
intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x0_data, x1_data) op_res1 = intrp1.evaluate(func)(x0_data, x1_data)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
tvm.testing.assert_allclose(op_indices_res1.asnumpy(), ref_indices_res, rtol=1e-5)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target) intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data) op_res2 = intrp2.evaluate(func)(x0_data, x1_data)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
tvm.testing.assert_allclose(op_indices_res2.asnumpy(), ref_indices_res, rtol=1e-5)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32") [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32") np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]]) [-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
num_anchors = 5 num_anchors = 5
dshape = (tvm.var("n"), num_anchors, 6) dshape = (tvm.var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result,
force_suppress=True, topk=2, check_type_only=True) force_suppress=True, top_k=2, check_type_only=True)
dshape = (1, num_anchors, 6) dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result,
force_suppress=True, topk=2, check_type_only=False) force_suppress=True, top_k=2, check_type_only=False)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[1, 0.7, 30, 60, 50, 80], [-1, 0.9, 35, 61, 52, 79], [1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]]) [-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, 1, -1, -1]])
dshape = (tvm.var("n"), num_anchors, 6) dshape = (tvm.var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], verify_nms(np_data, np_valid_count, dshape, np_result,
check_type_only=True) np_indices_result, check_type_only=True)
dshape = (1, num_anchors, 6) dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], verify_nms(np_data, np_valid_count, dshape, np_result,
topk=3) np_indices_result, top_k=3)
def test_multibox_transform_loc(): def test_multibox_transform_loc():
...@@ -226,7 +277,7 @@ def test_multibox_transform_loc(): ...@@ -226,7 +277,7 @@ def test_multibox_transform_loc():
assert ret.checked_type == ref_type assert ret.checked_type == ref_type
nms = relay.vision.nms(mtl[0], mtl[1]) nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False)
func = relay.Function([cls_prob, loc_pred, anchors], nms) func = relay.Function([cls_prob, loc_pred, anchors], nms)
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))] ctx_list = [("llvm", tvm.cpu(0))]
...@@ -411,8 +462,9 @@ if __name__ == "__main__": ...@@ -411,8 +462,9 @@ if __name__ == "__main__":
test_resize() test_resize()
test_multibox_prior() test_multibox_prior()
test_multibox_transform_loc() test_multibox_transform_loc()
test_nms() test_get_valid_counts()
test_roi_align() test_roi_align()
test_proposal() test_proposal()
test_yolo_reorg_infer_shape() test_yolo_reorg_infer_shape()
test_yolo_reorg() test_yolo_reorg()
test_non_max_suppression()
...@@ -30,7 +30,12 @@ inline Tensor l2_normalize(const Tensor& data, ...@@ -30,7 +30,12 @@ inline Tensor l2_normalize(const Tensor& data,
const Array<Integer>& axis, const Array<Integer>& axis,
std::string name = "tensor", std::string name = "tensor",
std::string tag = "l2_normalize") { std::string tag = "l2_normalize") {
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input"; for (size_t i = 0; i < axis.size(); ++i) {
int ax = topi::detail::GetConstInt(axis[i]);
CHECK_LT(ax, data->shape.size()) <<
"Axis " << ax << " exceeds input data dim " <<
data->shape.size();
}
auto input_shape = data->shape; auto input_shape = data->shape;
Tensor dot_value = topi::power(data, static_cast<float>(2.0)); Tensor dot_value = topi::power(data, static_cast<float>(2.0));
Tensor sum_value = topi::sum(dot_value, axis, true); Tensor sum_value = topi::sum(dot_value, axis, true);
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Non-maximum suppression operator""" """Non-maximum suppression operator"""
import math import math
import tvm import tvm
from tvm import api from tvm import api
from topi.vision import nms from topi.vision import non_max_suppression
from ..util import get_const_tuple from ..util import get_const_tuple
def sort_ir(data, index, output): def sort_ir(data, index, output):
...@@ -181,13 +181,14 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n ...@@ -181,13 +181,14 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
return body return body
@nms.register(["cuda", "gpu"]) @non_max_suppression.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1): def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False,
topk=-1, id_index=0, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection. """Non-maximum suppression operator for object detection.
Parameters Parameters
---------- ----------
data: tvm.Tensor data : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6]. 3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom]. [class_id, score, box_left, box_top, box_right, box_bottom].
...@@ -195,15 +196,24 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk ...@@ -195,15 +196,24 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
valid_count : tvm.Tensor valid_count : tvm.Tensor
1-D tensor for valid number of boxes. 1-D tensor for valid number of boxes.
nms_threshold : float return_indices : boolean
Whether to return box indices in input data.
iou_threshold : optional, float
Non-maximum suppression threshold. Non-maximum suppression threshold.
force_suppress : boolean force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id. Whether to suppress all detections regardless of class_id.
nms_topk : int topk : optional, int
Keep maximum top k detections before nms, -1 for no limit. Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns Returns
------- -------
out : tvm.Tensor out : tvm.Tensor
...@@ -216,14 +226,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk ...@@ -216,14 +226,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
# An example to use nms # An example to use nms
dshape = (1, 5, 6) dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data") data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder( valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
(dshape[0],), dtype="int32", name="valid_count") iou_threshold = 0.7
nms_threshold = 0.7
force_suppress = True force_suppress = True
nms_topk = -1 topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) out = nms(data, valid_count, iou_threshold, force_suppress, topk)
np_data = np.random.uniform(size=dshape).astype("float32") np_data = np.random.uniform(dshape)
np_valid_count = np.array([4]).astype("int32") np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out) s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm") f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu() ctx = tvm.cpu()
...@@ -263,8 +272,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk ...@@ -263,8 +272,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
tvm.extern(data.shape, tvm.extern(data.shape,
[data, sort_tensor, valid_count], [data, sort_tensor, valid_count],
lambda ins, outs: nms_ir( lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], nms_threshold, ins[0], ins[1], ins[2], outs[0], iou_threshold,
force_suppress, nms_topk), force_suppress, topk),
dtype="float32", dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms") tag="nms")
......
...@@ -11,7 +11,7 @@ import topi ...@@ -11,7 +11,7 @@ import topi
from topi.vision.ssd import multibox_prior from topi.vision.ssd import multibox_prior
from topi.vision.ssd import multibox_detection from topi.vision.ssd import multibox_detection
from topi.vision.ssd import multibox_transform_loc from topi.vision.ssd import multibox_transform_loc
from ..nms import nms from ..nms import non_max_suppression
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
...@@ -437,6 +437,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 ...@@ -437,6 +437,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
""" """
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances) clip, threshold, variances)
out = nms( out = non_max_suppression(
inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out return out
...@@ -162,3 +162,20 @@ def schedule_proposal(outs): ...@@ -162,3 +162,20 @@ def schedule_proposal(outs):
scheduled_ops.append(op) scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
@generic.schedule_get_valid_counts.register(["cuda", "gpu"])
def schedule_get_valid_counts(outs):
"""Schedule for get_valid_counts operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of get_valid_counts
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs)
...@@ -37,6 +37,23 @@ def schedule_reorg(outs): ...@@ -37,6 +37,23 @@ def schedule_reorg(outs):
return cpp.generic.default_schedule(cpp_target, outs, False) return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func @tvm.target.generic_func
def schedule_get_valid_counts(outs):
"""Schedule for get_valid_counts
Parameters
----------
outs: Array of Tensor
The computation graph description of nms
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_nms(outs): def schedule_nms(outs):
"""Schedule for non-maximum suppression """Schedule for non-maximum suppression
......
...@@ -20,3 +20,4 @@ from .l2_normalize_python import l2_normalize_python ...@@ -20,3 +20,4 @@ from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python from .strided_slice_python import strided_slice_python
from .batch_matmul import batch_matmul from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
"""Slice axis in python"""
def slice_axis_python(data, axis, begin, end=None):
"""Slice input array along specific axis.
Parameters
----------
data : numpy.ndarray
The source array to be sliced.
axis : int
Axis to be sliced.
begin: int
The index to begin with in the slicing.
end: int, optional
The index indicating end of the slice.
Returns
-------
ret : numpy.ndarray
The computed result.
"""
dshape = data.shape
if axis < 0:
axis += len(dshape)
if begin < 0:
begin += dshape[axis]
if end <= 0:
end += dshape[axis]
slc = [slice(None)] * len(dshape)
slc[axis] = slice(begin, end)
return data[tuple(slc)]
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements
"""Non-maximum suppression operator""" """Non-maximum suppression operator"""
import tvm import tvm
from tvm import api from tvm import api, hybrid
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): @hybrid.script
"""Low level IR routing for transform location in multibox_detection operator. def hybrid_rearrange_out(data):
"""Hybrid routine to rearrange nms output to
move all valid entries to top.
Parameters Parameters
---------- ----------
data: Buffer data : tvm.Tensor or numpy NDArray
Buffer of output boxes with class and score. NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6].
sort_result : Buffer Returns
Buffer of output box indexes sorted by score. -------
output : tvm.Tensor or numpy NDArray
Transformed NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6].
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]
output = output_tensor((batch_size,
num_anchors,
elem_length),
data.dtype)
valid_count : Buffer for i in parallel(batch_size):
Buffer of number of valid output boxes. valid_idx = 0
for j in range(num_anchors):
if data[i, j, 0] >= 0:
for k in range(elem_length):
output[i, valid_idx, k] = data[i, j, k]
valid_idx += 1
if j >= valid_idx:
for k in range(elem_length):
output[i, j, k] = -1.0
return output
out : Buffer
Output buffer.
nms_threshold : float @hybrid.script
Non-maximum suppression threshold. def hybrid_get_valid_counts(data, score_threshold):
"""Hybrid routine to get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the
top of input data.
Parameters
----------
data : tvm.Tensor or numpy NDArray
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : tvm.const
Lower limit of score for valid bounding boxes.
Returns
-------
out_tensor : tvm.Tensor or numpy NDArray
Rearranged data tensor.
valid_count : tvm.Tensor or numpy NDArray
1-D tensor for valid number of boxes.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
box_data_length = data.shape[2]
valid_count = output_tensor((batch_size,), "int32")
out_tensor = output_tensor((batch_size,
num_anchors,
box_data_length),
data.dtype)
for i in parallel(batch_size):
valid_count[i] = 0
for j in range(num_anchors):
score = data[i, j, 1]
if score > score_threshold:
for k in range(box_data_length):
out_tensor[i, valid_count[i], k] = data[i, j, k]
valid_count[i] += 1
if j >= valid_count[i]:
for k in range(box_data_length):
out_tensor[i, j, k] = -1.0
return valid_count, out_tensor
@tvm.target.generic_func
def get_valid_counts(data, score_threshold=0):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : tvm.Tensor
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
out_tensor : tvm.Tensor
Rearranged data tensor.
force_suppress : boolean valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
"""
score_threshold_const = tvm.const(score_threshold, "float")
return hybrid_get_valid_counts(data, score_threshold_const)
@hybrid.script
def hybrid_nms(data, sorted_index, valid_count,
max_output_size, iou_threshold, force_suppress,
top_k, id_index):
"""Hybrid routing for non-maximum suppression.
Parameters
----------
data: tvm.Tensor or numpy NDArray
Bounding boxes with class and score. 3-D tensor with shape
[batch_size, num_anchors, 6].
sorted_index : tvm.Tensor or numpy NDArray
Bounding box indexes sorted by score, with shape
[batch_size, num_anchors].
valid_count : tvm.Tensor or numpy NDArray
1-D tensor for valid number of boxes.
max_output_size : tvm.const
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : tvm.const
Overlapping(IoU) threshold to suppress object with smaller score.
force_suppress : tvm.const
Whether to suppress all detections regardless of class_id. Whether to suppress all detections regardless of class_id.
nms_topk : int top_k : tvm.const
Keep maximum top k detections before nms, -1 for no limit. Keep maximum top k detections before nms, -1 for no limit.
id_index : tvm.const
index of the class categories, -1 to disable.
Returns Returns
------- -------
stmt : Stmt output : tvm.Tensor
The result IR statement. 3-D tensor with shape [batch_size, num_anchors, 6].
"""
def calculate_overlap(out_tensor, box_a_idx, box_b_idx): box_indices: tvm.Tensor
"""Calculate overlap of two boxes. 2-D tensor with shape [batch_size, num_anchors].
""" """
w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) batch_size = data.shape[0]
- tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx])) num_anchors = data.shape[1]
h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) box_data_length = data.shape[2]
- tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) box_indices = output_tensor((batch_size, num_anchors), "int32")
i = w * h output = output_tensor((batch_size,
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ num_anchors,
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ box_data_length,),
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ data.dtype)
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.expr.Select(u <= 0.0, 0.0, i / u) for i in parallel(batch_size):
if iou_threshold > 0:
ib = tvm.ir_builder.create() if valid_count[i] > 0:
p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
batch_size = out.shape[0]
num_anchors = out.shape[1]
nms_threshold_node = tvm.make.node("FloatImm", dtype="float32", value=nms_threshold)
nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
force_suppress_node = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="parallel", name="n") as n:
with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
# Reorder output # Reorder output
nkeep = tvm.if_then_else( nkeep = valid_count[i]
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]), if 0 < top_k < nkeep:
nms_topk, p_valid_count[n]) nkeep = top_k
with ib.for_range(0, nkeep, name="l") as l: for j in range(nkeep):
with ib.for_range(0, 6, name="m") as m: for k in range(box_data_length):
p_out[(n * num_anchors * 6 output[i, j, k] = data[i, sorted_index[i, j], k]
+ l * 6 + m)] = p_data[(n * num_anchors * 6 box_indices[i, j] = sorted_index[i, j]
+ p_sort_result[n * num_anchors + l] * 6 + m)] if 0 < top_k < valid_count[i]:
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])): for j in range(valid_count[i] - nkeep):
with ib.for_range(0, p_valid_count[n] - nkeep, name="l") as l: for k in range(box_data_length):
with ib.for_range(0, 6, name="m") as m: output[i, j + nkeep, k] = -1.0
p_out[(n * num_anchors * 6 box_indices[i, j + nkeep] = -1
+ (l + nkeep) * 6 + m)] = p_data[(n * num_anchors * 6
+ (l + nkeep) * 6 + m)]
# Apply nms # Apply nms
with ib.for_range(0, p_valid_count[n], name="l") as l: for j in range(valid_count[i]):
offset_l = l * 6 if output[i, j, 0] >= 0:
with ib.if_scope(p_out[n * num_anchors * 6 + offset_l] >= 0): for k in range(valid_count[i]):
with ib.for_range(0, p_valid_count[n], name="m") as m: check_iou = 0
offset_m = m * 6 if k > j and output[i, k, 0] >= 0:
with ib.if_scope(tvm.all(m > l, p_out[n * num_anchors * 6 if force_suppress:
+ offset_m] >= 0)): check_iou = 1
with ib.if_scope(tvm.any(force_suppress_node > 0, elif id_index < 0 or output[i, j, 0] == output[i, k, 0]:
p_out[n * num_anchors * 6 + offset_l] == check_iou = 1
p_out[n * num_anchors * 6 + offset_m])): if check_iou > 0:
# When force_suppress == True or class_id equals batch_idx = i
iou = calculate_overlap(p_out, n * num_anchors * 6 + offset_l + 2, box_a_idx = j
n * num_anchors * 6 + offset_m + 2) box_b_idx = k
with ib.if_scope(iou >= nms_threshold): box_start_idx = 2
p_out[n * num_anchors * 6 + offset_m] = -1.0 a_t = output[batch_idx, box_a_idx, box_start_idx + 1]
with ib.else_scope(): a_b = output[batch_idx, box_a_idx, box_start_idx + 3]
with ib.for_range(0, p_valid_count[n], name="l") as l: a_l = output[batch_idx, box_a_idx, box_start_idx]
with ib.for_range(0, 6, name="m") as m: a_r = output[batch_idx, box_a_idx, box_start_idx + 2]
p_out[(n * num_anchors * 6 b_t = output[batch_idx, box_b_idx, box_start_idx + 1]
+ l * 6 + m)] = p_data[n * num_anchors * 6 + l * 6 + m] b_b = output[batch_idx, box_b_idx, box_start_idx + 3]
b_l = output[batch_idx, box_b_idx, box_start_idx]
b_r = output[batch_idx, box_b_idx, box_start_idx + 2]
w = max(0.0, min(a_r, b_r) - max(a_l, b_l))
h = max(0.0, min(a_b, b_b) - max(a_t, b_t))
area = h * w
u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
iou = 0.0 if u <= 0.0 else area / u
if iou >= iou_threshold:
output[i, k, 0] = -1.0
box_indices[i, k] = -1
else:
for j in range(valid_count[i]):
for k in range(box_data_length):
output[i, j, k] = data[i, j, k]
box_indices[i, j] = j
# Set invalid entry to be -1 # Set invalid entry to be -1
with ib.for_range(0, num_anchors - p_valid_count[n], name="l") as l: for j in range(num_anchors - valid_count[i]):
with ib.for_range(0, 6, name="m") as m: for k in range(box_data_length):
p_out[n * num_anchors * 6 + (l + p_valid_count[n]) * 6 + m] = -1.0 output[i, j + valid_count[i], k] = -1.0
return ib.get() box_indices[i, j + valid_count[i]] = -1
# Only return max_output_size valid boxes
num_valid_boxes = 0
if max_output_size > 0:
for j in range(valid_count[i]):
if output[i, j, 0] >= 0:
if num_valid_boxes == max_output_size:
for k in range(box_data_length):
output[i, j, k] = -1.0
box_indices[i, j] = -1
else:
num_valid_boxes += 1
return output, box_indices
@tvm.target.generic_func @tvm.target.generic_func
def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1): def non_max_suppression(data, valid_count, max_output_size=-1,
iou_threshold=0.5, force_suppress=False, top_k=-1,
id_index=0, return_indices=True, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection. """Non-maximum suppression operator for object detection.
Parameters Parameters
---------- ----------
data: tvm.Tensor data : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6]. 3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom]. [class_id, score, box_left, box_top, box_right, box_bottom].
...@@ -120,15 +249,28 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) ...@@ -120,15 +249,28 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
valid_count : tvm.Tensor valid_count : tvm.Tensor
1-D tensor for valid number of boxes. 1-D tensor for valid number of boxes.
nms_threshold : float max_output_size : optional, int
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : optional, float
Non-maximum suppression threshold. Non-maximum suppression threshold.
force_suppress : boolean force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id. Whether to suppress all detections regardless of class_id.
nms_topk : int top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit. Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
return_indices : optional, boolean
Whether to return box indices in input data.
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns Returns
------- -------
out : tvm.Tensor out : tvm.Tensor
...@@ -138,16 +280,17 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) ...@@ -138,16 +280,17 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
-------- --------
.. code-block:: python .. code-block:: python
# An example to use nms # An example to use non_max_suppression
dshape = (1, 5, 6) dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data") data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7 iou_threshold = 0.7
force_suppress = True force_suppress = True
nms_topk = -1 top_k = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) out = non_max_suppression(data, valid_count, iou_threshold=iou_threshold,
np_data = np.random.uniform(size=dshape).astype("float32") force_suppress=force_suppress, top_k=top_k)
np_valid_count = np.array([4]).astype("int32") np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out) s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm") f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu() ctx = tvm.cpu()
...@@ -161,7 +304,6 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) ...@@ -161,7 +304,6 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
valid_count_dtype = "int32" valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
score_axis = 1 score_axis = 1
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
...@@ -180,13 +322,13 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1) ...@@ -180,13 +322,13 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
in_buffers=[score_tensor_buf, valid_count_buf], in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf, out_buffers=sort_tensor_buf,
name="nms_sort") name="nms_sort")
out = \ out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.extern(data.shape, tvm.const(max_output_size, dtype="int32"),
[data, sort_tensor, valid_count], tvm.const(iou_threshold, dtype="float32"),
lambda ins, outs: nms_ir( tvm.const(force_suppress, dtype="bool"),
ins[0], ins[1], ins[2], outs[0], nms_threshold, tvm.const(top_k, dtype="int32"),
force_suppress, nms_topk), tvm.const(id_index, dtype="int32"))
dtype="float32", if not return_indices and invalid_to_bottom:
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], out = hybrid_rearrange_out(out)
tag="nms")
return out return box_indices if return_indices else out
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable
"""SSD multibox operators""" """SSD multibox operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import math
import tvm import tvm
from tvm import api from tvm import hybrid
from tvm.intrin import exp, sqrt
import topi import topi
from ..nms import nms from ..nms import non_max_suppression
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): @hybrid.script
"""Low level IR routing for multibox_prior operator. def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
"""Hybrid routing for multibox_prior operator.
Parameters Parameters
---------- ----------
data : Buffer data : tvm.Tensor or numpy NDArray
Input data buffer. 4-D tensor with shape [batch, channel, height, width]]
out : Buffer sizes : tvm ConsExpr
Output buffer. Sizes for anchor boxes.
sizes : tuple of float ratios : tvm ConsExpr
Tuple of sizes for anchor boxes. Ratios for anchor boxes.
ratios : tuple of float
Tuple of ratios for anchor boxes.
steps : Tuple of float steps : tvm ConsExpr
Priorbox step across y and x, -1 for auto calculation. Priorbox step across y and x, -1 for auto calculation.
offsets : tuple of int offsets : tvm ConsExpr
Priorbox center offsets, y and x respectively. Priorbox center offsets, y and x respectively.
Returns Returns
------- -------
stmt : Stmt output : tvm.Tensor or numpy NDArray
The result IR statement. 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
""" """
ib = tvm.ir_builder.create()
p_out = ib.buffer_ptr(out)
in_height = data.shape[2] in_height = data.shape[2]
in_width = data.shape[3] in_width = data.shape[3]
num_sizes = len(sizes) num_sizes = len(sizes)
num_ratios = len(ratios) num_ratios = len(ratios)
size_ratio_concat = sizes + ratios num_boxes = in_height * in_width * (num_sizes + num_ratios - 1)
steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height output = output_tensor((1, num_boxes, 4), "float32")
steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width steps_h = steps[0] * 1.0 if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] * 1.0 if steps[1] > 0 else 1.0 / in_width
offset_h = offsets[0] offset_h = offsets[0]
offset_w = offsets[1] offset_w = offsets[1]
with ib.for_range(0, in_height, for_type="parallel", name="i") as i: # Need to define var out of const_range + if
w = 0.0
h = 0.0
for i in parallel(in_height):
center_h = (i + offset_h) * steps_h center_h = (i + offset_h) * steps_h
with ib.for_range(0, in_width, name="j") as j: for j in range(in_width):
center_w = (j + offset_w) * steps_w center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1): for k in const_range(num_sizes + num_ratios - 1):
w = tvm.if_then_else(k < num_sizes, if k < num_sizes:
size_ratio_concat[k] * in_height / in_width / 2.0, w = sizes[k] * in_height / in_width / 2.0
size_ratio_concat[0] * in_height / in_width * h = sizes[k] / 2.0
math.sqrt(size_ratio_concat[k + 1]) / 2.0) else:
h = tvm.if_then_else( w = sizes[0] * in_height / in_width \
k < num_sizes, size_ratio_concat[k] / 2.0, * sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
count = (i * in_width * (num_sizes + num_ratios - 1) + count = i * in_width * (num_sizes + num_ratios - 1) \
j * (num_sizes + num_ratios - 1) + k) * 4 + j * (num_sizes + num_ratios - 1) + k
p_out[count] = center_w - w output[0, count, 0] = center_w - w
p_out[count + 1] = center_h - h output[0, count, 1] = center_h - h
p_out[count + 2] = center_w + w output[0, count, 2] = center_w + w
p_out[count + 3] = center_h + h output[0, count, 3] = center_h + h
return ib.get() return output
@tvm.target.generic_func @tvm.target.generic_func
...@@ -101,115 +102,120 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, ...@@ -101,115 +102,120 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
out : tvm.Tensor out : tvm.Tensor
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4] 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
""" """
num_sizes = len(sizes) out = hybrid_multibox_prior(data, tvm.convert(sizes), tvm.convert(ratios),
num_ratios = len(ratios) tvm.convert(steps), tvm.convert(offsets))
oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
out = tvm.extern(oshape, [data], lambda ins, outs:
multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets),
tag="multibox_prior")
if clip: if clip:
out = topi.clip(out, 0, 1) out = topi.clip(out, 0, 1)
return out return out
def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances): @hybrid.script
"""Low level IR routing for transform location in multibox_detection operator. def _hybridy_transform_loc(box, pred_loc, variance, clip):
"""Transform prior anchor box to output box through location predictions.
"""
al = box[0]
at = box[1]
ar = box[2]
ab = box[3]
Parameters px = pred_loc[0]
---------- py = pred_loc[1]
cls_prob : Buffer pw = pred_loc[2]
Buffer of class probabilities. ph = pred_loc[3]
loc_pred : Buffer vx = variance[0]
Buffer of location regression predictions. vy = variance[1]
vw = variance[2]
vh = variance[3]
anchor : Buffer output = output_tensor((4,), pred_loc.dtype)
Buffer of prior anchor boxes.
valid_count : Buffer aw = ar - al
Buffer of number of valid output boxes. ah = ab - at
ax = (al + ar) / 2.0
ay = (at + ab) / 2.0
ox = px * vx * aw + ax
oy = py * vy * ah + ay
ow = exp(pw * vw) * aw / 2.0
oh = exp(ph * vh) * ah / 2.0
output[0] = max(0.0, min(1.0, ox - ow)) if clip else ox - ow
output[1] = max(0.0, min(1.0, oy - oh)) if clip else oy - oh
output[2] = max(0.0, min(1.0, ox + ow)) if clip else ox + ow
output[3] = max(0.0, min(1.0, oy + oh)) if clip else oy + oh
return output
@hybrid.script
def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances):
"""Hybrid routing for transform location in multibox_detection operator.
out : Buffer Parameters
Output buffer. ----------
cls_prob : tvm.Tensor or numpy NDArray
3-D tensor of class probabilities.
clip : boolean loc_pred : tvm.Tensor or numpy NDArray
2-D tensor of location regression predictions.
anchor : tvm.Tensor or numpy NDArray
3-D tensor of prior anchor boxes.
clip : tvm.const
Whether to clip out-of-boundary boxes. Whether to clip out-of-boundary boxes.
threshold : float threshold : tvm.const
Threshold to be a positive prediction. Threshold to be a positive prediction.
variances : tuple of float variances : tvm.ndarray
Variances to be decoded from box regression output. Variances to be decoded from box regression output.
Returns Returns
------- -------
stmt : Stmt out_loc : tvm.Tensor or numpy NDArray
The result IR statement. 3-D tensor of transformed location.
"""
def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh):
"""Transform prior anchor box to output box through location predictions.
"""
al = anchor[anchor_base_idx]
at = anchor[anchor_base_idx + 1]
ar = anchor[anchor_base_idx + 2]
ab = anchor[anchor_base_idx + 3]
aw = ar - al
ah = ab - at
ax = (al + ar) / 2.0
ay = (at + ab) / 2.0
px = loc[loc_base_idx]
py = loc[loc_base_idx + 1]
pw = loc[loc_base_idx + 2]
ph = loc[loc_base_idx + 3]
ox = px * vx * aw + ax
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
valid_count : tvm.Tensor or numpy NDArray
1_d tensor of valid counts for boxes.
"""
batch_size = cls_prob.shape[0] batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1] num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2] num_anchors = cls_prob.shape[2]
box_coord = allocate((4,), loc_pred.dtype)
ib = tvm.ir_builder.create() pred_coord = allocate((4,), loc_pred.dtype)
p_cls_prob = ib.buffer_ptr(cls_prob) out_loc = output_tensor((batch_size, num_anchors, 6),
p_loc_pred = ib.buffer_ptr(loc_pred) loc_pred.dtype)
p_anchor = ib.buffer_ptr(anchor) valid_count = output_tensor((batch_size,), "int32")
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out) for i in parallel(batch_size):
with ib.for_range(0, batch_size, for_type="parallel", name="n") as n: valid_count[i] = 0
p_valid_count[n] = 0 for j in range(num_anchors):
with ib.for_range(0, num_anchors, name="i") as i:
# Find the predicted class id and probability # Find the predicted class id and probability
score = ib.allocate('float32', (1,), name="score", scope="local") score = -1.0
cls_id = ib.allocate('int32', (1,), name="id", scope="local") cls_id = 0
score[0] = -1.0 for k in range(num_classes):
cls_id[0] = 0 if k > 0:
with ib.for_range(0, num_classes, name="j") as j: temp = cls_prob[i, k, j]
with ib.if_scope(j > 0): cls_id = k if temp > score else cls_id
temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i] score = max(temp, score)
cls_id[0] = tvm.if_then_else(temp > score[0], j, cls_id[0]) if cls_id > 0 and score < threshold:
score[0] = tvm.max(temp, score[0]) cls_id = 0
with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)):
cls_id[0] = 0
# [id, prob, xmin, ymin, xmax, ymax] # [id, prob, xmin, ymin, xmax, ymax]
# Remove background, restore original id # Remove background, restore original id
with ib.if_scope(cls_id[0] > 0): if cls_id > 0:
out_base_idx = n * num_anchors * 6 + p_valid_count[n] * 6 out_loc[i, valid_count[i], 0] = cls_id - 1.0
p_out[out_base_idx] = cls_id[0] - 1.0 out_loc[i, valid_count[i], 1] = score
p_out[out_base_idx + 1] = score[0] for l in range(4):
offset = i * 4 box_coord[l] = anchor[0, j, l]
p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ pred_coord[l] = loc_pred[i, j * 4 + l]
p_out[out_base_idx + 5] = transform_loc(p_loc_pred, n * num_anchors * 4 + offset, out_coord = _hybridy_transform_loc(box_coord, pred_coord,
p_anchor, offset, clip, variances[0], variances, clip)
variances[1], variances[2], variances[3]) out_loc[i, valid_count[i], 2] = out_coord[0]
p_valid_count[n] += 1 out_loc[i, valid_count[i], 3] = out_coord[1]
out_loc[i, valid_count[i], 4] = out_coord[2]
return ib.get() out_loc[i, valid_count[i], 5] = out_coord[3]
valid_count[i] += 1
return out_loc, valid_count
@tvm.target.generic_func @tvm.target.generic_func
def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
...@@ -240,24 +246,10 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 ...@@ -240,24 +246,10 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
------- -------
ret : tuple of tvm.Tensor ret : tuple of tvm.Tensor
""" """
batch_size = cls_prob.shape[0] return hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
num_anchors = anchor.shape[1] tvm.const(clip, "bool"),
oshape = (batch_size, num_anchors, 6) tvm.const(threshold, "float32"),
# Define data alignment for intermediate buffer tvm.convert(variances))
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8)
valid_count, out = \
tvm.extern([(batch_size,), oshape],
[cls_prob, loc_pred, anchor],
lambda ins, outs: transform_loc_ir(
ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances),
dtype=[valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf, out_buf],
tag="multibox_transform_loc")
return [out, valid_count]
@tvm.target.generic_func @tvm.target.generic_func
def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5, def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
...@@ -300,5 +292,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm ...@@ -300,5 +292,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
""" """
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances) clip, threshold, variances)
out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) out = non_max_suppression(inter_out[0], inter_out[1], -1,
nms_threshold, force_suppress, nms_topk,
return_indices=False)
return out return out
...@@ -8,11 +8,62 @@ import topi.testing ...@@ -8,11 +8,62 @@ import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
from topi.vision import ssd, nms from topi.vision import ssd, non_max_suppression, get_valid_counts
def verify_get_valid_counts(dshape, score_threshold):
dtype = "float32"
batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor):
score = np_data[i, j, 1]
if score > score_threshold:
for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
inter_idx += 1
if j >= np_out1[i]:
for k in range(elem_length):
np_out2[i, j, k] = -1.0
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
data = tvm.placeholder(dshape, name="data", dtype=dtype)
outs = get_valid_counts(data, score_threshold)
s = topi.generic.schedule_multibox_prior(outs)
tvm_input_data = tvm.nd.array(np_data, ctx)
tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx)
tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx)
f = tvm.build(s, [data, outs[0], outs[1]], device)
f(tvm_input_data, tvm_out1, tvm_out2)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
def test_nms(): for device in ['llvm']:
check_device(device)
def test_get_valid_counts():
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
verify_get_valid_counts((3, 1000, 6), 0.55)
verify_get_valid_counts((16, 500, 6), 0.95)
def test_non_max_suppression():
dshape = (1, 5, 6) dshape = (1, 5, 6)
indices_dshape = (1, 5)
data = tvm.placeholder(dshape, name="data") data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7 nms_threshold = 0.7
...@@ -24,8 +75,9 @@ def test_nms(): ...@@ -24,8 +75,9 @@ def test_nms():
[1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype) [1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype) np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]]) [-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -35,18 +87,27 @@ def test_nms(): ...@@ -35,18 +87,27 @@ def test_nms():
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
if device == 'llvm': if device == 'llvm':
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False)
indices_out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk)
else: else:
out = topi.cuda.nms(data, valid_count, nms_threshold, force_suppress, nms_topk) out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False)
indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk)
s = topi.generic.schedule_nms(out) s = topi.generic.schedule_nms(out)
indices_s = topi.generic.schedule_nms(indices_out)
tvm_data = tvm.nd.array(np_data, ctx) tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f = tvm.build(s, [data, valid_count, out], device) f = tvm.build(s, [data, valid_count, out], device)
f(tvm_data, tvm_valid_count, tvm_out) f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
f = tvm.build(indices_s, [data, valid_count, indices_out], device)
f(tvm_data, tvm_valid_count, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
for device in ['llvm']: for device in ['llvm']:
check_device(device) check_device(device)
...@@ -274,7 +335,8 @@ def test_proposal(): ...@@ -274,7 +335,8 @@ def test_proposal():
if __name__ == "__main__": if __name__ == "__main__":
test_nms() test_get_valid_counts()
test_non_max_suppression()
test_multibox_prior() test_multibox_prior()
test_multibox_detection() test_multibox_detection()
test_roi_align() test_roi_align()
......
"""
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use GluonCV pre-trained SSD model and convert it to Relay IR
"""
import tvm
from matplotlib import pyplot as plt
from nnvm import compiler
from nnvm.frontend import from_mxnet
from nnvm.testing.config import ctx_list
from tvm import relay
from tvm.contrib import graph_runtime
from gluoncv import model_zoo, data, utils
######################################################################
# Preliminary and Set parameters
# ------------------------------
# We should build TVM with sort support, in TVM root directory
#
# .. code-block:: bash
#
# echo "set(USE_SORT ON)" > config.mk
# make -j8
#
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
#
# To get best inference performance on CPU, change
# target argument according to your device and
# follow the :ref:`tune_relay_x86` to tune x86 CPU and
# :ref:`tune_relay_arm` for arm cpu.
#
# SSD with VGG as body network is not supported yet since
# x86 conv2d schedule doesn't support dilation.
supported_model = [
'ssd_512_resnet18_v1_voc',
'ssd_512_resnet18_v1_coco',
'ssd_512_resnet50_v1_voc',
'ssd_512_resnet50_v1_coco',
'ssd_512_resnet101_v2_voc',
'ssd_512_mobilenet1_0_voc',
'ssd_512_mobilenet1_0_coco',
]
model_name = "ssd_512_resnet50_v1_voc"
dshape = (1, 3, 512, 512)
dtype = "float32"
target_list = ctx_list()
######################################################################
# Download and pre-process demo image
im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' +
'gluoncv/detection/street_small.jpg?raw=true',
path='street_small.jpg')
x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
######################################################################
# Convert and compile model for CPU.
block = model_zoo.get_model(model_name, pretrained=True)
def compile(target):
net, params = relay.frontend.from_mxnet(block, {"data": dshape})
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
return graph, lib, params
######################################################################
# Create TVM runtime and do inference
def run(graph, lib, params, ctx):
# Build TVM runtime
m = graph_runtime.create(graph, lib, ctx)
tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx)
m.set_input('data', tvm_input)
m.set_input(**params)
# execute
m.run()
# get outputs
class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2)
return class_IDs, scores, bounding_boxs
for target, ctx in target_list:
if target == "cuda":
print("GPU not supported yet, skip.")
continue
graph, lib, params = compile(target)
class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
######################################################################
# Display result
ax = utils.viz.plot_bbox(img, bounding_boxs.asnumpy()[0], scores.asnumpy()[0],
class_IDs.asnumpy()[0], class_names=block.classes)
plt.show()
...@@ -61,7 +61,7 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \ ...@@ -61,7 +61,7 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg" "cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder = \ inference_symbol_folder = \
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \ inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip" "archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment