Commit e722dbcb by Wenhao Hu Committed by Tianqi Chen

Onnx opset support (#416)

parent f4789db6
...@@ -4,87 +4,119 @@ from __future__ import absolute_import as _abs ...@@ -4,87 +4,119 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from .. import symbol as _sym from .. import symbol as _sym
from .. import graph as _graph from .. import graph as _graph
from .. compiler import graph_util from ..compiler import graph_util
from .common import get_nnvm_op, Renamer, AttrConverter as AttrCvt from .common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
__all__ = ['from_onnx'] __all__ = ['from_onnx']
def _revert_caffe2_pad(attr):
"""Caffe2 require two times the normal padding."""
if len(attr) == 4:
attr = attr[:2]
elif len(attr) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
return attr
def _math_name_picker(surfix): class OnnxOpConverter(object):
def _impl(attr): """ A helper class for holding onnx op converters.
if attr.get('broadcast', 0): """
return 'broadcast_' + surfix
return 'elemwise_' + surfix
return _impl
def _broadcast_constraint(): @classmethod
def _broadcast_check(attrs): def get_converter(cls, opset):
if attrs.get('axis', None): """ Get converter matches given opset.
return False
return True
return _broadcast_check, "Specifying broadcast axis not allowed."
def _dimension_picker(prefix, surfix=''): :param opset: opset from model.
def _impl(attr): :return: converter, which should be `_impl_vx`. Number x is the biggest
kernel = attr['kernel_shape'] number smaller than or equal to opset belongs to all support versions.
if len(kernel) == 2: """
return prefix + '2d' + surfix versions = [
int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
]
versions = sorted(versions + [opset])
version = versions[
max([i for i, v in enumerate(versions) if v == opset]) - 1]
if hasattr(cls, '_impl_v{}'.format(version)):
return getattr(cls, '_impl_v{}'.format(version))
else: else:
raise NotImplementedError("Only 2d kernel supported.") raise NotImplementedError(
return _impl 'opset version {} of {} not implemented'.format(
version, cls.__name__))
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False): class Elemwise(OnnxOpConverter):
"""A hack for getting 'channles' or 'units' since onnx don't provide """ A helper class for elemwise op converters.
these attributes. We check the shape of weights provided to get the number.
""" """
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _elemwise(name): name = ''
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) @classmethod
op_name = _math_name_picker(name)(attr) def _math_name_picker(cls, suffix):
def _impl(attr):
if attr.get('broadcast', 0):
return 'broadcast_' + suffix
return 'elemwise_' + suffix
return _impl
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
len(inputs))
op_name = cls._math_name_picker(cls.name)(attr)
axis = int(attr.get('axis', 0)) axis = int(attr.get('axis', 0))
conv_ops = ["conv2d", "conv2d_transpose"] conv_ops = ["conv2d", "conv2d_transpose"]
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops: if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
# TODO(zhreshold): remove hard coded infershape # TODO(zhreshold): remove hard coded infershape
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2) inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_nnvm_op(op_name)(*inputs) return get_nnvm_op(op_name)(*inputs)
return _impl
def _pooling(name):
return AttrCvt( class Pool(OnnxOpConverter):
op_name=_dimension_picker(name), """ A helper class for pool op converters.
transforms={ """
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)}, name = ''
# very weird attributes here in onnx, force check
ignores=['dilations'], @classmethod
# TODO(zhreshold): make sure ceil_mode in onnx, and layout? def _impl_v1(cls, inputs, attr, params):
extras={'ceil_mode': False}, return AttrCvt(
custom_check=_dimension_constraint()) op_name=_dimension_picker(cls.name),
transforms={
def _conv(): 'kernel_shape': 'pool_size',
def _impl(inputs, attr, params): 'pads': ('padding', (0, 0), _revert_caffe2_pad)
},
# very weird attributes here in onnx, force check
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr, params)
class Absolute(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
class Add(Elemwise):
name = 'add'
class AveragePool(Pool):
name = 'avg_pool'
class BatchNorm(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# TODO(zhreshold): 'spatial' is not properly handled here.
return AttrCvt(
op_name='batch_norm',
disables=['momentum'],
ignores=['spatial', 'is_test', 'consumed_inputs'])(inputs, attr,
params)
class Conv(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels # get number of channels
channels = _infer_channels(inputs[1], params) channels = _infer_channels(inputs[1], params)
attr['channels'] = channels attr['channels'] = channels
...@@ -94,13 +126,16 @@ def _conv(): ...@@ -94,13 +126,16 @@ def _conv():
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)), 'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad), 'pads': ('padding', (0, 0), _revert_caffe2_pad),
'group': ('groups', 1)}, 'group': ('groups', 1)
},
extras={'use_bias': len(inputs) == 3}, extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr) custom_check=_dimension_constraint())(inputs, attr, params)
return _impl
def _conv_transpose():
def _impl(inputs, attr, params): class ConvTranspose(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels # get number of channels
channels = _infer_channels(inputs[1], params, True) channels = _infer_channels(inputs[1], params, True)
attr['channels'] = channels attr['channels'] = channels
...@@ -111,31 +146,34 @@ def _conv_transpose(): ...@@ -111,31 +146,34 @@ def _conv_transpose():
transforms={ transforms={
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)), 'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad)}, 'pads': ('padding', (0, 0), _revert_caffe2_pad)
},
disables=['output_shape'], disables=['output_shape'],
extras={'use_bias': len(inputs) == 3}, extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr) custom_check=_dimension_constraint())(inputs, attr, params)
return _impl
def _fully_connected():
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
return _impl
def _batch_norm(): class Div(Elemwise):
# TODO(zhreshold): 'spatial' is not properly handled here. name = 'div'
return AttrCvt(
op_name='batch_norm',
disables=['momentum'],
ignores=['spatial', 'is_test', 'consumed_inputs'])
def _gemm(): class Elu(OnnxOpConverter):
def _impl(inputs, attr, params):
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs)) @classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(
inputs[0])
class Gemm(OnnxOpConverter):
""" Operator converter for Gemm.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
len(inputs))
# Y = alpha * A * B + beta * C # Y = alpha * A * B + beta * C
alpha = float(attr.get('alpha', 1.0)) alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0)) beta = float(attr.get('beta', 1.0))
...@@ -147,217 +185,325 @@ def _gemm(): ...@@ -147,217 +185,325 @@ def _gemm():
inputs[0] = _sym.transpose(inputs[0], axes=(1, 0)) inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
if not transB: if not transB:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0)) inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
return _sym.dense(alpha * inputs[0], inputs[1], beta * inputs[2], units=channels) return _sym.dense(
return _impl alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
def _thresholded_relu():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
return _sym.relu(inputs[0] - alpha)
return _impl
def _scaled_tanh(): class MaxPool(Pool):
def _impl(inputs, attr, params): name = 'max_pool'
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
return _sym.tanh(beta * inputs[0]) * alpha
return _impl
def parametric_soft_plus():
def _impl(inputs, attr, params): class Mul(Elemwise):
name = 'mul'
class Pad(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params, True)
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
return AttrCvt(
op_name='pad',
transforms={
'value': 'pad_value',
'pads': 'pad_width'
},
custom_check=lambda attrs: attrs.get('mode') == 'constant')(
inputs, attr, params)
class ParametricSoftPlus(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0)) alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0)) beta = float(attr.get('beta', 1.0))
return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha
return _impl
def _scale():
def _impl(inputs, attr, params): class Prelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(
len(inputs))
channels = _infer_channels(inputs[1], params, False)
if channels == 1:
return inputs[0] * inputs[1]
return _sym.broadcast_mul(inputs[0], inputs[1])
class Reciprocal(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return 1.0 / inputs[0]
class Reshape(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _sym.reshape(inputs[0], shape=attr['shape'])
@classmethod
def _impl_v5(cls, inputs, attr, params):
return _sym.reshape(
inputs[0],
shape=tuple(params[inputs[1].list_output_names()[0]].asnumpy()))
class Scale(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
scale = float(attr.get('scale', 1.0)) scale = float(attr.get('scale', 1.0))
return inputs[0] * scale return inputs[0] * scale
return _impl
def _absolute():
"""This is a workaround."""
def _impl(inputs, attr, params):
return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
return _impl
def _reciprocal(): class Selu(OnnxOpConverter):
def _impl(inputs, attr, params):
return 1.0 / inputs[0]
return _impl
def _selu(): @classmethod
def _impl(inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.6732)) alpha = float(attr.get('alpha', 1.6732))
gamma = float(attr.get('gamma', 1.0507)) gamma = float(attr.get('gamma', 1.0507))
return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) return gamma * (
+ _sym.relu(inputs[0])) -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
return _impl
def _elu():
def _impl(inputs, attr, params): class ScaledTanh(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0)) alpha = float(attr.get('alpha', 1.0))
return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]) beta = float(attr.get('beta', 1.0))
return _impl return _sym.tanh(beta * inputs[0]) * alpha
def _prelu():
def _impl(inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
channels = _infer_channels(inputs[1], params, False)
if channels == 1:
return inputs[0] * inputs[1]
return _sym.broadcast_mul(inputs[0], inputs[1])
return _impl
def _softsign(): class SoftPlus(OnnxOpConverter):
def _impl(inputs, attr, params):
return inputs[0] / (1 + _absolute()(inputs, attr, params))
return _impl
def _softplus(): @classmethod
def _impl(inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
return _sym.log(_sym.exp(inputs[0]) + 1) return _sym.log(_sym.exp(inputs[0]) + 1)
return _impl
def _pad():
def _impl(inputs, attr, params): class Softsign(OnnxOpConverter):
# get number of channels
channels = _infer_channels(inputs[1], params, True) @classmethod
attr['channels'] = channels def _impl_v1(cls, inputs, attr, params):
groups = attr.pop('group') return inputs[0] / (1 + Absolute.get_converter(1)(inputs, attr, params))
attr['groups'] = groups
return AttrCvt(
op_name='pad', class Sub(Elemwise):
transforms={ name = 'sub'
'value': 'pad_value',
'pads': 'pad_width'},
custom_check=lambda attrs: attrs.get('mode') == 'constant')(inputs, attr) class Sum(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# Onnx Sum Operator
for in_index in range(len(inputs) - 1):
inputs[in_index + 1] = _sym.broadcast_add(inputs[in_index],
inputs[in_index + 1])
return inputs[len(inputs) - 1]
class ThresholdedRelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
return _sym.relu(inputs[0] - alpha)
def _revert_caffe2_pad(attr):
"""Caffe2 require two times the normal padding."""
if len(attr) == 4:
attr = attr[:2]
elif len(attr) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
return attr
def _broadcast_constraint():
def _broadcast_check(attrs):
if attrs.get('axis', None):
return False
return True
return _broadcast_check, "Specifying broadcast axis not allowed."
def _dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl return _impl
def _sum():
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _fully_connected(opset):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# Onnx Sum Operator # get number of channels
for in_index in range(len(inputs)-1): channels = _infer_channels(inputs[1], params)
inputs[in_index+1] = _sym.broadcast_add(inputs[in_index], inputs[in_index+1]) attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
return inputs[len(inputs)-1]
return _impl return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
# _convert_map defines maps of name to converter functor(callable) # _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different # for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted # use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions # for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?) # for N to 1 mapping, currently not supported(?)
_convert_map = { def _get_convert_map(opset):
# defs/experimental return {
'Identity' : Renamer('copy'), # defs/experimental
# 'Affine' 'Identity': Renamer('copy'),
'ThresholdedRelu': _thresholded_relu(), # 'Affine'
'ScaledTanh' : _scaled_tanh(), 'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ParametricSoftplus': parametric_soft_plus(), 'ScaledTanh': ScaledTanh.get_converter(opset),
# 'ConstantFill' 'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
# 'GivenTensorFill' # 'ConstantFill'
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']), # 'GivenTensorFill'
'Scale' : _scale(), 'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
# 'GRUUnit' 'Scale': Scale.get_converter(opset),
# 'ATen' # 'GRUUnit'
# 'ImageScaler' # 'ATen'
# 'MeanVarianceNormalization' # 'ImageScaler'
# 'Crop' # 'MeanVarianceNormalization'
# 'Embedding' # 'Crop'
# 'Upsample' # 'Embedding'
'SpatialBN' : _batch_norm(), # 'Upsample'
'SpatialBN': BatchNorm.get_converter(opset),
# defs/generator
# 'Constant' # defs/generator
# 'RandomUniform' # 'Constant'
# 'RandomNormal' # 'RandomUniform'
# 'RandomUniformLike' # 'RandomNormal'
# 'RandomNormalLike' # 'RandomUniformLike'
# 'RandomNormalLike'
# defs/logical
# defs/logical
# defs/math
'Add' : _elemwise('add'), # defs/math
'Sub' : _elemwise('sub'), 'Add': Add.get_converter(opset),
'Mul' : _elemwise('mul'), 'Sub': Sub.get_converter(opset),
'Div' : _elemwise('div'), 'Mul': Mul.get_converter(opset),
'Neg' : Renamer('negative'), 'Div': Div.get_converter(opset),
'Abs' : _absolute(), 'Neg': Renamer('negative'),
'Reciprocal' : _reciprocal(), 'Abs': Absolute.get_converter(opset),
# 'Floor' 'Reciprocal': Reciprocal.get_converter(opset),
# 'Ceil' # 'Floor'
'Sqrt' : Renamer('sqrt'), # 'Ceil'
'Relu' : Renamer('relu'), 'Sqrt': Renamer('sqrt'),
'LeakyRelu' : Renamer('leaky_relu'), 'Relu': Renamer('relu'),
'Selu' : _selu(), 'LeakyRelu': Renamer('leaky_relu'),
'Elu' : _elu(), 'Selu': Selu.get_converter(opset),
'Exp' : Renamer('exp'), 'Elu': Elu.get_converter(opset),
'Log' : Renamer('log'), 'Exp': Renamer('exp'),
'Tanh' : Renamer('tanh'), 'Log': Renamer('log'),
# 'Pow' 'Tanh': Renamer('tanh'),
'PRelu' : _prelu(), # 'Pow'
'Sigmoid' : Renamer('sigmoid'), 'PRelu': Prelu.get_converter(opset),
# 'HardSigmoid' 'Sigmoid': Renamer('sigmoid'),
# 'Max' : this is the elemwise maximum # 'HardSigmoid'
# 'Min' : this is the elemwise minimum # 'Max' : this is the elemwise maximum
'Sum' : _sum(), # 'Min' : this is the elemwise minimum
# 'Mean' 'Sum': Sum.get_converter(opset),
# 'Clip' # 'Mean'
# softmax default axis is different in onnx # 'Clip'
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}), # softmax default axis is different in onnx
'LogSoftmax' : AttrCvt('log_softmax', {'axis': ('axis', 1)}), 'Softmax': AttrCvt('softmax', {'axis': ('axis', 1)}),
# 'Hardmax' 'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
'Softsign' : _softsign(), # 'Hardmax'
'SoftPlus' : _softplus(), 'Softsign': Softsign.get_converter(opset),
'Gemm' : _gemm(), 'SoftPlus': SoftPlus.get_converter(opset),
# 'MatMul' batch stacked dot operation 'Gemm': Gemm.get_converter(opset),
# 'MatMul' batch stacked dot operation
# defs/nn
'AveragePool' : _pooling('avg_pool'), # defs/nn
'MaxPool' : _pooling('max_pool'), 'AveragePool': AveragePool.get_converter(opset),
'Conv' : _conv(), 'MaxPool': MaxPool.get_converter(opset),
'ConvTranspose' : _conv_transpose(), 'Conv': Conv.get_converter(opset),
'GlobalAveragePool': Renamer('global_avg_pool2d'), 'ConvTranspose': ConvTranspose.get_converter(opset),
'GlobalMaxPool' : Renamer('global_max_pool2d'), 'GlobalAveragePool': Renamer('global_avg_pool2d'),
'BatchNormalization': _batch_norm(), 'GlobalMaxPool': Renamer('global_max_pool2d'),
# 'InstanceNormalization' 'BatchNormalization': BatchNorm.get_converter(opset),
# 'LpNormalization' # 'InstanceNormalization'
'Dropout' : AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), # 'LpNormalization'
'Flatten' : Renamer('flatten'), 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
# 'LRN' 'Flatten': Renamer('flatten'),
# 'LRN'
# defs/reduction
'ReduceMax' : AttrCvt('max', {'axes', 'axis'}), # defs/reduction
'ReduceMin' : AttrCvt('min', {'axes', 'axis'}), 'ReduceMax': AttrCvt('max', {'axes', 'axis'}),
'ReduceSum' : AttrCvt('sum', {'axes', 'axis'}), 'ReduceMin': AttrCvt('min', {'axes', 'axis'}),
# 'ReduceMean' 'ReduceSum': AttrCvt('sum', {'axes', 'axis'}),
# 'ReduceProd' # 'ReduceMean'
# 'ReduceLogSumExp' # 'ReduceProd'
# 'ArgMax' # 'ReduceLogSumExp'
# 'ArgMin' # 'ArgMax'
# 'ArgMin'
# defs/tensor
'Cast' : AttrCvt('cast', {'to': 'dtype'}), # defs/tensor
'Reshape' : Renamer('reshape'), 'Cast': AttrCvt('cast', {'to': 'dtype'}),
'Concat' : Renamer('concatenate'), 'Reshape': Reshape.get_converter(opset),
'Split' : AttrCvt('split', {'split': 'indices_or_sections'}), 'Concat': Renamer('concatenate'),
# 'Slice' 'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
'Transpose' : AttrCvt('transpose', {'perm': 'axes'}), # 'Slice'
# 'Gather' 'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
# 'Squeeze' # 'Gather'
'Pad' : _pad(), # 'Squeeze'
} 'Pad': Pad.get_converter(opset),
}
class GraphProto(object): class GraphProto(object):
"""A helper class for handling nnvm graph copying from pb2.GraphProto. """A helper class for handling nnvm graph copying from pb2.GraphProto.
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
""" """
def __init__(self): def __init__(self):
self._nodes = {} self._nodes = {}
self._params = {} self._params = {}
...@@ -365,7 +511,7 @@ class GraphProto(object): ...@@ -365,7 +511,7 @@ class GraphProto(object):
self._num_input = 0 self._num_input = 0
self._num_param = 0 self._num_param = 0
def from_onnx(self, graph): def from_onnx(self, graph, opset):
"""Construct nnvm nodes from onnx graph. """Construct nnvm nodes from onnx graph.
The inputs from onnx graph is vague, only providing "1", "2"... The inputs from onnx graph is vague, only providing "1", "2"...
For convenience, we rename the `real` input names to "input_0", For convenience, we rename the `real` input names to "input_0",
...@@ -375,6 +521,7 @@ class GraphProto(object): ...@@ -375,6 +521,7 @@ class GraphProto(object):
---------- ----------
graph : onnx protobuf object graph : onnx protobuf object
The loaded onnx graph The loaded onnx graph
opset : opset version
Returns Returns
------- -------
...@@ -410,7 +557,7 @@ class GraphProto(object): ...@@ -410,7 +557,7 @@ class GraphProto(object):
op_name = node.op_type op_name = node.op_type
attr = self._parse_attr(node.attribute) attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
op = self._convert_operator(op_name, inputs, attr) op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output) node_output = self._fix_outputs(op_name, node.output)
assert len(node_output) == len(op.list_output_names()), ( assert len(node_output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {} in {}.".format( "Number of output mismatch {} vs {} in {}.".format(
...@@ -438,7 +585,8 @@ class GraphProto(object): ...@@ -438,7 +585,8 @@ class GraphProto(object):
try: try:
from onnx.numpy_helper import to_array from onnx.numpy_helper import to_array
except ImportError as e: except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e)) raise ImportError(
"Unable to import onnx which is required {}".format(e))
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
return tvm.nd.array(np_array) return tvm.nd.array(np_array)
...@@ -455,15 +603,23 @@ class GraphProto(object): ...@@ -455,15 +603,23 @@ class GraphProto(object):
attrs[a.name] = tuple(getattr(a, f)) attrs[a.name] = tuple(getattr(a, f))
for f in ['t', 'g']: for f in ['t', 'g']:
if a.HasField(f): if a.HasField(f):
raise NotImplementedError("Filed {} is not supported in nnvm.".format(f)) raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f))
for f in ['tensors', 'graphs']: for f in ['tensors', 'graphs']:
if list(getattr(a, f)): if list(getattr(a, f)):
raise NotImplementedError("Filed {} is not supported in nnvm.".format(f)) raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f))
if a.name not in attrs: if a.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs return attrs
def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None): def _convert_operator(self,
op_name,
inputs,
attrs,
opset,
identity_list=None,
convert_map=None):
"""Convert from onnx operator to nnvm operator. """Convert from onnx operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes. apply handlers to operator attributes.
...@@ -476,6 +632,8 @@ class GraphProto(object): ...@@ -476,6 +632,8 @@ class GraphProto(object):
List of input symbols. List of input symbols.
attrs : dict attrs : dict
Dict of operator attributes Dict of operator attributes
opset : int
Opset version
identity_list : list identity_list : list
List of operators that don't require conversion List of operators that don't require conversion
convert_map : dict convert_map : dict
...@@ -489,13 +647,14 @@ class GraphProto(object): ...@@ -489,13 +647,14 @@ class GraphProto(object):
Converted nnvm Symbol Converted nnvm Symbol
""" """
identity_list = identity_list if identity_list else _identity_list identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map convert_map = convert_map if convert_map else _get_convert_map(opset)
if op_name in identity_list: if op_name in identity_list:
sym = get_nnvm_op(op_name)(*inputs, **attrs) sym = get_nnvm_op(op_name)(*inputs, **attrs)
elif op_name in convert_map: elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params) sym = convert_map[op_name](inputs, attrs, self._params)
else: else:
raise NotImplementedError("Operator {} not implemented.".format(op_name)) raise NotImplementedError(
"Operator {} not implemented.".format(op_name))
return sym return sym
def _fix_outputs(self, op_name, outputs): def _fix_outputs(self, op_name, outputs):
...@@ -510,7 +669,7 @@ class GraphProto(object): ...@@ -510,7 +669,7 @@ class GraphProto(object):
return outputs return outputs
def from_onnx(graph): def from_onnx(model):
"""Load onnx graph which is a python protobuf object into nnvm graph. """Load onnx graph which is a python protobuf object into nnvm graph.
The companion parameters will be handled automatically. The companion parameters will be handled automatically.
The inputs from onnx graph is vague, only providing "1", "2"... The inputs from onnx graph is vague, only providing "1", "2"...
...@@ -519,8 +678,8 @@ def from_onnx(graph): ...@@ -519,8 +678,8 @@ def from_onnx(graph):
Parameters Parameters
---------- ----------
graph : protobuf object model : protobuf object
ONNX GraphProto, or ONNX ModelProto after ONNX v0.2 ONNX ModelProto after ONNX v1.1.0
Returns Returns
------- -------
...@@ -531,8 +690,7 @@ def from_onnx(graph): ...@@ -531,8 +690,7 @@ def from_onnx(graph):
Dict of converted parameters stored in tvm.ndarray format Dict of converted parameters stored in tvm.ndarray format
""" """
g = GraphProto() g = GraphProto()
if hasattr(graph, 'graph'): graph = model.graph
# it's a ModelProto wrapper opset = model.opset_import[0].version if model.opset_import else 1
graph = graph.graph sym, params = g.from_onnx(graph, opset)
sym, params = g.from_onnx(graph)
return sym, params return sym, params
pip2 install onnx>=0.2.0 pip2 install onnx>=1.1.0
pip3 install onnx>=0.2.0 pip3 install onnx>=1.1.0
pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
pip2 install torchvision pip2 install torchvision
......
...@@ -14,8 +14,8 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): ...@@ -14,8 +14,8 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
c2_out = prepared_backend.run(W)[0] c2_out = prepared_backend.run(W)[0]
return c2_out return c2_out
def get_tvm_output(graph, x, target, ctx, dtype='float32'): def get_tvm_output(model, x, target, ctx, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(graph) new_sym, params = nnvm.frontend.from_onnx(model)
shape_dict = {'input_0': x.shape} shape_dict = {'input_0': x.shape}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params) graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
......
...@@ -5,8 +5,8 @@ from nnvm.compiler import graph_util, graph_attr ...@@ -5,8 +5,8 @@ from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution, super_resolution_sym from model_zoo import super_resolution, super_resolution_sym
def compare_graph(onnx_file, nnvm_sym, ishape): def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_graph = onnx.load(onnx_file) onnx_model = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph) onnx_sym, params = nnvm.frontend.from_onnx(onnx_model)
g1 = nnvm.graph.create(onnx_sym) g1 = nnvm.graph.create(onnx_sym)
g2 = nnvm.graph.create(nnvm_sym) g2 = nnvm.graph.create(nnvm_sym)
ishapes = {'input_0': ishape} ishapes = {'input_0': ishape}
......
...@@ -44,9 +44,9 @@ model_url = ''.join(['https://gist.github.com/zhreshold/', ...@@ -44,9 +44,9 @@ model_url = ''.join(['https://gist.github.com/zhreshold/',
'super_resolution_0.2.onnx']) 'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx', True) download(model_url, 'super_resolution.onnx', True)
# now you have super_resolution.onnx on disk # now you have super_resolution.onnx on disk
onnx_graph = onnx.load('super_resolution.onnx') onnx_model = onnx.load('super_resolution.onnx')
# we can load the graph as NNVM compatible model # we can load the graph as NNVM compatible model
sym, params = nnvm.frontend.from_onnx(onnx_graph) sym, params = nnvm.frontend.from_onnx(onnx_model)
###################################################################### ######################################################################
# Load a test image # Load a test image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment