Commit df8e3382 by Siju Committed by Tianqi Chen

[DARKNET]RNN Support for darknet (#1443)

parent ca5397d5
......@@ -479,6 +479,7 @@ void top_predictions(network *net, int n, int *index);
void free_image(image m);
image load_image_color(char *filename, int w, int h);
float *network_predict_image(network *net, image im);
float *network_predict(network *net, float *input);
network *make_network(int n);
layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, int batch_normalize, int adam);
......@@ -488,6 +489,8 @@ layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int
layer make_batchnorm_layer(int batch, int w, int h, int c);
layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, int flatten, int extra);
layer make_region_layer(int batch, int w, int h, int n, int classes, int coords);
layer make_softmax_layer(int batch, int inputs, int groups);
layer make_rnn_layer(int batch, int inputs, int outputs, int steps, ACTIVATION activation, int batch_normalize, int adam);
void free_network(network *net);
"""
)
......@@ -7,18 +7,20 @@ by the script.
"""
import os
import requests
import sys
import urllib
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from nnvm import frontend
from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler
import tvm
import sys
import urllib
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2
def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
......@@ -48,43 +50,31 @@ DARKNETLIB_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' \
_download(DARKNETLIB_URL, DARKNET_LIB)
LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
def _get_tvm_output(net, data):
'''Compute TVM output'''
dtype = 'float32'
sym, params = frontend.darknet.from_darknet(net, dtype)
target = 'llvm'
shape_dict = {'data': data.shape}
graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params)
# Execute on TVM
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, library, ctx)
# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out_shape = (net.outputs,)
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
return tvm_out
def test_forward(net):
'''Test network with given input image on both darknet and tvm'''
def get_darknet_output(net, img):
return LIB.network_predict_image(net, img)
def get_tvm_output(net, img):
'''Compute TVM output'''
dtype = 'float32'
batch_size = 1
sym, params = frontend.darknet.from_darknet(net, dtype)
data = np.empty([batch_size, img.c, img.h, img.w], dtype)
i = 0
for c in range(img.c):
for h in range(img.h):
for k in range(img.w):
data[0][c][h][k] = img.data[i]
i = i + 1
target = 'llvm'
shape_dict = {'data': data.shape}
#with nnvm.compiler.build_config(opt_level=2):
graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params)
######################################################################
# Execute on TVM
# ---------------
# The process is no different from other examples.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, library, ctx)
# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out_shape = (net.outputs,)
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
return tvm_out
dtype = 'float32'
test_image = 'dog.jpg'
img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true'
......@@ -94,9 +84,35 @@ def test_forward(net):
darknet_out = np.zeros(net.outputs, dtype='float32')
for i in range(net.outputs):
darknet_out[i] = darknet_output[i]
tvm_out = get_tvm_output(net, img)
batch_size = 1
data = np.empty([batch_size, img.c, img.h, img.w], dtype)
i = 0
for c in range(img.c):
for h in range(img.h):
for k in range(img.w):
data[0][c][h][k] = img.data[i]
i = i + 1
tvm_out = _get_tvm_output(net, data)
np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-3, atol=1e-3)
def test_rnn_forward(net):
'''Test network with given input data on both darknet and tvm'''
def get_darknet_network_predict(net, data):
return LIB.network_predict(net, data)
from cffi import FFI
ffi = FFI()
np_arr = np.zeros([1, net.inputs], dtype='float32')
np_arr[0, 84] = 1
cffi_arr = ffi.cast('float*', np_arr.ctypes.data)
tvm_out = _get_tvm_output(net, np_arr)
darknet_output = get_darknet_network_predict(net, cffi_arr)
darknet_out = np.zeros(net.outputs, dtype='float32')
for i in range(net.outputs):
darknet_out[i] = darknet_output[i]
np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-4, atol=1e-4)
def test_forward_extraction():
'''test extraction model'''
model_name = 'extraction'
......@@ -289,6 +305,25 @@ def test_forward_softmax_temperature():
test_forward(net)
LIB.free_network(net)
def test_forward_rnn():
'''test softmax layer'''
net = LIB.make_network(1)
batch = 1
inputs = 256
outputs = 256
steps = 1
activation = 1
batch_normalize = 0
adam = 0
layer_1 = LIB.make_rnn_layer(batch, inputs, outputs, steps, activation, batch_normalize, adam)
net.layers[0] = layer_1
net.inputs = inputs
net.outputs = outputs
net.w = net.h = 0
LIB.resize_network(net, net.w, net.h)
test_rnn_forward(net)
LIB.free_network(net)
if __name__ == '__main__':
test_forward_resnet50()
test_forward_alexnet()
......@@ -303,6 +338,7 @@ if __name__ == '__main__':
test_forward_dense_batchnorm()
test_forward_softmax()
test_forward_softmax_temperature()
test_forward_rnn()
test_forward_reorg()
test_forward_region()
test_forward_elu()
"""
Compile Darknet Models for RNN
==============================
**Author**: `Siju Samuel <https://siju-samuel.github.io/>`_
This article is an introductory tutorial to deploy darknet rnn models with NNVM.
This script will run a character prediction model
Each module consists of 3 fully-connected layers. The input layer propagates information from the
input to the current state. The recurrent layer propagates information through time from the
previous state to the current one.
The input to the network is a 1-hot encoding of ASCII characters. We train the network to predict
the next character in a stream of characters. The output is constrained to be a probability
distribution using a softmax layer.
Since each recurrent layer contains information about the current character and the past
characters, it can use this context to predict the future characters in a word or phrase.
All the required models and libraries will be downloaded from the internet
by the script.
"""
import random
import numpy as np
from mxnet.gluon.utils import download
import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.darknet import __darknetffi__
import nnvm
import nnvm.frontend.darknet
# Set the parameters
# -----------------------
# Set the seed value and the number of characters to predict
#Model name
MODEL_NAME = 'rnn'
#Seed value
seed = 'Thus'
#Number of characters to predict
num = 1000
# Download required files
# -----------------------
# Download cfg and weights file if first time.
CFG_NAME = MODEL_NAME + '.cfg'
WEIGHTS_NAME = MODEL_NAME + '.weights'
REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
WEIGHTS_URL = REPO_URL + 'weights/' + WEIGHTS_NAME + '?raw=true'
download(CFG_URL, CFG_NAME)
download(WEIGHTS_URL, WEIGHTS_NAME)
# Download and Load darknet library
DARKNET_LIB = 'libdarknet.so'
DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
download(DARKNET_URL, DARKNET_LIB)
DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
cfg = "./" + str(CFG_NAME)
weights = "./" + str(WEIGHTS_NAME)
net = DARKNET_LIB.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
dtype = 'float32'
batch_size = 1
# Import the graph to NNVM
# ------------------------
# Import darknet graph definition to nnvm.
#
# Results:
# sym: nnvm graph for rnn model
# params: params converted from darknet weights
print("Converting darknet rnn model to nnvm symbols...")
sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)
# Compile the model on NNVM
data = np.empty([1, net.inputs], dtype)#net.inputs
target = 'llvm'
shape = {'data': data.shape}
print("Compiling the model...")
shape_dict = {'data': data.shape}
dtype_dict = {'data': data.dtype}
with nnvm.compiler.build_config(opt_level=2):
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params)
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the NNVM compiled model on cpu target.
# Set the cpu context
ctx = tvm.cpu(0)
# Create graph runtime
m = graph_runtime.create(graph, lib, ctx)
# Set the params to runtime
m.set_input(**params)
def _init_state_memory(rnn_cells_count, dtype):
'''Initialize memory for states'''
states = {}
state_shape = (1024,)
for i in range(rnn_cells_count):
k = 'rnn' + str(i) + '_state'
states[k] = tvm.nd.array(np.zeros(state_shape, dtype).astype(dtype))
return states
def _set_state_input(runtime, states):
'''Set the state inputs'''
for state in states:
runtime.set_input(state, states[state])
def _get_state_output(runtime, states):
'''Get the state outputs and save'''
i = 1
for state in states:
data = states[state]
states[state] = runtime.get_output((i), tvm.nd.empty(data.shape, data.dtype))
i += 1
def _proc_rnn_output(out_data):
'''Generate the characters from the output array'''
sum_array = 0
n = out_data.size
r = random.uniform(0, 1)
for j in range(n):
if out_data[j] < 0.0001:
out_data[j] = 0
sum_array += out_data[j]
for j in range(n):
out_data[j] *= float(1.0) / sum_array
r = r - out_data[j]
if r <= 0:
return j
return n-1
print("RNN generaring text...")
out_shape = (net.outputs,)
rnn_cells_count = 3
# Initialize state memory
# -----------------------
states = _init_state_memory(rnn_cells_count, dtype)
len_seed = len(seed)
count = len_seed + num
out_txt = ""
#Initialize random seed
random.seed(0)
c = ord(seed[0])
inp_data = np.zeros([net.inputs], dtype)
# Run the model
# -------------
# Predict character by character till `num`
for i in range(count):
inp_data[c] = 1
# Set the input data
m.set_input('data', tvm.nd.array(inp_data.astype(dtype)))
inp_data[c] = 0
# Set the state inputs
_set_state_input(m, states)
# Run the model
m.run()
# Get the output
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
# Get the state outputs
_get_state_output(m, states)
# Get the predicted character and keep buffering it
c = ord(seed[i]) if i < len_seed else _proc_rnn_output(tvm_out)
out_txt += chr(c)
print("Predicted Text =", out_txt)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment