Commit dc8fd79c by Wuwei Lin Committed by Tianqi Chen

[RELAY] Add missing arg in vgg (#2329)

parent 14acb80a
...@@ -98,7 +98,8 @@ def get_workload(batch_size, ...@@ -98,7 +98,8 @@ def get_workload(batch_size,
num_classes=1000, num_classes=1000,
image_shape=(3, 224, 224), image_shape=(3, 224, 224),
dtype="float32", dtype="float32",
num_layers=11): num_layers=11,
batch_norm=False):
"""Get benchmark workload for VGG nets. """Get benchmark workload for VGG nets.
Parameters Parameters
...@@ -118,6 +119,9 @@ def get_workload(batch_size, ...@@ -118,6 +119,9 @@ def get_workload(batch_size,
num_layers : int num_layers : int
Number of layers for the variant of vgg. Options are 11, 13, 16, 19. Number of layers for the variant of vgg. Options are 11, 13, 16, 19.
batch_norm : bool
Use batch normalization.
Returns Returns
------- -------
net : nnvm.Symbol net : nnvm.Symbol
...@@ -126,5 +130,5 @@ def get_workload(batch_size, ...@@ -126,5 +130,5 @@ def get_workload(batch_size,
params : dict of str to NDArray params : dict of str to NDArray
The parameters. The parameters.
""" """
net = get_net(batch_size, image_shape, num_classes, dtype, num_layers) net = get_net(batch_size, image_shape, num_classes, dtype, num_layers, batch_norm)
return create_workload(net) return create_workload(net)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment