module.py 8.53 KB
Newer Older
1
"""Container of compiled functions of TVM."""
2
from __future__ import absolute_import as _abs
3

4
import struct
5
from collections import namedtuple
Hu Shiwen committed
6

7
from ._ffi.function import ModuleBase, _set_class_module
8
from ._ffi.function import _init_api
9
from ._ffi.libinfo import find_include_path
10
from .contrib import cc as _cc, tar as _tar, util as _util
11

12
ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
13

Hu Shiwen committed
14

15 16
class Module(ModuleBase):
    """Module container of all TVM generated functions"""
Hu Shiwen committed
17

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    def __repr__(self):
        return "Module(%s, %x)" % (self.type_key, self.handle.value)

    @property
    def type_key(self):
        """Get type key of the module."""
        return _GetTypeKey(self)

    def get_source(self, fmt=""):
        """Get source code from module, if available.

        Parameters
        ----------
        fmt : str, optional
            The specified format.
33 34 35 36 37

        Returns
        -------
        source : str
            The result source code.
38 39 40 41 42 43 44 45 46
        """
        return _GetSource(self, fmt)

    @property
    def imported_modules(self):
        """Get imported modules

        Returns
        ----------
47
        modules : list of Module
48 49
            The module
        """
50
        nmod = _ImportsSize(self)
51 52 53 54 55
        return [_GetImport(self, i) for i in range(nmod)]

    def save(self, file_name, fmt=""):
        """Save the module to file.

56 57 58
        This do not save the dependent device modules.
        See also export_shared

59 60 61 62 63 64
        Parameters
        ----------
        file_name : str
            The name of the file.
        fmt : str
            The format of the file.
65 66 67 68

        See Also
        --------
        Module.export_library : export the module to shared library.
69 70 71
        """
        _SaveToFile(self, file_name, fmt)

Tianqi Chen committed
72 73 74 75
    def export_library(self,
                       file_name,
                       fcompile=None,
                       **kwargs):
76 77 78 79 80 81 82 83 84
        """Export the module and its imported device code one library.

        This function only works on host llvm modules.
        It will pack all the imported modules

        Parameters
        ----------
        file_name : str
            The name of the shared library.
Tianqi Chen committed
85

86
        fcompile : function(target, file_list, kwargs), optional
Tianqi Chen committed
87
            Compilation function to use create dynamic library.
88 89
            If fcompile has attribute object_format, will compile host library
            to that format. Otherwise, will use default format "o".
Tianqi Chen committed
90

91
        kwargs : dict, optional
Tianqi Chen committed
92
            Additional arguments passed to fcompile
93
        """
94 95 96 97 98 99
        if self.type_key == "stackvm":
            if not file_name.endswith(".stackvm"):
                raise ValueError("Module[%s]: can only be saved as stackvm format."
                                 "did you build with LLVM enabled?" % self.type_key)
            self.save(file_name)
            return
100

101 102
        if not (self.type_key == "llvm" or self.type_key == "c"):
            raise ValueError("Module[%s]: Only llvm and c support export shared" % self.type_key)
103
        temp = _util.tempdir()
104 105 106
        if fcompile is not None and hasattr(fcompile, "object_format"):
            object_format = fcompile.object_format
        else:
107 108 109 110 111
            if self.type_key == "llvm":
                object_format = "o"
            else:
                assert self.type_key == "c"
                object_format = "cc"
112
        path_obj = temp.relpath("lib." + object_format)
113 114
        self.save(path_obj)
        files = [path_obj]
115
        is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
116 117 118
        if self.imported_modules:
            path_cc = temp.relpath("devc.cc")
            with open(path_cc, "w") as f:
119
                f.write(_PackImportsToC(self, is_system_lib))
120
            files.append(path_cc)
121 122 123 124 125
        if not fcompile:
            if file_name.endswith(".tar"):
                fcompile = _tar.tar
            else:
                fcompile = _cc.create_shared
126 127
        if self.type_key == "c":
            kwargs.update({'options': ["-I" + path for path in find_include_path()]})
Tianqi Chen committed
128
        fcompile(file_name, files, **kwargs)
129

130
    def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
131 132 133 134 135 136 137 138 139 140 141
        """Get an evaluator that measures time cost of running function.

        Parameters
        ----------
        func_name: str
            The name of the function in the module.

        ctx: TVMContext
            The context we should run this function on.

        number: int
142 143
            The number of times to run this function for taking average.
            We call these runs as one `repeat` of measurement.
144 145

        repeat: int, optional
146 147 148 149 150 151 152 153 154 155 156 157 158
            The number of times to repeat the measurement.
            In total, the function will be invoked (1 + number x repeat) times,
            where the first one is warm up and will be discarded.
            The returned result contains `repeat` costs,
            each of which is an average of `number` costs.

        min_repeat_ms: int, optional
            The minimum duration of one `repeat` in milliseconds.
            By default, one `repeat` contains `number` runs. If this parameter is set,
            the parameters `number` will be dynamically adjusted to meet the
            minimum duration requirement of one `repeat`.
            i.e., When the run time of one `repeat` falls below this time, the `number` parameter
            will be automatically increased.
159 160 161

        Note
        ----
162
        The function will be invoked  (1 + number x repeat) times,
163 164 165 166 167
        with the first call discarded in case there is lazy initialization.

        Returns
        -------
        ftimer : Function
168 169
            The function that takes same argument as func and returns a ProfileResult.
            The ProfileResult reports `repeat` time costs in seconds.
170 171
        """
        try:
172
            feval = _RPCTimeEvaluator(
173
                self, func_name, ctx.device_type, ctx.device_id, number, repeat, min_repeat_ms)
Hu Shiwen committed
174

175 176 177
            def evaluator(*args):
                """Internal wrapped evaluator."""
                # Wrap feval so we can add more stats in future.
178 179 180 181 182
                blob = feval(*args)
                fmt = "@" + ("d" * repeat)
                results = struct.unpack(fmt, blob)
                mean = sum(results) / float(repeat)
                return ProfileResult(mean=mean, results=results)
Hu Shiwen committed
183

184
            return evaluator
185 186 187
        except NameError:
            raise NameError("time_evaluate is only supported when RPC is enabled")

188

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
def system_lib():
    """Get system-wide library module singleton.

    System lib is a global module that contains self register functions in startup.
    Unlike normal dso modules which need to be loaded explicitly.
    It is useful in environments where dynamic loading api like dlopen is banned.

    To build system lib function, simply specify target option ```llvm --system-lib```
    The system lib will be available as long as the result code is linked by the program.

    The system lib is intended to be linked and loaded during the entire life-cyle of the program.
    If you want dynamic loading features, use dso modules instead.

    Returns
    -------
    module : Module
        The system-wide library module.
    """
    return _GetSystemLib()


210
def load(path, fmt=""):
211
    """Load module from file.
212 213 214 215 216 217 218 219 220

    Parameters
    ----------
    path : str
        The path to the module file.

    fmt : str, optional
        The format of the file, if not specified
        it will be inferred from suffix of the file.
221 222 223 224 225

    Returns
    -------
    module : Module
        The loaded module
226 227 228 229 230

    Note
    ----
    This function will automatically call
    cc.create_shared if the path is in format .o or .tar
231
    """
232 233 234 235 236 237 238 239 240 241 242 243
    # High level handling for .o and .tar file.
    # We support this to be consistent with RPC module load.
    if path.endswith(".o"):
        _cc.create_shared(path + ".so", path)
        path += ".so"
    elif path.endswith(".tar"):
        tar_temp = _util.tempdir()
        _tar.untar(path, tar_temp.temp_dir)
        files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
        _cc.create_shared(path + ".so", files)
        path += ".so"
    # Redirect to the load API
244 245
    return _LoadFromFile(path, fmt)

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269

def enabled(target):
    """Whether module runtime is enabled for target

    Parameters
    ----------
    target : str
        The target device type.

    Returns
    -------
    enabled : bool
        Whether runtime is enabled.

    Examples
    --------
    The following code checks if gpu is enabled.

    >>> tvm.module.enabled("gpu")
    """
    return _Enabled(target)


_init_api("tvm.module")
270
_set_class_module(Module)