Commit 292609d8 by Lianmin Zheng Committed by Tianqi Chen

remove dtype in model symbol (#310)

parent acb9fd62
...@@ -24,7 +24,6 @@ Implemented the following paper: ...@@ -24,7 +24,6 @@ Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks" Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
''' '''
# pylint: disable=unused-argument # pylint: disable=unused-argument
import numpy as np
from .. import symbol as sym from .. import symbol as sym
from . utils import create_workload from . utils import create_workload
...@@ -91,7 +90,7 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True): ...@@ -91,7 +90,7 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True):
return sym.elemwise_add(conv2, shortcut) return sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape, def resnet(units, num_stages, filter_list, num_classes, image_shape,
bottle_neck=True, dtype='float32'): bottle_neck=True):
"""Return ResNet symbol of """Return ResNet symbol of
Parameters Parameters
---------- ----------
...@@ -105,17 +104,10 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, ...@@ -105,17 +104,10 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape,
Ouput size of symbol Ouput size of symbol
dataset : str dataset : str
Dataset type, only cifar10 and imagenet supports Dataset type, only cifar10 and imagenet supports
dtype : str
Precision (float32 or float16)
""" """
num_unit = len(units) num_unit = len(units)
assert num_unit == num_stages assert num_unit == num_stages
data = sym.Variable(name='data') data = sym.Variable(name='data')
if dtype == 'float32':
data = data
else:
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data') data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
(_, height, _) = image_shape (_, height, _) = image_shape
if height <= 32: # such as cifar10 if height <= 32: # such as cifar10
...@@ -144,11 +136,9 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape, ...@@ -144,11 +136,9 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape,
pool1 = sym.global_avg_pool2d(data=relu1, name='pool1') pool1 = sym.global_avg_pool2d(data=relu1, name='pool1')
flat = sym.flatten(data=pool1) flat = sym.flatten(data=pool1)
fc1 = sym.dense(data=flat, units=num_classes, name='fc1') fc1 = sym.dense(data=flat, units=num_classes, name='fc1')
if dtype == 'float16':
fc1 = sym.cast(data=fc1, dtype=np.float32)
return sym.softmax(data=fc1, name='softmax') return sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), dtype='float32', **kwargs): def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), **kwargs):
""" """
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu Original author Wei Wu
...@@ -197,8 +187,7 @@ def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), dtype='flo ...@@ -197,8 +187,7 @@ def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), dtype='flo
filter_list=filter_list, filter_list=filter_list,
num_classes=num_classes, num_classes=num_classes,
image_shape=image_shape, image_shape=image_shape,
bottle_neck=bottle_neck, bottle_neck=bottle_neck)
dtype=dtype)
def get_workload(batch_size=1, num_classes=1000, num_layers=18, def get_workload(batch_size=1, num_classes=1000, num_layers=18,
image_shape=(3, 224, 224), dtype="float32", **kwargs): image_shape=(3, 224, 224), dtype="float32", **kwargs):
...@@ -233,5 +222,5 @@ def get_workload(batch_size=1, num_classes=1000, num_layers=18, ...@@ -233,5 +222,5 @@ def get_workload(batch_size=1, num_classes=1000, num_layers=18,
The parameters. The parameters.
""" """
net = get_symbol(num_classes=num_classes, num_layers=num_layers, net = get_symbol(num_classes=num_classes, num_layers=num_layers,
image_shape=image_shape, dtype=dtype, **kwargs) image_shape=image_shape, **kwargs)
return create_workload(net, batch_size, image_shape, dtype) return create_workload(net, batch_size, image_shape, dtype)
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014). large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
""" """
import numpy as np
from .. import symbol as sym from .. import symbol as sym
from . utils import create_workload from . utils import create_workload
...@@ -51,7 +50,7 @@ def get_classifier(input_data, num_classes): ...@@ -51,7 +50,7 @@ def get_classifier(input_data, num_classes):
fc8 = sym.dense(data=drop7, units=num_classes, name="fc8") fc8 = sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8 return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'): def get_symbol(num_classes, num_layers=11, batch_norm=False):
""" """
Parameters Parameters
---------- ----------
...@@ -61,8 +60,6 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'): ...@@ -61,8 +60,6 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'):
Number of layers for the variant of densenet. Options are 11, 13, 16, 19. Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False batch_norm : bool, default False
Use batch normalization. Use batch normalization.
dtype: str, float32 or float16
Data precision.
""" """
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]), vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
...@@ -72,12 +69,8 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'): ...@@ -72,12 +69,8 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32'):
raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers)) raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers] layers, filters = vgg_spec[num_layers]
data = sym.Variable(name="data") data = sym.Variable(name="data")
if dtype == 'float16':
data = sym.cast(data=data, dtype=np.float16)
feature = get_feature(data, layers, filters, batch_norm) feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes) classifier = get_classifier(feature, num_classes)
if dtype == 'float16':
classifier = sym.cast(data=classifier, dtype=np.float32)
symbol = sym.softmax(data=classifier, name='softmax') symbol = sym.softmax(data=classifier, name='softmax')
return symbol return symbol
...@@ -110,5 +103,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), ...@@ -110,5 +103,5 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224),
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
net = get_symbol(num_classes=num_classes, dtype=dtype, **kwargs) net = get_symbol(num_classes=num_classes, **kwargs)
return create_workload(net, batch_size, image_shape, dtype) return create_workload(net, batch_size, image_shape, dtype)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment