Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a808a987
Commit
a808a987
authored
Jul 25, 2018
by
Albin Joy
Committed by
Tianqi Chen
Jul 25, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM][TENSORFLOW] LSTM operator and PTB word prediction frontend (#1389)
parent
f7d05b7c
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
142 additions
and
0 deletions
+142
-0
nnvm/python/nnvm/frontend/tensorflow.py
+0
-0
nnvm/python/nnvm/testing/tf.py
+142
-0
nnvm/tests/python/frontend/tensorflow/test_forward.py
+0
-0
No files found.
nnvm/python/nnvm/frontend/tensorflow.py
View file @
a808a987
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/testing/tf.py
View file @
a808a987
...
@@ -6,6 +6,8 @@ Some helper definitions for tensorflow models.
...
@@ -6,6 +6,8 @@ Some helper definitions for tensorflow models.
"""
"""
import
re
import
re
import
os.path
import
os.path
import
collections
import
numpy
as
np
# Tensorflow imports
# Tensorflow imports
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -134,3 +136,143 @@ def get_workload(model_path):
...
@@ -134,3 +136,143 @@ def get_workload(model_path):
graph_def
.
ParseFromString
(
f
.
read
())
graph_def
.
ParseFromString
(
f
.
read
())
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
return
graph_def
return
graph_def
#######################################################################
# PTB LSTMBlockCell Model
# -----------------------
class
PTBSmallConfig
(
object
):
"""Small config.
This configurations are used when training the model
"""
num_layers
=
2
num_steps
=
1
hidden_size
=
200
batch_size
=
1
vocab_size
=
10000
init_scale
=
0.1
def
get_config
():
"""Configuration used for training the model"""
return
PTBSmallConfig
()
def
pick_from_weight
(
weight
,
pows
=
1.0
):
"""Identify token from Softmax output.
This token will be mapped to word in the vocabulary.
"""
weight
=
weight
**
pows
t
=
np
.
cumsum
(
weight
)
s
=
np
.
sum
(
weight
)
return
int
(
np
.
searchsorted
(
t
,
0.5
*
s
))
def
do_tf_sample
(
session
,
data
,
in_states
,
num_samples
):
"""Sampled from the model"""
samples
=
[]
sample
=
None
#Cell inputs c and h should be passed for each layer explicitly.
state_input_name
=
[
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0'
,
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0'
,
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0'
,
'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0'
]
state
=
session
.
run
(
state_input_name
)
#Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal
#nodes for intermediate operations (gates) in the cell during run.
#Cell state (c) is ':1'and cell output (h) is ':6' for each layer.
fetches
=
[[
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1'
,
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6'
,
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1'
,
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6'
],
'Model/Softmax:0'
]
def
_get_feed_dict
(
input_name
,
input_data
):
"""Create feed dict"""
feed_dict
=
{}
if
isinstance
(
input_data
,
list
):
for
i
,
e
in
enumerate
(
input_name
):
feed_dict
[
e
]
=
input_data
[
i
]
else
:
feed_dict
[
input_name
]
=
input_data
return
feed_dict
for
x
in
data
:
feed_dict
=
_get_feed_dict
(
state_input_name
,
state
)
feed_dict
[
'Model/Placeholder:0'
]
=
[[
x
]]
state
,
probs
=
session
.
run
(
fetches
,
feed_dict
)
sample
=
pick_from_weight
(
probs
[
0
])
if
sample
is
not
None
:
samples
.
append
(
sample
)
else
:
samples
.
append
(
0
)
k
=
1
while
k
<
num_samples
:
feed_dict
=
_get_feed_dict
(
state_input_name
,
state
)
feed_dict
[
'Model/Placeholder:0'
]
=
[[
samples
[
-
1
]]]
state
,
probs
=
session
.
run
(
fetches
,
feed_dict
)
sample
=
pick_from_weight
(
probs
[
0
])
samples
.
append
(
sample
)
k
+=
1
return
samples
,
state
def
_create_ptb_vocabulary
(
data_dir
):
"""Read the PTB sample data input to create vocabulary"""
data_path
=
data_dir
+
'simple-examples/data/'
file_name
=
'ptb.train.txt'
def
_read_words
(
filename
):
"""Read the data for creating vocabulary"""
with
tf
.
gfile
.
GFile
(
filename
,
"r"
)
as
f
:
return
f
.
read
()
.
encode
(
"utf-8"
)
.
decode
(
"utf-8"
)
.
replace
(
"
\n
"
,
"<eos>"
)
.
split
()
def
_build_vocab
(
filename
):
"""Create vocabulary"""
data
=
_read_words
(
filename
)
counter
=
collections
.
Counter
(
data
)
count_pairs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
count_pairs
))
word_to_id
=
dict
(
zip
(
words
,
range
(
len
(
words
))))
#for python 3.x
id_to_word
=
dict
((
v
,
k
)
for
k
,
v
in
word_to_id
.
items
())
return
word_to_id
,
id_to_word
def
ptb_raw_data
(
data_path
,
file_name
):
"""Read the sample data and create vocabulary"""
train_path
=
os
.
path
.
join
(
data_path
,
file_name
)
word_to_id
,
id_2_word
=
_build_vocab
(
train_path
)
return
word_to_id
,
id_2_word
return
ptb_raw_data
(
data_path
,
file_name
)
def
get_workload_ptb
():
""" Import ptb workload from frozen protobuf
Parameters
----------
Nothing.
Returns
-------
graph_def: graphdef
graph_def is the tensorflow workload for ptb.
word_to_id : dict
English word to integer id mapping
id_to_word : dict
Integer id to English word mapping
"""
sample_repo
=
'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
sample_data_file
=
'simple-examples.tgz'
sample_url
=
sample_repo
+
sample_data_file
ptb_model_file
=
'RNN/ptb/ptb_model_with_lstmblockcell.pb'
import
tarfile
from
tvm.contrib.download
import
download
DATA_DIR
=
'./ptb_data/'
if
not
os
.
path
.
exists
(
DATA_DIR
):
os
.
mkdir
(
DATA_DIR
)
download
(
sample_url
,
DATA_DIR
+
sample_data_file
)
t
=
tarfile
.
open
(
DATA_DIR
+
sample_data_file
,
'r'
)
t
.
extractall
(
DATA_DIR
)
word_to_id
,
id_to_word
=
_create_ptb_vocabulary
(
DATA_DIR
)
return
word_to_id
,
id_to_word
,
get_workload
(
ptb_model_file
)
nnvm/tests/python/frontend/tensorflow/test_forward.py
View file @
a808a987
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment