Commit 4f664f5b by Joshua Z. Zhang Committed by Tianqi Chen

[Frontend] Onnx (#40)

* init onnx

finish onnx frontend

add onnx tests

fix various

backup

use transformer

[Frontend] graph passed

add test forward

test forward

fix doc and lint

fix test graph tuple

from_onnx now take 2 args, output (sym, params)

fix rename

fix input names

fix multiple

fix lint

fix lint check

* better doc
parent dddd8d1a
"""NNVM frontends."""
from __future__ import absolute_import
from .mxnet import from_mxnet
from .onnx import from_onnx
"""Shared functions and classes for frontends."""
from __future__ import absolute_import as _abs
import warnings
from .._base import string_types
class Renamer(object):
"""A simply renamer for operators.
Parameters
----------
new_name : str
The new name for the operator
"""
def __init__(self, new_name):
self._new_name = new_name
def __call__(self, attrs):
return self._new_name, attrs
class AttrConverter(object):
"""Common attribute conveter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provded, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occured.
disables : list
A list of attributes that is disabled in nnvm. Raise warnings.
ignores : list
A list of attributes that is ignored in nnvm. Silent.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, attrs):
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, string_types):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
warnings.warn("Attribute {} is disabled in nnvm.sym.{}".format(k, op_name))
elif k in self._ignores:
pass
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return op_name, new_attrs
def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, string_types):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t
def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, string_types):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)
def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
......@@ -58,12 +58,12 @@ def _pooling(attrs):
def _batch_norm(attrs):
if _parse_bool_str(attrs, 'output_mean_var'):
_raise_not_supported('output_mean_var', 'batch_norm')
if _parse_bool_str(attrs, 'fix_gamma'):
_warn_not_used('fix_gamma', 'batch_norm')
# if _parse_bool_str(attrs, 'fix_gamma'):
# _warn_not_used('fix_gamma', 'batch_norm')
if _parse_bool_str(attrs, 'use_global_stats'):
_warn_not_used('use_global_stats', 'batch_norm')
if _parse_bool_str(attrs, 'momentum'):
_warn_not_used('momentum', 'batch_norm')
# if _parse_bool_str(attrs, 'momentum'):
# _warn_not_used('momentum', 'batch_norm')
op_name, new_attrs = 'batch_norm', {}
new_attrs['axis'] = attrs.get('axis', 1)
new_attrs['epsilon'] = attrs.get('eps', 0.001)
......
......@@ -2,4 +2,8 @@
from __future__ import absolute_import as _abs
from .config import ctx_list
from .utils import create_workload
from . import mobilenet
from . import mlp
from . import resnet
from . import vgg
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
from .. import symbol as sym
from . utils import create_workload
def get_symbol(num_classes=1000):
data = sym.Variable('data')
data = sym.flatten(data=data)
fc1 = sym.dense(data=data, name='fc1', units=128)
act1 = sym.relu(data=fc1, name='relu1')
fc2 = sym.dense(data=act1, name='fc2', units=64)
act2 = sym.relu(data=fc2, name='relu2')
fc3 = sym.dense(data=act2, name='fc3', units=num_classes)
mlp = sym.softmax(data=fc3, name='softmax')
return mlp
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for a simple multilayer perceptron
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes)
return create_workload(net, batch_size, image_shape, dtype)
......@@ -2,11 +2,8 @@
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from .. compiler import graph_util
from .. import graph
from .. import symbol as sym
from . utils import create_workload
def conv_block(data, name, channels,
kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
......@@ -104,22 +101,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype=
params : dict of str to NDArray
The parameters.
"""
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
params = {}
g = graph.create(net)
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes))
for k, v in shape_dict.items():
if k == "data":
continue
# Specially generate non-negative parameters.
if k.endswith("gamma"):
init = np.random.uniform(0.9, 1, size=v)
elif k.endswith("var"):
init = np.random.uniform(0.9, 1, size=v)
else:
init = np.random.uniform(-0.1, 0.1, size=v)
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
return net, params
return create_workload(net, batch_size, image_shape, dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import numpy as np
from .. import symbol as sym
from . utils import create_workload
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes,
# a bit difference with origin paper
bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
conv1 = sym.conv2d(
data=act1, channels=int(num_filter*0.25), kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), use_bias=False, name=name + '_conv1')
bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
conv2 = sym.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=stride, padding=(1, 1), use_bias=False, name=name + '_conv2')
bn3 = sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = sym.relu(data=bn3, name=name + '_relu3')
conv3 = sym.conv2d(
data=act3, channels=num_filter, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), use_bias=False, name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, use_bias=False, name=name+'_sc')
return sym.elemwise_add(conv3, shortcut)
else:
bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
conv1 = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), use_bias=False, name=name + '_conv1')
bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
conv2 = sym.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), use_bias=False, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, use_bias=False, name=name+'_sc')
return sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape,
bottle_neck=True, dtype='float32'):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert num_unit == num_stages
data = sym.Variable(name='data')
if dtype == 'float32':
data = data
else:
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
(_, height, _) = image_shape
if height <= 32: # such as cifar10
body = sym.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), use_bias=False, name="conv0")
else: # often expected to be 224 such as imagenet
body = sym.conv2d(
data=data, channels=filter_list[0], kernel_size=(7, 7),
strides=(2, 2), padding=(3, 3), use_bias=False, name="conv0")
body = sym.batch_norm(data=body, epsilon=2e-5, name='bn0')
body = sym.relu(data=body, name='relu0')
body = sym.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
for i in range(num_stages):
body = residual_unit(
body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck)
for j in range(units[i]-1):
body = residual_unit(
body, filter_list[i+1], (1, 1), True,
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck)
bn1 = sym.batch_norm(data=body, epsilon=2e-5, name='bn1')
relu1 = sym.relu(data=bn1, name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = sym.global_avg_pool2d(data=relu1, name='pool1')
flat = sym.flatten(data=pool1)
fc1 = sym.dense(data=flat, units=num_classes, name='fc1')
if dtype == 'float16':
fc1 = sym.cast(data=fc1, dtype=np.float32)
return sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), dtype='float32'):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
"""
(_, height, _) = image_shape
if height <= 28:
num_stages = 3
if (num_layers-2) % 9 == 0 and num_layers >= 164:
per_unit = [(num_layers-2)//9]
filter_list = [16, 64, 128, 256]
bottle_neck = True
elif (num_layers-2) % 6 == 0 and num_layers < 164:
per_unit = [(num_layers-2)//6]
filter_list = [16, 16, 32, 64]
bottle_neck = False
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
units = per_unit * num_stages
else:
if num_layers >= 50:
filter_list = [64, 256, 512, 1024, 2048]
bottle_neck = True
else:
filter_list = [64, 64, 128, 256, 512]
bottle_neck = False
num_stages = 4
if num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 50:
units = [3, 4, 6, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
return resnet(units=units,
num_stages=num_stages,
filter_list=filter_list,
num_classes=num_classes,
image_shape=image_shape,
bottle_neck=bottle_neck,
dtype=dtype)
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224),
dtype="float32", **kwargs):
"""Get benchmark workload for resnet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, image_shape=image_shape,
dtype=dtype, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
"""Helper utility to create common workload for testing."""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from ..compiler import graph_util
from ..import graph
def create_workload(net, batch_size, image_shape=(3, 224, 224), dtype="float32"):
"""Helper function to create benchmark workload for input network
Parameters
----------
net : nnvm.Symbol
The selected network symbol to use
batch_size : int
The batch size used in the model
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
params = {}
g = graph.create(net)
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes))
for k, v in shape_dict.items():
if k == "data":
continue
# Specially generate non-negative parameters.
if k.endswith("gamma"):
init = np.random.uniform(0.9, 1, size=v)
elif k.endswith("var"):
init = np.random.uniform(0.9, 1, size=v)
else:
init = np.random.uniform(-0.1, 0.1, size=v)
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
return net, params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import numpy as np
from .. import symbol as sym
from . utils import create_workload
def get_feature(internel_layer, layers, filters, batch_norm=False):
"""Get VGG feature body as stacks of convoltions."""
for i, num in enumerate(layers):
for j in range(num):
internel_layer = sym.conv2d(
data=internel_layer, kernel_size=(3, 3), padding=(1, 1),
channels=filters[i], name="conv%s_%s"%(i + 1, j + 1))
if batch_norm:
internel_layer = sym.batch_norm(
data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1))
internel_layer = sym.max_pool2d(
data=internel_layer, pool_size=(2, 2), strides=(2, 2), name="pool%s"%(i + 1))
return internel_layer
def get_classifier(input_data, num_classes):
"""Get VGG classifier layers as fc layers."""
flatten = sym.flatten(data=input_data, name="flatten")
fc6 = sym.dense(data=flatten, units=4096, name="fc6")
relu6 = sym.relu(data=fc6, name="relu6")
drop6 = sym.dropout(data=relu6, rate=0.5, name="drop6")
fc7 = sym.dense(data=drop6, units=4096, name="fc7")
relu7 = sym.relu(data=fc7, name="relu7")
drop7 = sym.dropout(data=relu7, rate=0.5, name="drop7")
fc8 = sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = sym.Variable(name="data")
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = sym.cast(data=classifier, dtype=np.float32)
symbol = sym.softmax(data=classifier, name='softmax')
return symbol
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224),
dtype="float32", **kwargs):
"""Get benchmark workload for VGG nets.
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, dtype=dtype, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
"""MXNet and NNVM model zoo."""
from __future__ import absolute_import
from . import mlp, resnet, vgg
import nnvm.testing
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
_num_class = 1000
# mlp fc
mx_mlp = mlp.get_symbol(_num_class)
nnvm_mlp = mlp.get_symbol_nnvm(_num_class)
nnvm_mlp = nnvm.testing.mlp.get_workload(1, _num_class)[0]
# resnet fc
mx_resnet = {}
nnvm_resnet = {}
for num_layer in [18, 34, 50, 101, 152, 200, 269]:
mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224')
nnvm_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3, 224, 224', lib='nnvm')
nnvm_resnet[num_layer] = nnvm.testing.resnet.get_workload(
1, _num_class, num_layers=num_layer)[0]
# vgg fc
mx_vgg = {}
nnvm_vgg = {}
for num_layer in [11, 13, 16, 19]:
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
nnvm_vgg[num_layer] = vgg.get_symbol_nnvm(_num_class, num_layer)
nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
1, _num_class, num_layers=num_layer)[0]
......@@ -19,7 +19,6 @@
a simple multilayer perceptron
"""
import mxnet as mx
import nnvm
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
......@@ -31,14 +30,3 @@ def get_symbol(num_classes=10, **kwargs):
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp
def get_symbol_nnvm(num_classes=10, **kwargs):
data = nnvm.symbol.Variable('data')
data = nnvm.sym.flatten(data=data)
fc1 = nnvm.symbol.dense(data = data, name='fc1', units=128)
act1 = nnvm.symbol.relu(data = fc1, name='relu1')
fc2 = nnvm.symbol.dense(data = act1, name = 'fc2', units = 64)
act2 = nnvm.symbol.relu(data = fc2, name='relu2')
fc3 = nnvm.symbol.dense(data = act2, name='fc3', units=num_classes)
mlp = nnvm.symbol.softmax(data = fc3, name = 'softmax')
return mlp
......@@ -25,7 +25,6 @@ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Re
'''
import mxnet as mx
import numpy as np
import nnvm
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
......@@ -86,65 +85,6 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, b
shortcut._set_attr(mirror_stage='True')
return conv2 + shortcut
def residual_unit_nnvm(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1')
conv1 = nnvm.sym.conv2d(data=act1, channels=int(num_filter*0.25), kernel_size=(1,1), strides=(1,1), padding=(0,0),
use_bias=False, name=name + '_conv1')
bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2')
conv2 = nnvm.sym.conv2d(data=act2, channels=int(num_filter*0.25), kernel_size=(3,3), strides=stride, padding=(1,1),
use_bias=False, name=name + '_conv2')
bn3 = nnvm.sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = nnvm.sym.relu(data=bn3, name=name + '_relu3')
conv3 = nnvm.sym.conv2d(data=act3, channels=num_filter, kernel_size=(1,1), strides=(1,1), padding=(0,0), use_bias=False,
name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False,
name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return nnvm.sym.elemwise_add(conv3, shortcut)
else:
bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1')
conv1 = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(3,3), strides=stride, padding=(1,1),
use_bias=False, name=name + '_conv1')
bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2')
conv2 = nnvm.sym.conv2d(data=act2, channels=num_filter, kernel_size=(3,3), strides=(1,1), padding=(1,1),
use_bias=False, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False,
name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return nnvm.sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
"""Return ResNet symbol of
Parameters
......@@ -202,64 +142,7 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck
fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
return mx.sym.softmax(data=fc1, name='softmax')
def resnet_nnvm(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
workspace : int
Workspace used in convolution operator
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert(num_unit == num_stages)
data = nnvm.sym.Variable(name='data')
if dtype == 'float32':
# data = nnvm.sym.identity(data=data, name='id')
data = data
else:
if dtype == 'float16':
data = nnvm.sym.cast(data=data, dtype=np.float16)
data = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
(nchannel, height, width) = image_shape
if height <= 32: # such as cifar10
body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(3, 3), strides=(1,1), padding=(1, 1),
use_bias=False, name="conv0")
else: # often expected to be 224 such as imagenet
body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2,2), padding=(3, 3),
use_bias=False, name="conv0")
body = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn0')
body = nnvm.sym.relu(data=body, name='relu0')
body = nnvm.sym.max_pool2d(data=body, pool_size=(3, 3), strides=(2,2), padding=(1,1))
for i in range(num_stages):
body = residual_unit_nnvm(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
memonger=memonger)
for j in range(units[i]-1):
body = residual_unit_nnvm(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
bottle_neck=bottle_neck, memonger=memonger)
bn1 = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn1')
relu1 = nnvm.sym.relu(data=bn1, name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = nnvm.sym.global_avg_pool2d(data=relu1, name='pool1')
flat = nnvm.sym.flatten(data=pool1)
fc1 = nnvm.sym.dense(data=flat, units=num_classes, name='fc1')
if dtype == 'float16':
fc1 = nnvm.sym.cast(data=fc1, dtype=np.float32)
return nnvm.sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', lib='mxnet', **kwargs):
def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', **kwargs):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
......@@ -311,12 +194,4 @@ def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='
image_shape = image_shape,
bottle_neck = bottle_neck,
workspace = conv_workspace,
dtype = dtype) if lib == 'mxnet' else \
resnet_nnvm(units = units,
num_stages = num_stages,
filter_list = filter_list,
num_classes = num_classes,
image_shape = image_shape,
bottle_neck = bottle_neck,
workspace = conv_workspace,
dtype = dtype)
dtype = dtype)
......@@ -22,7 +22,6 @@ large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import mxnet as mx
import nnvm
import numpy as np
def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
......@@ -35,16 +34,6 @@ def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_feature_nnvm(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = nnvm.sym.conv2d(data = internel_layer, kernel_size=(3, 3), padding=(1, 1), channels=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = nnvm.symbol.batch_norm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.max_pool2d(data=internel_layer, pool_size=(2, 2), strides=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten")
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
......@@ -56,17 +45,6 @@ def get_classifier(input_data, num_classes, **kwargs):
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8
def get_classifier_nnvm(input_data, num_classes, **kwargs):
flatten = nnvm.sym.flatten(data=input_data, name="flatten")
fc6 = nnvm.sym.dense(data=flatten, units=4096, name="fc6")
relu6 = nnvm.sym.relu(data=fc6, name="relu6")
drop6 = nnvm.sym.dropout(data=relu6, rate=0.5, name="drop6")
fc7 = nnvm.sym.dense(data=drop6, units=4096, name="fc7")
relu7 = nnvm.sym.relu(data=fc7, name="relu7")
drop7 = nnvm.sym.dropout(data=relu7, rate=0.5, name="drop7")
fc8 = nnvm.sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
......@@ -96,33 +74,3 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **
classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
symbol = mx.sym.softmax(data=classifier, name='softmax')
return symbol
def get_symbol_nnvm(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = nnvm.sym.Variable(name="data")
if dtype == 'float16':
data = nnvm.sym.cast(data=data, dtype=np.float16)
feature = get_feature_nnvm(data, layers, filters, batch_norm)
classifier = get_classifier_nnvm(feature, num_classes)
if dtype == 'float16':
classifier = nnvm.sym.cast(data=classifier, dtype=np.float32)
symbol = nnvm.sym.softmax(data=classifier, name='softmax')
return symbol
......@@ -14,21 +14,21 @@ def compare_graph(sym1, sym2, ishape=(2, 3, 224, 224)):
def test_mlp():
mx_sym = model_zoo.mx_mlp
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_mlp
compare_graph(from_mx_sym, nnvm_sym)
def test_vgg():
for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_vgg[n]
compare_graph(from_mx_sym, nnvm_sym)
def test_resnet():
for n in [18, 34, 50, 101]:
mx_sym = model_zoo.mx_resnet[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)
......
"""Store for onnx examples and common models."""
from __future__ import absolute_import as _abs
import os
from .super_resolution import get_super_resolution
__all__ = ['super_resolution']
def _as_abs_path(fname):
cur_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(cur_dir, fname)
# a pair of onnx pb file and corresponding nnvm symbol
super_resolution = (_as_abs_path('super_resolution.onnx'), get_super_resolution())
"""NNVM symbol corresponding to super_resolution.onnx example."""
from nnvm import sym
def get_super_resolution_deprecated():
factor = 3
size = 224
data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2))
relu1 = sym.relu(conv1)
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1))
relu2 = sym.relu(conv2)
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1))
relu3 = sym.relu(conv3)
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1))
r1 = sym.reshape(conv4, shape=(0, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
r2 = sym.reshape(t1, shape=(0, 1, size * factor, size * factor))
return r2
def get_super_resolution():
factor = 3
size = 224
data = sym.Variable(name='9')
conv1 = sym.conv2d(data, channels=64, kernel_size=(5, 5), padding=(2, 2), use_bias=False)
relu1 = sym.relu(conv1 + sym.Variable(name='2'))
conv2 = sym.conv2d(relu1, channels=64, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu2 = sym.relu(conv2 + sym.Variable(name='4'))
conv3 = sym.conv2d(relu2, channels=32, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
relu3 = sym.relu(conv3 + sym.Variable(name='6'))
conv4 = sym.conv2d(relu3, channels=factor**2, kernel_size=(3, 3), padding=(1, 1), use_bias=False)
conv4 = conv4 + sym.Variable(name='8')
# TODO(zhreshold): allow shape inference for batch size > 1
r1 = sym.reshape(conv4, shape=(1, 1, factor, factor, size, size))
t1 = sym.transpose(r1, axes=(0, 1, 4, 2, 5, 3))
r2 = sym.reshape(t1, shape=(1, 1, size * factor, size * factor))
return r2
import numpy as np
import nnvm
import tvm
from tvm.contrib import graph_runtime
from nnvm.testing.config import ctx_list
import onnx
import onnx_caffe2.backend
from model_zoo import super_resolution
def test_onnx_forward_impl(graph_file, data_shape, out_shape):
def get_caffe2_output(graph, x, dtype='float32'):
prepared_backend = onnx_caffe2.backend.prepare(graph)
W = {graph.input[-1]: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0]
return c2_out
def get_tvm_output(graph, x, target, ctx, dtype='float32'):
new_sym, params = nnvm.frontend.from_onnx(graph)
shape_dict = {'input_0': x.shape}
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('input_0', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
m.run()
# get outputs
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
return out.asnumpy()
dtype = 'float32'
x = np.random.uniform(size=data_shape)
graph = onnx.load(graph_file)
c2_out = get_caffe2_output(graph, x, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(graph, x, target, ctx, dtype)
np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_super_resolution_example():
test_onnx_forward_impl(super_resolution[0], (1, 1, 224, 224), (1, 1, 672, 672))
if __name__ == '__main__':
test_super_resolution_example()
"""Test graph equality of onnx models."""
import nnvm
import onnx
from nnvm.compiler import graph_util, graph_attr
from model_zoo import super_resolution
def compare_graph(onnx_file, nnvm_sym, ishape):
onnx_graph = onnx.load(onnx_file)
onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
g1 = nnvm.graph.create(onnx_sym)
g2 = nnvm.graph.create(nnvm_sym)
ishapes = {'input_0': ishape}
graph_attr.set_shape_inputs(g1, ishapes)
graph_attr.set_shape_inputs(g2, ishapes)
g1 = g1.apply("InferShape").apply("SimplifyInference")
g2 = g2.apply("InferShape").apply("SimplifyInference")
graph_util.check_graph_equal(g1, g2)
def test_super_resolution_example():
fname, symbol = super_resolution
compare_graph(fname, symbol, ishape=(1, 1, 224, 224))
if __name__ == '__main__':
test_super_resolution_example()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment