Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
59cf5735
Commit
59cf5735
authored
Oct 02, 2019
by
Wei Chen
Committed by
Haichen Shen
Oct 02, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TF][Op] Op where (#4045)
* [TF][Op] Add TF op Where * improve tests * add tests for vm
parent
2d537621
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
21 deletions
+74
-21
python/tvm/relay/frontend/tensorflow.py
+3
-0
tests/python/frontend/tensorflow/test_forward.py
+71
-21
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
59cf5735
...
@@ -937,6 +937,8 @@ def _transpose():
...
@@ -937,6 +937,8 @@ def _transpose():
def
_where
():
def
_where
():
def
_impl
(
inputs
,
attr
,
params
):
def
_impl
(
inputs
,
attr
,
params
):
if
len
(
inputs
)
==
1
:
return
AttrCvt
(
op_name
=
"argwhere"
)(
inputs
,
attr
)
return
AttrCvt
(
op_name
=
"where"
)(
inputs
,
attr
)
return
AttrCvt
(
op_name
=
"where"
)(
inputs
,
attr
)
return
_impl
return
_impl
...
@@ -1354,6 +1356,7 @@ _convert_map = {
...
@@ -1354,6 +1356,7 @@ _convert_map = {
'Transpose'
:
_transpose
(),
'Transpose'
:
_transpose
(),
'TruncateMod'
:
_elemwise
(
'mod'
),
'TruncateMod'
:
_elemwise
(
'mod'
),
'Unpack'
:
_unpack
(),
'Unpack'
:
_unpack
(),
'Where'
:
_where
(),
'ZerosLike'
:
AttrCvt
(
'zeros_like'
),
'ZerosLike'
:
AttrCvt
(
'zeros_like'
),
}
}
...
...
tests/python/frontend/tensorflow/test_forward.py
View file @
59cf5735
...
@@ -46,8 +46,34 @@ def convert_to_list(x):
...
@@ -46,8 +46,34 @@ def convert_to_list(x):
x
=
[
x
]
x
=
[
x
]
return
x
return
x
def
vmobj_to_list
(
o
):
if
isinstance
(
o
,
tvm
.
relay
.
backend
.
vmobj
.
TensorObject
):
return
[
o
.
asnumpy
()
.
tolist
()]
elif
isinstance
(
o
,
tvm
.
relay
.
backend
.
vmobj
.
DatatypeObject
):
result
=
[]
for
f
in
o
:
result
.
extend
(
vmobj_to_list
(
f
))
return
result
elif
isinstance
(
o
,
tvm
.
relay
.
backend
.
interpreter
.
TupleValue
):
result
=
[]
for
f
in
o
.
fields
:
result
.
append
(
vmobj_to_list
(
f
))
return
result
elif
isinstance
(
o
,
tvm
.
relay
.
backend
.
interpreter
.
ConstructorValue
):
if
o
.
constructor
.
name_hint
==
'cons'
:
tl
=
vmobj_to_list
(
o
.
fields
[
1
])
hd
=
vmobj_to_list
(
o
.
fields
[
0
])
hd
.
extend
(
tl
)
return
hd
elif
o
.
constructor
.
name_hint
==
'nil'
:
return
[]
elif
isinstance
(
o
,
tvm
.
relay
.
backend
.
interpreter
.
TensorValue
):
return
[
o
.
data
.
asnumpy
()]
else
:
raise
RuntimeError
(
"Unknown object type:
%
s"
%
type
(
o
))
def
run_tvm_graph
(
graph_def
,
input_data
,
input_node
,
num_output
=
1
,
def
run_tvm_graph
(
graph_def
,
input_data
,
input_node
,
num_output
=
1
,
target
=
'llvm'
,
out_names
=
None
,
opt_level
=
3
):
target
=
'llvm'
,
out_names
=
None
,
opt_level
=
3
,
mode
=
'graph_runtime'
):
""" Generic function to compile on relay and execute on tvm """
""" Generic function to compile on relay and execute on tvm """
input_data
=
convert_to_list
(
input_data
)
input_data
=
convert_to_list
(
input_data
)
input_node
=
convert_to_list
(
input_node
)
input_node
=
convert_to_list
(
input_node
)
...
@@ -63,24 +89,32 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
...
@@ -63,24 +89,32 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout
=
layout
,
layout
=
layout
,
shape
=
shape_dict
,
shape
=
shape_dict
,
outputs
=
out_names
)
outputs
=
out_names
)
with
relay
.
build_config
(
opt_level
=
opt_level
):
if
mode
in
[
'debug'
,
'vm'
]:
graph
,
lib
,
params
=
relay
.
build
(
mod
,
target
,
target_host
,
params
)
ex
=
relay
.
create_executor
(
mode
,
mod
=
mod
,
ctx
=
tvm
.
cpu
(),
target
=
"llvm"
)
inputs
=
[]
ctx
=
tvm
.
context
(
target
,
0
)
for
param
in
mod
[
'main'
]
.
params
:
from
tvm.contrib
import
graph_runtime
inputs
.
append
(
tvm
.
nd
.
array
(
params
[
param
.
name_hint
]))
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
result
=
ex
.
evaluate
()(
*
inputs
)
# set inputs
return
vmobj_to_list
(
result
)
for
e
,
i
in
zip
(
input_node
,
input_data
):
else
:
m
.
set_input
(
e
,
tvm
.
nd
.
array
(
i
))
with
relay
.
build_config
(
opt_level
=
opt_level
):
graph
,
lib
,
params
=
relay
.
build
(
mod
,
target
,
target_host
,
params
)
m
.
set_input
(
**
params
)
# execute
ctx
=
tvm
.
context
(
target
,
0
)
m
.
run
()
from
tvm.contrib
import
graph_runtime
# get outputs
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
assert
out_names
is
None
or
num_output
==
len
(
out_names
),
(
# set inputs
"out_names: {} num_output: {}"
.
format
(
out_names
,
num_output
))
for
e
,
i
in
zip
(
input_node
,
input_data
):
tvm_output_list
=
[
m
.
get_output
(
i
)
.
asnumpy
()
for
i
in
range
(
num_output
)]
m
.
set_input
(
e
,
tvm
.
nd
.
array
(
i
))
return
tvm_output_list
m
.
set_input
(
**
params
)
# execute
m
.
run
()
# get outputs
assert
out_names
is
None
or
num_output
==
len
(
out_names
),
(
"out_names: {} num_output: {}"
.
format
(
out_names
,
num_output
))
tvm_output_list
=
[
m
.
get_output
(
i
)
.
asnumpy
()
for
i
in
range
(
num_output
)]
return
tvm_output_list
def
run_tf_graph
(
sess
,
input_data
,
input_node
,
output_node
):
def
run_tf_graph
(
sess
,
input_data
,
input_node
,
output_node
):
""" Generic function to execute tensorflow """
""" Generic function to execute tensorflow """
...
@@ -97,7 +131,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):
...
@@ -97,7 +131,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
,
def
compare_tf_with_tvm
(
in_data
,
in_name
,
out_name
,
init_global_variables
=
False
,
no_gpu
=
False
,
opt_level
=
3
):
no_gpu
=
False
,
opt_level
=
3
,
mode
=
'graph_runtime'
):
"""Generic function to generate and compare tensorflow and TVM output"""
"""Generic function to generate and compare tensorflow and TVM output"""
def
name_without_num
(
name
):
def
name_without_num
(
name
):
return
name
.
split
(
':'
)[
0
]
if
":"
in
name
else
name
return
name
.
split
(
':'
)[
0
]
if
":"
in
name
else
name
...
@@ -128,7 +162,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
...
@@ -128,7 +162,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
in_node
,
tvm_output
=
run_tvm_graph
(
final_graph_def
,
in_data
,
in_node
,
target
=
device
,
out_names
=
out_name
,
target
=
device
,
out_names
=
out_name
,
num_output
=
len
(
out_name
),
opt_level
=
opt_level
)
num_output
=
len
(
out_name
),
opt_level
=
opt_level
,
mode
=
mode
)
# since the names from tensorflow and relay runs are not exactly same,
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
# first len(tf_output) will be compared
for
i
in
range
(
len
(
tf_output
)):
for
i
in
range
(
len
(
tf_output
)):
...
@@ -325,6 +359,22 @@ def test_forward_biasadd():
...
@@ -325,6 +359,22 @@ def test_forward_biasadd():
_test_biasadd
([
4
,
17
,
17
,
19
],
'NHWC'
)
_test_biasadd
([
4
,
17
,
17
,
19
],
'NHWC'
)
_test_biasadd
([
4
,
3
,
3
,
124
],
'NHWC'
)
_test_biasadd
([
4
,
3
,
3
,
124
],
'NHWC'
)
def
_test_forward_where
(
input_shape
):
with
tf
.
Graph
()
.
as_default
():
dtype
=
tf
.
float32
t
=
tf
.
constant
(
np
.
random
.
choice
([
0
,
1
,
-
2
,
3
,
-
1
,
0.1
,
-
0.2
],
size
=
input_shape
)
.
astype
(
dtype
.
name
))
out
=
tf
.
where
(
t
)
compare_tf_with_tvm
([],
[],
out
.
name
,
mode
=
'debug'
)
compare_tf_with_tvm
([],
[],
out
.
name
,
mode
=
'vm'
)
def
test_forward_argwhere
():
_test_forward_where
((
5
,))
_test_forward_where
((
5
,
5
))
_test_forward_where
((
5
,
5
,
5
))
_test_forward_where
((
5
,
5
,
5
,
5
))
_test_forward_where
((
5
,
5
,
5
,
5
,
5
))
#######################################################################
#######################################################################
# SpaceToBatchND
# SpaceToBatchND
# --------------
# --------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment