from_darknet.py 5.63 KB
Newer Older
1
"""
2
Compile YOLO-V2 and YOLO-V3 in DarkNet Models
3
=================================
4 5 6
**Author**: `Siju Samuel <https://siju-samuel.github.io/>`_

This article is an introductory tutorial to deploy darknet models with NNVM.
7
All the required models and libraries will be downloaded from the internet by the script.
8
This script runs the YOLO-V2 and YOLO-V3 Model with the bounding boxes
9 10 11
Darknet parsing have dependancy with CFFI and CV2 library
Please install CFFI and CV2 before executing this script

12
.. code-block:: bash
13

14 15
  pip install cffi
  pip install opencv-python
16
"""
Yao Wang committed
17

18 19
import nnvm
import nnvm.frontend.darknet
20
import nnvm.testing.yolo_detection
21
import nnvm.testing.darknet
22 23 24
import matplotlib.pyplot as plt
import numpy as np
import tvm
25
import sys
Yao Wang committed
26 27

from ctypes import *
28
from tvm.contrib.download import download
Yao Wang committed
29
from nnvm.testing.darknet import __darknetffi__
30

31
# Model name
32
MODEL_NAME = 'yolov3'
33 34

######################################################################
35 36 37 38 39 40 41
# Download required files
# -----------------------
# Download cfg and weights file if first time.
CFG_NAME = MODEL_NAME + '.cfg'
WEIGHTS_NAME = MODEL_NAME + '.weights'
REPO_URL = 'https://github.com/siju-samuel/darknet/blob/master/'
CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
Siju committed
42
WEIGHTS_URL = 'https://pjreddie.com/media/files/' + WEIGHTS_NAME
43 44 45

download(CFG_URL, CFG_NAME)
download(WEIGHTS_URL, WEIGHTS_NAME)
46 47

# Download and Load darknet library
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
if sys.platform in ['linux', 'linux2']:
    DARKNET_LIB = 'libdarknet2.0.so'
    DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
elif sys.platform == 'darwin':
    DARKNET_LIB = 'libdarknet_mac2.0.so'
    DARKNET_URL = REPO_URL + 'lib_osx/' + DARKNET_LIB + '?raw=true'
else:
    err = "Darknet lib is not supported on {} platform".format(sys.platform)
    raise NotImplementedError(err)

download(DARKNET_URL, DARKNET_LIB)

DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
cfg = "./" + str(CFG_NAME)
weights = "./" + str(WEIGHTS_NAME)
net = DARKNET_LIB.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
64 65
dtype = 'float32'
batch_size = 1
66

67 68 69 70 71
print("Converting darknet to nnvm symbols...")
sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)

######################################################################
# Compile the model on NNVM
72
# -------------------------
73
# compile the model
74 75 76
target = 'llvm'
ctx = tvm.cpu(0)
data = np.empty([batch_size, net.c, net.h, net.w], dtype)
77 78
shape = {'data': data.shape}
print("Compiling the model...")
79
dtype_dict = {}
80
with nnvm.compiler.build_config(opt_level=2):
81
    graph, lib, params = nnvm.compiler.build(sym, target, shape, dtype_dict, params)
82

83
[neth, netw] = shape['data'][2:] # Current image shape is 608x608
84 85 86
######################################################################
# Load a test image
# --------------------------------------------------------------------
87
test_image = 'dog.jpg'
88 89
print("Loading the test image...")
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \
90
          test_image + '?raw=true'
91 92
download(img_url, test_image)

93
data = nnvm.testing.darknet.load_image(test_image, netw, neth)
94
######################################################################
95 96
# Execute on TVM Runtime
# ----------------------
97 98 99 100 101 102 103 104 105 106 107 108 109
# The process is no different from other examples.
from tvm.contrib import graph_runtime

m = graph_runtime.create(graph, lib, ctx)

# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
# execute
print("Running the test image...")

m.run()
# get outputs
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
tvm_out = []
if MODEL_NAME == 'yolov2':
    layer_out = {}
    layer_out['type'] = 'Region'
    # Get the region layer attributes (n, out_c, out_h, out_w, classes, coords, background)
    layer_attr = m.get_output(2).asnumpy()
    layer_out['biases'] = m.get_output(1).asnumpy()
    out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
                 layer_attr[2], layer_attr[3])
    layer_out['output'] = m.get_output(0).asnumpy().reshape(out_shape)
    layer_out['classes'] = layer_attr[4]
    layer_out['coords'] = layer_attr[5]
    layer_out['background'] = layer_attr[6]
    tvm_out.append(layer_out)

elif MODEL_NAME == 'yolov3':
    for i in range(3):
        layer_out = {}
        layer_out['type'] = 'Yolo'
        # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total)
        layer_attr = m.get_output(i*4+3).asnumpy()
        layer_out['biases'] = m.get_output(i*4+2).asnumpy()
        layer_out['mask'] = m.get_output(i*4+1).asnumpy()
        out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
                     layer_attr[2], layer_attr[3])
        layer_out['output'] = m.get_output(i*4).asnumpy().reshape(out_shape)
        layer_out['classes'] = layer_attr[4]
        tvm_out.append(layer_out)
138

139
# do the detection and bring up the bounding boxes
140 141
thresh = 0.5
nms_thresh = 0.45
142 143
img = nnvm.testing.darknet.load_image_color(test_image)
_, im_h, im_w = img.shape
144 145 146 147
dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh,
                                                      1, tvm_out)
last_layer = net.layers[net.n - 1]
nnvm.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh)
148 149

coco_name = 'coco.names'
150
coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name + '?raw=true'
151
font_name = 'arial.ttf'
152
font_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + font_name + '?raw=true'
153 154 155 156 157 158 159 160
download(coco_url, coco_name)
download(font_url, font_name)

with open(coco_name) as f:
    content = f.readlines()

names = [x.strip() for x in content]

161
nnvm.testing.yolo_detection.draw_detections(img, dets, thresh, names, last_layer.classes)
162
plt.imshow(img.transpose(1, 2, 0))
163
plt.show()