Commit 389a00f6 by Joshua Z. Zhang Committed by Tianqi Chen

init mxnet converter (#27)

graph

backup

update

finish mxnet converter

fix

fix various

add tests

fix

add multi networks

uses model_zoo

fix tests

minor fix

fix graph

fix
parent 2b3d2e21
...@@ -7,5 +7,6 @@ from . import _base ...@@ -7,5 +7,6 @@ from . import _base
from . import symbol as sym from . import symbol as sym
from . import symbol from . import symbol
from ._base import NNVMError from ._base import NNVMError
from . import frontend
__version__ = _base.__version__ __version__ = _base.__version__
"""Frontend package."""
from __future__ import absolute_import
from .mxnet import from_mxnet
"""MXNet symbol frontend."""
from __future__ import absolute_import as _abs
import json
from .. import symbol as _sym
__all__ = ['from_mxnet']
def _required_attr(attr, key):
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
def _raise_not_supported(attr, op='nnvm'):
err = "{} is not supported in {}.".format(attr, op)
raise NotImplementedError(err)
def _warn_not_used(attr, op='nnvm'):
import warnings
err = "{} is ignored in {}.".format(attr, op)
warnings.warn(err)
def _parse_tshape(tshape):
"""Parse tshape in string."""
return [int(x.strip()) for x in tshape.strip('()').split(',')]
def _parse_bool_str(attr, key, default='False'):
"""Parse bool string to boolean."""
return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']
def _rename(new_name):
def impl(attr):
return new_name, attr
return impl
def _variable(attrs):
return "Variable", attrs
def _pooling(attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_raise_not_supported('non-2d kernel', 'pool_2d')
global_pool = 'global' if _parse_bool_str(attrs, 'global_pool') else ''
pool_type = _required_attr(attrs, 'pool_type')
if pool_type not in ['avg', 'max']:
_raise_not_supported('non-avg/max', 'pool2d')
op_name, new_attrs = '_'.join([global_pool, pool_type, 'pool2d']).strip('_'), {}
# new_attrs['layout'] = 'NCHW'
if not global_pool:
new_attrs['pool_size'] = kernel
new_attrs['strides'] = attrs.get('stride', (1, 1))
new_attrs['padding'] = attrs.get('pad', (0, 0))
new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full')
return op_name, new_attrs
def _batch_norm(attrs):
if _parse_bool_str(attrs, 'output_mean_var'):
_raise_not_supported('output_mean_var', 'batch_norm')
if _parse_bool_str(attrs, 'fix_gamma'):
_warn_not_used('fix_gamma', 'batch_norm')
if _parse_bool_str(attrs, 'use_global_stats'):
_warn_not_used('use_global_stats', 'batch_norm')
if _parse_bool_str(attrs, 'momentum'):
_warn_not_used('momentum', 'batch_norm')
op_name, new_attrs = 'batch_norm', {}
new_attrs['axis'] = attrs.get('axis', 1)
new_attrs['epsilon'] = attrs.get('eps', 0.001)
new_attrs['center'] = True
new_attrs['scale'] = True
return op_name, new_attrs
def _concat(attrs):
op_name = 'concatenate'
new_attrs = {'axis': attrs.get('dim', 1)}
return op_name, new_attrs
def _conv2d(attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_raise_not_supported('non 2d kernel', 'conv2d')
layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d')
op_name, new_attrs = 'conv2d', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel
new_attrs['strides'] = attrs.get('stride', (1, 1))
new_attrs['padding'] = attrs.get('pad', (0, 0))
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
return op_name, new_attrs
def _conv2d_transpose(attrs):
if 'target_shape' in attrs:
_raise_not_supported('target_shape', 'conv2d_transpose')
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_raise_not_supported('non-2d kernel', 'conv2d_transpose')
layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d_transpose')
op_name, new_attrs = 'conv2d_transpose', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel
new_attrs['strides'] = attrs.get('stride', (1, 1))
new_attrs['output_padding'] = attrs.get('adj', (0, 0))
new_attrs['padding'] = attrs.get('pad', (0, 0))
new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return op_name, new_attrs
def _dense(attrs):
op_name, new_attrs = 'dense', {}
new_attrs['units'] = _required_attr(attrs, 'num_hidden')
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return op_name, new_attrs
def _dropout(attrs):
op_name, new_attrs = 'dropout', {}
new_attrs['rate'] = attrs.get('p', 0.5)
return op_name, new_attrs
def _leaky_relu(attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type not in ['leaky']:
_raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
return op_name, new_attrs
def _activations(attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type not in ['relu', 'sigmoid', 'tanh']:
_raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = act_type, {}
return op_name, new_attrs
def _reshape(attrs):
if _parse_bool_str(attrs, 'reverse'):
_raise_not_supported('reverse', 'reshape')
op_name, new_attrs = 'reshape', {}
new_attrs['shape'] = _required_attr(attrs, 'shape')
return op_name, new_attrs
def _split(attrs):
if _parse_bool_str(attrs, 'squeeze_axis'):
_raise_not_supported('squeeze_axis', 'split')
op_name, new_attrs = 'split', {}
new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs')
new_attrs['axis'] = attrs.get('axis', 1)
return op_name, new_attrs
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
'__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
'broadcast_add', 'broadcast_div', 'broadcast_mul',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
_convert_map = {
'null' : _variable,
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
'Cast' : _rename('cast'),
'Concat' : _concat,
'Convolution' : _conv2d,
'Convolution_v1': _conv2d,
'Deconvolution' : _conv2d_transpose,
'Dropout' : _dropout,
'Flatten' : _rename('flatten'),
'FullyConnected': _dense,
'LeakyReLU' : _leaky_relu,
'Pooling' : _pooling,
'Pooling_v1' : _pooling,
'Reshape' : _reshape,
'Softmax' : _rename('softmax'),
'concat' : _concat,
'max_axis' : _rename('max'),
'min_axis' : _rename('min'),
'reshape' : _reshape,
'sum_axis' : _rename('sum'),
}
def _convert_symbol(op_name, attrs,
identity_list=_identity_list,
convert_map=_convert_map):
"""Convert from mxnet op to nnvm op.
The converter must specify some conversions explicitly to
support gluon format ops such as conv2d...
Parameters
----------
op_name : str
Operator name, such as Convolution, FullyConnected
attrs : dict
Dict of operator attributes
identity_list : list
List of operators that don't require conversion
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
(op_name, attrs)
Converted (op_name, attrs) for nnvm.
"""
if op_name in identity_list:
pass
elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs)
else:
_raise_not_supported('Operator: ' + op_name)
op = getattr(_sym, op_name, None)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op, attrs
def _is_mxnet_group_symbol(symbol):
"""Internal check for mxnet group symbol."""
return len(symbol.list_outputs()) > 1
def _as_list(arr):
"""Force being a list, ignore if already is."""
if isinstance(arr, list):
return arr
return [arr]
def _from_mxnet_impl(symbol, graph):
"""Convert mxnet symbol to nnvm implementation.
Reconstruct a nnvm symbol by traversing the mxnet symbol.
Parameters
----------
symbol : mxnet.sym.Symbol
Incompatible symbol from mxnet, sharing similar graph structure.
The op_name and attrs inside are not always compatible.
graph : dict
Reusable nodes are stored in graph.
Returns:
-------
nnvm.sym.Symbol
Converted symbol
"""
try:
from mxnet import sym as mx_sym
except ImportError as e:
raise ImportError('{}. MXNet is required to parse symbols.'.format(e))
if not isinstance(symbol, mx_sym.Symbol):
raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol))
if _is_mxnet_group_symbol(symbol):
return [_from_mxnet_impl(s, graph) for s in symbol]
name = symbol.attr('name')
node = graph.get(name, None)
if node:
return node
# op_name = symbol.attr('op_name')
if symbol.get_children():
op_name = symbol.attr('op_name')
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
attr = symbol.list_attr()
new_op, new_attr = _convert_symbol(op_name, attr)
if new_op == _sym.Variable:
node = new_op(name=name, **new_attr)
else:
childs = symbol.get_children()
childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = new_op(name=name, *childs, **new_attr)
graph[name] = node
return node
def from_mxnet(symbol):
"""Convert from mxnet.Symbol to compatible nnvm.Symbol
Parameters
----------
symbol : mxnet.Symbol
MXNet symbol
Returns
-------
nnvm.Symbol
Compatible nnvm symbol
"""
return _from_mxnet_impl(symbol, {})
from __future__ import absolute_import
from . import mlp, resnet, vgg
_num_class = 1000
# mlp fc
mx_mlp = mlp.get_symbol(_num_class)
nnvm_mlp = mlp.get_symbol_nnvm(_num_class)
# resnet fc
mx_resnet = {}
nnvm_resnet = {}
for num_layer in [18, 34, 50, 101, 152, 200, 269]:
mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224')
nnvm_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3, 224, 224', lib='nnvm')
# vgg fc
mx_vgg = {}
nnvm_vgg = {}
for num_layer in [11, 13, 16, 19]:
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
nnvm_vgg[num_layer] = vgg.get_symbol_nnvm(_num_class, num_layer)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
import mxnet as mx
import nnvm
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.sym.Flatten(data=data)
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp
def get_symbol_nnvm(num_classes=10, **kwargs):
data = nnvm.symbol.Variable('data')
data = nnvm.sym.flatten(data=data)
fc1 = nnvm.symbol.dense(data = data, name='fc1', units=128)
act1 = nnvm.symbol.relu(data = fc1, name='relu1')
fc2 = nnvm.symbol.dense(data = act1, name = 'fc2', units = 64)
act2 = nnvm.symbol.relu(data = fc2, name='relu2')
fc3 = nnvm.symbol.dense(data = act2, name='fc3', units=num_classes)
mlp = nnvm.symbol.softmax(data = fc3, name = 'softmax')
return mlp
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import mxnet as mx
import numpy as np
import nnvm
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
workspace=workspace, name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv3 + shortcut
else:
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv2 + shortcut
def residual_unit_nnvm(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1')
conv1 = nnvm.sym.conv2d(data=act1, channels=int(num_filter*0.25), kernel_size=(1,1), strides=(1,1), padding=(0,0),
use_bias=False, name=name + '_conv1')
bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2')
conv2 = nnvm.sym.conv2d(data=act2, channels=int(num_filter*0.25), kernel_size=(3,3), strides=stride, padding=(1,1),
use_bias=False, name=name + '_conv2')
bn3 = nnvm.sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = nnvm.sym.relu(data=bn3, name=name + '_relu3')
conv3 = nnvm.sym.conv2d(data=act3, channels=num_filter, kernel_size=(1,1), strides=(1,1), padding=(0,0), use_bias=False,
name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False,
name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return nnvm.sym.elemwise_add(conv3, shortcut)
else:
bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1')
conv1 = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(3,3), strides=stride, padding=(1,1),
use_bias=False, name=name + '_conv1')
bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2')
conv2 = nnvm.sym.conv2d(data=act2, channels=num_filter, kernel_size=(3,3), strides=(1,1), padding=(1,1),
use_bias=False, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False,
name=name+'_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return nnvm.sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
workspace : int
Workspace used in convolution operator
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert(num_unit == num_stages)
data = mx.sym.Variable(name='data')
if dtype == 'float32':
# data = mx.sym.identity(data=data, name='id')
data = data
else:
if dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
(nchannel, height, width) = image_shape
if height <= 32: # such as cifar10
body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
else: # often expected to be 224 such as imagenet
body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
for i in range(num_stages):
body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace,
memonger=memonger)
for j in range(units[i]-1):
body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
bottle_neck=bottle_neck, workspace=workspace, memonger=memonger)
bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
flat = mx.sym.Flatten(data=pool1)
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
if dtype == 'float16':
fc1 = mx.sym.Cast(data=fc1, dtype=np.float32)
return mx.sym.softmax(data=fc1, name='softmax')
def resnet_nnvm(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
workspace : int
Workspace used in convolution operator
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert(num_unit == num_stages)
data = nnvm.sym.Variable(name='data')
if dtype == 'float32':
# data = nnvm.sym.identity(data=data, name='id')
data = data
else:
if dtype == 'float16':
data = nnvm.sym.cast(data=data, dtype=np.float16)
data = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
(nchannel, height, width) = image_shape
if height <= 32: # such as cifar10
body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(3, 3), strides=(1,1), padding=(1, 1),
use_bias=False, name="conv0")
else: # often expected to be 224 such as imagenet
body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2,2), padding=(3, 3),
use_bias=False, name="conv0")
body = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn0')
body = nnvm.sym.relu(data=body, name='relu0')
body = nnvm.sym.max_pool2d(data=body, pool_size=(3, 3), strides=(2,2), padding=(1,1))
for i in range(num_stages):
body = residual_unit_nnvm(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck,
memonger=memonger)
for j in range(units[i]-1):
body = residual_unit_nnvm(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2),
bottle_neck=bottle_neck, memonger=memonger)
bn1 = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn1')
relu1 = nnvm.sym.relu(data=bn1, name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = nnvm.sym.global_avg_pool2d(data=relu1, name='pool1')
flat = nnvm.sym.flatten(data=pool1)
fc1 = nnvm.sym.dense(data=flat, units=num_classes, name='fc1')
if dtype == 'float16':
fc1 = nnvm.sym.cast(data=fc1, dtype=np.float32)
return nnvm.sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', lib='mxnet', **kwargs):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
"""
image_shape = [int(l) for l in image_shape.split(',')]
(nchannel, height, width) = image_shape
if height <= 28:
num_stages = 3
if (num_layers-2) % 9 == 0 and num_layers >= 164:
per_unit = [(num_layers-2)//9]
filter_list = [16, 64, 128, 256]
bottle_neck = True
elif (num_layers-2) % 6 == 0 and num_layers < 164:
per_unit = [(num_layers-2)//6]
filter_list = [16, 16, 32, 64]
bottle_neck = False
else:
raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
units = per_unit * num_stages
else:
if num_layers >= 50:
filter_list = [64, 256, 512, 1024, 2048]
bottle_neck = True
else:
filter_list = [64, 64, 128, 256, 512]
bottle_neck = False
num_stages = 4
if num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 50:
units = [3, 4, 6, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers))
return resnet(units = units,
num_stages = num_stages,
filter_list = filter_list,
num_classes = num_classes,
image_shape = image_shape,
bottle_neck = bottle_neck,
workspace = conv_workspace,
dtype = dtype) if lib == 'mxnet' else \
resnet_nnvm(units = units,
num_stages = num_stages,
filter_list = filter_list,
num_classes = num_classes,
image_shape = image_shape,
bottle_neck = bottle_neck,
workspace = conv_workspace,
dtype = dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import mxnet as mx
import nnvm
import numpy as np
def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_feature_nnvm(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = nnvm.sym.conv2d(data = internel_layer, kernel_size=(3, 3), padding=(1, 1), channels=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = nnvm.symbol.batch_norm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.max_pool2d(data=internel_layer, pool_size=(2, 2), strides=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten")
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8
def get_classifier_nnvm(input_data, num_classes, **kwargs):
flatten = nnvm.sym.flatten(data=input_data, name="flatten")
fc6 = nnvm.sym.dense(data=flatten, units=4096, name="fc6")
relu6 = nnvm.sym.relu(data=fc6, name="relu6")
drop6 = nnvm.sym.dropout(data=relu6, rate=0.5, name="drop6")
fc7 = nnvm.sym.dense(data=drop6, units=4096, name="fc7")
relu7 = nnvm.sym.relu(data=fc7, name="relu7")
drop7 = nnvm.sym.dropout(data=relu7, rate=0.5, name="drop7")
fc8 = nnvm.sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = mx.sym.Variable(name="data")
if dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
symbol = mx.sym.softmax(data=classifier, name='softmax')
return symbol
def get_symbol_nnvm(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = nnvm.sym.Variable(name="data")
if dtype == 'float16':
data = nnvm.sym.cast(data=data, dtype=np.float16)
feature = get_feature_nnvm(data, layers, filters, batch_norm)
classifier = get_classifier_nnvm(feature, num_classes)
if dtype == 'float16':
classifier = nnvm.sym.cast(data=classifier, dtype=np.float32)
symbol = nnvm.sym.softmax(data=classifier, name='softmax')
return symbol
import numpy as np
import topi
import tvm
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm import frontend
import mxnet as mx
import model_zoo
USE_GPU=True
def default_target():
if USE_GPU:
return 'cuda'
else:
return 'llvm'
def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)
def test_mxnet_frontend_impl(mx_symbol, data_shape=(2, 3, 224, 224), out_shape=(2, 1000)):
def get_mxnet_output(symbol, x, dtype='float32'):
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
mod = mx.mod.Module(symbol, label_names=None)
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
mod.init_params()
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
out = mod.get_outputs()[0].asnumpy()
args, auxs = mod.get_params()
return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, dtype='float32'):
dshape = x.shape
shape_dict = {'data': dshape}
for k, v in args.items():
shape_dict[k] = v.shape
for k, v in auxs.items():
shape_dict[k] = v.shape
graph, lib, _ = nnvm.compiler.build(symbol, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m['set_input'], m['run'], m['get_output']
# set inputs
set_input('data', tvm.nd.array(x.astype(dtype)))
for k, v in args.items():
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype)))
for k, v in auxs.items():
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype)))
# execute
run()
# get outputs
out = tvm.nd.empty(out_shape, dtype)
get_output(0, out)
return out.asnumpy()
# random input
dtype = 'float32'
x = np.random.uniform(size=data_shape)
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
new_sym = frontend.from_mxnet(mx_symbol)
tvm_out = get_tvm_output(new_sym, x, args, auxs, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5)
def test_forward_mlp():
mlp = model_zoo.mx_mlp
test_mxnet_frontend_impl(mlp)
def test_forward_vgg():
for n in [11]:
mx_sym = model_zoo.mx_vgg[n]
test_mxnet_frontend_impl(mx_sym)
def test_forward_resnet():
for n in [18]:
mx_sym = model_zoo.mx_resnet[n]
test_mxnet_frontend_impl(mx_sym)
if __name__ == '__main__':
test_forward_mlp()
# waiting for max_pool2d
# test_forward_vgg()
# test_forward_resnet()
import mxnet as mx
import nnvm
from nnvm.compiler import graph_util, graph_attr
import model_zoo
def compare_graph(sym1, sym2, ishape=(2, 3, 224, 224)):
g1 = nnvm.graph.create(sym1)
g2 = nnvm.graph.create(sym2)
graph_attr.set_shape_inputs(g1, {'data':ishape})
graph_attr.set_shape_inputs(g2, {'data':ishape})
g1 = g1.apply("InferShape").apply("SimplifyInference")
g2 = g2.apply("InferShape").apply("SimplifyInference")
graph_util.check_graph_equal(g1, g2)
def test_mlp():
mx_sym = model_zoo.mx_mlp
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_mlp
compare_graph(from_mx_sym, nnvm_sym)
def test_vgg():
for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_vgg[n]
compare_graph(from_mx_sym, nnvm_sym)
def test_resnet():
for n in [18, 34, 50, 101]:
mx_sym = model_zoo.mx_resnet[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)
if __name__ == '__main__':
test_mlp()
test_vgg()
test_resnet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment