Commit 30a5a600 by Joshua Z. Zhang Committed by Haichen Shen

[RELAY][FRONTEND]Onnx to relay frontend (#2302)

parent 312802f3
......@@ -35,7 +35,7 @@ Our goal is to build the shared libraries:
.. code:: bash
sudo apt-get update
sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev
sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake
The minimal building requirements are
......
......@@ -910,7 +910,7 @@ def test_single_ops():
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
x = np.random.uniform(size=in_shape).astype(dtype)
verify_single_ops("Neg",x, -x)
......@@ -918,13 +918,13 @@ def test_single_ops():
verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5)
verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Relu",x, np.maximum(x, 0))
verify_single_ops("Exp",x, np.exp(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Tanh",x, np.tanh(x))
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
verify_single_ops("Exp",x, np.exp(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Tanh",x, np.tanh(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)), rtol=1e-5, atol=1e-5)
verify_single_ops("Softsign",x, x / (1 + np.abs(x)), rtol=1e-5, atol=1e-5)
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)), rtol=1e-5, atol=1e-5)
def test_leaky_relu():
def leaky_relu_x(x, alpha):
......
......@@ -465,6 +465,14 @@ def const(value, dtype=None):
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)
if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = {
_np.dtype('int64'): _np.int32,
_np.dtype('float64'): _np.float32
}.get(value.dtype, None)
if map_dtype:
value = value.astype(map_dtype)
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
......
......@@ -9,3 +9,4 @@ from __future__ import absolute_import
from .mxnet import from_mxnet
from .keras import from_keras
from .onnx import from_onnx
"""Common utilities"""
from __future__ import absolute_import as _abs
import logging
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import expr as _expr
from .. import ir_pass
from .. import op as _op
class RequiredAttr(object):
......@@ -204,6 +209,30 @@ class StrAttrsDict(object):
raise AttributeError("Required attribute {} not found.".format(key))
return default
def get_relay_op(op_name):
"""Get the callable function from Relay based on operator name.
Parameters
----------
op_name : str
The Relay operator name.
"""
if '.' in op_name:
# explicit hierachical modules
op = _op
try:
for opn in op_name.split('.'):
op = getattr(op, opn)
except AttributeError:
op = None
else:
# try search op in various modules
for candidate in (_op, _op.nn, _op.image):
op = getattr(candidate, op_name, None)
if op is not None:
break
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op
class ExprTable(object):
"""Table storing Relay expressions by names."""
......@@ -227,3 +256,156 @@ class ExprTable(object):
def set_expr(self, name, expr):
assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr
class AttrCvt(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provded, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, inputs, attrs, *args):
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, str):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores:
logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return get_relay_op(op_name)(*inputs, **new_attrs)
def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, str):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t
def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, str):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)
def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
def get_name(node):
name = ''
if hasattr(node, "name_hint"):
name = node.name_hint
return name
def infer_shape(inputs):
"""A method to get the output shape of an intermediate node in the graph."""
out_type = ir_pass.infer_type(inputs)
out_shapes = get_const_tuple(out_type.checked_type.shape)
return out_shapes
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
"""
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def new_var(name_hint,
type_annotation=None,
shape=None,
dtype="float32"):
return _expr.var(name_hint, type_annotation, shape, dtype)
class Renamer(object):
"""A simply renamer for operators.
Parameters
----------
new_name : str
The new name for the operator
"""
def __init__(self, new_name):
self._new_name = new_name
def __call__(self, inputs, attrs, *args):
return get_relay_op(self._new_name)(*inputs, **attrs)
......@@ -4,15 +4,7 @@ from __future__ import absolute_import as _abs
from .. import expr as _expr
from .. import op as _op
def _get_relay_op(op_name):
op = _op
for path in op_name.split("."):
op = getattr(op, path)
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op
from .common import get_relay_op
def _warn_not_used(attr, op='nnvm'):
import warnings
......@@ -22,7 +14,7 @@ def _warn_not_used(attr, op='nnvm'):
def _rename(new_op):
if isinstance(new_op, str):
new_op = _get_relay_op(new_op)
new_op = get_relay_op(new_op)
# attrs are ignored.
def impl(inputs, _, _dtype='float32'):
return new_op(*inputs)
......
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
"""ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs
import logging
import numpy as np
from ... import nd as _nd
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name
__all__ = ['from_onnx']
def dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
else:
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def revert_caffe2_pad(pads):
"""Caffe2 requires two times the normal padding."""
if len(pads) == 4:
pads = pads[:2]
elif len(pads) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(pads))
return pads
def dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
class OnnxOpConverter(object):
""" A helper class for holding onnx op converters.
"""
@classmethod
def get_converter(cls, opset):
""" Get converter matches given opset.
Parameters
----------
opset: int
opset from model.
Returns
-------
converter, which should be `_impl_vx`. Number x is the biggest
number smaller than or equal to opset belongs to all support versions.
"""
versions = [
int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
]
versions = sorted(versions + [opset])
version = versions[
max([i for i, v in enumerate(versions) if v == opset]) - 1]
if hasattr(cls, '_impl_v{}'.format(version)):
return getattr(cls, '_impl_v{}'.format(version))
raise NotImplementedError(
'opset version {} of {} not implemented'.format(
version, cls.__name__))
class Elemwise(OnnxOpConverter):
""" A helper class for elemwise op converters.
"""
name = ''
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
len(inputs))
op_name = cls.name
conv_ops = ["conv2d", "conv2d_transpose"]
if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
# TODO(zhreshold): remove hard coded infershape
axis = int(attr.get('axis', 0))
inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_relay_op(op_name)(*inputs)
class Pool(OnnxOpConverter):
""" A helper class for pool op converters.
"""
name = ''
@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
# very weird attributes here in onnx, force check
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=dimension_constraint())(inputs, attr, params)
class Absolute(OnnxOpConverter):
""" Operator converter for Absolute.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.nn.relu(inputs[0]) + _op.nn.relu(_op.negative(inputs[0]))
class Add(Elemwise):
""" Operator converter for Add.
"""
name = 'add'
class AveragePool(Pool):
""" Operator converter for AveragePool.
"""
name = 'avg_pool'
class BatchNorm(OnnxOpConverter):
""" Operator converter for BatchNorm.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# TODO(zhreshold): 'spatial' is not properly handled here.
out = AttrCvt(
op_name='batch_norm',
ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr,
params)
return out[0]
class Conv(OnnxOpConverter):
""" Operator converter for Conv.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
out = AttrCvt(op_name=dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)},
custom_check=dimension_constraint())(inputs[:2], attr, params)
use_bias = len(inputs) == 3
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
return out
class ConvTranspose(OnnxOpConverter):
""" Operator converter for ConvTranspose.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
channels = infer_channels(inputs[1], True)
attr['channels'] = channels
groups = attr.pop('group')
attr['groups'] = groups
out = AttrCvt(
op_name=dimension_picker('conv', '_transpose'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad)
},
disables=['output_shape'],
custom_check=dimension_constraint())(inputs[:2], attr, params)
use_bias = len(inputs) == 3
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
return out
class Div(Elemwise):
name = 'divide'
class Elu(OnnxOpConverter):
""" Operator converter for Elu.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \
_op.nn.relu(inputs[0])
class Gemm(OnnxOpConverter):
""" Operator converter for Gemm.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
len(inputs))
# Y = alpha * A * B + beta * C
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
transA = int(attr.get('transA', 0))
transB = int(attr.get('transB', 0))
# get number of channels
channels = infer_channels(inputs[1], not transB)
if transA:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not transB:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
inputs[0] = _op.nn.batch_flatten(inputs[0])
out = _op.nn.dense(_expr.const(alpha) * inputs[0],
inputs[1], units=channels)
return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])
class MatMul(OnnxOpConverter):
""" Operator converter for MatMul.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
class MaxPool(Pool):
name = 'max_pool'
class Mul(Elemwise):
name = 'multiply'
class Pad(OnnxOpConverter):
""" Operator converter for Pad.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
pad_width = []
pads = attr.pop('paddings')
dims = int(len(pads) / 2)
for i in range(dims):
pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
return AttrCvt(
_op.nn.pad,
transforms={
'value': 'pad_value',
},
ignores=['mode'],
custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
@classmethod
def _impl_v2(cls, inputs, attr, params):
pad_width = []
pads = attr.pop('pads')
dims = int(len(pads) / 2)
for i in range(dims):
pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
return AttrCvt(
'pad',
transforms={
'value': 'pad_value',
},
ignores=['mode'],
custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
'split mode != constant'))(inputs, attr, params)
class ParametricSoftPlus(OnnxOpConverter):
""" Operator converter for ParametricSoftPlus.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = _expr.const(float(attr.get('alpha', 1.0)))
beta = _expr.const(float(attr.get('beta', 1.0)))
return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha
class Prelu(OnnxOpConverter):
""" Operator converter for Prelu.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
return _op.nn.prelu(inputs[0], inputs[1])
class Reciprocal(OnnxOpConverter):
""" Operator converter for Reciprocal.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _expr.const(1.0) / inputs[0]
class Reshape(OnnxOpConverter):
""" Operator converter for Reshape.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'shape' in attr:
return _op.reshape(inputs[0], attr['shape'])
if get_name(inputs[1]) in params:
shape = tuple(params[inputs[1].name_hint].asnumpy())
out = _op.reshape(inputs[0], shape)
else:
out = _op.reshape_like(inputs[0], inputs[1])
return out
class Concat(OnnxOpConverter):
""" Operator converter for Concat.
"""
@classmethod
def _impl_v1(cls, inputs, args, params):
return AttrCvt(op_name='concatenate')((inputs,), args)
class Scale(OnnxOpConverter):
""" Operator converter for Scale.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
scale = float(attr.get('scale', 1.0))
return inputs[0] * _expr.const(scale)
class Selu(OnnxOpConverter):
""" Operator converter for Selu.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.6732))
gamma = float(attr.get('gamma', 1.0507))
return _expr.const(gamma) * (_expr.const(-alpha) *
_op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) +
_op.nn.relu(inputs[0]))
class ScaledTanh(OnnxOpConverter):
""" Operator converter for ScaledTanh.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)
class SoftPlus(OnnxOpConverter):
""" Operator converter for SoftPlus.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.log(_op.exp(inputs[0]) + _expr.const(1.))
class Softsign(OnnxOpConverter):
""" Operator converter for Softsign.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params))
class Sub(Elemwise):
name = 'subtract'
class Sum(OnnxOpConverter):
""" Operator converter for Sum.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# Onnx Sum Operator
for in_index in range(len(inputs) - 1):
inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1])
return inputs[len(inputs) - 1]
class ThresholdedRelu(OnnxOpConverter):
""" Operator converter for ThresholdedRelu.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
return inputs[0] * mask
def _broadcast_constraint():
def _broadcast_check(attrs):
if attrs.get('axis', None):
return False
return True
return _broadcast_check, "Specifying broadcast axis not allowed."
def _fully_connected(opset):
def _impl(inputs, attr, params):
# get number of channels
channels = infer_channels(inputs[1], params)
attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
return _impl
class Upsample(OnnxOpConverter):
""" Operator converter for Upsample (nearest mode).
"""
@classmethod
def _impl_v7(cls, inputs, attr, params):
scales = attr.get('scales')
assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3]
mode = attr.get('mode')
if mode == b'nearest':
method = "NEAREST_NEIGHBOR"
elif mode == b'linear':
method = "BILINEAR"
else:
raise ValueError("Invalid ONNX upsample mode: {}".format(mode))
attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'}
return AttrCvt('upsampling')(inputs, attr)
class Shape(OnnxOpConverter):
""" Operator converter for Shape.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# Result of this operator is prominently used by reshape operator.
# Just pass the input as it is so that reshape_like can be used there.
logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)")
return inputs[0]
class Cast(OnnxOpConverter):
""" Operator converter for Cast.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
@classmethod
def _impl_v5(cls, inputs, attr, params):
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
except ImportError as e:
raise ImportError(
"Unable to import onnx.mapping which is required {}".format(e))
return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
class Unsqueeze(OnnxOpConverter):
""" Operator converter for Unsqueeze.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
for axes in attr['axes']:
inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
return inputs[0]
class Split(OnnxOpConverter):
""" Operator converter for Split.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
attr['indices_or_sections'] = []
index = 0
for i in attr['split'][:-1]:
index += i
attr['indices_or_sections'].append(index)
return AttrCvt(
'split',
ignores=['split'])(inputs, attr, params)
class Slice(OnnxOpConverter):
""" Operator converter for Slice.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if isinstance(attr['starts'], int):
attr['starts'] = (attr['starts'],)
attr['ends'] = (attr['ends'],)
try:
# Update the starts and ends according to axes if required.
if isinstance(attr['axes'], int):
attr['axes'] = (attr['axes'],)
if (max(attr['axes']) + 1) != len(attr['axes']):
new_axes = []
new_starts = []
new_ends = []
pop_index = 0
for i in range(max(attr['axes']) + 1):
if i in attr['axes']:
new_axes.append(i)
new_starts.append(attr['starts'][pop_index])
new_ends.append(attr['ends'][pop_index])
pop_index += 1
else:
new_axes.append(i)
new_starts.append(0)
new_ends.append(np.iinfo(np.int32).max)
attr['axes'] = new_axes
attr['starts'] = new_starts
attr['ends'] = new_ends
except KeyError:
pass
return AttrCvt('strided_slice',
transforms={'starts': 'begin',
'ends': 'end'},
ignores=['axes'])(inputs, attr)
class Gather(OnnxOpConverter):
""" Operator converter for Gather.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
return AttrCvt('take',
extras={'axis':axis})(inputs, {})
#return _op.take(inputs[0], inputs[1], axis)
class LRN(OnnxOpConverter):
""" Operator converter for Local Response Normalization.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
"""LRN support only NCHW format
https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
"""
axis = 1
alpha = attr.get('alpha', 0.0001)
beta = attr.get('beta', 0.75)
bias = attr.get('bias', 1.0)
nsize = attr.get('size')
attr = {'size':nsize, 'axis':axis, 'alpha':alpha, 'beta':beta, 'bias':bias}
return AttrCvt('lrn')(inputs, attr)
class Maximum(OnnxOpConverter):
""" Operator converter for Maximum.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
_max = inputs[0]
for i in range(1, len(inputs)):
_max = AttrCvt('maximum')([_max, inputs[i]], {})
return _max
class Minimum(OnnxOpConverter):
""" Operator converter for Minimum.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
_min = inputs[0]
for i in range(1, len(inputs)):
_min = AttrCvt('minimum')([_min, inputs[i]], {})
return _min
class Mean(OnnxOpConverter):
""" Operator converter for Mean.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
# avoid overflow
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
return _op.mean(concat, axis=0, keepdims=False)
class HardSigmoid(OnnxOpConverter):
""" Operator converter for HardSigmoid.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = attr.get('alpha', 0.2)
beta = attr.get('beta', 0.5)
transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
attr = {'a_min':0, 'a_max':1}
return AttrCvt('clip')([transformX], attr)
class Reduce(OnnxOpConverter):
""" Operator converter for reduce ops.
"""
name = ''
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'axes' in attr:
axis = attr.get('axes', 0)
else:
axis_len = len(infer_shape(inputs[0]))
axis = list(range(axis_len))
attr = {'axis':axis, 'keepdims':attr.get('keepdims', True)}
return AttrCvt(cls.name)(inputs, attr)
class ReduceMax(Reduce):
""" Operator converter for ArgMax.
"""
name = 'max'
class ReduceMin(Reduce):
""" Operator converter for ArgMax.
"""
name = 'min'
class ReduceSum(Reduce):
""" Operator converter for ArgMax.
"""
name = 'sum'
class ReduceMean(Reduce):
""" Operator converter for ArgMax.
"""
name = 'mean'
class ArgMax(OnnxOpConverter):
""" Operator converter for ArgMax.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
keepdims = attr.get('keepdims', True)
attr = {'axis':axis, 'keepdims':keepdims}
return AttrCvt('argmax')(inputs, attr)
class ArgMin(OnnxOpConverter):
""" Operator converter for ArgMin.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
keepdims = attr.get('keepdims', True)
attr = {'axis':axis, 'keepdims':keepdims}
return AttrCvt('argmin')(inputs, attr)
class Softmax(OnnxOpConverter):
""" Operator converter for Softmax.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
# set default value when axis is not set in the model
if 'axis' not in attr:
attr['axis'] = 1
return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)
class ConstantFill(OnnxOpConverter):
""" Operator converter for ConstantFill.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
num_inputs = len(inputs)
if 'shape' in attr:
if num_inputs > 1:
raise ImportError(
"Can't set shape and input tensor at a time")
shape = attr.pop('shape')
else:
if num_inputs == 1:
raise ImportError(
"Either shape attribute or input should be set")
if 'input_as_shape' in attr and attr['input_as_shape']:
shape = params[get_name(inputs[0])].asnumpy()
else:
if 'extra_shape' in attr:
raise ImportError(
"Extra Shape not supported with fill_like")
return _op.full_like(inputs[0], inputs[1])
if 'extra_shape' in attr:
shape = shape + attr.pop('extra_shape')
return _op.full(inputs[0], shape)
# compatible operators that do NOT require any conversion.
_identity_list = []
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
def _get_convert_map(opset):
return {
# defs/experimental
'Identity': Renamer('copy'),
# 'Affine'
'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ScaledTanh': ScaledTanh.get_converter(opset),
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
'ConstantFill': ConstantFill.get_converter(opset),
# 'GivenTensorFill'
'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
'Scale': Scale.get_converter(opset),
# 'GRUUnit'
# 'ATen'
# 'ImageScaler'
# 'MeanVarianceNormalization'
# 'Crop'
# 'Embedding'
'Upsample' : Upsample.get_converter(opset),
'SpatialBN': BatchNorm.get_converter(opset),
# defs/generator
# 'Constant' # Implemented
# 'RandomUniform'
# 'RandomNormal'
# 'RandomUniformLike'
# 'RandomNormalLike'
# defs/logical
# defs/math
'Add': Add.get_converter(opset),
'Sub': Sub.get_converter(opset),
'Mul': Mul.get_converter(opset),
'Div': Div.get_converter(opset),
'Neg': Renamer('negative'),
'Abs': Absolute.get_converter(opset),
'Reciprocal': Reciprocal.get_converter(opset),
'Floor': Renamer('floor'),
'Ceil': Renamer('ceil'),
'Sqrt': Renamer('sqrt'),
'Relu': Renamer('relu'),
'LeakyRelu': Renamer('leaky_relu'),
'Selu': Selu.get_converter(opset),
'Elu': Elu.get_converter(opset),
'Exp': Renamer('exp'),
'Log': Renamer('log'),
'Tanh': Renamer('tanh'),
'Pow': Renamer('power'),
'PRelu': Prelu.get_converter(opset),
'Sigmoid': Renamer('sigmoid'),
'HardSigmoid': HardSigmoid.get_converter(opset),
'Max': Maximum.get_converter(opset),
'Min': Minimum.get_converter(opset),
'Sum': Sum.get_converter(opset),
'Mean': Mean.get_converter(opset),
'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}),
# softmax default axis is different in onnx
'Softmax': Softmax.get_converter(opset),
'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
# 'Hardmax'
'Softsign': Softsign.get_converter(opset),
'SoftPlus': SoftPlus.get_converter(opset),
'Gemm': Gemm.get_converter(opset),
'MatMul': MatMul.get_converter(opset),
# defs/nn
'AveragePool': AveragePool.get_converter(opset),
'MaxPool': MaxPool.get_converter(opset),
'Conv': Conv.get_converter(opset),
'ConvTranspose': ConvTranspose.get_converter(opset),
'GlobalAveragePool': Renamer('global_avg_pool2d'),
'GlobalMaxPool': Renamer('global_max_pool2d'),
'BatchNormalization': BatchNorm.get_converter(opset),
# 'InstanceNormalization'
# 'LpNormalization'
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Renamer('flatten'),
'LRN': LRN.get_converter(opset),
# defs/reduction
'ReduceMax': ReduceMax.get_converter(opset),
'ReduceMin': ReduceMin.get_converter(opset),
'ReduceSum': ReduceSum.get_converter(opset),
'ReduceMean': ReduceMean.get_converter(opset),
# 'ReduceProd'
# 'ReduceLogSumExp'
'ArgMax': ArgMax.get_converter(opset),
'ArgMin': ArgMin.get_converter(opset),
# defs/tensor
'Cast': Cast.get_converter(opset),
'Reshape': Reshape.get_converter(opset),
'Concat': Concat.get_converter(opset),
'Split': Split.get_converter(opset),
'Slice': Slice.get_converter(opset),
'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
'Gather': Gather.get_converter(opset),
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset),
# TODO(zhreshold) Shape op is implemented as bypass op in relay
# 'Shape': Shape.get_converter(opset),
}
class GraphProto(object):
"""A helper class for handling Relay expression copying from pb2.GraphProto.
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
Parameters
----------
shape : dict of str to tuple, optional
The input shape to the graph
dtype : str or dict of str to str
The input types to the graph
"""
def __init__(self, shape, dtype):
self._nodes = {}
self._params = {}
self._renames = {}
self._num_input = 0
self._num_param = 0
self._shape = shape
self._dtype = dtype
def from_onnx(self, graph, opset):
"""Construct Relay expression from ONNX graph.
Onnx graph is a python protobuf object.
The companion parameters will be handled automatically.
However, the input names from onnx graph is vague, mixing inputs and
network weights/bias such as "1", "2"...
For convenience, we rename the `real` input names to "input_0",
"input_1"... And renaming parameters to "param_0", "param_1"...
Parameters
----------
graph : onnx protobuf object
The loaded onnx graph
opset : opset version
Returns
-------
sym : tvm.relay.expr.Function
The returned relay function
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
# parse network inputs to relay, aka parameters
for init_tensor in graph.initializer:
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
self._params[init_tensor.name] = self._parse_array(init_tensor)
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
i_name = self._parse_value_proto(i)
d_type = self._parse_dtype(i, 'float32')
if i_name in self._params:
# i is a param instead of input
self._num_param += 1
self._params[i_name] = self._params.pop(i_name)
self._nodes[i_name] = new_var(i_name,
shape=self._params[i_name].shape,
dtype=self._params[i_name].dtype)
else:
self._num_input += 1
shape = self._shape[i_name] if i_name in self._shape else ()
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype)
# construct nodes, nodes are stored as directed acyclic graph
for node in graph.node:
op_name = node.op_type
attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1
self._params[node.output[0]] = self._parse_array(t_proto)
self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims))
else:
if op_name == "ConstantFill":
fill_value = attr.get('value', 0.0)
dtype = attr.get('dtype', b'int32').decode("utf-8")
i_name = node.output[0]
self._params[i_name] = fill_value
self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
inputs.append(self._nodes[i_name])
op = self._convert_operator(op_name, inputs, attr, opset)
node_output = self._fix_outputs(op_name, node.output)
if not isinstance(op, _expr.TupleWrapper):
outputs_num = 1
else:
outputs_num = len(op)
assert len(node_output) == outputs_num, (
"Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name))
if outputs_num == 1:
self._nodes[node_output[0]] = op
else:
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(ir_pass.free_vars(outputs), outputs)
return func, self._params
def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str."""
try:
name = value_proto.name
except AttributeError:
name = value_proto
return name
def _parse_dtype(self, value_proto, dtype):
"""Parse dtype."""
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
except AttributeError:
return dtype
def _parse_array(self, tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
from onnx.numpy_helper import to_array
except ImportError as e:
raise ImportError(
"Unable to import onnx which is required {}".format(e))
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
return _nd.array(np_array)
def _parse_attr(self, attr_proto):
"""Convert a list of AttributeProto to a dict, with names as keys."""
attrs = {}
for a in attr_proto:
for f in ['f', 'i', 's']:
if a.HasField(f):
attrs[a.name] = getattr(a, f)
for f in ['floats', 'ints', 'strings']:
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ['t']:
if a.HasField(f):
attrs[a.name] = getattr(a, f)
for f in ['tensors']:
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ['g']:
if a.HasField(f):
raise NotImplementedError(
"Filed {} is not supported in relay.".format(f))
for f in ['graphs']:
if list(getattr(a, f)):
raise NotImplementedError(
"Filed {} is not supported in relay.".format(f))
if a.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs
def _convert_operator(self,
op_name,
inputs,
attrs,
opset):
"""Convert ONNX operator into a Relay operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.expr.Function
List of inputs.
attrs : dict
Dict of operator attributes
opset : int
Opset version
Returns
-------
sym : tvm.relay.expr.Function
Converted relay function
"""
convert_map = _get_convert_map(opset)
if op_name in _identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
else:
raise NotImplementedError(
"Operator {} not implemented.".format(op_name))
return sym
def _fix_outputs(self, op_name, outputs):
"""A hack to handle dropout or similar operator that have more than one out
in ONNX.
"""
if op_name == 'Dropout':
if len(outputs) == 1:
return outputs
# TODO(zhreshold): support dropout mask?
outputs = outputs[:-1]
return outputs
def from_onnx(model,
shape=None,
dtype="float32"):
"""Convert a ONNX model into an equivalent Relay Function.
ONNX graphs are represented as Python Protobuf objects.
The companion parameters will be handled automatically.
However, the input names from onnx graph is vague, mixing inputs and
network weights/bias such as "1", "2"...
For convenience, we rename the `real` input names to "input_0",
"input_1"... And renaming parameters to "param_0", "param_1"...
Parameters
----------
model : protobuf object
ONNX ModelProto after ONNX v1.1.0
shape : dict of str to tuple, optional
The input shape to the graph
dtype : str or dict of str to str
The input types to the graph
Returns
-------
sym : tvm.relay.expr.Function
Compatible relay function
params : dict of str to tvm.NDArray
The parameter dict to be used by relay
"""
g = GraphProto(shape, dtype)
graph = model.graph
try:
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1
sym, params = g.from_onnx(graph, opset)
return sym, params
import numpy as np
import math
import topi
import topi.testing
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
import onnx
from onnx import helper, TensorProto
import unittest
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32'):
""" Generic function to execute and get tvm output"""
target = 'llvm'
if isinstance(input_data, list):
input_names = {}
shape_dict = {}
dtype_dict = {}
for i, _ in enumerate(input_data):
input_names[i] = graph_def.graph.input[i].name
shape_dict[input_names[i]] = input_data[i].shape
dtype_dict[input_names[i]] = input_data[i].dtype
else:
input_names = graph_def.graph.input[0].name
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}
sym, params = relay.frontend.from_onnx(graph_def, shape_dict)
with relay.build_config(opt_level=1):
graph, lib, params = relay.build(sym, target, params=params)
ctx = tvm.cpu(0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
if isinstance(input_data, list):
for i, e in enumerate(input_names):
m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
else:
m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
if isinstance(output_shape, list) and isinstance(output_dtype, list):
tvm_output_list = []
for i, _ in enumerate(output_shape):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
else:
tvm_output = m.get_output(0)
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
import caffe2.python.onnx.backend
prepared_backend = caffe2.python.onnx.backend.prepare(model)
W = {model.graph.input[0].name: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0]
return c2_out
def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
dtype = 'float32'
x = np.random.uniform(size=data_shape)
model = onnx.load_model(graph_file)
c2_out = get_caffe2_output(model, x, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_super_resolution_example():
verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
def verify_squeezenet1_1():
verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))
def verify_lenet():
verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
def verify_resnet18():
verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
def test_reshape():
in_shape = (4, 3, 3, 4)
ref_shape = (6, 2, 4, 3)
ref_array = np.array(ref_shape)
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.INT32,
dims = ref_array.shape,
vals = ref_array.flatten().astype(int)))
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
graph = helper.make_graph([ref_node, reshape_node],
"reshape_test",
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('int32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def test_reshape_like():
in_shape = (4, 3, 3, 4)
ref_shape = (3, 4, 4, 3)
ref_array = np.random.uniform(size=ref_shape).astype('float32')
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in'],
value=onnx.helper.make_tensor(name = 'const_tensor',
data_type = onnx.TensorProto.FLOAT,
dims = ref_array.shape,
vals = ref_array.flatten().astype(float)))
copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"])
reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
graph = helper.make_graph([ref_node, copy_node, reshape_node],
"reshape_like_test",
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(ref_shape))])
model = helper.make_model(graph, producer_name='reshape_like_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def _test_power_iteration(x_shape, y_shape):
if isinstance(y_shape, int):
y_shape = [y_shape]
x = np.random.uniform(size=x_shape).astype(np.float32)
y = np.random.uniform(size=y_shape).astype(np.float32)
np_res = np.power(x, y).astype(np.float32)
res = helper.make_node("Pow", ['x', 'y'], ['out'])
graph = helper.make_graph([res],
'power_test',
inputs = [helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(x_shape)),
helper.make_tensor_value_info("y",
TensorProto.FLOAT, list(y_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(np_res.shape))])
model = helper.make_model(graph, producer_name='power_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape)
tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5)
def test_power():
_test_power_iteration((1, 3), (1))
_test_power_iteration((2, 3), (2, 3))
_test_power_iteration((2, 3), (1, 3))
def test_squeeze():
in_shape = (1, 3, 1, 3, 1, 1)
out_shape = (3, 3)
y = helper.make_node("Squeeze", ['in'], ['out'], axes=[0, 2, 4, 5])
graph = helper.make_graph([y],
'squeeze_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='squeeze_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
tvm.testing.assert_allclose(out_shape, tvm_out.shape)
def test_unsqueeze():
in_shape = (3, 3)
axis = (0, 3, 4)
out_shape = (1, 3, 3, 1, 1)
y = helper.make_node("Unsqueeze", ['in'], ['out'], axes=list(axis))
graph = helper.make_graph([y],
'squeeze_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='squeeze_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
tvm.testing.assert_allclose(out_shape, tvm_out.shape)
def verify_gather(in_shape, indices, axis, dtype):
x = np.random.uniform(size=in_shape).astype(dtype)
indices = np.array(indices, dtype="int32")
out_np = np.take(x, indices, axis=axis)
y = helper.make_node("Gather", ['in', 'indices'], ['out'], axis=axis)
graph = helper.make_graph([y],
'gather_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("indices",
TensorProto.INT32, list(indices.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='gather_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out)
def test_gather():
verify_gather((4,), [1], 0, 'int32')
verify_gather((1,4), [0], 0, 'int32')
verify_gather((4,), [[[1,0],[0,1]]], 0, 'float32')
verify_gather((2,2), [[[1,0],[0,1]]], 1, 'int32')
verify_gather((3,3,3), [[[1,0]]], -1, 'int32')
verify_gather((4,3,5,6), [[2,1,0,0]], 0, 'float32')
def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
if axes:
y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
else:
y = helper.make_node("Slice", ['in'], ['out'], starts=starts, ends=ends)
graph = helper.make_graph([y],
'slice_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='slice_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out)
def test_slice():
x = np.random.randn(20, 10, 5).astype(np.float32)
_test_slice_iteration(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))
def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
outdata = outfunc(indata, **npargs)
y = helper.make_node(opname, ['in'], ['out'], **kwargs)
graph = helper.make_graph([y],
opname+'_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name=opname+'_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_floor():
_test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {})
def test_ceil():
_test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {})
def test_clip():
_test_onnx_op_elementwise((2, 4, 5, 6),
np.clip,
{'a_min': -1.0, 'a_max': 1.0},
'float32',
'Clip',
{'min': -1.0, 'max': 1.0})
def test_matmul():
a_shape = (4, 3)
b_shape = (3, 4)
a_array = np.random.uniform(size=a_shape).astype('float32')
b_array = np.random.uniform(size=b_shape).astype('float32')
out_np = np.matmul(a_array, b_array)
mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
graph = helper.make_graph([mul_node],
"matmul_test",
inputs = [helper.make_tensor_value_info("a",
TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b",
TensorProto.FLOAT, list(b_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_np.shape))])
model = helper.make_model(graph, producer_name='matmul_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype)
if alpha == None and beta == None and bias==None:
alpha = 0.0001
beta = 0.75
bias = 1.0
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize)
else:
node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha,
beta=beta, bias=bias, size=nsize)
graph = helper.make_graph([node],
"lrn_test",
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
model = helper.make_model(graph, producer_name='lrn_test')
def _get_python_lrn():
square_sum = np.zeros(shape).astype(dtype)
for n, c, h, w in np.ndindex(in_array.shape):
square_sum[n, c, h, w] = sum(in_array[n,
max(0, c - int(math.floor((nsize - 1) / 2))): \
min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
h,
w] ** 2)
py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta)
return py_out
for target, ctx in ctx_list():
input_name = model.graph.input[0].name
py_out = _get_python_lrn()
tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, 'float32')
tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_lrn():
verify_lrn((5, 5, 5, 5), 3, 'float32')
verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
def _test_upsample_nearest():
scale = 2
in_shape = (1, 1, 3, 3)
out_shape = (1, 1, 3*scale, 3*scale)
y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.upsampling_python(in_array, scale, "NCHW")
graph = helper.make_graph([y],
'upsample_nearest_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='upsample_nearest_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
tvm.testing.assert_allclose(out_array, tvm_out)
def _test_upsample_bilinear():
scale = 2
in_shape = (1, 1, 3, 3)
out_shape = (1, 1, 3*scale, 3*scale)
y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
graph = helper.make_graph([y],
'upsample_bilinear_test',
inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='upsample_bilinear_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
def test_upsample():
_test_upsample_nearest()
_test_upsample_bilinear()
def _test_softmax(inshape, axis):
opname = 'Softmax'
indata = np.random.uniform(size=inshape).astype(np.float32)
outshape = inshape
outdata = topi.testing.softmax_python(indata)
if isinstance(axis, int):
y = helper.make_node(opname, ['in'], ['out'], axis = axis)
elif axis is None:
y = helper.make_node(opname, ['in'], ['out'])
graph = helper.make_graph([y],
opname+'_test',
inputs = [helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name=opname+'_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_softmax():
_test_softmax((1, 10), None)
_test_softmax((1, 10), 1)
def verify_min(input_dim):
dtype = 'float32'
a_np1 = np.random.uniform(size=input_dim).astype(dtype)
a_np2 = np.random.uniform(size=input_dim).astype(dtype)
a_np3 = np.random.uniform(size=input_dim).astype(dtype)
b_np = np.min((a_np1, a_np2, a_np3), axis=0)
min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"])
graph = helper.make_graph([min_node],
"Min_test",
inputs = [helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np2",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np3",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='Min_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_min():
verify_min((1, 3, 20, 20))
verify_min((20, 20))
def verify_max(input_dim):
dtype = 'float32'
a_np1 = np.random.uniform(size=input_dim).astype(dtype)
a_np2 = np.random.uniform(size=input_dim).astype(dtype)
a_np3 = np.random.uniform(size=input_dim).astype(dtype)
b_np = np.max((a_np1, a_np2, a_np3), axis=0)
max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"])
graph = helper.make_graph([max_node],
"Max_test",
inputs = [helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np2",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np3",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='Max_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_max():
verify_max((1, 3, 20, 20))
verify_max((20, 20))
def verify_mean(input_dim):
dtype = 'float32'
a_np1 = np.random.uniform(size=input_dim).astype(dtype)
a_np2 = np.random.uniform(size=input_dim).astype(dtype)
a_np3 = np.random.uniform(size=input_dim).astype(dtype)
b_np = np.mean((a_np1, a_np2, a_np3), axis=0)
mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"])
graph = helper.make_graph([mean_node],
"Mean_test",
inputs = [helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np2",
TensorProto.FLOAT, list(input_dim)),
helper.make_tensor_value_info("a_np3",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='Mean_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_mean():
verify_mean((1, 3, 20, 20))
verify_mean((20, 20))
def verify_hardsigmoid(input_dim, alpha, beta):
dtype = 'float32'
a_np1 = np.random.uniform(size=input_dim).astype(dtype)
b_np = np.clip(a_np1 * alpha + beta, 0, 1)
hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta)
graph = helper.make_graph([hardsigmoid_node],
"HardSigmoid_test",
inputs = [helper.make_tensor_value_info("a_np1",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(b_np.shape))])
model = helper.make_model(graph, producer_name='HardSigmoid_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_hardsigmoid():
verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6)
verify_hardsigmoid((20, 20), 0.3, 0.4)
def verify_argmin(input_dim, axis=None, keepdims=None):
def _argmin_numpy(data, axis=0, keepdims=True):
result = np.argmin(data, axis=axis)
if (keepdims == 1):
result = np.expand_dims(result, axis)
return result.astype(data.dtype)
a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
if keepdims is None and axis is None:
b_np = _argmin_numpy(a_np1)
node = onnx.helper.make_node('ArgMin',
inputs=['a_np1'],
outputs=['out'])
elif axis is None:
b_np = _argmin_numpy(a_np1, keepdims=keepdims)
node = onnx.helper.make_node('ArgMin',
inputs=['a_np1'],
outputs=['out'],
keepdims=keepdims)
elif keepdims is None:
b_np = _argmin_numpy(a_np1, axis=axis)
node = onnx.helper.make_node('ArgMin',
inputs=['a_np1'],
outputs=['out'],
axis=axis)
else:
b_np = _argmin_numpy(a_np1, axis=axis, keepdims=keepdims)
node = onnx.helper.make_node('ArgMin',
inputs=['a_np1'],
outputs=['out'],
axis=axis,
keepdims=keepdims)
graph = helper.make_graph([node],
"argmin_test",
inputs = [helper.make_tensor_value_info("a_np1",
TensorProto.INT32, list(a_np1.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.INT32, list(b_np.shape))])
model = helper.make_model(graph, producer_name='argmin_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def verify_argmax(input_dim, axis=None, keepdims=None):
def _argmax_numpy(data, axis=0, keepdims=True):
result = np.argmax(data, axis=axis)
if (keepdims == 1):
result = np.expand_dims(result, axis)
return result.astype(data.dtype)
a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
if keepdims is None and axis is None:
b_np = _argmax_numpy(a_np1)
node = onnx.helper.make_node('ArgMax',
inputs=['a_np1'],
outputs=['out'])
elif axis is None:
b_np = _argmax_numpy(a_np1, keepdims=keepdims)
node = onnx.helper.make_node('ArgMax',
inputs=['a_np1'],
outputs=['out'],
keepdims=keepdims)
elif keepdims is None:
b_np = _argmax_numpy(a_np1, axis=axis)
node = onnx.helper.make_node('ArgMax',
inputs=['a_np1'],
outputs=['out'],
axis=axis)
else:
b_np = _argmax_numpy(a_np1, axis=axis, keepdims=keepdims)
node = onnx.helper.make_node('ArgMax',
inputs=['a_np1'],
outputs=['out'],
axis=axis,
keepdims=keepdims)
graph = helper.make_graph([node],
"argmax_test",
inputs = [helper.make_tensor_value_info("a_np1",
TensorProto.INT32, list(a_np1.shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.INT32, list(b_np.shape))])
model = helper.make_model(graph, producer_name='argmax_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_arg_min_max():
'''Verify argmin and argmax'''
verify_argmin([3,4,4])
verify_argmax([3,4,4])
verify_argmin([3,4,4], axis=1)
verify_argmax([3,4,4], axis=0)
verify_argmin([3,4,4], keepdims=0)
verify_argmax([3,4,4], keepdims=1)
for axis in [None, 0,1,2]:
for keepdims in [None, True,False]:
verify_argmin([3,4,4], axis, keepdims)
verify_argmax([3,4,4], axis, keepdims)
def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
input_a = np.random.uniform(size=input_dim).astype(dtype)
out = np.empty(shape=out_dim, dtype=dtype)
out.fill(value)
if is_shape == True:
fill_node = helper.make_node("ConstantFill", [], ["out"], shape=input_dim, value=value, **kwargs)
else:
fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)
graph = helper.make_graph([fill_node],
"fill_test",
inputs = [helper.make_tensor_value_info("input_a",
TensorProto.FLOAT, list(input_dim))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out.shape))])
model = helper.make_model(graph, producer_name='fill_test')
for target, ctx in ctx_list():
if is_shape == True:
tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
else:
tvm_out = get_tvm_output(model, [input_a], target, ctx, out.shape)
tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5)
def test_constantfill():
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
def verify_pad(indata, pads, value=0.0):
indata = np.array(indata).astype(np.float32)
# numpy expect result
len_dim = len(pads) // 2
np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)]
outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
# onnx graph
node = helper.make_node(
'Pad',
inputs=['input'],
outputs=['output'],
mode='constant',
pads=pads,
value=value
)
graph = helper.make_graph([node],
'pad_test',
inputs = [helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='pad_test')
# tvm result
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_pad():
verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0)
verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0)
verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0)
def verify_reduce_x(name, indata, axis, keepdims):
indata = np.array(indata).astype(np.float32)
# numpy expect result
if name == 'ReduceMax':
outdata = np.maximum.reduce(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceMin':
outdata = np.minimum.reduce(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceSum':
outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceMean':
outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1)
else:
raise Exception('unsupport op: {}'.format(name))
if len(np.asarray(outdata).shape) == 0:
outdata = np.asarray([outdata])
# onnx graph
if axis is None:
node = helper.make_node(name, inputs=['input'], outputs=['output'],
keepdims=keepdims)
else:
node = helper.make_node(name, inputs=['input'], outputs=['output'],
axes=axis, keepdims=keepdims)
graph = helper.make_graph([node],
'{}_test'.format(name),
inputs = [helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='{}_test'.format(name))
# tvm result
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
def test_reduce_max():
verify_reduce_x("ReduceMax",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMax",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMax",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_min():
verify_reduce_x("ReduceMin",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMin",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMin",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_sum():
verify_reduce_x("ReduceSum",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceSum",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceSum",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def test_reduce_mean():
verify_reduce_x("ReduceMean",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMean",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMean",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)
def verify_split(indata, outdatas, split, axis=0):
indata = np.array(indata).astype(np.float32)
outdatas = [np.array(o).astype(np.float32) for o in outdatas]
node = helper.make_node(
'Split',
inputs=['input'],
outputs=['output_{}'.format(i) for i in range(len(split))],
axis=axis,
split=split
)
graph = helper.make_graph([node],
'split_test',
inputs = [helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs = [helper.make_tensor_value_info("output_{}".format(i),
TensorProto.FLOAT, list(outdatas[i].shape))
for i in range(len(split))
])
model = helper.make_model(graph, producer_name='split_test')
for target, ctx in ctx_list():
output_shape = [o.shape for o in outdatas]
output_type = ['float32', 'float32', 'float32']
tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type)
for o, t in zip(outdatas, tvm_out):
tvm.testing.assert_allclose(o, t)
def test_split():
# 1D
verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
# 2D
verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]],
[[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1)
def test_binary_ops():
in_shape = (1, 2, 3, 3)
dtype = "float32"
out_shape = in_shape
def verify_binary_ops(op, x, y, out_np, broadcast=None):
if broadcast is None:
z = helper.make_node(op, ['in1', 'in2'], ['out'])
else:
z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1)
graph = helper.make_graph([z],
'_test',
inputs = [helper.make_tensor_value_info("in1",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in2",
TensorProto.FLOAT, list(in_shape))],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
x = np.random.uniform(size=in_shape).astype(dtype)
y = np.random.uniform(size=in_shape).astype(dtype)
z = np.random.uniform(size=(3,)).astype(dtype)
verify_binary_ops("Add",x, y, x + y, broadcast=None)
verify_binary_ops("Add", x, z, x + z, broadcast=True)
verify_binary_ops("Sub", x, y, x - y, broadcast=None)
verify_binary_ops("Sub", x, z, x - z, broadcast=True)
verify_binary_ops("Mul",x, y, x * y, broadcast=None)
verify_binary_ops("Mul", x, z, x * z, broadcast=True)
verify_binary_ops("Div", x, y, x / y, broadcast=None)
verify_binary_ops("Div", x, z, x / z, broadcast=True)
verify_binary_ops("Sum", x, y, x + y, broadcast=None)
def test_single_ops():
in_shape = (1, 2, 3, 3)
dtype = "float32"
out_shape = in_shape
def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5):
z = helper.make_node(op, ['in1'], ['out'])
graph = helper.make_graph([z],
'_test',
inputs = [helper.make_tensor_value_info("in1",
TensorProto.FLOAT, list(in_shape)),],
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out_shape))])
model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
x = np.random.uniform(size=in_shape).astype(dtype)
verify_single_ops("Neg",x, -x)
verify_single_ops("Abs",x, np.abs(x))
verify_single_ops("Reciprocal",x, 1/x)
verify_single_ops("Sqrt",x, np.sqrt(x))
verify_single_ops("Relu",x, np.maximum(x, 0))
verify_single_ops("Exp",x, np.exp(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Log",x, np.log(x))
verify_single_ops("Tanh",x, np.tanh(x))
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
def test_leaky_relu():
def leaky_relu_x(x, alpha):
return np.where(x >= 0, x, x * alpha)
_test_onnx_op_elementwise((2, 4, 5, 6),
leaky_relu_x,
{'alpha': 0.25},
'float32',
'LeakyRelu',
{'alpha': 0.25})
def test_elu():
def elu_x(x, alpha):
return np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
_test_onnx_op_elementwise((2, 4, 5, 6),
elu_x,
{'alpha': 0.25},
'float32',
'Elu',
{'alpha': 0.25})
def test_selu():
def selu_x(x, alpha, gamma):
return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
_test_onnx_op_elementwise((2, 4, 5, 6),
selu_x,
{'alpha': 0.25, 'gamma': 0.3},
'float32',
'Selu',
{'alpha': 0.25, 'gamma': 0.3})
def test_ThresholdedRelu():
def ThresholdedRelu_x(x, alpha):
out_np = np.clip(x, alpha, np.inf)
out_np[out_np == alpha] = 0
return out_np
_test_onnx_op_elementwise((2, 4, 5, 6),
ThresholdedRelu_x,
{'alpha': 0.25},
'float32',
'ThresholdedRelu',
{'alpha': 0.25})
def test_ScaledTanh():
def ScaledTanh_x(x, alpha, beta):
return alpha * np.tanh(beta * x)
_test_onnx_op_elementwise((2, 4, 5, 6),
ScaledTanh_x,
{'alpha': 0.25, 'beta': 0.3},
'float32',
'ScaledTanh',
{'alpha': 0.25, 'beta': 0.3})
def test_ParametricSoftplus():
def ParametricSoftplus_x(x, alpha, beta):
return alpha * np.log(np.exp(beta * x) + 1)
_test_onnx_op_elementwise((2, 4, 5, 6),
ParametricSoftplus_x,
{'alpha': 0.25, 'beta': 0.3},
'float32',
'ParametricSoftplus',
{'alpha': 0.25, 'beta': 0.3})
def test_Scale():
def Scale_x(x, scale):
return scale * x
_test_onnx_op_elementwise((2, 4, 5, 6),
Scale_x,
{'scale': 0.25},
'float32',
'Scale',
{'scale': 0.25})
def test_LogSoftmax():
_test_onnx_op_elementwise((1, 4),
topi.testing.log_softmax_python,
{},
'float32',
'LogSoftmax',
{'axis': 1})
if __name__ == '__main__':
test_reshape()
test_reshape_like()
test_power()
test_squeeze()
test_unsqueeze()
test_slice()
test_floor()
test_ceil()
test_clip()
test_matmul()
test_gather()
test_lrn()
test_upsample()
test_forward_min()
test_forward_max()
test_forward_mean()
test_forward_hardsigmoid()
test_forward_arg_min_max()
test_softmax()
test_constantfill()
test_pad()
test_reduce_max()
test_reduce_min()
test_reduce_sum()
test_reduce_mean()
test_pad()
test_split()
test_binary_ops()
test_single_ops()
test_leaky_relu()
test_elu()
test_selu()
test_ThresholdedRelu()
test_ScaledTanh()
test_ParametricSoftplus()
test_Scale()
test_LogSoftmax()
......@@ -32,3 +32,6 @@ python3 -m nose -v tests/python/frontend/mxnet || exit -1
echo "Running relay Keras frontend test..."
python3 -m nose -v tests/python/frontend/keras || exit -1
echo "Running relay ONNX frondend test..."
python3 -m nose -v tests/python/frontend/onnx || exit -1
"""
Compile ONNX Models
===================
**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
This article is an introductory tutorial to deploy ONNX models with Relay.
For us to begin with, ONNX package must be installed.
A quick solution is to install protobuf compiler, and
.. code-block:: bash
pip install onnx --user
or please refer to offical site.
https://github.com/onnx/onnx
"""
import onnx
import numpy as np
import tvm
import tvm.relay as relay
def download(url, path, overwrite=False):
import os
if os.path.isfile(path) and not overwrite:
print('File {} existed, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
import urllib.request
urllib.request.urlretrieve(url, path)
except:
import urllib
urllib.urlretrieve(url, path)
######################################################################
# Load pretrained ONNX model
# ---------------------------------------------
# The example super resolution model used here is exactly the same model in onnx tutorial
# http://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html
# we skip the pytorch model construction part, and download the saved onnx model
model_url = ''.join(['https://gist.github.com/zhreshold/',
'bcda4716699ac97ea44f791c24310193/raw/',
'93672b029103648953c4e5ad3ac3aadf346a4cdc/',
'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx', False)
# now you have super_resolution.onnx on disk
onnx_model = onnx.load('super_resolution.onnx')
######################################################################
# Load a test image
# ---------------------------------------------
# A single cat dominates the examples!
from PIL import Image
img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
download(img_url, 'cat.png')
img = Image.open('cat.png').resize((224, 224))
img_ycbcr = img.convert("YCbCr") # convert to YCbCr
img_y, img_cb, img_cr = img_ycbcr.split()
x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
######################################################################
# Compile the model with relay
# ---------------------------------------------
target = 'llvm'
input_name = '1'
shape_dict = {input_name: x.shape}
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
with relay.build_config(opt_level=1):
intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target)
######################################################################
# Execute on TVM
# ---------------------------------------------
tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
######################################################################
# Display results
# ---------------------------------------------
# We put input and output image neck to neck
from matplotlib import pyplot as plt
out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode='L')
out_cb = img_cb.resize(out_y.size, Image.BICUBIC)
out_cr = img_cr.resize(out_y.size, Image.BICUBIC)
result = Image.merge('YCbCr', [out_y, out_cb, out_cr]).convert('RGB')
canvas = np.full((672, 672*2, 3), 255)
canvas[0:224, 0:224, :] = np.asarray(img)
canvas[:, 672:, :] = np.asarray(result)
plt.imshow(canvas.astype(np.uint8))
plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment