Unverified Commit 72f2aea2 by Ramana Radhakrishnan Committed by GitHub

Tf2 test fixups (#5391)

* Fix oversight in importing tf.compat.v1 as tf.

* Actually disable test for lstm in TF2.1

Since the testing framework actually uses pytest, the version
check needs to be moved.
parent be54c984
...@@ -22,7 +22,10 @@ in TensorFlow frontend when mean and variance are not given. ...@@ -22,7 +22,10 @@ in TensorFlow frontend when mean and variance are not given.
""" """
import tvm import tvm
import numpy as np import numpy as np
import tensorflow as tf try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from tvm import relay from tvm import relay
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
......
...@@ -1901,7 +1901,9 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): ...@@ -1901,7 +1901,9 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
def test_forward_lstm(): def test_forward_lstm():
'''test LSTM block cell''' '''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.5, 'float32') if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
_test_lstm_cell(1, 2, 1, 0.5, 'float32')
####################################################################### #######################################################################
...@@ -3308,9 +3310,7 @@ if __name__ == '__main__': ...@@ -3308,9 +3310,7 @@ if __name__ == '__main__':
test_forward_ptb() test_forward_ptb()
# RNN # RNN
if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'): test_forward_lstm()
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
test_forward_lstm()
# Elementwise # Elementwise
test_forward_ceil() test_forward_ceil()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment