Commit a808a987 by Albin Joy Committed by Tianqi Chen

[NNVM][TENSORFLOW] LSTM operator and PTB word prediction frontend (#1389)

parent f7d05b7c
# pylint: disable=import-self, invalid-name, unused-argument # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines
"""TF: Tensorflow frontend.""" """TF: Tensorflow frontend."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from __future__ import print_function from __future__ import print_function
...@@ -457,6 +457,190 @@ def _shape(): ...@@ -457,6 +457,190 @@ def _shape():
return inputs[0] return inputs[0]
return _impl return _impl
def _fill():
def _impl(inputs, attr, params):
fill_arg = params.pop(inputs.pop(1).list_output_names()[0])
new_inputs = []
return AttrCvt(
op_name='full',
extras={'shape':inputs[0],
'fill_value':fill_arg.asnumpy()[0], 'dtype':attr['T'].name},
ignores=['index_type', 'T'])(new_inputs, attr)
return _impl
def _gather_v2():
"Tensorflow now support only gatherv2"
def _impl(inputs, attr, params):
axis = params[inputs.pop(2).list_output_names()[0]].asnumpy()[0]
new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
return AttrCvt(
op_name="take",
extras={'axis':axis},
ignores=['Tindices', 'Tparams', 'validate_indices', \
'Taxis', '_class'])(new_input, attr)
return _impl
def _infer_out_shapes(inputs, params):
"""A method to get the output shape of an intermediate node in the NNVM graph."""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
return out_shapes
def _stridedSlice():
def _impl(inputs, attr, params):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist()
end = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist()
stride = params.pop(inputs[3].list_output_names()[0]).asnumpy().tolist()
begin_mask = int(attr.get('begin_mask', 0))
end_mask = int(attr.get('end_mask', 0))
ellipsis_mask = int(attr.get('ellipsis_mask', 0))
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape[0])
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
#Used later for extending the stride attributes in the below loop.
ellipsis_mask |= (1 << stride_dim)
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[0][final_index]
m_stride[final_index] = 1
final_index += 1
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[0][final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[0][final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[0][final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
final_index += 1
return m_begin, m_end, m_stride
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride = _transform_mask(stride_dim, ellipsis_mask)
out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride)
out_shape = _infer_out_shapes(out, params)[0]
#Create final output shape.
final_output = []
out_index = 0
index = 0
while out_index != len(out_shape):
#axis with shrink_axis_mask dimension=1 and it is ignored.
mask = 1 << index
if (new_axis_mask & mask) and not ellipsis_mask & mask:
final_output.append(1)
elif (not mask & shrink_axis_mask) or index >= stride_dim:
#Shrink is considered till stride_dim
final_output.append(out_shape[out_index])
out_index += 1
index += 1
return _sym.reshape(out, shape=tuple(final_output))
return _impl
def _LSTMBlockCell():
def _impl(inputs, in_state_c, in_state_h, attr, params):
"""LSTM Block cell.
Calculations are described in: https://github.com/tensorflow/tensorflow/blob/
r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
Parameters
----------
inputs : nnvm.Symbol
Input data
in_state_c: list of nnvm.Symbol
Cell state input values for all the layers
in_state_h: list of nnvm.Symbol
Hidden state input values for all the layers
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
output: nnvm.Symbol
Output state value.
"""
in_data = inputs[0]
in_weight = inputs[3]
in_bias = inputs[7]
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
num_hidden = num_hidden_layers // 4
in_data = _sym.reshape(in_data,
shape=(batch_size, input_size))
ixh = _sym.concatenate(*[in_data, in_state_h], axis=1)
in_weight = _sym.transpose(in_weight)
gates = _sym.dense(ixh, in_weight, in_bias, use_bias=True,
units=num_hidden_layers, name="dense")
gate_list = _sym.split(gates, indices_or_sections=4, axis=1)
in_gate = _sym.sigmoid(gate_list[0])
in_transform = _sym.tanh(gate_list[1])
forget_gate = _sym.sigmoid(gate_list[2])
forget_gate = forget_gate + forget_bias
out_gate = _sym.sigmoid(gate_list[3])
next_c = _sym.broadcast_add(_sym.broadcast_mul(forget_gate, in_state_c),
_sym.broadcast_mul(in_gate, in_transform))
next_h = out_gate * _sym.tanh(next_c)
out_state = _sym.concatenate(*[next_c, next_h])
out_state = _sym.reshape(out_state,
shape=(2, batch_size, num_hidden))
return next_h, out_state
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -493,8 +677,192 @@ _convert_map = { ...@@ -493,8 +677,192 @@ _convert_map = {
'DepthwiseConv2dNative' : _depthwise_conv(), 'DepthwiseConv2dNative' : _depthwise_conv(),
'Shape' : _shape(), 'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'), 'Sigmoid' : AttrCvt('sigmoid'),
'Fill' : _fill(),
'GatherV2' : _gather_v2(),
'StridedSlice' : _stridedSlice(),
}
# _convert_map_rnn defines maps of rnn operator name to
# converter functor(callable) for 1 to 1 mapping.
_convert_map_rnn = {
'LSTMBlockCell' : _LSTMBlockCell(),
} }
class RecurrentNetworks(object):
"""Recurrent network layer handlers.
Handle Layer operations.
ToDo: Operators like RNN/GRU layer concepts also can be handled here
Parameters
----------
nodes : list
list of graph nodes used for tensorflow parsing.
out_rnn : list
List of RecurrentNetwork outputs. This output will be appended to the
'head' nodes of the graph.
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
"""
def __init__(self, nodes, out_rnn, graph, convert_map):
self._graph = graph
self._convert_map = convert_map
self._nodes = nodes
self._out_rnn = out_rnn
self._cur_lstm_layer = 0
self._layer_name_list = []
self._recurrent_ops_layer_map = {
'LSTMBlockCell' : self._LSTMBlockCellLayer(),
}
def _LSTMBlockCellLayer(self):
"""LSTMBlockCell layer handler.
Parameters
----------
op_name : str
Operator name, eg:LSTMBlockCell
layer_name : str list
Layer name is used for creating the state input placeholder.
inputs : nnvm.Symbol
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
num_layers : int
Total number of LSTM layer presented in the graph
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
"""
def _impl(op_name, layer_name, inputs, attrs, params, num_layers):
in_state_c_name = layer_name+'_c'
in_state_h_name = layer_name+'_h'
def _init_state(num_layers, batch_size, num_hidden):
"""Create the initial states for the first layer in the graph."""
in_state_c = _sym.Variable(in_state_c_name,
shape=(num_layers, batch_size, num_hidden))
in_state_h = _sym.Variable(in_state_h_name,
shape=(num_layers, batch_size, num_hidden))
return in_state_c, in_state_h
def _get_cur_input_state(in_state_c, in_state_h, num_layers,
layer, batch_size, num_hidden):
"""Select the appropriate states for the current layer"""
in_state_c_tup = _sym.split(in_state_c,
indices_or_sections=num_layers, axis=0)
in_state_h_tup = _sym.split(in_state_h,
indices_or_sections=num_layers, axis=0)
cur_in_state_c = _sym.reshape(in_state_c_tup[layer],
shape=(batch_size, num_hidden))
cur_in_state_h = _sym.reshape(in_state_h_tup[layer],
shape=(batch_size, num_hidden))
return cur_in_state_c, cur_in_state_h
def _LSTMBlockCellWrapper(inputs, attr, params,
num_layers, layer):
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4
if layer == 0:
#Create initial states placeholder in case of first layer
in_state_c, in_state_h = _init_state(num_layers,
batch_size, num_hidden)
else:
in_state_c = self._nodes[in_state_c_name]
in_state_h = self._nodes[in_state_h_name]
cur_in_state_c, cur_in_state_h = _get_cur_input_state( \
in_state_c, in_state_h,
num_layers, layer,
batch_size, num_hidden)
output, out_state = self._convert_map[op_name](inputs, cur_in_state_c,
cur_in_state_h,
attr, params)
return output, out_state, in_state_c, in_state_h
sym, cur_out_state, in_state_c, in_state_h = \
_LSTMBlockCellWrapper(inputs, attrs, params,
num_layers, self._cur_lstm_layer)
self._nodes[in_state_c_name] = in_state_c
self._nodes[in_state_h_name] = in_state_h
cur_out_state = _sym.expand_dims(cur_out_state, axis=0, num_newaxis=1)
self._out_rnn.append(cur_out_state)
self._cur_lstm_layer += 1
return sym
return _impl
def process_op(self, op_name, inputs, attrs, params):
"""Process recurrent layer operators.
List '_recurrent_ops_layer_map' map each Layer based operators with its
layer handlers. Total number of layers are calculated to form the input
data shapes.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : nnvm.Symbol
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
"""
def _get_abs_layer_name(node):
"""Identify the layer name is already handled. Return the absolute name
"""
if not self._layer_name_list:
self._layer_name_list.append(node.name)
return node.name
for _name in self._layer_name_list:
if _name in node.name:
abs_name = _name
else:
self._layer_name_list.append(node.name)
abs_name = node.name
return abs_name
#Find number of layers of this same operator node in the graph
#and also read the inputs name for the current op.
num_layers = 0
for _, node in enumerate(self._graph.node):
if node.op == op_name:
layer_name = _get_abs_layer_name(node)
num_layers += 1
sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs,
params, num_layers)
return sym
class GraphProto(object): class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef. """ A helper class for handling nnvm graph copying from Tensorflow GraphDef.
...@@ -510,6 +878,7 @@ class GraphProto(object): ...@@ -510,6 +878,7 @@ class GraphProto(object):
self._num_input = 0 self._num_input = 0
self._num_param = 0 self._num_param = 0
self._input_node = '' self._input_node = ''
self._num_rnn_layer = False
def from_tensorflow(self, graph): def from_tensorflow(self, graph):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef. """Construct nnvm nodes from tensorflow graph definition - GraphDef.
...@@ -553,16 +922,19 @@ class GraphProto(object): ...@@ -553,16 +922,19 @@ class GraphProto(object):
for node in graph.node: for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction. # Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict. # Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
if node.op == "Placeholder": if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder' # Assuming only one input graph with type 'Placeholder'
self._input_node = node.name self._input_node = node.name
self._num_input += 1 self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name)
try: try:
self._output_shapes[node.name] = \ self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \ [tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']] for shape in self._parse_attr(node.attr)['_output_shapes']]
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])
input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
except KeyError: except KeyError:
raise NotImplementedError( \ raise NotImplementedError( \
"Please freeze the graph with add_shapes=True") "Please freeze the graph with add_shapes=True")
...@@ -580,7 +952,6 @@ class GraphProto(object): ...@@ -580,7 +952,6 @@ class GraphProto(object):
if node.name not in self._nodes: if node.name not in self._nodes:
raise NotImplementedError( \ raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name)) "Const {} couldn't be converted to Param.".format(node.name))
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr #Variable converted to Const will not have only value attr
if 'value' in attr: if 'value' in attr:
...@@ -611,9 +982,16 @@ class GraphProto(object): ...@@ -611,9 +982,16 @@ class GraphProto(object):
# Pass the node name too in attr # Pass the node name too in attr
attr["_node_name"] = node.name attr["_node_name"] = node.name
#ToDo: Some of the tensorflow operators maintain internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name
try: try:
inputs = [self._nodes[i] for i in node.input] inputs = [self._nodes[i] for i in node.input]
input_shapes = {}
for i in node.input: for i in node.input:
if i not in self._params: if i not in self._params:
input_shapes[self._nodes[i]] = self._output_shapes[i] input_shapes[self._nodes[i]] = self._output_shapes[i]
...@@ -624,12 +1002,20 @@ class GraphProto(object): ...@@ -624,12 +1002,20 @@ class GraphProto(object):
inputs = self._fix_extranodes(node.op, attr, inputs) inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr) op = self._convert_operator(node.op, inputs, attr, graph)
# Assuming only one output. # Assuming only one output.
self._nodes[node.name] = op self._nodes[node.name] = op
node_output = op node_output = op
# Assume the final node is the output node # Assume the final node is the output node
out = node_output out = node_output
#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out = [out, out_rnn]
if isinstance(out, list):
out = _sym.Group(out)
return out, self._params return out, self._params
def _parse_param(self, key, value, name): def _parse_param(self, key, value, name):
...@@ -651,7 +1037,7 @@ class GraphProto(object): ...@@ -651,7 +1037,7 @@ class GraphProto(object):
self._nodes[name] = _sym.Variable(name=name, self._nodes[name] = _sym.Variable(name=name,
shape=self._params[name].shape) shape=self._params[name].shape)
else: else:
if key != 'dtype' and key != '_output_shapes': if key != 'dtype' and key != '_output_shapes' and key != '_class':
raise NotImplementedError \ raise NotImplementedError \
("Other attributes for a Const(param) Node {} ? .".format(key)) ("Other attributes for a Const(param) Node {} ? .".format(key))
...@@ -706,7 +1092,44 @@ class GraphProto(object): ...@@ -706,7 +1092,44 @@ class GraphProto(object):
return attrs return attrs
def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None): def _convert_rnn_operator(self, op_name, inputs,
attrs, params, graph, convert_map):
"""Convert RNN and its variant operators to NNVM operators.
This converter read the input states of each layers and
also maintain the output states of each layer in a list.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
graph : Tensorflow graph object
Graph is to find the number of upcoming same operator to
calculate the number of layers.
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
"""
if not self._num_rnn_layer:
self._out_rnn = []
self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map)
self._num_rnn_layer = True
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym
def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to nnvm operator. """Convert from Tensorflow operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes. apply handlers to operator attributes.
...@@ -733,10 +1156,15 @@ class GraphProto(object): ...@@ -733,10 +1156,15 @@ class GraphProto(object):
""" """
identity_list = identity_list if identity_list else _identity_list identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map convert_map = convert_map if convert_map else _convert_map
convert_map_rnn = _convert_map_rnn
if op_name in identity_list: if op_name in identity_list:
sym = get_nnvm_op(op_name)(*inputs, **attrs) sym = get_nnvm_op(op_name)(*inputs, **attrs)
elif op_name in convert_map: elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params) sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
convert_map_rnn)
else: else:
raise NotImplementedError("Operator {} not implemented.".format(op_name)) raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym return sym
......
...@@ -6,6 +6,8 @@ Some helper definitions for tensorflow models. ...@@ -6,6 +6,8 @@ Some helper definitions for tensorflow models.
""" """
import re import re
import os.path import os.path
import collections
import numpy as np
# Tensorflow imports # Tensorflow imports
import tensorflow as tf import tensorflow as tf
...@@ -134,3 +136,143 @@ def get_workload(model_path): ...@@ -134,3 +136,143 @@ def get_workload(model_path):
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf.import_graph_def(graph_def, name='')
return graph_def return graph_def
#######################################################################
# PTB LSTMBlockCell Model
# -----------------------
class PTBSmallConfig(object):
"""Small config.
This configurations are used when training the model
"""
num_layers = 2
num_steps = 1
hidden_size = 200
batch_size = 1
vocab_size = 10000
init_scale = 0.1
def get_config():
"""Configuration used for training the model"""
return PTBSmallConfig()
def pick_from_weight(weight, pows=1.0):
"""Identify token from Softmax output.
This token will be mapped to word in the vocabulary.
"""
weight = weight**pows
t = np.cumsum(weight)
s = np.sum(weight)
return int(np.searchsorted(t, 0.5 * s))
def do_tf_sample(session, data, in_states, num_samples):
"""Sampled from the model"""
samples = []
sample = None
#Cell inputs c and h should be passed for each layer explicitly.
state_input_name = ['Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0']
state = session.run(state_input_name)
#Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal
#nodes for intermediate operations (gates) in the cell during run.
#Cell state (c) is ':1'and cell output (h) is ':6' for each layer.
fetches = [['Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6'],
'Model/Softmax:0']
def _get_feed_dict(input_name, input_data):
"""Create feed dict"""
feed_dict = {}
if isinstance(input_data, list):
for i, e in enumerate(input_name):
feed_dict[e] = input_data[i]
else:
feed_dict[input_name] = input_data
return feed_dict
for x in data:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict['Model/Placeholder:0'] = [[x]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
if sample is not None:
samples.append(sample)
else:
samples.append(0)
k = 1
while k < num_samples:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict['Model/Placeholder:0'] = [[samples[-1]]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
samples.append(sample)
k += 1
return samples, state
def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary"""
data_path = data_dir+'simple-examples/data/'
file_name = 'ptb.train.txt'
def _read_words(filename):
"""Read the data for creating vocabulary"""
with tf.gfile.GFile(filename, "r") as f:
return f.read().encode("utf-8").decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
"""Create vocabulary"""
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
#for python 3.x
id_to_word = dict((v, k) for k, v in word_to_id.items())
return word_to_id, id_to_word
def ptb_raw_data(data_path, file_name):
"""Read the sample data and create vocabulary"""
train_path = os.path.join(data_path, file_name)
word_to_id, id_2_word = _build_vocab(train_path)
return word_to_id, id_2_word
return ptb_raw_data(data_path, file_name)
def get_workload_ptb():
""" Import ptb workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for ptb.
word_to_id : dict
English word to integer id mapping
id_to_word : dict
Integer id to English word mapping
"""
sample_repo = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
sample_data_file = 'simple-examples.tgz'
sample_url = sample_repo+sample_data_file
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'
import tarfile
from tvm.contrib.download import download
DATA_DIR = './ptb_data/'
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
download(sample_url, DATA_DIR+sample_data_file)
t = tarfile.open(DATA_DIR+sample_data_file, 'r')
t.extractall(DATA_DIR)
word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
return word_to_id, id_to_word, get_workload(ptb_model_file)
...@@ -10,12 +10,14 @@ import nnvm.compiler ...@@ -10,12 +10,14 @@ import nnvm.compiler
import tvm import tvm
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
import nnvm.testing.tf import nnvm.testing.tf
...@@ -55,6 +57,13 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype) ...@@ -55,6 +57,13 @@ def run_tvm_graph(graph_def, input_data, input_node, output_shape, output_dtype)
# execute # execute
m.run() m.run()
# get outputs # get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, s in enumerate(output_shape):
tvm_output = m.get_output(i, tvm.nd.empty((s), output_dtype[i]))
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype)) tvm_output = m.get_output(0, tvm.nd.empty((output_shape), output_dtype))
return tvm_output.asnumpy() return tvm_output.asnumpy()
...@@ -434,6 +443,159 @@ def test_forward_variable(): ...@@ -434,6 +443,159 @@ def test_forward_variable():
####################################################################### #######################################################################
# LSTM
# ----
def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
tf.reset_default_graph()
input_size = num_hidden
input_data = np.full((batch_size, input_size), 1., dtype=dtype)
in_state_c = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
in_state_h = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
def _get_tensorflow_output():
with tf.Session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
m0 = array_ops.zeros([batch_size, num_hidden])
m1 = array_ops.zeros([batch_size, num_hidden])
x=tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
g, ((out_m0, out_m1)) = \
tf.contrib.rnn.LSTMBlockCell(num_hidden,
forget_bias=forget_bias)(x, ((m0, m1)))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1], {
x.name: np.array([[1., 1.]]),
m0.name: 0.1 * np.ones([batch_size, num_hidden]),
m1.name: 0.1 * np.ones([batch_size, num_hidden]),
})
graph_def = sess.graph.as_graph_def(add_shapes=True)
final_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['root/lstm_cell/LSTMBlockCell'])
return final_graph_def, res
graph_def, tf_out = _get_tensorflow_output()
tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h],
['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c',
'root/lstm_cell/LSTMBlockCell_h'],
[tf_out[0].shape, (2, batch_size, num_hidden)],
[tf_out[0].dtype, tf_out[1].dtype])
if isinstance(tvm_output, list):
out = tvm_output[0]
out_state = tvm_output[1]
out_state_tup = np.split(out_state, indices_or_sections=2, axis=0)
out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden))
out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden))
tvm_out = [out, out_state_c, out_state_h]
np.testing.assert_allclose(tf_out, tvm_out, rtol=1e-3, atol=1e-3)
def test_forward_lstm():
'''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.0, 'float32')
#######################################################################
# StridedSlice
# ------------
def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0):
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
end_mask=end_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask, name="strided_slice")
np_data = np.random.uniform(size=ip_shape).astype(dtype)
with tf.Session() as sess:
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['strided_slice'])
tf_output = run_tf_graph(sess, np_data,
'in_data:0', 'strided_slice:0')
tvm_output = run_tvm_graph(final_graph_def, np_data,
"in_data", tf_output.shape, np_data.dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
def test_forward_stridedslice():
'''test StridedSlice'''
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
_test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3)
_test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=2)
_test_stridedslice((3,4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1, new_axis_mask=2)
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2, new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], 'float32', shrink_axis_mask=5, new_axis_mask=1)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=5, new_axis_mask=1, ellipsis_mask=2, begin_mask=8, end_mask=8)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1],
'float32', shrink_axis_mask=16, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=5)
_test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1],
'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5,
end_mask=8)
#######################################################################
# Gather
# ------
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices, axis=axis)
np_data = np.random.uniform(size=ip_shape).astype(dtype)
def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
if isinstance(indice_value, int):
indices = np.array([indice_value], dtype='int32')
else:
indices = np.asarray(indice_value, dtype='int32')
return indices
np_indices = _fill_indices(indice_value)
with tf.Session() as sess:
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['GatherV2'])
tf_output = run_tf_graph(sess, [np_data, np_indices], ['in_data:0',
'indices:0'], 'GatherV2:0')
tvm_output = run_tvm_graph(final_graph_def, [np_data, np_indices],
['in_data', 'indices'], tf_output.shape, dtype)
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
sess.close()
def test_forward_gather():
'''test gather layer'''
_test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32')
_test_gather((1,4), (1,), [0], 0, 'int32')
_test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32')
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32')
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
#######################################################################
# Multi Input to graph # Multi Input to graph
# -------------------- # --------------------
...@@ -584,6 +746,115 @@ def test_forward_mobilenet(): ...@@ -584,6 +746,115 @@ def test_forward_mobilenet():
np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5) np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
####################################################################### #######################################################################
# PTB
# ---
dir(tf.contrib)
def test_forward_ptb():
'''test ptb model'''
config = nnvm.testing.tf.get_config()
num_steps = config.num_steps
num_hidden = config.hidden_size
num_layers = config.num_layers
batch_size = config.batch_size
vocab_size = config.vocab_size
out_sample_shape = (batch_size, vocab_size)
out_state_shape = (num_layers, 2, batch_size, num_hidden)
#Sample input
inpt = "we have no useful information on"
cnt_sample = 20
def _pretty_print(items, is_char_model, id2word):
if not is_char_model:
return ' '.join([id2word[x] for x in items])
else:
return ''.join([id2word[x] for x in items]).replace('_', ' ')
def _get_tvm_graph_module(graph_def):
sym, params = nnvm.frontend.from_tensorflow(graph_def)
#Cell inputs 'c and 'h' consist of all layers values
shape_dict = {'Model/Placeholder': (batch_size, num_steps),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden),
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)}
dtype_dict = {'Model/Placeholder': 'int32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'}
target = 'llvm'
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict,
dtype=dtype_dict, params=params)
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
return params, graph_runtime.create(graph, lib, ctx)
def _do_tvm_sample(model, data, in_states, params, num_samples):
"""Sampled from the model"""
samples = []
state = in_states
sample = None
def _get_sample(data, state):
input_data = np.full((batch_size, num_steps), data, dtype="int32")
in_state_tup = np.split(state, indices_or_sections=2, axis=1)
in_state_c = np.reshape(in_state_tup[0], (num_layers, batch_size, num_hidden))
in_state_h = np.reshape(in_state_tup[1], (num_layers, batch_size, num_hidden))
model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c',
tvm.nd.array(in_state_c.astype("float32")))
model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h',
tvm.nd.array(in_state_h.astype("float32")))
model.set_input(**params)
model.run()
tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape,
"float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy()
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
return sample, state_output
for x in data:
sample, state = _get_sample(x, state)
if sample is not None:
samples.append(sample)
else:
samples.append(0)
k = 1
while k < num_samples:
sample, state = _get_sample(samples[-1], state)
samples.append(sample)
k += 1
return samples, state
with tf.Graph().as_default():
word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb()
vocab_size = len(word_to_id)
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
sess = tf.Session()
#TVM graph module creation
params, m = _get_tvm_graph_module(graph_def)
# Create 10 predicted statments of 20 words
cnt_stm = 0
while cnt_stm < 10:
cnt_stm += 1
in_state = np.full((num_layers, 2, batch_size, num_hidden), 0, dtype="float32")
seed_for_sample = inpt.split()
tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] \
for word in seed_for_sample],
in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess,
[word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
inpt = tvm_sample_str
np.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
assert(tvm_sample_str == tf_sample_str)
#######################################################################
# Main # Main
# ---- # ----
if __name__ == '__main__': if __name__ == '__main__':
...@@ -600,3 +871,7 @@ if __name__ == '__main__': ...@@ -600,3 +871,7 @@ if __name__ == '__main__':
test_forward_mobilenet() test_forward_mobilenet()
test_forward_variable() test_forward_variable()
test_forward_resize_bilinear() test_forward_resize_bilinear()
test_forward_lstm()
test_forward_stridedslice()
test_forward_gather()
test_forward_ptb()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment