rpc_server.py 6.02 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=redefined-outer-name, invalid-name
18 19 20 21
"""Start an RPC server"""
from __future__ import absolute_import

import argparse
22
import ast
23
import json
24 25
import multiprocessing
import sys
Tianqi Chen committed
26
import logging
27 28
import tvm
from tvm import micro
29
from .. import rpc
30

31
def main(args):
32
    """Main function
33

34 35 36 37 38
    Parameters
    ----------
    args : argparse.Namespace
        parsed args from command-line invocation
    """
39
    if args.tracker:
40
        url, port = args.tracker.rsplit(":", 1)
41 42 43 44
        port = int(port)
        tracker_addr = (url, port)
        if not args.key:
            raise RuntimeError(
45
                'Need key to present type of resource when tracker is available')
46 47 48
    else:
        tracker_addr = None

49 50 51
    if args.utvm_dev_config or args.utvm_dev_id:
        init_utvm(args)

52 53 54 55
    server = rpc.Server(args.host,
                        args.port,
                        args.port_end,
                        key=args.key,
56
                        tracker_addr=tracker_addr,
57
                        load_library=args.load_library,
58 59
                        custom_addr=args.custom_addr,
                        silent=args.silent)
60 61
    server.proc.join()

62

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
def init_utvm(args):
    """MicroTVM-specific RPC initialization

    Parameters
    ----------
    args : argparse.Namespace
        parsed args from command-line invocation
    """
    if args.utvm_dev_config and args.utvm_dev_id:
        raise RuntimeError('only one of --utvm-dev-config and --utvm-dev-id allowed')

    if args.utvm_dev_config:
        with open(args.utvm_dev_config, 'r') as dev_conf_file:
            dev_config = json.load(dev_conf_file)
    else:
        dev_config_args = ast.literal_eval(args.utvm_dev_config_args)
79 80
        generate_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['generate_config']
        dev_config = generate_config_func(*dev_config_args)
81 82 83 84 85 86 87 88 89 90 91 92 93 94

    if args.utvm_dev_config or args.utvm_dev_id:
        # add MicroTVM overrides
        @tvm.register_func('tvm.rpc.server.start', override=True)
        def server_start():
            # pylint: disable=unused-variable
            session = micro.Session(dev_config)
            session._enter()

            @tvm.register_func('tvm.rpc.server.shutdown', override=True)
            def server_shutdown():
                session._exit()


95
if __name__ == "__main__":
96 97 98 99
    parser = argparse.ArgumentParser()
    parser.add_argument('--host', type=str, default="0.0.0.0",
                        help='the hostname of the server')
    parser.add_argument('--port', type=int, default=9090,
100
                        help='The port of the RPC')
101
    parser.add_argument('--port-end', type=int, default=9199,
102
                        help='The end search port of the RPC')
103
    parser.add_argument('--tracker', type=str,
104 105
                        help=("The address of RPC tracker in host:port format. "
                              "e.g. (10.77.1.234:9190)"))
106
    parser.add_argument('--key', type=str, default="",
107 108 109 110
                        help="The key used to identify the device type in tracker.")
    parser.add_argument('--silent', action='store_true',
                        help="Whether run in silent mode.")
    parser.add_argument('--load-library', type=str,
111 112 113
                        help="Additional library to load")
    parser.add_argument('--no-fork', dest='fork', action='store_false',
                        help="Use spawn mode to avoid fork. This option \
114 115
                        is able to avoid potential fork problems with Metal, OpenCL \
                        and ROCM compilers.")
116 117
    parser.add_argument('--custom-addr', type=str,
                        help="Custom IP Address to Report to RPC Tracker")
118
    parser.add_argument('--utvm-dev-config', type=str,
119 120 121 122
                        help=('JSON config file for the target device (if using MicroTVM). '
                              'This file should contain serialized output similar to that returned '
                              "from the device module's generate_config. Can't be specified when "
                              '--utvm-dev-config-args is specified.'))
123
    parser.add_argument('--utvm-dev-config-args', type=str,
124 125 126 127 128 129 130 131
                        help=("Arguments to the device module's generate_config function. "
                              'Must be a python literal parseable by literal_eval. If specified, '
                              "the device configuration is generated using the device module's "
                              "generate_config. Can't be specified when --utvm-dev-config is "
                              "specified."))
    parser.add_argument('--utvm-dev-id', type=str,
                        help=('Unique ID for the target device (if using MicroTVM). Should '
                              'match the name of a module underneath tvm.micro.device).'))
132

133 134 135 136 137 138 139 140 141 142
    parser.set_defaults(fork=True)
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)
    if args.fork is False:
        if sys.version_info[0] < 3:
            raise RuntimeError(
                "Python3 is required for spawn mode."
            )
        multiprocessing.set_start_method('spawn')
    else:
143 144 145
        if not args.silent:
            logging.info("If you are running ROCM/Metal, fork will cause "
                         "compiler internal error. Try to launch with arg ```--no-fork```")
146
    main(args)