# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Utility for benchmark""" import sys from tvm import relay from tvm.relay import testing def get_network(name, batch_size, dtype='float32'): """Get the symbol definition and random weight of a network Parameters ---------- name: str The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 'inception_v3', 'mobilenet', ... batch_size: int batch size dtype: str Data type Returns ------- net: tvm.IRModule The relay function of network definition params: dict The random parameters for benchmark input_shape: tuple The shape of input tensor output_shape: tuple The shape of output tensor """ input_shape = (batch_size, 3, 224, 224) output_shape = (batch_size, 1000) if name == 'mobilenet': net, params = testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype) elif name == 'inception_v3': input_shape = (batch_size, 3, 299, 299) net, params = testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) elif "resnet" in name: n_layer = int(name.split('-')[1]) net, params = testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) elif "vgg" in name: n_layer = int(name.split('-')[1]) net, params = testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) elif "densenet" in name: n_layer = int(name.split('-')[1]) net, params = testing.densenet.get_workload(densenet_size=n_layer, batch_size=batch_size, dtype=dtype) elif "squeezenet" in name: version = name.split("_v")[1] net, params = testing.squeezenet.get_workload(batch_size=batch_size, version=version, dtype=dtype) elif name == 'mxnet': # an example for mxnet model from mxnet.gluon.model_zoo.vision import get_model block = get_model('resnet18_v1', pretrained=True) net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) net = net["main"] net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) net = tvm.IRModule.from_expr(net) else: raise ValueError("Unsupported network: " + name) return net, params, input_shape, output_shape def print_progress(msg): """print progress message Parameters ---------- msg: str The message to print """ sys.stdout.write(msg + "\r") sys.stdout.flush()