test_forward.py 23.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
import numpy as np
import tvm
19
from tvm import te
20 21 22 23 24
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
import keras

25 26 27 28 29
try:
    import tensorflow.compat.v1 as tf
except ImportError:
    import tensorflow as tf

30
from tensorflow import keras as tf_keras
31
from packaging import version as package_version
32 33
# prevent Keras from using up all gpu memory
if tf.executing_eagerly():
34
    gpus = tf.config.experimental.list_physical_devices('GPU')
35 36 37 38 39 40 41
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
else:
    from keras.backend.tensorflow_backend import set_session
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    set_session(tf.Session(config=config))
42 43


44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
def pytest_generate_tests(metafunc):
    # This function generates the list of tests for pytest, based
    # on scenatios that will change the parameters in which the
    # tests use to run.
    # https://docs.pytest.org/en/latest/example/parametrize.html
    idlist = []
    argvalues = []
    for scenario in metafunc.cls.scenarios:
        idlist.append(scenario[0])
        items = scenario[1].items()
        argnames = [x[0] for x in items]
        argvalues.append([x[1] for x in items])
    metafunc.parametrize(argnames, argvalues, ids=idlist, scope="class")


# Scenarios:
# - classic keras, using keras from "import keras"
# - tensorflow keras, using keras from "from tensorflow import keras as tf_keras"
using_classic_keras = ("keras", {"keras": keras})
using_tensorflow_keras = ("tf_keras", {"keras": tf_keras})


66
def verify_keras_frontend(keras_model, need_transpose=True, layout='NCHW'):
67 68 69
    # Keras frontend currently supports tensorflow backend only.
    assert(keras.backend.backend() == 'tensorflow')

70 71 72
    if layout != 'NCHW':
        need_transpose = False

73 74
    in_shapes = []
    for layer in keras_model._input_layers:
75 76 77 78 79
        if tf.executing_eagerly():
            in_shapes.append(tuple(dim if dim is not None else 1 for dim in layer.input.shape))
        else:
            in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))

80 81 82 83 84 85

    def get_keras_output(xs, dtype='float32'):
        return keras_model.predict(xs)

    def get_tvm_output(xs, target, ctx, dtype='float32'):
        shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
86
        mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout)
87
        with relay.transform.build_config(opt_level=2):
88
            graph, lib, params = relay.build(mod,
89 90
                                             target,
                                             params=params)
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        m = graph_runtime.create(graph, lib, ctx)
        for name, x in zip(keras_model.input_names, xs):
            m.set_input(name, tvm.nd.array(x.astype(dtype)))
        m.set_input(**params)
        m.run()
        return [m.get_output(i).asnumpy() for i in range(m.get_num_outputs())]

    def to_channels_first(arr):
        return arr.transpose([0, -1] + list(range(1, arr.ndim - 1)))

    def to_channels_last(arr):
        return arr.transpose([0] + list(range(2, arr.ndim)) + [1])

    xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
    keras_out = get_keras_output(xs)
    keras_out = keras_out if isinstance(keras_out, list) else [keras_out]
    for target, ctx in ctx_list():
        inputs = [to_channels_first(x) for x in xs] if need_transpose else xs
        tvm_out = get_tvm_output(inputs, target, ctx)
        for kout, tout in zip(keras_out, tvm_out):
            if need_transpose:
                tout = to_channels_last(tout)
            tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5)


116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
class TestKeras:
    scenarios = [using_classic_keras, using_tensorflow_keras]

    def test_forward_merge(self, keras):
        data = keras.layers.Input(shape=(32, 32, 3))
        x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
        y = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
        z = keras.layers.Conv2D(8, (3, 3), padding="same")(y)
        merge_funcs = [keras.layers.Add(),
                    keras.layers.Subtract(),
                    keras.layers.Multiply(),
                    keras.layers.Maximum(),
                    keras.layers.Average(),
                    keras.layers.Concatenate()]
        for merge_func in merge_funcs:
            class_name = type(merge_func).__name__
            if class_name in ('Subtract', 'Dot'):
                out = merge_func([x, y])
            else:
                out = merge_func([x, y, z])
            keras_model = keras.models.Model(data, out)
            verify_keras_frontend(keras_model)

    def test_forward_merge_dot(self, keras):
        data1 = keras.layers.Input(shape=(2, 2))
        data2 = keras.layers.Input(shape=(2, 2))
        merge_funcs = [keras.layers.Dot(axes=[1, 2]),
                    keras.layers.Dot(axes=[2, 1]),
                    keras.layers.Dot(axes=[1, 1]),
                    keras.layers.Dot(axes=[2, 2]),
                    keras.layers.Dot(axes=1),
                    keras.layers.Dot(axes=2)]
        for merge_func in merge_funcs:
            out = merge_func([data1, data2])
            keras_model = keras.models.Model([data1, data2], out)
            verify_keras_frontend(keras_model)

    def test_forward_activations(self, keras):
        data = keras.layers.Input(shape=(32, 32, 3))
        act_funcs = [keras.layers.Activation('softmax'),
                    keras.layers.Softmax(),
                    keras.layers.Softmax(axis=-1),
                    keras.layers.Softmax(axis=1),
                    keras.layers.Softmax(axis=2),
                    keras.layers.Softmax(axis=3),
                    keras.layers.Activation('softplus'),
                    keras.layers.Activation('relu'),
                    keras.layers.Activation('softsign'),
                    keras.layers.Activation('hard_sigmoid'),
                    keras.layers.Activation('sigmoid'),
                    keras.layers.Activation('tanh'),
                    keras.layers.Activation('linear'),
                    keras.layers.Activation('selu'),
                    keras.layers.ReLU(),
                    keras.layers.ReLU(max_value=6.),
                    keras.layers.ReLU(max_value=6., threshold=0.),
                    keras.layers.ReLU(max_value=6., threshold=1.),
                    keras.layers.ReLU(max_value=6., threshold=1., negative_slope=0.),
                    keras.layers.ReLU(max_value=6., threshold=1., negative_slope=0.5),
                    keras.layers.ReLU(max_value=6., threshold=1., negative_slope=1.),
                    keras.layers.LeakyReLU(alpha=0.3),
                    keras.layers.PReLU(weights=np.random.rand(1, 32, 32, 3)),
                    keras.layers.ELU(alpha=0.5),
                    keras.layers.ThresholdedReLU(theta=0.5)]
        for act_func in act_funcs:
            x = act_func(data)
            keras_model = keras.models.Model(data, x)
            verify_keras_frontend(keras_model)


    def test_forward_dense(self, keras):
        data = keras.layers.Input(shape=(32, 32, 1))
        x = keras.layers.Flatten()(data)
        x = keras.layers.Dropout(0.5)(x)
        x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
        keras_model = keras.models.Model(data, x)
192 193
        verify_keras_frontend(keras_model)

194 195 196 197 198 199 200 201 202 203 204 205 206 207
    def test_forward_permute(self, keras):
        data = keras.layers.Input(shape=(2, 3, 4))
        x = keras.layers.Permute([2, 3, 1])(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model, need_transpose=False)

    def test_forward_sequential(self, keras):
        keras_model = keras.models.Sequential([
            keras.layers.Dense(16, input_dim=32, activation='relu'),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(8, activation='relu'),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(1, activation='sigmoid')
        ])
Yong Wu committed
208
        verify_keras_frontend(keras_model)
209

210 211 212 213 214

    def test_forward_pool(self, keras):
        data = keras.layers.Input(shape=(32, 32, 1))
        # maxpool
        x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
215 216
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)
217 218 219 220
        # avgpool
        y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data)
        keras_model = keras.models.Model(data, y)
        verify_keras_frontend(keras_model)
221 222


223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
    def test_forward_conv(self, keras):
        data = keras.layers.Input(shape=(32, 32, 3))
        conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
                                        strides=(2, 2), padding='same'),
                    keras.layers.Conv2D(filters=10, kernel_size=(3, 3),
                                        dilation_rate=(2, 2), padding='same'),
                    keras.layers.Conv2D(filters=1, kernel_size=(3, 3), padding='same'),
                    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
                    keras.layers.Conv2DTranspose(filters=10, kernel_size=(3, 3), padding='valid'),
                    keras.layers.SeparableConv2D(filters=10, kernel_size=(3, 3), padding='same')]
        for conv_func in conv_funcs:
            x = conv_func(data)
            keras_model = keras.models.Model(data, x)
            verify_keras_frontend(keras_model)

    def test_forward_batch_norm(self, keras):
        data = keras.layers.Input(shape=(32, 32, 3))
        batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
                                                            center=True, scale=False,
                                                            beta_initializer='zeros',
                                                            gamma_initializer='ones',
                                                            moving_mean_initializer='zeros',
                                                            moving_variance_initializer='ones'),
                        keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
                                                            center=True, scale=True,
                                                            beta_initializer='zeros',
                                                            gamma_initializer='ones',
                                                            moving_mean_initializer='zeros',
                                                            moving_variance_initializer='ones'),
                        keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
                                                            center=False, scale=True,
                                                            beta_initializer='zeros',
                                                            gamma_initializer='ones',
                                                            moving_mean_initializer='zeros',
                                                            moving_variance_initializer='ones'),
                        keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
                                                            center=False, scale=False,
                                                            beta_initializer='zeros',
                                                            gamma_initializer='ones',
                                                            moving_mean_initializer='zeros',
                                                            moving_variance_initializer='ones')]
        for batch_norm_func in batch_norm_funcs:
            x = batch_norm_func(data)
            keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)

    def test_forward_upsample(self, keras, interpolation='nearest'):
        data = keras.layers.Input(shape=(32, 32, 3))
        x = keras.layers.UpSampling2D(size=(3, 3), interpolation=interpolation)(data)
272 273 274
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)

275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299

    def test_forward_reshape(self, keras):
        # input_shape len is 3, target_shape len is 3
        data = keras.layers.Input(shape=(32, 32, 3))
        x = keras.layers.Reshape(target_shape=(16, 64, 3))(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)
        # input_shape len is 3, target_shape len is 2
        data = keras.layers.Input(shape=(32, 8, 3))
        x = keras.layers.Reshape(target_shape=(256, 3))(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)
        # input_shape len is 2, target_shape len is 3
        data = keras.layers.Input(shape=(256, 3))
        x = keras.layers.Reshape(target_shape=(8, 32, 3))(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)
        # input_shape len is 2, target_shape len is 1
        data = keras.layers.Input(shape=(2, 8))
        x = keras.layers.Reshape(target_shape=(16,))(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model, need_transpose=False)
        # input_shape len is 1, target_shape len is 2
        data = keras.layers.Input(shape=(16,))
        x = keras.layers.Reshape(target_shape=(4, 4))(data)
300
        keras_model = keras.models.Model(data, x)
301 302 303 304
        verify_keras_frontend(keras_model, need_transpose=False)
        # input_shape len is 2, target_shape len is 2
        data = keras.layers.Input(shape=(2, 8))
        x = keras.layers.Reshape(target_shape=(4, 4))(data)
305 306 307 308
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model, need_transpose=False)


309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
    def test_forward_crop(self, keras):
        data = keras.layers.Input(shape=(32, 32, 3))
        x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
        x = keras.layers.Cropping2D(cropping=(1, 1))(x)
        x = keras.layers.Cropping2D(cropping=1)(x)
        x = keras.layers.Cropping2D(cropping=((0, 1), (1, 0)))(x)
        x = keras.layers.Cropping2D(cropping=(1, 0))(x)
        x = keras.layers.Cropping2D(cropping=0)(x)
        x = keras.layers.Add()([x, x])
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)


    def test_forward_multi_inputs(self, keras):
        data1 = keras.layers.Input(shape=(32, 32, 3))
        data2 = keras.layers.Input(shape=(32, 32, 3))
        x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
        y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2)
        z = keras.layers.Average()([x, y])
        z = keras.layers.GlobalAveragePooling2D()(z)
        keras_model = keras.models.Model([data1, data2], z)
        verify_keras_frontend(keras_model)
331 332


333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
    def test_forward_multi_outputs(self, keras):
        data = keras.layers.Input(shape=(32, 32, 3))
        x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
        x = keras.layers.GlobalAveragePooling2D()(x)
        y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
        y = keras.layers.GlobalAveragePooling2D()(y)
        keras_model = keras.models.Model(data, [x, y])
        verify_keras_frontend(keras_model)


    def test_forward_reuse_layers(self, keras):
        # reuse conv2d
        data = keras.layers.Input(shape=(32, 32, 3))
        conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
        x = conv2d(data)
        y = conv2d(data)
        z = keras.layers.Add()([x, y])
        z = keras.layers.GlobalAveragePooling2D()(z)
        keras_model = keras.models.Model(data, z)
        verify_keras_frontend(keras_model)
        # reuse add
        data = keras.layers.Input(shape=(32, 32, 3))
        x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
        add = keras.layers.Add()
        x = add([x, x])
        x = add([x, x])
        z = keras.layers.GlobalAveragePooling2D()(x)
        keras_model = keras.models.Model(data, z)
        verify_keras_frontend(keras_model)
362 363


364 365 366 367 368 369 370
    def test_forward_rnn(self,keras):
        data = keras.layers.Input(shape=(1, 32))
        rnn_funcs = [keras.layers.LSTM(units=16, return_state=False,
                        recurrent_activation='sigmoid', activation='tanh'),
                    keras.layers.SimpleRNN(units=16, return_state=False,
                        activation='tanh'),
                    keras.layers.GRU(units=16, return_state=False,
371
                        recurrent_activation='sigmoid', activation='tanh', reset_after=False)]
372 373 374 375
        for rnn_func in rnn_funcs:
            x = rnn_func(data)
            keras_model = keras.models.Model(data, x)
            verify_keras_frontend(keras_model, need_transpose=False)
376 377


378
    def test_forward_vgg16(self, keras, layout='NCHW'):
379 380
        keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
            input_shape=(224, 224, 3), classes=1000)
381
        verify_keras_frontend(keras_model, layout=layout)
382 383


384
    def test_forward_xception(self, keras, layout='NCHW'):
385 386
        keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
            input_shape=(299, 299, 3), classes=1000)
387
        verify_keras_frontend(keras_model, layout=layout)
388 389


390
    def test_forward_resnet50(self, keras, layout='NCHW'):
391 392
        keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
            input_shape=(224, 224, 3), classes=1000)
393
        verify_keras_frontend(keras_model, layout=layout)
394 395


396
    def test_forward_mobilenet(self, keras, layout='NCHW'):
397 398
        keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
            input_shape=(224, 224, 3), classes=1000)
399
        verify_keras_frontend(keras_model, layout=layout)
400

401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
    def test_forward_conv3d(self, keras):
        data = keras.layers.Input(shape=(32, 32, 32, 3))
        conv_funcs = [keras.layers.Conv3D(filters=10,
                                          kernel_size=(3, 3, 3),
                                          strides=(2, 2, 2),
                                          padding='same'),
                      keras.layers.Conv3D(filters=10,
                                          kernel_size=(3, 3, 3),
                                          dilation_rate=(2, 2, 2),
                                          padding='same'),
                      keras.layers.Conv3D(filters=1,
                                          kernel_size=(3, 3, 3),
                                          padding='valid',
                                          use_bias=False),
                      keras.layers.Conv3D(filters=10,
                                          kernel_size=(2, 2, 2),
                                          padding='valid'),
                    ]
        for conv_func in conv_funcs:
            x = conv_func(data)
            keras_model = keras.models.Model(data, x)
            verify_keras_frontend(keras_model, layout='NDHWC')
423

424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
    def test_forward_pool3d(self, keras):
        data = keras.layers.Input(shape=(32, 32, 32, 1))
        pool_funcs = [# maxpool
                      keras.layers.MaxPooling3D(pool_size=(2, 2, 2),
                                                strides=(1, 1, 1),
                                                padding='same'),
                      keras.layers.MaxPooling3D(pool_size=(3, 3, 3),
                                                strides=(2, 2, 2),
                                                padding='valid'),
                      # avgpool
                      keras.layers.AveragePooling3D(pool_size=(3, 3, 3),
                                                    strides=(2, 2, 2),
                                                    padding='same'),
                      keras.layers.AveragePooling3D(pool_size=(2, 2, 2),
                                                    strides=(1, 1, 1),
                                                    padding='valid'),
                     ]
        for pool_func in pool_funcs:
            x = pool_func(data)
            keras_model = keras.models.Model(data, x)
            verify_keras_frontend(keras_model, layout='NDHWC')

446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
    def test_forward_upsample3d(self, keras):
        data = keras.layers.Input(shape=(32, 32, 32, 3))
        x = keras.layers.UpSampling3D(size=(2, 3, 4))(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model, layout='NDHWC')

    def test_forward_zero_padding3d(self, keras):
        data = keras.layers.Input(shape=(32, 32, 32, 3))
        pad_funcs = [# Integer
                     keras.layers.ZeroPadding3D(padding=2),
                     # tuple of 3 ints
                     keras.layers.ZeroPadding3D(padding=(1, 2, 3)),
                     # tuple of 3 tuples of 2 ints
                     keras.layers.ZeroPadding3D(padding=((1,1), (2,2), (2,2))),
                     # tuple of 3 tuples of 2 ints different values
                     keras.layers.ZeroPadding3D(padding=((1,2), (2,3), (3,2))),
                    ]
        for pad_func in pad_funcs:
            x = pad_func(data)
            keras_model = keras.models.Model(data, x)
            verify_keras_frontend(keras_model, layout='NDHWC')

468
if __name__ == '__main__':
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
    for k in [keras, tf_keras]:
        sut = TestKeras()
        sut.test_forward_merge_dot(keras=k)
        sut.test_forward_merge(keras=k)
        sut.test_forward_activations(keras=k)
        sut.test_forward_dense(keras=k)
        sut.test_forward_permute(keras=k)
        sut.test_forward_sequential(keras=k)
        sut.test_forward_pool(keras=k)
        sut.test_forward_conv(keras=k)
        sut.test_forward_batch_norm(keras=k)
        sut.test_forward_upsample(keras=k, interpolation='nearest')
        sut.test_forward_upsample(keras=k, interpolation='bilinear')
        sut.test_forward_reshape(keras=k)
        sut.test_forward_crop(keras=k)
        sut.test_forward_multi_inputs(keras=k)
        sut.test_forward_multi_outputs(keras=k)
        sut.test_forward_reuse_layers(keras=k)
        sut.test_forward_rnn(keras=k)
        sut.test_forward_vgg16(keras=k)
489
        sut.test_forward_vgg16(keras=k, layout='NHWC')
490 491
        sut.test_forward_xception(keras=k)
        sut.test_forward_resnet50(keras=k)
492
        sut.test_forward_resnet50(keras=k, layout='NHWC')
493
        sut.test_forward_mobilenet(keras=k)
494
        sut.test_forward_mobilenet(keras=k, layout='NHWC')
495
        sut.test_forward_conv3d(keras=k)
496
        sut.test_forward_pool3d(keras=k)
497 498 499
        sut.test_forward_upsample3d(keras=k)
        sut.test_forward_zero_padding3d(keras=k)