Unverified Commit 9d646543 by Josh Fromm Committed by GitHub

[Relay][Frontend][Keras] NHWC import support. (#4899)

* Basic test working

* Almost all tests working.

* all tests passing.

* Fixed lint.

* Improved Style.
parent d1e1ac49
...@@ -21,13 +21,18 @@ from tvm.contrib import graph_runtime ...@@ -21,13 +21,18 @@ from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
import keras import keras
# prevent Keras from using up all gpu memory
import tensorflow as tf import tensorflow as tf
from tensorflow import keras as tf_keras from tensorflow import keras as tf_keras
from keras.backend.tensorflow_backend import set_session # prevent Keras from using up all gpu memory
config = tf.ConfigProto() if tf.executing_eagerly():
config.gpu_options.per_process_gpu_memory_fraction = 0.5 gpus = tf.config.list_physical_devices('GPU')
set_session(tf.Session(config=config)) for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
else:
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
...@@ -52,20 +57,27 @@ using_classic_keras = ("keras", {"keras": keras}) ...@@ -52,20 +57,27 @@ using_classic_keras = ("keras", {"keras": keras})
using_tensorflow_keras = ("tf_keras", {"keras": tf_keras}) using_tensorflow_keras = ("tf_keras", {"keras": tf_keras})
def verify_keras_frontend(keras_model, need_transpose=True): def verify_keras_frontend(keras_model, need_transpose=True, layout='NCHW'):
# Keras frontend currently supports tensorflow backend only. # Keras frontend currently supports tensorflow backend only.
assert(keras.backend.backend() == 'tensorflow') assert(keras.backend.backend() == 'tensorflow')
if layout != 'NCHW':
need_transpose = False
in_shapes = [] in_shapes = []
for layer in keras_model._input_layers: for layer in keras_model._input_layers:
if tf.executing_eagerly():
in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape))
else:
in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape)) in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
def get_keras_output(xs, dtype='float32'): def get_keras_output(xs, dtype='float32'):
return keras_model.predict(xs) return keras_model.predict(xs)
def get_tvm_output(xs, target, ctx, dtype='float32'): def get_tvm_output(xs, target, ctx, dtype='float32'):
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
mod, params = relay.frontend.from_keras(keras_model, shape_dict) mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout)
with relay.transform.build_config(opt_level=2): with relay.transform.build_config(opt_level=2):
graph, lib, params = relay.build(mod, graph, lib, params = relay.build(mod,
target, target,
...@@ -357,28 +369,28 @@ class TestKeras: ...@@ -357,28 +369,28 @@ class TestKeras:
verify_keras_frontend(keras_model, need_transpose=False) verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_vgg16(self, keras): def test_forward_vgg16(self, keras, layout='NCHW'):
keras_model = keras.applications.VGG16(include_top=True, weights='imagenet', keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model, layout=layout)
def test_forward_xception(self, keras): def test_forward_xception(self, keras, layout='NCHW'):
keras_model = keras.applications.Xception(include_top=True, weights='imagenet', keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
input_shape=(299, 299, 3), classes=1000) input_shape=(299, 299, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model, layout=layout)
def test_forward_resnet50(self, keras): def test_forward_resnet50(self, keras, layout='NCHW'):
keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet', keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model, layout=layout)
def test_forward_mobilenet(self, keras): def test_forward_mobilenet(self, keras, layout='NCHW'):
keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet', keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
input_shape=(224, 224, 3), classes=1000) input_shape=(224, 224, 3), classes=1000)
verify_keras_frontend(keras_model) verify_keras_frontend(keras_model, layout=layout)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -402,6 +414,9 @@ if __name__ == '__main__': ...@@ -402,6 +414,9 @@ if __name__ == '__main__':
sut.test_forward_reuse_layers(keras=k) sut.test_forward_reuse_layers(keras=k)
sut.test_forward_rnn(keras=k) sut.test_forward_rnn(keras=k)
sut.test_forward_vgg16(keras=k) sut.test_forward_vgg16(keras=k)
sut.test_forward_vgg16(keras=k, layout='NHWC')
sut.test_forward_xception(keras=k) sut.test_forward_xception(keras=k)
sut.test_forward_resnet50(keras=k) sut.test_forward_resnet50(keras=k)
sut.test_forward_resnet50(keras=k, layout='NHWC')
sut.test_forward_mobilenet(keras=k) sut.test_forward_mobilenet(keras=k)
sut.test_forward_mobilenet(keras=k, layout='NHWC')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment