build_module.py 10.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20
"""
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
21
import warnings
22
import numpy as np
23

24
from tvm import expr as tvm_expr
25 26
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
27
from . import _build_module
28
from . import ty as _ty
29
from . import expr as _expr
30
from .module import Module as _Module
31
from .backend import interpreter as _interpreter
32
from .backend.vm import VMExecutor
33

34 35 36 37
def _update_target(target):
    target = target if target else _target.current_target()
    if target is None:
        raise ValueError("Target is not set in env or passed as argument.")
38

39 40 41 42 43 44 45 46 47 48 49 50 51
    tgts = {}
    if isinstance(target, (str, _target.Target)):
        dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type)
        tgts[dev_type] = _target.create(target)
    elif isinstance(target, dict):
        for dev, tgt in target.items():
            dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type)
            tgts[dev_type] = _target.create(tgt)
    else:
        raise TypeError("target is expected to be str or " +
                        "tvm.target.Target, but received " +
                        "{}".format(type(target)))
    return tgts
52

53

54 55 56
class BuildModule(object):
    """Build a Relay function to run on TVM graph runtime. This class is used
    to expose the `RelayBuildModule` APIs implemented in C++.
57
    """
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
    def __init__(self):
        self.mod = _build_module._BuildModule()
        self._get_graph_json = self.mod["get_graph_json"]
        self._get_module = self.mod["get_module"]
        self._build = self.mod["build"]
        self._set_params_func = self.mod["set_params"]
        self._get_params_func = self.mod["get_params"]

    def build(self, func, target=None, target_host=None, params=None):
        """
        Parameters
        ----------
        func: relay.Function
            The function to build.

        target : str, :any:`tvm.target.Target`, or dict of str(i.e.
        device/context name) to str/tvm.target.Target, optional
            For heterogeneous compilation, it is a dictionary indicating context
            to target mapping. For homogeneous compilation, it is a build target.

        target_host : str or :any:`tvm.target.Target`, optional
            Host compilation target, if target is device.
            When TVM compiles device specific program such as CUDA,
            we also need host(CPU) side code to interact with the driver
            to setup the dimensions and parameters correctly.
            target_host is used to specify the host side codegen target.
            By default, llvm is used if it is enabled,
            otherwise a stackvm intepreter is used.

        params : dict of str to NDArray
            Input parameters to the graph that do not change
            during inference time. Used for constant folding.

        Returns
        -------
        graph_json : str
            The json string that can be accepted by graph runtime.

        mod : tvm.Module
            The module containing necessary libraries.

        params : dict
            The parameters of the final graph.
        """
        target = _update_target(target)

104 105 106
        # Setup the params.
        if params:
            self._set_params(params)
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        # Build the function
        self._build(func, target, target_host)
        # Get artifacts
        graph_json = self.get_json()
        mod = self.get_module()
        params = self.get_params()

        return graph_json, mod, params

    def _set_params(self, params):
        inputs = {}
        for name, param in params.items():
            if isinstance(param, np.ndarray):
                param = _nd.array(param)
            inputs[name] = _expr.const(param)
        self._set_params_func(inputs)

    def get_json(self):
        """Return the json file of the built program."""
        return self._get_graph_json()

    def get_module(self):
        """Return the built module."""
        return self._get_module()

    def get_params(self):
        """Return the updated weights."""
        params = self._get_params_func()
        ret = {}
        for key, value in params.items():
            ret[key] = value.data
        return ret

140

141
def build(mod, target=None, target_host=None, params=None):
142 143
    """Helper function that builds a Relay function to run on TVM graph
    runtime.
144 145 146

    Parameters
    ----------
147 148
    mod : relay.Module
        The module to build. Using relay.Function is deprecated.
149

150 151 152 153
    target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
    name) to str/tvm.target.Target, optional
        For heterogeneous compilation, it is a dictionary indicating context to
        target mapping. For homogeneous compilation, it is a build target.
154

155
    target_host : str or :any:`tvm.target.Target`, optional
156 157 158 159 160 161 162 163 164 165
        Host compilation target, if target is device.
        When TVM compiles device specific program such as CUDA,
        we also need host(CPU) side code to interact with the driver
        setup the dimensions and parameters correctly.
        target_host is used to specify the host side codegen target.
        By default, llvm is used if it is enabled,
        otherwise a stackvm intepreter is used.

    params : dict of str to NDArray
        Input parameters to the graph that do not change
166
        during inference time. Used for constant folding.
167 168 169

    Returns
    -------
170 171
    graph_json : str
        The json string that can be accepted by graph runtime.
172

173 174 175 176 177
    mod : tvm.Module
        The module containing necessary libraries.

    params : dict
        The parameters of the final graph.
178
    """
179
    if isinstance(mod, _Module):
180
        func = mod["main"]
181 182 183 184 185 186 187 188 189
    elif isinstance(mod, _expr.Function):
        func = mod
        warnings.warn(
            "Please use input parameter mod (tvm.relay.module.Module) "
            "instead of deprecated parameter func (tvm.relay.expr.Function)",
            DeprecationWarning)
    else:
        raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")

190
    target = _update_target(target)
191

192 193 194 195 196
    if isinstance(target_host, (str, _target.Target)):
        target_host = _target.create(target_host)
    elif target_host:
        raise ValueError("target host must be the type of str, " +
                         "tvm.target.Target, or None")
197 198 199 200

    # If current dispatch context is fallback context (the default root context),
    # then load pre-tuned parameters from TopHub
    if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
201
        tophub_context = autotvm.tophub.context(list(target.values()))
202 203 204 205
    else:
        tophub_context = autotvm.util.EmptyContext()

    with tophub_context:
206
        bld_mod = BuildModule()
207
        graph_json, mod, params = bld_mod.build(func, target, target_host, params)
208
    return graph_json, mod, params
209 210 211 212 213 214 215 216 217


class GraphExecutor(_interpreter.Executor):
    """Wrapper around Executor interface.

    This executor is used for debug and testing purpoes.

    Parameters
    ----------
218
    mod : :py:class:`~tvm.relay.module.Module`
219 220
        The module to support the execution.

221
    ctx : :py:class:`TVMContext`
222 223
        The runtime context to run the code on.

224
    target : :py:class:`Target`
225 226
        The target option to build the function.
    """
227

228
    def __init__(self, mod, ctx, target):
229
        assert mod is not None
230 231 232 233
        self.mod = mod
        self.ctx = ctx
        self.target = target

234
    def _make_executor(self, expr=None):
235
        if expr:
236 237
            self.mod["main"] = expr
        ret_type = self.mod["main"].checked_type.ret_type
238
        num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
239
        graph_json, mod, params = build(self.mod, target=self.target)
240 241
        gmodule = _graph_rt.create(graph_json, mod, self.ctx)
        if params:
242
            gmodule.set_input(**params)
243 244

        def _graph_wrapper(*args, **kwargs):
245
            args = self._convert_args(self.mod["main"], args, kwargs)
246 247 248 249 250
            # Create map of inputs.
            for i, arg in enumerate(args):
                gmodule.set_input(i, arg)
            # Run the module, and fetch the output.
            gmodule.run()
251
            # make a copy so multiple invocation won't hurt perf.
252 253 254 255 256 257
            if num_outputs == 1:
                return gmodule.get_output(0).copyto(_nd.cpu(0))
            outputs = []
            for i in range(num_outputs):
                outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
            return outputs
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272

        return _graph_wrapper


def create_executor(kind="debug",
                    mod=None,
                    ctx=None,
                    target="llvm"):
    """Factory function to create an executor.

    Parameters
    ----------
    kind : str
        The type of executor

273
    mod : :py:class:`~tvm.relay.module.Module`
274
        The Relay module containing collection of functions
275

276
    ctx : :py:class:`tvm.TVMContext`
277 278
        The context to execute the code.

279
    target : :py:class:`tvm.Target`
280 281
        The corresponding context
    """
282 283
    if mod is None:
        mod = _Module()
284 285 286 287 288 289 290 291 292
    if ctx is not None:
        assert ctx.device_type == _nd.context(str(target), 0).device_type
    else:
        ctx = _nd.context(str(target), 0)

    if isinstance(target, str):
        target = _target.create(target)
    if kind == "debug":
        return _interpreter.Interpreter(mod, ctx, target)
293
    if kind == "graph":
294
        return GraphExecutor(mod, ctx, target)
295 296 297 298
    elif kind == "vm":
        return VMExecutor(mod, ctx, target)
    else:
        raise RuntimeError("unknown execution strategy: {0}".format(kind))