test_forward.py 11 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
import keras

# prevent keras from using up all gpu memory
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.5
set_session(tf.Session(config=config))


def verify_keras_frontend(keras_model, need_transpose=True):
    # Keras frontend currently supports tensorflow backend only.
    assert(keras.backend.backend() == 'tensorflow')

    in_shapes = []
    for layer in keras_model._input_layers:
        in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))

    def get_keras_output(xs, dtype='float32'):
        return keras_model.predict(xs)

    def get_tvm_output(xs, target, ctx, dtype='float32'):
        shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
        func, params = relay.frontend.from_keras(keras_model, shape_dict)
        with relay.build_module.build_config(opt_level=2):
            graph, lib, params = relay.build(func, target, params=params)
        m = graph_runtime.create(graph, lib, ctx)
        for name, x in zip(keras_model.input_names, xs):
            m.set_input(name, tvm.nd.array(x.astype(dtype)))
        m.set_input(**params)
        m.run()
        return [m.get_output(i).asnumpy() for i in range(m.get_num_outputs())]

    def to_channels_first(arr):
        return arr.transpose([0, -1] + list(range(1, arr.ndim - 1)))

    def to_channels_last(arr):
        return arr.transpose([0] + list(range(2, arr.ndim)) + [1])

    xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
    keras_out = get_keras_output(xs)
    keras_out = keras_out if isinstance(keras_out, list) else [keras_out]
    for target, ctx in ctx_list():
        inputs = [to_channels_first(x) for x in xs] if need_transpose else xs
        tvm_out = get_tvm_output(inputs, target, ctx)
        for kout, tout in zip(keras_out, tvm_out):
            if need_transpose:
                tout = to_channels_last(tout)
            tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5)


def test_forward_merge():
    data = keras.layers.Input(shape=(32,32,3))
    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
    y = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
    z = keras.layers.Conv2D(8, (3, 3), padding="same")(y)
    merge_funcs = [keras.layers.Add(),
                   keras.layers.Subtract(),
                   keras.layers.Multiply(),
                   keras.layers.Maximum(),
                   keras.layers.Average(),
                   keras.layers.Concatenate()]
    for merge_func in merge_funcs:
        if isinstance(merge_func, keras.layers.merge.Subtract):
            out = merge_func([x, y])
        else:
            out = merge_func([x, y, z])
        keras_model = keras.models.Model(data, out)
        verify_keras_frontend(keras_model)


def test_forward_activations():
    data = keras.layers.Input(shape=(32,32,3))
    act_funcs = [keras.layers.Activation('softmax'),
                 keras.layers.Activation('softplus'),
                 keras.layers.Activation('relu'),
                 keras.layers.Activation('softsign'),
                 keras.layers.Activation('hard_sigmoid'),
                 keras.layers.Activation('sigmoid'),
                 keras.layers.Activation('tanh'),
                 keras.layers.Activation('linear'),
                 keras.layers.Activation('selu'),
                 keras.layers.ReLU(),
                 keras.layers.ReLU(max_value=6.),
                 keras.layers.LeakyReLU(alpha=0.3),
                 keras.layers.PReLU(weights=np.random.rand(1, 32, 32, 3)),
                 keras.layers.ELU(alpha=0.5),
                 keras.layers.ThresholdedReLU(theta=0.5)]
    for act_func in act_funcs:
        x = act_func(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)


def test_forward_dense():
    data = keras.layers.Input(shape=(32,32,1))
    x = keras.layers.Flatten()(data)
    x = keras.layers.Dropout(0.5)(x)
    x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
    keras_model = keras.models.Model(data, x)
    verify_keras_frontend(keras_model)


125 126 127 128 129 130 131 132 133 134 135
def test_forward_sequential():
    keras_model = keras.models.Sequential([
        keras.layers.Dense(16, input_dim=32, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(8, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(1, activation='sigmoid')
    ])
    verify_keras_frontend(keras_model)


136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
def test_forward_pool():
    data = keras.layers.Input(shape=(32,32,1))
    # maxpool
    x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
    keras_model = keras.models.Model(data, x)
    verify_keras_frontend(keras_model)
    # avgpool
    y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data)
    keras_model = keras.models.Model(data, y)
    verify_keras_frontend(keras_model)


def test_forward_conv():
    data = keras.layers.Input(shape=(32,32,3))
    conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3,3),
                                      strides=(2,2), padding='same'),
                  keras.layers.Conv2D(filters=10, kernel_size=(3,3),
                                      dilation_rate=(2,2), padding='same'),
                  keras.layers.DepthwiseConv2D(kernel_size=(3,3), padding='same'),
                  keras.layers.Conv2DTranspose(filters=10, kernel_size=(3,3), padding='valid'),
                  keras.layers.SeparableConv2D(filters=10, kernel_size=(3,3), padding='same')]
    for conv_func in conv_funcs:
        x = conv_func(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model)


163
def test_forward_upsample(interpolation='nearest'):
164
    data = keras.layers.Input(shape=(32,32,3))
165
    x = keras.layers.UpSampling2D(size=(3,3), interpolation=interpolation)(data)
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    keras_model = keras.models.Model(data, x)
    verify_keras_frontend(keras_model)


def test_forward_reshape():
    data = keras.layers.Input(shape=(32,32,3))
    x = keras.layers.Reshape(target_shape=(32,32,3))(data)
    keras_model = keras.models.Model(data, x)
    verify_keras_frontend(keras_model)


def test_forward_crop():
    data = keras.layers.Input(shape=(32,32,3))
    x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
    x = keras.layers.Cropping2D(cropping=(1, 1))(x)
    x = keras.layers.Cropping2D(cropping=1)(x)
    x = keras.layers.Cropping2D(cropping=((0, 1), (1, 0)))(x)
    x = keras.layers.Cropping2D(cropping=(1, 0))(x)
    x = keras.layers.Cropping2D(cropping=0)(x)
    x = keras.layers.Add()([x, x])
    keras_model = keras.models.Model(data, x)
    verify_keras_frontend(keras_model)


def test_forward_multi_inputs():
    data1 = keras.layers.Input(shape=(32,32,3))
    data2 = keras.layers.Input(shape=(32,32,3))
    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
    y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2)
    z = keras.layers.Average()([x, y])
    z = keras.layers.GlobalAveragePooling2D()(z)
    keras_model = keras.models.Model([data1, data2], z)
    verify_keras_frontend(keras_model)


def test_forward_multi_outputs():
    data = keras.layers.Input(shape=(32,32,3))
    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
    x = keras.layers.GlobalAveragePooling2D()(x)
    y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
    y = keras.layers.GlobalAveragePooling2D()(y)
    keras_model = keras.models.Model(data, [x, y])
    verify_keras_frontend(keras_model)


def test_forward_reuse_layers():
    # reuse conv2d
    data = keras.layers.Input(shape=(32,32,3))
    conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
    x = conv2d(data)
    y = conv2d(data)
    z = keras.layers.Add()([x, y])
    z = keras.layers.GlobalAveragePooling2D()(z)
    keras_model = keras.models.Model(data, z)
    verify_keras_frontend(keras_model)
    # reuse add
    data = keras.layers.Input(shape=(32,32,3))
    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
    add = keras.layers.Add()
    x = add([x, x])
    x = add([x, x])
    z = keras.layers.GlobalAveragePooling2D()(x)
    keras_model = keras.models.Model(data, z)
    verify_keras_frontend(keras_model)


def test_forward_rnn():
    data = keras.layers.Input(shape=(1,32))
    rnn_funcs = [keras.layers.LSTM(units=16, return_state=False,
                    recurrent_activation='sigmoid', activation='tanh'),
                 keras.layers.SimpleRNN(units=16, return_state=False,
                    activation='tanh'),
                 keras.layers.GRU(units=16, return_state=False,
                    recurrent_activation='sigmoid', activation='tanh')]
    for rnn_func in rnn_funcs:
        x = rnn_func(data)
        keras_model = keras.models.Model(data, x)
        verify_keras_frontend(keras_model, need_transpose=False)


def test_forward_vgg16():
    keras_model = keras.applications.VGG16(include_top=True, weights='imagenet',
        input_shape=(224,224,3), classes=1000)
    verify_keras_frontend(keras_model)


def test_forward_xception():
    keras_model = keras.applications.Xception(include_top=True, weights='imagenet',
        input_shape=(299,299,3), classes=1000)
    verify_keras_frontend(keras_model)


def test_forward_resnet50():
    keras_model = keras.applications.ResNet50(include_top=True, weights='imagenet',
        input_shape=(224,224,3), classes=1000)
    verify_keras_frontend(keras_model)


def test_forward_mobilenet():
    keras_model = keras.applications.MobileNet(include_top=True, weights='imagenet',
        input_shape=(224,224,3), classes=1000)
    verify_keras_frontend(keras_model)


if __name__ == '__main__':
    test_forward_merge()
    test_forward_activations()
    test_forward_dense()
274
    test_forward_sequential()
275 276
    test_forward_pool()
    test_forward_conv()
277 278
    test_forward_upsample(interpolation='nearest')
    test_forward_upsample(interpolation='bilinear')
279 280 281 282 283 284 285 286 287 288
    test_forward_reshape()
    test_forward_crop()
    test_forward_multi_inputs()
    test_forward_multi_outputs()
    test_forward_reuse_layers()
    test_forward_rnn()
    test_forward_vgg16()
    test_forward_xception()
    test_forward_resnet50()
    test_forward_mobilenet()