keras.py 31.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21
# pylint: disable=invalid-name, import-self
"""Keras frontend."""
from __future__ import absolute_import as _abs
import sys
import numpy as np
Zhi committed
22
import tvm
Zhi committed
23
from .. import analysis
24
from .. import expr as _expr
25
from .. import module as _module
26 27
from .. import op as _op
from ... import nd as _nd
28
from .common import ExprTable, new_var
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

__all__ = ['from_keras']


def _check_data_format(keras_layer):
    if hasattr(keras_layer, ('data_format')):
        if keras_layer.data_format != 'channels_last':
            raise ValueError("Keras frontend currently supports data_format = channels_last only.")


def _get_pad_pair(input1d, kernel1d, stride1d):
    out1d = (input1d + stride1d - 1) // stride1d
    pad = np.maximum((out1d - 1) * stride1d + kernel1d - input1d, 0)
    pad_before = pad // 2
    pad_after = pad - pad_before
    return [pad_before, pad_after]


def _get_elu(inexpr, alpha):
    """A helper method for elu."""
49
    return _op.negative(alpha) * _op.nn.relu(_expr.const(1., dtype='float32') - \
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        _op.exp(inexpr)) + _op.nn.relu(inexpr)


def _as_list(arr):
    """Force being a list, ignore if already is."""
    if isinstance(arr, list):
        return arr
    return [arr]


def _convert_recurrent_activation(inexpr, keras_layer):
    act_type = keras_layer.recurrent_activation.__name__
    return _convert_activation(inexpr, act_type, None)


def _convert_activation(inexpr, keras_layer, _):
    if isinstance(keras_layer, str):
        act_type = keras_layer
    else:
        if sys.version_info.major < 3:
            act_type = keras_layer.activation.func_name
        else:
            act_type = keras_layer.activation.__name__
    if act_type == 'linear':
        if isinstance(keras_layer, str):
            return inexpr
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
        beta = keras_layer.beta if hasattr(keras_layer, 'beta') else 0.
        alpha = _expr.const(alpha, dtype='float32')
        beta = _expr.const(beta, dtype='float32')
        return _op.add(_op.multiply(inexpr, alpha), beta)
81
    if act_type == 'softmax':
82
        return _op.nn.softmax(inexpr, axis=1)
83
    if act_type == 'sigmoid':
84
        return _op.sigmoid(inexpr)
85
    if act_type == 'tanh':
86
        return _op.tanh(inexpr)
87
    if act_type == 'relu':
88
        return _op.nn.relu(inexpr)
89
    if act_type == 'softplus':
90
        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32')))
91
    if act_type == 'elu':
92 93 94
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
        alpha = _expr.const(alpha, dtype='float32')
        return _get_elu(inexpr, alpha)
95
    if act_type == 'selu':
96 97 98 99 100 101 102 103
        # Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \
            else 1.6732632423543772848170429916717
        gamma = keras_layer.gamma if hasattr(keras_layer, 'gamma') \
            else 1.0507009873554804934193349852946
        alpha = _expr.const(alpha, dtype='float32')
        gamma = _expr.const(gamma, dtype='float32')
        return gamma * _get_elu(inexpr, alpha)
104
    if act_type == 'relu6':
105
        return _op.clip(inexpr, a_min=0., a_max=6.)
106
    if act_type == 'softsign':
107
        return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr))
108
    if act_type == 'hard_sigmoid':
109 110
        x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
        return _op.clip(x, a_min=0., a_max=1.)
111

112 113
    raise tvm.error.OpNotImplemented(
        'Operator {} is not supported in frontend Keras.'.format(act_type))
114 115 116 117


def _convert_advanced_activation(inexpr, keras_layer, etab):
    act_type = type(keras_layer).__name__
118 119 120

    if act_type == 'Softmax':
        return _op.nn.softmax(inexpr, axis=1)
121 122 123 124
    if act_type == 'ReLU':
        if keras_layer.max_value:
            return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
        return _op.nn.relu(inexpr)
125
    if act_type == 'LeakyReLU':
126
        return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
127
    if act_type == 'ELU':
128 129 130
        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
        alpha = _expr.const(alpha, dtype='float32')
        return _get_elu(inexpr, alpha)
131
    if act_type == 'PReLU':
132 133 134 135 136 137
        assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
        _check_data_format(keras_layer)
        size = len(keras_layer.alpha.shape)
        alpha = etab.new_const(keras_layer.get_weights()[0] \
                               .transpose(np.roll(range(size), 1)))
        return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
138
    if act_type == 'ThresholdedReLU':
139 140 141
        theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
        return _op.multiply(inexpr, _op.greater(inexpr, \
            _expr.const(theta, dtype='float32')).astype('float32'))
142

143 144
    raise tvm.error.OpNotImplemented(
        'Operator {} is not supported in frontend Keras.'.format(act_type))
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161


def _convert_merge(inexpr, keras_layer, _):
    merge_type = type(keras_layer).__name__
    ret = inexpr[0]
    if merge_type == 'Subtract':
        assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
        ret = _op.subtract(ret, inexpr[1])
    elif merge_type in ['Add', 'Multiply', 'Maximum']:
        op_map = {'Add':_op.add, 'Multiply':_op.multiply, 'Maximum':_op.maximum}
        for i in range(1, len(inexpr)):
            ret = op_map[merge_type](ret, inexpr[i])
    elif merge_type == 'Average':
        for i in range(1, len(inexpr)):
            ret = _op.add(ret, inexpr[i])
        ret = ret / _expr.const(len(inexpr), dtype='float32')
    else:
162 163
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported in frontend Keras.'.format(merge_type))
164 165
    return ret

166 167
def _convert_permute(inexpr, keras_layer, _):
    return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
168 169 170 171 172 173 174 175 176 177 178

def _convert_dense(inexpr, keras_layer, etab):
    weightList = keras_layer.get_weights()
    weight = etab.new_const(weightList[0].transpose([1, 0]))
    params = {'weight':weight, 'units':weightList[0].shape[1]}
    input_shape = keras_layer.input_shape
    input_dim = len(input_shape)
    # In case of RNN dense, input shape will be (1, 1, n)
    if input_dim > 2:
        input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
        if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
179 180
            raise tvm.error.OpAttributeInvalid(
                'Input shape {} is not valid for operator Dense.'.format(input_shape))
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        inexpr = _op.squeeze(inexpr, axis=0)
    out = _op.nn.dense(data=inexpr, **params)
    if keras_layer.use_bias:
        bias = etab.new_const(weightList[1])
        out = _op.nn.bias_add(out, bias)
    # defuse activation
    if sys.version_info.major < 3:
        act_type = keras_layer.activation.func_name
    else:
        act_type = keras_layer.activation.__name__
    if act_type != 'linear':
        out = _convert_activation(out, act_type, etab)
    if input_dim > 2:
        out = _op.expand_dims(out, axis=0)
    return out


def _convert_convolution(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    is_deconv = type(keras_layer).__name__ == 'Conv2DTranspose'
    is_depthconv = type(keras_layer).__name__ == 'DepthwiseConv2D'
    weightList = keras_layer.get_weights()
    if is_deconv:
        kernel_h, kernel_w, n_filters, in_channels = weightList[0].shape
        weight = weightList[0].transpose([3, 2, 0, 1])
    elif is_depthconv:
        kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
        weight = weightList[0].transpose([2, 3, 0, 1])
    else:
        kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape
        weight = weightList[0].transpose([3, 2, 0, 1])
    if isinstance(keras_layer.dilation_rate, (list, tuple)):
        dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
    else:
        dilation = [keras_layer.dilation_rate, keras_layer.dilation_rate]
    dilated_kernel_h = (kernel_h - 1) * dilation[0] + 1
    dilated_kernel_w = (kernel_w - 1) * dilation[1] + 1
    stride_h, stride_w = keras_layer.strides
    params = {'weight': etab.new_const(weight),
              'kernel_size': [kernel_h, kernel_w],
              'strides': [stride_h, stride_w],
              'dilation': dilation,
              'padding': [0, 0]}
    if is_depthconv:
        params['channels'] = in_channels * depth_mult
        params['groups'] = in_channels
    else:
        params['channels'] = n_filters
    if keras_layer.padding == 'valid':
        pass
    # we insert a separate pad operator
    elif keras_layer.padding == 'same':
        in_h = keras_layer.input_shape[1]
        in_w = keras_layer.input_shape[2]
        pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
        pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
        if pad_t == pad_b and pad_l == pad_r:
            params['padding'] = (pad_t, pad_l)
        else:
            inexpr = _op.nn.pad(data=inexpr, pad_width=(
                (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
    else:
243 244
        msg = 'Padding with {} is not supported for operator Convolution ' \
              'in frontend Keras.'
245
        raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    if is_deconv:
        out = _op.nn.conv2d_transpose(data=inexpr, **params)
    else:
        out = _op.nn.conv2d(data=inexpr, **params)
    if keras_layer.use_bias:
        bias = etab.new_const(weightList[1])
        out = _op.nn.bias_add(out, bias)
    # defuse activation
    if sys.version_info.major < 3:
        act_type = keras_layer.activation.func_name
    else:
        act_type = keras_layer.activation.__name__
    if act_type != 'linear':
        out = _convert_activation(out, act_type, etab)
    return out


def _convert_separable_convolution(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    weightList = keras_layer.get_weights()
    # depthwise conv
    kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
    stride_h, stride_w = keras_layer.strides
    weight0 = weightList[0].transpose([2, 3, 0, 1])
    params0 = {'weight': etab.new_const(weight0),
               'channels': in_channels * depth_mult,
               'groups': in_channels,
               'kernel_size': [kernel_h, kernel_w],
               'strides': [stride_h, stride_w],
               'dilation': [1, 1],
               'padding': [0, 0]}
    if keras_layer.padding == 'valid':
        pass
    # we insert a separate pad operator
    elif keras_layer.padding == 'same':
        in_h = keras_layer.input_shape[1]
        in_w = keras_layer.input_shape[2]
        pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h)
        pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w)
        if pad_t == pad_b and pad_l == pad_r:
            params0['padding'] = (pad_t, pad_l)
        else:
            inexpr = _op.nn.pad(data=inexpr, pad_width=(
                (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
    else:
291 292
        msg = 'Padding with {} is not supported for operator Separable ' \
              'Convolution in frontend Keras.'
293
        raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
294

295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
    depthconv = _op.nn.conv2d(data=inexpr, **params0)
    # pointwise conv
    weight1 = weightList[1].transpose([3, 2, 0, 1])
    params1 = {'weight': etab.new_const(weight1),
               'channels': weight1.shape[0],
               'groups': 1,
               'kernel_size': [1, 1],
               'strides': [1, 1],
               'dilation': [1, 1]}
    out = _op.nn.conv2d(data=depthconv, **params1)
    if keras_layer.use_bias:
        bias = etab.new_const(weightList[2])
        out = _op.nn.bias_add(out, bias)
    # defuse activation
    if sys.version_info.major < 3:
        act_type = keras_layer.activation.func_name
    else:
        act_type = keras_layer.activation.__name__
    if act_type != 'linear':
        out = _convert_activation(out, act_type, etab)
    return out


def _convert_flatten(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    # NCHW -> NHWC so that dense can be correctly converted
    inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1])
    return _op.nn.batch_flatten(inexpr)


def _convert_pooling(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    pool_type = type(keras_layer).__name__
    # global pool in keras = global pool + flatten in nnvm/relay
    if pool_type == 'GlobalMaxPooling2D':
        return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab)
331
    if pool_type == 'GlobalAveragePooling2D':
332
        return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab)
333 334 335 336 337 338 339 340 341 342 343 344 345
    pool_h, pool_w = keras_layer.pool_size
    stride_h, stride_w = keras_layer.strides
    params = {'pool_size': [pool_h, pool_w],
              'strides': [stride_h, stride_w],
              'padding': [0, 0]}
    if keras_layer.padding == 'valid':
        pass
    elif keras_layer.padding == 'same':
        in_h = keras_layer.input_shape[1]
        in_w = keras_layer.input_shape[2]
        pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
        pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
        params['padding'] = [pad_t, pad_l, pad_b, pad_r]
346
    else:
347 348
        raise tvm.error.OpAttributeUnimplemented(
            'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding))
349 350 351 352 353
    if pool_type == 'MaxPooling2D':
        return _op.nn.max_pool2d(inexpr, **params)
    if pool_type == 'AveragePooling2D':
        params['count_include_pad'] = False
        return _op.nn.avg_pool2d(inexpr, **params)
354 355
    raise tvm.error.OpNotImplemented(
        'Operator {} is not supported for frontend Keras.'.format(keras_layer))
356 357 358 359 360 361 362 363 364 365 366


def _convert_upsample(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    upsample_type = type(keras_layer).__name__
    if upsample_type == 'UpSampling1D':
        h = keras_layer.size
        params = {'scale': h}
    elif upsample_type == 'UpSampling2D':
        h, w = keras_layer.size
        if h != w:
367 368
            raise tvm.error.OpAttributeInvalid(
                'Height must equal width for operator Upsample.')
369
        params = {'scale': h}
370 371 372 373 374 375 376 377

        if hasattr(keras_layer, 'interpolation'):
            interpolation = keras_layer.interpolation
            if interpolation == 'nearest':
                params['method'] = 'NEAREST_NEIGHBOR'
            else:
                params['method'] = 'BILINEAR'

378 379 380
    elif upsample_type == 'UpSampling3D':
        h, w, d = keras_layer.size
        if h != w or w != d:
381 382
            raise tvm.error.OpAttributeInvalid(
                'Height, width, and depth must all be equal for operator Upsample.')
383 384
        params = {'scale': h}
    else:
385 386
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported for frontend Keras.'.format(upsample_type))
387 388 389 390 391 392
    return _op.nn.upsampling(inexpr, **params)


def _convert_cropping(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    crop_type = type(keras_layer).__name__
393
    if crop_type == 'Cropping2D':
394 395 396
        (_, in_h, in_w, _) = keras_layer.input_shape
        ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
    else:
397 398
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported for frontend Keras.'.format(crop_type))
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
    int32_max = np.iinfo(np.int32).max
    return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
        end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])


def _convert_batchnorm(inexpr, keras_layer, etab):
    params = {'scale': False,
              'center': False,
              'epsilon': keras_layer.epsilon}
    idx = 0
    if keras_layer.scale:
        params['scale'] = True
        gamma = keras_layer.get_weights()[idx]
        params['gamma'] = etab.new_const(gamma)
        idx += 1
    if keras_layer.center:
        params['center'] = True
        beta = keras_layer.get_weights()[idx]
        params['beta'] = etab.new_const(beta)
        idx += 1
    moving_mean = keras_layer.get_weights()[idx]
    moving_var = keras_layer.get_weights()[idx + 1]
    params['moving_mean'] = etab.new_const(moving_mean)
    params['moving_var'] = etab.new_const(moving_var)
    result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
    return result


def _convert_padding(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    padding_type = type(keras_layer).__name__
    padding = keras_layer.padding
    top = left = bottom = right = 0
    if padding_type == 'ZeroPadding2D':
        if isinstance(padding, int):
            top = left = bottom = right = padding
        elif isinstance(padding, tuple):
            if isinstance(padding[0], int):
                top, left = padding
                bottom, right = padding
            elif isinstance(padding[0], tuple):
                top, bottom = padding[0]
                left, right = padding[1]
            else:
443 444 445
                msg = 'Value {} in attribute "padding" of operator Padding ' \
                      'is not valid.'
                raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
446
        else:
447 448 449
            msg = 'Value {} in attribute "padding" of operator Padding is ' \
                  'not valid.'
            raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
450
    else:
451 452 453 454
        msg = 'Operator {} is not supported in frontend Keras.'
        raise tvm.error.OpNotImplemented(msg.format(padding_type))
    return _op.nn.pad(data=inexpr,
                      pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567


def _convert_concat(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    return _op.concatenate(_as_list(inexpr), axis=1)


def _convert_reshape(inexpr, keras_layer, _):
    _check_data_format(keras_layer)
    ch = keras_layer.input_shape[-1]
    assert ch == keras_layer.target_shape[-1], \
        "Only supports last dimension in target shape being equal to " \
        "the channel number of input tensor."
    shape = (-1, ch) + keras_layer.target_shape[:-1]
    return _op.reshape(inexpr, newshape=shape)


def _convert_lstm(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    if not isinstance(inexpr, list):
        buf = np.zeros((1, keras_layer.units), 'float32')
        c_op = etab.new_const(buf)
        h_op = etab.new_const(buf)
        inexpr = [inexpr, h_op, c_op]
    in_data = inexpr[0]
    next_h = inexpr[1]
    next_c = inexpr[2]
    weightList = keras_layer.get_weights()
    in_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.input_shape)[0])
    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
    in_bias = etab.new_const(weightList[2])
    units = list(weightList[0].shape)[1]
    time_steps = in_shape[1]
    in_data = _op.squeeze(in_data, axis=[0])
    in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0)
    # loop for the number of time_steps
    for data in in_data:
        ixh1 = _op.nn.dense(data, kernel_weight, units=units)
        ixh2 = _op.nn.bias_add(_op.nn.dense(next_h, recurrent_weight, units=units), bias=in_bias)
        gate = ixh1 + ixh2
        gates = _op.split(gate, indices_or_sections=4, axis=1)
        in_gate = _convert_recurrent_activation(gates[0], keras_layer)
        in_transform = _convert_recurrent_activation(gates[1], keras_layer)
        next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None)
        out_gate = _convert_recurrent_activation(gates[3], keras_layer)
        next_h = out_gate * _convert_activation(next_c, keras_layer, None)
    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
    out = _op.reshape(next_h, newshape=out_shape)
    return [out, next_h, next_c]


def _convert_simple_rnn(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    if not isinstance(inexpr, list):
        buf = np.zeros((1, keras_layer.units), 'float32')
        prev_op = etab.new_const(buf)
        inexpr = [inexpr, prev_op]
    in_data = inexpr[0]
    prev_op = inexpr[1]
    weightList = keras_layer.get_weights()
    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
    in_bias = etab.new_const(weightList[2])
    units = list(weightList[0].shape)[1]
    in_data = _op.nn.batch_flatten(in_data)
    ixh = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), bias=in_bias)
    prev_op = _op.nn.batch_flatten(prev_op)
    ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
    output = ixh + ixh2
    output = _convert_activation(output, keras_layer, None)
    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
    output = _op.reshape(output, newshape=out_shape)
    return [output, output]


def _convert_gru(inexpr, keras_layer, etab):
    _check_data_format(keras_layer)
    if not isinstance(inexpr, list):
        buf = np.zeros((1, keras_layer.units), 'float32')
        h_tm1 = etab.new_const(buf)
        inexpr = [inexpr, h_tm1]
    in_data = inexpr[0]
    h_tm1_op = inexpr[1]
    weightList = keras_layer.get_weights()
    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
    in_bias = etab.new_const(weightList[2])
    units = list(weightList[0].shape)[1]
    in_data = _op.nn.batch_flatten(in_data)
    matrix_x = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), in_bias)
    # inputs projected by all gate matrices at once
    split_indices = [keras_layer.units, 2 * keras_layer.units]
    gates = _op.split(matrix_x, indices_or_sections=split_indices, axis=1)
    x_z = gates[0]
    x_r = gates[1]
    x_h = gates[2]
    # hidden state projected separately for update/reset and new
    units = 2 * keras_layer.units
    split_indices = [units]
    rec_weights = _op.split(recurrent_weight, indices_or_sections=split_indices, axis=0)
    h_tm1_op = _op.nn.batch_flatten(h_tm1_op)
    matrix_inner = _op.nn.dense(h_tm1_op, rec_weights[0], units=units)
    split_indices = [keras_layer.units]
    recurrent = _op.split(matrix_inner, indices_or_sections=split_indices, axis=1)
    recurrent_z = recurrent[0]
    recurrent_r = recurrent[1]
    rec_act_z = _convert_recurrent_activation(x_z + recurrent_z, keras_layer)
    rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer)
    units = keras_layer.units
    recurrent_h = _op.nn.dense(rec_act_r * h_tm1_op, rec_weights[1], units=units)
    act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None)
    # previous and candidate state mixed by update gate
568
    output = rec_act_z * h_tm1_op + (_expr.const(1., dtype='float32') - rec_act_z) * act_hh
569 570 571 572 573 574 575 576 577 578 579 580 581
    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
    output = _op.reshape(output, newshape=out_shape)
    return [output, output]


def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
    """Layers that can be skipped because they are train time only."""
    return inexpr


_convert_map = {
    'Dense'                    : _convert_dense,
    'Activation'               : _convert_activation,
582
    'Softmax'                  : _convert_advanced_activation,
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
    'ReLU'                     : _convert_advanced_activation,
    'LeakyReLU'                : _convert_advanced_activation,
    'PReLU'                    : _convert_advanced_activation,
    'ELU'                      : _convert_advanced_activation,
    'ThresholdedReLU'          : _convert_advanced_activation,

    'AveragePooling2D'         : _convert_pooling,
    'MaxPooling2D'             : _convert_pooling,
    'GlobalAveragePooling2D'   : _convert_pooling,
    'GlobalMaxPooling2D'       : _convert_pooling,
    'Conv2D'                   : _convert_convolution,
    'Conv2DTranspose'          : _convert_convolution,
    'DepthwiseConv2D'          : _convert_convolution,
    'SeparableConv2D'          : _convert_separable_convolution,

    'Flatten'                  : _convert_flatten,
    'Reshape'                  : _convert_reshape,
    'Concatenate'              : _convert_concat,
    'BatchNormalization'       : _convert_batchnorm,

    'Add'                      : _convert_merge,
    'Subtract'                 : _convert_merge,
    'Multiply'                 : _convert_merge,
    'ZeroPadding2D'            : _convert_padding,
    'UpSampling2D'             : _convert_upsample,
    'Cropping2D'               : _convert_cropping,

    # 'ZeroPadding1D'          : _convert_padding,
    # 'AveragePooling1D'       : _convert_pooling,
    # 'MaxPooling1D'           : _convert_pooling,
    # 'GlobalAveragePooling1D' : _convert_pooling,
    # 'GlobalMaxPooling1D'     : _convert_pooling,
    # 'Cropping1D'             : _convert_cropping,
    # 'UpSampling1D'           : _convert_upsample,
    # 'UpSampling3D'           : _convert_upsample,
    # 'Conv1D'                 : _convert_convolution1d,

    'SimpleRNN'                : _convert_simple_rnn,
    'LSTM'                     : _convert_lstm,
    'GRU'                      : _convert_gru,
    # 'Bidirectional'          : _convert_bidirectional,
    # 'TimeDistributed'        : _default_skip,

    'Average'                : _convert_merge,
    'Maximum'                : _convert_merge,
    # 'Dot'                    : _convert_merge,
629
    'Permute'                : _convert_permute,
630 631 632 633 634 635 636 637 638 639 640
    # 'Embedding'              : _convert_embedding,
    # 'RepeatVector'           : _convert_repeat_vector,

    'InputLayer'               : _default_skip,
    'Dropout'                  : _default_skip,
    'SpatialDropout2D'         : _default_skip,
    'SpatialDropout1D'         : _default_skip,
}


def _check_unsupported_layers(model):
641
    missing_ops = set()
642
    for layer in model.layers:
643 644
        op_name = type(layer).__name__
        if op_name not in _convert_map:
645 646 647 648 649
            missing_ops.add(op_name)

    if missing_ops:
        raise NotImplementedError( \
            "The following operators are not implemented: {}".format(missing_ops))
650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668


def keras_op_to_relay(inexpr, keras_layer, outname, etab):
    """Convert a Keras layer to a Relay expression and update the expression table.

    Parameters
    ----------
    inexpr : relay.expr.Expr or a list of it
        The input Relay expression(s).

    keras_layer : keras.layers
        The Keras layer to be converted.

    outname : str
        Name of the output Relay expression.

    etab : relay.frontend.common.ExprTable
        The global expression table to be updated.
    """
669 670 671 672 673
    op_name = type(keras_layer).__name__
    if op_name not in _convert_map:
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported for frontend Keras.'.format(op_name))
    outs = _convert_map[op_name](inexpr, keras_layer, etab)
674 675 676 677 678 679
    outs = _as_list(outs)
    for t_idx, out in enumerate(outs):
        name = outname + ":" + str(t_idx)
        etab.set_expr(name, out)


680
def from_keras(model, shape=None):
681 682 683 684 685 686 687
    """Convert keras model to relay Function.

    Parameters
    ----------
    model : keras.engine.training.Model
        The keras model to be converted.

688 689
    shape: dict of str to int list/tuple
        Input shapes of the model, optional
690 691 692

    Returns
    -------
693 694
    mod : tvm.relay.Module
        The relay module for compilation.
695 696

    params : dict of str to tvm.NDArray
697
        The parameter dict to be used by Relay.
698 699 700 701 702 703 704 705 706 707 708 709
    """
    try:
        import keras
    except ImportError:
        raise ImportError('Keras must be installed')
    assert isinstance(model, keras.engine.training.Model)
    if keras.backend.backend() != 'tensorflow':
        raise ValueError("Keras frontend currently supports tensorflow backend only.")
    if keras.backend.image_data_format() != 'channels_last':
        raise ValueError("Keras frontend currently supports data_format = channels_last only.")
    _check_unsupported_layers(model)

710 711 712 713 714
    def _convert_input_layer(keras_layer):
        input_name = keras_layer.name
        input_shape = shape[input_name] if shape is not None and input_name in shape else None
        etab.set_expr(input_name, new_var(input_name, shape=input_shape))

715 716 717
    etab = ExprTable()
    for keras_layer in model.layers:
        if isinstance(keras_layer, keras.engine.InputLayer):
718
            _convert_input_layer(keras_layer)
719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741
        else:
            inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
                       else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
                       else None
            if inbound_nodes is None:
                raise TypeError("Unknown layer type or unsupported Keras version : {}"
                                .format(keras_layer))
            for node_idx, node in enumerate(inbound_nodes):
                # If some nodes in imported model is not relevant to the current model,
                # skip such layers. model._network_nodes contains keys of all nodes relevant
                # to the current model.
                if not model._node_key(keras_layer, node_idx) in model._network_nodes:
                    continue
                inexpr = []
                # Since Keras allows creating multiple layers from the same name instance,
                # we append node index to the expr name to make it unique.
                # The one exception is InputLayer. Changing input variable names after conversion
                # would confuse users, so we should keep them as far as possible. Fortunately,
                # they are named uniquely to input_1, input_2, input_3... by default.
                zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
                for n_idx, t_idx, inbound_layer in zip_node:
                    if isinstance(inbound_layer, keras.engine.InputLayer):
                        expr_name = inbound_layer.name
742
                        _convert_input_layer(inbound_layer)
743 744 745 746 747 748 749 750 751 752 753 754 755
                    else:
                        expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
                    expr = etab.get_expr(expr_name)
                    inexpr.append(expr)
                if len(inexpr) == 1:
                    inexpr = inexpr[0]
                keras_op_to_relay(inexpr, keras_layer, keras_layer.name + ':' + str(node_idx), etab)
    # model._output_coordinates contains out_node(oc[0]), node_index(oc[1]) and tensor_index(oc[2])
    # Get all output nodes in etab using the name made from above values.
    # The out exprs were added to etab in keras_op_to_relay using this name.
    outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
               for oc in model._output_coordinates]
    outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
Zhi committed
756
    func = _expr.Function(analysis.free_vars(outexpr), outexpr)
757
    params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
758
    return _module.Module.from_expr(func), params