Commit 4f664f5b by Joshua Z. Zhang Committed by Tianqi Chen

[Frontend] Onnx (#40)

* init onnx

finish onnx frontend

add onnx tests

fix various

backup

use transformer

[Frontend] graph passed

add test forward

test forward

fix doc and lint

fix test graph tuple

from_onnx now take 2 args, output (sym, params)

fix rename

fix input names

fix multiple

fix lint

fix lint check

* better doc
parent dddd8d1a
"""NNVM frontends.""" """NNVM frontends."""
from __future__ import absolute_import from __future__ import absolute_import
from .mxnet import from_mxnet from .mxnet import from_mxnet
from .onnx import from_onnx
"""Shared functions and classes for frontends."""
from __future__ import absolute_import as _abs
import warnings
from .._base import string_types
class Renamer(object):
"""A simply renamer for operators.
Parameters
----------
new_name : str
The new name for the operator
"""
def __init__(self, new_name):
self._new_name = new_name
def __call__(self, attrs):
return self._new_name, attrs
class AttrConverter(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provded, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured.
disables : list
A list of attributes that is disabled in nnvm. Raise warnings.
ignores : list
A list of attributes that is ignored in nnvm. Silent.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, attrs):
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, string_types):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name))
elif k in self._ignores:
pass
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return op_name, new_attrs
def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, string_types):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t
def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, string_types):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)
def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
...@@ -58,12 +58,12 @@ def _pooling(attrs): ...@@ -58,12 +58,12 @@ def _pooling(attrs):
def _batch_norm(attrs): def _batch_norm(attrs):
if _parse_bool_str(attrs, 'output_mean_var'): if _parse_bool_str(attrs, 'output_mean_var'):
_raise_not_supported('output_mean_var', 'batch_norm') _raise_not_supported('output_mean_var', 'batch_norm')
if _parse_bool_str(attrs, 'fix_gamma'): # if _parse_bool_str(attrs, 'fix_gamma'):
_warn_not_used('fix_gamma', 'batch_norm') # _warn_not_used('fix_gamma', 'batch_norm')
if _parse_bool_str(attrs, 'use_global_stats'): if _parse_bool_str(attrs, 'use_global_stats'):
_warn_not_used('use_global_stats', 'batch_norm') _warn_not_used('use_global_stats', 'batch_norm')
if _parse_bool_str(attrs, 'momentum'): # if _parse_bool_str(attrs, 'momentum'):
_warn_not_used('momentum', 'batch_norm') # _warn_not_used('momentum', 'batch_norm')
op_name, new_attrs = 'batch_norm', {} op_name, new_attrs = 'batch_norm', {}
new_attrs['axis'] = attrs.get('axis', 1) new_attrs['axis'] = attrs.get('axis', 1)
new_attrs['epsilon'] = attrs.get('eps', 0.001) new_attrs['epsilon'] = attrs.get('eps', 0.001)
......
...@@ -2,4 +2,8 @@ ...@@ -2,4 +2,8 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .config import ctx_list from .config import ctx_list
from .utils import create_workload
from . import mobilenet from . import mobilenet
from . import mlp
from . import resnet
from . import vgg
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
from .. import symbol as sym
from . utils import create_workload
def get_symbol(num_classes=1000):
data = sym.Variable('data')
data = sym.flatten(data=data)
fc1 = sym.dense(data=data, name='fc1', units=128)
act1 = sym.relu(data=fc1, name='relu1')
fc2 = sym.dense(data=act1, name='fc2', units=64)
act2 = sym.relu(data=fc2, name='relu2')
fc3 = sym.dense(data=act2, name='fc3', units=num_classes)
mlp = sym.softmax(data=fc3, name='softmax')
return mlp
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for a simple multilayer perceptron
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes)
return create_workload(net, batch_size, image_shape, dtype)
...@@ -2,11 +2,8 @@ ...@@ -2,11 +2,8 @@
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import numpy as np
import tvm
from .. compiler import graph_util
from .. import graph
from .. import symbol as sym from .. import symbol as sym
from . utils import create_workload
def conv_block(data, name, channels, def conv_block(data, name, channels,
kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
...@@ -104,22 +101,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype= ...@@ -104,22 +101,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype=
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False) net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
params = {} return create_workload(net, batch_size, image_shape, dtype)
g = graph.create(net)
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes))
for k, v in shape_dict.items():
if k == "data":
continue
# Specially generate non-negative parameters.
if k.endswith("gamma"):
init = np.random.uniform(0.9, 1, size=v)
elif k.endswith("var"):
init = np.random.uniform(0.9, 1, size=v)
else:
init = np.random.uniform(-0.1, 0.1, size=v)
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
return net, params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import numpy as np
from .. import symbol as sym
from . utils import create_workload
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes,
# a bit difference with origin paper
bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
conv1 = sym.conv2d(
data=act1, channels=int(num_filter*0.25), kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), use_bias=False, name=name + '_conv1')
bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
conv2 = sym.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=stride, padding=(1, 1), use_bias=False, name=name + '_conv2')
bn3 = sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = sym.relu(data=bn3, name=name + '_relu3')
conv3 = sym.conv2d(
data=act3, channels=num_filter, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), use_bias=False, name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, use_bias=False, name=name+'_sc')
return sym.elemwise_add(conv3, shortcut)
else:
bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
conv1 = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), use_bias=False, name=name + '_conv1')
bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
conv2 = sym.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), use_bias=False, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, use_bias=False, name=name+'_sc')
return sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape,
bottle_neck=True, dtype='float32'):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert num_unit == num_stages
data = sym.Variable(name='data')
if dtype == 'float32':
data = data
else:
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
(_, height, _) = image_shape
if height <= 32: # such as cifar10
body = sym.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), use_bias=False, name="conv0")
else: # often expected to be 224 such as imagenet
body = sym.conv2d(
data=data, channels=filter_list[0], kernel_size=(7, 7),
strides=(2, 2), padding=(3, 3), use_bias=False, name="conv0")
body = sym.batch_norm(data=body, epsilon=2e-5, name='bn0')
body = sym.relu(data=body, name='relu0')
body = sym.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
for i in range(num_stages):
body = residual_unit(
body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck)
for j in range(units[i]-1):
body = residual_unit(
body, filter_list[i+1], (1, 1), True,
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck)
bn1 = sym.batch_norm(data=body, epsilon=2e-5, name='bn1')
relu1 = sym.relu(data=bn1, name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = sym.global_avg_pool2d(data=relu1, name='pool1')
flat = sym.flatten(data=pool1)
fc1 = sym.dense(data=flat, units=num_classes, name='fc1')
if dtype == 'float16':
fc1 = sym.cast(data=fc1, dtype=np.float32)
return sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), dtype='float32'):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
"""
(_, height, _) = image_shape
if height <= 28:
num_stages = 3
if (num_layers-2) % 9 == 0 and num_layers >= 164:
per_unit = [(num_layers-2)//9]
filter_list = [16, 64, 128, 256]
bottle_neck = True
elif (num_layers-2) % 6 == 0 and num_layers < 164:
per_unit = [(num_layers-2)//6]
filter_list = [16, 16, 32, 64]
bottle_neck = False
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
units = per_unit * num_stages
else:
if num_layers >= 50:
filter_list = [64, 256, 512, 1024, 2048]
bottle_neck = True
else:
filter_list = [64, 64, 128, 256, 512]
bottle_neck = False
num_stages = 4
if num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 50:
units = [3, 4, 6, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
return resnet(units=units,
num_stages=num_stages,
filter_list=filter_list,
num_classes=num_classes,
image_shape=image_shape,
bottle_neck=bottle_neck,
dtype=dtype)
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224),
dtype="float32", **kwargs):
"""Get benchmark workload for resnet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, image_shape=image_shape,
dtype=dtype, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
"""Helper utility to create common workload for testing."""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from ..compiler import graph_util
from ..import graph
def create_workload(net, batch_size, image_shape=(3, 224, 224), dtype="float32"):
"""Helper function to create benchmark workload for input network
Parameters
----------
net : nnvm.Symbol
The selected network symbol to use
batch_size : int
The batch size used in the model
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
params = {}
g = graph.create(net)
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes))
for k, v in shape_dict.items():
if k == "data":
continue
# Specially generate non-negative parameters.
if k.endswith("gamma"):
init = np.random.uniform(0.9, 1, size=v)
elif k.endswith("var"):
init = np.random.uniform(0.9, 1, size=v)
else:
init = np.random.uniform(-0.1, 0.1, size=v)
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
return net, params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import numpy as np
from .. import symbol as sym
from . utils import create_workload
def get_feature(internel_layer, layers, filters, batch_norm=False):
"""Get VGG feature body as stacks of convoltions."""
for i, num in enumerate(layers):
for j in range(num):
internel_layer = sym.conv2d(
data=internel_layer, kernel_size=(3, 3), padding=(1, 1),
channels=filters[i], name="conv%s_%s"%(i + 1, j + 1))
if batch_norm:
internel_layer = sym.batch_norm(
data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1))
internel_layer = sym.max_pool2d(
data=internel_layer, pool_size=(2, 2), strides=(2, 2), name="pool%s"%(i + 1))
return internel_layer
def get_classifier(input_data, num_classes):
"""Get VGG classifier layers as fc layers."""
flatten = sym.flatten(data=input_data, name="flatten")
fc6 = sym.dense(data=flatten, units=4096, name="fc6")
relu6 = sym.relu(data=fc6, name="relu6")
drop6 = sym.dropout(data=relu6, rate=0.5, name="drop6")
fc7 = sym.dense(data=drop6, units=4096, name="fc7")
relu7 = sym.relu(data=fc7, name="relu7")
drop7 = sym.dropout(data=relu7, rate=0.5, name="drop7")
fc8 = sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = sym.Variable(name="data")
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = sym.cast(data=classifier, dtype=np.float32)
symbol = sym.softmax(data=classifier, name='softmax')
return symbol
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224),
dtype="float32", **kwargs):
"""Get benchmark workload for VGG nets.
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, dtype=dtype, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
"""MXNet and NNVM model zoo."""
from __future__ import absolute_import from __future__ import absolute_import
from . import mlp, resnet, vgg from . import mlp, resnet, vgg
import nnvm.testing
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
_num_class = 1000 _num_class = 1000
# mlp fc # mlp fc
mx_mlp = mlp.get_symbol(_num_class) mx_mlp = mlp.get_symbol(_num_class)
nnvm_mlp = mlp.get_symbol_nnvm(_num_class) nnvm_mlp = nnvm.testing.mlp.get_workload(1, _num_class)[0]
# resnet fc # resnet fc
mx_resnet = {} mx_resnet = {}
nnvm_resnet = {} nnvm_resnet = {}
for num_layer in [18, 34, 50, 101, 152, 200, 269]: for num_layer in [18, 34, 50, 101, 152, 200, 269]:
mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224') mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224')
nnvm_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3, 224, 224', lib='nnvm') nnvm_resnet[num_layer] = nnvm.testing.resnet.get_workload(
1, _num_class, num_layers=num_layer)[0]
# vgg fc # vgg fc
mx_vgg = {} mx_vgg = {}
nnvm_vgg = {} nnvm_vgg = {}
for num_layer in [11, 13, 16, 19]: for num_layer in [11, 13, 16, 19]:
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer) mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
nnvm_vgg[num_layer] = vgg.get_symbol_nnvm(_num_class, num_layer) nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
1, _num_class, num_layers=num_layer)[0]
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
a simple multilayer perceptron a simple multilayer perceptron
""" """
import mxnet as mx import mxnet as mx
import nnvm
def get_symbol(num_classes=10, **kwargs): def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data') data = mx.symbol.Variable('data')
...@@ -31,14 +30,3 @@ def get_symbol(num_classes=10, **kwargs): ...@@ -31,14 +30,3 @@ def get_symbol(num_classes=10, **kwargs):
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax') mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp return mlp
def get_symbol_nnvm(num_classes=10, **kwargs):
data = nnvm.symbol.Variable('data')
data = nnvm.sym.flatten(data=data)
fc1 = nnvm.symbol.dense(data = data, name='fc1', units=128)
act1 = nnvm.symbol.relu(data = fc1, name='relu1')
fc2 = nnvm.symbol.dense(data = act1, name = 'fc2', units = 64)
act2 = nnvm.symbol.relu(data = fc2, name='relu2')
fc3 = nnvm.symbol.dense(data = act2, name='fc3', units=num_classes)
mlp = nnvm.symbol.softmax(data = fc3, name = 'softmax')
return mlp
...@@ -25,7 +25,6 @@ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Re ...@@ -25,7 +25,6 @@ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Re
''' '''
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
import nnvm
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False): def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet """Return ResNet Unit symbol for building ResNet
...@@ -86,65 +85,6 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, b ...@@ -86,65 +85,6 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, b
shortcut._set_attr(mirror_stage='True') shortcut._set_attr(mirror_stage='True')
return conv2 + shortcut return conv2 + shortcut
def residual_unit_nnvm(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1')
conv1 = nnvm.sym.conv2d(data=act1, channels=int(num_filter*0.25), kernel_size=(1,1), strides=(1,1), padding=(0,0),
use_bias=False, name=name + '_conv1')
bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2')
conv2 = nnvm.sym.conv2d(data=act2, channels=int(num_filter*0.25), kernel_size=(3,3), strides=stride, padding=(1,1),
use_bias=False, name=name + '_conv2')
bn3 = nnvm.sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = nnvm.sym.relu(data=bn3, name=name + '_relu3')
conv3 = nnvm.sym.conv2d(data=act3, channels=num_filter, kernel_size=(1,1), strides=(1,1), padding=(0,0), use_bias=False,
name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False,
name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return nnvm.sym.elemwise_add(conv3, shortcut)
else:
bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1')
conv1 = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(3,3), strides=stride, padding=(1,1),
use_bias=False, name=name + '_conv1')
bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2')
conv2 = nnvm.sym.conv2d(data=act2, channels=num_filter, kernel_size=(3,3), strides=(1,1), padding=(1,1),
use_bias=False, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False,
name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return nnvm.sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False): def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
"""Return ResNet symbol of """Return ResNet symbol of
Parameters Parameters
...@@ -202,64 +142,7 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck ...@@ -202,64 +142,7 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck
fc1 = mx.sym.Cast(data=fc1, dtype=np.float32) fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
return mx.sym.softmax(data=fc1, name='softmax') return mx.sym.softmax(data=fc1, name='softmax')
def resnet_nnvm(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False): def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
workspace : int
Workspace used in convolution operator
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert(num_unit == num_stages)
data = nnvm.sym.Variable(name='data')
if dtype == 'float32':
# data = nnvm.sym.identity(data=data, name='id')
data = data
else:
if dtype == 'float16':
data = nnvm.sym.cast(data=data, dtype=np.float16)
data = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
(nchannel, height, width) = image_shape
if height <= 32: # such as cifar10
body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(3, 3), strides=(1,1), padding=(1, 1),
use_bias=False, name="conv0")
else: # often expected to be 224 such as imagenet
body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2,2), padding=(3, 3),
use_bias=False, name="conv0")
body = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn0')
body = nnvm.sym.relu(data=body, name='relu0')
body = nnvm.sym.max_pool2d(data=body, pool_size=(3, 3), strides=(2,2), padding=(1,1))
for i in range(num_stages):
body = residual_unit_nnvm(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
memonger=memonger)
for j in range(units[i]-1):
body = residual_unit_nnvm(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
bottle_neck=bottle_neck, memonger=memonger)
bn1 = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn1')
relu1 = nnvm.sym.relu(data=bn1, name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = nnvm.sym.global_avg_pool2d(data=relu1, name='pool1')
flat = nnvm.sym.flatten(data=pool1)
fc1 = nnvm.sym.dense(data=flat, units=num_classes, name='fc1')
if dtype == 'float16':
fc1 = nnvm.sym.cast(data=fc1, dtype=np.float32)
return nnvm.sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', lib='mxnet', **kwargs):
""" """
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu Original author Wei Wu
...@@ -311,12 +194,4 @@ def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype=' ...@@ -311,12 +194,4 @@ def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='
image_shape = image_shape, image_shape = image_shape,
bottle_neck = bottle_neck, bottle_neck = bottle_neck,
workspace = conv_workspace, workspace = conv_workspace,
dtype = dtype) if lib == 'mxnet' else \ dtype = dtype)
resnet_nnvm(units = units,
num_stages = num_stages,
filter_list = filter_list,
num_classes = num_classes,
image_shape = image_shape,
bottle_neck = bottle_neck,
workspace = conv_workspace,
dtype = dtype)
...@@ -22,7 +22,6 @@ large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014). ...@@ -22,7 +22,6 @@ large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
""" """
import mxnet as mx import mxnet as mx
import nnvm
import numpy as np import numpy as np
def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs): def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
...@@ -35,16 +34,6 @@ def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs): ...@@ -35,16 +34,6 @@ def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1)) internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
return internel_layer return internel_layer
def get_feature_nnvm(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = nnvm.sym.conv2d(data = internel_layer, kernel_size=(3, 3), padding=(1, 1), channels=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = nnvm.symbol.batch_norm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.max_pool2d(data=internel_layer, pool_size=(2, 2), strides=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_classifier(input_data, num_classes, **kwargs): def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten") flatten = mx.sym.Flatten(data=input_data, name="flatten")
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6") fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
...@@ -56,17 +45,6 @@ def get_classifier(input_data, num_classes, **kwargs): ...@@ -56,17 +45,6 @@ def get_classifier(input_data, num_classes, **kwargs):
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8 return fc8
def get_classifier_nnvm(input_data, num_classes, **kwargs):
flatten = nnvm.sym.flatten(data=input_data, name="flatten")
fc6 = nnvm.sym.dense(data=flatten, units=4096, name="fc6")
relu6 = nnvm.sym.relu(data=fc6, name="relu6")
drop6 = nnvm.sym.dropout(data=relu6, rate=0.5, name="drop6")
fc7 = nnvm.sym.dense(data=drop6, units=4096, name="fc7")
relu7 = nnvm.sym.relu(data=fc7, name="relu7")
drop7 = nnvm.sym.dropout(data=relu7, rate=0.5, name="drop7")
fc8 = nnvm.sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs): def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
""" """
Parameters Parameters
...@@ -96,33 +74,3 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', ** ...@@ -96,33 +74,3 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **
classifier = mx.sym.Cast(data=classifier, dtype=np.float32) classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
symbol = mx.sym.softmax(data=classifier, name='softmax') symbol = mx.sym.softmax(data=classifier, name='softmax')
return symbol return symbol
def get_symbol_nnvm(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = nnvm.sym.Variable(name="data")
if dtype == 'float16':
data = nnvm.sym.cast(data=data, dtype=np.float16)
feature = get_feature_nnvm(data, layers, filters, batch_norm)
classifier = get_classifier_nnvm(feature, num_classes)
if dtype == 'float16':
classifier = nnvm.sym.cast(data=classifier, dtype=np.float32)
symbol = nnvm.sym.softmax(data=classifier, name='softmax')
return symbol
...@@ -14,21 +14,21 @@ def compare_graph(sym1, sym2, ishape=(2, 3, 224, 224)): ...@@ -14,21 +14,21 @@ def compare_graph(sym1, sym2, ishape=(2, 3, 224, 224)):
def test_mlp(): def test_mlp():
mx_sym = model_zoo.mx_mlp mx_sym = model_zoo.mx_mlp
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym) from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_mlp nnvm_sym = model_zoo.nnvm_mlp
compare_graph(from_mx_sym, nnvm_sym) compare_graph(from_mx_sym, nnvm_sym)
def test_vgg(): def test_vgg():
for n in [11, 13, 16, 19]: for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg[n] mx_sym = model_zoo.mx_vgg[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym) from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_vgg[n] nnvm_sym = model_zoo.nnvm_vgg[n]
compare_graph(from_mx_sym, nnvm_sym) compare_graph(from_mx_sym, nnvm_sym)
def test_resnet(): def test_resnet():
for n in [18, 34, 50, 101]: for n in [18, 34, 50, 101]:
mx_sym = model_zoo.mx_resnet[n] mx_sym = model_zoo.mx_resnet[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym) from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_resnet[n] nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym) compare_graph(from_mx_sym, nnvm_sym)
......
"""Store for onnx examples and common models."""
from __future__ import absolute_import as _abs
import os
from .super_resolution import get_super_resolution
__all__ = ['super_resolution']
def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname)
# a pair of onnx pb file and corresponding nnvm symbol
super_resolution = (_as_abs_path('super_resolution.onnx'), get_super_resolution())
"""NNVM symbol corresponding to super_resolution.onnx example."""
from nnvm import sym
def get_super_resolution_deprecated():
factor = 3
size = 224
data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2))
relu1 = sym.relu(conv1)
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1))
relu2 = sym.relu(conv2)
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1))
relu3 = sym.relu(conv3)
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1))
r1 = sym.reshape(conv4, shape=(0, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
r2 = sym.reshape(t1, shape=(0, 1, size * factor, size * factor))
return r2
def get_super_resolution():
factor = 3
size = 224
data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2), use_bias=False)
relu1 = sym.relu(conv1 + sym.Variable(name='2'))
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu2 = sym.relu(conv2 + sym.Variable(name='4'))
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu3 = sym.relu(conv3 + sym.Variable(name='6'))
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
conv4 = conv4 + sym.Variable(name='8')
# TODO(zhreshold): allow shape inference for batch size > 1
r1 = sym.reshape(conv4, shape=(1, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
r2 = sym.reshape(t1, shape=(1, 1, size * factor, size * factor))
return r2
import numpy as np
import nnvm
import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
import onnx
import onnx_caffe2.backend
from model_zoo import super_resolution
def test_onnx_forward_impl(graph_file, data_shape, out_shape):
def get_caffe2_output(graph, x, dtype='float32'):
prepared_backend = onnx_caffe2.backend.prepare(graph)
W = {graph.input[-1]: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0]
return c2_out
def get_tvm_output(graph, x, target, ctx, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(graph)
shape_dict = {'input_0': x.shape}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('input_0', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
dtype = 'float32'
x = np.random.uniform(size=data_shape)
graph = onnx.load(graph_file)
c2_out = get_caffe2_output(graph, x, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(graph, x, target, ctx, dtype)
np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_super_resolution_example():
test_onnx_forward_impl(super_resolution[0], (1, 1, 224, 224), (1, 1, 672, 672))
if __name__ == '__main__':
test_super_resolution_example()
"""Test graph equality of onnx models."""
import nnvm
import onnx
from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution
def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_graph = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
g1 = nnvm.graph.create(onnx_sym)
g2 = nnvm.graph.create(nnvm_sym)
ishapes = {'input_0': ishape}
graph_attr.set_shape_inputs(g1, ishapes)
graph_attr.set_shape_inputs(g2, ishapes)
g1 = g1.apply("InferShape").apply("SimplifyInference")
g2 = g2.apply("InferShape").apply("SimplifyInference")
graph_util.check_graph_equal(g1, g2)
def test_super_resolution_example():
fname, symbol = super_resolution
compare_graph(fname, symbol, ishape=(1, 1, 224, 224))
if __name__ == '__main__':
test_super_resolution_example()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment