Commit 48c16a17 by Leyuan Wang Committed by masahi

[BugFix] SSD fully supported on GPUs, updated deploy_ssd tutorial (#2510)

* nms fixed for gpu, tested on cuda and opencl devices, ssd now can run fully on the gpu

* sort updated to use virtual thread

* typo fixed

* fix lint

* fix lint

* add support when batch_size > 1

* intel graphics conv2d bugs fixed for inception_v3

* intel conv2d api updated, nn input size 4 condition added

* review addressed

* move conv_tags to attributes

* opencl ctx fixed

* nms_ir index simplified
parent 881a78b3
""" """
Deploy Single Shot Multibox Detector(SSD) model Deploy Single Shot Multibox Detector(SSD) model
=============================================== ===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_ **Author**: `Yao Wang <https://github.com/kevinthesun>`_, \
`Leyuan Wang <https://github.com/Laurawly>`_
This article is an introductory tutorial to deploy SSD models with TVM. This article is an introductory tutorial to deploy SSD models with TVM.
We will use mxnet pretrained SSD model with Resnet50 as body network and We will use mxnet pretrained SSD model with Resnet50 as body network and
...@@ -32,17 +33,20 @@ from mxnet.model import load_checkpoint ...@@ -32,17 +33,20 @@ from mxnet.model import load_checkpoint
# echo "set(USE_SORT ON)" > config.mk # echo "set(USE_SORT ON)" > config.mk
# make -j8 # make -j8
# #
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
#
model_name = "ssd_resnet50_512" model_name = "ssd_resnet50_512"
model_file = "%s.zip" % model_name model_file = "%s.zip" % model_name
test_image = "dog.jpg" test_image = "dog.jpg"
dshape = (1, 3, 512, 512) dshape = (1, 3, 512, 512)
dtype = "float32" dtype = "float32"
# Target settings
# Use these commented settings to build for cuda.
#target = 'cuda'
#ctx = tvm.gpu(0)
# Use these commented settings to build for opencl.
#target = 'opencl'
#ctx = tvm.opencl(0)
target = "llvm" target = "llvm"
ctx = tvm.cpu() ctx = tvm.cpu()
...@@ -56,7 +60,8 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \ ...@@ -56,7 +60,8 @@ model_url = "https://github.com/zhreshold/mxnet-ssd/releases/download/v0.6/" \
"resnet50_ssd_512_voc0712_trainval.zip" "resnet50_ssd_512_voc0712_trainval.zip"
image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \
"cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg" "cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg"
inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" inference_symbol_folder = \
"c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26"
inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \ inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \
"archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip" "archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip"
...@@ -92,7 +97,8 @@ parser.add_argument( ...@@ -92,7 +97,8 @@ parser.add_argument(
default="nnvm") default="nnvm")
args = parser.parse_args() args = parser.parse_args()
if args.frontend == "relay": if args.frontend == "relay":
net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, aux_params=aux_params) net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, \
aux_params=aux_params)
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params) graph, lib, params = relay.build(net, target, params=params)
elif args.frontend == "nnvm": elif args.frontend == "nnvm":
...@@ -134,7 +140,7 @@ def display(img, out, thresh=0.5): ...@@ -134,7 +140,7 @@ def display(img, out, thresh=0.5):
import random import random
import matplotlib as mpl import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (10,10) mpl.rcParams['figure.figsize'] = (10, 10)
pens = dict() pens = dict()
plt.clf() plt.clf()
plt.imshow(img) plt.imshow(img)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment