Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
06bb17ec
Unverified
Commit
06bb17ec
authored
Mar 17, 2020
by
Samuel
Committed by
GitHub
Mar 16, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Tensorflow script upgrade from 1.13.1 to 2.0.0, so that it can run in both versionsw (#4963)
parent
11ee1a0e
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
37 additions
and
14 deletions
+37
-14
python/tvm/relay/frontend/tensorflow.py
+1
-1
python/tvm/relay/frontend/tensorflow_parser.py
+1
-1
python/tvm/relay/testing/tf.py
+2
-2
tests/python/frontend/keras/test_forward.py
+8
-3
tests/python/frontend/tensorflow/test_control_flow.py
+5
-1
tests/python/frontend/tensorflow/test_debugging.py
+5
-1
tests/python/frontend/tensorflow/test_forward.py
+0
-0
tests/python/frontend/tensorflow/test_no_op.py
+4
-1
tests/python/frontend/tflite/test_forward.py
+5
-2
tutorials/frontend/from_tflite.py
+6
-2
No files found.
python/tvm/relay/frontend/tensorflow.py
View file @
06bb17ec
...
...
@@ -1259,7 +1259,7 @@ def _broadcast(name):
def
_impl
(
inputs
,
attr
,
params
):
return
AttrCvt
(
op_name
=
name
,
ignores
=
[
'name'
,
'Tidx'
]
ignores
=
[
'name'
,
'
incompatible_shape_error'
,
'
Tidx'
]
)(
inputs
,
attr
)
return
_impl
...
...
python/tvm/relay/frontend/tensorflow_parser.py
View file @
06bb17ec
...
...
@@ -73,7 +73,7 @@ class TFParser(object):
def
_get_output_names
(
self
):
"""Return the concatenated output names"""
try
:
import
tensorflow
as
tf
import
tensorflow
.compat.v1
as
tf
except
ImportError
:
raise
ImportError
(
"InputConfiguration: Unable to import tensorflow which is "
...
...
python/tvm/relay/testing/tf.py
View file @
06bb17ec
...
...
@@ -219,9 +219,9 @@ def get_workload(model_path, model_sub_path=None):
# Creates graph from saved graph_def.pb.
with
tf_compat_v1
.
gfile
.
FastGFile
(
path_model
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
=
tf
_compat_v1
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
graph
=
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
graph
=
tf
_compat_v1
.
import_graph_def
(
graph_def
,
name
=
''
)
return
graph_def
#######################################################################
...
...
tests/python/frontend/keras/test_forward.py
View file @
06bb17ec
...
...
@@ -22,11 +22,16 @@ from tvm.contrib import graph_runtime
from
tvm.relay.testing.config
import
ctx_list
import
keras
import
tensorflow
as
tf
try
:
import
tensorflow.compat.v1
as
tf
except
ImportError
:
import
tensorflow
as
tf
from
tensorflow
import
keras
as
tf_keras
from
packaging
import
version
as
package_version
# prevent Keras from using up all gpu memory
if
tf
.
executing_eagerly
():
gpus
=
tf
.
config
.
list_physical_devices
(
'GPU'
)
gpus
=
tf
.
config
.
experimental
.
list_physical_devices
(
'GPU'
)
for
gpu
in
gpus
:
tf
.
config
.
experimental
.
set_memory_growth
(
gpu
,
True
)
else
:
...
...
@@ -363,7 +368,7 @@ class TestKeras:
keras
.
layers
.
SimpleRNN
(
units
=
16
,
return_state
=
False
,
activation
=
'tanh'
),
keras
.
layers
.
GRU
(
units
=
16
,
return_state
=
False
,
recurrent_activation
=
'sigmoid'
,
activation
=
'tanh'
)]
recurrent_activation
=
'sigmoid'
,
activation
=
'tanh'
,
reset_after
=
False
)]
for
rnn_func
in
rnn_funcs
:
x
=
rnn_func
(
data
)
keras_model
=
keras
.
models
.
Model
(
data
,
x
)
...
...
tests/python/frontend/tensorflow/test_control_flow.py
View file @
06bb17ec
...
...
@@ -16,7 +16,11 @@
# under the License.
"""Unit tests for converting TensorFlow control flow op to Relay."""
import
pytest
import
tensorflow
as
tf
try
:
import
tensorflow.compat.v1
as
tf
tf
.
disable_v2_behavior
()
except
ImportError
:
import
tensorflow
as
tf
import
numpy
as
np
from
tvm
import
nd
from
tvm
import
relay
...
...
tests/python/frontend/tensorflow/test_debugging.py
View file @
06bb17ec
...
...
@@ -15,7 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import
tensorflow
as
tf
try
:
import
tensorflow.compat.v1
as
tf
tf
.
disable_v2_behavior
()
except
ImportError
:
import
tensorflow
as
tf
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay.frontend.tensorflow
import
from_tensorflow
...
...
tests/python/frontend/tensorflow/test_forward.py
View file @
06bb17ec
This diff is collapsed.
Click to expand it.
tests/python/frontend/tensorflow/test_no_op.py
View file @
06bb17ec
...
...
@@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import
tensorflow
as
tf
try
:
import
tensorflow.compat.v1
as
tf
except
ImportError
:
import
tensorflow
as
tf
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay.frontend.tensorflow
import
from_tensorflow
...
...
tests/python/frontend/tflite/test_forward.py
View file @
06bb17ec
...
...
@@ -26,7 +26,10 @@ import numpy as np
import
tvm
from
tvm
import
te
from
tvm
import
relay
import
tensorflow
as
tf
try
:
import
tensorflow.compat.v1
as
tf
except
ImportError
:
import
tensorflow
as
tf
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
ops
from
tensorflow.python.ops
import
math_ops
...
...
@@ -156,7 +159,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
if
init_global_variables
:
sess
.
run
(
variables
.
global_variables_initializer
())
# convert to tflite model
converter
=
interpreter_wrapper
.
TFLiteConverter
.
from_session
(
converter
=
tf
.
lite
.
TFLiteConverter
.
from_session
(
sess
,
input_tensors
,
output_tensors
)
if
quantized
:
...
...
tutorials/frontend/from_tflite.py
View file @
06bb17ec
...
...
@@ -99,8 +99,12 @@ tflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224.tflite")
tflite_model_buf
=
open
(
tflite_model_file
,
"rb"
)
.
read
()
# Get TFLite model from buffer
import
tflite.Model
tflite_model
=
tflite
.
Model
.
Model
.
GetRootAsModel
(
tflite_model_buf
,
0
)
try
:
import
tflite
tflite_model
=
tflite
.
Model
.
GetRootAsModel
(
tflite_model_buf
,
0
)
except
AttributeError
:
import
tflite.Model
tflite_model
=
tflite
.
Model
.
Model
.
GetRootAsModel
(
tflite_model_buf
,
0
)
######################################################################
# Load a test image
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment