Commit d2f29ba5 by Yao Wang Committed by Tianqi Chen

[Object Detection] Gluoncv SSD support on CPU (#2353)

parent f7eff095
......@@ -58,19 +58,42 @@ struct MultiBoxTransformLocAttrs
}
};
/*! \brief Attributes used in non_maximum_suppression operators */
struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
double overlap_threshold;
/*! \brief Attributes used in get_valid_counts operator */
struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
double score_threshold;
TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
TVM_ATTR_FIELD(score_threshold).set_default(0.0)
.describe("Lower limit of score for valid bounding boxes.");
}
};
/*! \brief Attributes used in non_maximum_suppression operator */
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
int max_output_size;
double iou_threshold;
bool force_suppress;
int topk;
int top_k;
int id_index;
bool return_indices;
bool invalid_to_bottom;
TVM_DECLARE_ATTRS(NMSAttrs, "relay.attrs.NMSAttrs") {
TVM_ATTR_FIELD(overlap_threshold).set_default(0.5)
TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_output_size).set_default(-1)
.describe("Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
TVM_ATTR_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1)
TVM_ATTR_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(return_indices).set_default(true)
.describe("Whether to return box indices in input data.");
TVM_ATTR_FIELD(invalid_to_bottom).set_default(false)
.describe("Whether to move all invalid bounding boxes to the bottom.");
}
};
......
......@@ -443,17 +443,30 @@ struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocPa
}
};
struct NMSParam : public dmlc::Parameter<NMSParam> {
float nms_threshold;
struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppressionParam> {
bool return_indices;
float iou_threshold;
bool force_suppress;
int nms_topk;
DMLC_DECLARE_PARAMETER(NMSParam) {
DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5)
int top_k;
int id_index;
int max_output_size;
bool invalid_to_bottom;
DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) {
DMLC_DECLARE_FIELD(max_output_size).set_default(-1)
.describe("Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
DMLC_DECLARE_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(nms_topk).set_default(-1)
DMLC_DECLARE_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
DMLC_DECLARE_FIELD(return_indices).set_default(true)
.describe("Whether to return box indices in input data.");
DMLC_DECLARE_FIELD(invalid_to_bottom).set_default(false)
.describe("Whether to move all invalid bounding boxes to the bottom.");
}
};
......
......@@ -245,11 +245,11 @@ def _contrib_multibox_detection(inputs, attrs):
if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2)
nms_topk = attrs.get('nms_topk') or -1
new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances}
new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress,
'nms_topk': int(nms_topk)}
new_attrs1 = {'return_indices': False, 'iou_threshold': float(nms_threshold),
'force_suppress': force_suppress, 'top_k': int(nms_topk)}
data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _get_nnvm_op('nms')(data, valid_count, **new_attrs1)
return _get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1)
def _elemwise_sum(inputs, _):
new_attrs = {'num_args':len(inputs)}
......
......@@ -61,20 +61,25 @@ def compute_multibox_transform_loc(attrs, inputs, _):
reg.register_pattern("multibox_detection", OpPattern.OPAQUE)
# non-maximum suppression
@reg.register_schedule("nms")
@reg.register_schedule("non_max_suppression")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
"""Schedule definition of non_max_suppression"""
with tvm.target.create(target):
return topi.generic.schedule_nms(outs)
@reg.register_compute("nms")
@reg.register_compute("non_max_suppression")
def compute_nms(attrs, inputs, _):
"""Compute definition of nms"""
nms_threshold = attrs.get_float('nms_threshold')
"""Compute definition of non_max_suppression"""
return_indices = attrs.get_bool('return_indices')
max_output_size = attrs.get_int('max_output_size')
iou_threshold = attrs.get_float('iou_threshold')
force_suppress = attrs.get_bool('force_suppress')
nms_topk = attrs.get_int('nms_topk')
top_k = attrs.get_int('top_k')
id_index = attrs.get_int('id_index')
invalid_to_bottom = attrs.get_bool('invalid_to_bottom')
return topi.vision.nms(inputs[0], inputs[1], nms_threshold,
force_suppress, nms_topk)
return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
reg.register_pattern("nms", OpPattern.OPAQUE)
reg.register_pattern("non_max_suppression", OpPattern.OPAQUE)
......@@ -19,11 +19,13 @@ using compiler::FTVMCompute;
using tvm::Tensor;
using tvm::Array;
DMLC_REGISTER_PARAMETER(NMSParam);
DMLC_REGISTER_PARAMETER(NonMaximumSuppressionParam);
bool NMSShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const NonMaximumSuppressionParam& param =
nnvm::get<NonMaximumSuppressionParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]";
TShape dshape = in_attrs->at(0);
TShape vshape = in_attrs->at(1);
......@@ -33,7 +35,14 @@ bool NMSShape(const NodeAttrs& attrs,
"(batch_size, num_anchors, 6).";
CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch.";
out_attrs->clear();
if (param.return_indices) {
TShape oshape = TShape(2);
oshape[0] = dshape[0];
oshape[1] = dshape[1];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
} else {
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape);
}
return true;
}
......@@ -56,15 +65,15 @@ inline bool NMSInferLayout(const NodeAttrs& attrs,
return true;
}
NNVM_REGISTER_OP(nms)
NNVM_REGISTER_OP(non_max_suppression)
.describe(R"doc("Non-maximum suppression."
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NMSParam>)
.set_attr_parser(ParamParser<NonMaximumSuppressionParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<NMSParam>)
.add_arguments(NMSParam::__FIELDS__())
ParamGetAttrDict<NonMaximumSuppressionParam>)
.add_arguments(NonMaximumSuppressionParam::__FIELDS__())
.add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
......
......@@ -550,7 +550,7 @@ def test_multibox_transform_loc():
anchors = sym.Variable("anchors")
transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds,
anchor=anchors)
out = sym.nms(data=transform_loc_data, valid_count=valid_count)
out = sym.non_max_suppression(data=transform_loc_data, valid_count=valid_count, return_indices=False)
# Manually create test case
np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]])
......@@ -573,22 +573,22 @@ def test_multibox_transform_loc():
out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
def test_nms():
def test_non_max_suppression():
dshape = (1, 5, 6)
data = sym.Variable("data")
valid_count = sym.Variable("valid_count", dtype="int32")
nms_threshold = 0.7
iou_threshold = 0.7
force_suppress = True
nms_topk = 2
out = sym.nms(data=data, valid_count=valid_count, nms_threshold=nms_threshold,
force_suppress=force_suppress, nms_topk=nms_topk)
top_k = 2
out = sym.non_max_suppression(data=data, valid_count=valid_count, return_indices=False,
iou_threshold=iou_threshold, force_suppress=force_suppress, top_k=top_k)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
target = "llvm"
......@@ -726,7 +726,7 @@ if __name__ == "__main__":
test_flip()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
test_non_max_suppression()
test_slice_like()
test_where()
test_argmax()
......
......@@ -315,4 +315,3 @@ if __name__ == '__main__':
test_forward_slice()
test_forward_maximum()
test_forward_minimum()
......@@ -328,13 +328,14 @@ def _mx_multibox_detection(inputs, attrs):
0.2, 0.2))
new_attrs1 = {}
new_attrs1["overlap_threshold"] = attrs.get_float("nms_threshold", 0.5)
new_attrs1["return_indices"] = False
new_attrs1["iou_threshold"] = attrs.get_float("nms_threshold", 0.5)
new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False)
new_attrs1["topk"] = attrs.get_int("nms_topk", -1)
new_attrs1["top_k"] = attrs.get_int("nms_topk", -1)
ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _op.vision.nms(ret[0], ret[1], **new_attrs1)
return _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1)
def _mx_batch_dot(inputs, attrs):
......@@ -399,6 +400,49 @@ def _mx_proposal(inputs, attrs):
return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs)
def _mx_box_nms(inputs, attrs):
force_suppress = attrs.get_bool("force_suppress", False)
iou_thresh = attrs.get_float('overlap_thresh', 0.5)
top_k = attrs.get_int('topk', -1)
valid_thresh = attrs.get_float('valid_thresh', 0)
coord_start = attrs.get_int('coord_start', 2)
score_index = attrs.get_int('score_index', 1)
id_index = attrs.get_int('id_index', -1)
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
if coord_start != 2:
raise RuntimeError('coord_start %s is not supported.' % coord_start)
if score_index != 1:
raise RuntimeError('score_index %s is not supported.' % score_index)
if id_index != -1 and int(id_index) != 0:
raise RuntimeError('id_index %s is not supported.' % id_index)
if in_format != 'corner':
raise RuntimeError('in_format %s is not supported.' % in_format)
if out_format != 'corner':
raise RuntimeError('out_format %s is not supported.' % out_format)
ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh)
nms_out = _op.vision.non_max_suppression(ret[1],
ret[0],
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
id_index=id_index,
return_indices=False,
invalid_to_bottom=True)
return nms_out
def _mx_l2_normalize(inputs, attrs):
new_attrs = {}
mode = attrs.get_str('mode', 'instance')
if mode != 'channel':
raise RuntimeError('mode %s is not supported.' % mode)
new_attrs['eps'] = attrs.get_float('eps', 1e-10)
new_attrs['axis'] = [1]
return _op.nn.l2_normalize(inputs[0], **new_attrs)
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -497,6 +541,7 @@ _convert_map = {
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
"slice" : _mx_slice,
"slice_like" : _mx_slice_like,
"slice_axis" : _mx_slice_axis,
......@@ -520,6 +565,7 @@ _convert_map = {
"_contrib_ROIAlign" : _mx_roi_align,
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
......@@ -662,6 +708,8 @@ def from_mxnet(symbol,
params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data")
sym = symbol(data)
if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(sym, shape, dtype)
elif isinstance(symbol, mx.gluon.Block):
......
......@@ -525,7 +525,7 @@ def strided_slice(data, begin, end, strides=None):
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
Indices indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
......
......@@ -6,6 +6,6 @@ from .multibox import *
from .nms import *
from .rcnn import *
from .yolo import *
from . import _multibox
from . import _rcnn
from . import _yolo
from . import _vision
......@@ -54,24 +54,46 @@ reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)
# Get counts of valid boxes
@reg.register_schedule("vision.get_valid_counts")
def schedule_get_valid_counts(_, outs, target):
"""Schedule definition of get_valid_counts"""
with target:
return topi.generic.schedule_get_valid_counts(outs)
@reg.register_compute("vision.get_valid_counts")
def compute_get_valid_counts(attrs, inputs, _, target):
"""Compute definition of get_valid_counts"""
score_threshold = get_const_float(attrs.score_threshold)
return topi.vision.get_valid_counts(inputs[0], score_threshold)
reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
# non-maximum suppression
@reg.register_schedule("vision.nms")
@reg.register_schedule("vision.non_max_suppression")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
with target:
return topi.generic.schedule_nms(outs)
@reg.register_compute("vision.nms")
@reg.register_compute("vision.non_max_suppression")
def compute_nms(attrs, inputs, _, target):
"""Compute definition of nms"""
overlap_threshold = get_const_float(attrs.overlap_threshold)
return_indices = bool(get_const_int(attrs.return_indices))
max_output_size = get_const_int(attrs.max_output_size)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
topk = get_const_int(attrs.topk)
top_k = get_const_int(attrs.top_k)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
return [
topi.vision.nms(inputs[0], inputs[1], overlap_threshold,
force_suppress, topk)
topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
]
reg.register_pattern("vision.nms", OpPattern.OPAQUE)
reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)
"""Non-maximum suppression operations."""
from __future__ import absolute_import as _abs
from . import _make
from ...expr import TupleWrapper
def nms(data,
def get_valid_counts(data,
score_threshold):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : relay.Expr
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
valid_count : relay.Expr
1-D tensor for valid number of boxes.
out_tensor : relay.Expr
Rearranged data tensor.
"""
return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2)
def non_max_suppression(data,
valid_count,
overlap_threshold=0.5,
max_output_size=-1,
iou_threshold=0.5,
force_suppress=False,
topk=-1):
top_k=-1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
......@@ -19,18 +48,33 @@ def nms(data,
valid_count : relay.Expr
1-D tensor for valid number of boxes.
overlap_threshold : float, optional
max_output_size : int, optional
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : float, optional
Non-maximum suppression threshold.
force_suppress : bool, optional
Suppress all detections regardless of class_id.
topk : int, optional
top_k : int, optional
Keep maximum top k detections before nms, -1 for no limit.
id_index : int, optional
index of the class categories, -1 to disable.
return_indices : bool, optional
Whether to return box indices in input data.
invalid_to_bottom : bool, optional
Whether to move all valid bounding boxes to the top.
Returns
-------
out : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
"""
return _make.nms(data, valid_count, overlap_threshold, force_suppress, topk)
return _make.non_max_suppression(data, valid_count, max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
......@@ -1516,6 +1516,16 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
// strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
bool StridedSliceRel(const Array<Type>& types,
......@@ -1870,15 +1880,6 @@ Expr MakeSliceLike(Expr data,
return CallNode::make(op, {data, shape_like}, Attrs(attrs), {});
}
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
Array<Tensor> SliceLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
......
......@@ -70,8 +70,10 @@ RELAY_REGISTER_OP("vision.multibox_prior")
TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs);
bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
bool MultiBoxTransformLocRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* cls_prob = types[0].as<TensorTypeNode>();
......
......@@ -9,7 +9,54 @@
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(NMSAttrs);
TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs);
bool GetValidCountRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const auto& dshape = data->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
std::vector<IndexExpr> oshape({data->shape[0]});
std::vector<Type> fields;
fields.push_back(TensorTypeNode::make(oshape, Int(32)));
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
// assign output type
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
return true;
}
Expr MakeGetValidCounts(Expr data,
double score_threshold) {
auto attrs = make_node<GetValidCountsAttrs>();
attrs->score_threshold = score_threshold;
static const Op& op = Op::Get("vision.get_valid_counts");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.get_valid_counts")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGetValidCounts, args, rv);
});
RELAY_REGISTER_OP("vision.get_valid_counts")
.describe(R"doc(Get valid count of bounding boxes given
a score threshold. Also moves valid boxes to the top of
input data.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(5)
.add_type_rel("GetValidCount", GetValidCountRel);
TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
bool NMSRel(const Array<Type>& types,
int num_inputs,
......@@ -18,39 +65,56 @@ bool NMSRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* valid_count = types[1].as<TensorTypeNode>();
const NonMaximumSuppressionAttrs* param =
attrs.as<NonMaximumSuppressionAttrs>();
const auto& dshape = data->shape;
const auto& vshape = valid_count->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D.";
// assign output type
if (param->return_indices) {
std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, Int(32)));
} else {
reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
}
return true;
}
Expr MakeNMS(Expr data,
Expr valid_count,
double overlap_threshold,
int max_output_size,
double iou_threshold,
bool force_suppress,
int topk) {
auto attrs = make_node<NMSAttrs>();
attrs->overlap_threshold = overlap_threshold;
int top_k,
int id_index,
bool return_indices,
bool invalid_to_bottom) {
auto attrs = make_node<NonMaximumSuppressionAttrs>();
attrs->max_output_size = max_output_size;
attrs->iou_threshold = iou_threshold;
attrs->force_suppress = force_suppress;
attrs->topk = topk;
static const Op& op = Op::Get("vision.nms");
attrs->top_k = top_k;
attrs->id_index = id_index;
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
static const Op& op = Op::Get("vision.non_max_suppression");
return CallNode::make(op, {data, valid_count}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.nms")
TVM_REGISTER_API("relay.op.vision._make.non_max_suppression")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeNMS, args, rv);
runtime::detail::unpack_call<Expr, 9>(MakeNMS, args, rv);
});
RELAY_REGISTER_OP("vision.nms")
.describe(R"doc("Non-maximum suppression."
RELAY_REGISTER_OP("vision.non_max_suppression")
.describe(R"doc(Non-maximum suppression. The input boxes should
be in the format of [class_id, score, left, top, right, bottom].
Set id_index to be -1 to ignore class_id axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
......
......@@ -374,6 +374,11 @@ def test_forward_slice_like():
verify((3, 4), (2, 3), (0))
verify((3, 4), (2, 3), (-1))
def test_forward_l2_normalize():
data = mx.sym.var('data')
mx_sym = mx.sym.L2Normalization(data, mode="channel")
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
if __name__ == '__main__':
test_forward_mlp()
......@@ -401,5 +406,6 @@ if __name__ == '__main__':
test_forward_broadcast_ops()
test_forward_elemwise_ops()
test_forward_scalar_ops()
test_forward_slice_axis()
test_forward_slice_like()
test_forward_slice_axis()
test_forward_l2_normalize()
......@@ -2,6 +2,7 @@
"""
import numpy as np
import tvm
import topi.testing
from tvm import relay
from tvm.relay.testing import ctx_list
import topi
......
......@@ -135,56 +135,107 @@ def test_multibox_prior():
verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True)
def test_nms():
def verify_nms(x0_data, x1_data, dshape, ref_res, valid_count,
overlap_threshold=0.5, force_suppress=False, topk=-1,
def test_get_valid_counts():
def verify_get_valid_counts(dshape, score_threshold):
dtype = "float32"
batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor):
score = np_data[i, j, 1]
if score >= score_threshold:
for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
inter_idx += 1
if j >= np_out1[i]:
for k in range(elem_length):
np_out2[i, j, k] = -1
x = relay.var("x", relay.ty.TensorType(dshape, dtype))
z = relay.vision.get_valid_counts(x, score_threshold)
assert "score_threshold" in z.astext()
func = relay.Function([x], z.astuple())
func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3)
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
verify_get_valid_counts((3, 1000, 6), 0.55)
verify_get_valid_counts((16, 500, 6), 0.95)
def test_non_max_suppression():
def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res,
iou_threshold=0.5, force_suppress=False, top_k=-1,
check_type_only=False):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int"))
z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, topk)
assert "overlap_threshold" in z.astext()
z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False)
z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k)
assert "iou_threshold" in z.astext()
assert "iou_threshold" in z_indices.astext()
zz = relay.ir_pass.infer_type(z)
zz_indices = relay.ir_pass.infer_type(z_indices)
assert zz.checked_type == relay.ty.TensorType(dshape, "float32")
assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32")
if check_type_only:
return
func = relay.Function([x0, x1], z)
func = relay.ir_pass.infer_type(func)
func_indices = relay.Function([x0, x1], z_indices)
func_indices = relay.ir_pass.infer_type(func_indices)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x0_data, x1_data)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
tvm.testing.assert_allclose(op_indices_res1.asnumpy(), ref_indices_res, rtol=1e-5)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
tvm.testing.assert_allclose(op_indices_res2.asnumpy(), ref_indices_res, rtol=1e-5)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
num_anchors = 5
dshape = (tvm.var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
force_suppress=True, topk=2, check_type_only=True)
verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result,
force_suppress=True, top_k=2, check_type_only=True)
dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
force_suppress=True, topk=2, check_type_only=False)
verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result,
force_suppress=True, top_k=2, check_type_only=False)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[1, 0.7, 30, 60, 50, 80], [-1, 0.9, 35, 61, 52, 79],
[1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, 1, -1, -1]])
dshape = (tvm.var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
check_type_only=True)
verify_nms(np_data, np_valid_count, dshape, np_result,
np_indices_result, check_type_only=True)
dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
topk=3)
verify_nms(np_data, np_valid_count, dshape, np_result,
np_indices_result, top_k=3)
def test_multibox_transform_loc():
......@@ -226,7 +277,7 @@ def test_multibox_transform_loc():
assert ret.checked_type == ref_type
nms = relay.vision.nms(mtl[0], mtl[1])
nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False)
func = relay.Function([cls_prob, loc_pred, anchors], nms)
func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))]
......@@ -411,8 +462,9 @@ if __name__ == "__main__":
test_resize()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
test_get_valid_counts()
test_roi_align()
test_proposal()
test_yolo_reorg_infer_shape()
test_yolo_reorg()
test_non_max_suppression()
......@@ -30,7 +30,12 @@ inline Tensor l2_normalize(const Tensor& data,
const Array<Integer>& axis,
std::string name = "tensor",
std::string tag = "l2_normalize") {
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input";
for (size_t i = 0; i < axis.size(); ++i) {
int ax = topi::detail::GetConstInt(axis[i]);
CHECK_LT(ax, data->shape.size()) <<
"Axis " << ax << " exceeds input data dim " <<
data->shape.size();
}
auto input_shape = data->shape;
Tensor dot_value = topi::power(data, static_cast<float>(2.0));
Tensor sum_value = topi::sum(dot_value, axis, true);
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Non-maximum suppression operator"""
import math
import tvm
from tvm import api
from topi.vision import nms
from topi.vision import non_max_suppression
from ..util import get_const_tuple
def sort_ir(data, index, output):
......@@ -181,13 +181,14 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
return body
@nms.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1):
@non_max_suppression.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False,
topk=-1, id_index=0, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data: tvm.Tensor
data : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
......@@ -195,15 +196,24 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
nms_threshold : float
return_indices : boolean
Whether to return box indices in input data.
iou_threshold : optional, float
Non-maximum suppression threshold.
force_suppress : boolean
force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
topk : optional, int
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
-------
out : tvm.Tensor
......@@ -216,14 +226,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
# An example to use nms
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder(
(dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]).astype("int32")
topk = -1
out = nms(data, valid_count, iou_threshold, force_suppress, topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
......@@ -263,8 +272,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], nms_threshold,
force_suppress, nms_topk),
ins[0], ins[1], ins[2], outs[0], iou_threshold,
force_suppress, topk),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
......
......@@ -11,7 +11,7 @@ import topi
from topi.vision.ssd import multibox_prior
from topi.vision.ssd import multibox_detection
from topi.vision.ssd import multibox_transform_loc
from ..nms import nms
from ..nms import non_max_suppression
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
......@@ -437,6 +437,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = nms(
out = non_max_suppression(
inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out
......@@ -162,3 +162,20 @@ def schedule_proposal(outs):
scheduled_ops.append(op)
traverse(outs[0].op)
return s
@generic.schedule_get_valid_counts.register(["cuda", "gpu"])
def schedule_get_valid_counts(outs):
"""Schedule for get_valid_counts operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of get_valid_counts
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs)
......@@ -37,6 +37,23 @@ def schedule_reorg(outs):
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_get_valid_counts(outs):
"""Schedule for get_valid_counts
Parameters
----------
outs: Array of Tensor
The computation graph description of nms
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_nms(outs):
"""Schedule for non-maximum suppression
......
......@@ -20,3 +20,4 @@ from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python
from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
"""Slice axis in python"""
def slice_axis_python(data, axis, begin, end=None):
"""Slice input array along specific axis.
Parameters
----------
data : numpy.ndarray
The source array to be sliced.
axis : int
Axis to be sliced.
begin: int
The index to begin with in the slicing.
end: int, optional
The index indicating end of the slice.
Returns
-------
ret : numpy.ndarray
The computed result.
"""
dshape = data.shape
if axis < 0:
axis += len(dshape)
if begin < 0:
begin += dshape[axis]
if end <= 0:
end += dshape[axis]
slc = [slice(None)] * len(dshape)
slc[axis] = slice(begin, end)
return data[tuple(slc)]
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements
"""Non-maximum suppression operator"""
import tvm
from tvm import api
from tvm import api, hybrid
def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk):
"""Low level IR routing for transform location in multibox_detection operator.
@hybrid.script
def hybrid_rearrange_out(data):
"""Hybrid routine to rearrange nms output to
move all valid entries to top.
Parameters
----------
data: Buffer
Buffer of output boxes with class and score.
data : tvm.Tensor or numpy NDArray
NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6].
sort_result : Buffer
Buffer of output box indexes sorted by score.
Returns
-------
output : tvm.Tensor or numpy NDArray
Transformed NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6].
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]
output = output_tensor((batch_size,
num_anchors,
elem_length),
data.dtype)
valid_count : Buffer
Buffer of number of valid output boxes.
for i in parallel(batch_size):
valid_idx = 0
for j in range(num_anchors):
if data[i, j, 0] >= 0:
for k in range(elem_length):
output[i, valid_idx, k] = data[i, j, k]
valid_idx += 1
if j >= valid_idx:
for k in range(elem_length):
output[i, j, k] = -1.0
return output
out : Buffer
Output buffer.
nms_threshold : float
Non-maximum suppression threshold.
@hybrid.script
def hybrid_get_valid_counts(data, score_threshold):
"""Hybrid routine to get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the
top of input data.
Parameters
----------
data : tvm.Tensor or numpy NDArray
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : tvm.const
Lower limit of score for valid bounding boxes.
Returns
-------
out_tensor : tvm.Tensor or numpy NDArray
Rearranged data tensor.
valid_count : tvm.Tensor or numpy NDArray
1-D tensor for valid number of boxes.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
box_data_length = data.shape[2]
valid_count = output_tensor((batch_size,), "int32")
out_tensor = output_tensor((batch_size,
num_anchors,
box_data_length),
data.dtype)
for i in parallel(batch_size):
valid_count[i] = 0
for j in range(num_anchors):
score = data[i, j, 1]
if score > score_threshold:
for k in range(box_data_length):
out_tensor[i, valid_count[i], k] = data[i, j, k]
valid_count[i] += 1
if j >= valid_count[i]:
for k in range(box_data_length):
out_tensor[i, j, k] = -1.0
return valid_count, out_tensor
@tvm.target.generic_func
def get_valid_counts(data, score_threshold=0):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : tvm.Tensor
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
out_tensor : tvm.Tensor
Rearranged data tensor.
force_suppress : boolean
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
"""
score_threshold_const = tvm.const(score_threshold, "float")
return hybrid_get_valid_counts(data, score_threshold_const)
@hybrid.script
def hybrid_nms(data, sorted_index, valid_count,
max_output_size, iou_threshold, force_suppress,
top_k, id_index):
"""Hybrid routing for non-maximum suppression.
Parameters
----------
data: tvm.Tensor or numpy NDArray
Bounding boxes with class and score. 3-D tensor with shape
[batch_size, num_anchors, 6].
sorted_index : tvm.Tensor or numpy NDArray
Bounding box indexes sorted by score, with shape
[batch_size, num_anchors].
valid_count : tvm.Tensor or numpy NDArray
1-D tensor for valid number of boxes.
max_output_size : tvm.const
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : tvm.const
Overlapping(IoU) threshold to suppress object with smaller score.
force_suppress : tvm.const
Whether to suppress all detections regardless of class_id.
nms_topk : int
top_k : tvm.const
Keep maximum top k detections before nms, -1 for no limit.
id_index : tvm.const
index of the class categories, -1 to disable.
Returns
-------
stmt : Stmt
The result IR statement.
"""
def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes.
output : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
box_indices: tvm.Tensor
2-D tensor with shape [batch_size, num_anchors].
"""
w = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
- tvm.make.Max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
h = tvm.make.Max(0.0, tvm.make.Min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
- tvm.make.Max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
i = w * h
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
return tvm.expr.Select(u <= 0.0, 0.0, i / u)
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(data)
p_sort_result = ib.buffer_ptr(sort_result)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
batch_size = out.shape[0]
num_anchors = out.shape[1]
nms_threshold_node = tvm.make.node("FloatImm", dtype="float32", value=nms_threshold)
nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk)
force_suppress_node = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="parallel", name="n") as n:
with ib.if_scope(tvm.all(nms_threshold_node > 0, nms_threshold_node < 1,
p_valid_count[0] > 0)):
batch_size = data.shape[0]
num_anchors = data.shape[1]
box_data_length = data.shape[2]
box_indices = output_tensor((batch_size, num_anchors), "int32")
output = output_tensor((batch_size,
num_anchors,
box_data_length,),
data.dtype)
for i in parallel(batch_size):
if iou_threshold > 0:
if valid_count[i] > 0:
# Reorder output
nkeep = tvm.if_then_else(
tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n]),
nms_topk, p_valid_count[n])
with ib.for_range(0, nkeep, name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[(n * num_anchors * 6
+ l * 6 + m)] = p_data[(n * num_anchors * 6
+ p_sort_result[n * num_anchors + l] * 6 + m)]
with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[n])):
with ib.for_range(0, p_valid_count[n] - nkeep, name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[(n * num_anchors * 6
+ (l + nkeep) * 6 + m)] = p_data[(n * num_anchors * 6
+ (l + nkeep) * 6 + m)]
nkeep = valid_count[i]
if 0 < top_k < nkeep:
nkeep = top_k
for j in range(nkeep):
for k in range(box_data_length):
output[i, j, k] = data[i, sorted_index[i, j], k]
box_indices[i, j] = sorted_index[i, j]
if 0 < top_k < valid_count[i]:
for j in range(valid_count[i] - nkeep):
for k in range(box_data_length):
output[i, j + nkeep, k] = -1.0
box_indices[i, j + nkeep] = -1
# Apply nms
with ib.for_range(0, p_valid_count[n], name="l") as l:
offset_l = l * 6
with ib.if_scope(p_out[n * num_anchors * 6 + offset_l] >= 0):
with ib.for_range(0, p_valid_count[n], name="m") as m:
offset_m = m * 6
with ib.if_scope(tvm.all(m > l, p_out[n * num_anchors * 6
+ offset_m] >= 0)):
with ib.if_scope(tvm.any(force_suppress_node > 0,
p_out[n * num_anchors * 6 + offset_l] ==
p_out[n * num_anchors * 6 + offset_m])):
# When force_suppress == True or class_id equals
iou = calculate_overlap(p_out, n * num_anchors * 6 + offset_l + 2,
n * num_anchors * 6 + offset_m + 2)
with ib.if_scope(iou >= nms_threshold):
p_out[n * num_anchors * 6 + offset_m] = -1.0
with ib.else_scope():
with ib.for_range(0, p_valid_count[n], name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[(n * num_anchors * 6
+ l * 6 + m)] = p_data[n * num_anchors * 6 + l * 6 + m]
for j in range(valid_count[i]):
if output[i, j, 0] >= 0:
for k in range(valid_count[i]):
check_iou = 0
if k > j and output[i, k, 0] >= 0:
if force_suppress:
check_iou = 1
elif id_index < 0 or output[i, j, 0] == output[i, k, 0]:
check_iou = 1
if check_iou > 0:
batch_idx = i
box_a_idx = j
box_b_idx = k
box_start_idx = 2
a_t = output[batch_idx, box_a_idx, box_start_idx + 1]
a_b = output[batch_idx, box_a_idx, box_start_idx + 3]
a_l = output[batch_idx, box_a_idx, box_start_idx]
a_r = output[batch_idx, box_a_idx, box_start_idx + 2]
b_t = output[batch_idx, box_b_idx, box_start_idx + 1]
b_b = output[batch_idx, box_b_idx, box_start_idx + 3]
b_l = output[batch_idx, box_b_idx, box_start_idx]
b_r = output[batch_idx, box_b_idx, box_start_idx + 2]
w = max(0.0, min(a_r, b_r) - max(a_l, b_l))
h = max(0.0, min(a_b, b_b) - max(a_t, b_t))
area = h * w
u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
iou = 0.0 if u <= 0.0 else area / u
if iou >= iou_threshold:
output[i, k, 0] = -1.0
box_indices[i, k] = -1
else:
for j in range(valid_count[i]):
for k in range(box_data_length):
output[i, j, k] = data[i, j, k]
box_indices[i, j] = j
# Set invalid entry to be -1
with ib.for_range(0, num_anchors - p_valid_count[n], name="l") as l:
with ib.for_range(0, 6, name="m") as m:
p_out[n * num_anchors * 6 + (l + p_valid_count[n]) * 6 + m] = -1.0
return ib.get()
for j in range(num_anchors - valid_count[i]):
for k in range(box_data_length):
output[i, j + valid_count[i], k] = -1.0
box_indices[i, j + valid_count[i]] = -1
# Only return max_output_size valid boxes
num_valid_boxes = 0
if max_output_size > 0:
for j in range(valid_count[i]):
if output[i, j, 0] >= 0:
if num_valid_boxes == max_output_size:
for k in range(box_data_length):
output[i, j, k] = -1.0
box_indices[i, j] = -1
else:
num_valid_boxes += 1
return output, box_indices
@tvm.target.generic_func
def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1):
def non_max_suppression(data, valid_count, max_output_size=-1,
iou_threshold=0.5, force_suppress=False, top_k=-1,
id_index=0, return_indices=True, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data: tvm.Tensor
data : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
......@@ -120,15 +249,28 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
nms_threshold : float
max_output_size : optional, int
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : optional, float
Non-maximum suppression threshold.
force_suppress : boolean
force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
return_indices : optional, boolean
Whether to return box indices in input data.
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
-------
out : tvm.Tensor
......@@ -138,16 +280,17 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
--------
.. code-block:: python
# An example to use nms
# An example to use non_max_suppression
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
iou_threshold = 0.7
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]).astype("int32")
top_k = -1
out = non_max_suppression(data, valid_count, iou_threshold=iou_threshold,
force_suppress=force_suppress, top_k=top_k)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
......@@ -161,7 +304,6 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
"valid_count_buf", data_alignment=4)
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
score_axis = 1
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
......@@ -180,13 +322,13 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
out = \
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], nms_threshold,
force_suppress, nms_topk),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
return out
out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"),
tvm.const(force_suppress, dtype="bool"),
tvm.const(top_k, dtype="int32"),
tvm.const(id_index, dtype="int32"))
if not return_indices and invalid_to_bottom:
out = hybrid_rearrange_out(out)
return box_indices if return_indices else out
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable
"""SSD multibox operators"""
from __future__ import absolute_import as _abs
import math
import tvm
from tvm import api
from tvm import hybrid
from tvm.intrin import exp, sqrt
import topi
from ..nms import nms
from ..nms import non_max_suppression
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
"""Low level IR routing for multibox_prior operator.
@hybrid.script
def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
"""Hybrid routing for multibox_prior operator.
Parameters
----------
data : Buffer
Input data buffer.
data : tvm.Tensor or numpy NDArray
4-D tensor with shape [batch, channel, height, width]]
out : Buffer
Output buffer.
sizes : tvm ConsExpr
Sizes for anchor boxes.
sizes : tuple of float
Tuple of sizes for anchor boxes.
ratios : tuple of float
Tuple of ratios for anchor boxes.
ratios : tvm ConsExpr
Ratios for anchor boxes.
steps : Tuple of float
steps : tvm ConsExpr
Priorbox step across y and x, -1 for auto calculation.
offsets : tuple of int
offsets : tvm ConsExpr
Priorbox center offsets, y and x respectively.
Returns
-------
stmt : Stmt
The result IR statement.
output : tvm.Tensor or numpy NDArray
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
ib = tvm.ir_builder.create()
p_out = ib.buffer_ptr(out)
in_height = data.shape[2]
in_width = data.shape[3]
num_sizes = len(sizes)
num_ratios = len(ratios)
size_ratio_concat = sizes + ratios
steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
num_boxes = in_height * in_width * (num_sizes + num_ratios - 1)
output = output_tensor((1, num_boxes, 4), "float32")
steps_h = steps[0] * 1.0 if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] * 1.0 if steps[1] > 0 else 1.0 / in_width
offset_h = offsets[0]
offset_w = offsets[1]
with ib.for_range(0, in_height, for_type="parallel", name="i") as i:
# Need to define var out of const_range + if
w = 0.0
h = 0.0
for i in parallel(in_height):
center_h = (i + offset_h) * steps_h
with ib.for_range(0, in_width, name="j") as j:
for j in range(in_width):
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
w = tvm.if_then_else(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = tvm.if_then_else(
k < num_sizes, size_ratio_concat[k] / 2.0,
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
count = (i * in_width * (num_sizes + num_ratios - 1) +
j * (num_sizes + num_ratios - 1) + k) * 4
p_out[count] = center_w - w
p_out[count + 1] = center_h - h
p_out[count + 2] = center_w + w
p_out[count + 3] = center_h + h
return ib.get()
for k in const_range(num_sizes + num_ratios - 1):
if k < num_sizes:
w = sizes[k] * in_height / in_width / 2.0
h = sizes[k] / 2.0
else:
w = sizes[0] * in_height / in_width \
* sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
count = i * in_width * (num_sizes + num_ratios - 1) \
+ j * (num_sizes + num_ratios - 1) + k
output[0, count, 0] = center_w - w
output[0, count, 1] = center_h - h
output[0, count, 2] = center_w + w
output[0, count, 3] = center_h + h
return output
@tvm.target.generic_func
......@@ -101,115 +102,120 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
out : tvm.Tensor
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
num_sizes = len(sizes)
num_ratios = len(ratios)
oshape = (1, data.shape[2] * data.shape[3] * (num_sizes + num_ratios - 1), 4)
out = tvm.extern(oshape, [data], lambda ins, outs:
multibox_prior_ir(ins[0], outs[0], sizes, ratios, steps, offsets),
tag="multibox_prior")
out = hybrid_multibox_prior(data, tvm.convert(sizes), tvm.convert(ratios),
tvm.convert(steps), tvm.convert(offsets))
if clip:
out = topi.clip(out, 0, 1)
return out
def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, threshold, variances):
"""Low level IR routing for transform location in multibox_detection operator.
@hybrid.script
def _hybridy_transform_loc(box, pred_loc, variance, clip):
"""Transform prior anchor box to output box through location predictions.
"""
al = box[0]
at = box[1]
ar = box[2]
ab = box[3]
Parameters
----------
cls_prob : Buffer
Buffer of class probabilities.
px = pred_loc[0]
py = pred_loc[1]
pw = pred_loc[2]
ph = pred_loc[3]
loc_pred : Buffer
Buffer of location regression predictions.
vx = variance[0]
vy = variance[1]
vw = variance[2]
vh = variance[3]
anchor : Buffer
Buffer of prior anchor boxes.
output = output_tensor((4,), pred_loc.dtype)
valid_count : Buffer
Buffer of number of valid output boxes.
aw = ar - al
ah = ab - at
ax = (al + ar) / 2.0
ay = (at + ab) / 2.0
ox = px * vx * aw + ax
oy = py * vy * ah + ay
ow = exp(pw * vw) * aw / 2.0
oh = exp(ph * vh) * ah / 2.0
output[0] = max(0.0, min(1.0, ox - ow)) if clip else ox - ow
output[1] = max(0.0, min(1.0, oy - oh)) if clip else oy - oh
output[2] = max(0.0, min(1.0, ox + ow)) if clip else ox + ow
output[3] = max(0.0, min(1.0, oy + oh)) if clip else oy + oh
return output
@hybrid.script
def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances):
"""Hybrid routing for transform location in multibox_detection operator.
out : Buffer
Output buffer.
Parameters
----------
cls_prob : tvm.Tensor or numpy NDArray
3-D tensor of class probabilities.
clip : boolean
loc_pred : tvm.Tensor or numpy NDArray
2-D tensor of location regression predictions.
anchor : tvm.Tensor or numpy NDArray
3-D tensor of prior anchor boxes.
clip : tvm.const
Whether to clip out-of-boundary boxes.
threshold : float
threshold : tvm.const
Threshold to be a positive prediction.
variances : tuple of float
variances : tvm.ndarray
Variances to be decoded from box regression output.
Returns
-------
stmt : Stmt
The result IR statement.
"""
def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, vh):
"""Transform prior anchor box to output box through location predictions.
"""
al = anchor[anchor_base_idx]
at = anchor[anchor_base_idx + 1]
ar = anchor[anchor_base_idx + 2]
ab = anchor[anchor_base_idx + 3]
aw = ar - al
ah = ab - at
ax = (al + ar) / 2.0
ay = (at + ab) / 2.0
px = loc[loc_base_idx]
py = loc[loc_base_idx + 1]
pw = loc[loc_base_idx + 2]
ph = loc[loc_base_idx + 3]
ox = px * vx * aw + ax
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \
tvm.if_then_else(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh)
out_loc : tvm.Tensor or numpy NDArray
3-D tensor of transformed location.
valid_count : tvm.Tensor or numpy NDArray
1_d tensor of valid counts for boxes.
"""
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
ib = tvm.ir_builder.create()
p_cls_prob = ib.buffer_ptr(cls_prob)
p_loc_pred = ib.buffer_ptr(loc_pred)
p_anchor = ib.buffer_ptr(anchor)
p_valid_count = ib.buffer_ptr(valid_count)
p_out = ib.buffer_ptr(out)
with ib.for_range(0, batch_size, for_type="parallel", name="n") as n:
p_valid_count[n] = 0
with ib.for_range(0, num_anchors, name="i") as i:
box_coord = allocate((4,), loc_pred.dtype)
pred_coord = allocate((4,), loc_pred.dtype)
out_loc = output_tensor((batch_size, num_anchors, 6),
loc_pred.dtype)
valid_count = output_tensor((batch_size,), "int32")
for i in parallel(batch_size):
valid_count[i] = 0
for j in range(num_anchors):
# Find the predicted class id and probability
score = ib.allocate('float32', (1,), name="score", scope="local")
cls_id = ib.allocate('int32', (1,), name="id", scope="local")
score[0] = -1.0
cls_id[0] = 0
with ib.for_range(0, num_classes, name="j") as j:
with ib.if_scope(j > 0):
temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i]
cls_id[0] = tvm.if_then_else(temp > score[0], j, cls_id[0])
score[0] = tvm.max(temp, score[0])
with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)):
cls_id[0] = 0
score = -1.0
cls_id = 0
for k in range(num_classes):
if k > 0:
temp = cls_prob[i, k, j]
cls_id = k if temp > score else cls_id
score = max(temp, score)
if cls_id > 0 and score < threshold:
cls_id = 0
# [id, prob, xmin, ymin, xmax, ymax]
# Remove background, restore original id
with ib.if_scope(cls_id[0] > 0):
out_base_idx = n * num_anchors * 6 + p_valid_count[n] * 6
p_out[out_base_idx] = cls_id[0] - 1.0
p_out[out_base_idx + 1] = score[0]
offset = i * 4
p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \
p_out[out_base_idx + 5] = transform_loc(p_loc_pred, n * num_anchors * 4 + offset,
p_anchor, offset, clip, variances[0],
variances[1], variances[2], variances[3])
p_valid_count[n] += 1
return ib.get()
if cls_id > 0:
out_loc[i, valid_count[i], 0] = cls_id - 1.0
out_loc[i, valid_count[i], 1] = score
for l in range(4):
box_coord[l] = anchor[0, j, l]
pred_coord[l] = loc_pred[i, j * 4 + l]
out_coord = _hybridy_transform_loc(box_coord, pred_coord,
variances, clip)
out_loc[i, valid_count[i], 2] = out_coord[0]
out_loc[i, valid_count[i], 3] = out_coord[1]
out_loc[i, valid_count[i], 4] = out_coord[2]
out_loc[i, valid_count[i], 5] = out_coord[3]
valid_count[i] += 1
return out_loc, valid_count
@tvm.target.generic_func
def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
......@@ -240,24 +246,10 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
-------
ret : tuple of tvm.Tensor
"""
batch_size = cls_prob.shape[0]
num_anchors = anchor.shape[1]
oshape = (batch_size, num_anchors, 6)
# Define data alignment for intermediate buffer
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8)
valid_count, out = \
tvm.extern([(batch_size,), oshape],
[cls_prob, loc_pred, anchor],
lambda ins, outs: transform_loc_ir(
ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances),
dtype=[valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf, out_buf],
tag="multibox_transform_loc")
return [out, valid_count]
return hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
tvm.const(clip, "bool"),
tvm.const(threshold, "float32"),
tvm.convert(variances))
@tvm.target.generic_func
def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
......@@ -300,5 +292,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
out = non_max_suppression(inter_out[0], inter_out[1], -1,
nms_threshold, force_suppress, nms_topk,
return_indices=False)
return out
......@@ -8,11 +8,62 @@ import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi.vision import ssd, nms
from topi.vision import ssd, non_max_suppression, get_valid_counts
def verify_get_valid_counts(dshape, score_threshold):
dtype = "float32"
batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor):
score = np_data[i, j, 1]
if score > score_threshold:
for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
inter_idx += 1
if j >= np_out1[i]:
for k in range(elem_length):
np_out2[i, j, k] = -1.0
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
data = tvm.placeholder(dshape, name="data", dtype=dtype)
outs = get_valid_counts(data, score_threshold)
s = topi.generic.schedule_multibox_prior(outs)
tvm_input_data = tvm.nd.array(np_data, ctx)
tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx)
tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx)
f = tvm.build(s, [data, outs[0], outs[1]], device)
f(tvm_input_data, tvm_out1, tvm_out2)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
def test_nms():
for device in ['llvm']:
check_device(device)
def test_get_valid_counts():
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
verify_get_valid_counts((3, 1000, 6), 0.55)
verify_get_valid_counts((16, 500, 6), 0.95)
def test_non_max_suppression():
dshape = (1, 5, 6)
indices_dshape = (1, 5)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
......@@ -24,8 +75,9 @@ def test_nms():
[1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -35,18 +87,27 @@ def test_nms():
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False)
indices_out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk)
else:
out = topi.cuda.nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False)
indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk)
s = topi.generic.schedule_nms(out)
indices_s = topi.generic.schedule_nms(indices_out)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f = tvm.build(s, [data, valid_count, out], device)
f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
f = tvm.build(indices_s, [data, valid_count, indices_out], device)
f(tvm_data, tvm_valid_count, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
for device in ['llvm']:
check_device(device)
......@@ -274,7 +335,8 @@ def test_proposal():
if __name__ == "__main__":
test_nms()
test_get_valid_counts()
test_non_max_suppression()
test_multibox_prior()
test_multibox_detection()
test_roi_align()
......
"""
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use GluonCV pre-trained SSD model and convert it to Relay IR
"""
import tvm
from matplotlib import pyplot as plt
from nnvm import compiler
from nnvm.frontend import from_mxnet
from nnvm.testing.config import ctx_list
from tvm import relay
from tvm.contrib import graph_runtime
from gluoncv import model_zoo, data, utils
######################################################################
# Preliminary and Set parameters
# ------------------------------
# We should build TVM with sort support, in TVM root directory
#
# .. code-block:: bash
#
# echo "set(USE_SORT ON)" > config.mk
# make -j8
#
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
#
# To get best inference performance on CPU, change
# target argument according to your device and
# follow the :ref:`tune_relay_x86` to tune x86 CPU and
# :ref:`tune_relay_arm` for arm cpu.
#
# SSD with VGG as body network is not supported yet since
# x86 conv2d schedule doesn't support dilation.
supported_model = [
'ssd_512_resnet18_v1_voc',
'ssd_512_resnet18_v1_coco',
'ssd_512_resnet50_v1_voc',
'ssd_512_resnet50_v1_coco',
'ssd_512_resnet101_v2_voc',
'ssd_512_mobilenet1_0_voc',
'ssd_512_mobilenet1_0_coco',
]
model_name = "ssd_512_resnet50_v1_voc"
dshape = (1, 3, 512, 512)
dtype = "float32"
target_list = ctx_list()
######################################################################
# Download and pre-process demo image
im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' +
'gluoncv/detection/street_small.jpg?raw=true',
path='street_small.jpg')
x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
######################################################################
# Convert and compile model for CPU.
block = model_zoo.get_model(model_name, pretrained=True)
def compile(target):
net, params = relay.frontend.from_mxnet(block, {"data": dshape})
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
return graph, lib, params
######################################################################
# Create TVM runtime and do inference
def run(graph, lib, params, ctx):
# Build TVM runtime
m = graph_runtime.create(graph, lib, ctx)
tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx)
m.set_input('data', tvm_input)
m.set_input(**params)
# execute
m.run()
# get outputs
class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2)
return class_IDs, scores, bounding_boxs
for target, ctx in target_list:
if target == "cuda":
print("GPU not supported yet, skip.")
continue
graph, lib, params = compile(target)
class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
######################################################################
# Display result
ax = utils.viz.plot_bbox(img, bounding_boxs.asnumpy()[0], scores.asnumpy()[0],
class_IDs.asnumpy()[0], class_names=block.classes)
plt.show()
......@@ -61,7 +61,7 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder = \
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment