Commit d2f29ba5 by Yao Wang Committed by Tianqi Chen

[Object Detection] Gluoncv SSD support on CPU (#2353)

parent f7eff095
......@@ -58,19 +58,42 @@ struct MultiBoxTransformLocAttrs
}
};
/*! \brief Attributes used in non_maximum_suppression operators */
struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
double overlap_threshold;
/*! \brief Attributes used in get_valid_counts operator */
struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
double score_threshold;
TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
TVM_ATTR_FIELD(score_threshold).set_default(0.0)
.describe("Lower limit of score for valid bounding boxes.");
}
};
/*! \brief Attributes used in non_maximum_suppression operator */
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
int max_output_size;
double iou_threshold;
bool force_suppress;
int topk;
TVM_DECLARE_ATTRS(NMSAttrs, "relay.attrs.NMSAttrs") {
TVM_ATTR_FIELD(overlap_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
int top_k;
int id_index;
bool return_indices;
bool invalid_to_bottom;
TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_output_size).set_default(-1)
.describe("Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
TVM_ATTR_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(return_indices).set_default(true)
.describe("Whether to return box indices in input data.");
TVM_ATTR_FIELD(invalid_to_bottom).set_default(false)
.describe("Whether to move all invalid bounding boxes to the bottom.");
}
};
......
......@@ -443,17 +443,30 @@ struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocPa
}
};
struct NMSParam : public dmlc::Parameter<NMSParam> {
float nms_threshold;
struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppressionParam> {
bool return_indices;
float iou_threshold;
bool force_suppress;
int nms_topk;
DMLC_DECLARE_PARAMETER(NMSParam) {
DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5)
int top_k;
int id_index;
int max_output_size;
bool invalid_to_bottom;
DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) {
DMLC_DECLARE_FIELD(max_output_size).set_default(-1)
.describe("Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
DMLC_DECLARE_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(nms_topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
DMLC_DECLARE_FIELD(return_indices).set_default(true)
.describe("Whether to return box indices in input data.");
DMLC_DECLARE_FIELD(invalid_to_bottom).set_default(false)
.describe("Whether to move all invalid bounding boxes to the bottom.");
}
};
......
......@@ -245,11 +245,11 @@ def _contrib_multibox_detection(inputs, attrs):
if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2)
nms_topk = attrs.get('nms_topk') or -1
new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances}
new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress,
'nms_topk': int(nms_topk)}
new_attrs1 = {'return_indices': False, 'iou_threshold': float(nms_threshold),
'force_suppress': force_suppress, 'top_k': int(nms_topk)}
data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _get_nnvm_op('nms')(data, valid_count, **new_attrs1)
return _get_nnvm_op('non_max_suppression')(data, valid_count, **new_attrs1)
def _elemwise_sum(inputs, _):
new_attrs = {'num_args':len(inputs)}
......
......@@ -61,20 +61,25 @@ def compute_multibox_transform_loc(attrs, inputs, _):
reg.register_pattern("multibox_detection", OpPattern.OPAQUE)
# non-maximum suppression
@reg.register_schedule("nms")
@reg.register_schedule("non_max_suppression")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
"""Schedule definition of non_max_suppression"""
with tvm.target.create(target):
return topi.generic.schedule_nms(outs)
@reg.register_compute("nms")
@reg.register_compute("non_max_suppression")
def compute_nms(attrs, inputs, _):
"""Compute definition of nms"""
nms_threshold = attrs.get_float('nms_threshold')
"""Compute definition of non_max_suppression"""
return_indices = attrs.get_bool('return_indices')
max_output_size = attrs.get_int('max_output_size')
iou_threshold = attrs.get_float('iou_threshold')
force_suppress = attrs.get_bool('force_suppress')
nms_topk = attrs.get_int('nms_topk')
top_k = attrs.get_int('top_k')
id_index = attrs.get_int('id_index')
invalid_to_bottom = attrs.get_bool('invalid_to_bottom')
return topi.vision.nms(inputs[0], inputs[1], nms_threshold,
force_suppress, nms_topk)
return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
reg.register_pattern("nms", OpPattern.OPAQUE)
reg.register_pattern("non_max_suppression", OpPattern.OPAQUE)
......@@ -19,11 +19,13 @@ using compiler::FTVMCompute;
using tvm::Tensor;
using tvm::Array;
DMLC_REGISTER_PARAMETER(NMSParam);
DMLC_REGISTER_PARAMETER(NonMaximumSuppressionParam);
bool NMSShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const NonMaximumSuppressionParam& param =
nnvm::get<NonMaximumSuppressionParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]";
TShape dshape = in_attrs->at(0);
TShape vshape = in_attrs->at(1);
......@@ -33,7 +35,14 @@ bool NMSShape(const NodeAttrs& attrs,
"(batch_size, num_anchors, 6).";
CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch.";
out_attrs->clear();
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape);
if (param.return_indices) {
TShape oshape = TShape(2);
oshape[0] = dshape[0];
oshape[1] = dshape[1];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
} else {
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape);
}
return true;
}
......@@ -56,15 +65,15 @@ inline bool NMSInferLayout(const NodeAttrs& attrs,
return true;
}
NNVM_REGISTER_OP(nms)
NNVM_REGISTER_OP(non_max_suppression)
.describe(R"doc("Non-maximum suppression."
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NMSParam>)
.set_attr_parser(ParamParser<NonMaximumSuppressionParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<NMSParam>)
.add_arguments(NMSParam::__FIELDS__())
ParamGetAttrDict<NonMaximumSuppressionParam>)
.add_arguments(NonMaximumSuppressionParam::__FIELDS__())
.add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
......
......@@ -550,7 +550,7 @@ def test_multibox_transform_loc():
anchors = sym.Variable("anchors")
transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds,
anchor=anchors)
out = sym.nms(data=transform_loc_data, valid_count=valid_count)
out = sym.non_max_suppression(data=transform_loc_data, valid_count=valid_count, return_indices=False)
# Manually create test case
np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]])
......@@ -573,22 +573,22 @@ def test_multibox_transform_loc():
out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
def test_nms():
def test_non_max_suppression():
dshape = (1, 5, 6)
data = sym.Variable("data")
valid_count = sym.Variable("valid_count", dtype="int32")
nms_threshold = 0.7
iou_threshold = 0.7
force_suppress = True
nms_topk = 2
out = sym.nms(data=data, valid_count=valid_count, nms_threshold=nms_threshold,
force_suppress=force_suppress, nms_topk=nms_topk)
top_k = 2
out = sym.non_max_suppression(data=data, valid_count=valid_count, return_indices=False,
iou_threshold=iou_threshold, force_suppress=force_suppress, top_k=top_k)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
target = "llvm"
......@@ -726,7 +726,7 @@ if __name__ == "__main__":
test_flip()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
test_non_max_suppression()
test_slice_like()
test_where()
test_argmax()
......
......@@ -315,4 +315,3 @@ if __name__ == '__main__':
test_forward_slice()
test_forward_maximum()
test_forward_minimum()
......@@ -328,13 +328,14 @@ def _mx_multibox_detection(inputs, attrs):
0.2, 0.2))
new_attrs1 = {}
new_attrs1["overlap_threshold"] = attrs.get_float("nms_threshold", 0.5)
new_attrs1["return_indices"] = False
new_attrs1["iou_threshold"] = attrs.get_float("nms_threshold", 0.5)
new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False)
new_attrs1["topk"] = attrs.get_int("nms_topk", -1)
new_attrs1["top_k"] = attrs.get_int("nms_topk", -1)
ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _op.vision.nms(ret[0], ret[1], **new_attrs1)
return _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1)
def _mx_batch_dot(inputs, attrs):
......@@ -399,6 +400,49 @@ def _mx_proposal(inputs, attrs):
return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs)
def _mx_box_nms(inputs, attrs):
force_suppress = attrs.get_bool("force_suppress", False)
iou_thresh = attrs.get_float('overlap_thresh', 0.5)
top_k = attrs.get_int('topk', -1)
valid_thresh = attrs.get_float('valid_thresh', 0)
coord_start = attrs.get_int('coord_start', 2)
score_index = attrs.get_int('score_index', 1)
id_index = attrs.get_int('id_index', -1)
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
if coord_start != 2:
raise RuntimeError('coord_start %s is not supported.' % coord_start)
if score_index != 1:
raise RuntimeError('score_index %s is not supported.' % score_index)
if id_index != -1 and int(id_index) != 0:
raise RuntimeError('id_index %s is not supported.' % id_index)
if in_format != 'corner':
raise RuntimeError('in_format %s is not supported.' % in_format)
if out_format != 'corner':
raise RuntimeError('out_format %s is not supported.' % out_format)
ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh)
nms_out = _op.vision.non_max_suppression(ret[1],
ret[0],
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
id_index=id_index,
return_indices=False,
invalid_to_bottom=True)
return nms_out
def _mx_l2_normalize(inputs, attrs):
new_attrs = {}
mode = attrs.get_str('mode', 'instance')
if mode != 'channel':
raise RuntimeError('mode %s is not supported.' % mode)
new_attrs['eps'] = attrs.get_float('eps', 1e-10)
new_attrs['axis'] = [1]
return _op.nn.l2_normalize(inputs[0], **new_attrs)
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -497,6 +541,7 @@ _convert_map = {
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
"slice" : _mx_slice,
"slice_like" : _mx_slice_like,
"slice_axis" : _mx_slice_axis,
......@@ -520,6 +565,7 @@ _convert_map = {
"_contrib_ROIAlign" : _mx_roi_align,
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
......@@ -662,6 +708,8 @@ def from_mxnet(symbol,
params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data")
sym = symbol(data)
if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(sym, shape, dtype)
elif isinstance(symbol, mx.gluon.Block):
......
......@@ -525,7 +525,7 @@ def strided_slice(data, begin, end, strides=None):
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
Indices indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
......
......@@ -6,6 +6,6 @@ from .multibox import *
from .nms import *
from .rcnn import *
from .yolo import *
from . import _multibox
from . import _rcnn
from . import _yolo
from . import _vision
......@@ -54,24 +54,46 @@ reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)
# Get counts of valid boxes
@reg.register_schedule("vision.get_valid_counts")
def schedule_get_valid_counts(_, outs, target):
"""Schedule definition of get_valid_counts"""
with target:
return topi.generic.schedule_get_valid_counts(outs)
@reg.register_compute("vision.get_valid_counts")
def compute_get_valid_counts(attrs, inputs, _, target):
"""Compute definition of get_valid_counts"""
score_threshold = get_const_float(attrs.score_threshold)
return topi.vision.get_valid_counts(inputs[0], score_threshold)
reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
# non-maximum suppression
@reg.register_schedule("vision.nms")
@reg.register_schedule("vision.non_max_suppression")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
with target:
return topi.generic.schedule_nms(outs)
@reg.register_compute("vision.nms")
@reg.register_compute("vision.non_max_suppression")
def compute_nms(attrs, inputs, _, target):
"""Compute definition of nms"""
overlap_threshold = get_const_float(attrs.overlap_threshold)
return_indices = bool(get_const_int(attrs.return_indices))
max_output_size = get_const_int(attrs.max_output_size)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
topk = get_const_int(attrs.topk)
top_k = get_const_int(attrs.top_k)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
return [
topi.vision.nms(inputs[0], inputs[1], overlap_threshold,
force_suppress, topk)
topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
]
reg.register_pattern("vision.nms", OpPattern.OPAQUE)
reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)
"""Non-maximum suppression operations."""
from __future__ import absolute_import as _abs
from . import _make
from ...expr import TupleWrapper
def nms(data,
valid_count,
overlap_threshold=0.5,
force_suppress=False,
topk=-1):
def get_valid_counts(data,
score_threshold):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : relay.Expr
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
valid_count : relay.Expr
1-D tensor for valid number of boxes.
out_tensor : relay.Expr
Rearranged data tensor.
"""
return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2)
def non_max_suppression(data,
valid_count,
max_output_size=-1,
iou_threshold=0.5,
force_suppress=False,
top_k=-1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
......@@ -19,18 +48,33 @@ def nms(data,
valid_count : relay.Expr
1-D tensor for valid number of boxes.
overlap_threshold : float, optional
max_output_size : int, optional
Max number of output valid boxes for each instance.
By default all valid boxes are returned.
iou_threshold : float, optional
Non-maximum suppression threshold.
force_suppress : bool, optional
Suppress all detections regardless of class_id.
topk : int, optional
top_k : int, optional
Keep maximum top k detections before nms, -1 for no limit.
id_index : int, optional
index of the class categories, -1 to disable.
return_indices : bool, optional
Whether to return box indices in input data.
invalid_to_bottom : bool, optional
Whether to move all valid bounding boxes to the top.
Returns
-------
out : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
"""
return _make.nms(data, valid_count, overlap_threshold, force_suppress, topk)
return _make.non_max_suppression(data, valid_count, max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
......@@ -1516,6 +1516,16 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
// strided_slice
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
bool StridedSliceRel(const Array<Type>& types,
......@@ -1870,15 +1880,6 @@ Expr MakeSliceLike(Expr data,
return CallNode::make(op, {data, shape_like}, Attrs(attrs), {});
}
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
CHECK(!arr[i].defined() || arr[i].as<IntImm>())
<< "Expect an int array";
}
return Array<Integer>(arr.node_);
}
Array<Tensor> SliceLikeCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
......
......@@ -70,8 +70,10 @@ RELAY_REGISTER_OP("vision.multibox_prior")
TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs);
bool MultiBoxTransformLocRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
bool MultiBoxTransformLocRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* cls_prob = types[0].as<TensorTypeNode>();
......
......@@ -9,7 +9,54 @@
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(NMSAttrs);
TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs);
bool GetValidCountRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
const auto& dshape = data->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
std::vector<IndexExpr> oshape({data->shape[0]});
std::vector<Type> fields;
fields.push_back(TensorTypeNode::make(oshape, Int(32)));
fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
// assign output type
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
return true;
}
Expr MakeGetValidCounts(Expr data,
double score_threshold) {
auto attrs = make_node<GetValidCountsAttrs>();
attrs->score_threshold = score_threshold;
static const Op& op = Op::Get("vision.get_valid_counts");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.get_valid_counts")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGetValidCounts, args, rv);
});
RELAY_REGISTER_OP("vision.get_valid_counts")
.describe(R"doc(Get valid count of bounding boxes given
a score threshold. Also moves valid boxes to the top of
input data.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(5)
.add_type_rel("GetValidCount", GetValidCountRel);
TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
bool NMSRel(const Array<Type>& types,
int num_inputs,
......@@ -18,39 +65,56 @@ bool NMSRel(const Array<Type>& types,
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* valid_count = types[1].as<TensorTypeNode>();
const NonMaximumSuppressionAttrs* param =
attrs.as<NonMaximumSuppressionAttrs>();
const auto& dshape = data->shape;
const auto& vshape = valid_count->shape;
CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
CHECK_EQ(vshape.size(), 1) << "Input valid count should be 1-D.";
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
if (param->return_indices) {
std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, Int(32)));
} else {
reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
}
return true;
}
Expr MakeNMS(Expr data,
Expr valid_count,
double overlap_threshold,
int max_output_size,
double iou_threshold,
bool force_suppress,
int topk) {
auto attrs = make_node<NMSAttrs>();
attrs->overlap_threshold = overlap_threshold;
int top_k,
int id_index,
bool return_indices,
bool invalid_to_bottom) {
auto attrs = make_node<NonMaximumSuppressionAttrs>();
attrs->max_output_size = max_output_size;
attrs->iou_threshold = iou_threshold;
attrs->force_suppress = force_suppress;
attrs->topk = topk;
static const Op& op = Op::Get("vision.nms");
attrs->top_k = top_k;
attrs->id_index = id_index;
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
static const Op& op = Op::Get("vision.non_max_suppression");
return CallNode::make(op, {data, valid_count}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.vision._make.nms")
TVM_REGISTER_API("relay.op.vision._make.non_max_suppression")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeNMS, args, rv);
runtime::detail::unpack_call<Expr, 9>(MakeNMS, args, rv);
});
RELAY_REGISTER_OP("vision.nms")
.describe(R"doc("Non-maximum suppression."
RELAY_REGISTER_OP("vision.non_max_suppression")
.describe(R"doc(Non-maximum suppression. The input boxes should
be in the format of [class_id, score, left, top, right, bottom].
Set id_index to be -1 to ignore class_id axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
......
......@@ -374,6 +374,11 @@ def test_forward_slice_like():
verify((3, 4), (2, 3), (0))
verify((3, 4), (2, 3), (-1))
def test_forward_l2_normalize():
data = mx.sym.var('data')
mx_sym = mx.sym.L2Normalization(data, mode="channel")
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
if __name__ == '__main__':
test_forward_mlp()
......@@ -401,5 +406,6 @@ if __name__ == '__main__':
test_forward_broadcast_ops()
test_forward_elemwise_ops()
test_forward_scalar_ops()
test_forward_slice_axis()
test_forward_slice_like()
test_forward_slice_axis()
test_forward_l2_normalize()
......@@ -2,6 +2,7 @@
"""
import numpy as np
import tvm
import topi.testing
from tvm import relay
from tvm.relay.testing import ctx_list
import topi
......
......@@ -135,56 +135,107 @@ def test_multibox_prior():
verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True)
def test_nms():
def verify_nms(x0_data, x1_data, dshape, ref_res, valid_count,
overlap_threshold=0.5, force_suppress=False, topk=-1,
def test_get_valid_counts():
def verify_get_valid_counts(dshape, score_threshold):
dtype = "float32"
batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor):
score = np_data[i, j, 1]
if score >= score_threshold:
for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
inter_idx += 1
if j >= np_out1[i]:
for k in range(elem_length):
np_out2[i, j, k] = -1
x = relay.var("x", relay.ty.TensorType(dshape, dtype))
z = relay.vision.get_valid_counts(x, score_threshold)
assert "score_threshold" in z.astext()
func = relay.Function([x], z.astuple())
func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3)
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
verify_get_valid_counts((3, 1000, 6), 0.55)
verify_get_valid_counts((16, 500, 6), 0.95)
def test_non_max_suppression():
def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res,
iou_threshold=0.5, force_suppress=False, top_k=-1,
check_type_only=False):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int"))
z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, topk)
assert "overlap_threshold" in z.astext()
z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False)
z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k)
assert "iou_threshold" in z.astext()
assert "iou_threshold" in z_indices.astext()
zz = relay.ir_pass.infer_type(z)
zz_indices = relay.ir_pass.infer_type(z_indices)
assert zz.checked_type == relay.ty.TensorType(dshape, "float32")
assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32")
if check_type_only:
return
func = relay.Function([x0, x1], z)
func = relay.ir_pass.infer_type(func)
func_indices = relay.Function([x0, x1], z_indices)
func_indices = relay.ir_pass.infer_type(func_indices)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x0_data, x1_data)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
tvm.testing.assert_allclose(op_indices_res1.asnumpy(), ref_indices_res, rtol=1e-5)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
tvm.testing.assert_allclose(op_indices_res2.asnumpy(), ref_indices_res, rtol=1e-5)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
num_anchors = 5
dshape = (tvm.var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
force_suppress=True, topk=2, check_type_only=True)
verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result,
force_suppress=True, top_k=2, check_type_only=True)
dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
force_suppress=True, topk=2, check_type_only=False)
verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result,
force_suppress=True, top_k=2, check_type_only=False)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[1, 0.7, 30, 60, 50, 80], [-1, 0.9, 35, 61, 52, 79],
[1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, 1, -1, -1]])
dshape = (tvm.var("n"), num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
check_type_only=True)
verify_nms(np_data, np_valid_count, dshape, np_result,
np_indices_result, check_type_only=True)
dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0],
topk=3)
verify_nms(np_data, np_valid_count, dshape, np_result,
np_indices_result, top_k=3)
def test_multibox_transform_loc():
......@@ -226,7 +277,7 @@ def test_multibox_transform_loc():
assert ret.checked_type == ref_type
nms = relay.vision.nms(mtl[0], mtl[1])
nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False)
func = relay.Function([cls_prob, loc_pred, anchors], nms)
func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))]
......@@ -411,8 +462,9 @@ if __name__ == "__main__":
test_resize()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
test_get_valid_counts()
test_roi_align()
test_proposal()
test_yolo_reorg_infer_shape()
test_yolo_reorg()
test_non_max_suppression()
......@@ -30,7 +30,12 @@ inline Tensor l2_normalize(const Tensor& data,
const Array<Integer>& axis,
std::string name = "tensor",
std::string tag = "l2_normalize") {
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input";
for (size_t i = 0; i < axis.size(); ++i) {
int ax = topi::detail::GetConstInt(axis[i]);
CHECK_LT(ax, data->shape.size()) <<
"Axis " << ax << " exceeds input data dim " <<
data->shape.size();
}
auto input_shape = data->shape;
Tensor dot_value = topi::power(data, static_cast<float>(2.0));
Tensor sum_value = topi::sum(dot_value, axis, true);
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Non-maximum suppression operator"""
import math
import tvm
from tvm import api
from topi.vision import nms
from topi.vision import non_max_suppression
from ..util import get_const_tuple
def sort_ir(data, index, output):
......@@ -181,13 +181,14 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n
return body
@nms.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1):
@non_max_suppression.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False,
topk=-1, id_index=0, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data: tvm.Tensor
data : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
......@@ -195,15 +196,24 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
nms_threshold : float
return_indices : boolean
Whether to return box indices in input data.
iou_threshold : optional, float
Non-maximum suppression threshold.
force_suppress : boolean
force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
topk : optional, int
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
invalid_to_bottom : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
-------
out : tvm.Tensor
......@@ -216,14 +226,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
# An example to use nms
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder(
(dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]).astype("int32")
topk = -1
out = nms(data, valid_count, iou_threshold, force_suppress, topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
......@@ -263,8 +272,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], nms_threshold,
force_suppress, nms_topk),
ins[0], ins[1], ins[2], outs[0], iou_threshold,
force_suppress, topk),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
......
......@@ -11,7 +11,7 @@ import topi
from topi.vision.ssd import multibox_prior
from topi.vision.ssd import multibox_detection
from topi.vision.ssd import multibox_transform_loc
from ..nms import nms
from ..nms import non_max_suppression
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
......@@ -437,6 +437,6 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = nms(
out = non_max_suppression(
inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out
......@@ -162,3 +162,20 @@ def schedule_proposal(outs):
scheduled_ops.append(op)
traverse(outs[0].op)
return s
@generic.schedule_get_valid_counts.register(["cuda", "gpu"])
def schedule_get_valid_counts(outs):
"""Schedule for get_valid_counts operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of get_valid_counts
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs)
......@@ -37,6 +37,23 @@ def schedule_reorg(outs):
return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_get_valid_counts(outs):
"""Schedule for get_valid_counts
Parameters
----------
outs: Array of Tensor
The computation graph description of nms
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_nms(outs):
"""Schedule for non-maximum suppression
......
......@@ -20,3 +20,4 @@ from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python
from .batch_matmul import batch_matmul
from .slice_axis_python import slice_axis_python
"""Slice axis in python"""
def slice_axis_python(data, axis, begin, end=None):
"""Slice input array along specific axis.
Parameters
----------
data : numpy.ndarray
The source array to be sliced.
axis : int
Axis to be sliced.
begin: int
The index to begin with in the slicing.
end: int, optional
The index indicating end of the slice.
Returns
-------
ret : numpy.ndarray
The computed result.
"""
dshape = data.shape
if axis < 0:
axis += len(dshape)
if begin < 0:
begin += dshape[axis]
if end <= 0:
end += dshape[axis]
slc = [slice(None)] * len(dshape)
slc[axis] = slice(begin, end)
return data[tuple(slc)]
......@@ -8,11 +8,62 @@ import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi.vision import ssd, nms
from topi.vision import ssd, non_max_suppression, get_valid_counts
def verify_get_valid_counts(dshape, score_threshold):
dtype = "float32"
batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor):
score = np_data[i, j, 1]
if score > score_threshold:
for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
inter_idx += 1
if j >= np_out1[i]:
for k in range(elem_length):
np_out2[i, j, k] = -1.0
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
data = tvm.placeholder(dshape, name="data", dtype=dtype)
outs = get_valid_counts(data, score_threshold)
s = topi.generic.schedule_multibox_prior(outs)
tvm_input_data = tvm.nd.array(np_data, ctx)
tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx)
tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx)
f = tvm.build(s, [data, outs[0], outs[1]], device)
f(tvm_input_data, tvm_out1, tvm_out2)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
def test_nms():
for device in ['llvm']:
check_device(device)
def test_get_valid_counts():
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
verify_get_valid_counts((3, 1000, 6), 0.55)
verify_get_valid_counts((16, 500, 6), 0.95)
def test_non_max_suppression():
dshape = (1, 5, 6)
indices_dshape = (1, 5)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
......@@ -24,8 +75,9 @@ def test_nms():
[1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -35,18 +87,27 @@ def test_nms():
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False)
indices_out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk)
else:
out = topi.cuda.nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False)
indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk)
s = topi.generic.schedule_nms(out)
indices_s = topi.generic.schedule_nms(indices_out)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f = tvm.build(s, [data, valid_count, out], device)
f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx)
f = tvm.build(indices_s, [data, valid_count, indices_out], device)
f(tvm_data, tvm_valid_count, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
for device in ['llvm']:
check_device(device)
......@@ -274,7 +335,8 @@ def test_proposal():
if __name__ == "__main__":
test_nms()
test_get_valid_counts()
test_non_max_suppression()
test_multibox_prior()
test_multibox_detection()
test_roi_align()
......
"""
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use GluonCV pre-trained SSD model and convert it to Relay IR
"""
import tvm
from matplotlib import pyplot as plt
from nnvm import compiler
from nnvm.frontend import from_mxnet
from nnvm.testing.config import ctx_list
from tvm import relay
from tvm.contrib import graph_runtime
from gluoncv import model_zoo, data, utils
######################################################################
# Preliminary and Set parameters
# ------------------------------
# We should build TVM with sort support, in TVM root directory
#
# .. code-block:: bash
#
# echo "set(USE_SORT ON)" > config.mk
# make -j8
#
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
#
# To get best inference performance on CPU, change
# target argument according to your device and
# follow the :ref:`tune_relay_x86` to tune x86 CPU and
# :ref:`tune_relay_arm` for arm cpu.
#
# SSD with VGG as body network is not supported yet since
# x86 conv2d schedule doesn't support dilation.
supported_model = [
'ssd_512_resnet18_v1_voc',
'ssd_512_resnet18_v1_coco',
'ssd_512_resnet50_v1_voc',
'ssd_512_resnet50_v1_coco',
'ssd_512_resnet101_v2_voc',
'ssd_512_mobilenet1_0_voc',
'ssd_512_mobilenet1_0_coco',
]
model_name = "ssd_512_resnet50_v1_voc"
dshape = (1, 3, 512, 512)
dtype = "float32"
target_list = ctx_list()
######################################################################
# Download and pre-process demo image
im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' +
'gluoncv/detection/street_small.jpg?raw=true',
path='street_small.jpg')
x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
######################################################################
# Convert and compile model for CPU.
block = model_zoo.get_model(model_name, pretrained=True)
def compile(target):
net, params = relay.frontend.from_mxnet(block, {"data": dshape})
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
return graph, lib, params
######################################################################
# Create TVM runtime and do inference
def run(graph, lib, params, ctx):
# Build TVM runtime
m = graph_runtime.create(graph, lib, ctx)
tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx)
m.set_input('data', tvm_input)
m.set_input(**params)
# execute
m.run()
# get outputs
class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2)
return class_IDs, scores, bounding_boxs
for target, ctx in target_list:
if target == "cuda":
print("GPU not supported yet, skip.")
continue
graph, lib, params = compile(target)
class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
######################################################################
# Display result
ax = utils.viz.plot_bbox(img, bounding_boxs.asnumpy()[0], scores.asnumpy()[0],
class_IDs.asnumpy()[0], class_names=block.classes)
plt.show()
......@@ -61,7 +61,7 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder = \
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment