Commit 4dc21bdb by Pariksheet Pinjari Committed by Tianqi Chen

[NNVM][DARKNET]Yolo and Upsample frontend support (#1501)

* Yolo and Upsample frontend support

* Lint fix

* Mac support added

* Code clean and trigger CI
parent 20c495e9
...@@ -32,8 +32,12 @@ class LAYERTYPE(object): ...@@ -32,8 +32,12 @@ class LAYERTYPE(object):
NETWORK = 20 NETWORK = 20
XNOR = 21 XNOR = 21
REGION = 22 REGION = 22
REORG = 23 YOLO = 23
BLANK = 24 REORG = 24
UPSAMPLE = 25
LOGXENT = 26
L2NORM = 27
BLANK = 28
class ACTIVATION(object): class ACTIVATION(object):
"""Darknet ACTIVATION Class constant.""" """Darknet ACTIVATION Class constant."""
...@@ -257,6 +261,12 @@ def _darknet_reshape(inputs, attrs): ...@@ -257,6 +261,12 @@ def _darknet_reshape(inputs, attrs):
new_attrs['shape'] = _darknet_required_attr(attrs, 'shape') new_attrs['shape'] = _darknet_required_attr(attrs, 'shape')
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_upsampling(inputs, attrs):
"""Process the upsampling operation."""
op_name, new_attrs = 'upsampling', {}
new_attrs['scale'] = attrs.get('scale', 1)
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_softmax_output(inputs, attrs): def _darknet_softmax_output(inputs, attrs):
"""Process the softmax operation.""" """Process the softmax operation."""
temperature = attrs.get('temperature', 1) temperature = attrs.get('temperature', 1)
...@@ -298,6 +308,15 @@ def _darknet_region(inputs, attrs): ...@@ -298,6 +308,15 @@ def _darknet_region(inputs, attrs):
new_attrs['softmax'] = attrs.get('softmax', 0) new_attrs['softmax'] = attrs.get('softmax', 0)
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_yolo(inputs, attrs):
"""Process the yolo operation."""
op_name, new_attrs = 'yolov3_yolo', {}
if 'n' in attrs:
new_attrs['n'] = attrs.get('n', 1)
if 'classes' in attrs:
new_attrs['classes'] = attrs.get('classes', 1)
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_activations(inputs, attrs): def _darknet_activations(inputs, attrs):
"""Process the activation function.""" """Process the activation function."""
act = _darknet_required_attr(attrs, 'activation') act = _darknet_required_attr(attrs, 'activation')
...@@ -350,6 +369,8 @@ _DARKNET_CONVERT_MAP = { ...@@ -350,6 +369,8 @@ _DARKNET_CONVERT_MAP = {
LAYERTYPE.REORG : _darknet_reorg, LAYERTYPE.REORG : _darknet_reorg,
LAYERTYPE.REGION : _darknet_region, LAYERTYPE.REGION : _darknet_region,
LAYERTYPE.SHORTCUT : _darknet_shortcut, LAYERTYPE.SHORTCUT : _darknet_shortcut,
LAYERTYPE.UPSAMPLE : _darknet_upsampling,
LAYERTYPE.YOLO : _darknet_yolo,
LAYERTYPE.DETECTION : _darknet_op_not_support, LAYERTYPE.DETECTION : _darknet_op_not_support,
LAYERTYPE.CROP : _darknet_op_not_support, LAYERTYPE.CROP : _darknet_op_not_support,
LAYERTYPE.COST : _darknet_op_not_support, LAYERTYPE.COST : _darknet_op_not_support,
...@@ -575,6 +596,13 @@ class GraphProto(object): ...@@ -575,6 +596,13 @@ class GraphProto(object):
attr.update({'coords' : layer.coords}) attr.update({'coords' : layer.coords})
attr.update({'background' : layer.background}) attr.update({'background' : layer.background})
attr.update({'softmax' : layer.softmax}) attr.update({'softmax' : layer.softmax})
elif LAYERTYPE.YOLO == layer.type:
attr.update({'n' : layer.n})
attr.update({'classes' : layer.classes})
elif LAYERTYPE.UPSAMPLE == layer.type:
attr.update({'scale' : layer.stride})
else: else:
err = "Darknet layer type {} is not supported in nnvm.".format(layer.type) err = "Darknet layer type {} is not supported in nnvm.".format(layer.type)
raise NotImplementedError(err) raise NotImplementedError(err)
......
...@@ -115,8 +115,12 @@ class LAYERTYPE(object): ...@@ -115,8 +115,12 @@ class LAYERTYPE(object):
NETWORK = 20 NETWORK = 20
XNOR = 21 XNOR = 21
REGION = 22 REGION = 22
REORG = 23 YOLO = 23
BLANK = 24 REORG = 24
UPSAMPLE = 25
LOGXENT = 26
L2NORM = 27
BLANK = 28
class ACTIVATION(object): class ACTIVATION(object):
"""Darknet ACTIVATION Class constant.""" """Darknet ACTIVATION Class constant."""
...@@ -182,12 +186,16 @@ typedef enum { ...@@ -182,12 +186,16 @@ typedef enum {
NETWORK, NETWORK,
XNOR, XNOR,
REGION, REGION,
YOLO,
REORG, REORG,
UPSAMPLE,
LOGXENT,
L2NORM,
BLANK BLANK
} LAYERTYPE; } LAYERTYPE;
typedef enum{ typedef enum{
SSE, MASKED, LONE, SEG, SMOOTH SSE, MASKED, L1, SEG, SMOOTH, WGAN
} COSTTYPE; } COSTTYPE;
...@@ -241,18 +249,20 @@ struct layer{ ...@@ -241,18 +249,20 @@ struct layer{
float shift; float shift;
float ratio; float ratio;
float learning_rate_scale; float learning_rate_scale;
float clip;
int softmax; int softmax;
int classes; int classes;
int coords; int coords;
int background; int background;
int rescore; int rescore;
int objectness; int objectness;
int does_cost;
int joint; int joint;
int noadjust; int noadjust;
int reorg; int reorg;
int log; int log;
int tanh; int tanh;
int *mask;
int total;
float alpha; float alpha;
float beta; float beta;
...@@ -265,13 +275,17 @@ struct layer{ ...@@ -265,13 +275,17 @@ struct layer{
float class_scale; float class_scale;
int bias_match; int bias_match;
int random; int random;
float ignore_thresh;
float truth_thresh;
float thresh; float thresh;
float focus;
int classfix; int classfix;
int absolute; int absolute;
int onlyforward; int onlyforward;
int stopbackward; int stopbackward;
int dontload; int dontload;
int dontsave;
int dontloadscales; int dontloadscales;
float temperature; float temperature;
...@@ -309,6 +323,7 @@ struct layer{ ...@@ -309,6 +323,7 @@ struct layer{
float * delta; float * delta;
float * output; float * output;
float * loss;
float * squared; float * squared;
float * norms; float * norms;
...@@ -462,6 +477,7 @@ typedef struct network{ ...@@ -462,6 +477,7 @@ typedef struct network{
int train; int train;
int index; int index;
float *cost; float *cost;
float clip;
} network; } network;
...@@ -491,6 +507,7 @@ layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, ...@@ -491,6 +507,7 @@ layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse,
layer make_region_layer(int batch, int w, int h, int n, int classes, int coords); layer make_region_layer(int batch, int w, int h, int n, int classes, int coords);
layer make_softmax_layer(int batch, int inputs, int groups); layer make_softmax_layer(int batch, int inputs, int groups);
layer make_rnn_layer(int batch, int inputs, int outputs, int steps, ACTIVATION activation, int batch_normalize, int adam); layer make_rnn_layer(int batch, int inputs, int outputs, int steps, ACTIVATION activation, int batch_normalize, int adam);
layer make_yolo_layer(int batch, int w, int h, int n, int total, int *mask, int classes);
layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, ACTIVATION activation, int batch_normalize); layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, ACTIVATION activation, int batch_normalize);
layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam); layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam);
layer make_gru_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam); layer make_gru_layer(int batch, int inputs, int outputs, int steps, int batch_normalize, int adam);
......
...@@ -38,6 +38,21 @@ def schedule_region(attrs, outs, target): ...@@ -38,6 +38,21 @@ def schedule_region(attrs, outs, target):
reg.register_pattern("yolo_region", OpPattern.OPAQUE) reg.register_pattern("yolo_region", OpPattern.OPAQUE)
@reg.register_compute("yolov3_yolo")
def compute_yolo(attrs, inputs, _):
"""Compute definition of yolo"""
n = attrs.get_int("n")
classes = attrs.get_int("classes")
return topi.vision.yolo.yolo(inputs[0], n, classes)
@reg.register_schedule("yolov3_yolo")
def schedule_yolo(attrs, outs, target):
"""Schedule definition of yolo"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)
reg.register_pattern("yolov3_yolo", OpPattern.OPAQUE)
# multibox_prior # multibox_prior
@reg.register_schedule("multibox_prior") @reg.register_schedule("multibox_prior")
def schedule_multibox_prior(_, outs, target): def schedule_multibox_prior(_, outs, target):
......
/*!
* Copyright (c) 2018 by Contributors
* \file yolo.cc
* \brief Property def of yolo operators.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "../../elemwise_op_common.h"
namespace nnvm {
namespace top {
NNVM_REGISTER_OP(yolov3_yolo)
.describe(R"code(Yolo layer
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(5)
.add_argument("data", "Tensor", "Input data")
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInplaceOption>(
"FInplaceOption",
[](const NodeAttrs &attrs) {
return std::vector<std::pair<int, int>>{{0, 0}, {1, 0}};
})
.set_attr<FGradient>("FGradient", [](const NodePtr &n,
const std::vector<NodeEntry> &ograds) {
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});
} // namespace top
} // namespace nnvm
...@@ -44,7 +44,7 @@ def _download(url, path, overwrite=False, sizecompare=False): ...@@ -44,7 +44,7 @@ def _download(url, path, overwrite=False, sizecompare=False):
except: except:
urllib.urlretrieve(url, path) urllib.urlretrieve(url, path)
DARKNET_LIB = 'libdarknet.so' DARKNET_LIB = 'libdarknet2.0.so'
DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \ DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+ DARKNET_LIB + '?raw=true' + DARKNET_LIB + '?raw=true'
_download(DARKNETLIB_URL, DARKNET_LIB) _download(DARKNETLIB_URL, DARKNET_LIB)
...@@ -239,6 +239,8 @@ def test_forward_shortcut(): ...@@ -239,6 +239,8 @@ def test_forward_shortcut():
layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0) layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0)
layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32) layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32)
layer_3.activation = 1 layer_3.activation = 1
layer_3.alpha = 1
layer_3.beta = 1
net.layers[0] = layer_1 net.layers[0] = layer_1
net.layers[1] = layer_2 net.layers[1] = layer_2
net.layers[2] = layer_3 net.layers[2] = layer_3
...@@ -272,6 +274,30 @@ def test_forward_region(): ...@@ -272,6 +274,30 @@ def test_forward_region():
test_forward(net) test_forward(net)
LIB.free_network(net) LIB.free_network(net)
def test_forward_yolo_op():
'''test yolo layer'''
net = LIB.make_network(2)
layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 14, 1, 3, 2, 0, 1, 0, 0, 0, 0)
a = []
layer_2 = LIB.make_yolo_layer(1, 111, 111, 2, 0, a, 2)
net.layers[0] = layer_1
net.layers[1] = layer_2
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_upsample():
'''test upsample layer'''
net = LIB.make_network(1)
layer = LIB.make_upsample_layer(1, 19, 19, 3, 3)
layer.scale = 1
net.layers[0] = layer
net.w = net.h = 19
LIB.resize_network(net, 19, 19)
test_forward(net)
LIB.free_network(net)
def test_forward_elu(): def test_forward_elu():
'''test elu activation layer''' '''test elu activation layer'''
net = LIB.make_network(1) net = LIB.make_network(1)
...@@ -428,6 +454,8 @@ if __name__ == '__main__': ...@@ -428,6 +454,8 @@ if __name__ == '__main__':
test_forward_rnn() test_forward_rnn()
test_forward_reorg() test_forward_reorg()
test_forward_region() test_forward_region()
test_forward_yolo_op()
test_forward_upsample()
test_forward_elu() test_forward_elu()
test_forward_rnn() test_forward_rnn()
test_forward_crnn() test_forward_crnn()
......
...@@ -22,54 +22,48 @@ import matplotlib.pyplot as plt ...@@ -22,54 +22,48 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import tvm import tvm
import os import os
import sys
from ctypes import * from ctypes import *
from tvm.contrib.download import download from tvm.contrib.download import download
from nnvm.testing.darknet import __darknetffi__ from nnvm.testing.darknet import __darknetffi__
###################################################################### #Model name
# Set the parameters here. MODEL_NAME = 'yolo'
# Supported models alexnet, resnet50, resnet152, extraction, yolo
#
model_name = 'yolo'
test_image = 'dog.jpg'
target = 'llvm'
ctx = tvm.cpu(0)
###################################################################### ######################################################################
# Prepare cfg and weights file # Download required files
# ---------------------------- # -----------------------
# Pretrained model available https://pjreddie.com/darknet/imagenet/ # Download cfg and weights file if first time.
# Download cfg and weights file first time. CFG_NAME = MODEL_NAME + '.cfg'
WEIGHTS_NAME = MODEL_NAME + '.weights'
REPO_URL = 'https://github.com/siju-samuel/darknet/blob/master/'
CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
WEIGHTS_URL = REPO_URL + 'weights/' + WEIGHTS_NAME + '?raw=true'
download(CFG_URL, CFG_NAME)
download(WEIGHTS_URL, WEIGHTS_NAME)
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/siju-samuel/darknet/blob/master/cfg/' + \
cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
download(cfg_url, cfg_name)
download(weights_url, weights_name)
######################################################################
# Download and Load darknet library # Download and Load darknet library
# --------------------------------- if sys.platform in ['linux', 'linux2']:
DARKNET_LIB = 'libdarknet2.0.so'
darknet_lib = 'libdarknet.so' DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
darknetlib_url = 'https://github.com/siju-samuel/darknet/blob/master/lib/' + \ elif sys.platform == 'darwin':
darknet_lib + '?raw=true' DARKNET_LIB = 'libdarknet_mac2.0.so'
download(darknetlib_url, darknet_lib) DARKNET_URL = REPO_URL + 'lib_osx/' + DARKNET_LIB + '?raw=true'
else:
#if the file doesnt exist, then exit normally. err = "Darknet lib is not supported on {} platform".format(sys.platform)
if os.path.isfile('./' + darknet_lib) is False: raise NotImplementedError(err)
exit(0)
download(DARKNET_URL, DARKNET_LIB)
darknet_lib = __darknetffi__.dlopen('./' + darknet_lib)
cfg = "./" + str(cfg_name) DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
weights = "./" + str(weights_name) cfg = "./" + str(CFG_NAME)
net = darknet_lib.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0) weights = "./" + str(WEIGHTS_NAME)
net = DARKNET_LIB.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
dtype = 'float32' dtype = 'float32'
batch_size = 1 batch_size = 1
print("Converting darknet to nnvm symbols...") print("Converting darknet to nnvm symbols...")
sym, params = nnvm.frontend.darknet.from_darknet(net, dtype) sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)
...@@ -77,7 +71,9 @@ sym, params = nnvm.frontend.darknet.from_darknet(net, dtype) ...@@ -77,7 +71,9 @@ sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)
# Compile the model on NNVM # Compile the model on NNVM
# ------------------------- # -------------------------
# compile the model # compile the model
data = np.empty([batch_size, net.c ,net.h, net.w], dtype); target = 'llvm'
ctx = tvm.cpu(0)
data = np.empty([batch_size, net.c, net.h, net.w], dtype)
shape = {'data': data.shape} shape = {'data': data.shape}
print("Compiling the model...") print("Compiling the model...")
with nnvm.compiler.build_config(opt_level=2): with nnvm.compiler.build_config(opt_level=2):
...@@ -103,6 +99,7 @@ def save_lib(): ...@@ -103,6 +99,7 @@ def save_lib():
###################################################################### ######################################################################
# Load a test image # Load a test image
# -------------------------------------------------------------------- # --------------------------------------------------------------------
test_image = 'dog.jpg'
print("Loading the test image...") print("Loading the test image...")
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \ img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \
test_image +'?raw=true' test_image +'?raw=true'
...@@ -134,7 +131,7 @@ thresh = 0.24 ...@@ -134,7 +131,7 @@ thresh = 0.24
hier_thresh = 0.5 hier_thresh = 0.5
img = nnvm.testing.darknet.load_image_color(test_image) img = nnvm.testing.darknet.load_image_color(test_image)
_, im_h, im_w = img.shape _, im_h, im_w = img.shape
probs= [] probs = []
boxes = [] boxes = []
region_layer = net.layers[net.n - 1] region_layer = net.layers[net.n - 1]
boxes, probs = nnvm.testing.yolo2_detection.get_region_boxes(region_layer, im_w, im_h, net.w, net.h, boxes, probs = nnvm.testing.yolo2_detection.get_region_boxes(region_layer, im_w, im_h, net.w, net.h,
...@@ -157,5 +154,5 @@ names = [x.strip() for x in content] ...@@ -157,5 +154,5 @@ names = [x.strip() for x in content]
nnvm.testing.yolo2_detection.draw_detections(img, region_layer.w*region_layer.h*region_layer.n, nnvm.testing.yolo2_detection.draw_detections(img, region_layer.w*region_layer.h*region_layer.n,
thresh, boxes, probs, names, region_layer.classes) thresh, boxes, probs, names, region_layer.classes)
plt.imshow(img.transpose(1,2,0)) plt.imshow(img.transpose(1, 2, 0))
plt.show() plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment