Commit 292609d8 by Lianmin Zheng Committed by Tianqi Chen

remove dtype in model symbol (#310)

parent acb9fd62
......@@ -24,7 +24,6 @@ Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
# pylint: disable=unused-argument
import numpy as np
from .. import symbol as sym
from . utils import create_workload
......@@ -91,7 +90,7 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True):
return sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape,
bottle_neck=True, dtype='float32'):
bottle_neck=True):
"""Return ResNet symbol of
Parameters
----------
......@@ -105,17 +104,10 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape,
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
dtype : str
Precision (float32 or float16)
"""
num_unit = len(units)
assert num_unit == num_stages
data = sym.Variable(name='data')
if dtype == 'float32':
data = data
else:
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
(_, height, _) = image_shape
if height <= 32: # such as cifar10
......@@ -144,11 +136,9 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape,
pool1 = sym.global_avg_pool2d(data=relu1, name='pool1')
flat = sym.flatten(data=pool1)
fc1 = sym.dense(data=flat, units=num_classes, name='fc1')
if dtype == 'float16':
fc1 = sym.cast(data=fc1, dtype=np.float32)
return sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), dtype='float32', **kwargs):
def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), **kwargs):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
......@@ -197,8 +187,7 @@ def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), dtype='flo
filter_list=filter_list,
num_classes=num_classes,
image_shape=image_shape,
bottle_neck=bottle_neck,
dtype=dtype)
bottle_neck=bottle_neck)
def get_workload(batch_size=1, num_classes=1000, num_layers=18,
image_shape=(3, 224, 224), dtype="float32", **kwargs):
......@@ -233,5 +222,5 @@ def get_workload(batch_size=1, num_classes=1000, num_layers=18,
The parameters.
"""
net = get_symbol(num_classes=num_classes, num_layers=num_layers,
image_shape=image_shape, dtype=dtype, **kwargs)
image_shape=image_shape, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
......@@ -20,7 +20,6 @@
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
import numpy as np
from .. import symbol as sym
from . utils import create_workload
......@@ -51,7 +50,7 @@ def get_classifier(input_data, num_classes):
fc8 = sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'):
def get_symbol(num_classes, num_layers=11, batch_norm=False):
"""
Parameters
----------
......@@ -61,8 +60,6 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'):
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
dtype: str, float32 or float16
Data precision.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
......@@ -72,12 +69,8 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'):
raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = sym.Variable(name="data")
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = sym.cast(data=classifier, dtype=np.float32)
symbol = sym.softmax(data=classifier, name='softmax')
return symbol
......@@ -110,5 +103,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224),
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, dtype=dtype, **kwargs)
net = get_symbol(num_classes=num_classes, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment