Unverified Commit 06bb17ec by Samuel Committed by GitHub

Tensorflow script upgrade from 1.13.1 to 2.0.0, so that it can run in both versionsw (#4963)

parent 11ee1a0e
...@@ -1259,7 +1259,7 @@ def _broadcast(name): ...@@ -1259,7 +1259,7 @@ def _broadcast(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return AttrCvt( return AttrCvt(
op_name=name, op_name=name,
ignores=['name', 'Tidx'] ignores=['name', 'incompatible_shape_error', 'Tidx']
)(inputs, attr) )(inputs, attr)
return _impl return _impl
......
...@@ -73,7 +73,7 @@ class TFParser(object): ...@@ -73,7 +73,7 @@ class TFParser(object):
def _get_output_names(self): def _get_output_names(self):
"""Return the concatenated output names""" """Return the concatenated output names"""
try: try:
import tensorflow as tf import tensorflow.compat.v1 as tf
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"InputConfiguration: Unable to import tensorflow which is " "InputConfiguration: Unable to import tensorflow which is "
......
...@@ -219,9 +219,9 @@ def get_workload(model_path, model_sub_path=None): ...@@ -219,9 +219,9 @@ def get_workload(model_path, model_sub_path=None):
# Creates graph from saved graph_def.pb. # Creates graph from saved graph_def.pb.
with tf_compat_v1.gfile.FastGFile(path_model, 'rb') as f: with tf_compat_v1.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef() graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='') graph = tf_compat_v1.import_graph_def(graph_def, name='')
return graph_def return graph_def
####################################################################### #######################################################################
......
...@@ -22,11 +22,16 @@ from tvm.contrib import graph_runtime ...@@ -22,11 +22,16 @@ from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
import keras import keras
import tensorflow as tf try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from tensorflow import keras as tf_keras from tensorflow import keras as tf_keras
from packaging import version as package_version
# prevent Keras from using up all gpu memory # prevent Keras from using up all gpu memory
if tf.executing_eagerly(): if tf.executing_eagerly():
gpus = tf.config.list_physical_devices('GPU') gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus: for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True) tf.config.experimental.set_memory_growth(gpu, True)
else: else:
...@@ -363,7 +368,7 @@ class TestKeras: ...@@ -363,7 +368,7 @@ class TestKeras:
keras.layers.SimpleRNN(units=16, return_state=False, keras.layers.SimpleRNN(units=16, return_state=False,
activation='tanh'), activation='tanh'),
keras.layers.GRU(units=16, return_state=False, keras.layers.GRU(units=16, return_state=False,
recurrent_activation='sigmoid', activation='tanh')] recurrent_activation='sigmoid', activation='tanh', reset_after=False)]
for rnn_func in rnn_funcs: for rnn_func in rnn_funcs:
x = rnn_func(data) x = rnn_func(data)
keras_model = keras.models.Model(data, x) keras_model = keras.models.Model(data, x)
......
...@@ -16,7 +16,11 @@ ...@@ -16,7 +16,11 @@
# under the License. # under the License.
"""Unit tests for converting TensorFlow control flow op to Relay.""" """Unit tests for converting TensorFlow control flow op to Relay."""
import pytest import pytest
import tensorflow as tf try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
import numpy as np import numpy as np
from tvm import nd from tvm import nd
from tvm import relay from tvm import relay
......
...@@ -15,7 +15,11 @@ ...@@ -15,7 +15,11 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay.""" """Unit tests for converting TensorFlow debugging ops to Relay."""
import tensorflow as tf try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow from tvm.relay.frontend.tensorflow import from_tensorflow
......
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay.""" """Unit tests for converting TensorFlow debugging ops to Relay."""
import tensorflow as tf try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow from tvm.relay.frontend.tensorflow import from_tensorflow
......
...@@ -26,7 +26,10 @@ import numpy as np ...@@ -26,7 +26,10 @@ import numpy as np
import tvm import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
import tensorflow as tf try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
...@@ -156,7 +159,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, ...@@ -156,7 +159,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
if init_global_variables: if init_global_variables:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
# convert to tflite model # convert to tflite model
converter = interpreter_wrapper.TFLiteConverter.from_session( converter = tf.lite.TFLiteConverter.from_session(
sess, input_tensors, output_tensors) sess, input_tensors, output_tensors)
if quantized: if quantized:
......
...@@ -99,8 +99,12 @@ tflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224.tflite") ...@@ -99,8 +99,12 @@ tflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224.tflite")
tflite_model_buf = open(tflite_model_file, "rb").read() tflite_model_buf = open(tflite_model_file, "rb").read()
# Get TFLite model from buffer # Get TFLite model from buffer
import tflite.Model try:
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
###################################################################### ######################################################################
# Load a test image # Load a test image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment