Commit 8b6f80c2 by Joshua Z. Zhang Committed by Tianqi Chen

[Frontend] ONNX frontend v0.2 support (#202)

* update

* fix

* use generated onnx model

* fix tests

* fix lint

* remove log filter

* add vgg

* fix tests

* update tests

* fix download

* fix ci

* fix tutorial url

* clean cache
parent f5d158b7
"""Shared functions and classes for frontends.""" """Shared functions and classes for frontends."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import warnings import logging
from nnvm import sym as _sym
from .._base import string_types from .._base import string_types
def get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op
class Renamer(object): class Renamer(object):
"""A simply renamer for operators. """A simply renamer for operators.
...@@ -14,8 +21,8 @@ class Renamer(object): ...@@ -14,8 +21,8 @@ class Renamer(object):
def __init__(self, new_name): def __init__(self, new_name):
self._new_name = new_name self._new_name = new_name
def __call__(self, attrs): def __call__(self, inputs, attrs, *args):
return self._new_name, attrs return get_nnvm_op(self._new_name)(*inputs, **attrs)
class AttrConverter(object): class AttrConverter(object):
...@@ -40,9 +47,9 @@ class AttrConverter(object): ...@@ -40,9 +47,9 @@ class AttrConverter(object):
A list of excluded attributes that should `NOT` appear. A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured. Raise NotImplementedError if occured.
disables : list disables : list
A list of attributes that is disabled in nnvm. Raise warnings. A list of attributes that is disabled in nnvm. Log warnings.
ignores : list ignores : list
A list of attributes that is ignored in nnvm. Silent. A list of attributes that is ignored in nnvm. Debug level logging.
extras : dict extras : dict
A series of additional attributes should be added anyway to the returned A series of additional attributes should be added anyway to the returned
attribute dict. attribute dict.
...@@ -61,7 +68,7 @@ class AttrConverter(object): ...@@ -61,7 +68,7 @@ class AttrConverter(object):
self._extras = extras if extras else {} self._extras = extras if extras else {}
self._custom_check = custom_check self._custom_check = custom_check
def __call__(self, attrs): def __call__(self, inputs, attrs, *args):
# apply custom check # apply custom check
if self._custom_check: if self._custom_check:
func, msg = self._custom_check func, msg = self._custom_check
...@@ -79,9 +86,9 @@ class AttrConverter(object): ...@@ -79,9 +86,9 @@ class AttrConverter(object):
if k in self._excludes: if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k)) raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables: elif k in self._disables:
warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name)) logging.warning("Attribute %s is disabled in nnvm.sym.%s", k, op_name)
elif k in self._ignores: elif k in self._ignores:
pass logging.debug("Attribute %s is ignored in nnvm.sym.%s", k, op_name)
elif k in self._transforms: elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k]) new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None: if defaults is None:
...@@ -97,7 +104,7 @@ class AttrConverter(object): ...@@ -97,7 +104,7 @@ class AttrConverter(object):
new_attrs[k] = attrs[k] new_attrs[k] = attrs[k]
# add extras # add extras
new_attrs.update(self._extras) new_attrs.update(self._extras)
return op_name, new_attrs return get_nnvm_op(op_name)(*inputs, **new_attrs)
def _parse_default(self, target): def _parse_default(self, target):
"""Helper function to parse default values.""" """Helper function to parse default values."""
......
pip2 install onnx pip2 install onnx>=0.2.0
pip3 install onnx pip3 install onnx>=0.2.0
pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
pip2 install torchvision pip2 install torchvision
......
"""Store for onnx examples and common models.""" """Store for onnx examples and common models."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import os import os
import logging
from .super_resolution import get_super_resolution from .super_resolution import get_super_resolution
__all__ = ['super_resolution'] def _download(url, filename, overwrite=False):
if os.path.isfile(filename) and not overwrite:
logging.debug('File %s existed, skip.', filename)
return
logging.debug('Downloading from url %s to %s', url, filename)
try:
import urllib.request
urllib.request.urlretrieve(url, filename)
except:
import urllib
urllib.urlretrieve(url, filename)
def _as_abs_path(fname): def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__)) cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname) return os.path.join(cur_dir, fname)
# a pair of onnx pb file and corresponding nnvm symbol
super_resolution = (_as_abs_path('super_resolution.onnx'), get_super_resolution()) URLS = {
'super_resolution.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/super_resolution_0.2.onnx',
'squeezenet1_1.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/squeezenet1_1_0.2.onnx',
'lenet.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/lenet_0.2.onnx'}
# download and add paths
for k, v in URLS.items():
name = k.split('.')[0]
path = _as_abs_path(k)
_download(v, path, False)
locals()[name] = path
# symbol for graph comparison
super_resolution_sym = get_super_resolution()
"""NNVM symbol corresponding to super_resolution.onnx example.""" """NNVM symbol corresponding to super_resolution.onnx example."""
from nnvm import sym from nnvm import sym
def get_super_resolution_deprecated():
factor = 3
size = 224
data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2))
relu1 = sym.relu(conv1)
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1))
relu2 = sym.relu(conv2)
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1))
relu3 = sym.relu(conv3)
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1))
r1 = sym.reshape(conv4, shape=(0, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
r2 = sym.reshape(t1, shape=(0, 1, size * factor, size * factor))
return r2
def get_super_resolution(): def get_super_resolution():
factor = 3 factor = 3
size = 224 size = 224
data = sym.Variable(name='9') data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2), use_bias=False) conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2), use_bias=False)
relu1 = sym.relu(conv1 + sym.Variable(name='2')) relu1 = sym.relu(conv1 + sym.expand_dims(sym.Variable(name='2', shape=(64)), axis=1, num_newaxis=2))
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1), use_bias=False) conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu2 = sym.relu(conv2 + sym.Variable(name='4')) relu2 = sym.relu(conv2 + sym.expand_dims(sym.Variable(name='4', shape=(64)), axis=1, num_newaxis=2))
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1), use_bias=False) conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu3 = sym.relu(conv3 + sym.Variable(name='6')) relu3 = sym.relu(conv3 + sym.expand_dims(sym.Variable(name='6', shape=(32)), axis=1, num_newaxis=2))
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1), use_bias=False) conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
conv4 = conv4 + sym.Variable(name='8') conv4 = conv4 + sym.expand_dims(sym.Variable(name='8', shape=(factor**2)), axis=1, num_newaxis=2)
# TODO(zhreshold): allow shape inference for batch size > 1 # TODO(zhreshold): allow shape inference for batch size > 1
r1 = sym.reshape(conv4, shape=(1, 1, factor, factor, size, size)) r1 = sym.reshape(conv4, shape=(1, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3)) t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
......
...@@ -4,13 +4,13 @@ import tvm ...@@ -4,13 +4,13 @@ import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
import onnx import onnx
from model_zoo import super_resolution from model_zoo import super_resolution, squeezenet1_1, lenet
def verify_onnx_forward_impl(graph_file, data_shape, out_shape): def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
import onnx_caffe2.backend import onnx_caffe2.backend
def get_caffe2_output(graph, x, dtype='float32'): def get_caffe2_output(model, x, dtype='float32'):
prepared_backend = onnx_caffe2.backend.prepare(graph) prepared_backend = onnx_caffe2.backend.prepare(model)
W = {graph.input[-1]: x.astype(dtype)} W = {model.graph.input[0].name: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0] c2_out = prepared_backend.run(W)[0]
return c2_out return c2_out
...@@ -29,14 +29,22 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): ...@@ -29,14 +29,22 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
dtype = 'float32' dtype = 'float32'
x = np.random.uniform(size=data_shape) x = np.random.uniform(size=data_shape)
graph = onnx.load(graph_file) model = onnx.load(graph_file)
c2_out = get_caffe2_output(graph, x, dtype) c2_out = get_caffe2_output(model, x, dtype)
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_out = get_tvm_output(graph, x, target, ctx, dtype) tvm_out = get_tvm_output(model, x, target, ctx, dtype)
np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_super_resolution_example(): def verify_super_resolution_example():
verify_onnx_forward_impl(super_resolution[0], (1, 1, 224, 224), (1, 1, 672, 672)) verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
def verify_squeezenet1_1():
verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))
def verify_lenet():
verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
if __name__ == '__main__': if __name__ == '__main__':
verify_super_resolution_example() verify_super_resolution_example()
verify_squeezenet1_1()
verify_lenet()
...@@ -2,14 +2,9 @@ ...@@ -2,14 +2,9 @@
import nnvm import nnvm
import onnx import onnx
from nnvm.compiler import graph_util, graph_attr from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution from model_zoo import super_resolution, super_resolution_sym
def compare_graph(onnx_file, nnvm_sym, ishape): def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_vars = [int(n) for n in onnx.__version__.split('.')] if hasattr(onnx, "__version__") else []
if len(onnx_vars) >= 2 and (onnx_vars[0] > 0 or onnx_vars[1] >= 2): # version >= 0.2
onnx_model = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_model.graph)
else:
onnx_graph = onnx.load(onnx_file) onnx_graph = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph) onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
g1 = nnvm.graph.create(onnx_sym) g1 = nnvm.graph.create(onnx_sym)
...@@ -22,7 +17,7 @@ def compare_graph(onnx_file, nnvm_sym, ishape): ...@@ -22,7 +17,7 @@ def compare_graph(onnx_file, nnvm_sym, ishape):
graph_util.check_graph_equal(g1, g2) graph_util.check_graph_equal(g1, g2)
def test_super_resolution_example(): def test_super_resolution_example():
fname, symbol = super_resolution fname, symbol = super_resolution, super_resolution_sym
compare_graph(fname, symbol, ishape=(1, 1, 224, 224)) compare_graph(fname, symbol, ishape=(1, 1, 224, 224))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,12 +21,17 @@ import numpy as np ...@@ -21,12 +21,17 @@ import numpy as np
from PIL import Image from PIL import Image
def download(url, path, overwrite=False): def download(url, path, overwrite=False):
import urllib2, os import os
if os.path.exists(path) and not overwrite: if os.path.isfile(path) and not overwrite:
print('File {} existed, skip.'.format(path))
return return
print('Downloading {} to {}.'.format(url, path)) print('Downloading from url {} to {}'.format(url, path))
with open(path, 'w') as f: try:
f.write(urllib2.urlopen(url).read()) import urllib.request
urllib.request.urlretrieve(url, path)
except:
import urllib
urllib.urlretrieve(url, path)
###################################################################### ######################################################################
# Load pretrained CoreML model # Load pretrained CoreML model
......
...@@ -20,12 +20,17 @@ import onnx ...@@ -20,12 +20,17 @@ import onnx
import numpy as np import numpy as np
def download(url, path, overwrite=False): def download(url, path, overwrite=False):
import urllib2, os import os
if os.path.exists(path) and not overwrite: if os.path.isfile(path) and not overwrite:
print('File {} existed, skip.'.format(path))
return return
print('Downloading {} to {}.'.format(url, path)) print('Downloading from url {} to {}'.format(url, path))
with open(path, 'w') as f: try:
f.write(urllib2.urlopen(url).read()) import urllib.request
urllib.request.urlretrieve(url, path)
except:
import urllib
urllib.urlretrieve(url, path)
###################################################################### ######################################################################
# Load pretrained ONNX model # Load pretrained ONNX model
...@@ -35,9 +40,9 @@ def download(url, path, overwrite=False): ...@@ -35,9 +40,9 @@ def download(url, path, overwrite=False):
# we skip the pytorch model construction part, and download the saved onnx model # we skip the pytorch model construction part, and download the saved onnx model
model_url = ''.join(['https://gist.github.com/zhreshold/', model_url = ''.join(['https://gist.github.com/zhreshold/',
'bcda4716699ac97ea44f791c24310193/raw/', 'bcda4716699ac97ea44f791c24310193/raw/',
'41b443bf2b6cf795892d98edd28bacecd8eb0d8d/', '93672b029103648953c4e5ad3ac3aadf346a4cdc/',
'super_resolution.onnx']) 'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx') download(model_url, 'super_resolution.onnx', True)
# now you have super_resolution.onnx on disk # now you have super_resolution.onnx on disk
onnx_graph = onnx.load('super_resolution.onnx') onnx_graph = onnx.load('super_resolution.onnx')
# we can load the graph as NNVM compatible model # we can load the graph as NNVM compatible model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment