Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
770ac84e
Commit
770ac84e
authored
Jun 06, 2019
by
Alexey Romanov
Committed by
Tianqi Chen
Jun 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Frontend] Simplify parameter handling in Tensorflow frontend (#2993)
parent
5999f7a6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
40 deletions
+51
-40
python/tvm/relay/frontend/tensorflow.py
+0
-0
tests/python/frontend/tensorflow/test_forward.py
+49
-30
topi/python/topi/util.py
+2
-10
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
770ac84e
This diff is collapsed.
Click to expand it.
tests/python/frontend/tensorflow/test_forward.py
View file @
770ac84e
...
@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
...
@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout
=
None
layout
=
None
if
target
==
"cuda"
:
if
target
==
"cuda"
:
layout
=
"NCHW"
layout
=
"NCHW"
target_host
=
'llvm'
target_host
=
None
if
isinstance
(
input_data
,
list
):
shape_dict
=
{
e
:
i
.
shape
for
e
,
i
in
zip
(
input_node
,
input_data
)}
shape_dict
=
{}
dtype_dict
=
{}
for
i
,
e
in
enumerate
(
input_node
):
shape_dict
[
e
]
=
input_data
[
i
]
.
shape
dtype_dict
[
e
]
=
input_data
[
i
]
.
dtype
else
:
shape_dict
=
{
input_node
:
input_data
.
shape
}
dtype_dict
=
{
input_node
:
input_data
.
dtype
}
sym
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
sym
,
params
=
relay
.
frontend
.
from_tensorflow
(
graph_def
,
layout
=
layout
,
layout
=
layout
,
shape
=
shape_dict
,
shape
=
shape_dict
,
outputs
=
out_names
)
outputs
=
out_names
)
with
relay
.
build_config
(
opt_level
=
opt_level
):
with
relay
.
build_config
(
opt_level
=
opt_level
):
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
params
=
params
)
graph
,
lib
,
params
=
relay
.
build
(
sym
,
target
,
target_host
,
params
)
ctx
=
tvm
.
context
(
target
,
0
)
ctx
=
tvm
.
context
(
target
,
0
)
from
tvm.contrib
import
graph_runtime
from
tvm.contrib
import
graph_runtime
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
# set inputs
# set inputs
for
i
,
e
in
enumerate
(
input_node
):
for
e
,
i
in
zip
(
input_node
,
input_data
):
m
.
set_input
(
e
,
tvm
.
nd
.
array
(
i
nput_data
[
i
]
.
astype
(
input_data
[
i
]
.
dtype
)
))
m
.
set_input
(
e
,
tvm
.
nd
.
array
(
i
))
m
.
set_input
(
**
params
)
m
.
set_input
(
**
params
)
# execute
# execute
...
@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
...
@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
# get outputs
# get outputs
assert
out_names
is
None
or
num_output
==
len
(
out_names
),
(
assert
out_names
is
None
or
num_output
==
len
(
out_names
),
(
"out_names: {} num_output: {}"
.
format
(
out_names
,
num_output
))
"out_names: {} num_output: {}"
.
format
(
out_names
,
num_output
))
tvm_output_list
=
[]
tvm_output_list
=
[
m
.
get_output
(
i
)
.
asnumpy
()
for
i
in
range
(
num_output
)]
for
i
in
range
(
0
,
num_output
):
tvm_output
=
m
.
get_output
(
i
)
tvm_output_list
.
append
(
tvm_output
.
asnumpy
())
return
tvm_output_list
return
tvm_output_list
def
run_tf_graph
(
sess
,
input_data
,
input_node
,
output_node
):
def
run_tf_graph
(
sess
,
input_data
,
input_node
,
output_node
):
...
@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
...
@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
input_node
=
convert_to_list
(
input_node
)
input_node
=
convert_to_list
(
input_node
)
output_node
=
convert_to_list
(
output_node
)
output_node
=
convert_to_list
(
output_node
)
tensor
=
[
0
]
*
len
(
output_node
)
tensor
=
[
sess
.
graph
.
get_tensor_by_name
(
output_name
)
for
output_name
in
output_node
]
for
i
in
range
(
len
(
output_node
)):
tensor
[
i
]
=
sess
.
graph
.
get_tensor_by_name
(
output_node
[
i
])
input_dict
=
{}
input_dict
=
{
e
:
input_data
[
i
]
for
i
,
e
in
enumerate
(
input_node
)}
for
i
,
e
in
enumerate
(
input_node
):
input_dict
[
e
]
=
input_data
[
i
]
output_data
=
sess
.
run
(
tensor
,
input_dict
)
output_data
=
sess
.
run
(
tensor
,
input_dict
)
return
output_data
return
output_data
...
@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
...
@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
,
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
,
no_gpu
=
False
,
opt_level
=
3
):
no_gpu
=
False
,
opt_level
=
3
):
"""Generic function to generate and compare tensorflow and TVM output"""
"""Generic function to generate and compare tensorflow and TVM output"""
def
name_without_num
(
name
):
return
name
.
split
(
':'
)[
0
]
if
":"
in
name
else
name
out_name
=
convert_to_list
(
out_name
)
out_name
=
convert_to_list
(
out_name
)
out_node
=
[
0
]
*
len
(
out_name
)
out_node
=
[
name_without_num
(
name
)
for
name
in
out_name
]
for
i
in
range
(
len
(
out_name
)):
out_node
[
i
]
=
out_name
[
i
]
.
split
(
':'
)[
0
]
if
":"
in
out_name
[
i
]
else
out_name
[
i
]
in_data
=
convert_to_list
(
in_data
)
in_data
=
convert_to_list
(
in_data
)
in_name
=
convert_to_list
(
in_name
)
in_name
=
convert_to_list
(
in_name
)
in_node
=
[
0
]
*
len
(
in_name
)
in_node
=
[
name_without_num
(
name
)
for
name
in
in_name
]
for
i
in
range
(
len
(
in_name
)):
in_node
[
i
]
=
in_name
[
i
]
.
split
(
':'
)[
0
]
if
":"
in
in_name
[
i
]
else
in_name
[
i
]
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
if
init_global_variables
:
if
init_global_variables
:
sess
.
run
(
variables
.
global_variables_initializer
())
sess
.
run
(
variables
.
global_variables_initializer
())
...
@@ -578,6 +561,38 @@ def test_forward_variable():
...
@@ -578,6 +561,38 @@ def test_forward_variable():
#######################################################################
#######################################################################
# MatMul
# ------
def
_test_matmul
(
i
,
j
,
k
,
dtype
,
outer
=
None
):
""" One iteration of matmul """
A_shape_init
=
[
i
,
j
]
B_shape_init
=
[
j
,
k
]
for
transpose_a
in
[
False
,
True
]:
for
transpose_b
in
[
False
,
True
]:
outer
=
outer
or
[]
A_shape
=
outer
+
(
A_shape_init
[::
-
1
]
if
transpose_a
else
A_shape_init
)
B_shape
=
outer
+
(
B_shape_init
[::
-
1
]
if
transpose_b
else
B_shape_init
)
with
tf
.
Graph
()
.
as_default
():
A
=
tf
.
placeholder
(
shape
=
A_shape
,
dtype
=
dtype
,
name
=
'A'
)
B
=
tf
.
placeholder
(
shape
=
B_shape
,
dtype
=
dtype
,
name
=
'B'
)
result
=
tf
.
matmul
(
A
,
B
,
transpose_a
=
transpose_a
,
transpose_b
=
transpose_b
)
A_np
=
np
.
random
.
uniform
(
high
=
5.0
,
size
=
A_shape
)
.
astype
(
dtype
)
B_np
=
np
.
random
.
uniform
(
high
=
5.0
,
size
=
B_shape
)
.
astype
(
dtype
)
compare_tf_with_tvm
([
A_np
,
B_np
],
[
A
.
name
,
B
.
name
],
result
.
name
)
def
test_forward_matmul
():
""" Matmul op test"""
_test_matmul
(
1
,
3
,
6
,
'int32'
)
_test_matmul
(
5
,
3
,
1
,
'float64'
)
# TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
#######################################################################
# StridedSlice
# StridedSlice
# ------------
# ------------
...
@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
...
@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
test_forward_rel_ops
()
test_forward_rel_ops
()
test_forward_logical
()
test_forward_logical
()
test_where
()
test_where
()
test_forward_matmul
()
# TODO missing tests: rank, range
\ No newline at end of file
topi/python/topi/util.py
View file @
770ac84e
...
@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
...
@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int
out_tuple : tuple of int
The output.
The output.
"""
"""
out_tuple
=
()
return
tuple
(
get_const_int
(
elem
)
for
elem
in
in_tuple
)
for
elem
in
in_tuple
:
value
=
get_const_int
(
elem
)
out_tuple
=
out_tuple
+
(
value
,
)
return
out_tuple
def
get_float_tuple
(
in_tuple
):
def
get_float_tuple
(
in_tuple
):
...
@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
...
@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
out_tuple : tuple of float
out_tuple : tuple of float
The output.
The output.
"""
"""
out_tuple
=
()
return
tuple
(
get_const_float
(
elem
)
for
elem
in
in_tuple
)
for
elem
in
in_tuple
:
value
=
get_const_float
(
elem
)
out_tuple
=
out_tuple
+
(
value
,
)
return
out_tuple
def
simplify
(
expr
):
def
simplify
(
expr
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment