Commit e722dbcb by Wenhao Hu Committed by Tianqi Chen

Onnx opset support (#416)

parent f4789db6
......@@ -4,87 +4,119 @@ from __future__ import absolute_import as _abs
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
from ..compiler import graph_util
from .common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
__all__ = ['from_onnx']
def _revert_caffe2_pad(attr):
"""Caffe2 require two times the normal padding."""
if len(attr) == 4:
attr = attr[:2]
elif len(attr) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
return attr
def _math_name_picker(surfix):
def _impl(attr):
if attr.get('broadcast', 0):
return 'broadcast_' + surfix
return 'elemwise_' + surfix
return _impl
class OnnxOpConverter(object):
""" A helper class for holding onnx op converters.
"""
def _broadcast_constraint():
def _broadcast_check(attrs):
if attrs.get('axis', None):
return False
return True
return _broadcast_check, "Specifying broadcast axis not allowed."
@classmethod
def get_converter(cls, opset):
""" Get converter matches given opset.
def _dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
:param opset: opset from model.
:return: converter, which should be `_impl_vx`. Number x is the biggest
number smaller than or equal to opset belongs to all support versions.
"""
versions = [
int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
]
versions = sorted(versions + [opset])
version = versions[
max([i for i, v in enumerate(versions) if v == opset]) - 1]
if hasattr(cls, '_impl_v{}'.format(version)):
return getattr(cls, '_impl_v{}'.format(version))
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl
raise NotImplementedError(
'opset version {} of {} not implemented'.format(
version, cls.__name__))
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
class Elemwise(OnnxOpConverter):
""" A helper class for elemwise op converters.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
op_name = _math_name_picker(name)(attr)
name = ''
@classmethod
def _math_name_picker(cls, suffix):
def _impl(attr):
if attr.get('broadcast', 0):
return 'broadcast_' + suffix
return 'elemwise_' + suffix
return _impl
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
len(inputs))
op_name = cls._math_name_picker(cls.name)(attr)
axis = int(attr.get('axis', 0))
conv_ops = ["conv2d", "conv2d_transpose"]
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
# TODO(zhreshold): remove hard coded infershape
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_nnvm_op(op_name)(*inputs)
return _impl
def _pooling(name):
return AttrCvt(
op_name=_dimension_picker(name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)},
# very weird attributes here in onnx, force check
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=_dimension_constraint())
def _conv():
def _impl(inputs, attr, params):
class Pool(OnnxOpConverter):
""" A helper class for pool op converters.
"""
name = ''
@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(
op_name=_dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)
},
# very weird attributes here in onnx, force check
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=_dimension_constraint())(inputs, attr, params)
class Absolute(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
class Add(Elemwise):
name = 'add'
class AveragePool(Pool):
name = 'avg_pool'
class BatchNorm(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# TODO(zhreshold): 'spatial' is not properly handled here.
return AttrCvt(
op_name='batch_norm',
disables=['momentum'],
ignores=['spatial', 'is_test', 'consumed_inputs'])(inputs, attr,
params)
class Conv(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
attr['channels'] = channels
......@@ -94,13 +126,16 @@ def _conv():
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad),
'group': ('groups', 1)},
'group': ('groups', 1)
},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)
return _impl
custom_check=_dimension_constraint())(inputs, attr, params)
def _conv_transpose():
def _impl(inputs, attr, params):
class ConvTranspose(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params, True)
attr['channels'] = channels
......@@ -111,31 +146,34 @@ def _conv_transpose():
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad)},
'pads': ('padding', (0, 0), _revert_caffe2_pad)
},
disables=['output_shape'],
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)
return _impl
custom_check=_dimension_constraint())(inputs, attr, params)
def _fully_connected():
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
return _impl
def _batch_norm():
# TODO(zhreshold): 'spatial' is not properly handled here.
return AttrCvt(
op_name='batch_norm',
disables=['momentum'],
ignores=['spatial', 'is_test', 'consumed_inputs'])
class Div(Elemwise):
name = 'div'
def _gemm():
def _impl(inputs, attr, params):
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs))
class Elu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(
inputs[0])
class Gemm(OnnxOpConverter):
""" Operator converter for Gemm.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
len(inputs))
# Y = alpha * A * B + beta * C
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
......@@ -147,217 +185,325 @@ def _gemm():
inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
if not transB:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
return _sym.dense(alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
return _impl
return _sym.dense(
alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
def _thresholded_relu():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
return _sym.relu(inputs[0] - alpha)
return _impl
def _scaled_tanh():
def _impl(inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
return _sym.tanh(beta * inputs[0]) * alpha
return _impl
class MaxPool(Pool):
name = 'max_pool'
def parametric_soft_plus():
def _impl(inputs, attr, params):
class Mul(Elemwise):
name = 'mul'
class Pad(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params, True)
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
return AttrCvt(
op_name='pad',
transforms={
'value': 'pad_value',
'pads': 'pad_width'
},
custom_check=lambda attrs: attrs.get('mode') == 'constant')(
inputs, attr, params)
class ParametricSoftPlus(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha
return _impl
def _scale():
def _impl(inputs, attr, params):
class Prelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(
len(inputs))
channels = _infer_channels(inputs[1], params, False)
if channels == 1:
return inputs[0] * inputs[1]
return _sym.broadcast_mul(inputs[0], inputs[1])
class Reciprocal(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return 1.0 / inputs[0]
class Reshape(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _sym.reshape(inputs[0], shape=attr['shape'])
@classmethod
def _impl_v5(cls, inputs, attr, params):
return _sym.reshape(
inputs[0],
shape=tuple(params[inputs[1].list_output_names()[0]].asnumpy()))
class Scale(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
scale = float(attr.get('scale', 1.0))
return inputs[0] * scale
return _impl
def _absolute():
"""This is a workaround."""
def _impl(inputs, attr, params):
return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
return _impl
def _reciprocal():
def _impl(inputs, attr, params):
return 1.0 / inputs[0]
return _impl
class Selu(OnnxOpConverter):
def _selu():
def _impl(inputs, attr, params):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.6732))
gamma = float(attr.get('gamma', 1.0507))
return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0]))
+ _sym.relu(inputs[0]))
return _impl
return gamma * (
-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
def _elu():
def _impl(inputs, attr, params):
class ScaledTanh(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])
return _impl
beta = float(attr.get('beta', 1.0))
return _sym.tanh(beta * inputs[0]) * alpha
def _prelu():
def _impl(inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
channels = _infer_channels(inputs[1], params, False)
if channels == 1:
return inputs[0] * inputs[1]
return _sym.broadcast_mul(inputs[0], inputs[1])
return _impl
def _softsign():
def _impl(inputs, attr, params):
return inputs[0] / (1 + _absolute()(inputs, attr, params))
return _impl
class SoftPlus(OnnxOpConverter):
def _softplus():
def _impl(inputs, attr, params):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _sym.log(_sym.exp(inputs[0]) + 1)
return _impl
def _pad():
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params, True)
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
return AttrCvt(
op_name='pad',
transforms={
'value': 'pad_value',
'pads': 'pad_width'},
custom_check=lambda attrs: attrs.get('mode') == 'constant')(inputs, attr)
class Softsign(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
return inputs[0] / (1 + Absolute.get_converter(1)(inputs, attr, params))
class Sub(Elemwise):
name = 'sub'
class Sum(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# Onnx Sum Operator
for in_index in range(len(inputs) - 1):
inputs[in_index + 1] = _sym.broadcast_add(inputs[in_index],
inputs[in_index + 1])
return inputs[len(inputs) - 1]
class ThresholdedRelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
return _sym.relu(inputs[0] - alpha)
def _revert_caffe2_pad(attr):
"""Caffe2 require two times the normal padding."""
if len(attr) == 4:
attr = attr[:2]
elif len(attr) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(attr))
return attr
def _broadcast_constraint():
def _broadcast_check(attrs):
if attrs.get('axis', None):
return False
return True
return _broadcast_check, "Specifying broadcast axis not allowed."
def _dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def _sum():
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _fully_connected(opset):
def _impl(inputs, attr, params):
# Onnx Sum Operator
for in_index in range(len(inputs)-1):
inputs[in_index+1] = _sym.broadcast_add(inputs[in_index], inputs[in_index+1])
# get number of channels
channels = _infer_channels(inputs[1], params)
attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
return inputs[len(inputs)-1]
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
# defs/experimental
'Identity' : Renamer('copy'),
# 'Affine'
'ThresholdedRelu': _thresholded_relu(),
'ScaledTanh' : _scaled_tanh(),
'ParametricSoftplus': parametric_soft_plus(),
# 'ConstantFill'
# 'GivenTensorFill'
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']),
'Scale' : _scale(),
# 'GRUUnit'
# 'ATen'
# 'ImageScaler'
# 'MeanVarianceNormalization'
# 'Crop'
# 'Embedding'
# 'Upsample'
'SpatialBN' : _batch_norm(),
# defs/generator
# 'Constant'
# 'RandomUniform'
# 'RandomNormal'
# 'RandomUniformLike'
# 'RandomNormalLike'
# defs/logical
# defs/math
'Add' : _elemwise('add'),
'Sub' : _elemwise('sub'),
'Mul' : _elemwise('mul'),
'Div' : _elemwise('div'),
'Neg' : Renamer('negative'),
'Abs' : _absolute(),
'Reciprocal' : _reciprocal(),
# 'Floor'
# 'Ceil'
'Sqrt' : Renamer('sqrt'),
'Relu' : Renamer('relu'),
'LeakyRelu' : Renamer('leaky_relu'),
'Selu' : _selu(),
'Elu' : _elu(),
'Exp' : Renamer('exp'),
'Log' : Renamer('log'),
'Tanh' : Renamer('tanh'),
# 'Pow'
'PRelu' : _prelu(),
'Sigmoid' : Renamer('sigmoid'),
# 'HardSigmoid'
# 'Max' : this is the elemwise maximum
# 'Min' : this is the elemwise minimum
'Sum' : _sum(),
# 'Mean'
# 'Clip'
# softmax default axis is different in onnx
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
'LogSoftmax' : AttrCvt('log_softmax', {'axis': ('axis', 1)}),
# 'Hardmax'
'Softsign' : _softsign(),
'SoftPlus' : _softplus(),
'Gemm' : _gemm(),
# 'MatMul' batch stacked dot operation
# defs/nn
'AveragePool' : _pooling('avg_pool'),
'MaxPool' : _pooling('max_pool'),
'Conv' : _conv(),
'ConvTranspose' : _conv_transpose(),
'GlobalAveragePool': Renamer('global_avg_pool2d'),
'GlobalMaxPool' : Renamer('global_max_pool2d'),
'BatchNormalization': _batch_norm(),
# 'InstanceNormalization'
# 'LpNormalization'
'Dropout' : AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten' : Renamer('flatten'),
# 'LRN'
# defs/reduction
'ReduceMax' : AttrCvt('max', {'axes', 'axis'}),
'ReduceMin' : AttrCvt('min', {'axes', 'axis'}),
'ReduceSum' : AttrCvt('sum', {'axes', 'axis'}),
# 'ReduceMean'
# 'ReduceProd'
# 'ReduceLogSumExp'
# 'ArgMax'
# 'ArgMin'
# defs/tensor
'Cast' : AttrCvt('cast', {'to': 'dtype'}),
'Reshape' : Renamer('reshape'),
'Concat' : Renamer('concatenate'),
'Split' : AttrCvt('split', {'split': 'indices_or_sections'}),
# 'Slice'
'Transpose' : AttrCvt('transpose', {'perm': 'axes'}),
# 'Gather'
# 'Squeeze'
'Pad' : _pad(),
}
def _get_convert_map(opset):
return {
# defs/experimental
'Identity': Renamer('copy'),
# 'Affine'
'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ScaledTanh': ScaledTanh.get_converter(opset),
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
# 'ConstantFill'
# 'GivenTensorFill'
'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
'Scale': Scale.get_converter(opset),
# 'GRUUnit'
# 'ATen'
# 'ImageScaler'
# 'MeanVarianceNormalization'
# 'Crop'
# 'Embedding'
# 'Upsample'
'SpatialBN': BatchNorm.get_converter(opset),
# defs/generator
# 'Constant'
# 'RandomUniform'
# 'RandomNormal'
# 'RandomUniformLike'
# 'RandomNormalLike'
# defs/logical
# defs/math
'Add': Add.get_converter(opset),
'Sub': Sub.get_converter(opset),
'Mul': Mul.get_converter(opset),
'Div': Div.get_converter(opset),
'Neg': Renamer('negative'),
'Abs': Absolute.get_converter(opset),
'Reciprocal': Reciprocal.get_converter(opset),
# 'Floor'
# 'Ceil'
'Sqrt': Renamer('sqrt'),
'Relu': Renamer('relu'),
'LeakyRelu': Renamer('leaky_relu'),
'Selu': Selu.get_converter(opset),
'Elu': Elu.get_converter(opset),
'Exp': Renamer('exp'),
'Log': Renamer('log'),
'Tanh': Renamer('tanh'),
# 'Pow'
'PRelu': Prelu.get_converter(opset),
'Sigmoid': Renamer('sigmoid'),
# 'HardSigmoid'
# 'Max' : this is the elemwise maximum
# 'Min' : this is the elemwise minimum
'Sum': Sum.get_converter(opset),
# 'Mean'
# 'Clip'
# softmax default axis is different in onnx
'Softmax': AttrCvt('softmax', {'axis': ('axis', 1)}),
'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
# 'Hardmax'
'Softsign': Softsign.get_converter(opset),
'SoftPlus': SoftPlus.get_converter(opset),
'Gemm': Gemm.get_converter(opset),
# 'MatMul' batch stacked dot operation
# defs/nn
'AveragePool': AveragePool.get_converter(opset),
'MaxPool': MaxPool.get_converter(opset),
'Conv': Conv.get_converter(opset),
'ConvTranspose': ConvTranspose.get_converter(opset),
'GlobalAveragePool': Renamer('global_avg_pool2d'),
'GlobalMaxPool': Renamer('global_max_pool2d'),
'BatchNormalization': BatchNorm.get_converter(opset),
# 'InstanceNormalization'
# 'LpNormalization'
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Renamer('flatten'),
# 'LRN'
# defs/reduction
'ReduceMax': AttrCvt('max', {'axes', 'axis'}),
'ReduceMin': AttrCvt('min', {'axes', 'axis'}),
'ReduceSum': AttrCvt('sum', {'axes', 'axis'}),
# 'ReduceMean'
# 'ReduceProd'
# 'ReduceLogSumExp'
# 'ArgMax'
# 'ArgMin'
# defs/tensor
'Cast': AttrCvt('cast', {'to': 'dtype'}),
'Reshape': Reshape.get_converter(opset),
'Concat': Renamer('concatenate'),
'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
# 'Slice'
'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
# 'Gather'
# 'Squeeze'
'Pad': Pad.get_converter(opset),
}
class GraphProto(object):
"""A helper class for handling nnvm graph copying from pb2.GraphProto.
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
"""
def __init__(self):
self._nodes = {}
self._params = {}
......@@ -365,7 +511,7 @@ class GraphProto(object):
self._num_input = 0
self._num_param = 0
def from_onnx(self, graph):
def from_onnx(self, graph, opset):
"""Construct nnvm nodes from onnx graph.
The inputs from onnx graph is vague, only providing "1", "2"...
For convenience, we rename the `real` input names to "input_0",
......@@ -375,6 +521,7 @@ class GraphProto(object):
----------
graph : onnx protobuf object
The loaded onnx graph
opset : opset version
Returns
-------
......@@ -410,7 +557,7 @@ class GraphProto(object):
op_name = node.op_type
attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
op = self._convert_operator(op_name, inputs, attr)
op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output)
assert len(node_output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {} in {}.".format(
......@@ -438,7 +585,8 @@ class GraphProto(object):
try:
from onnx.numpy_helper import to_array
except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e))
raise ImportError(
"Unable to import onnx which is required {}".format(e))
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
return tvm.nd.array(np_array)
......@@ -455,15 +603,23 @@ class GraphProto(object):
attrs[a.name] = tuple(getattr(a, f))
for f in ['t', 'g']:
if a.HasField(f):
raise NotImplementedError("Filed {} is not supported in nnvm.".format(f))
raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f))
for f in ['tensors', 'graphs']:
if list(getattr(a, f)):
raise NotImplementedError("Filed {} is not supported in nnvm.".format(f))
raise NotImplementedError(
"Filed {} is not supported in nnvm.".format(f))
if a.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs
def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None):
def _convert_operator(self,
op_name,
inputs,
attrs,
opset,
identity_list=None,
convert_map=None):
"""Convert from onnx operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
......@@ -476,6 +632,8 @@ class GraphProto(object):
List of input symbols.
attrs : dict
Dict of operator attributes
opset : int
Opset version
identity_list : list
List of operators that don't require conversion
convert_map : dict
......@@ -489,13 +647,14 @@ class GraphProto(object):
Converted nnvm Symbol
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
convert_map = convert_map if convert_map else _get_convert_map(opset)
if op_name in identity_list:
sym = get_nnvm_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
raise NotImplementedError(
"Operator {} not implemented.".format(op_name))
return sym
def _fix_outputs(self, op_name, outputs):
......@@ -510,7 +669,7 @@ class GraphProto(object):
return outputs
def from_onnx(graph):
def from_onnx(model):
"""Load onnx graph which is a python protobuf object into nnvm graph.
The companion parameters will be handled automatically.
The inputs from onnx graph is vague, only providing "1", "2"...
......@@ -519,8 +678,8 @@ def from_onnx(graph):
Parameters
----------
graph : protobuf object
ONNX GraphProto, or ONNX ModelProto after ONNX v0.2
model : protobuf object
ONNX ModelProto after ONNX v1.1.0
Returns
-------
......@@ -531,8 +690,7 @@ def from_onnx(graph):
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
if hasattr(graph, 'graph'):
# it's a ModelProto wrapper
graph = graph.graph
sym, params = g.from_onnx(graph)
graph = model.graph
opset = model.opset_import[0].version if model.opset_import else 1
sym, params = g.from_onnx(graph, opset)
return sym, params
pip2 install onnx>=0.2.0
pip3 install onnx>=0.2.0
pip2 install onnx>=1.1.0
pip3 install onnx>=1.1.0
pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
pip2 install torchvision
......
......@@ -14,8 +14,8 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
c2_out = prepared_backend.run(W)[0]
return c2_out
def get_tvm_output(graph, x, target, ctx, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(graph)
def get_tvm_output(model, x, target, ctx, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(model)
shape_dict = {'input_0': x.shape}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
......
......@@ -5,8 +5,8 @@ from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution, super_resolution_sym
def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_graph = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
onnx_model = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_model)
g1 = nnvm.graph.create(onnx_sym)
g2 = nnvm.graph.create(nnvm_sym)
ishapes = {'input_0': ishape}
......
......@@ -44,9 +44,9 @@ model_url = ''.join(['https://gist.github.com/zhreshold/',
'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx', True)
# now you have super_resolution.onnx on disk
onnx_graph = onnx.load('super_resolution.onnx')
onnx_model = onnx.load('super_resolution.onnx')
# we can load the graph as NNVM compatible model
sym, params = nnvm.frontend.from_onnx(onnx_graph)
sym, params = nnvm.frontend.from_onnx(onnx_model)
######################################################################
# Load a test image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment