Commit 7ec898d5 by Pariksheet Pinjari Committed by Tianqi Chen

[FRONTEND] DarkNet Yolo2 Frontend Support (#377)

parent 2e836ca7
......@@ -56,7 +56,7 @@ endif
all: lib/libnnvm.a lib/libnnvm_compiler.$(SHARED_LIBRARY_SUFFIX)
SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc)
SRC_COMPILER = $(wildcard src/top/*/*.cc src/compiler/*.cc src/compiler/*/*.cc)
SRC_COMPILER = $(wildcard src/top/*/*.cc wildcard src/top/vision/*/*.cc src/compiler/*.cc src/compiler/*/*.cc)
ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC))
TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_COMPILER))
ALL_DEP = $(ALL_OBJ)
......
......@@ -4,3 +4,4 @@ from .mxnet import from_mxnet
from .onnx import from_onnx
from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
"""
DarkNet symbol frontend.
"""
from __future__ import absolute_import as _abs
from enum import IntEnum
import numpy as np
import tvm
from .. import symbol as _sym
class LAYERTYPE(IntEnum):
"""Darknet LAYERTYPE Class constant."""
CONVOLUTIONAL = 0
DECONVOLUTIONAL = 1
CONNECTED = 2
MAXPOOL = 3
SOFTMAX = 4
DETECTION = 5
DROPOUT = 6
CROP = 7
ROUTE = 8
COST = 9
NORMALIZATION = 10
AVGPOOL = 11
LOCAL = 12
SHORTCUT = 13
ACTIVE = 14
RNN = 15
GRU = 16
LSTM = 17
CRNN = 18
BATCHNORM = 19
NETWORK = 20
XNOR = 21
REGION = 22
REORG = 23
BLANK = 24
class ACTIVATION(IntEnum):
"""Darknet ACTIVATION Class constant."""
LOGISTIC = 0
RELU = 1
RELIE = 2
LINEAR = 3
RAMP = 4
TANH = 5
PLSE = 6
LEAKY = 7
ELU = 8
LOGGY = 9
STAIR = 10
HARDTAN = 11
LHTAN = 12
__all__ = ['from_darknet']
def _darknet_get_nnvm_op(op_name):
"""Get the nnvm operation from opname, raise error if not supported."""
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Not to map op_name {} to nnvm.sym".format(op_name))
return op
def _darknet_required_attr(attr, key):
"""Check the attribute exists and return if exists, if not return error."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
def _darknet_raise_not_supported(attr, op='nnvm'):
"""Raise error if any operation is not supported."""
err = "{} is not supported in {}.".format(attr, op)
raise NotImplementedError(err)
def _darknet_warn_not_used(attr, op='nnvm'):
"""Raise warning if any operation not supported."""
import warnings
err = "{} is ignored in {}.".format(attr, op)
warnings.warn(err)
def _darknet_parse_tshape(tshape):
"""Parse tshape in string."""
return [int(x.strip()) for x in tshape.strip('()').split(',')]
def _darknet_parse_bool_str(attr, key, default='False'):
"""Parse bool string to boolean."""
return attr.get(key, default).strip().lower() in \
['true', '1', 't', 'y', 'yes']
def _darknet_maxpooling(inputs, attrs):
"""Process the max pool 2d operation."""
kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
if len(kernel) != 1:
_darknet_raise_not_supported('non-2d kernel', 'pool_2d')
op_name, new_attrs = 'max_pool2d', {}
strides = int(attrs.get('stride', (1, 1)))
pads = int(attrs.get('pad', (0, 0)))
new_attrs['pool_size'] = [kernel[0], kernel[0]]
new_attrs['strides'] = str((strides, strides))
new_attrs['padding'] = str((pads, pads))
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_avgpooling(inputs, attrs):
"""Process the average pool 2d operation."""
kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
if len(kernel) != 1:
_darknet_raise_not_supported('non-2d kernel', 'pool_2d')
op_name, new_attrs = 'avg_pool2d', {}
strides = int(attrs.get('stride', (1, 1)))
pads = int(attrs.get('pad', (0, 0)))
new_attrs['pool_size'] = [kernel[0], kernel[0]]
new_attrs['strides'] = str((strides, strides))
new_attrs['padding'] = str((pads, pads))
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_batch_norm(inputs, attrs):
"""Process the batchnormalization operation."""
op_name, new_attrs = 'darknet_batch_norm', {}
new_attrs['axis'] = attrs.get('axis', 1)
new_attrs['epsilon'] = attrs.get('eps', 0.000001)
new_attrs['center'] = True
new_attrs['scale'] = True
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_conv2d(inputs, attrs):
"""Process the convolution 2d operation."""
kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
if len(kernel) != 1:
_darknet_raise_not_supported('non 2d kernel', 'conv2d')
layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']:
_darknet_raise_not_supported('layout: ' + layout, 'conv2d')
strides = int(attrs.get('stride', (1, 1)))
pads = int(attrs.get('pad', (0, 0)))
op_name, new_attrs = 'conv2d', {}
new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = [kernel[0], kernel[0]]
new_attrs['strides'] = (strides, strides)
new_attrs['padding'] = (pads, pads)
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
if attrs.get('use_batchNorm', False) is True:
new_attrs['use_bias'] = False
else:
new_attrs['use_bias'] = True
out_name = {}
sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
out_name[0] = sym.list_output_names()[0].replace('_output', '')
if attrs.get('use_batchNorm', False) is True:
op_name, new_attrs = 'batch_norm', {}
new_attrs['epsilon'] = 0.000001
sym = _darknet_get_nnvm_op(op_name)(*sym, **new_attrs)
out_name[1] = sym.list_output_names()[0].replace('_output', '')
if 'activation' in attrs:
new_attrs = {}
new_attrs['activation'] = attrs['activation']
new_attrs['slope'] = 0.1
sym, _ = _darknet_activations(sym, new_attrs)
return sym, out_name
def _darknet_conv2d_transpose(inputs, attrs):
"""Process the convolution 2d transpose operation."""
if 'target_shape' in attrs:
_darknet_raise_not_supported('target_shape', 'conv2d_transpose')
kernel = _darknet_parse_tshape(_darknet_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_darknet_raise_not_supported('non-2d kernel', 'conv2d_transpose')
layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']:
_darknet_raise_not_supported('layout: ' + layout, 'conv2d_transpose')
op_name, new_attrs = 'conv2d_transpose', {}
new_attrs['channels'] = _darknet_required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel
new_attrs['strides'] = attrs.get('stride', (1, 1))
new_attrs['output_padding'] = attrs.get('adj', (0, 0))
new_attrs['padding'] = attrs.get('pad', (0, 0))
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['use_bias'] = not _darknet_parse_bool_str(attrs, 'no_bias')
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_shortcut(inputs, attrs):
"""Process the shortcut operation."""
op_name, new_attrs = 'elemwise_add', {}
input_0 = inputs[0]
input_1 = inputs[1]
input_0_channel = int(attrs['out_channel'])
input_1_channel = int(attrs['add_out_channel'])
input_0_size = int(attrs['out_size'])
input_1_size = int(attrs['add_out_size'])
if input_0_size > input_1_size:
scale = int(input_0_size/input_1_size)
input_1 = _sym.upsampling(input_1, scale=scale, name="_upsampling")
elif input_0_size < input_1_size:
stride = int(input_1_size/input_0_size)
input_1 = _sym.avg_pool2d(input_1, pool_size=(1, 1),
strides=(stride, stride), padding=(0, 0), name="_downsampling")
if input_0_channel != input_1_channel:
pad_channel = input_0_channel - input_1_channel
input_1 = _sym.pad(input_1, pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)),
pad_value=0.)
new_inputs = _as_list([input_0, input_1])
sym = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs)
out_name = sym.list_output_names()[0].replace('_output', '')
if 'activation' in attrs:
new_attrs['activation'] = attrs['activation']
sym, _ = _darknet_activations(sym, new_attrs)
return sym, out_name
def _darknet_dense(inputs, attrs):
"""Process the dense operation."""
op_name, new_attrs = 'dense', {}
new_attrs['units'] = _darknet_required_attr(attrs, 'num_hidden')
if attrs.get('use_bias', False) is True:
new_attrs['use_bias'] = True
if attrs.get('use_flatten', False) is True:
inputs[0] = _sym.flatten(inputs[0])
sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
out_name = sym.list_output_names()[0].replace('_output', '')
if 'activation' in attrs:
new_attrs = {}
new_attrs['activation'] = attrs['activation']
sym, _ = _darknet_activations(sym, new_attrs)
return sym, out_name
def _darknet_dropout(inputs, attrs):
"""Process the dropout operation, its a blank operation."""
op_name, new_attrs = 'dropout', {}
new_attrs['rate'] = attrs.get('p', 0.5)
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_reshape(inputs, attrs):
"""Process the reshape operation."""
if _darknet_parse_bool_str(attrs, 'reverse'):
_darknet_raise_not_supported('reverse', 'reshape')
op_name, new_attrs = 'reshape', {}
new_attrs['shape'] = _darknet_required_attr(attrs, 'shape')
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_softmax_output(inputs, attrs):
"""Process the softmax operation."""
op_name, new_attrs = 'softmax', {}
if _darknet_parse_bool_str(attrs, 'multi_output'):
new_attrs['axis'] = 1
if attrs.get('use_flatten', False) is True:
inputs[0] = _sym.flatten(inputs[0])
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_route(inputs, attrs):
"""Process the route operation, which is equivalent to concat."""
op_name = 'concatenate'
new_attrs = {'axis': attrs.get('dim', 1)}
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_reorg(inputs, attrs):
"""Process the reorg operation."""
op_name, new_attrs = 'yolo2_reorg', {}
if 'stride' in attrs:
new_attrs = {'stride': attrs.get('stride', 1)}
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_region(inputs, attrs):
"""Process the region operation."""
op_name, new_attrs = 'yolo2_region', {}
if 'n' in attrs:
new_attrs['n'] = attrs.get('n', 1)
if 'classes' in attrs:
new_attrs['classes'] = attrs.get('classes', 1)
if 'coords' in attrs:
new_attrs['coords'] = attrs.get('coords', 0)
if 'background' in attrs:
new_attrs['background'] = attrs.get('background', 0)
if 'softmax' in attrs:
new_attrs['softmax'] = attrs.get('softmax', 0)
return _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs), None
def _darknet_activations(inputs, attrs):
"""Process the activation function."""
act = _darknet_required_attr(attrs, 'activation')
if ACTIVATION.RELU == act:
act_type = 'relu'
elif ACTIVATION.TANH == act:
act_type = 'tanh'
elif ACTIVATION.LINEAR == act:
return inputs, None
elif ACTIVATION.LEAKY == act:
act_type = 'leaky_relu'
else:
_darknet_raise_not_supported('act: ' + act)
if act_type in ['relu', 'tanh']:
op_name, new_attrs = act_type, {}
sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type in ['leaky_relu']:
op_name, new_attrs = act_type, {}
new_attrs['alpha'] = attrs.get('slope', 0.1)
sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs)
else:
_darknet_raise_not_supported('act_type: ' + act_type)
return sym, None
def _darknet_op_not_support(inputs, attrs):
"""Raise exception if the operation is not supported."""
err = "{} is not supported in {}.".format(attrs, inputs)
raise NotImplementedError(err)
_DARKNET_CONVERT_MAP = {
'CONVOLUTIONAL' : _darknet_conv2d,
'DECONVOLUTIONAL' : _darknet_conv2d_transpose,
'CONNECTED' : _darknet_dense,
'MAXPOOL' : _darknet_maxpooling,
'SOFTMAX' : _darknet_softmax_output,
'DROPOUT' : _darknet_dropout,
'AVGPOOL' : _darknet_avgpooling,
'BATCHNORM' : _darknet_batch_norm,
'RESHAPE' : _darknet_reshape,
'ROUTE' : _darknet_route,
'REORG' : _darknet_reorg,
'REGION' : _darknet_region,
'ACTIVATION' : _darknet_activations,
'SHORTCUT' : _darknet_shortcut,
'DETECTION' : _darknet_op_not_support,
'CROP' : _darknet_op_not_support,
'COST' : _darknet_op_not_support,
'NORMALIZATION' : _darknet_op_not_support,
'LOCAL' : _darknet_op_not_support,
'ACTIVE' : _darknet_op_not_support,
'RNN' : _darknet_op_not_support,
'GRU' : _darknet_op_not_support,
'LSTM' : _darknet_op_not_support,
'CRNN' : _darknet_op_not_support,
'NETWORK' : _darknet_op_not_support,
'XNOR' : _darknet_op_not_support,
'BLANK' : _darknet_op_not_support,
}
def _darknet_convert_symbol(op_name, inputs, attrs):
"""Convert from darknet op to nnvm op.
The converter must specify some conversions explicitly to
support gluon format ops such as conv2d...
Parameters
----------
op_name : str
Operator name, such as Convolution, Connected, etc
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
Returns
-------
out_name : converted out name of operation
sym : nnvm.Symbol
Converted nnvm Symbol
"""
if op_name in _DARKNET_CONVERT_MAP:
sym, out_name = _DARKNET_CONVERT_MAP[op_name](inputs, attrs)
else:
_darknet_raise_not_supported('Operator: ' + op_name)
if out_name is None:
out_name = sym.list_output_names()[0].replace('_output', '')
return out_name, sym
def _as_list(arr):
"""Force being a list, ignore if already is."""
if isinstance(arr, list):
return arr
return [arr]
def _read_memory_buffer(shape, data, dtype):
length = 1
for x in shape:
length *= x
data_np = np.zeros(length, dtype=dtype)
for i in range(length):
data_np[i] = data[i]
return data_np.reshape(shape)
def _get_darknet_layername(layer_type):
"""Get the layer name from the darknet enums."""
return str((LAYERTYPE(layer_type))).replace('LAYERTYPE.', '')
def _get_convolution_weights(layer, opname, params, dtype):
"""Get the convolution layer weights and biases."""
if layer.nweights == 0:
return
if (layer.n * layer.c * layer.size * layer.size) != layer.nweights:
raise RuntimeError("layer weights size not matching with n c h w")
weights = _read_memory_buffer((layer.n, layer.c, layer.size, layer.size), layer.weights, dtype)
biases = _read_memory_buffer((layer.n, ), layer.biases, dtype)
k = _get_tvm_params_name(opname[0], 'weight')
params[k] = tvm.nd.array(weights)
if layer.batch_normalize == 1 and layer.dontloadscales != 1:
_get_batchnorm_weights(layer, opname[1], params, layer.n, dtype)
k = _get_tvm_params_name(opname[1], 'beta')
params[k] = tvm.nd.array(biases)
else:
k = _get_tvm_params_name(opname[0], 'bias')
params[k] = tvm.nd.array(biases)
def _get_connected_weights(layer, opname, params, dtype):
"""Parse the weights and biases for fully connected or dense layer."""
size = layer.outputs * layer.inputs
if size == 0:
return
weights = _read_memory_buffer((layer.outputs, layer.inputs), layer.weights, dtype)
biases = _read_memory_buffer((layer.outputs, ), layer.biases, dtype)
k = _get_tvm_params_name(opname, 'weight')
params[k] = tvm.nd.array(weights)
k = _get_tvm_params_name(opname, 'bias')
params[k] = tvm.nd.array(biases)
if layer.batch_normalize == 1 and layer.dontloadscales != 1:
_get_batchnorm_weights(layer, opname, params, layer.outputs, dtype)
def _get_batchnorm_weights(layer, opname, params, size, dtype):
"""Parse the weights for batchnorm, which includes, scales, moving mean
and moving variances."""
scales = _read_memory_buffer((size, ), layer.scales, dtype)
rolling_mean = _read_memory_buffer((size, ), layer.rolling_mean, dtype)
rolling_variance = _read_memory_buffer((size, ), layer.rolling_variance, dtype)
k = _get_tvm_params_name(opname, 'moving_mean')
params[k] = tvm.nd.array(rolling_mean)
k = _get_tvm_params_name(opname, 'moving_var')
params[k] = tvm.nd.array(rolling_variance)
k = _get_tvm_params_name(opname, 'gamma')
params[k] = tvm.nd.array(scales)
def _get_darknet_attrs(net, layer_num):
"""Parse attributes of each layer and return."""
attr = {}
use_flatten = True
layer = net.layers[layer_num]
op_name = _get_darknet_layername(layer.type)
if LAYERTYPE.CONVOLUTIONAL == layer.type:
attr.update({'layout' : 'NCHW'})
attr.update({'pad' : str(layer.pad)})
attr.update({'num_group' : str(layer.groups)})
attr.update({'num_filter' : str(layer.n)})
attr.update({'stride' : str(layer.stride)})
attr.update({'kernel' : str(layer.size)})
attr.update({'activation' : (layer.activation)})
if layer.nbiases == 0:
attr.update({'use_bias' : False})
else:
attr.update({'use_bias' : True})
if layer.batch_normalize == 1 and layer.dontloadscales != 1:
attr.update({'use_batchNorm' : True})
attr.update({'use_scales' : True})
#elif LAYERTYPE.BATCHNORM == layer.type:
# attr.update({'flatten' : str('True')})
elif LAYERTYPE.CONNECTED == layer.type:
attr.update({'num_hidden' : str(layer.outputs)})
attr.update({'activation' : (layer.activation)})
if layer_num != 0:
layer_prev = net.layers[layer_num - 1]
if (layer_prev.out_h == layer.h and
layer_prev.out_w == layer.w and
layer_prev.out_c == layer.c):
use_flatten = False
attr.update({'use_flatten' : use_flatten})
if layer.nbiases == 0:
attr.update({'use_bias' : False})
else:
attr.update({'use_bias' : True})
if layer.batch_normalize == 1 and layer.dontloadscales != 1:
attr.update({'use_batchNorm' : True})
attr.update({'use_scales' : True})
elif LAYERTYPE.MAXPOOL == layer.type:
attr.update({'pad' : str(layer.pad)})
attr.update({'stride' : str(layer.stride)})
attr.update({'kernel' : str(layer.size)})
elif LAYERTYPE.AVGPOOL == layer.type:
attr.update({'pad' : str(layer.pad)})
if layer.stride == 0:
attr.update({'stride' : str(1)})
else:
attr.update({'stride' : str(layer.stride)})
if layer.size == 0 and layer.h == layer.w:
attr.update({'kernel' : str(layer.h)})
else:
attr.update({'kernel' : str(layer.size)})
elif LAYERTYPE.DROPOUT == layer.type:
attr.update({'p' : str(layer.probability)})
elif LAYERTYPE.SOFTMAX == layer.type:
attr.update({'axis' : 1})
attr.update({'use_flatten' : True})
elif LAYERTYPE.SHORTCUT == layer.type:
add_layer = net.layers[layer.index]
attr.update({'activation' : (layer.activation)})
attr.update({'out_channel' : (layer.out_c)})
attr.update({'out_size' : (layer.out_h)})
attr.update({'add_out_channel' : (add_layer.out_c)})
attr.update({'add_out_size' : (add_layer.out_h)})
elif LAYERTYPE.ROUTE == layer.type:
pass
elif LAYERTYPE.COST == layer.type:
pass
elif LAYERTYPE.REORG == layer.type:
attr.update({'stride' : layer.stride})
elif LAYERTYPE.REGION == layer.type:
attr.update({'n' : layer.n})
attr.update({'classes' : layer.classes})
attr.update({'coords' : layer.coords})
attr.update({'background' : layer.background})
attr.update({'softmax' : layer.softmax})
else:
err = "Darknet layer {} is not supported in nnvm.".format(op_name)
raise NotImplementedError(err)
return op_name, attr
def _get_tvm_params_name(opname, arg_name):
"""Makes the params name for the k,v pair."""
return opname + '_'+ arg_name
def _get_darknet_params(layer, opname, tvmparams, dtype='float32'):
"""To parse and get the darknet params."""
if LAYERTYPE.CONVOLUTIONAL == layer.type:
_get_convolution_weights(layer, opname, tvmparams, dtype)
#elif LAYERTYPE.BATCHNORM == layer.type:
# size = layer.outputs
# _get_batchnorm_weights(layer, opname, tvmparams, size, dtype)
elif LAYERTYPE.CONNECTED == layer.type:
_get_connected_weights(layer, opname, tvmparams, dtype)
def _preproc_layer(net, i, sym_array):
"""To preprocess each darknet layer, some layer doesnt need processing."""
layer = net.layers[i]
if i == 0:
name = 'data'
attribute = {}
sym = [_sym.Variable(name, **attribute)]
else:
sym = sym_array[i - 1]
skip_layer = False
if LAYERTYPE.ROUTE == layer.type:
sym = []
for j in range(layer.n):
sym.append(sym_array[layer.input_layers[j]])
if layer.n == 1:
skip_layer = True
elif LAYERTYPE.COST == layer.type:
skip_layer = True
elif LAYERTYPE.SHORTCUT == layer.type:
sym = [sym, sym_array[layer.index]]
elif LAYERTYPE.BLANK == layer.type:
skip_layer = True
if skip_layer is True:
sym_array[i] = sym
return skip_layer, sym
def _from_darknet(net, dtype='float32'):
"""To convert the darknet symbol to nnvm symbols."""
sym_array = {}
tvmparams = {}
for i in range(net.n):
need_skip, sym = _preproc_layer(net, i, sym_array)
if need_skip is True:
continue
op_name, attr = _get_darknet_attrs(net, i)
layer_name, sym = _darknet_convert_symbol(op_name, _as_list(sym), attr)
_get_darknet_params(net.layers[i], layer_name, tvmparams, dtype)
sym_array[i] = sym
return sym, tvmparams
def from_darknet(net, dtype='float32'):
"""Convert from darknet's model into compatible NNVM format.
Reconstruct a nnvm symbol by traversing the darknet input.
Parameters
----------
net : ctype Pointer to network
Darknet parsed symbols
dtype : str
Datatype of the input net structure, default is float32
Returns
-------
sym : nnvm.Symbol
Compatible nnvm symbol
params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm
"""
return _from_darknet(net, dtype)
......@@ -7,3 +7,5 @@ from . import mobilenet
from . import mlp
from . import resnet
from . import vgg
from . import darknet
from . import yolo2_detection
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
"""
Compile DarkNet Models
====================
DarkNet helper functions for darknet model parsing and image loading.
This functions will not be loaded by default.
These are utility functions used for testing and tutorial file.
"""
from __future__ import division
from enum import IntEnum
import math
import numpy as np
import cv2
from cffi import FFI
def _resize_image(img, w_in, h_in):
"""Resize the image to the given height and width."""
imc, imh, imw = img.shape
h_in = int(h_in)
w_in = int(w_in)
part = np.zeros((imc, imh, w_in))
resized = np.zeros((imc, h_in, w_in))
w_scale = (imw - 1) / (w_in - 1)
h_scale = (imh - 1) / (h_in - 1)
for k in range(imc):
for j in range(imh):
for c in range(w_in):
if c == w_in - 1 or imw == 1:
part[k][j][c] = img[k][j][imw - 1]
else:
fdx, idx = math.modf(c * w_scale)
part[k][j][c] = (1 - fdx) * img[k][j][int(idx)] + \
fdx * img[k][j][int(idx) + 1]
for k in range(imc):
for j in range(h_in):
fdy, idy = math.modf(j * h_scale)
for c in range(w_in):
resized[k][j][c] = (1 - fdy)*part[k][int(idy)][c]
if (j == h_in - 1) or (imh == 1):
continue
for c in range(w_in):
resized[k][j][c] += fdy * part[k][int(idy) + 1][c]
return resized
def load_image_color(test_image):
"""To load the image using opencv api and do preprocessing."""
imagex = cv2.imread(test_image)
imagex = np.array(imagex)
imagex = imagex.transpose((2, 0, 1))
imagex = np.divide(imagex, 255.0)
imagex = np.flip(imagex, 0)
return imagex
def _letterbox_image(img, w_in, h_in):
"""To get the image in boxed format."""
imc, imh, imw = img.shape
if (w_in / imw) < (h_in / imh):
new_w = w_in
new_h = imh * w_in / imw
else:
new_h = h_in
new_w = imw * h_in/imh
resized = _resize_image(img, new_w, new_h)
boxed = np.full((imc, h_in, w_in), 0.5, dtype=float)
_, resizedh, resizedw = resized.shape
boxed[:, int((h_in - new_h) / 2)
:int((h_in - new_h) / 2) + resizedh, int((w_in - new_w) / 2)
:int((w_in - new_w) / 2) + resizedw] = resized
return boxed
def load_image(image, resize_width, resize_height):
"""Load the image and convert to the darknet model format.
The image processing of darknet is different from normal.
Parameters
----------
image : string
The image file name with path
resize_width : integer
The width to which the image needs to be resized
resize_height : integer
The height to which the image needs to be resized
Returns
-------
img : Float array
Array of processed image
"""
img = load_image_color(image)
return _letterbox_image(img, resize_width, resize_height)
class LAYERTYPE(IntEnum):
"""Darknet LAYERTYPE Class constant."""
CONVOLUTIONAL = 0
DECONVOLUTIONAL = 1
CONNECTED = 2
MAXPOOL = 3
SOFTMAX = 4
DETECTION = 5
DROPOUT = 6
CROP = 7
ROUTE = 8
COST = 9
NORMALIZATION = 10
AVGPOOL = 11
LOCAL = 12
SHORTCUT = 13
ACTIVE = 14
RNN = 15
GRU = 16
LSTM = 17
CRNN = 18
BATCHNORM = 19
NETWORK = 20
XNOR = 21
REGION = 22
REORG = 23
BLANK = 24
class ACTIVATION(IntEnum):
"""Darknet ACTIVATION Class constant."""
LOGISTIC = 0
RELU = 1
RELIE = 2
LINEAR = 3
RAMP = 4
TANH = 5
PLSE = 6
LEAKY = 7
ELU = 8
LOGGY = 9
STAIR = 10
HARDTAN = 11
LHTAN = 12
__darknetffi__ = FFI()
__darknetffi__.cdef("""
typedef struct network network;
typedef struct layer layer;
typedef struct{
int *leaf;
int n;
int *parent;
int *child;
int *group;
char **name;
int groups;
int *group_size;
int *group_offset;
} tree;
typedef enum{
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN
} ACTIVATION;
typedef enum {
CONVOLUTIONAL,
DECONVOLUTIONAL,
CONNECTED,
MAXPOOL,
SOFTMAX,
DETECTION,
DROPOUT,
CROP,
ROUTE,
COST,
NORMALIZATION,
AVGPOOL,
LOCAL,
SHORTCUT,
ACTIVE,
RNN,
GRU,
LSTM,
CRNN,
BATCHNORM,
NETWORK,
XNOR,
REGION,
REORG,
BLANK
} LAYERTYPE;
typedef enum{
SSE, MASKED, LONE, SEG, SMOOTH
} COSTTYPE;
struct layer{
LAYERTYPE type;
ACTIVATION activation;
COSTTYPE cost_type;
void (*forward);
void (*backward);
void (*update);
void (*forward_gpu);
void (*backward_gpu);
void (*update_gpu);
int batch_normalize;
int shortcut;
int batch;
int forced;
int flipped;
int inputs;
int outputs;
int nweights;
int nbiases;
int extra;
int truths;
int h,w,c;
int out_h, out_w, out_c;
int n;
int max_boxes;
int groups;
int size;
int side;
int stride;
int reverse;
int flatten;
int spatial;
int pad;
int sqrt;
int flip;
int index;
int binary;
int xnor;
int steps;
int hidden;
int truth;
float smooth;
float dot;
float angle;
float jitter;
float saturation;
float exposure;
float shift;
float ratio;
float learning_rate_scale;
int softmax;
int classes;
int coords;
int background;
int rescore;
int objectness;
int does_cost;
int joint;
int noadjust;
int reorg;
int log;
int tanh;
float alpha;
float beta;
float kappa;
float coord_scale;
float object_scale;
float noobject_scale;
float mask_scale;
float class_scale;
int bias_match;
int random;
float thresh;
int classfix;
int absolute;
int onlyforward;
int stopbackward;
int dontload;
int dontloadscales;
float temperature;
float probability;
float scale;
char * cweights;
int * indexes;
int * input_layers;
int * input_sizes;
int * map;
float * rand;
float * cost;
float * state;
float * prev_state;
float * forgot_state;
float * forgot_delta;
float * state_delta;
float * combine_cpu;
float * combine_delta_cpu;
float * concat;
float * concat_delta;
float * binary_weights;
float * biases;
float * bias_updates;
float * scales;
float * scale_updates;
float * weights;
float * weight_updates;
float * delta;
float * output;
float * squared;
float * norms;
float * spatial_mean;
float * mean;
float * variance;
float * mean_delta;
float * variance_delta;
float * rolling_mean;
float * rolling_variance;
float * x;
float * x_norm;
float * m;
float * v;
float * bias_m;
float * bias_v;
float * scale_m;
float * scale_v;
float *z_cpu;
float *r_cpu;
float *h_cpu;
float * prev_state_cpu;
float *temp_cpu;
float *temp2_cpu;
float *temp3_cpu;
float *dh_cpu;
float *hh_cpu;
float *prev_cell_cpu;
float *cell_cpu;
float *f_cpu;
float *i_cpu;
float *g_cpu;
float *o_cpu;
float *c_cpu;
float *dc_cpu;
float * binary_input;
struct layer *input_layer;
struct layer *self_layer;
struct layer *output_layer;
struct layer *reset_layer;
struct layer *update_layer;
struct layer *state_layer;
struct layer *input_gate_layer;
struct layer *state_gate_layer;
struct layer *input_save_layer;
struct layer *state_save_layer;
struct layer *input_state_layer;
struct layer *state_state_layer;
struct layer *input_z_layer;
struct layer *state_z_layer;
struct layer *input_r_layer;
struct layer *state_r_layer;
struct layer *input_h_layer;
struct layer *state_h_layer;
struct layer *wz;
struct layer *uz;
struct layer *wr;
struct layer *ur;
struct layer *wh;
struct layer *uh;
struct layer *uo;
struct layer *wo;
struct layer *uf;
struct layer *wf;
struct layer *ui;
struct layer *wi;
struct layer *ug;
struct layer *wg;
tree *softmax_tree;
size_t workspace_size;
};
typedef enum {
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
} LEARNINGRATEPOLICY;
typedef struct network{
int n;
int batch;
size_t *seen;
int *t;
float epoch;
int subdivisions;
layer *layers;
float *output;
LEARNINGRATEPOLICY policy;
float learning_rate;
float momentum;
float decay;
float gamma;
float scale;
float power;
int time_steps;
int step;
int max_batches;
float *scales;
int *steps;
int num_steps;
int burn_in;
int adam;
float B1;
float B2;
float eps;
int inputs;
int outputs;
int truths;
int notruth;
int h, w, c;
int max_crop;
int min_crop;
float max_ratio;
float min_ratio;
int center;
float angle;
float aspect;
float exposure;
float saturation;
float hue;
int random;
int gpu_index;
tree *hierarchy;
float *input;
float *truth;
float *delta;
float *workspace;
int train;
int index;
float *cost;
} network;
typedef struct {
int w;
int h;
int c;
float *data;
} image;
network *load_network(char *cfg, char *weights, int clear);
image letterbox_image(image im, int w, int h);
int resize_network(network *net, int w, int h);
void top_predictions(network *net, int n, int *index);
void free_image(image m);
image load_image_color(char *filename, int w, int h);
float *network_predict_image(network *net, image im);
network *make_network(int n);
layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, int batch_normalize, int adam);
layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride, int padding);
layer make_avgpool_layer(int batch, int w, int h, int c);
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2);
layer make_batchnorm_layer(int batch, int w, int h, int c);
layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, int flatten, int extra);
layer make_region_layer(int batch, int w, int h, int n, int classes, int coords);
void free_network(network *net);
"""
)
# pylint: disable=invalid-name, unused-variable, unused-argument, no-init
"""
Yolo detection boxes helper functions
====================
DarkNet helper functions for yolo and image loading.
This functions will not be loaded by default.
These are utility functions used for testing and tutorial file.
"""
from __future__ import division
import math
from collections import namedtuple
import numpy as np
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
def _entry_index(batch, w, h, outputs, classes, coords, location, entry):
n = int(location/(w*h))
loc = location%(w*h)
return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
Box = namedtuple('Box', ['x', 'y', 'w', 'h'])
def _get_region_box(x, biases, n, index, i, j, w, h, stride):
b = Box(0, 0, 0, 0)
b = b._replace(x=(i + x[index + 0*stride]) / w)
b = b._replace(y=(j + x[index + 1*stride]) / h)
b = b._replace(w=np.exp(x[index + 2*stride]) * biases[2*n] / w)
b = b._replace(h=np.exp(x[index + 3*stride]) * biases[2*n+1] / h)
return b
def _correct_region_boxes(boxes, n, w, h, netw, neth, relative):
new_w, new_h = (netw, (h*netw)/w) if (netw/w < neth/h) else ((w*neth/h), neth)
for i in range(n):
b = boxes[i]
b = boxes[i]
b = b._replace(x=(b.x - (netw - new_w)/2/netw) / (new_w/netw))
b = b._replace(y=(b.y - (neth - new_h)/2/neth) / (new_h/neth))
b = b._replace(w=b.w * netw/new_w)
b = b._replace(h=b.h * neth/new_h)
if not relative:
b = b._replace(x=b.x * w)
b = b._replace(w=b.w * w)
b = b._replace(y=b.y * h)
b = b._replace(h=b.h * h)
boxes[i] = b
def _overlap(x1, w1, x2, w2):
l1 = x1 - w1/2
l2 = x2 - w2/2
left = l1 if l1 > l2 else l2
r1 = x1 + w1/2
r2 = x2 + w2/2
right = r1 if r1 < r2 else r2
return right - left
def _box_intersection(a, b):
w = _overlap(a.x, a.w, b.x, b.w)
h = _overlap(a.y, a.h, b.y, b.h)
if w < 0 or h < 0:
return 0
return w*h
def _box_union(a, b):
i = _box_intersection(a, b)
u = a.w*a.h + b.w*b.h - i
return u
def _box_iou(a, b):
return _box_intersection(a, b)/_box_union(a, b)
def get_region_boxes(layer_in, imw, imh, netw, neth, thresh, probs,
boxes, relative, tvm_out):
"To get the boxes for the image based on the prediction"
lw = layer_in.w
lh = layer_in.h
probs = [[0 for i in range(layer_in.classes + 1)] for y in range(lw*lh*layer_in.n)]
boxes = [Box(0, 0, 0, 0) for i in range(lw*lh*layer_in.n)]
for i in range(lw*lh):
row = int(i / lw)
col = int(i % lw)
for n in range(layer_in.n):
index = n*lw*lh + i
obj_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, layer_in.coords)
box_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, 0)
mask_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, 4)
scale = 1 if layer_in.background else tvm_out[obj_index]
boxes[index] = _get_region_box(tvm_out, layer_in.biases, n, box_index, col,
row, lw, lh, lw*lh)
if not layer_in.softmax_tree:
max_element = 0
for j in range(layer_in.classes):
class_index = _entry_index(0, lw, lh, layer_in.outputs, layer_in.classes,
layer_in.coords, n*lw*lh + i, layer_in.coords+1+j)
prob = scale*tvm_out[class_index]
probs[index][j] = prob if prob > thresh else 0
max_element = max(max_element, prob)
probs[index][layer_in.classes] = max_element
_correct_region_boxes(boxes, lw*lh*layer_in.n, imw, imh, netw, neth, relative)
return boxes, probs
def do_nms_sort(boxes, probs, total, classes, thresh):
"Does the sorting based on the threshold values"
SortableBbox = namedtuple('SortableBbox', ['index_var', 'class_var', 'probs'])
s = [SortableBbox(0, 0, []) for i in range(total)]
for i in range(total):
s[i] = s[i]._replace(index_var=i)
s[i] = s[i]._replace(class_var=0)
s[i] = s[i]._replace(probs=probs)
for k in range(classes):
for i in range(total):
s[i] = s[i]._replace(class_var=k)
s = sorted(s, key=lambda x: x.probs[x.index_var][x.class_var], reverse=True)
for i in range(total):
if probs[s[i].index_var][k] == 0:
continue
a = boxes[s[i].index_var]
for j in range(i+1, total):
b = boxes[s[j].index_var]
if _box_iou(a, b) > thresh:
probs[s[j].index_var][k] = 0
return boxes, probs
def draw_detections(im, num, thresh, boxes, probs, names, classes):
"Draw the markings around the detected region"
for i in range(num):
labelstr = []
category = -1
for j in range(classes):
if probs[i][j] > thresh:
if category == -1:
category = j
labelstr.append(names[j])
if category > -1:
imc, imh, imw = im.shape
width = int(imh * 0.006)
offset = category*123457 % classes
red = _get_color(2, offset, classes)
green = _get_color(1, offset, classes)
blue = _get_color(0, offset, classes)
rgb = [red, green, blue]
b = boxes[i]
left = int((b.x-b.w/2.)*imw)
right = int((b.x+b.w/2.)*imw)
top = int((b.y-b.h/2.)*imh)
bot = int((b.y+b.h/2.)*imh)
if left < 0:
left = 0
if right > imw-1:
right = imw-1
if top < 0:
top = 0
if bot > imh-1:
bot = imh-1
_draw_box_width(im, left, top, right, bot, width, red, green, blue)
label = _get_label(''.join(labelstr), rgb)
_draw_label(im, top + width, left, label, rgb)
def _get_pixel(im, x, y, c):
return im[c][y][x]
def _set_pixel(im, x, y, c, val):
if x < 0 or y < 0 or c < 0 or x >= im.shape[2] or y >= im.shape[1] or c >= im.shape[0]:
return
im[c][y][x] = val
def _draw_label(im, r, c, label, rgb):
w = label.shape[2]
h = label.shape[1]
if (r - h) >= 0:
r = r - h
for j in range(h):
if j < h and (j + r) < im.shape[1]:
for i in range(w):
if i < w and (i + c) < im.shape[2]:
for k in range(label.shape[0]):
val = _get_pixel(label, i, j, k)
_set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)
def _get_label(labelstr, rgb):
text = labelstr
colorText = "black"
testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1)))
font = ImageFont.truetype("arial.ttf", 25)
width, height = testDraw.textsize(labelstr, font=font)
img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255),
int(rgb[2]*255)))
d = ImageDraw.Draw(img)
d.text((0, 0), text, fill=colorText, font=font)
opencvImage = np.divide(np.asarray(img), 255)
return opencvImage.transpose(2, 0, 1)
def _get_color(c, x, max_value):
c = int(c)
colors = [[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]]
ratio = (float(x)/float(max_value)) * 5
i = int(math.floor(ratio))
j = int(math.ceil(ratio))
ratio -= i
r = (1-ratio) * colors[i][c] + ratio*colors[j][c]
return r
def _draw_box(im, x1, y1, x2, y2, r, g, b):
y1 = int(y1)
y2 = int(y2)
x1 = int(x1)
x2 = int(x2)
ac, ah, aw = im.shape
if x1 < 0:
x1 = 0
if x1 >= aw:
y1 = 0
if y1 >= ah:
y1 = ah - 1
if y2 < 0:
y2 = 0
if y2 >= ah:
y2 = ah - 1
for i in range(x1, x2):
im[0][y1][i] = r
im[0][y2][i] = r
im[1][y1][i] = g
im[1][y2][i] = g
im[2][y1][i] = b
im[2][y2][i] = b
for i in range(y1, y2):
im[0][i][x1] = r
im[0][i][x2] = r
im[1][i][x1] = g
im[1][i][x2] = g
im[2][i][x1] = b
im[2][i][x2] = b
def _draw_box_width(im, x1, y1, x2, y2, w, r, g, b):
for i in range(int(w)):
_draw_box(im, x1+i, y1+i, x2-i, y2-i, r, g, b)
......@@ -7,6 +7,7 @@ from . import tensor
from . import nn
from . import transform
from . import reduction
from . import vision
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
# pylint: disable=invalid-name, unused-argument
"""Definition of nn ops"""
from __future__ import absolute_import
import topi
import tvm
from . import registry as reg
from .registry import OpPattern
@reg.register_compute("yolo2_reorg")
def compute_reorg(attrs, inputs, _):
"""Compute definition of reorg"""
return topi.vision.reorg(inputs[0], attrs.get_int("stride"))
@reg.register_schedule("yolo2_reorg")
def schedule_reorg(attrs, outs, target):
"""Schedule definition of reorg"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)
reg.register_pattern("yolo2_reorg", OpPattern.INJECTIVE)
@reg.register_compute("yolo2_region")
def compute_region(attrs, inputs, _):
"""Compute definition of region"""
n = attrs.get_int("n")
classes = attrs.get_int("classes")
coords = attrs.get_int("coords")
background = attrs.get_int("background")
softmax = attrs.get_int("softmax")
return topi.vision.yolo2.region(inputs[0], n, classes, coords, background, softmax)
@reg.register_schedule("yolo2_region")
def schedule_region(attrs, outs, target):
"""Schedule definition of region"""
with tvm.target.create(target):
return topi.generic.vision.schedule_region(outs)
reg.register_pattern("yolo2_region", OpPattern.OPAQUE)
/*!
* Copyright (c) 2018 by Contributors
* \file region.cc
* \brief Property def of pooling operators.
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "../../op_common.h"
#include "region.h"
namespace nnvm {
namespace top {
NNVM_REGISTER_OP(yolo2_region)
.describe(R"code(Region layer
)code" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(5)
.add_argument("data", "Tensor", "Input data")
.set_attr<FInferType>("FInferType", RegionType<1, 1>)
.set_attr<FInferShape>("FInferShape", RegionShape<1, 1>)
.set_attr<FInplaceOption>(
"FInplaceOption",
[](const NodeAttrs &attrs) {
return std::vector<std::pair<int, int>>{{0, 0}, {1, 0}};
})
.set_attr<FGradient>("FGradient", [](const NodePtr &n,
const std::vector<NodeEntry> &ograds) {
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});
} // namespace top
} // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \file region.h
*/
#ifndef NNVM_TOP_VISION_YOLO2_REGION_H_
#define NNVM_TOP_VISION_YOLO2_REGION_H_
#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include <sstream>
namespace nnvm {
namespace top {
template <typename AttrType,
bool (*is_none)(const AttrType &),
bool (*assign)(AttrType *,
const AttrType &),
bool reverse_infer,
std::string (*attr_string)(const AttrType &),
int n_in = -1,
int n_out = -1>
inline bool RegionAttr(const nnvm::NodeAttrs &attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType &none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
if (n_in != -1) {
in_size = static_cast<size_t>(n_in);
}
if (n_out != -1) {
out_size = static_cast<size_t>(n_out);
}
auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
if (i == 0)
CHECK(assign(&dattr, (*vec)[i]))
<< "Incompatible attr in node " << attrs.name << " at " << i
<< "-th " << name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
deduce(in_attrs, in_size, "input");
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
write(out_attrs, out_size, "output");
if (is_none(dattr)) {
return false;
}
return true;
}
template <int n_in, int n_out>
inline bool RegionShape(const NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return RegionAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}
template <int n_in, int n_out>
inline bool RegionType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return RegionAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
} // namespace top
} // namespace nnvm
#endif // NNVM_TOP_VISION_YOLO2_REGION_H_
/*!
* Copyright (c) 2018 by Contributors
* \file reorg.cc
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/top/nn.h>
#include "../../op_common.h"
#include "../../elemwise_op_common.h"
#include "reorg.h"
namespace nnvm {
namespace top {
// reorg
DMLC_REGISTER_PARAMETER(ReorgParam);
inline bool ReorgInferShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const ReorgParam &param = nnvm::get<ReorgParam>(attrs.parsed);
TShape dshape = in_shape->at(0);
if (dshape.ndim() == 0)
return false;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D";
CHECK_GT(param.stride, 0U) << "Stride value cannot be 0";
TShape oshape({dshape[0], 0, 0, 0});
oshape[1] = dshape[1] * param.stride * param.stride;
oshape[2] = dshape[2] / param.stride;
oshape[3] = dshape[3] / param.stride;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
NNVM_REGISTER_OP(yolo2_reorg)
.describe(R"(Perform reorg operation on input array based on the stride value.
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
- **out**: Output is 4D array of shape (batch_size, channels/(stride*stride), in_height*stride, in_width*stride).
)" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(5)
.add_argument("data", "Tensor", "Data input to reorganize")
.set_attr_parser(ParamParser<ReorgParam>)
.add_arguments(ReorgParam::__FIELDS__())
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReorgParam>)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferShape>("FInferShape", ReorgInferShape);
} // namespace top
} // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \file reorg.h
*/
#ifndef NNVM_TOP_VISION_YOLO2_REORG_H_
#define NNVM_TOP_VISION_YOLO2_REORG_H_
#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include <sstream>
namespace nnvm {
namespace top {
template <typename AttrType,
bool (*is_none)(const AttrType &),
bool (*assign)(AttrType *,
const AttrType &),
bool reverse_infer,
std::string (*attr_string)(const AttrType &),
int n_in = -1,
int n_out = -1>
inline bool ReorgAttr(const nnvm::NodeAttrs &attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType &none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
if (n_in != -1) {
in_size = static_cast<size_t>(n_in);
}
if (n_out != -1) {
out_size = static_cast<size_t>(n_out);
}
auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
if (i == 0) {
CHECK(assign(&dattr, (*vec)[i]))
<< "Incompatible attr in node " << attrs.name << " at " << i
<< "-th " << name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
}
};
deduce(in_attrs, in_size, "input");
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": "
<< "expected " << attr_string(dattr) << ", got "
<< attr_string((*vec)[i]);
}
};
write(out_attrs, out_size, "output");
if (is_none(dattr)) {
return false;
}
return true;
}
template <int n_in, int n_out>
inline bool ReorgShape(const NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return ReorgAttr<TShape, shape_is_none, shape_assign, true, shape_string>(
attrs, in_attrs, out_attrs, TShape());
}
template <int n_in, int n_out>
inline bool ReorgType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in))
<< " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out))
<< " in operator " << attrs.name;
}
return ReorgAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
struct ReorgParam : public dmlc::Parameter<ReorgParam> {
int stride;
DMLC_DECLARE_PARAMETER(ReorgParam) {
DMLC_DECLARE_FIELD(stride).set_default(1).describe("Stride value");
}
};
} // namespace top
} // namespace nnvm
#endif // NNVM_TOP_VISION_YOLO2_REORG_H_
......@@ -41,6 +41,9 @@ RUN bash /install/ubuntu_install_coreml.sh
COPY install/ubuntu_install_keras.sh /install/ubuntu_install_keras.sh
RUN bash /install/ubuntu_install_keras.sh
COPY install/ubuntu_install_darknet.sh /install/ubuntu_install_darknet.sh
RUN bash /install/ubuntu_install_darknet.sh
RUN pip install Pillow
# Environment variables
......
#install the necessary dependancies, cffi, opencv
wget 'https://github.com/siju-samuel/darknet/blob/master/lib/libdarknet.so?raw=true' -O libdarknet.so
pip2 install opencv-python cffi
pip3 install opencv-python cffi
"""
Compile Darknet Models
=====================
This article is a test script to test darknet models with NNVM.
All the required models and libraries will be downloaded from the internet
by the script.
"""
import os
import requests
import numpy as np
from nnvm import frontend
from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler
import tvm
import sys
import urllib
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
if sizecompare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
urlfile_size = int(res_get.headers['Content-Length'])
if urlfile_size != file_size:
print("exist file got corrupted, downloading", path, " file freshly")
_download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path)
print('')
except:
urllib.urlretrieve(url, path)
DARKNET_LIB = 'libdarknet.so'
DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
+ DARKNET_LIB + '?raw=true'
_download(DARKNETLIB_URL, DARKNET_LIB)
LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
def test_forward(net):
'''Test network with given input image on both darknet and tvm'''
def get_darknet_output(net, img):
return LIB.network_predict_image(net, img)
def get_tvm_output(net, img):
'''Compute TVM output'''
dtype = 'float32'
batch_size = 1
sym, params = frontend.darknet.from_darknet(net, dtype)
data = np.empty([batch_size, img.c, img.h, img.w], dtype)
i = 0
for c in range(img.c):
for h in range(img.h):
for k in range(img.w):
data[0][c][h][k] = img.data[i]
i = i + 1
target = 'llvm'
shape_dict = {'data': data.shape}
#with nnvm.compiler.build_config(opt_level=2):
graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params)
######################################################################
# Execute on TVM
# ---------------
# The process is no different from other examples.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, library, ctx)
# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out_shape = (net.outputs,)
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
return tvm_out
test_image = 'dog.jpg'
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true'
_download(img_url, test_image)
img = LIB.letterbox_image(LIB.load_image_color(test_image.encode('utf-8'), 0, 0), net.w, net.h)
darknet_output = get_darknet_output(net, img)
darknet_out = np.zeros(net.outputs, dtype='float32')
for i in range(net.outputs):
darknet_out[i] = darknet_output[i]
tvm_out = get_tvm_output(net, img)
np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-3, atol=1e-3)
def test_forward_extraction():
'''test extraction model'''
model_name = 'extraction'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_alexnet():
'''test alexnet model'''
model_name = 'alexnet'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_resnet50():
'''test resnet50 model'''
model_name = 'resnet50'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_yolo():
'''test yolo model'''
model_name = 'yolo'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
_download(cfg_url, cfg_name)
_download(weights_url, weights_name)
net = LIB.load_network(cfg_name.encode('utf-8'), weights_name.encode('utf-8'), 0)
test_forward(net)
LIB.free_network(net)
def test_forward_convolutional():
'''test convolutional layer'''
net = LIB.make_network(1)
layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_dense():
'''test fully connected layer'''
net = LIB.make_network(1)
layer = LIB.make_connected_layer(1, 75, 20, 1, 0, 0)
net.layers[0] = layer
net.w = net.h = 5
LIB.resize_network(net, 5, 5)
test_forward(net)
LIB.free_network(net)
def test_forward_maxpooling():
'''test maxpooling layer'''
net = LIB.make_network(1)
layer = LIB.make_maxpool_layer(1, 224, 224, 3, 2, 2, 0)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_avgpooling():
'''test avgerage pooling layer'''
net = LIB.make_network(1)
layer = LIB.make_avgpool_layer(1, 224, 224, 3)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_batch_norm():
'''test batch normalization layer'''
net = LIB.make_network(1)
layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 1, 0, 0, 0)
for i in range(32):
layer.rolling_mean[i] = np.random.rand(1)
layer.rolling_variance[i] = np.random.rand(1)
net.layers[0] = layer
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_shortcut():
'''test shortcut layer'''
net = LIB.make_network(3)
layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0)
layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32)
layer_3.activation = 1
net.layers[0] = layer_1
net.layers[1] = layer_2
net.layers[2] = layer_3
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
def test_forward_reorg():
'''test reorg layer'''
net = LIB.make_network(2)
layer_1 = LIB.make_convolutional_layer(1, 222, 222, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0)
layer_2 = LIB.make_reorg_layer(1, 110, 110, 32, 2, 0, 0, 0)
net.layers[0] = layer_1
net.layers[1] = layer_2
net.w = net.h = 222
LIB.resize_network(net, 222, 222)
test_forward(net)
LIB.free_network(net)
def test_forward_region():
'''test region layer'''
net = LIB.make_network(2)
layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 8, 1, 3, 2, 0, 1, 0, 0, 0, 0)
layer_2 = LIB.make_region_layer(1, 111, 111, 2, 2, 1)
layer_2.softmax = 1
net.layers[0] = layer_1
net.layers[1] = layer_2
net.w = net.h = 224
LIB.resize_network(net, 224, 224)
test_forward(net)
LIB.free_network(net)
if __name__ == '__main__':
test_forward_resnet50()
test_forward_alexnet()
test_forward_extraction()
test_forward_yolo()
test_forward_convolutional()
test_forward_maxpooling()
test_forward_avgpooling()
test_forward_batch_norm()
test_forward_shortcut()
test_forward_dense()
test_forward_reorg()
test_forward_region()
"""
Tutorial for running Yolo-V2 in Darknet Models
=====================
**Author**: `Siju Samuel <https://siju-samuel.github.io/>`_
This article is an introductory tutorial to deploy darknet models with NNVM.
All the required models and libraries will be downloaded from the internet
by the script.
This script runs the YOLO-V2 Model with the bounding boxes
Darknet parsing have dependancy with CFFI and CV2 library
Please install CFFI and CV2 before executing this script
pip install cffi
pip install opencv-python
"""
from ctypes import *
import math
import random
import nnvm
import nnvm.frontend.darknet
from nnvm.testing.darknet import __darknetffi__
import matplotlib.pyplot as plt
import numpy as np
import tvm
import os, sys, time, urllib, requests
if sys.version_info >= (3,):
import urllib.request as urllib2
import urllib.parse as urlparse
else:
import urllib2
import urlparse
######################################################################
# Set the parameters here.
# Supported models alexnet, resnet50, resnet152, extraction, yolo
######################################################################
model_name = 'yolo'
test_image = 'dog.jpg'
target = 'llvm'
ctx = tvm.cpu(0)
######################################################################
def dlProgress(count, block_size, total_size):
"""Show the download progress."""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()
def download(url, path, overwrite=False, sizecompare=False):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Parameters
----------
url : str
Operator name, such as Convolution, Connected, etc
path : str
List of input symbols.
overwrite : dict
Dict of operator attributes
sizecompare : dict
Dict of operator attributes
Returns
-------
out_name : converted out name of operation
sym : nnvm.Symbol
Converted nnvm Symbol
"""
if os.path.isfile(path) and not overwrite:
if (sizecompare):
fileSize = os.path.getsize(path)
resHead = requests.head(url)
resGet = requests.get(url,stream=True)
if 'Content-Length' not in resHead.headers :
resGet = urllib2.urlopen(url)
urlFileSize = int(resGet.headers['Content-Length'])
if urlFileSize != fileSize:
print ("exist file got corrupted, downloading", path , " file freshly")
download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=dlProgress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=dlProgress)
######################################################################
# Prepare cfg and weights file
# Pretrained model available https://pjreddie.com/darknet/imagenet/
# --------------------------------------------------------------------
# Download cfg and weights file first time.
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/siju-samuel/darknet/blob/master/cfg/' + \
cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
download(cfg_url, cfg_name)
download(weights_url, weights_name)
######################################################################
# Download and Load darknet library
# ---------------------------------
darknet_lib = 'libdarknet.so'
darknetlib_url = 'https://github.com/siju-samuel/darknet/blob/master/lib/' + \
darknet_lib + '?raw=true'
download(darknetlib_url, darknet_lib)
#if the file doesnt exist, then exit normally.
if os.path.isfile('./' + darknet_lib) is False:
exit(0)
darknet_lib = __darknetffi__.dlopen('./' + darknet_lib)
cfg = "./" + str(cfg_name)
weights = "./" + str(weights_name)
net = darknet_lib.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
dtype = 'float32'
batch_size = 1
print("Converting darknet to nnvm symbols...")
sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)
######################################################################
# Compile the model on NNVM
# --------------------------------------------------------------------
# compile the model
data = np.empty([batch_size, net.c ,net.h, net.w], dtype);
shape = {'data': data.shape}
print("Compiling the model...")
with nnvm.compiler.build_config(opt_level=2):
graph, lib, params = nnvm.compiler.build(sym, target, shape, dtype, params)
#####################################################################
# Save the json
# --------------------------------------------------------------------
def save_lib():
#Save the graph, params and .so to the current directory
print("Saving the compiled output...")
path_name = 'nnvm_darknet_' + model_name
path_lib = path_name + '_deploy_lib.so'
lib.export_library(path_lib)
with open(path_name
+ "deploy_graph.json", "w") as fo:
fo.write(graph.json())
with open(path_name
+ "deploy_param.params", "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
#save_lib()
######################################################################
# Load a test image
# --------------------------------------------------------------------
print("Loading the test image...")
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \
test_image +'?raw=true'
download(img_url, test_image)
data = nnvm.testing.darknet.load_image(test_image, net.w, net.h)
######################################################################
# Execute on TVM
# --------------------------------------------------------------------
# The process is no different from other examples.
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
# execute
print("Running the test image...")
m.run()
# get outputs
out_shape = (net.outputs,)
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
#do the detection and bring up the bounding boxes
thresh = 0.24
hier_thresh = 0.5
img = nnvm.testing.darknet.load_image_color(test_image)
_, im_h, im_w = img.shape
probs= []
boxes = []
region_layer = net.layers[net.n - 1]
boxes, probs = nnvm.testing.yolo2_detection.get_region_boxes(region_layer, im_w, im_h, net.w, net.h,
thresh, probs, boxes, 1, tvm_out)
boxes, probs = nnvm.testing.yolo2_detection.do_nms_sort(boxes, probs,
region_layer.w*region_layer.h*region_layer.n, region_layer.classes, 0.3)
coco_name = 'coco.names'
coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name +'?raw=true'
font_name = 'arial.ttf'
font_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + font_name +'?raw=true'
download(coco_url, coco_name)
download(font_url, font_name)
with open(coco_name) as f:
content = f.readlines()
names = [x.strip() for x in content]
nnvm.testing.yolo2_detection.draw_detections(img, region_layer.w*region_layer.h*region_layer.n,
thresh, boxes, probs, names, region_layer.classes)
plt.imshow(img.transpose(1,2,0))
plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment