Commit 30a5a600 by Joshua Z. Zhang Committed by Haichen Shen

[RELAY][FRONTEND]Onnx to relay frontend (#2302)

parent 312802f3
...@@ -35,7 +35,7 @@ Our goal is to build the shared libraries: ...@@ -35,7 +35,7 @@ Our goal is to build the shared libraries:
.. code:: bash .. code:: bash
sudo apt-get update sudo apt-get update
sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake
The minimal building requirements are The minimal building requirements are
......
...@@ -910,7 +910,7 @@ def test_single_ops(): ...@@ -910,7 +910,7 @@ def test_single_ops():
model = helper.make_model(graph, producer_name='_test') model = helper.make_model(graph, producer_name='_test')
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x], target, ctx) tvm_out = get_tvm_output(model, [x], target, ctx)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
x = np.random.uniform(size=in_shape).astype(dtype) x = np.random.uniform(size=in_shape).astype(dtype)
verify_single_ops("Neg",x, -x) verify_single_ops("Neg",x, -x)
...@@ -918,13 +918,13 @@ def test_single_ops(): ...@@ -918,13 +918,13 @@ def test_single_ops():
verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5) verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5)
verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5) verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Relu",x, np.maximum(x, 0)) verify_single_ops("Relu",x, np.maximum(x, 0))
verify_single_ops("Exp",x, np.exp(x)) verify_single_ops("Exp",x, np.exp(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Log",x, np.log(x)) verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Log",x, np.log(x)) verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Tanh",x, np.tanh(x)) verify_single_ops("Tanh",x, np.tanh(x), rtol=1e-5, atol=1e-5)
verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x))) verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)), rtol=1e-5, atol=1e-5)
verify_single_ops("Softsign",x, x / (1 + np.abs(x))) verify_single_ops("Softsign",x, x / (1 + np.abs(x)), rtol=1e-5, atol=1e-5)
verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x))) verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)), rtol=1e-5, atol=1e-5)
def test_leaky_relu(): def test_leaky_relu():
def leaky_relu_x(x, alpha): def leaky_relu_x(x, alpha):
......
...@@ -465,6 +465,14 @@ def const(value, dtype=None): ...@@ -465,6 +465,14 @@ def const(value, dtype=None):
""" """
if isinstance(value, (_base.numeric_types, (bool, list))): if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype) value = _np.array(value, dtype=dtype)
if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = {
_np.dtype('int64'): _np.int32,
_np.dtype('float64'): _np.float32
}.get(value.dtype, None)
if map_dtype:
value = value.astype(map_dtype)
if isinstance(value, (_np.ndarray, _np.generic)): if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value) value = _nd.array(value)
......
...@@ -9,3 +9,4 @@ from __future__ import absolute_import ...@@ -9,3 +9,4 @@ from __future__ import absolute_import
from .mxnet import from_mxnet from .mxnet import from_mxnet
from .keras import from_keras from .keras import from_keras
from .onnx import from_onnx
"""Common utilities""" """Common utilities"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging
from topi.util import get_const_tuple
from .. import expr as _expr from .. import expr as _expr
from .. import expr as _expr
from .. import ir_pass
from .. import op as _op
class RequiredAttr(object): class RequiredAttr(object):
...@@ -204,6 +209,30 @@ class StrAttrsDict(object): ...@@ -204,6 +209,30 @@ class StrAttrsDict(object):
raise AttributeError("Required attribute {} not found.".format(key)) raise AttributeError("Required attribute {} not found.".format(key))
return default return default
def get_relay_op(op_name):
"""Get the callable function from Relay based on operator name.
Parameters
----------
op_name : str
The Relay operator name.
"""
if '.' in op_name:
# explicit hierachical modules
op = _op
try:
for opn in op_name.split('.'):
op = getattr(op, opn)
except AttributeError:
op = None
else:
# try search op in various modules
for candidate in (_op, _op.nn, _op.image):
op = getattr(candidate, op_name, None)
if op is not None:
break
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op
class ExprTable(object): class ExprTable(object):
"""Table storing Relay expressions by names.""" """Table storing Relay expressions by names."""
...@@ -227,3 +256,156 @@ class ExprTable(object): ...@@ -227,3 +256,156 @@ class ExprTable(object):
def set_expr(self, name, expr): def set_expr(self, name, expr):
assert isinstance(expr, _expr.Expr) assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr self.exprs[name] = expr
class AttrCvt(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provded, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured.
disables : list
A list of attributes that is disabled in relay. Log warnings.
ignores : list
A list of attributes that is ignored in relay. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, inputs, attrs, *args):
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, str):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
elif k in self._ignores:
logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return get_relay_op(op_name)(*inputs, **new_attrs)
def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, str):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t
def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, str):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)
def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
def get_name(node):
name = ''
if hasattr(node, "name_hint"):
name = node.name_hint
return name
def infer_shape(inputs):
"""A method to get the output shape of an intermediate node in the graph."""
out_type = ir_pass.infer_type(inputs)
out_shapes = get_const_tuple(out_type.checked_type.shape)
return out_shapes
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
"""
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def new_var(name_hint,
type_annotation=None,
shape=None,
dtype="float32"):
return _expr.var(name_hint, type_annotation, shape, dtype)
class Renamer(object):
"""A simply renamer for operators.
Parameters
----------
new_name : str
The new name for the operator
"""
def __init__(self, new_name):
self._new_name = new_name
def __call__(self, inputs, attrs, *args):
return get_relay_op(self._new_name)(*inputs, **attrs)
...@@ -4,15 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,15 +4,7 @@ from __future__ import absolute_import as _abs
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .common import get_relay_op
def _get_relay_op(op_name):
op = _op
for path in op_name.split("."):
op = getattr(op, path)
if not op:
raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
return op
def _warn_not_used(attr, op='nnvm'): def _warn_not_used(attr, op='nnvm'):
import warnings import warnings
...@@ -22,7 +14,7 @@ def _warn_not_used(attr, op='nnvm'): ...@@ -22,7 +14,7 @@ def _warn_not_used(attr, op='nnvm'):
def _rename(new_op): def _rename(new_op):
if isinstance(new_op, str): if isinstance(new_op, str):
new_op = _get_relay_op(new_op) new_op = get_relay_op(new_op)
# attrs are ignored. # attrs are ignored.
def impl(inputs, _, _dtype='float32'): def impl(inputs, _, _dtype='float32'):
return new_op(*inputs) return new_op(*inputs)
......
...@@ -32,3 +32,6 @@ python3 -m nose -v tests/python/frontend/mxnet || exit -1 ...@@ -32,3 +32,6 @@ python3 -m nose -v tests/python/frontend/mxnet || exit -1
echo "Running relay Keras frontend test..." echo "Running relay Keras frontend test..."
python3 -m nose -v tests/python/frontend/keras || exit -1 python3 -m nose -v tests/python/frontend/keras || exit -1
echo "Running relay ONNX frondend test..."
python3 -m nose -v tests/python/frontend/onnx || exit -1
"""
Compile ONNX Models
===================
**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
This article is an introductory tutorial to deploy ONNX models with Relay.
For us to begin with, ONNX package must be installed.
A quick solution is to install protobuf compiler, and
.. code-block:: bash
pip install onnx --user
or please refer to offical site.
https://github.com/onnx/onnx
"""
import onnx
import numpy as np
import tvm
import tvm.relay as relay
def download(url, path, overwrite=False):
import os
if os.path.isfile(path) and not overwrite:
print('File {} existed, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
import urllib.request
urllib.request.urlretrieve(url, path)
except:
import urllib
urllib.urlretrieve(url, path)
######################################################################
# Load pretrained ONNX model
# ---------------------------------------------
# The example super resolution model used here is exactly the same model in onnx tutorial
# http://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html
# we skip the pytorch model construction part, and download the saved onnx model
model_url = ''.join(['https://gist.github.com/zhreshold/',
'bcda4716699ac97ea44f791c24310193/raw/',
'93672b029103648953c4e5ad3ac3aadf346a4cdc/',
'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx', False)
# now you have super_resolution.onnx on disk
onnx_model = onnx.load('super_resolution.onnx')
######################################################################
# Load a test image
# ---------------------------------------------
# A single cat dominates the examples!
from PIL import Image
img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
download(img_url, 'cat.png')
img = Image.open('cat.png').resize((224, 224))
img_ycbcr = img.convert("YCbCr") # convert to YCbCr
img_y, img_cb, img_cr = img_ycbcr.split()
x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
######################################################################
# Compile the model with relay
# ---------------------------------------------
target = 'llvm'
input_name = '1'
shape_dict = {input_name: x.shape}
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
with relay.build_config(opt_level=1):
intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target)
######################################################################
# Execute on TVM
# ---------------------------------------------
tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
######################################################################
# Display results
# ---------------------------------------------
# We put input and output image neck to neck
from matplotlib import pyplot as plt
out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode='L')
out_cb = img_cb.resize(out_y.size, Image.BICUBIC)
out_cr = img_cr.resize(out_y.size, Image.BICUBIC)
result = Image.merge('YCbCr', [out_y, out_cb, out_cr]).convert('RGB')
canvas = np.full((672, 672*2, 3), 255)
canvas[0:224, 0:224, :] = np.asarray(img)
canvas[:, 672:, :] = np.asarray(result)
plt.imshow(canvas.astype(np.uint8))
plt.show()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment