build_resnet.py 5.14 KB
Newer Older
1
#!/usr/bin/env python3
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
18 19 20 21 22 23 24 25 26 27

import argparse
import csv
import logging
from os import path as osp
import sys

import numpy as np

import tvm
28
from tvm import te
29 30
from tvm import relay
from tvm.relay import testing
31 32 33 34 35 36 37
from tvm.contrib import graph_runtime, cc

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(description='Resnet build example')
aa = parser.add_argument
38 39
aa('--build-dir', type=str, required=True, help='directory to put the build artifacts')
aa('--pretrained', action='store_true', help='use a pretrained resnet')
40 41 42 43 44 45 46 47
aa('--batch-size', type=int, default=1, help='input image batch size')
aa('--opt-level', type=int, default=3,
   help='level of optimization. 0 is unoptimized and 3 is the highest level')
aa('--target', type=str, default='llvm', help='target context for compilation')
aa('--image-shape', type=str, default='3,224,224', help='input image dimensions')
aa('--image-name', type=str, default='cat.png', help='name of input image to download')
args = parser.parse_args()

48
build_dir = args.build_dir
49 50 51 52 53 54 55 56 57 58 59 60
batch_size = args.batch_size
opt_level = args.opt_level
target = tvm.target.create(args.target)
image_shape = tuple(map(int, args.image_shape.split(",")))
data_shape = (batch_size,) + image_shape

def build(target_dir):
    """ Compiles resnet18 with TVM"""
    deploy_lib = osp.join(target_dir, 'deploy_lib.o')
    if osp.exists(deploy_lib):
        return

61 62 63
    if args.pretrained:
        # needs mxnet installed
        from mxnet.gluon.model_zoo.vision import get_model
64

65 66 67 68 69 70 71 72 73 74 75 76
        # if `--pretrained` is enabled, it downloads a pretrained
        # resnet18 trained on imagenet1k dataset for image classification task
        block = get_model('resnet18_v1', pretrained=True)
        net, params = relay.frontend.from_mxnet(block, {"data": data_shape})
        # we want a probability so add a softmax operator
        net = relay.Function(net.params, relay.nn.softmax(net.body),
            None, net.type_params, net.attrs)
    else:
        # use random weights from relay.testing
        net, params = relay.testing.resnet.get_workload(
            num_layers=18, batch_size=batch_size, image_shape=image_shape)

77
    # compile the model
78 79 80
    with relay.build_config(opt_level=opt_level):
            graph, lib, params = relay.build_module.build(net, target, params=params)

81 82 83 84 85 86
    # save the model artifacts
    lib.save(deploy_lib)
    cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
                    [osp.join(target_dir, "deploy_lib.o")])

    with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
87
        fo.write(graph)
88 89

    with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
90
        fo.write(relay.save_param_dict(params))
91 92 93

def download_img_labels():
    """ Download an image and imagenet1k class labels for test"""
94 95
    from mxnet.gluon.utils import download

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    img_name = 'cat.png'
    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
                      'imagenet1000_clsid_to_human.txt'])
    synset_name = 'synset.txt'
    download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
    download(synset_url, synset_name)

    with open(synset_name) as fin:
        synset = eval(fin.read())

    with open("synset.csv", "w") as fout:
        w = csv.writer(fout)
        w.writerows(synset.items())

112
def test_build(build_dir):
113
    """ Sanity check with random input"""
114
    graph = open(osp.join(build_dir, "deploy_graph.json")).read()
115
    lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so"))
116
    params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read())
117 118 119 120 121 122 123 124 125 126
    input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
    ctx = tvm.cpu()
    module = graph_runtime.create(graph, lib, ctx)
    module.load_params(params)
    module.run(data=input_data)
    out = module.get_output(0).asnumpy()


if __name__ == '__main__':
    logger.info("building the model")
127
    build(build_dir)
128 129
    logger.info("build was successful")
    logger.info("test the build artifacts")
130
    test_build(build_dir)
131
    logger.info("test was successful")
132 133 134
    if args.pretrained:
        download_img_labels()
        logger.info("image and synset downloads are successful")