coreml.py 19.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
# pylint: disable=invalid-name, import-self, unused-argument, unused-variable, inconsistent-return-statements
"""CoreML frontend."""
from __future__ import absolute_import as _abs
20
import math
21
import numpy as np
Zhi committed
22
import tvm
Zhi committed
23
from .. import analysis
24
from .. import expr as _expr
25
from .. import module as _module
26 27 28 29
from .. import op as _op
from ... import nd as _nd
from ..._ffi import base as _base
from .common import ExprTable
30
from .common import infer_shape as _infer_shape
31 32 33 34 35

__all__ = ['from_coreml']


def _NeuralNetworkImageScaler(op, inexpr, etab):
36
    # TODO: we need to support more colorspace, such as rgb.
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    # this changes the symbol
    biases = np.array([op.blueBias, op.greenBias, op.redBias]).reshape([3, 1, 1])
    bias = etab.new_const(biases)
    ret = _op.multiply(inexpr, _expr.const(op.channelScale, dtype='float32'))
    ret = _op.add(ret, bias)
    return ret


def _NeuralNetworkMeanImage(op, inexpr, etab):
    # this changes the symbol
    ret = _op.subtract(inexpr, _expr.const(op.meanImage, dtype='float32'))
    return ret


def _ConvolutionLayerParams(op, inexpr, etab):
    """Convolution layer params."""
53 54 55 56 57 58
    if op.isDeconvolution:
        weights = etab.new_const(np.array(list(op.weights.floatValue)).reshape(
            tuple([op.kernelChannels, op.outputChannels] + list(op.kernelSize))))
    else:
        weights = etab.new_const(np.array(list(op.weights.floatValue)).reshape(
            tuple([op.outputChannels, op.kernelChannels] + list(op.kernelSize))))
59 60 61
    dilation = list(op.dilationFactor)
    if not dilation:
        dilation = [1, 1]
62
    N, C, H, W = _infer_shape(inexpr)
63 64 65 66 67 68 69 70
    params = {'channels':op.outputChannels,
              'kernel_size':list(op.kernelSize),
              'strides':list(op.stride),
              'dilation': dilation,
              'groups':op.nGroups}

    if op.WhichOneof('ConvolutionPaddingType') == 'valid':
        valid = op.valid
71 72 73 74 75 76
        if valid.paddingAmounts.borderAmounts:
            assert len(valid.paddingAmounts.borderAmounts) == 2
            pad_t = valid.paddingAmounts.borderAmounts[0].startEdgeSize
            pad_l = valid.paddingAmounts.borderAmounts[1].startEdgeSize
            pad_b = valid.paddingAmounts.borderAmounts[0].endEdgeSize
            pad_r = valid.paddingAmounts.borderAmounts[1].endEdgeSize
77 78 79 80 81
            if not all(v == 0 for v in (pad_t, pad_l, pad_b, pad_r)):
                inexpr = _op.nn.pad(data=inexpr, pad_width=((0, 0),
                                                            (0, 0),
                                                            (pad_t, pad_b),
                                                            (pad_l, pad_r)))
82
    elif op.WhichOneof('ConvolutionPaddingType') == 'same':
83 84
        assert op.same.asymmetryMode == 0, "Only support BOTTOM_RIGHT_HEAVY mode, " \
                                           "which is used by tf/caffe and so on"
85
        kernel = params['kernel_size']
86 87 88 89 90 91 92 93
        strides = params['strides']
        pad_t, pad_b = get_pad_value(H, kernel[0], strides[0])
        pad_l, pad_r = get_pad_value(W, kernel[1], strides[1])
        inexpr = _op.nn.pad(data=inexpr, pad_width=((0, 0),
                                                    (0, 0),
                                                    (pad_t, pad_b),
                                                    (pad_l, pad_r)))

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    else:
        raise NotImplementedError("Valid/Same convolution padding implemented")

    if op.isDeconvolution:
        ret = _op.nn.conv2d_transpose(data=inexpr, weight=weights, **params)
    else:
        ret = _op.nn.conv2d(data=inexpr, weight=weights, **params)
    if op.hasBias:
        biases = etab.new_const(list(op.bias.floatValue))
        ret = _op.nn.bias_add(ret, biases)

    return ret


def _BatchnormLayerParams(op, inexpr, etab):
    """Get layer of batchnorm parameter"""
    # this changes the symbol
    if op.instanceNormalization:
112 113
        raise tvm.error.OpNotImplemented(
            'Operator "instance normalization" is not supported in frontend CoreML.')
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    else:
        params = {'gamma':etab.new_const(list(op.gamma.floatValue)),
                  'beta':etab.new_const(list(op.beta.floatValue)),
                  'moving_mean':etab.new_const(list(op.mean.floatValue)),
                  'moving_var': etab.new_const(list(op.variance.floatValue)),
                  'epsilon': op.epsilon}
        result, moving_mean, moving_var = _op.nn.batch_norm(data=inexpr, **params)
        return result


def _ActivationParams(op, inexpr, etab):
    """Get activation parameters"""
    whichActivation = op.WhichOneof('NonlinearityType')
    par = getattr(op, whichActivation)
    if whichActivation == 'linear':
        alpha = _expr.const(par.alpha, dtype='float32')
        beta = _expr.const(par.beta, dtype='float32')
        return _op.add(_op.multiply(inexpr, alpha), beta)
132
    if whichActivation == 'ReLU':
133
        return _op.nn.relu(inexpr)
134
    if whichActivation == 'leakyReLU':
135 136 137 138
        _op.nn.leaky_relu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
    elif whichActivation == 'thresholdedReLU':
        alpha_tensor = _op.full_like(inexpr, fill_value=_expr.const(par.alpha, dtype='float32'))
        return _op.multiply(inexpr, _op.greater(inexpr, alpha_tensor).as_type('float32'))
139
    if whichActivation == 'PReLU':
140
        return _op.nn.prelu(inexpr, alpha=_expr.const(par.alpha, dtype='float32'))
141
    if whichActivation == 'tanh':
142
        return _op.tanh(inexpr)
143
    if whichActivation == 'scaledTanh':
144 145 146
        alpha = _expr.const(par.alpha, dtype='float32')
        beta = _expr.const(par.beta, dtype='float32')
        return _op.multiply(_op.tanh(_op.multiply(inexpr, beta)), alpha)
147
    if whichActivation == 'sigmoid':
148
        return _op.sigmoid(inexpr)
149
    if whichActivation == 'sigmoidHard':
150 151 152 153
        alpha = _expr.const(par.alpha, dtype='float32')
        beta = _expr.const(par.beta, dtype='float32')
        transformX = (alpha * inexpr) + beta
        return _op.clip(transformX, a_min=0., a_max=1.)
154
    if whichActivation == 'ELU':
155 156
        return _op.multiply(_op.add(_op.exp(inexpr), _expr.const(-1, dtype='float32')),
                            _expr.const(par.alpha, dtype='float32'))
157
    if whichActivation == 'softsign':
158 159
        return inexpr / (_expr.const(1, dtype='float32') + (
            op.nn.relu(inexpr) + _op.nn.relu(_op.negative(inexpr))))
160
    if whichActivation == 'softplus':
161
        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1, dtype='float32')))
162
    if whichActivation == 'parametricSoftplus':
163 164 165 166 167 168 169 170 171 172 173
        alpha = list(par.alpha.floatValue)
        beta = list(par.alpha.floatValue)
        if len(alpha) == 1:
            return _op.multiply(_op.log(_op.add(_op.exp(inexpr),
                                                _expr.const(beta[0], dtype='float32'))),
                                _expr.const(alpha[0], dtype='float32'))
        alpha = np.array(alpha).reshape((len(alpha), 1, 1))
        beta = np.array(beta).reshape((len(beta), 1, 1))
        alpha_expr = etab.new_const(alpha)
        beta_expr = etab.new_const(beta)
        return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr)
174 175
    raise tvm.error.OpNotImplemented(
        'Operator {} is not supported in frontend CoreML.'.format(whichActivation))
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194


def _ScaleLayerParams(op, inexpr, etab):
    """Scale layer params."""
    scale = etab.new_const(np.array(list(op.scale.floatValue)).reshape(
        tuple(list(op.shapeScale) + [1, 1])))
    ret = _op.multiply(inexpr, scale)
    if op.hasBias:
        bias = etab.new_const(np.array(list(op.bias.floatValue)).reshape(
            tuple(list(op.shapeBias) + [1, 1])))
        ret = _op.add(ret, bias)
    return ret


def _PoolingLayerParams(op, inexpr, etab):
    """get pooling parameters"""
    if op.globalPooling:
        if op.type == 0:
            return _op.nn.global_max_pool2d(inexpr)
195
        if op.type == 1:
196
            return _op.nn.global_avg_pool2d(inexpr)
197 198
        raise tvm.error.OpNotImplemented(
            'Only Max and Average Pooling are supported in frontend CoreML.')
199 200 201 202 203 204 205

    else:
        params = {'pool_size':list(op.kernelSize),
                  'strides':list(op.stride)}

        if op.WhichOneof('PoolingPaddingType') == 'valid':
            valid = op.valid
206 207 208 209 210 211
            if valid.paddingAmounts.borderAmounts:
                assert len(valid.paddingAmounts.borderAmounts) == 2
                pad_t = valid.paddingAmounts.borderAmounts[0].startEdgeSize
                pad_l = valid.paddingAmounts.borderAmounts[1].startEdgeSize
                pad_b = valid.paddingAmounts.borderAmounts[0].endEdgeSize
                pad_r = valid.paddingAmounts.borderAmounts[1].endEdgeSize
212 213
                if not all(v == 0 for v in (pad_t, pad_l, pad_b, pad_r)):
                    params['padding'] = [pad_t, pad_l, pad_b, pad_r]
214 215 216 217 218 219 220
        elif op.WhichOneof('PoolingPaddingType') == 'includeLastPixel':
            # I don't know if this is correct
            valid = op.includeLastPixel
            padding = list(valid.paddingAmounts)
            params['padding'] = padding
            params['ceil_mode'] = True
        else:
221 222
            msg = 'PoolingPaddingType {} is not supported in operator Pooling.'
            op_name = op.WhichOneof('PoolingPaddingType')
223
            raise tvm.error.OpAttributeUnImplemented(msg.format(op_name))
224 225 226

        if op.type == 0:
            return _op.nn.max_pool2d(inexpr, **params)
227
        if op.type == 1:
228
            return _op.nn.avg_pool2d(inexpr, **params)
229 230
        raise tvm.error.OpNotImplemented(
            'Only Max and Average Pooling are supported in CoreML.')
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272


def _SoftmaxLayerParams(op, inexpr, etab):
    return _op.nn.softmax(_op.nn.batch_flatten(inexpr))


def _InnerProductLayerParams(op, inexpr, etab):
    weights = etab.new_const(np.array(op.weights.floatValue).reshape(
        (op.outputChannels, op.inputChannels)))
    out = _op.nn.dense(data=inexpr, weight=weights, units=op.outputChannels)
    if op.hasBias:
        bias = etab.new_const(np.array(op.bias.floatValue))
        out = _op.nn.bias_add(out, bias)
    return out


def _AddLayerParams(op, inexpr, etab):
    if not isinstance(inexpr, list):
        inexpr = [inexpr]
    ret = inexpr[0]
    for i in range(1, len(inexpr)):
        ret = _op.add(ret, inexpr[i])
    if op.alpha > 0:
        ret = _op.add(ret, _expr.const(op.alpha, dtype='float32'))
    return ret


def _MultiplyLayerParams(op, inexpr, etab):
    if not isinstance(inexpr, list):
        inexpr = [inexpr]
    ret = inexpr[0]
    for i in range(1, len(inexpr)):
        ret = _op.multiply(ret, inexpr[i])
    if op.alpha != 1:
        ret = _op.multiply(ret, _expr.const(op.alpha, dtype='float32'))
    return ret


def _ConcatLayerParams(op, inexpr, etab):
    if not isinstance(inexpr, list):
        inexpr = [inexpr]
    if op.sequenceConcat:
273 274
        raise tvm.error.OpNotImplemented(
            'Operator Sequence Concat is not supported in frontend CoreML.')
275 276 277 278 279 280 281 282 283 284 285
    ret = _op.concatenate(inexpr, axis=1)
    return ret


def _FlattenLayerParams(op, inexpr, etab):
    if op.mode == 1:
        inexpr = _op.transpose(_op.reshape(inexpr, newshape=(0, 0, -1)), axes=(0, 2, 1))
    return _op.nn.batch_flatten(inexpr)


def _PaddingLayerParams(op, inexpr, etab):
286
    """Padding layer params."""
287 288 289
    if op.WhichOneof('PaddingType') == 'constant':
        constant = op.constant
        if constant.value != 0:
290
            raise tvm.error.OpAttributeUnImplemented(
291
                '{} is not supported in operator Padding.'.format(constant.value))
292 293 294 295 296 297 298 299 300
        pad_t = op.paddingAmounts.borderAmounts[0].startEdgeSize
        pad_l = op.paddingAmounts.borderAmounts[1].startEdgeSize
        pad_b = op.paddingAmounts.borderAmounts[0].endEdgeSize
        pad_r = op.paddingAmounts.borderAmounts[1].endEdgeSize
        return _op.nn.pad(data=inexpr, pad_width=((0, 0),
                                                  (0, 0),
                                                  (pad_t, pad_b),
                                                  (pad_l, pad_r)))

301
    else:
302 303
        raise tvm.error.OpNotImplemented(
            'Non-constant padding is not supported in frontend CoreML.')
304 305 306 307 308 309 310 311 312


def _PermuteLayerParams(op, inexpr, etab):
    axes = tuple(op.axis)
    return _op.transpose(inexpr, axes=axes)


def _UpsampleLayerParams(op, inexpr, etab):
    if op.scalingFactor[0] != op.scalingFactor[1]:
313 314
        raise tvm.error.OpAttributeUnimplemented(
            'Upsample height and width must be equal.')
315
    interpolationMode = 'nearest_neighbor' if op.mode == 0 else 'bilinear'
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
    return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode)


def _L2NormalizeLayerParams(op, inexpr, etab):
    return _op.nn.l2_normalize(inexpr, eps=op.epsilon, axis=[1])


def _LRNLayerParams(op, inexpr, etab):
    par = {}
    par['size'] = op.localSize
    par['bias'] = op.k
    par['alpha'] = op.alpha
    par['beta'] = op.beta
    par['axis'] = 1 # default layout is nchw
    return _op.nn.lrn(data=inexpr, **par)


def _AverageLayerParams(op, inexpr, etab):
    if not isinstance(inexpr, list) or len(inexpr) < 2:
        raise ValueError("Expect minimum 2 inputs")
    count = len(inexpr)
    _sum = inexpr[0]
    for i in range(1, count):
        _sum = _op.add(_sum, inexpr[i])
    return _sum / _expr.const(count, dtype='float32')


def _MaxLayerParams(op, inexpr, etab):
    if not isinstance(inexpr, list) or len(inexpr) < 2:
        raise ValueError("Expect minimum 2 inputs")
    _max = inexpr[0]
    for i in range(1, len(inexpr)):
        _max = _op.maximum(_max, inexpr[i])
    return _max


def _MinLayerParams(op, inexpr, etab):
    if not isinstance(inexpr, list) or len(inexpr) < 2:
        raise ValueError("Expect minimum 2 inputs")
    _min = inexpr[0]
    for i in range(1, len(inexpr)):
        _min = _op.minimum(_min, inexpr[i])
    return _min


_convert_map = {
    'NeuralNetworkMeanImage': _NeuralNetworkMeanImage,
    'NeuralNetworkImageScaler': _NeuralNetworkImageScaler,
    'ConvolutionLayerParams': _ConvolutionLayerParams,
    'BatchnormLayerParams': _BatchnormLayerParams,
    'ActivationParams': _ActivationParams,
    'ScaleLayerParams': _ScaleLayerParams,
    'PoolingLayerParams': _PoolingLayerParams,
    'SoftmaxLayerParams': _SoftmaxLayerParams,
    'InnerProductLayerParams': _InnerProductLayerParams,
    'AddLayerParams': _AddLayerParams,
    'MultiplyLayerParams': _MultiplyLayerParams,
    'FlattenLayerParams': _FlattenLayerParams,
    'ConcatLayerParams': _ConcatLayerParams,
    'PaddingLayerParams': _PaddingLayerParams,
    'PermuteLayerParams': _PermuteLayerParams,
    'UpsampleLayerParams': _UpsampleLayerParams,
    'L2NormalizeLayerParams': _L2NormalizeLayerParams,
    'LRNLayerParams': _LRNLayerParams,
    'AverageLayerParams': _AverageLayerParams,
    'MaxLayerParams': _MaxLayerParams,
    'MinLayerParams': _MinLayerParams,
}

385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
# SAME padding: https://www.tensorflow.org/api_guides/python/nn
def get_pad_value(data, kernel, stride):
    """Get the pad tuple of value for SAME padding

    Parameters
    ----------
    data:
        1D input data

    kernel:
        1D input kernel

    stride:
        1D input stride

    Returns
    -------
        pad tuple of value
    """

    out = int(math.ceil(float(data) / float(stride)))
    pad = max(0, (out - 1) * stride + kernel - data)
    pad_before = pad // 2
    pad_after = pad - pad_before
    return pad_before, pad_after

411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429

def coreml_op_to_relay(op, inname, outname, etab):
    """Convert coreml layer to a Relay expression and update the expression table.

    Parameters
    ----------
    op: a coreml protobuf bit

    inname : str or list of str
        Name of the input Relay expression.

    outname : str
        Name of the output Relay expression.

    etab : relay.frontend.common.ExprTable
        The global expression table to be updated.
    """
    classname = type(op).__name__
    if classname not in _convert_map:
430 431
        raise tvm.error.OpNotImplemented(
            'Operator {} is not supported in frontend CoreML.'.format(classname))
432 433 434 435 436 437
    if isinstance(inname, _base.string_types):
        insym = etab.get_expr(inname)
    else:
        insym = [etab.get_expr(i) for i in inname]
    ret = _convert_map[classname](op, insym, etab)
    if outname:
438
        etab.set_expr(outname, ret, force_override=True)
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453


def from_coreml(model, shape=None):
    """Convert from coreml model into Relay Function.

    Parameters
    ----------
    model:
        coremltools.models.MLModel of a NeuralNetworkClassifier

    shape : dict of str to int list/tuple, optional
        The input shapes

    Returns
    -------
454 455
    mod : tvm.relay.Module
        The relay module for compilation.
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479

    params : dict of str to tvm.NDArray
        The parameter dict to be used by Relay.
    """
    try:
        import coremltools as cm
    except ImportError:
        raise ImportError('The coremltools package must be installed')

    assert isinstance(model, cm.models.MLModel)
    spec = model.get_spec()
    modeltype = spec.WhichOneof('Type')
    assert modeltype in ['neuralNetworkClassifier', 'neuralNetwork', 'neuralNetworkRegressor']
    cc = getattr(spec, modeltype)

    etab = ExprTable()
    for i in spec.description.input:
        input_shape = shape[i.name] if shape is not None and i.name in shape else None
        etab.set_expr(i.name, _expr.var(i.name, shape=input_shape))

    for pp in cc.preprocessing:
        whichpp = pp.WhichOneof('preprocessor')
        ppmethod = getattr(pp, whichpp)
        if whichpp == 'scaler':
480 481 482
            # Be careful we maybe only preprocess one input when we have multi inputs
            # which is stored in pp.featureName. See unit testing verify_image_scaler
            # in test_forward.py for CoreML.
483
            for i in spec.description.input:
484 485 486 487 488 489 490 491
                # we have multi inputs
                if len(spec.description.input) > 1:
                    assert pp.featureName != ''
                    if i.name == pp.featureName:
                        coreml_op_to_relay(ppmethod, i.name, i.name, etab)
                else:
                    assert pp.featureName == ''
                    coreml_op_to_relay(ppmethod, i.name, i.name, etab)
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
        else:
            coreml_op_to_relay(ppmethod, pp.featureName, pp.featureName, etab)

    for l in cc.layers:
        layertype = l.WhichOneof('layer')
        layerop = getattr(l, layertype)
        assert len(l.output) == 1
        if len(l.input) == 1:
            coreml_op_to_relay(layerop, l.input[0], l.output[0], etab)
        else:
            coreml_op_to_relay(layerop, list(l.input), l.output[0], etab)

    outexpr = [etab.get_expr(o.name) if o.name in etab.exprs else _expr.var(o.name)
               for o in spec.description.output]
    # for now return first output
    outexpr = outexpr[0]
Zhi committed
508
    func = _expr.Function(analysis.free_vars(outexpr), outexpr)
509
    params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
510
    return _module.Module.from_expr(func), params