Commit 389a00f6 by Joshua Z. Zhang Committed by Tianqi Chen

init mxnet converter (#27)

graph

backup

update

finish mxnet converter

fix

fix various

add tests

fix

add multi networks

uses model_zoo

fix tests

minor fix

fix graph

fix
parent 2b3d2e21
......@@ -7,5 +7,6 @@ from . import _base
from . import symbol as sym
from . import symbol
from ._base import NNVMError
from . import frontend
__version__ = _base.__version__
"""Frontend package."""
from __future__ import absolute_import
from .mxnet import from_mxnet
from __future__ import absolute_import
from . import mlp, resnet, vgg
_num_class = 1000
# mlp fc
mx_mlp = mlp.get_symbol(_num_class)
nnvm_mlp = mlp.get_symbol_nnvm(_num_class)
# resnet fc
mx_resnet = {}
nnvm_resnet = {}
for num_layer in [18, 34, 50, 101, 152, 200, 269]:
mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224')
nnvm_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3, 224, 224', lib='nnvm')
# vgg fc
mx_vgg = {}
nnvm_vgg = {}
for num_layer in [11, 13, 16, 19]:
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
nnvm_vgg[num_layer] = vgg.get_symbol_nnvm(_num_class, num_layer)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
import mxnet as mx
import nnvm
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.sym.Flatten(data=data)
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.softmax(data = fc3, name = 'softmax')
return mlp
def get_symbol_nnvm(num_classes=10, **kwargs):
data = nnvm.symbol.Variable('data')
data = nnvm.sym.flatten(data=data)
fc1 = nnvm.symbol.dense(data = data, name='fc1', units=128)
act1 = nnvm.symbol.relu(data = fc1, name='relu1')
fc2 = nnvm.symbol.dense(data = act1, name = 'fc2', units = 64)
act2 = nnvm.symbol.relu(data = fc2, name='relu2')
fc3 = nnvm.symbol.dense(data = act2, name='fc3', units=num_classes)
mlp = nnvm.symbol.softmax(data = fc3, name = 'softmax')
return mlp
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import mxnet as mx
import nnvm
import numpy as np
def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1))
internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_feature_nnvm(internel_layer, layers, filters, batch_norm = False, **kwargs):
for i, num in enumerate(layers):
for j in range(num):
internel_layer = nnvm.sym.conv2d(data = internel_layer, kernel_size=(3, 3), padding=(1, 1), channels=filters[i], name="conv%s_%s" %(i + 1, j + 1))
if batch_norm:
internel_layer = nnvm.symbol.batch_norm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1))
internel_layer = nnvm.sym.max_pool2d(data=internel_layer, pool_size=(2, 2), strides=(2,2), name="pool%s" %(i + 1))
return internel_layer
def get_classifier(input_data, num_classes, **kwargs):
flatten = mx.sym.Flatten(data=input_data, name="flatten")
fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
return fc8
def get_classifier_nnvm(input_data, num_classes, **kwargs):
flatten = nnvm.sym.flatten(data=input_data, name="flatten")
fc6 = nnvm.sym.dense(data=flatten, units=4096, name="fc6")
relu6 = nnvm.sym.relu(data=fc6, name="relu6")
drop6 = nnvm.sym.dropout(data=relu6, rate=0.5, name="drop6")
fc7 = nnvm.sym.dense(data=drop6, units=4096, name="fc7")
relu7 = nnvm.sym.relu(data=fc7, name="relu7")
drop7 = nnvm.sym.dropout(data=relu7, rate=0.5, name="drop7")
fc8 = nnvm.sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = mx.sym.Variable(name="data")
if dtype == 'float16':
data = mx.sym.Cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
symbol = mx.sym.softmax(data=classifier, name='softmax')
return symbol
def get_symbol_nnvm(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if not vgg_spec.has_key(num_layers):
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = nnvm.sym.Variable(name="data")
if dtype == 'float16':
data = nnvm.sym.cast(data=data, dtype=np.float16)
feature = get_feature_nnvm(data, layers, filters, batch_norm)
classifier = get_classifier_nnvm(feature, num_classes)
if dtype == 'float16':
classifier = nnvm.sym.cast(data=classifier, dtype=np.float32)
symbol = nnvm.sym.softmax(data=classifier, name='softmax')
return symbol
import numpy as np
import topi
import tvm
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm import frontend
import mxnet as mx
import model_zoo
USE_GPU=True
def default_target():
if USE_GPU:
return 'cuda'
else:
return 'llvm'
def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)
def test_mxnet_frontend_impl(mx_symbol, data_shape=(2, 3, 224, 224), out_shape=(2, 1000)):
def get_mxnet_output(symbol, x, dtype='float32'):
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
mod = mx.mod.Module(symbol, label_names=None)
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
mod.init_params()
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
out = mod.get_outputs()[0].asnumpy()
args, auxs = mod.get_params()
return out, args, auxs
def get_tvm_output(symbol, x, args, auxs, dtype='float32'):
dshape = x.shape
shape_dict = {'data': dshape}
for k, v in args.items():
shape_dict[k] = v.shape
for k, v in auxs.items():
shape_dict[k] = v.shape
graph, lib, _ = nnvm.compiler.build(symbol, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m['set_input'], m['run'], m['get_output']
# set inputs
set_input('data', tvm.nd.array(x.astype(dtype)))
for k, v in args.items():
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype)))
for k, v in auxs.items():
set_input(k, tvm.nd.array(v.asnumpy().astype(dtype)))
# execute
run()
# get outputs
out = tvm.nd.empty(out_shape, dtype)
get_output(0, out)
return out.asnumpy()
# random input
dtype = 'float32'
x = np.random.uniform(size=data_shape)
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
new_sym = frontend.from_mxnet(mx_symbol)
tvm_out = get_tvm_output(new_sym, x, args, auxs, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5)
def test_forward_mlp():
mlp = model_zoo.mx_mlp
test_mxnet_frontend_impl(mlp)
def test_forward_vgg():
for n in [11]:
mx_sym = model_zoo.mx_vgg[n]
test_mxnet_frontend_impl(mx_sym)
def test_forward_resnet():
for n in [18]:
mx_sym = model_zoo.mx_resnet[n]
test_mxnet_frontend_impl(mx_sym)
if __name__ == '__main__':
test_forward_mlp()
# waiting for max_pool2d
# test_forward_vgg()
# test_forward_resnet()
import mxnet as mx
import nnvm
from nnvm.compiler import graph_util, graph_attr
import model_zoo
def compare_graph(sym1, sym2, ishape=(2, 3, 224, 224)):
g1 = nnvm.graph.create(sym1)
g2 = nnvm.graph.create(sym2)
graph_attr.set_shape_inputs(g1, {'data':ishape})
graph_attr.set_shape_inputs(g2, {'data':ishape})
g1 = g1.apply("InferShape").apply("SimplifyInference")
g2 = g2.apply("InferShape").apply("SimplifyInference")
graph_util.check_graph_equal(g1, g2)
def test_mlp():
mx_sym = model_zoo.mx_mlp
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_mlp
compare_graph(from_mx_sym, nnvm_sym)
def test_vgg():
for n in [11, 13, 16, 19]:
mx_sym = model_zoo.mx_vgg[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_vgg[n]
compare_graph(from_mx_sym, nnvm_sym)
def test_resnet():
for n in [18, 34, 50, 101]:
mx_sym = model_zoo.mx_resnet[n]
from_mx_sym = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)
if __name__ == '__main__':
test_mlp()
test_vgg()
test_resnet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment