Commit 993fe12f by Zhi Committed by Yizhi Liu

[relay][frontend] Enable ssd test by attaching schedules to multibox and ssd ops (#2322)

* add ssd ops to mxnet.py

* add ssd ops to mxnet.py

* add result check for multibox and nms unit tests

* add result check for multibox and nms unit tests

* address @kevinthesun's comments

* Disable cuda test for nms for now.
parent 98ce9ea0
......@@ -106,6 +106,30 @@ class StrAttrsDict(object):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_float_tuple(self, key, default=RequiredAttr()):
"""Get float tuple attribute
Parameters
----------
key : str
The attribute key
default : float
The default value.
Returns
-------
value : The result
"""
if key in self.attrs:
tshape = self.attrs[key]
return tuple(float(x.strip()) for x in
tshape.strip('()[]').split(','))
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_tuple_tuple_int(self, key, default=RequiredAttr()):
"""Get int list attribute
......
......@@ -241,6 +241,33 @@ def _mx_lrn(inputs, attrs):
return _op.nn.lrn(inputs[0], **new_attrs)
def _mx_multibox_prior(inputs, attrs):
new_attrs = {}
new_attrs["sizes"] = attrs.get_float_tuple("sizes", (1.0, ))
new_attrs["steps"] = attrs.get_float_tuple("steps", (-1.0, -1.0))
new_attrs["offsets"] = attrs.get_float_tuple("offsets", (0.5, 0.5))
new_attrs["ratios"] = attrs.get_float_tuple("ratios", (1.0, ))
new_attrs["clip"] = attrs.get_bool("clip", False)
return _op.vision.multibox_prior(inputs[0], **new_attrs)
def _mx_multibox_detection(inputs, attrs):
new_attrs0 = {}
new_attrs0["clip"] = attrs.get_bool("clip", True)
new_attrs0["threshold"] = attrs.get_float("threshold", 0.01)
new_attrs0["variances"] = attrs.get_float_tuple("variances", (0.1, 0.1,
0.2, 0.2))
new_attrs1 = {}
new_attrs1["overlap_threshold"] = attrs.get_float("nms_threshold", 0.5)
new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False)
new_attrs1["topk"] = attrs.get_int("nms_topk", -1)
ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _op.vision.nms(ret[0], ret[1], **new_attrs1)
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -327,13 +354,14 @@ _convert_map = {
"LeakyReLU" : _mx_leaky_relu,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
# vision
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
"_contrib_MultiBoxDetection" : _mx_multibox_detection,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "gather_nd",
# "_contrib_MultiBoxPrior" : _rename("multibox_prior"),
# "_contrib_MultiBoxDetection" : _contrib_multibox_detection,
# "Crop" : _crop_like,
}
......
......@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .multibox import *
from .nms import *
from . import _multibox
# pylint: disable=invalid-name, unused-argument
"""Definition of vision ops"""
from __future__ import absolute_import
import topi
from topi.util import get_const_int, get_const_float, get_float_tuple
from .. import op as reg
from ..op import OpPattern
@reg.register_schedule("vision.multibox_prior")
def schedule_multibox_prior(_, outs, target):
"""Schedule definition of multibox_prior"""
with target:
return topi.generic.schedule_multibox_prior(outs)
@reg.register_compute("vision.multibox_prior")
def compute_multibox_prior(attrs, inputs, _, target):
"""Compute definition of multibox_prior"""
sizes = get_float_tuple(attrs.sizes)
ratios = get_float_tuple(attrs.ratios)
steps = get_float_tuple(attrs.steps)
offsets = get_float_tuple(attrs.offsets)
clip = bool(get_const_int(attrs.clip))
return [
topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, steps,
offsets, clip)
]
reg.register_pattern("vision.multibox_prior", OpPattern.OPAQUE)
# multibox_transform_loc
@reg.register_schedule("vision.multibox_transform_loc")
def schedule_multibox_transform_loc(_, outs, target):
"""Schedule definition of multibox_detection"""
with target:
return topi.generic.schedule_multibox_transform_loc(outs)
@reg.register_compute("vision.multibox_transform_loc")
def compute_multibox_transform_loc(attrs, inputs, _, target):
"""Compute definition of multibox_detection"""
clip = bool(get_const_int(attrs.clip))
threshold = get_const_float(attrs.threshold)
variances = get_float_tuple(attrs.variances)
return topi.vision.ssd.multibox_transform_loc(
inputs[0], inputs[1], inputs[2], clip, threshold, variances)
reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE)
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)
# non-maximum suppression
@reg.register_schedule("vision.nms")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
with target:
return topi.generic.schedule_nms(outs)
@reg.register_compute("vision.nms")
def compute_nms(attrs, inputs, _, target):
"""Compute definition of nms"""
overlap_threshold = get_const_float(attrs.overlap_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
topk = get_const_int(attrs.topk)
return [
topi.vision.nms(inputs[0], inputs[1], overlap_threshold,
force_suppress, topk)
]
reg.register_pattern("vision.nms", OpPattern.OPAQUE)
"""Multibox operations."""
from __future__ import absolute_import as _abs
from . import _make
from ...expr import TupleWrapper
def multibox_prior(data,
sizes=(1.0,),
......@@ -43,7 +44,7 @@ def multibox_transform_loc(cls_prob,
anchor,
clip=True,
threshold=0.01,
variance=(0.1, 0.1, 0.2, 0.2)):
variances=(0.1, 0.1, 0.2, 0.2)):
"""Location transformation for multibox detection
Parameters
......@@ -63,12 +64,13 @@ def multibox_transform_loc(cls_prob,
threshold : double, optional
Threshold to be a positive prediction.
variance : Tuple of float, optional
Variances to be decoded from box regression output.
variances : Tuple of float, optional
variances to be decoded from box regression output.
Returns
-------
ret : tuple of tvm.relay.Expr
"""
return _make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip,
threshold, variance)
return TupleWrapper(_make.multibox_transform_loc(cls_prob, loc_pred,
anchor, clip, threshold,
variances), 2)
......@@ -78,6 +78,28 @@ def get_const_int(expr):
return int(expr.value)
def get_const_float(expr):
"""Verifies expr is a floating point and get the constant value.
Parameters
----------
expr : tvm.Expr or float
The input expression.
Returns
-------
out_value : float
The output.
"""
if isinstance(expr, float):
return float(expr)
if not isinstance(expr, tvm.expr.FloatImm):
expr = tvm.ir_pass.Simplify(expr)
if not isinstance(expr, tvm.expr.FloatImm):
raise ValueError("Expect value to be constant float")
return float(expr.value)
def equal_const_int(expr, value):
"""Returns if expr equals value.
......@@ -120,6 +142,26 @@ def get_const_tuple(in_tuple):
return out_tuple
def get_float_tuple(in_tuple):
"""Verifies input tuple is FloatImm, returns tuple of float.
Parameters
----------
in_tuple : tuple of Expr
The input.
Returns
-------
out_tuple : tuple of float
The output.
"""
out_tuple = ()
for elem in in_tuple:
value = get_const_float(elem)
out_tuple = out_tuple + (value, )
return out_tuple
def simplify(expr):
"""Simplify the expression if it is Expr, directly return if it is int.
......
......@@ -5,7 +5,7 @@ Deploy Single Shot Multibox Detector(SSD) model
This article is an introductory tutorial to deploy SSD models with TVM.
We will use mxnet pretrained SSD model with Resnet50 as body network and
convert it to NNVM graph.
convert it to NNVM graph;
"""
import os
import zipfile
......@@ -16,6 +16,7 @@ import numpy as np
from nnvm import compiler
from nnvm.frontend import from_mxnet
from tvm import relay
from tvm.contrib.download import download
from tvm.contrib import graph_runtime
from mxnet.model import load_checkpoint
......@@ -58,7 +59,7 @@ image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
dir = "ssd_model"
if not os.path.exists(dir):
os.makedirs(dir)
......@@ -77,13 +78,31 @@ zip_ref.extractall(dir)
zip_ref.close()
######################################################################
# Convert and compile model with NNVM for CPU.
# Convert and compile model with NNVM or Relay for CPU.
sym = mx.sym.load("%s/%s/ssd_resnet50_inference.json" % (dir, inference_symbol_folder))
_, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0)
net, params = from_mxnet(sym, arg_params, aux_params)
with compiler.build_config(opt_level=3):
graph, lib, params = compiler.build(net, target, {"data": dshape}, params=params)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-f", "--frontend",
help="Frontend for compilation, nnvm or relay",
type=str,
default="nnvm")
args = parser.parse_args()
if args.frontend == "relay":
net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, aux_params=aux_params)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
elif args.frontend == "nnvm":
net, params = from_mxnet(sym, arg_params, aux_params)
with compiler.build_config(opt_level=3):
graph, lib, params = compiler.build(
net, target, {"data": dshape}, params=params)
else:
parser.print_help()
parser.exit()
######################################################################
# Create TVM runtime and do inference
......@@ -141,4 +160,3 @@ def display(img, out, thresh=0.5):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
display(image, tvm_output.asnumpy()[0], thresh=0.45)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment