util.py 2.82 KB
Newer Older
1 2 3 4 5
"""Utility for benchmark"""

import sys
import nnvm

6
def get_network(name, batch_size, dtype='float32'):
7 8 9 10 11 12
    """Get the symbol definition and random weight of a network
    
    Parameters
    ----------
    name: str
        The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ...
13
    batch_size: int
14
        batch size
15 16
    dtype: str
        Data type
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

    Returns
    -------
    net: nnvm.symbol
        The NNVM symbol of network definition
    params: dict
        The random parameters for benchmark
    input_shape: tuple
        The shape of input tensor
    output_shape: tuple
        The shape of output tensor
    """
    input_shape = (batch_size, 3, 224, 224)
    output_shape = (batch_size, 1000)

32
    if name == 'mobilenet':
33
        net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
34
    elif name == 'mobilenet_v2':
35
        net, params = nnvm.testing.mobilenet_v2.get_workload(batch_size=batch_size, dtype=dtype)
36
    elif name == 'inception_v3':
37
        input_shape = (batch_size, 3, 299, 299)
38
        net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
39
    elif "resnet" in name:
40
        n_layer = int(name.split('-')[1])
41
        net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
42 43
    elif "vgg" in name:
        n_layer = int(name.split('-')[1])
44
        net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
45 46
    elif "densenet" in name:
        n_layer = int(name.split('-')[1])
47
        net, params = nnvm.testing.densenet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
48 49
    elif "squeezenet" in name:
        version = name.split("_v")[1]
50
        net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version=version, dtype=dtype)
51 52 53 54 55 56 57
    elif name == 'custom':
        # an example for custom network
        from nnvm.testing import utils
        net = nnvm.sym.Variable('data')
        net = nnvm.sym.conv2d(net, channels=4, kernel_size=(3,3), padding=(1,1))
        net = nnvm.sym.flatten(net)
        net = nnvm.sym.dense(net, units=1000)
58
        net, params = utils.create_workload(net, batch_size, (3, 224, 224), dtype=dtype)
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
    elif name == 'mxnet':
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model
        block = get_model('resnet18_v1', pretrained=True)
        net, params = nnvm.frontend.from_mxnet(block)
        net = nnvm.sym.softmax(net)
    else:
        raise ValueError("Unsupported network: " + name)

    return net, params, input_shape, output_shape

def print_progress(msg):
    """print progress message
    
    Parameters
    ----------
    msg: str
        The message to print
    """
    sys.stdout.write(msg + "\r")
    sys.stdout.flush()