Commit 53511bf1 by Mark Rogers Committed by Tianqi Chen

Unified error handling in NNVM and Relay frontends (#2828)

parent e3206aa8
......@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from nnvm import symbol as _sym
from nnvm.frontend.common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
from .common import get_nnvm_op
from .onnx_caffe2_utils import dimension_picker, dimension_constraint, infer_channels, revert_caffe2_pad
from . import onnx
......@@ -73,8 +73,8 @@ class Caffe2OpConverter(object):
if hasattr(cls, '_impl'):
return getattr(cls, '_impl')
raise NotImplementedError('{} not implemented'.format(
cls.__name__))
raise tvm.error.OpNotImplemented(
'Operator {} is not implemented in frontend Caffe2.'.format(cls.__name__))
_caffe2_internal_args = {
......@@ -176,8 +176,7 @@ class Concat(Caffe2OpConverter):
return 1
if order == 'NHWC':
return 3
raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order))
raise tvm.error.OpAttributeInvalid('Value {} in attribute {} of operator {} is not valid.'.format(order, 'order', 'Concat'))
return AttrCvt(
op_name='concatenate',
......@@ -427,8 +426,8 @@ class Caffe2NetDef(object):
# Add a sanitizing step to convert all byte strings in args to strings
sym = convert_map[op_type](inputs, args, self._params)
else:
raise NotImplementedError(
"Operator {} not implemented.".format(op_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Caffe2.'.format(op_type))
return sym
......
......@@ -7,9 +7,25 @@ from .._base import string_types
def get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
raise OpNotImplemented(
'Operator {} is not supported.'.format(op))
return op
def required_attr(attr, key, op_name):
assert isinstance(attr, dict)
if key not in attr:
raise OpAttributeRequired(
'Required attribute {} not found in operator {}'.format(key, op_name))
return attr[key]
def parse_tshape(tshape):
"""Parse tshape in string."""
return [int(x.strip()) for x in tshape.strip('()').split(',')]
def parse_bool_str(attr, key, default='False'):
"""Parse bool string to boolean."""
return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']
class Renamer(object):
"""A simply renamer for operators.
......
......@@ -2,11 +2,10 @@
"""CoreML frontend."""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from .common import SymbolTable
from .. import symbol as _sym
from .._base import string_types
from .common import SymbolTable
__all__ = ['from_coreml']
......@@ -83,7 +82,8 @@ def BatchnormLayerParams(op, insym, symtab):
"""Get layer of batchnorm parameter"""
# this changes the symbol
if op.instanceNormalization:
raise NotImplementedError("instance normalization not implemented")
msg = 'Operator "instance normalization" is not supported in frontend CoreML.'
raise tvm.error.OpNotImplemented(msg)
else:
params = {'gamma':symtab.new_const(list(op.gamma.floatValue)),
'beta':symtab.new_const(list(op.beta.floatValue)),
......@@ -136,7 +136,8 @@ def ActivationParams(op, insym, symtab):
betasym = symtab.new_const(beta)
return _sym.broadcast_mul(_sym.log(_sym.broadcast_add(
_sym.exp(insym), betasym)), alphasym)
raise NotImplementedError('%s not implemented' % whichActivation)
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend CoreML.'.format(whichActivation))
def ScaleLayerParams(op, insym, symtab):
"""Scale layer params."""
......@@ -158,7 +159,8 @@ def PoolingLayerParams(op, insym, symtab):
return _sym.global_max_pool2d(insym)
if op.type == 1:
return _sym.global_avg_pool2d(insym)
raise NotImplementedError("Only max and average pooling implemented")
raise tvm.error.OpNotImplemented(
'Operator pooling (not max or average) is not supported in frontend CoreML.')
else:
params = {'pool_size':list(op.kernelSize),
......@@ -178,7 +180,8 @@ def PoolingLayerParams(op, insym, symtab):
params['padding'] = padding
params['ceil_mode'] = True
else:
raise NotImplementedError("Other convolution padding not implemented")
msg = 'Value {} in attribute PoolingPaddingType of operator Pooling is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(op.WhichOneof('PoolingPaddingType')))
# consume padding layer
if symtab.in_padding:
......@@ -190,7 +193,8 @@ def PoolingLayerParams(op, insym, symtab):
return _sym.max_pool2d(insym, **params)
if op.type == 1:
return _sym.avg_pool2d(insym, **params)
raise NotImplementedError("Only max and average pooling implemented")
msg = 'Operator pooling (not max or average) is not supported in frontend CoreML.'
raise tvm.error.OpNotImplemented(msg)
def SoftmaxLayerParams(op, insym, symtab):
return _sym.softmax(_sym.flatten(insym))
......@@ -229,7 +233,8 @@ def ConcatLayerParams(op, insyms, symtab):
if not isinstance(insyms, list):
insyms = [insyms]
if op.sequenceConcat:
raise NotImplementedError("Sequence Concat not supported")
raise tvm.error.OpNotImplemented(
'Operator Sequence Concat is not supported in frontend CoreML.')
ret = _sym.concatenate(*insyms, axis=1)
return ret
......@@ -243,14 +248,16 @@ def PaddingLayerParams(op, insym, symtab):
if op.WhichOneof('PaddingType') == 'constant':
constant = op.constant
if constant.value != 0:
raise NotImplementedError("Padding value {} not supported.".format(constant.value))
msg = 'Value {} in attribute "padding value" of operator Padding is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(constant.value))
padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts]
padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts]
for i, j in zip(padding, padding2):
assert i == j
symtab.set_padding(padding)
else:
raise NotImplementedError("Only constant padding is supported now.")
raise tvm.error.OpNotImplemented(
'Operator "non-constant padding" is not supported in frontend CoreML.')
return insym
def PermuteLayerParams(op, insym, symtab):
......@@ -259,8 +266,8 @@ def PermuteLayerParams(op, insym, symtab):
def UpsampleLayerParams(op, insym, symtab):
if op.scalingFactor[0] != op.scalingFactor[1]:
raise NotImplementedError("Upsampling only supported with same \
height and width scaling factor.")
raise tvm.error.OpAttributeInvalid(
'Height and width scaling factors of Upsample operator must be equal.')
interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR'
return _sym.upsampling(insym, scale=op.scalingFactor[0], method=interpolationMode)
......@@ -341,7 +348,8 @@ def coreml_op_to_nnvm(op, inname, outname, symtab):
"""
classname = type(op).__name__
if classname not in _convert_map:
raise NotImplementedError("%s is not supported" % (classname))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend CoreML.'.format(classname))
if isinstance(inname, string_types):
insym = symtab.get_var(inname)
else:
......
......@@ -74,7 +74,8 @@ def _convert_activation(insym, keras_layer, _):
if act_type == 'hard_sigmoid':
transformX = (0.2 * insym) + 0.5
return _sym.clip(transformX, a_min=0, a_max=1)
raise TypeError("Unsupported activation type : {}".format(act_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(act_type))
def _convert_advanced_activation(insym, keras_layer, symtab):
......@@ -100,7 +101,8 @@ def _convert_advanced_activation(insym, keras_layer, symtab):
theta = keras_layer.theta if hasattr(keras_layer, "theta") else 1.0
theta_tensor = _sym.full_like(insym[0], fill_value=float(theta))
return _sym.elemwise_mul(insym[0], _sym.greater(insym[0], theta_tensor, out_type="float32"))
raise TypeError("Unsupported advanced activation type : {}".format(act_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(act_type))
def _convert_merge(insym, keras_layer, _):
......@@ -113,12 +115,9 @@ def _convert_merge(insym, keras_layer, _):
ret = _sym.elemwise_sub(ret, insym[i])
elif merge_type == 'Multiply':
ret = _sym.elemwise_mul(ret, insym[i])
elif merge_type == 'Average':
raise NotImplementedError('Average merge not implemented')
elif merge_type == 'Maximum':
raise NotImplementedError('Maximum merge not implemented')
else:
raise TypeError("Unsupported merge type : {}".format(merge_type))
raise tvm.error.OpNotImplemented(
'Operator {} Merge is not supported in frontend Keras.'.format(merge_type))
return ret
......@@ -135,7 +134,8 @@ def _convert_dense(insym, keras_layer, symtab):
if input_dim > 2:
input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
raise ValueError("Cannot flatten the inputs with shape.", input_shape, " for dense.")
msg = 'Value {} in attribute "input_shape" of operator Dense is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape))
insym = _sym.squeeze(insym, axis=0)
out = _sym.dense(data=insym, **params)
# defuse activation
......@@ -199,7 +199,8 @@ def _convert_convolution(insym, keras_layer, symtab):
else:
insym = _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
msg = 'Value {} in attribute "padding" of operator Convolution is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding))
if is_deconv:
out = _sym.conv2d_transpose(data=insym, **params)
else:
......@@ -240,7 +241,8 @@ def _convert_separable_convolution(insym, keras_layer, symtab):
insym = _sym.pad(data=insym, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
msg = 'Value {} in attribute "padding" of operator Separable Convolution is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding))
depthconv = _sym.conv2d(data=insym, **params0)
# pointwise conv
weight1 = weightList[1].transpose([3, 2, 0, 1])
......@@ -294,13 +296,15 @@ def _convert_pooling(insym, keras_layer, symtab):
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
msg = 'Value {} in attribute "padding" of operator Pooling is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding))
if pool_type == 'MaxPooling2D':
return _sym.max_pool2d(insym, **params)
if pool_type == 'AveragePooling2D':
# TODO: in keras, padded zeros are not calculated
return _sym.avg_pool2d(insym, **params)
raise TypeError("Unsupported pooling type : {}".format(keras_layer))
msg = 'Value {} in attribute "padding" of operator Pooling is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(keras_layer.padding))
def _convert_upsample(insym, keras_layer, _):
......@@ -312,30 +316,30 @@ def _convert_upsample(insym, keras_layer, _):
elif upsample_type == "UpSampling2D":
h, w = keras_layer.size
if h != w:
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
raise tvm.error.OpAttributeInvalid(
'Upsample height ({}) must equal width ({})'.format(h, w))
params = {'scale': h}
elif upsample_type == "UpSampling3D":
h, w, d = keras_layer.size
if h != w or w != d:
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
raise tvm.error.OpAttributeInvalid(
'Upsample height ({}), width ({}), and depth ({}) must be equal.'.format(h, w, d))
params = {'scale': h}
else:
raise TypeError("Unsupported upsampling type : {}".format(upsample_type))
msg = 'Operator {} is not supported in frontend Keras.'
raise tvm.error.OpNotImplemented(msg.format(upsample_type))
return _sym.upsampling(insym, **params)
def _convert_cropping(insym, keras_layer, _):
_check_data_format(keras_layer)
crop_type = type(keras_layer).__name__
if crop_type == "Cropping1D":
raise NotImplementedError("Cropping1D not implemented")
elif crop_type == "Cropping2D":
if crop_type == "Cropping2D":
(_, in_h, in_w, _) = keras_layer.input_shape
((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
else:
raise TypeError("Unrecognized cropping type : {}".format(crop_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(crop_type))
int32_max = np.iinfo(np.int32).max
return _sym.strided_slice(insym, begin=[0, 0, crop_t, crop_l],
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
......@@ -379,13 +383,13 @@ def _convert_padding(insym, keras_layer, _):
top, bottom = padding[0]
left, right = padding[1]
else:
raise ValueError("Unrecognized padding option: {}".format(str(padding)))
msg = 'Value {} in attribute "padding" of operator {} is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(str(padding), padding_type))
else:
raise ValueError("Unrecognized padding option: {}".format(str(padding)))
elif padding_type == 'ZeroPadding1D':
raise NotImplementedError("ZeroPadding1D not implemented")
msg = 'Value {} in attribute "padding" of operator {} is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(str(padding), padding_type))
else:
raise ValueError("Unrecognized padding type: {}".format(padding_type))
raise tvm.error.OpNotImplemented('Operator {} is not supported in frontend Keras.')
return _sym.pad(data=insym, pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
......@@ -592,8 +596,10 @@ _convert_map = {
def _check_unsupported_layers(model):
for layer in model.layers:
if type(layer).__name__ not in _convert_map:
raise ValueError("Keras layer {} not supported.".format(type(layer).__name__))
op_name = type(layer).__name__
if op_name not in _convert_map:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(op_name))
def _as_list(arr):
"""Force being a list, ignore if already is."""
......@@ -618,9 +624,11 @@ def keras_op_to_nnvm(insym, keras_layer, outname, symtab):
symtab : nnvm.frontend.common.SymbolTable
The global symbol table to be updated
"""
if type(keras_layer).__name__ not in _convert_map:
raise NotImplementedError("{} is not supported".format((type(keras_layer).__name__)))
outs = _convert_map[type(keras_layer).__name__](insym, keras_layer, symtab)
op_name = type(keras_layer).__name__
if op_name not in _convert_map:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(op_name))
outs = _convert_map[op_name](insym, keras_layer, symtab)
outs = _as_list(outs)
for t_idx, out in enumerate(outs):
......
......@@ -397,7 +397,8 @@ class Upsample(OnnxOpConverter):
elif mode == b'linear':
method = "BILINEAR"
else:
raise ValueError("Invalid ONNX upsample mode: {}".format(mode))
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
return _sym.upsampling(inputs[0], scale=int(scales[-1]), method=method, layout='NCHW')
......@@ -922,8 +923,8 @@ class GraphProto(object):
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
else:
raise NotImplementedError(
"Operator {} not implemented.".format(op_name))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend ONNX.')
return sym
def _fix_outputs(self, op_name, outputs):
......
......@@ -68,7 +68,8 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
raise tvm.error.OpAttributeUnimplemented(
'Non-2D kernels are not supported for operator {}.'.format(prefix))
return _impl
def _dimension_constraint():
......@@ -129,7 +130,8 @@ def _pooling(name):
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
msg = 'Value {} in attribute "data_format" of operator Pooling is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]]
......@@ -158,7 +160,8 @@ def _pooling(name):
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
msg = 'Value {} in attribute "padding" of operator Pooling is not valid.'
raise tvm.error.OpAttributeUnimplemented(msg.format(attr['padding']))
if name == "avg_pool":
attr['count_include_pad'] = False
......@@ -232,7 +235,8 @@ def _conv(opname):
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
msg = 'Value {} in attribute "data_format" of operator Conv is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if opname == 'depthwise':
......@@ -276,7 +280,8 @@ def _conv(opname):
attr['padding'] = [0, 0]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
msg = 'Value {} in attribute "padding" of operator Conv is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
if 'kernel_layout' not in attr:
if opname == 'conv':
......@@ -432,7 +437,8 @@ def _reshape():
op_name="reshape",
extras={'shape':tuple(params_new[0].asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
raise tvm.error.OpAttributeUnimplemented(
'Attribute "dynamic shape" of operator Reshape is not supported.')
return _impl
def _bias_add():
......@@ -736,7 +742,8 @@ def _pad(name):
if padlist_key in params:
padlist = params.pop(padlist_key).asnumpy()
else:
raise RuntimeError("Required parameter {} not fount.".format(padlist_key))
raise tvm.error.OpAttributeRequired(
'Required attribute "{}" not found in operator Pad.'.format(padlist_key))
paddings = tuple([tuple(l) for l in padlist])
attr['pad_width'] = paddings
attr['pad_value'] = 0
......@@ -1188,8 +1195,9 @@ class GraphProto(object):
missing_operators = self._parse_import_prerequisites(graph)
if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))
msg = 'The following operators are not supported in frontend TensorFlow: {}'
ops = str(list(missing_operators)).strip('[,]')
raise tvm.error.OpNotImplemented(msg.format(ops))
for node in graph.node:
if node.op == 'Placeholder':
......@@ -1529,7 +1537,8 @@ class GraphProto(object):
self._params, graph,
convert_map_rnn)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend TensorFlow.'.format(op_name))
return sym
def _fix_extranodes(self, op_name, attr, inputs):
......
# pylint: disable=import-self, invalid-name, line-too-long, unused-argument
"""Caffe2 frontend"""
from __future__ import absolute_import as _abs
import tvm
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
......@@ -15,7 +16,8 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
raise tvm.error.OpAttributeUnimplemented(
'Non-2D kernels are not supported for operator {}2d'.format(prefix))
return _impl
......@@ -27,7 +29,8 @@ def revert_caffe2_pad(pads):
elif len(pads) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(pads))
raise tvm.error.OpAttributeInvalid(
'Number of pads must equal 2 or 4.')
return pads
......@@ -103,8 +106,8 @@ class Caffe2OpConverter(object):
if hasattr(cls, '_impl'):
return getattr(cls, '_impl')
raise NotImplementedError('{} not implemented'.format(
cls.__name__))
raise tvm.error.OpNotInplemented(
'Operator {} is not supported in frontend Caffe2.'.format(cls.__name__))
_caffe2_internal_args = [
......@@ -224,8 +227,8 @@ class Concat(Caffe2OpConverter):
return 1
if order == 'NHWC':
return 3
raise RuntimeError(
"Unsupported storage order: {} in caffe2".format(order))
raise tvm.error.OpAttributeUnimplemented(
'Order {} is not supported in operator Concat.'.format(order))
return AttrCvt(
op_name='concatenate',
......@@ -517,8 +520,8 @@ class Caffe2NetDef(object):
# Add a sanitizing step to convert all byte strings in args to strings
func = convert_map[op_type](inputs, args, self._params)
else:
raise NotImplementedError(
"Operator {} not implemented.".format(op_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Caffe2.'.format(op_type))
return func
......
# pylint: disable=invalid-name, import-self, unused-argument, unused-variable, inconsistent-return-statements
"""CoreML frontend."""
from __future__ import absolute_import as _abs
import tvm
import numpy as np
from .. import ir_pass
from .. import expr as _expr
......@@ -81,7 +82,8 @@ def _BatchnormLayerParams(op, inexpr, etab):
"""Get layer of batchnorm parameter"""
# this changes the symbol
if op.instanceNormalization:
raise NotImplementedError("instance normalization not implemented")
raise tvm.error.OpNotImplemented(
'Operator "instance normalization" is not supported in frontend CoreML.')
else:
params = {'gamma':etab.new_const(list(op.gamma.floatValue)),
'beta':etab.new_const(list(op.beta.floatValue)),
......@@ -142,7 +144,8 @@ def _ActivationParams(op, inexpr, etab):
alpha_expr = etab.new_const(alpha)
beta_expr = etab.new_const(beta)
return _op.multiply(_op.log(_op.add(_op.exp(inexpr), beta_expr)), alpha_expr)
raise NotImplementedError('%s not implemented' % whichActivation)
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend CoreML.'.format(whichActivation))
def _ScaleLayerParams(op, inexpr, etab):
......@@ -164,7 +167,8 @@ def _PoolingLayerParams(op, inexpr, etab):
return _op.nn.global_max_pool2d(inexpr)
if op.type == 1:
return _op.nn.global_avg_pool2d(inexpr)
raise NotImplementedError("Only max and average pooling implemented")
raise tvm.error.OpNotImplemented(
'Only Max and Average Pooling are supported in frontend CoreML.')
else:
params = {'pool_size':list(op.kernelSize),
......@@ -184,7 +188,9 @@ def _PoolingLayerParams(op, inexpr, etab):
params['padding'] = padding
params['ceil_mode'] = True
else:
raise NotImplementedError("Other convolution padding not implemented")
msg = 'PoolingPaddingType {} is not supported in operator Pooling.'
op_name = op.WhichOneof('PoolingPaddingType')
raise tvm.error.OpAttributeUnimplemented(msg.format(op_name))
# consume padding layer
if etab.in_padding:
......@@ -196,7 +202,8 @@ def _PoolingLayerParams(op, inexpr, etab):
return _op.nn.max_pool2d(inexpr, **params)
if op.type == 1:
return _op.nn.avg_pool2d(inexpr, **params)
raise NotImplementedError("Only max and average pooling implemented")
raise tvm.error.OpNotImplemented(
'Only Max and Average Pooling are supported in CoreML.')
def _SoftmaxLayerParams(op, inexpr, etab):
......@@ -239,7 +246,8 @@ def _ConcatLayerParams(op, inexpr, etab):
if not isinstance(inexpr, list):
inexpr = [inexpr]
if op.sequenceConcat:
raise NotImplementedError("Sequence Concat not supported")
raise tvm.error.OpNotImplemented(
'Operator Sequence Concat is not supported in frontend CoreML.')
ret = _op.concatenate(inexpr, axis=1)
return ret
......@@ -255,14 +263,16 @@ def _PaddingLayerParams(op, inexpr, etab):
if op.WhichOneof('PaddingType') == 'constant':
constant = op.constant
if constant.value != 0:
raise NotImplementedError("Padding value {} not supported.".format(constant.value))
raise tvm.error.OpAttributeUnimplemented(
'{} is not supported in operator Padding.'.format(constant.value))
padding = [b.startEdgeSize for b in op.paddingAmounts.borderAmounts]
padding2 = [b.endEdgeSize for b in op.paddingAmounts.borderAmounts]
for i, j in zip(padding, padding2):
assert i == j
etab.set_padding(padding)
else:
raise NotImplementedError("Only constant padding is supported now.")
raise tvm.error.OpNotImplemented(
'Non-constant padding is not supported in frontend CoreML.')
return inexpr
......@@ -273,8 +283,8 @@ def _PermuteLayerParams(op, inexpr, etab):
def _UpsampleLayerParams(op, inexpr, etab):
if op.scalingFactor[0] != op.scalingFactor[1]:
raise NotImplementedError("Upsampling only supported with same \
height and width scaling factor.")
raise tvm.error.OpAttributeUnimplemented(
'Upsample height and width must be equal.')
interpolationMode = 'NEAREST_NEIGHBOR' if op.mode == 0 else 'BILINEAR'
return _op.nn.upsampling(inexpr, scale=op.scalingFactor[0], method=interpolationMode)
......@@ -364,7 +374,8 @@ def coreml_op_to_relay(op, inname, outname, etab):
"""
classname = type(op).__name__
if classname not in _convert_map:
raise NotImplementedError("%s is not supported" % (classname))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend CoreML.'.format(classname))
if isinstance(inname, _base.string_types):
insym = etab.get_expr(inname)
else:
......
......@@ -2,6 +2,7 @@
"""Keras frontend."""
from __future__ import absolute_import as _abs
import sys
import tvm
import numpy as np
from .. import ir_pass
from .. import expr as _expr
......@@ -91,7 +92,8 @@ def _convert_activation(inexpr, keras_layer, _):
x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
return _op.clip(x, a_min=0., a_max=1.)
raise TypeError("Unsupported activation type : {}".format(act_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(act_type))
def _convert_advanced_activation(inexpr, keras_layer, etab):
......@@ -118,7 +120,8 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
return _op.multiply(inexpr, _op.greater(inexpr, \
_expr.const(theta, dtype='float32')).astype('float32'))
raise TypeError("Unsupported advanced activation type : {}".format(act_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(act_type))
def _convert_merge(inexpr, keras_layer, _):
......@@ -136,7 +139,8 @@ def _convert_merge(inexpr, keras_layer, _):
ret = _op.add(ret, inexpr[i])
ret = ret / _expr.const(len(inexpr), dtype='float32')
else:
raise TypeError("Unsupported merge type : {}".format(merge_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(merge_type))
return ret
......@@ -150,7 +154,8 @@ def _convert_dense(inexpr, keras_layer, etab):
if input_dim > 2:
input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
raise ValueError("Cannot flatten the inputs with shape.", input_shape, " for dense.")
raise tvm.error.OpAttributeInvalid(
'Input shape {} is not valid for operator Dense.'.format(input_shape))
inexpr = _op.squeeze(inexpr, axis=0)
out = _op.nn.dense(data=inexpr, **params)
if keras_layer.use_bias:
......@@ -214,7 +219,9 @@ def _convert_convolution(inexpr, keras_layer, etab):
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
msg = 'Padding with {} is not supported for operator Convolution ' \
'in frontend Keras.'
raise tvm.error.OpAttributeUnimplemented(msg.format(keras_layer.padding))
if is_deconv:
out = _op.nn.conv2d_transpose(data=inexpr, **params)
else:
......@@ -260,7 +267,10 @@ def _convert_separable_convolution(inexpr, keras_layer, etab):
inexpr = _op.nn.pad(data=inexpr, pad_width=(
(0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
msg = 'Padding with {} is not supported for operator Separable ' \
'Convolution in frontend Keras.'
raise tvm.error.OpAttributeUnimplemented(msg.format(keras_layer.padding))
depthconv = _op.nn.conv2d(data=inexpr, **params0)
# pointwise conv
weight1 = weightList[1].transpose([3, 2, 0, 1])
......@@ -313,13 +323,15 @@ def _convert_pooling(inexpr, keras_layer, etab):
pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
params['padding'] = [pad_t, pad_l, pad_b, pad_r]
else:
raise TypeError("Unsupported padding type : {}".format(keras_layer.padding))
raise tvm.error.OpAttributeUnimplemented(
'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding))
if pool_type == 'MaxPooling2D':
return _op.nn.max_pool2d(inexpr, **params)
if pool_type == 'AveragePooling2D':
params['count_include_pad'] = False
return _op.nn.avg_pool2d(inexpr, **params)
raise TypeError("Unsupported pooling type : {}".format(keras_layer))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(keras_layer))
def _convert_upsample(inexpr, keras_layer, _):
......@@ -331,8 +343,8 @@ def _convert_upsample(inexpr, keras_layer, _):
elif upsample_type == 'UpSampling2D':
h, w = keras_layer.size
if h != w:
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
raise tvm.error.OpAttributeInvalid(
'Height must equal width for operator Upsample.')
params = {'scale': h}
if hasattr(keras_layer, 'interpolation'):
......@@ -345,24 +357,24 @@ def _convert_upsample(inexpr, keras_layer, _):
elif upsample_type == 'UpSampling3D':
h, w, d = keras_layer.size
if h != w or w != d:
raise TypeError("Unsupported upsampling type with different axes size : {}"
.format(keras_layer.size))
raise tvm.error.OpAttributeInvalid(
'Height, width, and depth must all be equal for operator Upsample.')
params = {'scale': h}
else:
raise TypeError("Unsupported upsampling type : {}".format(upsample_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(upsample_type))
return _op.nn.upsampling(inexpr, **params)
def _convert_cropping(inexpr, keras_layer, _):
_check_data_format(keras_layer)
crop_type = type(keras_layer).__name__
if crop_type == 'Cropping1D':
raise NotImplementedError("Cropping1D not implemented")
elif crop_type == 'Cropping2D':
if crop_type == 'Cropping2D':
(_, in_h, in_w, _) = keras_layer.input_shape
((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
else:
raise TypeError("Unrecognized cropping type : {}".format(crop_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(crop_type))
int32_max = np.iinfo(np.int32).max
return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
......@@ -407,14 +419,18 @@ def _convert_padding(inexpr, keras_layer, _):
top, bottom = padding[0]
left, right = padding[1]
else:
raise ValueError("Unrecognized padding option: {}".format(str(padding)))
msg = 'Value {} in attribute "padding" of operator Padding ' \
'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
else:
raise ValueError("Unrecognized padding option: {}".format(str(padding)))
elif padding_type == 'ZeroPadding1D':
raise NotImplementedError("ZeroPadding1D not implemented")
msg = 'Value {} in attribute "padding" of operator Padding is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
else:
raise ValueError("Unrecognized padding type: {}".format(padding_type))
return _op.nn.pad(data=inexpr, pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
msg = 'Operator {} is not supported in frontend Keras.'
raise tvm.error.OpNotImplemented(msg.format(padding_type))
return _op.nn.pad(data=inexpr,
pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
def _convert_concat(inexpr, keras_layer, _):
......@@ -601,8 +617,10 @@ _convert_map = {
def _check_unsupported_layers(model):
for layer in model.layers:
if type(layer).__name__ not in _convert_map:
raise ValueError("Keras layer {} not supported.".format(type(layer).__name__))
op_name = type(layer).__name__
if op_name not in _convert_map:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend Keras.'.format(op_name))
def keras_op_to_relay(inexpr, keras_layer, outname, etab):
......@@ -622,9 +640,11 @@ def keras_op_to_relay(inexpr, keras_layer, outname, etab):
etab : relay.frontend.common.ExprTable
The global expression table to be updated.
"""
if type(keras_layer).__name__ not in _convert_map:
raise NotImplementedError("{} is not supported".format((type(keras_layer).__name__)))
outs = _convert_map[type(keras_layer).__name__](inexpr, keras_layer, etab)
op_name = type(keras_layer).__name__
if op_name not in _convert_map:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(op_name))
outs = _convert_map[op_name](inexpr, keras_layer, etab)
outs = _as_list(outs)
for t_idx, out in enumerate(outs):
name = outname + ":" + str(t_idx)
......
......@@ -3,10 +3,12 @@
from __future__ import absolute_import as _abs
import json
import tvm
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
from ... import nd as _nd
from .common import StrAttrsDict
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
......@@ -41,7 +43,8 @@ def _get_channel_axis(layout, op_name):
return 1
if layout == "NHWC":
return 3
raise RuntimeError("layout: {} is not supported in {}".format(layout, op_name))
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name))
def _mx_activations(inputs, attrs):
......@@ -61,7 +64,8 @@ def _mx_activations(inputs, attrs):
return _op.add(_op.log(_op.add(one, exp_neg_abs_x)),
_op.nn.relu(x))
return _stable_softrelu(inputs[0])
raise RuntimeError("Do not support act_type: {}".format(act_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend MXNet.'.format(act_type))
def _mx_compare(new_op, wrapper):
......@@ -74,7 +78,8 @@ def _mx_compare(new_op, wrapper):
def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
raise RuntimeError("non-2d kernel is not supported in conv2d")
raise tvm.error.OpAttributeInvalid(
'Non-2D kernels are not supported for operator Conv2D.')
data_layout = attrs.get_str("layout", "NCHW")
channel_axis = _get_channel_axis(data_layout, "conv2d")
......@@ -102,10 +107,12 @@ def _mx_conv2d(inputs, attrs):
def _mx_conv2d_transpose(inputs, attrs):
if "target_shape" in attrs.attrs:
raise RuntimeError("target_shape is not supported in conv2d_transpose")
raise tvm.error.OpAttributeUnimplemented(
'Attribute "target_shape" is not supported for operator Conv2D-transpose.')
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
raise RuntimeError("non-2d kernel is not supported in conv2d")
raise tvm.error.OpAttributeInvalid(
'Non-2D kernels are not supported for operator Conv2D-transpose.')
data_layout = attrs.get_str("layout", "NCHW")
channel_axis = _get_channel_axis(data_layout, "conv2d_transpose")
......@@ -140,7 +147,8 @@ def _mx_pooling(inputs, attrs):
def _pool2d(new_op, is_avg):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
raise RuntimeError("non-2d kernel is not supported in pool2d")
raise tvm.error.OpAttributeInvalid(
'Only 2D kernels are supported for operator Pool2D.')
new_attrs = {}
new_attrs["pool_size"] = kernel_size
new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1))
......@@ -158,7 +166,8 @@ def _mx_pooling(inputs, attrs):
if global_pool:
return _op.nn.global_avg_pool2d(inputs[0])
return _pool2d(_op.nn.avg_pool2d, True)
raise RuntimeError("Do not support pool_type:{}".format(pool_type))
raise tvm.error.OpNotImplemented(
'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize()))
def _mx_dropout(inputs, attrs):
......@@ -172,7 +181,8 @@ def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument
def _mx_batch_norm(inputs, attrs):
if attrs.get_bool("output_mean_var", False):
raise RuntimeError("batch_norm do not support output_mean_var")
raise tvm.error.OpAttributeUnimplemented(
'Attribute "output_mean_var" is not supported for operator Batch Norm.')
if attrs.get_bool("use_global_stats", False):
_warn_not_used("use_global_stats", "batch_norm")
new_attrs = {}
......@@ -188,10 +198,18 @@ def _mx_slice(inputs, attrs):
begin = attrs.get_int_tuple('begin', None)
end = attrs.get_int_tuple('end', None)
stride = attrs.get_int_tuple('step', None)
if begin is None or end is None:
raise RuntimeError("begin and end are required parameters.")
if None in begin or None in end:
raise RuntimeError("None in begin or end is not supported yet.")
if begin is None:
raise tvm.error.OpAttributeRequired(
'Attribute "begin" not found in operator Slice.')
if end is None:
raise tvm.error.OpAttributeRequired(
'Attribute "end" not found in operator Slice.')
if None in begin:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "begin" of operator Slice is not valid.')
if None in end:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "end" of operator Slice is not valid.')
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['strides'] = stride
......@@ -295,7 +313,8 @@ def _mx_leaky_relu(inputs, attrs):
upper_bound = attrs.get_float("upper_bound")
alpha = (lower_bound + upper_bound) / 2.0
return _op.nn.leaky_relu(inputs[0], alpha=alpha)
raise RuntimeError("act_type: {} is not supported".format(act_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend MXNet.'.format(act_type))
def _mx_make_power(power):
......@@ -389,7 +408,9 @@ def _mx_batch_dot(inputs, attrs):
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True:
raise RuntimeError("batch_dot: only support transpose_a=False")
msg = 'Value {} in attribute "transpose_a" of operator batch_dot ' \
'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.batch_matmul(a, b)
......@@ -398,7 +419,8 @@ def _mx_batch_dot(inputs, attrs):
def _mx_arange(inputs, attrs):
assert len(inputs) == 0
if attrs.get_int("repeat", 1) != 1:
raise RuntimeError("arange doesn't support repeat")
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = attrs.get_float("start", 0)
new_attrs["stop"] = attrs.get_float("stop")
......@@ -482,15 +504,20 @@ def _mx_box_nms(inputs, attrs):
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
if coord_start != 2:
raise RuntimeError('coord_start %s is not supported.' % coord_start)
raise tvm.error.OpAttributeInvalid(
'Value of attribute "coord_start" must equal 2 for operator box_nms.')
if score_index != 1:
raise RuntimeError('score_index %s is not supported.' % score_index)
raise tvm.error.OpAttributeInvalid(
'Value of attribute "score_index" must equal 1 for operator box_nms.')
if id_index != -1 and int(id_index) != 0:
raise RuntimeError('id_index %s is not supported.' % id_index)
raise tvm.error.OpAttributeInvalid(
'Value of attribute "id_index" must equal either -1 or 0 for operator box_nms.')
if in_format != 'corner':
raise RuntimeError('in_format %s is not supported.' % in_format)
raise tvm.error.OpAttributeInvalid(
'Value of attribute "in_format" must equal "corner" for operator box_nms.')
if out_format != 'corner':
raise RuntimeError('out_format %s is not supported.' % out_format)
raise tvm.error.OpAttributeInvalid(
'Value of attribute "out_format" must equal "corner" for operator box_nms.')
ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh)
nms_out = _op.vision.non_max_suppression(ret[1],
......@@ -508,7 +535,8 @@ def _mx_l2_normalize(inputs, attrs):
new_attrs = {}
mode = attrs.get_str('mode', 'instance')
if mode != 'channel':
raise RuntimeError('mode %s is not supported.' % mode)
raise tvm.error.OpAttributeInvalid(
'Value of attribute "mode" must equal "channel" for operator l2_normalize.')
new_attrs['eps'] = attrs.get_float('eps', 1e-10)
new_attrs['axis'] = [1]
return _op.nn.l2_normalize(inputs[0], **new_attrs)
......@@ -771,7 +799,8 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
raise RuntimeError("unexpected type %s" % type(res))
node_map[nid] = res
else:
raise RuntimeError("{} is not supported in relay frontend".format(op_name))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported in frontend MXNet.'.format(op_name))
outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
......
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
import logging
import tvm
import numpy as np
from ... import nd as _nd
from .. import ir_pass
......@@ -18,7 +19,9 @@ def dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
msg = 'Only 2D kernels are supported for operator {}.'
op_name = prefix + '2d'
raise tvm.error.OpAttributeInvalid(msg.format(op_name))
return _impl
......@@ -29,7 +32,8 @@ def revert_caffe2_pad(pads):
elif len(pads) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(pads))
raise tvm.error.OpAttributeInvalid(
'Number of pads must be either 2 or 4.')
return pads
def dimension_constraint():
......@@ -461,7 +465,8 @@ class Upsample(OnnxOpConverter):
elif mode == b'linear':
method = "BILINEAR"
else:
raise ValueError("Invalid ONNX upsample mode: {}".format(mode))
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'}
return AttrCvt('upsampling')(inputs, attr)
......@@ -718,8 +723,9 @@ class ConstantFill(OnnxOpConverter):
shape = params[get_name(inputs[0])].asnumpy()
else:
if 'extra_shape' in attr:
raise ImportError(
"Extra Shape not supported with fill_like")
raise tvm.error.OpAttributeInvalid('Attribute "extra_shape" not '
'supported with "fill_like" for '
'operator ConstantFill.')
return _op.full_like(inputs[0], inputs[1])
if 'extra_shape' in attr:
......
......@@ -27,7 +27,8 @@ def _get_relay_op(op_name):
op = getattr(_op.image, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TensorFlow.'.format(op_name))
return op
class AttrCvt(object):
......@@ -99,7 +100,8 @@ class AttrCvt(object):
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
raise tvm.error.OpAttributeUnimplemented(
'Attribute {} in operator {} is not supported.'.format(k, op_name))
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.%s", k, op_name)
elif k in self._ignores:
......@@ -148,7 +150,8 @@ class AttrCvt(object):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
raise tvm.error.OpAttributeRequired(
'Attribute {} not found in operator {}'.format(key, self._op_name))
return attr[key]
def _get_pad_pair(input1d, kernel1d, stride1d):
......@@ -178,7 +181,8 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
raise tvm.error.OpAttributeInvalid(
'Only 2D kernels are supported for operator {}'.format(prefix + '2d'))
return _impl
def _dimension_constraint():
......@@ -238,7 +242,9 @@ def _pooling(name):
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
msg = 'Value {} of attribute "data_format" of operator Pooling ' \
'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attrs['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]]
......@@ -267,7 +273,9 @@ def _pooling(name):
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
msg = 'Value {} in attribute "padding" of operator Pooling is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
if name == "avg_pool":
attr['count_include_pad'] = False
......@@ -341,7 +349,9 @@ def _conv(opname):
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
msg = 'Value {} in attribute "data_format" of operator Conv is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if opname == 'depthwise':
......@@ -386,7 +396,9 @@ def _conv(opname):
attr['padding'] = [0, 0]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
'valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
if 'kernel_layout' not in attr:
if opname == 'conv':
......@@ -791,7 +803,8 @@ def _pad(name):
if padlist_key in params:
padlist = params.pop(padlist_key).asnumpy()
else:
raise RuntimeError("Required parameter {} not fount.".format(padlist_key))
raise tvm.error.OpAttributeRequired(
'Attribute {} not found in operator Pad.'.format(padlist_key))
paddings = tuple([tuple(l) for l in padlist])
attr['pad_width'] = paddings
attr['pad_value'] = 0
......
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
import math
import numpy as np
import tvm
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
......@@ -59,8 +60,10 @@ class OperatorConverter(object):
unsupported_ops_set.add(op_code_str)
if unsupported_ops_set:
raise NotImplementedError("Unsupported Ops: %s" % (
','.join(unsupported_ops_set)))
msg = 'The following operators are not supported in frontend ' \
'TFLite: {}'
ops = str(list(unsupported_ops_set)).strip('[,]')
raise tvm.error.OpNotImplemented(msg.format(ops))
def convert_op_to_relay(self):
"""Convert TFLite ops to relay ops"""
......@@ -205,8 +208,8 @@ class OperatorConverter(object):
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
raise NotImplementedError("Not support input shape length {} of reshape : "
.format(str(input_shape_length)))
msg = 'Input shape length {} for operator Reshape is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.reshape(in_expr, newshape=tuple(target_shape))
......@@ -223,8 +226,8 @@ class OperatorConverter(object):
elif len(target_shape) == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
raise NotImplementedError("Not support to reshape to shape length {}: "
.format(str(len(target_shape))))
raise tvm.error.OpAttributeInvalid(
'Length of target shape must be between 1 and 5 for operator Reshape.')
return out
......@@ -330,8 +333,8 @@ class OperatorConverter(object):
# finally convert back if necessary
in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1))
else:
raise NotImplementedError("Not support input shape length {} of squeeze : "
.format(str(input_shape_length)))
msg = 'Input shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))
out = _op.squeeze(in_expr, axis=tuple(squeeze_axis))
......@@ -348,8 +351,8 @@ class OperatorConverter(object):
elif output_shape_length == 4:
out = _op.transpose(out, axes=(0, 3, 1, 2))
else:
raise NotImplementedError("Not support to squeeze to length {} : "
.format(str(output_shape_length)))
msg = 'Output shape length {} for operator Squeeze is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length))
return out
......@@ -369,8 +372,8 @@ class OperatorConverter(object):
if fused_activation_fn == ActivationFunctionType.TANH:
return _op.tanh(in_expr)
fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
raise NotImplementedError("Unsupported fused activation fn {}"
.format(fused_activation_fn_str))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str))
def convert_conv(self, op, conv_type):
"""convolution implementation."""
......@@ -409,7 +412,8 @@ class OperatorConverter(object):
assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \
"no matter original value be set by 0.25, 0.5 or any else"
else:
raise ValueError("Not support conv type: {}".format(conv_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(conv_type))
stride_h = conv_options.StrideH()
stride_w = conv_options.StrideW()
......@@ -466,7 +470,8 @@ class OperatorConverter(object):
(pad_top, pad_bottom),
(pad_left, pad_right)))
else:
raise NotImplementedError("Not support padding format: {}".format(padding))
raise tvm.error.OpAttributeUnimplemented(
'Padding format {} is not supported for operator Conv.'.format(padding))
out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params)
......@@ -529,14 +534,16 @@ class OperatorConverter(object):
pad_left, pad_right = get_pad_value(input_w, filter_w, stride_w)
params['padding'] = [pad_top, pad_left, pad_bottom, pad_right]
else:
raise NotImplementedError("Not support padding format: {}".format(padding))
raise tvm.error.OpAttributeUnimplemented(
'Padding format {} for operator Pool2D is not supported.'.format(padding))
if pool_type == "average":
out = _op.nn.avg_pool2d(in_expr, **params)
elif pool_type == "max":
out = _op.nn.max_pool2d(in_expr, **params)
else:
raise ValueError("Not support pool type: {}".format(pool_type))
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool'))
# If we have fused activations
if fused_activation_fn != ActivationFunctionType.NONE:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment