keras.py 35.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21
# pylint: disable=invalid-name, import-self
"""Keras frontend."""
from __future__ import absolute_import as _abs
import sys
import numpy as np
Zhi committed
22
import tvm
Zhi committed
23
from .. import analysis
24
from .. import expr as _expr
25
from .. import module as _module
26 27
from .. import op as _op
from ... import nd as _nd
28
from .common import ExprTable, new_var
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

__all__ = ['from_keras']


def _check_data_format(keras_layer):
    if hasattr(keras_layer, ('data_format')):
        if keras_layer.data_format != 'channels_last':
            raise ValueError("Keras frontend currently supports data_format = channels_last only.")


def _get_pad_pair(input1d, kernel1d, stride1d):
    out1d = (input1d + stride1d - 1) // stride1d
    pad = np.maximum((out1d - 1) * stride1d + kernel1d - input1d, 0)
    pad_before = pad // 2
    pad_after = pad - pad_before
    return [pad_before, pad_after]


def _get_elu(inexpr, alpha):
    """A helper method for elu."""
49
    return _op.negative(alpha) * _op.nn.relu(_expr.const(1., dtype='float32') - \
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        _op.exp(inexpr)) + _op.nn.relu(inexpr)


def _as_list(arr):
    """Force being a list, ignore if already is."""
    if isinstance(arr, list):
        return arr
    return [arr]


def _convert_recurrent_activation(inexpr, keras_layer):
    act_type = keras_layer.recurrent_activation.__name__
    return _convert_activation(inexpr, act_type, None)


def _convert_activation(inexpr, keras_layer, _):
    if isinstance(keras_layer, str):
        act_type = keras_layer
    else:
        if sys.version_info.major < 3:
            act_type = keras_layer.activation.func_name
        else:
            act_type = keras_layer.activation.__name__
    if act_type == 'linear':
        if isinstance(keras_layer, str):
            return inexpr
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
        beta = keras_layer.beta if hasattr(keras_layer, 'beta') else 0.
        alpha = _expr.const(alpha, dtype='float32')
        beta = _expr.const(beta, dtype='float32')
        return _op.add(_op.multiply(inexpr, alpha), beta)
81
    if act_type == 'softmax':
82
        return _op.nn.softmax(inexpr, axis=1)
83
    if act_type == 'sigmoid':
84
        return _op.sigmoid(inexpr)
85
    if act_type == 'tanh':
86
        return _op.tanh(inexpr)
87
    if act_type == 'relu':
88
        return _op.nn.relu(inexpr)
89
    if act_type == 'softplus':
90
        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32')))
91
    if act_type == 'elu':
92 93 94
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
        alpha = _expr.const(alpha, dtype='float32')
        return _get_elu(inexpr, alpha)
95
    if act_type == 'selu':
96 97 98 99 100 101 102 103
        # Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \
            else 1.6732632423543772848170429916717
        gamma = keras_layer.gamma if hasattr(keras_layer, 'gamma') \
            else 1.0507009873554804934193349852946
        alpha = _expr.const(alpha, dtype='float32')
        gamma = _expr.const(gamma, dtype='float32')
        return gamma * _get_elu(inexpr, alpha)
104
    if act_type == 'relu6':
105
        return _op.clip(inexpr, a_min=0., a_max=6.)
106
    if act_type == 'softsign':
107
        return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr))
108
    if act_type == 'hard_sigmoid':
109 110
        x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
        return _op.clip(x, a_min=0., a_max=1.)
111

112 113
    raise tvm.error.OpNotImplemented(
        'Operator {} is not supported in frontend Keras.'.format(act_type))
114 115 116 117


def _convert_advanced_activation(inexpr, keras_layer, etab):
    act_type = type(keras_layer).__name__
118 119

    if act_type == 'Softmax':
120 121 122 123 124 125 126 127 128 129
        axis = keras_layer.axis
        dims = len(keras_layer.input_shape)
        if isinstance(axis, list):
            raise tvm.error.OpAttributeUnImplemented(
                'Softmax with axes {} is not supported.'.format(axis))
        if axis == -1:
            axis = 1
        else:
            axis = axis + 1 if axis < dims - 1 else 1
        return _op.nn.softmax(inexpr, axis=axis)
130
    if act_type == 'ReLU':
131 132 133 134
        threshold = _expr.const(keras_layer.threshold, dtype='float32')
        if keras_layer.max_value and float(keras_layer.threshold) == 0:
            # f(x) = max_value, for x >= max_value
            # f(x) = x,         for threshold <= x < max_value
135
            return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
136 137 138 139
        elif keras_layer.max_value and _op.greater(threshold, inexpr).astype('float32'):
            # f(x) = negative_slope * (inexpr - threshold)
            negative_slope = _expr.const(keras_layer.negative_slope, dtype='float32')
            return _op.multiply(negative_slope, _op.subtract(inexpr, threshold))
140
        return _op.nn.relu(inexpr)
141
    if act_type == 'LeakyReLU':
142
        return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
143
    if act_type == 'ELU':
144 145 146
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
        alpha = _expr.const(alpha, dtype='float32')
        return _get_elu(inexpr, alpha)
147
    if act_type == 'PReLU':
148 149 150 151 152 153
        assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
        _check_data_format(keras_layer)
        size = len(keras_layer.alpha.shape)
        alpha = etab.new_const(keras_layer.get_weights()[0] \
                               .transpose(np.roll(range(size), 1)))
        return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
154
    if act_type == 'ThresholdedReLU':
155 156 157
        theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
        return _op.multiply(inexpr, _op.greater(inexpr, \
            _expr.const(theta, dtype='float32')).astype('float32'))
158

159 160
    raise tvm.error.OpNotImplemented(
        'Operator {} is not supported in frontend Keras.'.format(act_type))
161 162 163 164 165


def _convert_merge(inexpr, keras_layer, _):
    merge_type = type(keras_layer).__name__
    ret = inexpr[0]
Yong Wu committed
166 167 168 169 170 171
    if merge_type == 'Dot':
        axes = keras_layer.axes
        if isinstance(keras_layer.axes, int):
            axes = [keras_layer.axes, keras_layer.axes]
        if isinstance(axes, list):
            if len(axes) != 2:
172
                raise tvm.error.OpAttributeUnImplemented(
Yong Wu committed
173 174 175
                    'Dot with axes {} is not supported.'.format(keras_layer.axes))
            for i, axis in enumerate(axes):
                if axis not in [1, 2]:
176
                    raise tvm.error.OpAttributeUnImplemented(
Yong Wu committed
177 178 179 180 181 182 183 184 185
                        'Dot with axes {} is not supported.'.format(keras_layer.axes))
                if axes[i] == 2:
                    inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
        else:
            raise tvm.error.OpAttributeUnImplemented(
                'Dot with axes {} is not supported.'.format(keras_layer.axes))
        ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1])
        ret = _op.transpose(ret_dot, axes=[0, 2, 1])
    elif merge_type == 'Subtract':
186 187 188 189 190 191 192 193 194 195 196
        assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
        ret = _op.subtract(ret, inexpr[1])
    elif merge_type in ['Add', 'Multiply', 'Maximum']:
        op_map = {'Add':_op.add, 'Multiply':_op.multiply, 'Maximum':_op.maximum}
        for i in range(1, len(inexpr)):
            ret = op_map[merge_type](ret, inexpr[i])
    elif merge_type == 'Average':
        for i in range(1, len(inexpr)):
            ret = _op.add(ret, inexpr[i])
        ret = ret / _expr.const(len(inexpr), dtype='float32')
    else:
197 198
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported in frontend Keras.'.format(merge_type))
199 200
    return ret

201

202 203
def _convert_permute(inexpr, keras_layer, _):
    return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
204

205

206 207 208 209 210 211 212 213 214 215
def _convert_dense(inexpr, keras_layer, etab):
    weightList = keras_layer.get_weights()
    weight = etab.new_const(weightList[0].transpose([1, 0]))
    params = {'weight':weight, 'units':weightList[0].shape[1]}
    input_shape = keras_layer.input_shape
    input_dim = len(input_shape)
    # In case of RNN dense, input shape will be (1, 1, n)
    if input_dim > 2:
        input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
        if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
216 217
            raise tvm.error.OpAttributeInvalid(
                'Input shape {} is not valid for operator Dense.'.format(input_shape))
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
        inexpr = _op.squeeze(inexpr, axis=0)
    out = _op.nn.dense(data=inexpr, **params)
    if keras_layer.use_bias:
        bias = etab.new_const(weightList[1])
        out = _op.nn.bias_add(out, bias)
    # defuse activation
    if sys.version_info.major < 3:
        act_type = keras_layer.activation.func_name
    else:
        act_type = keras_layer.activation.__name__
    if act_type != 'linear':
        out = _convert_activation(out, act_type, etab)
    if input_dim > 2:
        out = _op.expand_dims(out, axis=0)
    return out


def _convert_convolution(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    is_deconv = type(keras_layer).__name__ == 'Conv2DTranspose'
    is_depthconv = type(keras_layer).__name__ == 'DepthwiseConv2D'
    weightList = keras_layer.get_weights()
    if is_deconv:
        kernel_h, kernel_w, n_filters, in_channels = weightList[0].shape
        weight = weightList[0].transpose([3, 2, 0, 1])
    elif is_depthconv:
        kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
        weight = weightList[0].transpose([2, 3, 0, 1])
    else:
        kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape
        weight = weightList[0].transpose([3, 2, 0, 1])
    if isinstance(keras_layer.dilation_rate, (list, tuple)):
        dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
    else:
        dilation = [keras_layer.dilation_rate, keras_layer.dilation_rate]
    dilated_kernel_h = (kernel_h - 1) * dilation[0] + 1
    dilated_kernel_w = (kernel_w - 1) * dilation[1] + 1
    stride_h, stride_w = keras_layer.strides
    params = {'weight': etab.new_const(weight),
              'kernel_size': [kernel_h, kernel_w],
              'strides': [stride_h, stride_w],
              'dilation': dilation,
              'padding': [0, 0]}
    if is_depthconv:
        params['channels'] = in_channels * depth_mult
        params['groups'] = in_channels
    else:
        params['channels'] = n_filters
    if keras_layer.padding == 'valid':
        pass
    # we insert a separate pad operator
    elif keras_layer.padding == 'same':
        in_h = keras_layer.input_shape[1]
        in_w = keras_layer.input_shape[2]
        pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
        pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
        if pad_t == pad_b and pad_l == pad_r:
            params['padding'] = (pad_t, pad_l)
        else:
            inexpr = _op.nn.pad(data=inexpr, pad_width=(
                (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
    else:
280 281
        msg = 'Padding with {} is not supported for operator Convolution ' \
              'in frontend Keras.'
282
        raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
    if is_deconv:
        out = _op.nn.conv2d_transpose(data=inexpr, **params)
    else:
        out = _op.nn.conv2d(data=inexpr, **params)
    if keras_layer.use_bias:
        bias = etab.new_const(weightList[1])
        out = _op.nn.bias_add(out, bias)
    # defuse activation
    if sys.version_info.major < 3:
        act_type = keras_layer.activation.func_name
    else:
        act_type = keras_layer.activation.__name__
    if act_type != 'linear':
        out = _convert_activation(out, act_type, etab)
    return out


def _convert_separable_convolution(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    weightList = keras_layer.get_weights()
    # depthwise conv
    kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
    stride_h, stride_w = keras_layer.strides
    weight0 = weightList[0].transpose([2, 3, 0, 1])
    params0 = {'weight': etab.new_const(weight0),
               'channels': in_channels * depth_mult,
               'groups': in_channels,
               'kernel_size': [kernel_h, kernel_w],
               'strides': [stride_h, stride_w],
               'dilation': [1, 1],
               'padding': [0, 0]}
    if keras_layer.padding == 'valid':
        pass
    # we insert a separate pad operator
    elif keras_layer.padding == 'same':
        in_h = keras_layer.input_shape[1]
        in_w = keras_layer.input_shape[2]
        pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h)
        pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w)
        if pad_t == pad_b and pad_l == pad_r:
            params0['padding'] = (pad_t, pad_l)
        else:
            inexpr = _op.nn.pad(data=inexpr, pad_width=(
                (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
    else:
328 329
        msg = 'Padding with {} is not supported for operator Separable ' \
              'Convolution in frontend Keras.'
330
        raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
331

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    depthconv = _op.nn.conv2d(data=inexpr, **params0)
    # pointwise conv
    weight1 = weightList[1].transpose([3, 2, 0, 1])
    params1 = {'weight': etab.new_const(weight1),
               'channels': weight1.shape[0],
               'groups': 1,
               'kernel_size': [1, 1],
               'strides': [1, 1],
               'dilation': [1, 1]}
    out = _op.nn.conv2d(data=depthconv, **params1)
    if keras_layer.use_bias:
        bias = etab.new_const(weightList[2])
        out = _op.nn.bias_add(out, bias)
    # defuse activation
    if sys.version_info.major < 3:
        act_type = keras_layer.activation.func_name
    else:
        act_type = keras_layer.activation.__name__
    if act_type != 'linear':
        out = _convert_activation(out, act_type, etab)
    return out


def _convert_flatten(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    # NCHW -> NHWC so that dense can be correctly converted
    inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1])
    return _op.nn.batch_flatten(inexpr)


def _convert_pooling(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    pool_type = type(keras_layer).__name__
365
    # global pool in keras = global pool + flatten in relay
366 367
    if pool_type == 'GlobalMaxPooling2D':
        return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab)
368
    if pool_type == 'GlobalAveragePooling2D':
369
        return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab)
370 371 372 373 374 375 376 377 378 379 380 381 382
    pool_h, pool_w = keras_layer.pool_size
    stride_h, stride_w = keras_layer.strides
    params = {'pool_size': [pool_h, pool_w],
              'strides': [stride_h, stride_w],
              'padding': [0, 0]}
    if keras_layer.padding == 'valid':
        pass
    elif keras_layer.padding == 'same':
        in_h = keras_layer.input_shape[1]
        in_w = keras_layer.input_shape[2]
        pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
        pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
        params['padding'] = [pad_t, pad_l, pad_b, pad_r]
383
    else:
384
        raise tvm.error.OpAttributeUnImplemented(
385
            'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding))
386 387 388 389 390
    if pool_type == 'MaxPooling2D':
        return _op.nn.max_pool2d(inexpr, **params)
    if pool_type == 'AveragePooling2D':
        params['count_include_pad'] = False
        return _op.nn.avg_pool2d(inexpr, **params)
391 392
    raise tvm.error.OpNotImplemented(
        'Operator {} is not supported for frontend Keras.'.format(keras_layer))
393 394 395 396 397


def _convert_upsample(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    upsample_type = type(keras_layer).__name__
398
    params = {}
399 400
    if upsample_type == 'UpSampling1D':
        h = keras_layer.size
401
        params['scale_h'] = h
402 403 404
    elif upsample_type == 'UpSampling2D':
        h, w = keras_layer.size
        if h != w:
405 406
            raise tvm.error.OpAttributeInvalid(
                'Height must equal width for operator Upsample.')
407 408
        params['scale_h'] = h
        params['scale_w'] = h
409 410 411 412

        if hasattr(keras_layer, 'interpolation'):
            interpolation = keras_layer.interpolation
            if interpolation == 'nearest':
413
                params['method'] = 'nearest_neighbor'
414
            else:
415
                params['method'] = 'bilinear'
416

417 418 419
    elif upsample_type == 'UpSampling3D':
        h, w, d = keras_layer.size
        if h != w or w != d:
420 421
            raise tvm.error.OpAttributeInvalid(
                'Height, width, and depth must all be equal for operator Upsample.')
422 423
        params['scale_h'] = h
        params['scale_w'] = h
424
    else:
425 426
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported for frontend Keras.'.format(upsample_type))
427 428 429 430 431 432
    return _op.nn.upsampling(inexpr, **params)


def _convert_cropping(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    crop_type = type(keras_layer).__name__
433
    if crop_type == 'Cropping2D':
434 435 436
        (_, in_h, in_w, _) = keras_layer.input_shape
        ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
    else:
437 438
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported for frontend Keras.'.format(crop_type))
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
    int32_max = np.iinfo(np.int32).max
    return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
        end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])


def _convert_batchnorm(inexpr, keras_layer, etab):
    params = {'scale': False,
              'center': False,
              'epsilon': keras_layer.epsilon}
    idx = 0
    if keras_layer.scale:
        params['scale'] = True
        gamma = keras_layer.get_weights()[idx]
        params['gamma'] = etab.new_const(gamma)
        idx += 1
    if keras_layer.center:
        params['center'] = True
        beta = keras_layer.get_weights()[idx]
        params['beta'] = etab.new_const(beta)
        idx += 1
    moving_mean = keras_layer.get_weights()[idx]
    moving_var = keras_layer.get_weights()[idx + 1]
    params['moving_mean'] = etab.new_const(moving_mean)
    params['moving_var'] = etab.new_const(moving_var)
463 464 465 466 467
    # in case beta or gamma is not defined
    params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \
                     'beta' not in params else params['beta']
    params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \
                      'gamma' not in params else params['gamma']
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
    result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
    return result


def _convert_padding(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    padding_type = type(keras_layer).__name__
    padding = keras_layer.padding
    top = left = bottom = right = 0
    if padding_type == 'ZeroPadding2D':
        if isinstance(padding, int):
            top = left = bottom = right = padding
        elif isinstance(padding, tuple):
            if isinstance(padding[0], int):
                top, left = padding
                bottom, right = padding
            elif isinstance(padding[0], tuple):
                top, bottom = padding[0]
                left, right = padding[1]
            else:
488 489 490
                msg = 'Value {} in attribute "padding" of operator Padding ' \
                      'is not valid.'
                raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
491
        else:
492 493 494
            msg = 'Value {} in attribute "padding" of operator Padding is ' \
                  'not valid.'
            raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
495
    else:
496 497 498 499
        msg = 'Operator {} is not supported in frontend Keras.'
        raise tvm.error.OpNotImplemented(msg.format(padding_type))
    return _op.nn.pad(data=inexpr,
                      pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
500 501 502 503 504 505 506 507 508


def _convert_concat(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    return _op.concatenate(_as_list(inexpr), axis=1)


def _convert_reshape(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
    inshape = keras_layer.input_shape # includes batch
    tshape = keras_layer.target_shape # no batch
    if len(inshape) == 3 and len(tshape) == 1:
        # (?, a, b) -> (-1, ab)
        shape = (-1, tshape[0])
    elif len(inshape) in [2, 3] and len(tshape) == 2:
        # (?, cc) -> (-1, c, c)
        # (?, a, b) -> (-1, c, c)
        assert tshape[0] == tshape[1], \
            "Only supports square target shapes, but got {}".format(tshape)
        shape = (-1, ) + tshape
    else:
        # (?, h, w, c) -> (-1, c, H, W)
        # (?, h, w, c) -> (-1, c, hw)
        # (?, hw, c) -> (-1, c, h, w)
        ch = inshape[-1]
        assert ch == tshape[-1], \
            "Only supports last dimension in target shape being equal to " \
            "the channel number of input tensor."
        shape = (-1, ch) + tshape[:-1]
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
    return _op.reshape(inexpr, newshape=shape)


def _convert_lstm(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    if not isinstance(inexpr, list):
        buf = np.zeros((1, keras_layer.units), 'float32')
        c_op = etab.new_const(buf)
        h_op = etab.new_const(buf)
        inexpr = [inexpr, h_op, c_op]
    in_data = inexpr[0]
    next_h = inexpr[1]
    next_c = inexpr[2]
    weightList = keras_layer.get_weights()
    in_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.input_shape)[0])
    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
    in_bias = etab.new_const(weightList[2])
    units = list(weightList[0].shape)[1]
    time_steps = in_shape[1]
    in_data = _op.squeeze(in_data, axis=[0])
    in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0)
    # loop for the number of time_steps
    for data in in_data:
        ixh1 = _op.nn.dense(data, kernel_weight, units=units)
        ixh2 = _op.nn.bias_add(_op.nn.dense(next_h, recurrent_weight, units=units), bias=in_bias)
        gate = ixh1 + ixh2
        gates = _op.split(gate, indices_or_sections=4, axis=1)
        in_gate = _convert_recurrent_activation(gates[0], keras_layer)
        in_transform = _convert_recurrent_activation(gates[1], keras_layer)
        next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None)
        out_gate = _convert_recurrent_activation(gates[3], keras_layer)
        next_h = out_gate * _convert_activation(next_c, keras_layer, None)
    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
    out = _op.reshape(next_h, newshape=out_shape)
    return [out, next_h, next_c]


def _convert_simple_rnn(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    if not isinstance(inexpr, list):
        buf = np.zeros((1, keras_layer.units), 'float32')
        prev_op = etab.new_const(buf)
        inexpr = [inexpr, prev_op]
    in_data = inexpr[0]
    prev_op = inexpr[1]
    weightList = keras_layer.get_weights()
    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
    in_bias = etab.new_const(weightList[2])
    units = list(weightList[0].shape)[1]
    in_data = _op.nn.batch_flatten(in_data)
    ixh = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), bias=in_bias)
    prev_op = _op.nn.batch_flatten(prev_op)
    ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
    output = ixh + ixh2
    output = _convert_activation(output, keras_layer, None)
    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
    output = _op.reshape(output, newshape=out_shape)
    return [output, output]


def _convert_gru(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    if not isinstance(inexpr, list):
        buf = np.zeros((1, keras_layer.units), 'float32')
        h_tm1 = etab.new_const(buf)
        inexpr = [inexpr, h_tm1]
    in_data = inexpr[0]
    h_tm1_op = inexpr[1]
    weightList = keras_layer.get_weights()
    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
    in_bias = etab.new_const(weightList[2])
    units = list(weightList[0].shape)[1]
    in_data = _op.nn.batch_flatten(in_data)
    matrix_x = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), in_bias)
    # inputs projected by all gate matrices at once
    split_indices = [keras_layer.units, 2 * keras_layer.units]
    gates = _op.split(matrix_x, indices_or_sections=split_indices, axis=1)
    x_z = gates[0]
    x_r = gates[1]
    x_h = gates[2]
    # hidden state projected separately for update/reset and new
    units = 2 * keras_layer.units
    split_indices = [units]
    rec_weights = _op.split(recurrent_weight, indices_or_sections=split_indices, axis=0)
    h_tm1_op = _op.nn.batch_flatten(h_tm1_op)
    matrix_inner = _op.nn.dense(h_tm1_op, rec_weights[0], units=units)
    split_indices = [keras_layer.units]
    recurrent = _op.split(matrix_inner, indices_or_sections=split_indices, axis=1)
    recurrent_z = recurrent[0]
    recurrent_r = recurrent[1]
    rec_act_z = _convert_recurrent_activation(x_z + recurrent_z, keras_layer)
    rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer)
    units = keras_layer.units
    recurrent_h = _op.nn.dense(rec_act_r * h_tm1_op, rec_weights[1], units=units)
    act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None)
    # previous and candidate state mixed by update gate
628
    output = rec_act_z * h_tm1_op + (_expr.const(1., dtype='float32') - rec_act_z) * act_hh
629 630 631 632 633 634 635 636 637 638 639 640 641
    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
    output = _op.reshape(output, newshape=out_shape)
    return [output, output]


def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
    """Layers that can be skipped because they are train time only."""
    return inexpr


_convert_map = {
    'Dense'                    : _convert_dense,
    'Activation'               : _convert_activation,
642
    'Softmax'                  : _convert_advanced_activation,
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
    'ReLU'                     : _convert_advanced_activation,
    'LeakyReLU'                : _convert_advanced_activation,
    'PReLU'                    : _convert_advanced_activation,
    'ELU'                      : _convert_advanced_activation,
    'ThresholdedReLU'          : _convert_advanced_activation,

    'AveragePooling2D'         : _convert_pooling,
    'MaxPooling2D'             : _convert_pooling,
    'GlobalAveragePooling2D'   : _convert_pooling,
    'GlobalMaxPooling2D'       : _convert_pooling,
    'Conv2D'                   : _convert_convolution,
    'Conv2DTranspose'          : _convert_convolution,
    'DepthwiseConv2D'          : _convert_convolution,
    'SeparableConv2D'          : _convert_separable_convolution,

    'Flatten'                  : _convert_flatten,
    'Reshape'                  : _convert_reshape,
    'Concatenate'              : _convert_concat,
    'BatchNormalization'       : _convert_batchnorm,

663 664 665
    # Specific tf.Keras terminology for batch normalization
    'BatchNormalizationV1'     : _convert_batchnorm,

666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
    'Add'                      : _convert_merge,
    'Subtract'                 : _convert_merge,
    'Multiply'                 : _convert_merge,
    'ZeroPadding2D'            : _convert_padding,
    'UpSampling2D'             : _convert_upsample,
    'Cropping2D'               : _convert_cropping,

    # 'ZeroPadding1D'          : _convert_padding,
    # 'AveragePooling1D'       : _convert_pooling,
    # 'MaxPooling1D'           : _convert_pooling,
    # 'GlobalAveragePooling1D' : _convert_pooling,
    # 'GlobalMaxPooling1D'     : _convert_pooling,
    # 'Cropping1D'             : _convert_cropping,
    # 'UpSampling1D'           : _convert_upsample,
    # 'UpSampling3D'           : _convert_upsample,
    # 'Conv1D'                 : _convert_convolution1d,

    'SimpleRNN'                : _convert_simple_rnn,
    'LSTM'                     : _convert_lstm,
    'GRU'                      : _convert_gru,
    # 'Bidirectional'          : _convert_bidirectional,
    # 'TimeDistributed'        : _default_skip,

    'Average'                : _convert_merge,
    'Maximum'                : _convert_merge,
Yong Wu committed
691
    'Dot'                    : _convert_merge,
692
    'Permute'                : _convert_permute,
693 694 695 696 697 698 699 700 701 702 703
    # 'Embedding'              : _convert_embedding,
    # 'RepeatVector'           : _convert_repeat_vector,

    'InputLayer'               : _default_skip,
    'Dropout'                  : _default_skip,
    'SpatialDropout2D'         : _default_skip,
    'SpatialDropout1D'         : _default_skip,
}


def _check_unsupported_layers(model):
704
    missing_ops = set()
705
    for layer in model.layers:
706 707
        op_name = type(layer).__name__
        if op_name not in _convert_map:
708 709 710 711 712
            missing_ops.add(op_name)

    if missing_ops:
        raise NotImplementedError( \
            "The following operators are not implemented: {}".format(missing_ops))
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731


def keras_op_to_relay(inexpr, keras_layer, outname, etab):
    """Convert a Keras layer to a Relay expression and update the expression table.

    Parameters
    ----------
    inexpr : relay.expr.Expr or a list of it
        The input Relay expression(s).

    keras_layer : keras.layers
        The Keras layer to be converted.

    outname : str
        Name of the output Relay expression.

    etab : relay.frontend.common.ExprTable
        The global expression table to be updated.
    """
732 733 734 735 736
    op_name = type(keras_layer).__name__
    if op_name not in _convert_map:
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported for frontend Keras.'.format(op_name))
    outs = _convert_map[op_name](inexpr, keras_layer, etab)
737 738 739 740 741 742
    outs = _as_list(outs)
    for t_idx, out in enumerate(outs):
        name = outname + ":" + str(t_idx)
        etab.set_expr(name, out)


743
def from_keras(model, shape=None):
744 745 746 747
    """Convert keras model to relay Function.

    Parameters
    ----------
748
    model : keras.engine.training.Model or tensorflow.keras.models.Model
749 750
        The keras model to be converted.

751 752
    shape: dict of str to int list/tuple
        Input shapes of the model, optional
753 754 755

    Returns
    -------
756 757
    mod : tvm.relay.Module
        The relay module for compilation.
758 759

    params : dict of str to tvm.NDArray
760
        The parameter dict to be used by Relay.
761
    """
762 763
    def _check_model_is_tf_keras():
        return type(model).__module__.startswith("tensorflow.python.keras")
764

765 766 767 768 769
    def _convert_input_layer(keras_layer):
        input_name = keras_layer.name
        input_shape = shape[input_name] if shape is not None and input_name in shape else None
        etab.set_expr(input_name, new_var(input_name, shape=input_shape))

770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794
    is_tf_keras = _check_model_is_tf_keras()

    if not is_tf_keras:
        # Importing from Keras
        try:
            import keras
        except ImportError:
            raise ImportError("Keras must be installed")
        if keras.backend.backend() != 'tensorflow':
            raise ValueError("Keras frontend currently supports tensorflow backend only.")
        if keras.backend.image_data_format() != 'channels_last':
            raise ValueError("Keras frontend currently supports data_format = channels_last only.")
        expected_model_class = keras.engine.training.Model
        input_layer_class = keras.engine.InputLayer
    else:
        # Importing from Tensorflow Keras (tf.keras)
        try:
            from tensorflow import keras as tf_keras
        except ImportError:
            raise ImportError("Tensorflow must be installed")
        expected_model_class = tf_keras.models.Model
        input_layer_class = tf_keras.layers.InputLayer

    assert isinstance(model, expected_model_class)

795 796
    etab = ExprTable()
    for keras_layer in model.layers:
797
        if isinstance(keras_layer, input_layer_class):
798
            _convert_input_layer(keras_layer)
799 800 801 802 803 804 805 806
        else:
            inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
                       else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
                       else None
            if inbound_nodes is None:
                raise TypeError("Unknown layer type or unsupported Keras version : {}"
                                .format(keras_layer))
            for node_idx, node in enumerate(inbound_nodes):
807 808 809 810 811 812 813
                # If some nodes in imported model are not relevant to the current model,
                # skip such layers.
                # - In Keras, model._network_nodes contains keys of all nodes relevant to the
                #   current model;
                # - In tf.Keras, this is already done as part of tensorflow.keras.network.get_config
                if not is_tf_keras and \
                   not model._node_key(keras_layer, node_idx) in model._network_nodes:
814 815 816 817 818 819 820 821 822
                    continue
                inexpr = []
                # Since Keras allows creating multiple layers from the same name instance,
                # we append node index to the expr name to make it unique.
                # The one exception is InputLayer. Changing input variable names after conversion
                # would confuse users, so we should keep them as far as possible. Fortunately,
                # they are named uniquely to input_1, input_2, input_3... by default.
                zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
                for n_idx, t_idx, inbound_layer in zip_node:
823
                    if isinstance(inbound_layer, input_layer_class):
824
                        expr_name = inbound_layer.name
825
                        _convert_input_layer(inbound_layer)
826 827 828 829 830 831 832 833 834 835 836 837 838
                    else:
                        expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
                    expr = etab.get_expr(expr_name)
                    inexpr.append(expr)
                if len(inexpr) == 1:
                    inexpr = inexpr[0]
                keras_op_to_relay(inexpr, keras_layer, keras_layer.name + ':' + str(node_idx), etab)
    # model._output_coordinates contains out_node(oc[0]), node_index(oc[1]) and tensor_index(oc[2])
    # Get all output nodes in etab using the name made from above values.
    # The out exprs were added to etab in keras_op_to_relay using this name.
    outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
               for oc in model._output_coordinates]
    outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
Zhi committed
839
    func = _expr.Function(analysis.free_vars(outexpr), outexpr)
840
    params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
841
    return _module.Module.from_expr(func), params