Commit 7ec898d5 by Pariksheet Pinjari Committed by Tianqi Chen

[FRONTEND] DarkNet Yolo2 Frontend Support (#377)

parent 2e836ca7
......@@ -56,7 +56,7 @@ endif
all: lib/libnnvm.a lib/libnnvm_compiler.$(SHARED_LIBRARY_SUFFIX)
SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc)
SRC_COMPILER = $(wildcard src/top/*/*.cc src/compiler/*.cc src/compiler/*/*.cc)
SRC_COMPILER = $(wildcard src/top/*/*.cc wildcard src/top/vision/*/*.cc src/compiler/*.cc src/compiler/*/*.cc)
ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC))
TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_COMPILER))
ALL_DEP = $(ALL_OBJ)
......
......@@ -4,3 +4,4 @@ from .mxnet import from_mxnet
from .onnx import from_onnx
from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
......@@ -7,3 +7,5 @@ from . import mobilenet
from . import mlp
from . import resnet
from . import vgg
from . import darknet
from . import yolo2_detection
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
"""
Yolo detection boxes helper functions
====================
DarkNet helper functions for yolo and image loading.
This functions will not be loaded by default.
These are utility functions used for testing and tutorial file.
"""
from __future__ import division
import math
from collections import namedtuple
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
def _entry_index(batch, w, h, outputs, classes, coords, location, entry):
n = int(location/(w*h))
loc = location%(w*h)
return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
Box = namedtuple('Box', ['x', 'y', 'w', 'h'])
def _get_region_box(x, biases, n, index, i, j, w, h, stride):
b = Box(0, 0, 0, 0)
b = b._replace(x=(i + x[index + 0*stride]) / w)
b = b._replace(y=(j + x[index + 1*stride]) / h)
b = b._replace(w=np.exp(x[index + 2*stride]) * biases[2*n] / w)
b = b._replace(h=np.exp(x[index + 3*stride]) * biases[2*n+1] / h)
return b
def _correct_region_boxes(boxes, n, w, h, netw, neth, relative):
new_w, new_h = (netw, (h*netw)/w) if (netw/w < neth/h) else ((w*neth/h), neth)
for i in range(n):
b = boxes[i]
b = boxes[i]
b = b._replace(x=(b.x - (netw - new_w)/2/netw) / (new_w/netw))
b = b._replace(y=(b.y - (neth - new_h)/2/neth) / (new_h/neth))
b = b._replace(w=b.w * netw/new_w)
b = b._replace(h=b.h * neth/new_h)
if not relative:
b = b._replace(x=b.x * w)
b = b._replace(w=b.w * w)
b = b._replace(y=b.y * h)
b = b._replace(h=b.h * h)
boxes[i] = b
def _overlap(x1, w1, x2, w2):
l1 = x1 - w1/2
l2 = x2 - w2/2
left = l1 if l1 > l2 else l2
r1 = x1 + w1/2
r2 = x2 + w2/2
right = r1 if r1 < r2 else r2
return right - left
def _box_intersection(a, b):
w = _overlap(a.x, a.w, b.x, b.w)
h = _overlap(a.y, a.h, b.y, b.h)
if w < 0 or h < 0:
return 0
return w*h
def _box_union(a, b):
i = _box_intersection(a, b)
u = a.w*a.h + b.w*b.h - i
return u
def _box_iou(a, b):
return _box_intersection(a, b)/_box_union(a, b)
def get_region_boxes(layer_in, imw, imh, netw, neth, thresh, probs,
boxes, relative, tvm_out):
"To get the boxes for the image based on the prediction"
lw = layer_in.w
lh = layer_in.h
probs = [[0 for i in range(layer_in.classes + 1)] for y in range(lw*lh*layer_in.n)]
boxes = [Box(0, 0, 0, 0) for i in range(lw*lh*layer_in.n)]
for i in range(lw*lh):
row = int(i / lw)
col = int(i % lw)
for n in range(layer_in.n):
index = n*lw*lh + i
obj_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, layer_in.coords)
box_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, 0)
mask_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, 4)
scale = 1 if layer_in.background else tvm_out[obj_index]
boxes[index] = _get_region_box(tvm_out, layer_in.biases, n, box_index, col,
row, lw, lh, lw*lh)
if not layer_in.softmax_tree:
max_element = 0
for j in range(layer_in.classes):
class_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, layer_in.coords+1+j)
prob = scale*tvm_out[class_index]
probs[index][j] = prob if prob > thresh else 0
max_element = max(max_element, prob)
probs[index][layer_in.classes] = max_element
_correct_region_boxes(boxes, lw*lh*layer_in.n, imw, imh, netw, neth, relative)
return boxes, probs
def do_nms_sort(boxes, probs, total, classes, thresh):
"Does the sorting based on the threshold values"
SortableBbox = namedtuple('SortableBbox', ['index_var', 'class_var', 'probs'])
s = [SortableBbox(0, 0, []) for i in range(total)]
for i in range(total):
s[i] = s[i]._replace(index_var=i)
s[i] = s[i]._replace(class_var=0)
s[i] = s[i]._replace(probs=probs)
for k in range(classes):
for i in range(total):
s[i] = s[i]._replace(class_var=k)
s = sorted(s, key=lambda x: x.probs[x.index_var][x.class_var], reverse=True)
for i in range(total):
if probs[s[i].index_var][k] == 0:
continue
a = boxes[s[i].index_var]
for j in range(i+1, total):
b = boxes[s[j].index_var]
if _box_iou(a, b) > thresh:
probs[s[j].index_var][k] = 0
return boxes, probs
def draw_detections(im, num, thresh, boxes, probs, names, classes):
"Draw the markings around the detected region"
for i in range(num):
labelstr = []
category = -1
for j in range(classes):
if probs[i][j] > thresh:
if category == -1:
category = j
labelstr.append(names[j])
if category > -1:
imc, imh, imw = im.shape
width = int(imh * 0.006)
offset = category*123457 % classes
red = _get_color(2, offset, classes)
green = _get_color(1, offset, classes)
blue = _get_color(0, offset, classes)
rgb = [red, green, blue]
b = boxes[i]
left = int((b.x-b.w/2.)*imw)
right = int((b.x+b.w/2.)*imw)
top = int((b.y-b.h/2.)*imh)
bot = int((b.y+b.h/2.)*imh)
if left < 0:
left = 0
if right > imw-1:
right = imw-1
if top < 0:
top = 0
if bot > imh-1:
bot = imh-1
_draw_box_width(im, left, top, right, bot, width, red, green, blue)
label = _get_label(''.join(labelstr), rgb)
_draw_label(im, top + width, left, label, rgb)
def _get_pixel(im, x, y, c):
return im[c][y][x]
def _set_pixel(im, x, y, c, val):
if x < 0 or y < 0 or c < 0 or x >= im.shape[2] or y >= im.shape[1] or c >= im.shape[0]:
return
im[c][y][x] = val
def _draw_label(im, r, c, label, rgb):
w = label.shape[2]
h = label.shape[1]
if (r - h) >= 0:
r = r - h
for j in range(h):
if j < h and (j + r) < im.shape[1]:
for i in range(w):
if i < w and (i + c) < im.shape[2]:
for k in range(label.shape[0]):
val = _get_pixel(label, i, j, k)
_set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)
def _get_label(labelstr, rgb):
text = labelstr
colorText = "black"
testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1)))
font = ImageFont.truetype("arial.ttf", 25)
width, height = testDraw.textsize(labelstr, font=font)
img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255),
int(rgb[2]*255)))
d = ImageDraw.Draw(img)
d.text((0, 0), text, fill=colorText, font=font)
opencvImage = np.divide(np.asarray(img), 255)
return opencvImage.transpose(2, 0, 1)
def _get_color(c, x, max_value):
c = int(c)
colors = [[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]]
ratio = (float(x)/float(max_value)) * 5
i = int(math.floor(ratio))
j = int(math.ceil(ratio))
ratio -= i
r = (1-ratio) * colors[i][c] + ratio*colors[j][c]
return r
def _draw_box(im, x1, y1, x2, y2, r, g, b):
y1 = int(y1)
y2 = int(y2)
x1 = int(x1)
x2 = int(x2)
ac, ah, aw = im.shape
if x1 < 0:
x1 = 0
if x1 >= aw:
y1 = 0
if y1 >= ah:
y1 = ah - 1
if y2 < 0:
y2 = 0
if y2 >= ah:
y2 = ah - 1
for i in range(x1, x2):
im[0][y1][i] = r
im[0][y2][i] = r
im[1][y1][i] = g
im[1][y2][i] = g
im[2][y1][i] = b
im[2][y2][i] = b
for i in range(y1, y2):
im[0][i][x1] = r
im[0][i][x2] = r
im[1][i][x1] = g
im[1][i][x2] = g
im[2][i][x1] = b
im[2][i][x2] = b
def _draw_box_width(im, x1, y1, x2, y2, w, r, g, b):
for i in range(int(w)):
_draw_box(im, x1+i, y1+i, x2-i, y2-i, r, g, b)
......@@ -7,6 +7,7 @@ from . import tensor
from . import nn
from . import transform
from . import reduction
from . import vision
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
# pylint: disable=invalid-name, unused-argument
"""Definition of nn ops"""
from __future__ import absolute_import
import topi
import tvm
from . import registry as reg
from .registry import OpPattern
@reg.register_compute("yolo2_reorg")
def compute_reorg(attrs, inputs, _):
"""Compute definition of reorg"""
return topi.vision.reorg(inputs[0], attrs.get_int("stride"))
@reg.register_schedule("yolo2_reorg")
def schedule_reorg(attrs, outs, target):
"""Schedule definition of reorg"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)
reg.register_pattern("yolo2_reorg", OpPattern.INJECTIVE)
@reg.register_compute("yolo2_region")
def compute_region(attrs, inputs, _):
"""Compute definition of region"""
n = attrs.get_int("n")
classes = attrs.get_int("classes")
coords = attrs.get_int("coords")
background = attrs.get_int("background")
softmax = attrs.get_int("softmax")
return topi.vision.yolo2.region(inputs[0], n, classes, coords, background, softmax)
@reg.register_schedule("yolo2_region")
def schedule_region(attrs, outs, target):
"""Schedule definition of region"""
with tvm.target.create(target):
return topi.generic.vision.schedule_region(outs)
reg.register_pattern("yolo2_region", OpPattern.OPAQUE)
/*!
* Copyright (c) 2018 by Contributors
* \file region.cc
* \brief Property def of pooling operators.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "../../op_common.h"
#include "region.h"
namespace nnvm {
namespace top {
NNVM_REGISTER_OP(yolo2_region)
.describe(R"code(Region layer
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(5)
.add_argument("data", "Tensor", "Input data")
.set_attr<FInferType>("FInferType", RegionType<1, 1>)
.set_attr<FInferShape>("FInferShape", RegionShape<1, 1>)
.set_attr<FInplaceOption>(
"FInplaceOption",
[](const NodeAttrs &attrs) {
return std::vector<std::pair<int, int>>{{0, 0}, {1, 0}};
})
.set_attr<FGradient>("FGradient", [](const NodePtr &n,
const std::vector<NodeEntry> &ograds) {
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});
} // namespace top
} // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \file region.h
*/
#ifndef NNVM_TOP_VISION_YOLO2_REGION_H_
#define NNVM_TOP_VISION_YOLO2_REGION_H_
#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include <sstream>
namespace nnvm {
namespace top {
template <typename AttrType,
bool (*is_none)(const AttrType &),
bool (*assign)(AttrType *,
const AttrType &),
bool reverse_infer,
std::string (*attr_string)(const AttrType &),
int n_in = -1,
int n_out = -1>
inline bool RegionAttr(const nnvm::NodeAttrs &attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType &none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
if (n_in != -1) {
in_size = static_cast<size_t>(n_in);
}
if (n_out != -1) {
out_size = static_cast<size_t>(n_out);
}
auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
if (i == 0)
CHECK(assign(&dattr, (*vec)[i]))
<< "Incompatible attr in node " << attrs.name << " at " << i
<< "-th " << name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
deduce(in_attrs, in_size, "input");
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
write(out_attrs, out_size, "output");
if (is_none(dattr)) {
return false;
}
return true;
}
template <int n_in, int n_out>
inline bool RegionShape(const NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return RegionAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}
template <int n_in, int n_out>
inline bool RegionType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return RegionAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
} // namespace top
} // namespace nnvm
#endif // NNVM_TOP_VISION_YOLO2_REGION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file reorg.cc
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "../../op_common.h"
#include "../../elemwise_op_common.h"
#include "reorg.h"
namespace nnvm {
namespace top {
// reorg
DMLC_REGISTER_PARAMETER(ReorgParam);
inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const ReorgParam &param = nnvm::get<ReorgParam>(attrs.parsed);
TShape dshape = in_shape->at(0);
if (dshape.ndim() == 0)
return false;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D";
CHECK_GT(param.stride, 0U) << "Stride value cannot be 0";
TShape oshape({dshape[0], 0, 0, 0});
oshape[1] = dshape[1] * param.stride * param.stride;
oshape[2] = dshape[2] / param.stride;
oshape[3] = dshape[3] / param.stride;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
NNVM_REGISTER_OP(yolo2_reorg)
.describe(R"(Perform reorg operation on input array based on the stride value.
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
- **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride).
)" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(5)
.add_argument("data", "Tensor", "Data input to reorganize")
.set_attr_parser(ParamParser<ReorgParam>)
.add_arguments(ReorgParam::__FIELDS__())
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReorgParam>)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferShape>("FInferShape", ReorgInferShape);
} // namespace top
} // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \file reorg.h
*/
#ifndef NNVM_TOP_VISION_YOLO2_REORG_H_
#define NNVM_TOP_VISION_YOLO2_REORG_H_
#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include <sstream>
namespace nnvm {
namespace top {
template <typename AttrType,
bool (*is_none)(const AttrType &),
bool (*assign)(AttrType *,
const AttrType &),
bool reverse_infer,
std::string (*attr_string)(const AttrType &),
int n_in = -1,
int n_out = -1>
inline bool ReorgAttr(const nnvm::NodeAttrs &attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType &none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
if (n_in != -1) {
in_size = static_cast<size_t>(n_in);
}
if (n_out != -1) {
out_size = static_cast<size_t>(n_out);
}
auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
if (i == 0) {
CHECK(assign(&dattr, (*vec)[i]))
<< "Incompatible attr in node " << attrs.name << " at " << i
<< "-th " << name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
}
};
deduce(in_attrs, in_size, "input");
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
write(out_attrs, out_size, "output");
if (is_none(dattr)) {
return false;
}
return true;
}
template <int n_in, int n_out>
inline bool ReorgShape(const NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return ReorgAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}
template <int n_in, int n_out>
inline bool ReorgType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return ReorgAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
struct ReorgParam : public dmlc::Parameter<ReorgParam> {
int stride;
DMLC_DECLARE_PARAMETER(ReorgParam) {
DMLC_DECLARE_FIELD(stride).set_default(1).describe("Stride value");
}
};
} // namespace top
} // namespace nnvm
#endif // NNVM_TOP_VISION_YOLO2_REORG_H_
......@@ -41,6 +41,9 @@ RUN bash /install/ubuntu_install_coreml.sh
COPY install/ubuntu_install_keras.sh /install/ubuntu_install_keras.sh
RUN bash /install/ubuntu_install_keras.sh
COPY install/ubuntu_install_darknet.sh /install/ubuntu_install_darknet.sh
RUN bash /install/ubuntu_install_darknet.sh
RUN pip install Pillow
# Environment variables
......
#install the necessary dependancies, cffi, opencv
wget 'https://github.com/siju-samuel/darknet/blob/master/lib/libdarknet.so?raw=true' -O libdarknet.so
pip2 install opencv-python cffi
pip3 install opencv-python cffi
"""
Compile Darknet Models
=====================
This article is a test script to test darknet models with NNVM.
All the required models and libraries will be downloaded from the internet
by the script.
"""
import os
import requests
import numpy as np
from nnvm import frontend
from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler
import tvm
import sys
import urllib
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
if sizecompare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
urlfile_size = int(res_get.headers['Content-Length'])
if urlfile_size != file_size:
print("exist file got corrupted, downloading", path, " file freshly")
_download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path)
print('')
except:
urllib.urlretrieve(url, path)
DARKNET_LIB = 'libdarknet.so'
DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+ DARKNET_LIB + '?raw=true'
_download(DARKNETLIB_URL, DARKNET_LIB)
LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
def test_forward(net):
'''Test network with given input image on both darknet and tvm'''
def get_darknet_output(net, img):
return LIB.network_predict_image(net, img)
def get_tvm_output(net, img):
'''Compute TVM output'''
dtype = 'float32'
batch_size = 1
sym, params = frontend.darknet.from_darknet(net, dtype)
data = np.empty([batch_size, img.c, img.h, img.w], dtype)
i = 0
for c in range(img.c):
for h in range(img.h):
for k in range(img.w):
data[0][c][h][k] = img.data[i]
i = i + 1
target = 'llvm'
shape_dict = {'data': data.shape}
#with nnvm.compiler.build_config(opt_level=2):
graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params)
######################################################################
# Execute on TVM
# ---------------
# The process is no different from other examples.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, library, ctx)
# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out_shape = (net.outputs,)
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
return tvm_out
test_image = 'dog.jpg'
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true'
_download(img_url, test_image)
img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h)
darknet_output = get_darknet_output(net, img)
darknet_out = np.zeros(net.outputs, dtype='float32')
for i in range(net.outputs):
darknet_out[i] = darknet_output[i]
tvm_out = get_tvm_output(net, img)
np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-3, atol=1e-3)
def test_forward_extraction():
'''test extraction model'''
model_name = 'extraction'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_alexnet():
'''test alexnet model'''
model_name = 'alexnet'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_resnet50():
'''test resnet50 model'''
model_name = 'resnet50'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_yolo():
'''test yolo model'''
model_name = 'yolo'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_convolutional():
'''test convolutional layer'''
net = LIB.make_network(1)
layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_dense():
'''test fully connected layer'''
net = LIB.make_network(1)
layer = LIB.make_connected_layer(1, 75, 20, 1, 0, 0)
net.layers[0] = layer
net.w = net.h = 5
LIB.resize_network(net, 5, 5)
test_forward(net)
LIB.free_network(net)
def test_forward_maxpooling():
'''test maxpooling layer'''
net = LIB.make_network(1)
layer = LIB.make_maxpool_layer(1, 224, 224, 3, 2, 2, 0)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_avgpooling():
'''test avgerage pooling layer'''
net = LIB.make_network(1)
layer = LIB.make_avgpool_layer(1, 224, 224, 3)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_batch_norm():
'''test batch normalization layer'''
net = LIB.make_network(1)
layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 1, 0, 0, 0)
for i in range(32):
layer.rolling_mean[i] = np.random.rand(1)
layer.rolling_variance[i] = np.random.rand(1)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_shortcut():
'''test shortcut layer'''
net = LIB.make_network(3)
layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0)
layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32)
layer_3.activation = 1
net.layers[0] = layer_1
net.layers[1] = layer_2
net.layers[2] = layer_3
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_reorg():
'''test reorg layer'''
net = LIB.make_network(2)
layer_1 = LIB.make_convolutional_layer(1, 222, 222, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
layer_2 = LIB.make_reorg_layer(1, 110, 110, 32, 2, 0, 0, 0)
net.layers[0] = layer_1
net.layers[1] = layer_2
net.w = net.h = 222
LIB.resize_network(net, 222, 222)
test_forward(net)
LIB.free_network(net)
def test_forward_region():
'''test region layer'''
net = LIB.make_network(2)
layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 8, 1, 3, 2, 0, 1, 0, 0, 0, 0)
layer_2 = LIB.make_region_layer(1, 111, 111, 2, 2, 1)
layer_2.softmax = 1
net.layers[0] = layer_1
net.layers[1] = layer_2
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
if __name__ == '__main__':
test_forward_resnet50()
test_forward_alexnet()
test_forward_extraction()
test_forward_yolo()
test_forward_convolutional()
test_forward_maxpooling()
test_forward_avgpooling()
test_forward_batch_norm()
test_forward_shortcut()
test_forward_dense()
test_forward_reorg()
test_forward_region()
"""
Tutorial for running Yolo-V2 in Darknet Models
=====================
**Author**: `Siju Samuel <https://siju-samuel.github.io/>`_
This article is an introductory tutorial to deploy darknet models with NNVM.
All the required models and libraries will be downloaded from the internet
by the script.
This script runs the YOLO-V2 Model with the bounding boxes
Darknet parsing have dependancy with CFFI and CV2 library
Please install CFFI and CV2 before executing this script
pip install cffi
pip install opencv-python
"""
from ctypes import *
import math
import random
import nnvm
import nnvm.frontend.darknet
from nnvm.testing.darknet import __darknetffi__
import matplotlib.pyplot as plt
import numpy as np
import tvm
import os, sys, time, urllib, requests
if sys.version_info >= (3,):
import urllib.request as urllib2
import urllib.parse as urlparse
else:
import urllib2
import urlparse
######################################################################
# Set the parameters here.
# Supported models alexnet, resnet50, resnet152, extraction, yolo
######################################################################
model_name = 'yolo'
test_image = 'dog.jpg'
target = 'llvm'
ctx = tvm.cpu(0)
######################################################################
def dlProgress(count, block_size, total_size):
"""Show the download progress."""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
def download(url, path, overwrite=False, sizecompare=False):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Parameters
----------
url : str
Operator name, such as Convolution, Connected, etc
path : str
List of input symbols.
overwrite : dict
Dict of operator attributes
sizecompare : dict
Dict of operator attributes
Returns
-------
out_name : converted out name of operation
sym : nnvm.Symbol
Converted nnvm Symbol
"""
if os.path.isfile(path) and not overwrite:
if (sizecompare):
fileSize = os.path.getsize(path)
resHead = requests.head(url)
resGet = requests.get(url,stream=True)
if 'Content-Length' not in resHead.headers :
resGet = urllib2.urlopen(url)
urlFileSize = int(resGet.headers['Content-Length'])
if urlFileSize != fileSize:
print ("exist file got corrupted, downloading", path , " file freshly")
download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=dlProgress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=dlProgress)
######################################################################
# Prepare cfg and weights file
# Pretrained model available https://pjreddie.com/darknet/imagenet/
# --------------------------------------------------------------------
# Download cfg and weights file first time.
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/siju-samuel/darknet/blob/master/cfg/' + \
cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
download(cfg_url, cfg_name)
download(weights_url, weights_name)
######################################################################
# Download and Load darknet library
# ---------------------------------
darknet_lib = 'libdarknet.so'
darknetlib_url = 'https://github.com/siju-samuel/darknet/blob/master/lib/' + \
darknet_lib + '?raw=true'
download(darknetlib_url, darknet_lib)
#if the file doesnt exist, then exit normally.
if os.path.isfile('./' + darknet_lib) is False:
exit(0)
darknet_lib = __darknetffi__.dlopen('./' + darknet_lib)
cfg = "./" + str(cfg_name)
weights = "./" + str(weights_name)
net = darknet_lib.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
dtype = 'float32'
batch_size = 1
print("Converting darknet to nnvm symbols...")
sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)
######################################################################
# Compile the model on NNVM
# --------------------------------------------------------------------
# compile the model
data = np.empty([batch_size, net.c ,net.h, net.w], dtype);
shape = {'data': data.shape}
print("Compiling the model...")
with nnvm.compiler.build_config(opt_level=2):
graph, lib, params = nnvm.compiler.build(sym, target, shape, dtype, params)
#####################################################################
# Save the json
# --------------------------------------------------------------------
def save_lib():
#Save the graph, params and .so to the current directory
print("Saving the compiled output...")
path_name = 'nnvm_darknet_' + model_name
path_lib = path_name + '_deploy_lib.so'
lib.export_library(path_lib)
with open(path_name
+ "deploy_graph.json", "w") as fo:
fo.write(graph.json())
with open(path_name
+ "deploy_param.params", "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
#save_lib()
######################################################################
# Load a test image
# --------------------------------------------------------------------
print("Loading the test image...")
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \
test_image +'?raw=true'
download(img_url, test_image)
data = nnvm.testing.darknet.load_image(test_image, net.w, net.h)
######################################################################
# Execute on TVM
# --------------------------------------------------------------------
# The process is no different from other examples.
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
# execute
print("Running the test image...")
m.run()
# get outputs
out_shape = (net.outputs,)
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
#do the detection and bring up the bounding boxes
thresh = 0.24
hier_thresh = 0.5
img = nnvm.testing.darknet.load_image_color(test_image)
_, im_h, im_w = img.shape
probs= []
boxes = []
region_layer = net.layers[net.n - 1]
boxes, probs = nnvm.testing.yolo2_detection.get_region_boxes(region_layer, im_w, im_h, net.w, net.h,
thresh, probs, boxes, 1, tvm_out)
boxes, probs = nnvm.testing.yolo2_detection.do_nms_sort(boxes, probs,
region_layer.w*region_layer.h*region_layer.n, region_layer.classes, 0.3)
coco_name = 'coco.names'
coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name +'?raw=true'
font_name = 'arial.ttf'
font_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + font_name +'?raw=true'
download(coco_url, coco_name)
download(font_url, font_name)
with open(coco_name) as f:
content = f.readlines()
names = [x.strip() for x in content]
nnvm.testing.yolo2_detection.draw_detections(img, region_layer.w*region_layer.h*region_layer.n,
thresh, boxes, probs, names, region_layer.classes)
plt.imshow(img.transpose(1,2,0))
plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment