Unverified Commit 72f2aea2 by Ramana Radhakrishnan Committed by GitHub

Tf2 test fixups (#5391)

* Fix oversight in importing tf.compat.v1 as tf.

* Actually disable test for lstm in TF2.1

Since the testing framework actually uses pytest, the version
check needs to be moved.
parent be54c984
......@@ -22,7 +22,10 @@ in TensorFlow frontend when mean and variance are not given.
"""
import tvm
import numpy as np
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
except ImportError:
import tensorflow as tf
from tvm import relay
from tensorflow.python.framework import graph_util
......
......@@ -1901,7 +1901,9 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
def test_forward_lstm():
'''test LSTM block cell'''
_test_lstm_cell(1, 2, 1, 0.5, 'float32')
if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
_test_lstm_cell(1, 2, 1, 0.5, 'float32')
#######################################################################
......@@ -3308,9 +3310,7 @@ if __name__ == '__main__':
test_forward_ptb()
# RNN
if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
#in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
test_forward_lstm()
test_forward_lstm()
# Elementwise
test_forward_ceil()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment