rpc_server.py 5.39 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=redefined-outer-name, invalid-name
18 19 20 21
"""Start an RPC server"""
from __future__ import absolute_import

import argparse
22
import ast
23 24
import multiprocessing
import sys
Tianqi Chen committed
25
import logging
26 27
import tvm
from tvm import micro
28
from .. import rpc
29

30
def main(args):
31
    """Main function
32

33 34 35 36 37
    Parameters
    ----------
    args : argparse.Namespace
        parsed args from command-line invocation
    """
38
    if args.tracker:
39
        url, port = args.tracker.rsplit(":", 1)
40 41 42 43 44 45 46 47
        port = int(port)
        tracker_addr = (url, port)
        if not args.key:
            raise RuntimeError(
                "Need key to present type of resource when tracker is available")
    else:
        tracker_addr = None

48 49 50
    if args.utvm_dev_config or args.utvm_dev_id:
        init_utvm(args)

51 52 53 54
    server = rpc.Server(args.host,
                        args.port,
                        args.port_end,
                        key=args.key,
55
                        tracker_addr=tracker_addr,
56
                        load_library=args.load_library,
57 58
                        custom_addr=args.custom_addr,
                        silent=args.silent)
59 60
    server.proc.join()

61

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
def init_utvm(args):
    """MicroTVM-specific RPC initialization

    Parameters
    ----------
    args : argparse.Namespace
        parsed args from command-line invocation
    """
    if args.utvm_dev_config and args.utvm_dev_id:
        raise RuntimeError('only one of --utvm-dev-config and --utvm-dev-id allowed')

    if args.utvm_dev_config:
        with open(args.utvm_dev_config, 'r') as dev_conf_file:
            dev_config = json.load(dev_conf_file)
    else:
        dev_config_args = ast.literal_eval(args.utvm_dev_config_args)
        default_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['default_config']
        dev_config = default_config_func(*dev_config_args)

    if args.utvm_dev_config or args.utvm_dev_id:
        # add MicroTVM overrides
        @tvm.register_func('tvm.rpc.server.start', override=True)
        def server_start():
            # pylint: disable=unused-variable
            session = micro.Session(dev_config)
            session._enter()

            @tvm.register_func('tvm.rpc.server.shutdown', override=True)
            def server_shutdown():
                session._exit()


94
if __name__ == "__main__":
95 96 97 98
    parser = argparse.ArgumentParser()
    parser.add_argument('--host', type=str, default="0.0.0.0",
                        help='the hostname of the server')
    parser.add_argument('--port', type=int, default=9090,
99
                        help='The port of the RPC')
100
    parser.add_argument('--port-end', type=int, default=9199,
101
                        help='The end search port of the RPC')
102 103 104
    parser.add_argument('--tracker', type=str,
                        help="The address of RPC tracker in host:port format. "
                             "e.g. (10.77.1.234:9190)")
105
    parser.add_argument('--key', type=str, default="",
106 107 108 109
                        help="The key used to identify the device type in tracker.")
    parser.add_argument('--silent', action='store_true',
                        help="Whether run in silent mode.")
    parser.add_argument('--load-library', type=str,
110 111 112 113 114
                        help="Additional library to load")
    parser.add_argument('--no-fork', dest='fork', action='store_false',
                        help="Use spawn mode to avoid fork. This option \
                         is able to avoid potential fork problems with Metal, OpenCL \
                         and ROCM compilers.")
115 116
    parser.add_argument('--custom-addr', type=str,
                        help="Custom IP Address to Report to RPC Tracker")
117 118 119 120 121 122 123
    parser.add_argument('--utvm-dev-config', type=str,
                        help='JSON config file for the target device (if using MicroTVM)')
    parser.add_argument('--utvm-dev-id', type=str,
                        help='Unique ID for the target device (if using MicroTVM)')
    parser.add_argument('--utvm-dev-config-args', type=str,
                        help=('Python list of literals required to generate a default'
                              ' MicroTVM config (if --utvm-dev-id is specified)'))
124

125 126 127 128 129 130 131 132 133 134
    parser.set_defaults(fork=True)
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)
    if args.fork is False:
        if sys.version_info[0] < 3:
            raise RuntimeError(
                "Python3 is required for spawn mode."
            )
        multiprocessing.set_start_method('spawn')
    else:
135 136 137
        if not args.silent:
            logging.info("If you are running ROCM/Metal, fork will cause "
                         "compiler internal error. Try to launch with arg ```--no-fork```")
138
    main(args)