Commit a80356bb by Lianmin Zheng Committed by Tianqi Chen

[NNVM] Add symbol for inception v3 (#1604)

parent 7751a6ba
......@@ -8,6 +8,7 @@ from . import mlp
from . import resnet
from . import vgg
from . import squeezenet
from . import inception_v3
from . import dcgan
from . import dqn
from . import yolo2_detection
......@@ -98,7 +98,7 @@ def get_symbol(num_classes, version, **kwargs):
def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for resnet
"""Get benchmark workload for SqueezeNet
Parameters
----------
......
......@@ -125,7 +125,7 @@ std::string GraphDeepCompare(const Graph& a,
const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err;
if (idxa.num_nodes() != idxb.num_nodes()) {
err << "Number of nodes mismatch";
err << "Number of nodes mismatch (" << idxa.num_nodes() << " v.s " << idxb.num_nodes() << ")";
return err.str();
}
if (idxa.num_node_entries() != idxb.num_node_entries()) {
......
"""MXNet and NNVM model zoo."""
from __future__ import absolute_import
from . import mlp, resnet, vgg, dqn, dcgan, squeezenet
from . import mlp, resnet, vgg, dqn, dcgan, squeezenet, inception_v3
import nnvm.testing
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg',
'mx_squeezenet', 'nnvm_squeezenet']
_num_class = 1000
# mlp fc
......@@ -35,6 +32,10 @@ for version in ['1.0', '1.1']:
mx_squeezenet[version] = squeezenet.get_symbol(version=version)
nnvm_squeezenet[version] = nnvm.testing.squeezenet.get_workload(1, version=version)[0]
# inception
mx_inception_v3 = inception_v3.get_symbol()
nnvm_inception_v3 = nnvm.testing.inception_v3.get_workload(1)[0]
# dqn
mx_dqn = dqn.get_symbol()
nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]
......
"""
Inception V3, suitable for images with around 299 x 299
Reference:
Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015).
Adopted from https://github.com/apache/incubator-mxnet/blob/
master/example/image-classification/symbols/inception-v3.py
"""
import mxnet as mx
import numpy as np
def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
bn = mx.sym.BatchNorm(data=conv, eps=2e-5, name='%s%s_batchnorm' % (name, suffix))
act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
return act
def Inception7A(data,
num_1x1,
num_3x3_red, num_3x3_1, num_3x3_2,
num_5x5_red, num_5x5,
pool, proj,
name):
tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name))
tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv')
tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv')
concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name)
return concat
# First Downsample
def Inception7B(data,
num_3x3,
num_d3x3_red, num_d3x3_1, num_d3x3_2,
pool,
name):
tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name))
tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv')
tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1')
tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name))
concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7C(data,
num_1x1,
num_d7_red, num_d7_1, num_d7_2,
num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4,
pool, proj,
name):
tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv')
tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1')
tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2')
tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3')
tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv')
# concat
concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7D(data,
num_3x3_red, num_3x3,
num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3,
pool,
name):
tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv')
tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2')
tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
# concat
concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat
def Inception7E(data,
num_1x1,
num_d3_red, num_d3_1, num_d3_2,
num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2,
pool, proj,
name):
tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv')
tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv')
tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1')
tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv')
tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv')
tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1')
pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv')
# concat
concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name)
return concat
def get_symbol(num_classes=1000, **kwargs):
data = mx.sym.Variable(name="data")
# stage 1
conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1")
conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2")
pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool")
# stage 2
conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3")
conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4")
pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1")
# # stage 3
in3a = Inception7A(pool1, 64,
64, 96, 96,
48, 64,
"avg", 32, "mixed")
in3b = Inception7A(in3a, 64,
64, 96, 96,
48, 64,
"avg", 64, "mixed_1")
in3c = Inception7A(in3b, 64,
64, 96, 96,
48, 64,
"avg", 64, "mixed_2")
in3d = Inception7B(in3c, 384,
64, 96, 96,
"max", "mixed_3")
# stage 4
in4a = Inception7C(in3d, 192,
128, 128, 192,
128, 128, 128, 128, 192,
"avg", 192, "mixed_4")
in4b = Inception7C(in4a, 192,
160, 160, 192,
160, 160, 160, 160, 192,
"avg", 192, "mixed_5")
in4c = Inception7C(in4b, 192,
160, 160, 192,
160, 160, 160, 160, 192,
"avg", 192, "mixed_6")
in4d = Inception7C(in4c, 192,
192, 192, 192,
192, 192, 192, 192, 192,
"avg", 192, "mixed_7")
in4e = Inception7D(in4d, 192, 320,
192, 192, 192, 192,
"max", "mixed_8")
# stage 5
in5a = Inception7E(in4e, 320,
384, 384, 384,
448, 384, 384, 384,
"avg", 192, "mixed_9")
in5b = Inception7E(in5a, 320,
384, 384, 384,
448, 384, 384, 384,
"max", 192, "mixed_10")
# pool
pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool")
flatten = mx.sym.Flatten(data=pool, name="flatten")
fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1', flatten=False)
softmax = mx.sym.SoftmaxOutput(data=fc1, name='softmax')
return softmax
......@@ -39,17 +39,23 @@ def test_squeezenet():
nnvm_sym = model_zoo.nnvm_squeezenet[version]
compare_graph(from_mx_sym, nnvm_sym)
def test_inception_v3():
mx_sym = model_zoo.mx_inception_v3
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_inception_v3
compare_graph(from_mx_sym, nnvm_sym, ishape=(2, 3, 299, 299))
def test_dqn():
mx_sym = model_zoo.mx_dqn
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_dqn
compare_graph(from_mx_sym, nnvm_sym)
compare_graph(from_mx_sym, nnvm_sym, ishape=(2, 4, 84, 84))
def test_dcgan():
mx_sym = model_zoo.mx_dcgan
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_dcgan
compare_graph(from_mx_sym, nnvm_sym)
compare_graph(from_mx_sym, nnvm_sym, ishape=(2, 100))
def test_multi_outputs():
def compose(F, **kwargs):
......@@ -70,3 +76,4 @@ if __name__ == '__main__':
test_dqn()
test_dcgan()
test_squeezenet()
test_inception_v3()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment