__init__.py 1.32 KB
Newer Older
1
"""MXNet and NNVM model zoo."""
2
from __future__ import absolute_import
3
from . import mlp, resnet, vgg, dqn, dcgan, squeezenet, inception_v3
4 5
import nnvm.testing

6 7 8 9
_num_class = 1000

# mlp fc
mx_mlp = mlp.get_symbol(_num_class)
10
nnvm_mlp = nnvm.testing.mlp.get_workload(1, _num_class)[0]
11 12 13 14 15 16

# resnet fc
mx_resnet = {}
nnvm_resnet = {}
for num_layer in [18, 34, 50, 101, 152, 200, 269]:
    mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224')
17 18
    nnvm_resnet[num_layer] = nnvm.testing.resnet.get_workload(
        1, _num_class, num_layers=num_layer)[0]
19 20 21 22 23 24

# vgg fc
mx_vgg = {}
nnvm_vgg = {}
for num_layer in [11, 13, 16, 19]:
    mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
25 26
    nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
        1, _num_class, num_layers=num_layer)[0]
27

28 29 30 31 32 33 34
# squeezenet
mx_squeezenet = {}
nnvm_squeezenet = {}
for version in ['1.0', '1.1']:
    mx_squeezenet[version] = squeezenet.get_symbol(version=version)
    nnvm_squeezenet[version] = nnvm.testing.squeezenet.get_workload(1, version=version)[0]

35 36 37 38
# inception
mx_inception_v3 = inception_v3.get_symbol()
nnvm_inception_v3 = nnvm.testing.inception_v3.get_workload(1)[0]

39 40 41 42 43 44 45
# dqn
mx_dqn = dqn.get_symbol()
nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]

# dcgan generator
mx_dcgan = dcgan.get_symbol()
nnvm_dcgan = nnvm.testing.dcgan.get_workload(1)[0]