Commit 8b6f80c2 by Joshua Z. Zhang Committed by Tianqi Chen

[Frontend] ONNX frontend v0.2 support (#202)

* update

* fix

* use generated onnx model

* fix tests

* fix lint

* remove log filter

* add vgg

* fix tests

* update tests

* fix download

* fix ci

* fix tutorial url

* clean cache
parent f5d158b7
"""Shared functions and classes for frontends."""
from __future__ import absolute_import as _abs
import warnings
import logging
from nnvm import sym as _sym
from .._base import string_types
def get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op
class Renamer(object):
"""A simply renamer for operators.
......@@ -14,8 +21,8 @@ class Renamer(object):
def __init__(self, new_name):
self._new_name = new_name
def __call__(self, attrs):
return self._new_name, attrs
def __call__(self, inputs, attrs, *args):
return get_nnvm_op(self._new_name)(*inputs, **attrs)
class AttrConverter(object):
......@@ -40,9 +47,9 @@ class AttrConverter(object):
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured.
disables : list
A list of attributes that is disabled in nnvm. Raise warnings.
A list of attributes that is disabled in nnvm. Log warnings.
ignores : list
A list of attributes that is ignored in nnvm. Silent.
A list of attributes that is ignored in nnvm. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
......@@ -61,7 +68,7 @@ class AttrConverter(object):
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, attrs):
def __call__(self, inputs, attrs, *args):
# apply custom check
if self._custom_check:
func, msg = self._custom_check
......@@ -79,9 +86,9 @@ class AttrConverter(object):
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name))
logging.warning("Attribute %s is disabled in nnvm.sym.%s", k, op_name)
elif k in self._ignores:
pass
logging.debug("Attribute %s is ignored in nnvm.sym.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
......@@ -97,7 +104,7 @@ class AttrConverter(object):
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return op_name, new_attrs
return get_nnvm_op(op_name)(*inputs, **new_attrs)
def _parse_default(self, target):
"""Helper function to parse default values."""
......
# pylint: disable=import-self, invalid-name
# pylint: disable=import-self, invalid-name, unused-argument
"""ONNX: Open Neural Network Exchange frontend."""
from __future__ import absolute_import as _abs
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
from .common import Renamer, AttrConverter as AttrCvt
from .common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
__all__ = ['from_onnx']
......@@ -49,11 +49,26 @@ def _dimension_constraint():
return False
return _dim_check, "Only 2d kernel supported."
def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def _elemwise(name):
return AttrCvt(
op_name=_math_name_picker(name),
disables=['axis'],
ignores=['broadcast'])
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
op_name = _math_name_picker(name)(attr)
axis = int(attr.get('axis', 0))
if op_name == 'broadcast_add' and inputs[0].attr('op_name') == 'conv2d':
# TODO(zhreshold): remove hard coded infershape
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_nnvm_op(op_name)(*inputs)
return _impl
def _pooling(name):
return AttrCvt(
......@@ -68,6 +83,10 @@ def _pooling(name):
custom_check=_dimension_constraint())
def _conv():
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
attr['channels'] = channels
return AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
......@@ -75,9 +94,15 @@ def _conv():
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad),
'group': ('groups', 1)},
custom_check=_dimension_constraint())
extras={'use_bias': False},
custom_check=_dimension_constraint())(inputs, attr)
return _impl
def _conv_transpose():
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
attr['channels'] = channels
return AttrCvt(
op_name=_dimension_picker('conv', '_transpose'),
transforms={
......@@ -85,7 +110,17 @@ def _conv_transpose():
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), _revert_caffe2_pad)},
disables=['output_shape'],
custom_check=_dimension_constraint())
extras={'use_bias': False},
custom_check=_dimension_constraint())(inputs, attr)
return _impl
def _fully_connected():
def _impl(inputs, attr, params):
# get number of channels
channels = _infer_channels(inputs[1], params)
attr['units'] = channels
return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
return _impl
def _batch_norm():
# TODO(zhreshold): 'spatial' is not properly handled here.
......@@ -95,10 +130,32 @@ def _batch_norm():
ignores=['spatial', 'is_test', 'consumed_inputs'])
def _gemm():
def _impl(inputs, attr, params):
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs))
# Y = alpha * A * B + beta * C
alpha = float(attr.get('alpha', 1.0))
beta = float(attr.get('beta', 1.0))
transA = int(attr.get('transA', 0))
transB = int(attr.get('transB', 0))
# get number of channels
channels = _infer_channels(inputs[1], params, not transB)
if transA:
inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
if not transB:
inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
return _sym.dense(alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
return _impl
# compatible operators that do NOT require any conversion.
_identity_list = []
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
# defs/experimental
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']),
......@@ -140,6 +197,7 @@ _convert_map = {
# 'Sum' : elemwise sum
# softmax default axis is different in onnx
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
'Gemm' : _gemm(),
# defs/nn
'AveragePool' : _pooling('avg_pool'),
......@@ -151,6 +209,7 @@ _convert_map = {
'BatchNormalization': _batch_norm(),
'Dropout' : AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten' : Renamer('flatten'),
# 'LRN'
# defs/reduction
'ReduceMax' : AttrCvt('max', {'axes', 'axis'}),
......@@ -173,42 +232,6 @@ _convert_map = {
# 'Squeeze'
}
def _convert_operator(op_name, attrs, identity_list=None, convert_map=None):
"""Convert from onnx operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_name : str
Operator name, such as Convolution, FullyConnected
attrs : dict
Dict of operator attributes
identity_list : list
List of operators that don't require conversion
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
(op_name, attrs)
Converted (op_name, attrs) for nnvm.
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list:
pass
elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
op = getattr(_sym, op_name, None)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op, attrs
class GraphProto(object):
"""A helper class for handling nnvm graph copying from pb2.GraphProto.
......@@ -247,16 +270,14 @@ class GraphProto(object):
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
try:
i_name = i.name
except AttributeError:
i_name = i
i_name = self._parse_value_proto(i)
if i_name in self._params:
# i is a param instead of input
name_param = 'param_{}'.format(self._num_param)
self._num_param += 1
self._params[name_param] = self._params.pop(i_name)
self._nodes[name_param] = _sym.Variable(name=name_param)
self._nodes[name_param] = _sym.Variable(
name=name_param, shape=self._params[name_param].shape)
self._renames[i_name] = name_param
else:
name_input = 'input_{}'.format(self._num_input)
......@@ -264,18 +285,11 @@ class GraphProto(object):
self._nodes[name_input] = _sym.Variable(name=name_input)
self._renames[i_name] = name_input
# construct nodes, nodes are stored as directed acyclic graph
for idx, node in enumerate(graph.node):
for node in graph.node:
op_name = node.op_type
node_name = node.name.strip()
node_name = node_name if node_name else None
attr = self._parse_attr(node.attribute)
new_op, new_attr = _convert_operator(op_name, attr)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
# some hacks for onnx problem
new_attr = self._fix_bias(new_op, new_attr, len(inputs))
new_attr = self._fix_channels(new_op, new_attr, list(node.input))
self._fix_bias_shape(node.op_type, graph.node[idx-1].op_type, node.input)
op = new_op(name=node_name, *inputs, **new_attr)
op = self._convert_operator(op_name, inputs, attr)
node_output = self._fix_outputs(op_name, node.output)
assert len(node_output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {} in {}.".format(
......@@ -283,13 +297,21 @@ class GraphProto(object):
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# now return the outputs
out = [self._nodes[i] for i in graph.output]
out = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
if len(out) > 1:
out = _sym.Group(out)
else:
out = out[0]
return out, self._params
def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str."""
try:
name = value_proto.name
except AttributeError:
name = value_proto
return name
def _parse_array(self, tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
......@@ -320,64 +342,52 @@ class GraphProto(object):
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs
def _fix_outputs(self, op, outputs):
def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None):
"""Convert from onnx operator to nnvm operator.
The converter must specify conversions explicity for incompatible name, and
apply handlers to operator attributes.
Parameters
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
identity_list : list
List of operators that don't require conversion
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list:
sym = get_nnvm_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym
def _fix_outputs(self, op_name, outputs):
"""A hack to handle dropout or similar operator that have more than one out
in ONNX.
"""
if op == 'Dropout':
assert len(outputs) == 2, "ONNX have two outputs for dropout layer."
if op_name == 'Dropout':
if len(outputs) == 1:
return outputs
# TODO(zhreshold): support dropout mask?
outputs = outputs[:-1]
return outputs
def _fix_bias(self, op, attrs, num_inputs):
"""A hack for 'use_bias' attribute since onnx don't provide this attribute,
we have to check the number of inputs to decide it."""
if op not in [_sym.conv2d, _sym.conv2d_transpose, _sym.dense]:
return attrs
if num_inputs == 3:
attrs['use_bias'] = True
elif num_inputs == 2:
attrs['use_bias'] = False
else:
raise ValueError("Unexpected number of inputs for: {}".format(op))
return attrs
def _fix_bias_shape(self, op_name, last_op_name, inputs):
"""A hack to reshape bias term to (1, num_channel)."""
if op_name == 'Add' and last_op_name == 'Conv':
assert len(list(inputs)) == 2
bias_name = self._renames.get(inputs[1], inputs[1])
bias = self._params[bias_name]
assert len(bias.shape) == 1
# reshape to (1, n)
bias = tvm.nd.array(bias.asnumpy().reshape((1, -1, 1, 1)))
self._params[bias_name] = bias
def _fix_channels(self, op, attrs, inputs):
"""A hack for getting 'channles' or 'units' since onnx don't provide
these attributes. We check the shape of weights provided to get the number.
"""
if op not in [_sym.conv2d, _sym.conv2d_transpose, _sym.dense]:
return attrs
if inputs[1] not in self._renames:
assert inputs[1] in self._nodes
g = _graph.create(self._nodes[inputs[1]])
shape_dict = {k: v.shape for k, v in self._params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0]
else:
weight_name = self._renames[inputs[1]]
if not weight_name in self._params:
raise ValueError("Unable to get channels/units attr from onnx graph.")
else:
wshape = self._params[weight_name].shape
assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape)
channels = wshape[0]
if op in [_sym.dense]:
attrs['units'] = channels
else:
attrs['channels'] = channels
return attrs
def from_onnx(graph):
"""Load onnx graph which is a python protobuf object in to nnvm graph.
......@@ -389,7 +399,7 @@ def from_onnx(graph):
Parameters
----------
graph : protobuf object
ONNX graph
ONNX GraphProto, or ONNX ModelProto after ONNX v0.2
Returns
-------
......@@ -400,5 +410,8 @@ def from_onnx(graph):
Dict of converted parameters stored in tvm.ndarray format
"""
g = GraphProto()
if hasattr(graph, 'graph'):
# it's a ModelProto wrapper
graph = graph.graph
sym, params = g.from_onnx(graph)
return sym, params
pip2 install onnx
pip3 install onnx
pip2 install onnx>=0.2.0
pip3 install onnx>=0.2.0
pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
pip2 install torchvision
......
"""Store for onnx examples and common models."""
from __future__ import absolute_import as _abs
import os
import logging
from .super_resolution import get_super_resolution
__all__ = ['super_resolution']
def _download(url, filename, overwrite=False):
if os.path.isfile(filename) and not overwrite:
logging.debug('File %s existed, skip.', filename)
return
logging.debug('Downloading from url %s to %s', url, filename)
try:
import urllib.request
urllib.request.urlretrieve(url, filename)
except:
import urllib
urllib.urlretrieve(url, filename)
def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname)
# a pair of onnx pb file and corresponding nnvm symbol
super_resolution = (_as_abs_path('super_resolution.onnx'), get_super_resolution())
URLS = {
'super_resolution.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/super_resolution_0.2.onnx',
'squeezenet1_1.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/squeezenet1_1_0.2.onnx',
'lenet.onnx': 'https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/93672b029103648953c4e5ad3ac3aadf346a4cdc/lenet_0.2.onnx'}
# download and add paths
for k, v in URLS.items():
name = k.split('.')[0]
path = _as_abs_path(k)
_download(v, path, False)
locals()[name] = path
# symbol for graph comparison
super_resolution_sym = get_super_resolution()
"""NNVM symbol corresponding to super_resolution.onnx example."""
from nnvm import sym
def get_super_resolution_deprecated():
factor = 3
size = 224
data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2))
relu1 = sym.relu(conv1)
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1))
relu2 = sym.relu(conv2)
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1))
relu3 = sym.relu(conv3)
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1))
r1 = sym.reshape(conv4, shape=(0, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
r2 = sym.reshape(t1, shape=(0, 1, size * factor, size * factor))
return r2
def get_super_resolution():
factor = 3
size = 224
data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2), use_bias=False)
relu1 = sym.relu(conv1 + sym.Variable(name='2'))
relu1 = sym.relu(conv1 + sym.expand_dims(sym.Variable(name='2', shape=(64)), axis=1, num_newaxis=2))
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu2 = sym.relu(conv2 + sym.Variable(name='4'))
relu2 = sym.relu(conv2 + sym.expand_dims(sym.Variable(name='4', shape=(64)), axis=1, num_newaxis=2))
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu3 = sym.relu(conv3 + sym.Variable(name='6'))
relu3 = sym.relu(conv3 + sym.expand_dims(sym.Variable(name='6', shape=(32)), axis=1, num_newaxis=2))
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
conv4 = conv4 + sym.Variable(name='8')
conv4 = conv4 + sym.expand_dims(sym.Variable(name='8', shape=(factor**2)), axis=1, num_newaxis=2)
# TODO(zhreshold): allow shape inference for batch size > 1
r1 = sym.reshape(conv4, shape=(1, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
......
......@@ -4,13 +4,13 @@ import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
import onnx
from model_zoo import super_resolution
from model_zoo import super_resolution, squeezenet1_1, lenet
def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
import onnx_caffe2.backend
def get_caffe2_output(graph, x, dtype='float32'):
prepared_backend = onnx_caffe2.backend.prepare(graph)
W = {graph.input[-1]: x.astype(dtype)}
def get_caffe2_output(model, x, dtype='float32'):
prepared_backend = onnx_caffe2.backend.prepare(model)
W = {model.graph.input[0].name: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0]
return c2_out
......@@ -29,14 +29,22 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
dtype = 'float32'
x = np.random.uniform(size=data_shape)
graph = onnx.load(graph_file)
c2_out = get_caffe2_output(graph, x, dtype)
model = onnx.load(graph_file)
c2_out = get_caffe2_output(model, x, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(graph, x, target, ctx, dtype)
tvm_out = get_tvm_output(model, x, target, ctx, dtype)
np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_super_resolution_example():
verify_onnx_forward_impl(super_resolution[0], (1, 1, 224, 224), (1, 1, 672, 672))
verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
def verify_squeezenet1_1():
verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))
def verify_lenet():
verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
if __name__ == '__main__':
verify_super_resolution_example()
verify_squeezenet1_1()
verify_lenet()
......@@ -2,14 +2,9 @@
import nnvm
import onnx
from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution
from model_zoo import super_resolution, super_resolution_sym
def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_vars = [int(n) for n in onnx.__version__.split('.')] if hasattr(onnx, "__version__") else []
if len(onnx_vars) >= 2 and (onnx_vars[0] > 0 or onnx_vars[1] >= 2): # version >= 0.2
onnx_model = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_model.graph)
else:
onnx_graph = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
g1 = nnvm.graph.create(onnx_sym)
......@@ -22,7 +17,7 @@ def compare_graph(onnx_file, nnvm_sym, ishape):
graph_util.check_graph_equal(g1, g2)
def test_super_resolution_example():
fname, symbol = super_resolution
fname, symbol = super_resolution, super_resolution_sym
compare_graph(fname, symbol, ishape=(1, 1, 224, 224))
if __name__ == '__main__':
......
......@@ -21,12 +21,17 @@ import numpy as np
from PIL import Image
def download(url, path, overwrite=False):
import urllib2, os
if os.path.exists(path) and not overwrite:
import os
if os.path.isfile(path) and not overwrite:
print('File {} existed, skip.'.format(path))
return
print('Downloading {} to {}.'.format(url, path))
with open(path, 'w') as f:
f.write(urllib2.urlopen(url).read())
print('Downloading from url {} to {}'.format(url, path))
try:
import urllib.request
urllib.request.urlretrieve(url, path)
except:
import urllib
urllib.urlretrieve(url, path)
######################################################################
# Load pretrained CoreML model
......
......@@ -20,12 +20,17 @@ import onnx
import numpy as np
def download(url, path, overwrite=False):
import urllib2, os
if os.path.exists(path) and not overwrite:
import os
if os.path.isfile(path) and not overwrite:
print('File {} existed, skip.'.format(path))
return
print('Downloading {} to {}.'.format(url, path))
with open(path, 'w') as f:
f.write(urllib2.urlopen(url).read())
print('Downloading from url {} to {}'.format(url, path))
try:
import urllib.request
urllib.request.urlretrieve(url, path)
except:
import urllib
urllib.urlretrieve(url, path)
######################################################################
# Load pretrained ONNX model
......@@ -35,9 +40,9 @@ def download(url, path, overwrite=False):
# we skip the pytorch model construction part, and download the saved onnx model
model_url = ''.join(['https://gist.github.com/zhreshold/',
'bcda4716699ac97ea44f791c24310193/raw/',
'41b443bf2b6cf795892d98edd28bacecd8eb0d8d/',
'super_resolution.onnx'])
download(model_url, 'super_resolution.onnx')
'93672b029103648953c4e5ad3ac3aadf346a4cdc/',
'super_resolution_0.2.onnx'])
download(model_url, 'super_resolution.onnx', True)
# now you have super_resolution.onnx on disk
onnx_graph = onnx.load('super_resolution.onnx')
# we can load the graph as NNVM compatible model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment