Unverified Commit 9d646543 by Josh Fromm Committed by GitHub

[Relay][Frontend][Keras] NHWC import support. (#4899)

* Basic test working

* Almost all tests working.

* all tests passing.

* Fixed lint.

* Improved Style.
parent d1e1ac49
......@@ -21,13 +21,18 @@ from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
import keras
# prevent Keras from using up all gpu memory
import tensorflow as tf
from tensorflow import keras as tf_keras
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))
# prevent Keras from using up all gpu memory
if tf.executing_eagerly():
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
else:
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))
def pytest_generate_tests(metafunc):
......@@ -52,20 +57,27 @@ using_classic_keras = ("keras", {"keras": keras})
using_tensorflow_keras = ("tf_keras", {"keras": tf_keras})
def verify_keras_frontend(keras_model, need_transpose=True):
def verify_keras_frontend(keras_model, need_transpose=True, layout='NCHW'):
# Keras frontend currently supports tensorflow backend only.
assert(keras.backend.backend() == 'tensorflow')
if layout != 'NCHW':
need_transpose = False
in_shapes = []
for layer in keras_model._input_layers:
if tf.executing_eagerly():
in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape))
else:
in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
def get_keras_output(xs, dtype='float32'):
return keras_model.predict(xs)
def get_tvm_output(xs, target, ctx, dtype='float32'):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
mod, params = relay.frontend.from_keras(keras_model, shape_dict)
mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout)
with relay.transform.build_config(opt_level=2):
graph, lib, params = relay.build(mod,
target,
......@@ -357,28 +369,28 @@ class TestKeras:
verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_vgg16(self, keras):
def test_forward_vgg16(self, keras, layout='NCHW'):
keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, layout=layout)
def test_forward_xception(self, keras):
def test_forward_xception(self, keras, layout='NCHW'):
keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
input_shape=(299, 299, 3), classes=1000)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, layout=layout)
def test_forward_resnet50(self, keras):
def test_forward_resnet50(self, keras, layout='NCHW'):
keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, layout=layout)
def test_forward_mobilenet(self, keras):
def test_forward_mobilenet(self, keras, layout='NCHW'):
keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, layout=layout)
if __name__ == '__main__':
......@@ -402,6 +414,9 @@ if __name__ == '__main__':
sut.test_forward_reuse_layers(keras=k)
sut.test_forward_rnn(keras=k)
sut.test_forward_vgg16(keras=k)
sut.test_forward_vgg16(keras=k, layout='NHWC')
sut.test_forward_xception(keras=k)
sut.test_forward_resnet50(keras=k)
sut.test_forward_resnet50(keras=k, layout='NHWC')
sut.test_forward_mobilenet(keras=k)
sut.test_forward_mobilenet(keras=k, layout='NHWC')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment