base.py 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Base definitions for MicroTVM"""
18 19 20 21 22

from __future__ import absolute_import

import os
import sys
23
from enum import Enum
24

25
import tvm
26 27
import tvm._ffi

28 29
from tvm.contrib import util as _util
from tvm.contrib import cc as _cc
30

31

32 33 34 35 36 37 38
class LibType(Enum):
    """Enumeration of library types that can be compiled and loaded onto a device"""
    # library to be used as a MicroTVM runtime
    RUNTIME = 0
    # library to be used as an operator
    OPERATOR = 1

39 40 41 42 43 44

class Session:
    """MicroTVM Device Session

    Parameters
    ----------
45 46 47
    config : dict
        configuration for this session (as generated by
        `tvm.micro.device.host.default_config()`, for example)
48 49 50 51 52 53

    Example
    --------
    .. code-block:: python

      c_mod = ...  # some module generated with "c" as the target
54 55 56
      dev_config = micro.device.arm.stm32f746xx.default_config("127.0.0.1", 6666)
      with tvm.micro.Session(dev_config) as sess:
          micro_mod = create_micro_mod(c_mod, dev_config)
57 58
    """

59
    def __init__(self, config):
60
        self._check_system()
61 62 63 64 65 66 67 68 69 70
        # TODO(weberlo): add config validation

        # grab a binutil instance from the ID in the config
        dev_funcs = tvm.micro.device.get_device_funcs(config["device_id"])
        self.create_micro_lib = dev_funcs["create_micro_lib"]
        self.toolchain_prefix = config["toolchain_prefix"]
        self.mem_layout = config["mem_layout"]
        self.word_size = config["word_size"]
        self.thumb_mode = config["thumb_mode"]
        self.comms_method = config["comms_method"]
71 72

        # First, find and compile runtime library.
73
        runtime_src_path = os.path.join(get_micro_host_driven_dir(), "utvm_runtime.c")
74 75
        tmp_dir = _util.tempdir()
        runtime_obj_path = tmp_dir.relpath("utvm_runtime.obj")
76 77 78 79 80 81 82 83 84 85 86 87
        self.create_micro_lib(runtime_obj_path, runtime_src_path, LibType.RUNTIME)
        #input(f"check {runtime_obj_path}: ")

        comms_method = config["comms_method"]
        if comms_method == "openocd":
            server_addr = config["server_addr"]
            server_port = config["server_port"]
        elif comms_method == "host":
            server_addr = ""
            server_port = 0
        else:
            raise RuntimeError(f"unknown communication method: f{self.comms_method}")
88

89
        self.module = _CreateSession(
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
            comms_method,
            runtime_obj_path,
            self.toolchain_prefix,
            self.mem_layout["text"].get("start", 0),
            self.mem_layout["text"]["size"],
            self.mem_layout["rodata"].get("start", 0),
            self.mem_layout["rodata"]["size"],
            self.mem_layout["data"].get("start", 0),
            self.mem_layout["data"]["size"],
            self.mem_layout["bss"].get("start", 0),
            self.mem_layout["bss"]["size"],
            self.mem_layout["args"].get("start", 0),
            self.mem_layout["args"]["size"],
            self.mem_layout["heap"].get("start", 0),
            self.mem_layout["heap"]["size"],
            self.mem_layout["workspace"].get("start", 0),
            self.mem_layout["workspace"]["size"],
            self.mem_layout["stack"].get("start", 0),
            self.mem_layout["stack"]["size"],
            self.word_size,
            self.thumb_mode,
            server_addr,
            server_port)
113 114 115 116 117 118 119 120 121
        self._enter = self.module["enter"]
        self._exit = self.module["exit"]

    def _check_system(self):
        """Check if the user's system is supported by MicroTVM.

        Raises error if not supported.
        """
        if not sys.platform.startswith("linux"):
122
            raise RuntimeError("MicroTVM is currently only supported on Linux hosts")
123 124 125
        # TODO(weberlo): Add 32-bit support.
        # It's primarily the compilation pipeline that isn't compatible.
        if sys.maxsize <= 2**32:
126
            raise RuntimeError("MicroTVM is currently only supported on 64-bit host platforms")
127

128 129
    def __enter__(self):
        self._enter()
130
        return self
131 132 133 134 135

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self._exit()


136 137 138 139 140 141 142 143 144 145
def create_micro_mod(c_mod, dev_config):
    """Produces a micro module from a given module.

    Parameters
    ----------
    c_mod : tvm.module.Module
        module with "c" as its target backend

    dev_config : Dict[str, Any]
        MicroTVM config dict for the target device
146 147 148

    Return
    ------
149 150
    micro_mod : tvm.module.Module
        micro module for the target device
151
    """
152 153 154 155 156 157 158
    temp_dir = _util.tempdir()
    lib_obj_path = temp_dir.relpath("dev_lib.obj")
    c_mod.export_library(
        lib_obj_path,
        fcompile=cross_compiler(dev_config, LibType.OPERATOR))
    micro_mod = tvm.module.load(lib_obj_path)
    return micro_mod
159 160


161 162
def cross_compiler(dev_config, lib_type):
    """Create a cross-compile function that wraps `create_lib` for a `Binutil` instance.
163 164 165 166 167

    For use in `tvm.module.Module.export_library`.

    Parameters
    ----------
168 169
    dev_config : Dict[str, Any]
        MicroTVM config dict for the target device
170

171 172
    lib_type : micro.LibType
        whether to compile a MicroTVM runtime or operator library
173 174 175 176 177 178 179 180 181 182 183 184

    Return
    ------
    func : Callable[[str, str, Optional[str]], None]
        cross compile function taking a destination path for the object file
        and a path for the input source file.

    Example
    --------
    .. code-block:: python

      c_mod = ...  # some module generated with "c" as the target
185
      fcompile = tvm.micro.cross_compiler(dev_config, LibType.OPERATOR)
186 187
      c_mod.export_library("dev_lib.obj", fcompile=fcompile)
    """
188 189
    dev_funcs = tvm.micro.device.get_device_funcs(dev_config['device_id'])
    create_micro_lib = dev_funcs['create_micro_lib']
190 191 192 193 194
    def compile_func(obj_path, src_path, **kwargs):
        if isinstance(obj_path, list):
            obj_path = obj_path[0]
        if isinstance(src_path, list):
            src_path = src_path[0]
195 196
        create_micro_lib(obj_path, src_path, lib_type, kwargs.get("options", None))
    return _cc.cross_compiler(compile_func, output_format="obj")
197 198


199 200
def get_micro_host_driven_dir():
    """Get directory path for uTVM host-driven runtime source files.
201

202 203 204 205 206 207 208 209 210
    Return
    ------
    micro_device_dir : str
        directory path
    """
    micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
    micro_host_driven_dir = os.path.join(micro_dir, "..", "..", "..",
                                         "src", "runtime", "micro", "host_driven")
    return micro_host_driven_dir
211 212


213 214
def get_micro_device_dir():
    """Get directory path for parent directory of device-specific source files
215

216 217 218 219
    Return
    ------
    micro_device_dir : str
        directory path
220
    """
221 222 223 224
    micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
    micro_device_dir = os.path.join(micro_dir, "..", "..", "..",
                                    "src", "runtime", "micro", "device")
    return micro_device_dir
225 226


227
tvm._ffi._init_api("tvm.micro", "tvm.micro.base")