Commit b1cf70a8 by Siju Committed by Tianqi Chen

[RELAY]Testing Inception, Squeezenet, VGG port (#2013)

parent 53ac89ed
...@@ -7,4 +7,7 @@ from . import dqn ...@@ -7,4 +7,7 @@ from . import dqn
from . import dcgan from . import dcgan
from . import mobilenet from . import mobilenet
from . import lstm from . import lstm
from . import inception_v3
from . import squeezenet
from . import vgg
from .config import ctx_list from .config import ctx_list
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=unused-argument
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
from tvm import relay
from .init import create_workload
from . import layers
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = relay.concatenate((left, right), axis=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = layers.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding), name="conv2d")
net = relay.nn.relu(net)
return net
# Net
def get_net(batch_size, image_shape, num_classes, version, dtype):
"""Get symbol of SqueezeNet
Parameters
----------
batch_size : int
The batch size used in the model
image_shape : tuple, optional
The input image shape
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
data_shape = (batch_size,) + image_shape
net = relay.var("data", shape=data_shape, dtype=dtype)
if version == '1.0':
net = layers.conv2d(net,
channels=96,
kernel_size=(7, 7),
strides=(2, 2),
padding=(3, 3),
name="conv2d")
net = relay.nn.bias_add(net, relay.var("dense1_bias"))
net = relay.nn.relu(net)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = layers.conv2d(net,
channels=64,
kernel_size=(3, 3),
strides=(2, 2),
padding=(1, 1),
name="conv2d")
net = relay.nn.relu(net)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = relay.nn.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = relay.nn.dropout(net, rate=0.5)
net = layers.conv2d(net, channels=num_classes, kernel_size=(1, 1), name="conv2d")
net = relay.nn.relu(net)
net = relay.nn.global_avg_pool2d(net)
net = relay.nn.batch_flatten(net)
net = relay.nn.softmax(net)
args = relay.ir_pass.free_vars(net)
return relay.Function(args, net)
def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for SqueezeNet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
version : str, optional
"1.0" or "1.1" of SqueezeNet
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, image_shape, num_classes, version, dtype)
return create_workload(net)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
from tvm import relay
from .init import create_workload
from . import layers as wrapper
def get_feature(internel_layer, layers, filters, batch_norm=False):
"""Get VGG feature body as stacks of convoltions."""
for i, num in enumerate(layers):
for j in range(num):
internel_layer = wrapper.conv2d(
data=internel_layer, kernel_size=(3, 3), padding=(1, 1),
channels=filters[i], name="conv%s_%s"%(i + 1, j + 1))
if batch_norm:
internel_layer = wrapper.batch_norm_infer(
data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = relay.nn.relu(data=internel_layer)
internel_layer = relay.nn.max_pool2d(
data=internel_layer, pool_size=(2, 2), strides=(2, 2))
return internel_layer
def get_classifier(input_data, num_classes):
"""Get VGG classifier layers as fc layers."""
flatten = relay.nn.batch_flatten(data=input_data)
fc6 = wrapper.dense_add_bias(data=flatten, units=4096, name="fc6")
relu6 = relay.nn.relu(data=fc6)
drop6 = relay.nn.dropout(data=relu6, rate=0.5)
fc7 = wrapper.dense_add_bias(data=drop6, units=4096, name="fc7")
relu7 = relay.nn.relu(data=fc7)
drop7 = relay.nn.dropout(data=relu7, rate=0.5)
fc8 = wrapper.dense_add_bias(data=drop7, units=num_classes, name="fc8")
return fc8
def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_norm=False):
"""
Parameters
----------
batch_size : int
The batch size used in the model
image_shape : tuple, optional
The input image shape
num_classes : int, optional
Number of claseses
dtype : str, optional
The data type
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if num_layers not in vgg_spec:
raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data_shape = (batch_size,) + image_shape
data = relay.var("data", shape=data_shape, dtype=dtype)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
symbol = relay.nn.softmax(data=classifier)
args = relay.ir_pass.free_vars(symbol)
return relay.Function(args, symbol)
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for VGG nets.
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, image_shape, num_classes, dtype)
return create_workload(net)
...@@ -124,6 +124,18 @@ def test_lstm(): ...@@ -124,6 +124,18 @@ def test_lstm():
net, params = tvm.relay.testing.lstm.get_workload(4, 4) net, params = tvm.relay.testing.lstm.get_workload(4, 4)
net.astext() net.astext()
def test_inception_v3():
net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
net.astext()
def test_squeezenet():
for version in ['1.0', '1.1']:
net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
net.astext()
def test_vgg():
net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
net.astext()
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
...@@ -132,6 +144,9 @@ if __name__ == "__main__": ...@@ -132,6 +144,9 @@ if __name__ == "__main__":
test_mlp() test_mlp()
test_dqn() test_dqn()
test_dcgan() test_dcgan()
test_squeezenet()
test_inception_v3()
test_vgg()
test_func() test_func()
test_env() test_env()
test_meta_data() test_meta_data()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment