Commit a808a987 by Albin Joy Committed by Tianqi Chen

[NNVM][TENSORFLOW] LSTM operator and PTB word prediction frontend (#1389)

parent f7d05b7c
......@@ -6,6 +6,8 @@ Some helper definitions for tensorflow models.
"""
import re
import os.path
import collections
import numpy as np
# Tensorflow imports
import tensorflow as tf
......@@ -134,3 +136,143 @@ def get_workload(model_path):
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
return graph_def
#######################################################################
# PTB LSTMBlockCell Model
# -----------------------
class PTBSmallConfig(object):
"""Small config.
This configurations are used when training the model
"""
num_layers = 2
num_steps = 1
hidden_size = 200
batch_size = 1
vocab_size = 10000
init_scale = 0.1
def get_config():
"""Configuration used for training the model"""
return PTBSmallConfig()
def pick_from_weight(weight, pows=1.0):
"""Identify token from Softmax output.
This token will be mapped to word in the vocabulary.
"""
weight = weight**pows
t = np.cumsum(weight)
s = np.sum(weight)
return int(np.searchsorted(t, 0.5 * s))
def do_tf_sample(session, data, in_states, num_samples):
"""Sampled from the model"""
samples = []
sample = None
#Cell inputs c and h should be passed for each layer explicitly.
state_input_name = ['Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0',
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0']
state = session.run(state_input_name)
#Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal
#nodes for intermediate operations (gates) in the cell during run.
#Cell state (c) is ':1'and cell output (h) is ':6' for each layer.
fetches = [['Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6'],
'Model/Softmax:0']
def _get_feed_dict(input_name, input_data):
"""Create feed dict"""
feed_dict = {}
if isinstance(input_data, list):
for i, e in enumerate(input_name):
feed_dict[e] = input_data[i]
else:
feed_dict[input_name] = input_data
return feed_dict
for x in data:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict['Model/Placeholder:0'] = [[x]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
if sample is not None:
samples.append(sample)
else:
samples.append(0)
k = 1
while k < num_samples:
feed_dict = _get_feed_dict(state_input_name, state)
feed_dict['Model/Placeholder:0'] = [[samples[-1]]]
state, probs = session.run(fetches, feed_dict)
sample = pick_from_weight(probs[0])
samples.append(sample)
k += 1
return samples, state
def _create_ptb_vocabulary(data_dir):
"""Read the PTB sample data input to create vocabulary"""
data_path = data_dir+'simple-examples/data/'
file_name = 'ptb.train.txt'
def _read_words(filename):
"""Read the data for creating vocabulary"""
with tf.gfile.GFile(filename, "r") as f:
return f.read().encode("utf-8").decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
"""Create vocabulary"""
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
#for python 3.x
id_to_word = dict((v, k) for k, v in word_to_id.items())
return word_to_id, id_to_word
def ptb_raw_data(data_path, file_name):
"""Read the sample data and create vocabulary"""
train_path = os.path.join(data_path, file_name)
word_to_id, id_2_word = _build_vocab(train_path)
return word_to_id, id_2_word
return ptb_raw_data(data_path, file_name)
def get_workload_ptb():
""" Import ptb workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for ptb.
word_to_id : dict
English word to integer id mapping
id_to_word : dict
Integer id to English word mapping
"""
sample_repo = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
sample_data_file = 'simple-examples.tgz'
sample_url = sample_repo+sample_data_file
ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'
import tarfile
from tvm.contrib.download import download
DATA_DIR = './ptb_data/'
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
download(sample_url, DATA_DIR+sample_data_file)
t = tarfile.open(DATA_DIR+sample_data_file, 'r')
t.extractall(DATA_DIR)
word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
return word_to_id, id_to_word, get_workload(ptb_model_file)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment