Unverified Commit 06bb17ec by Samuel Committed by GitHub

Tensorflow script upgrade from 1.13.1 to 2.0.0, so that it can run in both versionsw (#4963)

parent 11ee1a0e
......@@ -1259,7 +1259,7 @@ def _broadcast(name):
def _impl(inputs, attr, params):
return AttrCvt(
op_name=name,
ignores=['name', 'Tidx']
ignores=['name', 'incompatible_shape_error', 'Tidx']
)(inputs, attr)
return _impl
......
......@@ -73,7 +73,7 @@ class TFParser(object):
def _get_output_names(self):
"""Return the concatenated output names"""
try:
import tensorflow as tf
import tensorflow.compat.v1 as tf
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
......
......@@ -219,9 +219,9 @@ def get_workload(model_path, model_sub_path=None):
# Creates graph from saved graph_def.pb.
with tf_compat_v1.gfile.FastGFile(path_model, 'rb') as f:
graph_def = tf.GraphDef()
graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
graph = tf_compat_v1.import_graph_def(graph_def, name='')
return graph_def
#######################################################################
......
......@@ -22,11 +22,16 @@ from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
import keras
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from tensorflow import keras as tf_keras
from packaging import version as package_version
# prevent Keras from using up all gpu memory
if tf.executing_eagerly():
gpus = tf.config.list_physical_devices('GPU')
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
else:
......@@ -363,7 +368,7 @@ class TestKeras:
keras.layers.SimpleRNN(units=16, return_state=False,
activation='tanh'),
keras.layers.GRU(units=16, return_state=False,
recurrent_activation='sigmoid', activation='tanh')]
recurrent_activation='sigmoid', activation='tanh', reset_after=False)]
for rnn_func in rnn_funcs:
x = rnn_func(data)
keras_model = keras.models.Model(data, x)
......
......@@ -16,7 +16,11 @@
# under the License.
"""Unit tests for converting TensorFlow control flow op to Relay."""
import pytest
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
import numpy as np
from tvm import nd
from tvm import relay
......
......@@ -15,7 +15,11 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
import numpy as np
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow
......
......@@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
import numpy as np
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow
......
......@@ -26,7 +26,10 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
......@@ -156,7 +159,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
if init_global_variables:
sess.run(variables.global_variables_initializer())
# convert to tflite model
converter = interpreter_wrapper.TFLiteConverter.from_session(
converter = tf.lite.TFLiteConverter.from_session(
sess, input_tensors, output_tensors)
if quantized:
......
......@@ -99,8 +99,12 @@ tflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224.tflite")
tflite_model_buf = open(tflite_model_file, "rb").read()
# Get TFLite model from buffer
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
try:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
######################################################################
# Load a test image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment