Commit a808a987 by Albin Joy Committed by Tianqi Chen

[NNVM][TENSORFLOW] LSTM operator and PTB word prediction frontend (#1389)

parent f7d05b7c
# pylint: disable=import-self, invalid-name, unused-argument
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines
"""TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs
from __future__ import print_function
......@@ -457,6 +457,190 @@ def _shape():
return inputs[0]
return _impl
def _fill():
def _impl(inputs, attr, params):
fill_arg = params.pop(inputs.pop(1).list_output_names()[0])
new_inputs = []
return AttrCvt(
op_name='full',
extras={'shape':inputs[0],
'fill_value':fill_arg.asnumpy()[0], 'dtype':attr['T'].name},
ignores=['index_type', 'T'])(new_inputs, attr)
return _impl
def _gather_v2():
"Tensorflow now support only gatherv2"
def _impl(inputs, attr, params):
axis = params[inputs.pop(2).list_output_names()[0]].asnumpy()[0]
new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
return AttrCvt(
op_name="take",
extras={'axis':axis},
ignores=['Tindices', 'Tparams', 'validate_indices', \
'Taxis', '_class'])(new_input, attr)
return _impl
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of an intermediate node in the NNVM graph."""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
return out_shapes
def _stridedSlice():
def _impl(inputs, attr, params):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist()
end = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist()
stride = params.pop(inputs[3].list_output_names()[0]).asnumpy().tolist()
begin_mask = int(attr.get('begin_mask', 0))
end_mask = int(attr.get('end_mask', 0))
ellipsis_mask = int(attr.get('ellipsis_mask', 0))
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0])
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
#Used later for extending the stride attributes in the below loop.
ellipsis_mask |= (1 << stride_dim)
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_stride[final_index] = 1
final_index += 1
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
final_index += 1
return m_begin, m_end, m_stride
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride = _transform_mask(stride_dim, ellipsis_mask)
out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride)
out_shape = _infer_out_shapes(out, params)[0]
#Create final output shape.
final_output = []
out_index = 0
index = 0
while out_index != len(out_shape):
#axis with shrink_axis_mask dimension=1 and it is ignored.
mask = 1 << index
if (new_axis_mask & mask) and not ellipsis_mask & mask:
final_output.append(1)
elif (not mask & shrink_axis_mask) or index >= stride_dim:
#Shrink is considered till stride_dim
final_output.append(out_shape[out_index])
out_index += 1
index += 1
return _sym.reshape(out, shape=tuple(final_output))
return _impl
def _LSTMBlockCell():
def _impl(inputs, in_state_c, in_state_h, attr, params):
"""LSTM Block cell.
Calculations are described in: https://github.com/tensorflow/tensorflow/blob/
r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
Parameters
----------
inputs : nnvm.Symbol
Input data
in_state_c: list of nnvm.Symbol
Cell state input values for all the layers
in_state_h: list of nnvm.Symbol
Hidden state input values for all the layers
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
output: nnvm.Symbol
Output state value.
"""
in_data = inputs[0]
in_weight = inputs[3]
in_bias = inputs[7]
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
num_hidden = num_hidden_layers // 4
in_data = _sym.reshape(in_data,
shape=(batch_size, input_size))
ixh = _sym.concatenate(*[in_data, in_state_h], axis=1)
in_weight = _sym.transpose(in_weight)
gates = _sym.dense(ixh, in_weight, in_bias, use_bias=True,
units=num_hidden_layers, name="dense")
gate_list = _sym.split(gates, indices_or_sections=4, axis=1)
in_gate = _sym.sigmoid(gate_list[0])
in_transform = _sym.tanh(gate_list[1])
forget_gate = _sym.sigmoid(gate_list[2])
forget_gate = forget_gate + forget_bias
out_gate = _sym.sigmoid(gate_list[3])
next_c = _sym.broadcast_add(_sym.broadcast_mul(forget_gate, in_state_c),
_sym.broadcast_mul(in_gate, in_transform))
next_h = out_gate * _sym.tanh(next_c)
out_state = _sym.concatenate(*[next_c, next_h])
out_state = _sym.reshape(out_state,
shape=(2, batch_size, num_hidden))
return next_h, out_state
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
......@@ -493,8 +677,192 @@ _convert_map = {
'DepthwiseConv2dNative' : _depthwise_conv(),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Fill' : _fill(),
'GatherV2' : _gather_v2(),
'StridedSlice' : _stridedSlice(),
}
# _convert_map_rnn defines maps of rnn operator name to
# converter functor(callable) for 1 to 1 mapping.
_convert_map_rnn = {
'LSTMBlockCell' : _LSTMBlockCell(),
}
class RecurrentNetworks(object):
"""Recurrent network layer handlers.
Handle Layer operations.
ToDo: Operators like RNN/GRU layer concepts also can be handled here
Parameters
----------
nodes : list
list of graph nodes used for tensorflow parsing.
out_rnn : list
List of RecurrentNetwork outputs. This output will be appended to the
'head' nodes of the graph.
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
"""
def __init__(self, nodes, out_rnn, graph, convert_map):
self._graph = graph
self._convert_map = convert_map
self._nodes = nodes
self._out_rnn = out_rnn
self._cur_lstm_layer = 0
self._layer_name_list = []
self._recurrent_ops_layer_map = {
'LSTMBlockCell' : self._LSTMBlockCellLayer(),
}
def _LSTMBlockCellLayer(self):
"""LSTMBlockCell layer handler.
Parameters
----------
op_name : str
Operator name, eg:LSTMBlockCell
layer_name : str list
Layer name is used for creating the state input placeholder.
inputs : nnvm.Symbol
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
num_layers : int
Total number of LSTM layer presented in the graph
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
"""
def _impl(op_name, layer_name, inputs, attrs, params, num_layers):
in_state_c_name = layer_name+'_c'
in_state_h_name = layer_name+'_h'
def _init_state(num_layers, batch_size, num_hidden):
"""Create the initial states for the first layer in the graph."""
in_state_c = _sym.Variable(in_state_c_name,
shape=(num_layers, batch_size, num_hidden))
in_state_h = _sym.Variable(in_state_h_name,
shape=(num_layers, batch_size, num_hidden))
return in_state_c, in_state_h
def _get_cur_input_state(in_state_c, in_state_h, num_layers,
layer, batch_size, num_hidden):
"""Select the appropriate states for the current layer"""
in_state_c_tup = _sym.split(in_state_c,
indices_or_sections=num_layers, axis=0)
in_state_h_tup = _sym.split(in_state_h,
indices_or_sections=num_layers, axis=0)
cur_in_state_c = _sym.reshape(in_state_c_tup[layer],
shape=(batch_size, num_hidden))
cur_in_state_h = _sym.reshape(in_state_h_tup[layer],
shape=(batch_size, num_hidden))
return cur_in_state_c, cur_in_state_h
def _LSTMBlockCellWrapper(inputs, attr, params,
num_layers, layer):
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4
if layer == 0:
#Create initial states placeholder in case of first layer
in_state_c, in_state_h = _init_state(num_layers,
batch_size, num_hidden)
else:
in_state_c = self._nodes[in_state_c_name]
in_state_h = self._nodes[in_state_h_name]
cur_in_state_c, cur_in_state_h = _get_cur_input_state( \
in_state_c, in_state_h,
num_layers, layer,
batch_size, num_hidden)
output, out_state = self._convert_map[op_name](inputs, cur_in_state_c,
cur_in_state_h,
attr, params)
return output, out_state, in_state_c, in_state_h
sym, cur_out_state, in_state_c, in_state_h = \
_LSTMBlockCellWrapper(inputs, attrs, params,
num_layers, self._cur_lstm_layer)
self._nodes[in_state_c_name] = in_state_c
self._nodes[in_state_h_name] = in_state_h
cur_out_state = _sym.expand_dims(cur_out_state, axis=0, num_newaxis=1)
self._out_rnn.append(cur_out_state)
self._cur_lstm_layer += 1
return sym
return _impl
def process_op(self, op_name, inputs, attrs, params):
"""Process recurrent layer operators.
List '_recurrent_ops_layer_map' map each Layer based operators with its
layer handlers. Total number of layers are calculated to form the input
data shapes.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : nnvm.Symbol
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
"""
def _get_abs_layer_name(node):
"""Identify the layer name is already handled. Return the absolute name
"""
if not self._layer_name_list:
self._layer_name_list.append(node.name)
return node.name
for _name in self._layer_name_list:
if _name in node.name:
abs_name = _name
else:
self._layer_name_list.append(node.name)
abs_name = node.name
return abs_name
#Find number of layers of this same operator node in the graph
#and also read the inputs name for the current op.
num_layers = 0
for _, node in enumerate(self._graph.node):
if node.op == op_name:
layer_name = _get_abs_layer_name(node)
num_layers += 1
sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs,
params, num_layers)
return sym
class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
......@@ -510,6 +878,7 @@ class GraphProto(object):
self._num_input = 0
self._num_param = 0
self._input_node = ''
self._num_rnn_layer = False
def from_tensorflow(self, graph):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
......@@ -553,16 +922,19 @@ class GraphProto(object):
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder'
self._input_node = node.name
self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name)
try:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])
input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
......@@ -580,7 +952,6 @@ class GraphProto(object):
if node.name not in self._nodes:
raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name))
attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr
if 'value' in attr:
......@@ -611,9 +982,16 @@ class GraphProto(object):
# Pass the node name too in attr
attr["_node_name"] = node.name
#ToDo: Some of the tensorflow operators maintain internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name
try:
inputs = [self._nodes[i] for i in node.input]
input_shapes = {}
for i in node.input:
if i not in self._params:
input_shapes[self._nodes[i]] = self._output_shapes[i]
......@@ -624,12 +1002,20 @@ class GraphProto(object):
inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr)
op = self._convert_operator(node.op, inputs, attr, graph)
# Assuming only one output.
self._nodes[node.name] = op
node_output = op
# Assume the final node is the output node
out = node_output
#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out = [out, out_rnn]
if isinstance(out, list):
out = _sym.Group(out)
return out, self._params
def _parse_param(self, key, value, name):
......@@ -651,7 +1037,7 @@ class GraphProto(object):
self._nodes[name] = _sym.Variable(name=name,
shape=self._params[name].shape)
else:
if key != 'dtype' and key != '_output_shapes':
if key != 'dtype' and key != '_output_shapes' and key != '_class':
raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key))
......@@ -706,7 +1092,44 @@ class GraphProto(object):
return attrs
def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None):
def _convert_rnn_operator(self, op_name, inputs,
attrs, params, graph, convert_map):
"""Convert RNN and its variant operators to NNVM operators.
This converter read the input states of each layers and
also maintain the output states of each layer in a list.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
graph : Tensorflow graph object
Graph is to find the number of upcoming same operator to
calculate the number of layers.
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
"""
if not self._num_rnn_layer:
self._out_rnn = []
self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map)
self._num_rnn_layer = True
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym
def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
......@@ -733,10 +1156,15 @@ class GraphProto(object):
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
convert_map_rnn = _convert_map_rnn
if op_name in identity_list:
sym = get_nnvm_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
convert_map_rnn)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym
......
......@@ -6,6 +6,8 @@ Some helper definitions for tensorflow models.
"""
import re
import os.path
import collections
import numpy as np
# Tensorflow imports
import tensorflow as tf
......@@ -134,3 +136,143 @@ def get_workload(model_path):
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return graph_def
#######################################################################
# PTB LSTMBlockCell Model
# -----------------------
class PTBSmallConfig(object):
"""Small config.
This configurations are used when training the model
"""
num_layers = 2
num_steps = 1
hidden_size = 200
batch_size = 1
vocab_size = 10000
init_scale = 0.1
def get_config():
"""Configuration used for training the model"""
return PTBSmallConfig()
def pick_from_weight(weight, pows=1.0):
"""Identify token from Softmax output.
This token will be mapped to word in the vocabulary.
"""
weight = weight**pows
t = np.cumsum(weight)
s = np.sum(weight)
return int(np.searchsorted(t, 0.5 * s))
def do_tf_sample(session, data, in_states, num_samples):
"""Sampled from the model"""
samples = []
sample = None
#Cell inputs c and h should be passed for each layer explicitly.
state_input_name = ['Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0']
state = session.run(state_input_name)
#Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal
#nodes for intermediate operations (gates) in the cell during run.
#Cell state (c) is ':1'and cell output (h) is ':6' for each layer.
fetches = [['Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6'],
'Model/Softmax:0']
def _get_feed_dict(input_name, input_data):
"""Create feed dict"""
feed_dict = {}
if isinstance(input_data, list):
for i, e in enumerate(input_name):
feed_dict[e] = input_data[i]
else:
feed_dict[input_name] = input_data
return feed_dict
for x in data:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict['Model/Placeholder:0'] = [[x]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
if sample is not None:
samples.append(sample)
else:
samples.append(0)
k = 1
while k < num_samples:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict['Model/Placeholder:0'] = [[samples[-1]]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
samples.append(sample)
k += 1
return samples, state
def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary"""
data_path = data_dir+'simple-examples/data/'
file_name = 'ptb.train.txt'
def _read_words(filename):
"""Read the data for creating vocabulary"""
with tf.gfile.GFile(filename, "r") as f:
return f.read().encode("utf-8").decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
"""Create vocabulary"""
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
#for python 3.x
id_to_word = dict((v, k) for k, v in word_to_id.items())
return word_to_id, id_to_word
def ptb_raw_data(data_path, file_name):
"""Read the sample data and create vocabulary"""
train_path = os.path.join(data_path, file_name)
word_to_id, id_2_word = _build_vocab(train_path)
return word_to_id, id_2_word
return ptb_raw_data(data_path, file_name)
def get_workload_ptb():
""" Import ptb workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for ptb.
word_to_id : dict
English word to integer id mapping
id_to_word : dict
Integer id to English word mapping
"""
sample_repo = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
sample_data_file = 'simple-examples.tgz'
sample_url = sample_repo+sample_data_file
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'
import tarfile
from tvm.contrib.download import download
DATA_DIR = './ptb_data/'
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
download(sample_url, DATA_DIR+sample_data_file)
t = tarfile.open(DATA_DIR+sample_data_file, 'r')
t.extractall(DATA_DIR)
word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
return word_to_id, id_to_word, get_workload(ptb_model_file)
......@@ -10,12 +10,14 @@ import nnvm.compiler
import tvm
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf
......@@ -55,8 +57,15 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
return tvm_output.asnumpy()
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
return tvm_output.asnumpy()
def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """
......@@ -434,6 +443,159 @@ def test_forward_variable():
#######################################################################
# LSTM
# ----
def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
tf.reset_default_graph()
input_size = num_hidden
input_data = np.full((batch_size, input_size), 1., dtype=dtype)
in_state_c = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
in_state_h = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
def _get_tensorflow_output():
with tf.Session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
m0 = array_ops.zeros([batch_size, num_hidden])
m1 = array_ops.zeros([batch_size, num_hidden])
x=tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
g, ((out_m0, out_m1)) = \
tf.contrib.rnn.LSTMBlockCell(num_hidden,
forget_bias=forget_bias)(x, ((m0, m1)))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1], {
x.name: np.array([[1., 1.]]),
m0.name: 0.1 * np.ones([batch_size, num_hidden]),
m1.name: 0.1 * np.ones([batch_size, num_hidden]),
})
graph_def = sess.graph.as_graph_def(add_shapes=True)
final_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['root/lstm_cell/LSTMBlockCell'])
return final_graph_def, res
graph_def, tf_out = _get_tensorflow_output()
tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h],
['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c',
'root/lstm_cell/LSTMBlockCell_h'],
[tf_out[0].shape, (2, batch_size, num_hidden)],
[tf_out[0].dtype, tf_out[1].dtype])
if isinstance(tvm_output, list):
out = tvm_output[0]
out_state = tvm_output[1]
out_state_tup = np.split(out_state, indices_or_sections=2, axis=0)
out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
tvm_out = [out, out_state_c, out_state_h]
np.testing.assert_allclose(tf_out, tvm_out, rtol=1e-3, atol=1e-3)
def test_forward_lstm():
'''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.0, 'float32')
#######################################################################
# StridedSlice
# ------------
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0):
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
end_mask=end_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask, name="strided_slice")
np_data = np.random.uniform(size=ip_shape).astype(dtype)
with tf.Session() as sess:
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['strided_slice'])
tf_output = run_tf_graph(sess, np_data,
'in_data:0', 'strided_slice:0')
tvm_output = run_tvm_graph(final_graph_def, np_data,
"in_data", tf_output.shape, np_data.dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
def test_forward_stridedslice():
'''test StridedSlice'''
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=2)
_test_stridedslice((3,4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5, new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2, begin_mask=8, end_mask=8)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5,
end_mask=8)
#######################################################################
# Gather
# ------
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices, axis=axis)
np_data = np.random.uniform(size=ip_shape).astype(dtype)
def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
if isinstance(indice_value, int):
indices = np.array([indice_value], dtype='int32')
else:
indices = np.asarray(indice_value, dtype='int32')
return indices
np_indices = _fill_indices(indice_value)
with tf.Session() as sess:
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['GatherV2'])
tf_output = run_tf_graph(sess, [np_data, np_indices], ['in_data:0',
'indices:0'], 'GatherV2:0')
tvm_output = run_tvm_graph(final_graph_def, [np_data, np_indices],
['in_data', 'indices'], tf_output.shape, dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
def test_forward_gather():
'''test gather layer'''
_test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32')
_test_gather((1,4), (1,), [0], 0, 'int32')
_test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32')
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
#######################################################################
# Multi Input to graph
# --------------------
......@@ -584,6 +746,115 @@ def test_forward_mobilenet():
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB
# ---
dir(tf.contrib)
def test_forward_ptb():
'''test ptb model'''
config = nnvm.testing.tf.get_config()
num_steps = config.num_steps
num_hidden = config.hidden_size
num_layers = config.num_layers
batch_size = config.batch_size
vocab_size = config.vocab_size
out_sample_shape = (batch_size, vocab_size)
out_state_shape = (num_layers, 2, batch_size, num_hidden)
#Sample input
inpt = "we have no useful information on"
cnt_sample = 20
def _pretty_print(items, is_char_model, id2word):
if not is_char_model:
return ' '.join([id2word[x] for x in items])
else:
return ''.join([id2word[x] for x in items]).replace('_', ' ')
def _get_tvm_graph_module(graph_def):
sym, params = nnvm.frontend.from_tensorflow(graph_def)
#Cell inputs 'c and 'h' consist of all layers values
shape_dict = {'Model/Placeholder': (batch_size, num_steps),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)}
dtype_dict = {'Model/Placeholder': 'int32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'}
target = 'llvm'
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
dtype=dtype_dict, params=params)
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
return params, graph_runtime.create(graph, lib, ctx)
def _do_tvm_sample(model, data, in_states, params, num_samples):
"""Sampled from the model"""
samples = []
state = in_states
sample = None
def _get_sample(data, state):
input_data = np.full((batch_size, num_steps), data, dtype="int32")
in_state_tup = np.split(state, indices_or_sections=2, axis=1)
in_state_c = np.reshape(in_state_tup[0], (num_layers, batch_size, num_hidden))
in_state_h = np.reshape(in_state_tup[1], (num_layers, batch_size, num_hidden))
model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c',
tvm.nd.array(in_state_c.astype("float32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h',
tvm.nd.array(in_state_h.astype("float32")))
model.set_input(**params)
model.run()
tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape,
"float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy()
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
return sample, state_output
for x in data:
sample, state = _get_sample(x, state)
if sample is not None:
samples.append(sample)
else:
samples.append(0)
k = 1
while k < num_samples:
sample, state = _get_sample(samples[-1], state)
samples.append(sample)
k += 1
return samples, state
with tf.Graph().as_default():
word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb()
vocab_size = len(word_to_id)
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
sess = tf.Session()
#TVM graph module creation
params, m = _get_tvm_graph_module(graph_def)
# Create 10 predicted statments of 20 words
cnt_stm = 0
while cnt_stm < 10:
cnt_stm += 1
in_state = np.full((num_layers, 2, batch_size, num_hidden), 0, dtype="float32")
seed_for_sample = inpt.split()
tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] \
for word in seed_for_sample],
in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess,
[word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
inpt = tvm_sample_str
np.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
assert(tvm_sample_str == tf_sample_str)
#######################################################################
# Main
# ----
if __name__ == '__main__':
......@@ -600,3 +871,7 @@ if __name__ == '__main__':
test_forward_mobilenet()
test_forward_variable()
test_forward_resize_bilinear()
test_forward_lstm()
test_forward_stridedslice()
test_forward_gather()
test_forward_ptb()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment