Commit c418e916 by Yao Wang Committed by Tianqi Chen

SSD support in NNVM (#1214)

parent f969a689
......@@ -319,6 +319,55 @@ struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
}
};
struct MultiBoxPriorParam : public dmlc::Parameter<MultiBoxPriorParam> {
Tuple<float> sizes;
Tuple<float> ratios;
Tuple<float> steps;
Tuple<float> offsets;
bool clip;
DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) {
DMLC_DECLARE_FIELD(sizes).set_default(Tuple<float>({1.0}))
.describe("List of sizes of generated MultiBoxPriores.");
DMLC_DECLARE_FIELD(ratios).set_default(Tuple<float>({1.0}))
.describe("List of aspect ratios of generated MultiBoxPriores.");
DMLC_DECLARE_FIELD(steps).set_default(Tuple<float>({-1.0, -1.0}))
.describe("Priorbox step across y and x, -1 for auto calculation.");
DMLC_DECLARE_FIELD(offsets).set_default(Tuple<float>({0.5, 0.5}))
.describe("Priorbox center offsets, y and x respectively.");
DMLC_DECLARE_FIELD(clip).set_default(false)
.describe("Whether to clip out-of-boundary boxes.");
}
};
struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocParam> {
bool clip;
float threshold;
Tuple<float> variances;
DMLC_DECLARE_PARAMETER(MultiBoxTransformLocParam) {
DMLC_DECLARE_FIELD(clip).set_default(true)
.describe("Clip out-of-boundary boxes.");
DMLC_DECLARE_FIELD(threshold).set_default(0.01)
.describe("Threshold to be a positive prediction.");
DMLC_DECLARE_FIELD(variances).set_default(Tuple<float>{0.1, 0.1, 0.2, 0.2})
.describe("Variances to be decoded from box regression output.");
}
};
struct NMSParam : public dmlc::Parameter<NMSParam> {
float nms_threshold;
bool force_suppress;
int nms_topk;
DMLC_DECLARE_PARAMETER(NMSParam) {
DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(nms_topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
}
};
} // namespace top
} // namespace nnvm
......
......@@ -71,7 +71,7 @@ def _batch_norm(inputs, attrs):
new_attrs['axis'] = attrs.get('axis', 1)
new_attrs['epsilon'] = attrs.get('eps', 0.001)
new_attrs['center'] = True
new_attrs['scale'] = True
new_attrs['scale'] = not _parse_bool_str(attrs, 'fix_gamma', default="False")
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _concat(inputs, attrs):
......@@ -195,6 +195,12 @@ def _split(inputs, attrs):
new_attrs['axis'] = attrs.get('axis', 1)
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _softmax_activation(inputs, attrs):
op_name, new_attrs = 'softmax', {}
mode = attrs.get('mode', 'instance')
new_attrs['axis'] = 0 if mode == 'instance' else 1
return _get_nnvm_op(op_name)(inputs[0], **new_attrs)
def _softmax_output(inputs, attrs):
op_name, new_attrs = 'softmax', {}
if _parse_bool_str(attrs, 'multi_output'):
......@@ -212,6 +218,25 @@ def _clip(inputs, attrs):
new_attrs['a_max'] = _required_attr(attrs, 'a_max')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _contrib_multibox_detection(inputs, attrs):
clip = _parse_bool_str(attrs, 'clip', default='True')
threshold = attrs.get('threshold') or 0.01
nms_threshold = attrs.get('nms_threshold') or 0.5
force_suppress = _parse_bool_str(attrs, 'force_suppress', default='False')
variances = tuple([float(x.strip()) for x in attrs.get('variances').strip('()').split(',')]) \
if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2)
nms_topk = attrs.get('nms_topk') or -1
new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances}
new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress,
'nms_topk': int(nms_topk)}
data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _get_nnvm_op('nms')(data, valid_count, **new_attrs1)
def _elemwise_sum(inputs, _):
new_attrs = {'num_args':len(inputs)}
return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
......@@ -224,12 +249,15 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
_convert_map = {
'_copy' : _rename('copy'),
'_div_scalar' : _rename('__div_scalar__'),
'_minus_scalar' : _rename('__sub_scalar__'),
'_mul_scalar' : _rename('__mul_scalar__'),
'_plus_scalar' : _rename('__add_scalar__'),
'_rdiv_scalar' : _rename('__rdiv_scalar__'),
'_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
......@@ -248,7 +276,9 @@ _convert_map = {
'SliceChannel' : _split,
'split' : _split,
'Softmax' : _rename('softmax'),
'SoftmaxActivation' : _softmax_activation,
'SoftmaxOutput' : _softmax_output,
'add_n' : _elemwise_sum,
'concat' : _concat,
'max_axis' : _rename('max'),
'min_axis' : _rename('min'),
......
# pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except
"""Helper utility for downloading"""
from __future__ import print_function
from __future__ import absolute_import as _abs
import os
import sys
import time
import urllib
import requests
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
def _download_progress(count, block_size, total_size):
"""Show the download progress.
"""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
def download(url, path, overwrite=False, size_compare=False):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Parameters
----------
url : str
Download url.
path : str
Local file path to save downloaded file
overwrite : bool, optional
Whether to overwrite existing file
size_compare : bool, optional
Whether to do size compare to check downloaded file.
"""
if os.path.isfile(path) and not overwrite:
if size_compare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
url_file_size = int(res_get.headers['Content-Length'])
if url_file_size != file_size:
print("exist file got corrupted, downloading %s file freshly..." % path)
download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=_download_progress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=_download_progress)
......@@ -108,7 +108,7 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape,
num_unit = len(units)
assert num_unit == num_stages
data = sym.Variable(name='data')
data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
data = sym.batch_norm(data=data, epsilon=2e-5, scale=False, name='bn_data')
(_, height, _) = image_shape
if height <= 32: # such as cifar10
body = sym.conv2d(
......
......@@ -83,6 +83,21 @@ class AttrDict(object):
"""
return int(self[key])
def get_float_tuple(self, key):
"""Get tuple of float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
tuple : tuple of float
The result tuple
"""
return tuple(float(x) for x in self[key][1:-1].split(",") if x)
def get_float(self, key):
"""Get float from attr dict
......
# pylint: disable=invalid-name, unused-argument
"""Definition of nn ops"""
from __future__ import absolute_import
import topi
import tvm
import topi
from . import registry as reg
from .registry import OpPattern
......@@ -38,3 +37,62 @@ def schedule_region(attrs, outs, target):
return topi.generic.vision.schedule_region(outs)
reg.register_pattern("yolo2_region", OpPattern.OPAQUE)
# multibox_prior
@reg.register_schedule("multibox_prior")
def schedule_multibox_prior(_, outs, target):
"""Schedule definition of multibox_prior"""
with tvm.target.create(target):
return topi.generic.schedule_multibox_prior(outs)
@reg.register_compute("multibox_prior")
def compute_multibox_prior(attrs, inputs, _):
"""Compute definition of multibox_prior"""
sizes = attrs.get_float_tuple('sizes')
ratios = attrs.get_float_tuple('ratios')
steps = attrs.get_float_tuple('steps')
offsets = attrs.get_float_tuple('offsets')
clip = attrs.get_bool('clip')
return topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios,
steps, offsets, clip)
reg.register_pattern("multibox_prior", OpPattern.OPAQUE)
# multibox_transform_loc
@reg.register_schedule("multibox_transform_loc")
def schedule_multibox_transform_loc(_, outs, target):
"""Schedule definition of multibox_detection"""
with tvm.target.create(target):
return topi.generic.schedule_multibox_transform_loc(outs)
@reg.register_compute("multibox_transform_loc")
def compute_multibox_transform_loc(attrs, inputs, _):
"""Compute definition of multibox_detection"""
clip = attrs.get_bool('clip')
threshold = attrs.get_float('threshold')
variance = attrs.get_float_tuple('variances')
return topi.vision.ssd.multibox_transform_loc(inputs[0], inputs[1], inputs[2],
clip, threshold, variance)
reg.register_pattern("multibox_detection", OpPattern.OPAQUE)
# non-maximum suppression
@reg.register_schedule("nms")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
with tvm.target.create(target):
return topi.generic.schedule_nms(outs)
@reg.register_compute("nms")
def compute_nms(attrs, inputs, _):
"""Compute definition of nms"""
nms_threshold = attrs.get_float('nms_threshold')
force_suppress = attrs.get_bool('force_suppress')
nms_topk = attrs.get_int('nms_topk')
return topi.vision.nms(inputs[0], inputs[1], nms_threshold,
force_suppress, nms_topk)
reg.register_pattern("nms", OpPattern.OPAQUE)
/*!
* Copyright (c) 2017 by Contributors
* \file nms.cc
* \brief Property def of SSD non-maximum suppression operator.
*/
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h>
#include <nnvm/top/nn.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"
namespace nnvm {
namespace top {
using compiler::FTVMCompute;
using tvm::Tensor;
using tvm::Array;
DMLC_REGISTER_PARAMETER(NMSParam);
bool NMSShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]";
TShape dshape = in_attrs->at(0);
TShape vshape = in_attrs->at(1);
CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3-D.";
CHECK_EQ(vshape.ndim(), 1U) << "Input valid count should be 1-D.";
CHECK_EQ(dshape[2], 6U) << "Data input should have shape "
"(batch_size, num_anchors, 6).";
CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch.";
out_attrs->clear();
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape);
return true;
}
inline bool NMSInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0));
return true;
}
inline bool NMSInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
static const Layout kNCHW("NCHW");
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kNCHW);
return true;
}
NNVM_REGISTER_OP(nms)
.describe(R"doc("Non-maximum suppression."
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NMSParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<NMSParam>)
.add_arguments(NMSParam::__FIELDS__())
.add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "valid_count"};
})
.set_attr<FInferShape>("FInferShape", NMSShape)
.set_attr<FInferType>("FInferType", NMSInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", NMSInferLayout)
.set_support_level(4);
} // namespace top
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file multibox_op.cc
* \brief Property def of SSD multibox related operators.
*/
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h>
#include <nnvm/top/nn.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include "../../op_common.h"
#include "../../elemwise_op_common.h"
namespace nnvm {
namespace top {
using compiler::FTVMCompute;
using tvm::Tensor;
using tvm::Array;
DMLC_REGISTER_PARAMETER(MultiBoxPriorParam);
bool MultiBoxPriorShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const MultiBoxPriorParam& param = nnvm::get<MultiBoxPriorParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U) << "Inputs: [data]" << in_attrs->size();
TShape dshape = in_attrs->at(0);
CHECK_GE(dshape.ndim(), 4U) << "Input data should be 4D: "
"[batch, channel, height, width]";
int in_height = dshape[2];
CHECK_GT(in_height, 0) << "Input height should > 0";
int in_width = dshape[3];
CHECK_GT(in_width, 0) << "Input width should > 0";
// since input sizes are same in each batch, we could share MultiBoxPrior
TShape oshape = TShape(3);
int num_sizes = param.sizes.ndim();
int num_ratios = param.ratios.ndim();
oshape[0] = 1;
oshape[1] = in_height * in_width * (num_sizes + num_ratios - 1);
oshape[2] = 4;
CHECK_EQ(param.steps.ndim(), 2) << "Step ndim must be 2: (step_y, step_x)";
CHECK_GE(param.steps[0] * param.steps[1], 0) << "Must specify both "
"step_y and step_x";
out_attrs->clear();
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
return true;
}
inline bool MultiBoxPriorLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
static const Layout kNCHW("NCHW");
CHECK_EQ(ilayouts->size(), 1U);
CHECK_EQ(olayouts->size(), 1U);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW);
return true;
}
NNVM_REGISTER_OP(multibox_prior)
.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios."
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<MultiBoxPriorParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<MultiBoxPriorParam>)
.add_arguments(MultiBoxPriorParam::__FIELDS__())
.add_argument("data", "Tensor", "Input data")
.set_attr<FInferShape>("FInferShape", MultiBoxPriorShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", MultiBoxPriorLayout)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_zero_grad",
{n->inputs[0]}),
ograds[0]
};
})
.set_support_level(4);
DMLC_REGISTER_PARAMETER(MultiBoxTransformLocParam);
bool MultiBoxTransformLocShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3U) << "Inputs: [cls_prob, loc_pred, anchor]";
TShape cshape = in_attrs->at(0);
TShape lshape = in_attrs->at(1);
TShape ashape = in_attrs->at(2);
CHECK_EQ(cshape.ndim(), 3U) << "Class probability should be 3-D.";
CHECK_EQ(lshape.ndim(), 2U) << "Location prediction should be 2-D.";
CHECK_EQ(ashape.ndim(), 3U) << "Anchor should be 3-D.";
CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch.";
CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc.";
CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0.";
CHECK_EQ(ashape[2], 4U);
TShape oshape0 = TShape(3);
oshape0[0] = cshape[0];
oshape0[1] = ashape[1];
oshape0[2] = 6; // [id, prob, xmin, ymin, xmax, ymax]
TShape oshape1 = TShape(1);
oshape1[0] = cshape[0];
out_attrs->clear();
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape0);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 1, oshape1);
return true;
}
inline bool MultiBoxTransformLocLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), 3U);
CHECK_EQ(last_ilayouts->size(), 3U);
CHECK_EQ(olayouts->size(), 2U);
for (size_t i = 0; i < last_ilayouts->size(); ++i) {
const Layout& last_layout = last_ilayouts->at(i);
if (last_layout.defined()) {
NNVM_ASSIGN_LAYOUT(*ilayouts, i, last_layout);
}
}
return true;
}
inline bool MultiBoxTransformLocInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0));
DTYPE_ASSIGN(out_attrs->at(1), 4U);
return true;
}
NNVM_REGISTER_OP(multibox_transform_loc)
.describe(R"doc("Location transformation for multibox detection."
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr_parser(ParamParser<MultiBoxTransformLocParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<MultiBoxTransformLocParam>)
.add_arguments(MultiBoxTransformLocParam::__FIELDS__())
.add_argument("cls_prob", "Tensor", "Class probabilities.")
.add_argument("loc_pred", "Tensor", "Location regression predictions.")
.add_argument("anchor", "Tensor", "Multibox prior anchor boxes")
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"cls_prob", "loc_pred", "anchor"};
})
.set_attr<FInferShape>("FInferShape", MultiBoxTransformLocShape)
.set_attr<FInferType>("FInferType", MultiBoxTransformLocInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", MultiBoxTransformLocLayout)
.set_support_level(4);
} // namespace top
} // namespace nnvm
import math
import numpy as np
import tvm
from tvm.contrib import graph_runtime
......@@ -356,6 +357,118 @@ def test_full():
np.full(shape, fill_value=0, dtype=dtype),
atol=1e-5, rtol=1e-5)
def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1),
offsets=(0.5, 0.5), clip=False):
data = sym.Variable("data")
out = sym.multibox_prior(data=data, sizes=sizes, ratios=ratios, steps=steps,
offsets=offsets, clip=clip)
in_height = dshape[2]
in_width = dshape[3]
num_sizes = len(sizes)
num_ratios = len(ratios)
size_ratio_concat = sizes + ratios
steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
offset_h = offsets[0]
offset_w = offsets[1]
oshape = (1, in_height * in_width * (num_sizes + num_ratios - 1), 4)
dtype = "float32"
np_out = np.zeros(oshape).astype(dtype)
for i in range(in_height):
center_h = (i + offset_h) * steps_h
for j in range(in_width):
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
w = size_ratio_concat[k] * in_height / in_width / 2.0 if k < num_sizes else \
size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0
h = size_ratio_concat[k] / 2.0 if k < num_sizes else \
size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0
count = i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k
np_out[0][count][0] = center_w - w
np_out[0][count][1] = center_h - h
np_out[0][count][2] = center_w + w
np_out[0][count][3] = center_h + h
if clip:
np_out = np.clip(np_out, 0, 1)
target = "llvm"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
m.run()
out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
def test_multibox_prior():
verify_multibox_prior((1, 3, 50, 50))
verify_multibox_prior((1, 3, 224, 224), sizes=(0.5, 0.25, 0.1), ratios=(1, 2, 0.5))
verify_multibox_prior((1, 32, 32, 32), sizes=(0.5, 0.25), ratios=(1, 2), steps=(2, 2), clip=True)
def test_multibox_transform_loc():
batch_size = 1
num_anchors = 3
num_classes = 3
cls_prob = sym.Variable("cls_prob")
loc_preds = sym.Variable("loc_preds")
anchors = sym.Variable("anchors")
transform_loc_data, valid_count = sym.multibox_transform_loc(cls_prob=cls_prob, loc_pred=loc_preds,
anchor=anchors)
out = sym.nms(data=transform_loc_data, valid_count=valid_count)
# Manually create test case
np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]])
np_loc_preds = np.array([[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4, -0.8]])
np_anchors = np.array([[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2], [1.2, 1.2, 1.5, 1.5]]])
expected_np_out = np.array([[[1, 0.69999999, 0, 0, 0.10818365, 0.10008108],
[0, 0.44999999, 1, 1, 1, 1],
[0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])
target = "llvm"
dtype = "float32"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
"loc_preds": (batch_size, num_anchors * 4),
"anchors": (1, num_anchors, 4)})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
m.run()
out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
def test_nms():
dshape = (1, 5, 6)
data = sym.Variable("data")
valid_count = sym.Variable("valid_count", dtype="int32")
nms_threshold = 0.7
force_suppress = True
nms_topk = 2
out = sym.nms(data=data, valid_count=valid_count, nms_threshold=nms_threshold,
force_suppress=force_suppress, nms_topk=nms_topk)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1]]])
target = "llvm"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
dtype={"data": "float32", "valid_count": "int32"})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data": np_data, "valid_count": np_valid_count})
m.run()
out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
if __name__ == "__main__":
test_reshape()
......@@ -370,4 +483,7 @@ if __name__ == "__main__":
test_block_grad()
test_full()
test_flip()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
print(nnvm.compiler.engine.dump())
......@@ -27,7 +27,7 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
ratios : tuple of float
Tuple of ratios for anchor boxes.
steps : Tuple of int
steps : Tuple of float
Priorbox step across y and x, -1 for auto calculation.
offsets : tuple of int
......@@ -86,7 +86,7 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
ratios : tuple of float
Tuple of ratios for anchor boxes.
steps : Tuple of int
steps : Tuple of float
Priorbox step across y and x, -1 for auto calculation.
offsets : tuple of int
......@@ -211,8 +211,8 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho
@tvm.target.generic_func
def mutibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
variances=(0.1, 0.1, 0.2, 0.2)):
def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
variances=(0.1, 0.1, 0.2, 0.2)):
"""Location transformation for multibox detection
Parameters
......@@ -237,11 +237,7 @@ def mutibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
Returns
-------
out : tvm.Tensor
3-D tensor with shape (batch_size, num_anchors, 6)
valid_count : tvm.Tensor
1-D tensor with shape (batch_size,), number of valid anchor boxes.
ret : tuple of tvm.Tensor
"""
batch_size = cls_prob.shape[0]
num_anchors = anchor.shape[1]
......@@ -259,7 +255,7 @@ def mutibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01,
dtype=[valid_count_dtype, cls_prob.dtype],
out_buffers=[valid_count_buf, out_buf],
tag="multibox_transform_loc")
return out, valid_count
return [out, valid_count]
@tvm.target.generic_func
......@@ -301,7 +297,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
out : tvm.Tensor
3-D tensor with shape (batch_size, num_anchors, 6)
"""
inter_out, valid_count = mutibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = nms(inter_out, valid_count, nms_threshold, force_suppress, nms_topk)
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = nms(inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk)
return out
"""
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use mxnet pretrained SSD model with Resnet50 as body network and
convert it to NNVM graph.
"""
import os
import zipfile
import tvm
import mxnet as mx
import cv2
import numpy as np
from nnvm import compiler
from nnvm.frontend import from_mxnet
from nnvm.testing.download import download
from tvm.contrib import graph_runtime
from mxnet.model import load_checkpoint
######################################################################
# Set the parameters here
# -----------------------
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
model_name = "ssd_resnet50_512"
model_file = "%s.zip" % model_name
test_image = "dog.jpg"
dshape = (1, 3, 512, 512)
dtype = "float32"
target = "llvm"
ctx = tvm.cpu()
######################################################################
# Download MXNet SSD pre-trained model and demo image
# ---------------------------------------------------
# Pre-trained model available at
# https://github.com/apache/incubator-\mxnet/tree/master/example/ssd
model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
"resnet50_ssd_512_voc0712_trainval.zip"
image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
dir = "ssd_model"
if not os.path.exists(dir):
os.makedirs(dir)
model_file_path = "%s/%s" % (dir, model_file)
test_image_path = "%s/%s" % (dir, test_image)
inference_symbol_path = "%s/inference_model.zip" % dir
download(model_url, model_file_path)
download(image_url, test_image_path)
download(inference_symbol_url, inference_symbol_path)
zip_ref = zipfile.ZipFile(model_file_path, 'r')
zip_ref.extractall(dir)
zip_ref.close()
zip_ref = zipfile.ZipFile(inference_symbol_path)
zip_ref.extractall(dir)
zip_ref.close()
######################################################################
# Convert and compile model with NNVM for CPU.
sym = mx.sym.load("%s/%s/ssd_resnet50_inference.json" % (dir, inference_symbol_folder))
_, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0)
net, params = from_mxnet(sym, arg_params, aux_params)
with compiler.build_config(opt_level=3):
graph, lib, params = compiler.build(net, target, {"data": dshape}, params=params)
######################################################################
# Create TVM runtime and do inference
# Preprocess image
image = cv2.imread(test_image_path)
img_data = cv2.resize(image, (dshape[2], dshape[3]))
img_data = img_data[:, :, (2, 1, 0)].astype(np.float32)
img_data -= np.array([123, 117, 104])
img_data = np.transpose(np.array(img_data), (2, 0, 1))
img_data = np.expand_dims(img_data, axis=0)
# Build TVM runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input('data', tvm.nd.array(img_data.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
_, oshape = compiler.graph_util.infer_shape(graph, shape={"data": dshape})
tvm_output = m.get_output(0, tvm.nd.empty(tuple(oshape[0]), dtype))
######################################################################
# Display result
class_names = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair",
"cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant",
"sheep", "sofa", "train", "tvmonitor"]
def display(img, out, thresh=0.5):
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (10,10)
pens = dict()
plt.clf()
plt.imshow(img)
for det in out:
cid = int(det[0])
if cid < 0:
continue
score = det[1]
if score < thresh:
continue
if cid not in pens:
pens[cid] = (random.random(), random.random(), random.random())
scales = [img.shape[1], img.shape[0]] * 2
xmin, ymin, xmax, ymax = [int(p * s) for p, s in zip(det[2:6].tolist(), scales)]
rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False,
edgecolor=pens[cid], linewidth=3)
plt.gca().add_patch(rect)
text = class_names[cid]
plt.gca().text(xmin, ymin-2, '{:s} {:.3f}'.format(text, score),
bbox=dict(facecolor=pens[cid], alpha=0.5),
fontsize=12, color='white')
plt.show()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
display(image, tvm_output.asnumpy()[0], thresh=0.45)
......@@ -14,23 +14,18 @@ Please install CFFI and CV2 before executing this script
pip install cffi
pip install opencv-python
"""
from ctypes import *
import math
import random
import nnvm
import nnvm.frontend.darknet
import nnvm.testing.darknet
from nnvm.testing.darknet import __darknetffi__
import matplotlib.pyplot as plt
import numpy as np
import tvm
import os, sys, time, urllib, requests
if sys.version_info >= (3,):
import urllib.request as urllib2
import urllib.parse as urlparse
else:
import urllib2
import urlparse
import os
from ctypes import *
from nnvm.testing.download import download
from nnvm.testing.darknet import __darknetffi__
######################################################################
# Set the parameters here.
......@@ -41,62 +36,6 @@ test_image = 'dog.jpg'
target = 'llvm'
ctx = tvm.cpu(0)
def dlProgress(count, block_size, total_size):
"""Show the download progress."""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
def download(url, path, overwrite=False, sizecompare=False):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Parameters
----------
url : str
Operator name, such as Convolution, Connected, etc
path : str
List of input symbols.
overwrite : dict
Dict of operator attributes
sizecompare : dict
Dict of operator attributes
Returns
-------
out_name : converted out name of operation
sym : nnvm.Symbol
Converted nnvm Symbol
"""
if os.path.isfile(path) and not overwrite:
if (sizecompare):
fileSize = os.path.getsize(path)
resHead = requests.head(url)
resGet = requests.get(url,stream=True)
if 'Content-Length' not in resHead.headers :
resGet = urllib2.urlopen(url)
urlFileSize = int(resGet.headers['Content-Length'])
if urlFileSize != fileSize:
print ("exist file got corrupted, downloading", path , " file freshly")
download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=dlProgress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=dlProgress)
######################################################################
# Prepare cfg and weights file
# ----------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment