interpreter.py 7.98 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
#pylint: disable=no-else-return
18
"""The Python interface to the Relay reference interpreter."""
19 20 21 22
from __future__ import absolute_import

import numpy as np

23
from tvm.runtime import container
24 25
from tvm.ir import IRModule

26
from . import _backend
Zhi committed
27
from .. import _make, analysis, transform
28 29
from ... import nd
from ..base import Object, register_relay_node
30
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
31 32
from ..scope_builder import ScopeBuilder

33 34

@register_relay_node
35
class ConstructorValue(Object):
36
    def __init__(self, tag, fields, constructor):
37
        self.__init_handle_by_constructor__(
38
            _make.ConstructorValue, tag, fields, constructor)
39 40 41


@register_relay_node
42
class RefValue(Object):
43 44 45 46 47
    def __init__(self, value):
        self.__init_handle_by_constructor__(
            _make.RefValue, value)


48
def _arg_to_ast(mod, arg):
49 50
    if isinstance(arg, nd.NDArray):
        return Constant(arg.copyto(nd.cpu(0)))
51 52
    elif isinstance(arg, container.ADT):
        return Tuple([_arg_to_ast(mod, field) for field in arg])
53
    elif isinstance(arg, tuple):
54
        return Tuple([_arg_to_ast(mod, field) for field in arg])
55
    elif isinstance(arg, RefValue):
56
        return RefCreate(_arg_to_ast(mod, arg.value))
57
    elif isinstance(arg, ConstructorValue):
58 59
        return Call(mod.get_constructor(arg.tag),
                    [_arg_to_ast(mod, field) for field in arg.fields])
60 61 62 63 64 65 66 67 68 69
    elif isinstance(arg, np.ndarray):
        return Constant(nd.array(arg))
    elif isinstance(arg, Constant):
        return arg
    else:
        return const(arg)


class Executor(object):
    """An abstract interface for executing Relay programs."""
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

    def _convert_args(self, expr, args, kwargs):
        """
        Convert the combination of arguments and keyword arguments
        into a sequence of arguments that may be passed to
        a Relay evaluator.

        We first provide all positional arguments, and then attempt
        to fill in the remaining arguments using the keyword arguments. We
        map the keyword arguments to the corresponding parameters, if there
        is an ambiguity between positional and keyword arguments this
        procedure will raise an error.

        Parameters
        ----------
        expr: relay.Expr
            The expression to evaluate

88
        args: List[tvm.nd.NDArray]
89 90 91 92 93 94
            The arguments to pass to the evaluator.

        kwargs: Dict[str, tvm.NDArrray]
            The keyword arguments to pass to the evaluator.

        Returns:
95
            args: List[tvm.nd.NDArray]
96 97
                The new arguments with all keyword arguments placed in the correct slot.
        """
98 99
        assert expr is not None

100 101 102 103
        if not kwargs:
            return args

        if kwargs and not isinstance(expr, Function):
104 105
            raise Exception("can only supply keyword parameters for a "
                            "relay.Function, found {0}".format(expr))
106 107 108 109 110 111 112 113 114 115

        params = expr.params
        param_names = [p.name_hint for p in params]
        num_of_args = len(args)

        cargs = list(args)[:]
        for i, name in enumerate(param_names):
            if i < num_of_args:
                if kwargs.get(name):
                    raise Exception(
116 117 118
                        "duplicate argument supplied in "
                        "both positional args (at position: {0}), "
                        "and keyword argument (with name: {1})".format(i, name))
119 120 121 122 123
            else:
                cargs.append(kwargs[name])

        if len(cargs) != len(params):
            raise Exception(
124 125
                "insufficient arguments, expected "
                "{0}, provided {1}".format(len(cargs), len(params)))
126 127 128

        return tuple(cargs)

129
    def _make_executor(self, expr=None):
130 131 132 133 134 135
        """
        Construct a Python function that implements the evaluation
        of expression.

        Parameters
        ----------
136
        expr: Optional[relay.Expr]
137 138 139 140 141 142 143 144 145
            The Relay expression to execute.

        Returns
        -------
        executor: function,
            A Python function which implements the behavior of `expr`.
        """
        raise NotImplementedError()

146
    def evaluate(self, expr=None, binds=None):
147 148 149 150 151
        """
        Evaluate a Relay expression on the executor.

        Parameters
        ----------
152
        expr: Optional[tvm.relay.Expr]
153 154
            The expression to evaluate.

155
        binds: Optional[Map[tvm.relay.Var, tvm.relay.Expr]]
156 157 158 159
            Additional binding of free variable.

        Returns
        -------
160
        val : Union[function, Object]
161 162 163 164 165
            The evaluation result.
        """
        if binds:
            scope_builder = ScopeBuilder()
            for key, value in binds.items():
166
                scope_builder.let(key, _arg_to_ast(self.mod, value))
167 168 169
            scope_builder.ret(expr)
            expr = scope_builder.get()

170 171 172
        if not expr:
            return self._make_executor()

173
        if isinstance(expr, Function):
Zhi committed
174
            assert not analysis.free_vars(expr)
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

        if isinstance(expr, (Function, GlobalVar)):
            return self._make_executor(expr)

        # normal expression evaluated by running a function.
        func = Function([], expr)
        return self._make_executor(func)()


class Interpreter(Executor):
    """
    Simple interpreter interface.

    Parameters
    ----------
190
    mod : tvm.IRModule
191 192
        The module to support the execution.

193
    ctx : tvmContext
194 195 196 197 198 199 200 201 202 203
        The runtime context to run the code on.

    target : tvm.Target
        The target option to build the function.
    """
    def __init__(self, mod, ctx, target):
        self.mod = mod
        self.ctx = ctx
        self.target = target

204 205
    def optimize(self):
        """Optimize functions in a module.
206 207 208

        Returns
        -------
209
        opt_mod : tvm.IRModule
210
            The optimized module.
211
        """
212 213
        seq = transform.Sequential([transform.SimplifyInference(),
                                    transform.FuseOps(0),
214
                                    transform.ToANormalForm(),
215 216 217 218 219 220
                                    transform.InferType()])
        return seq(self.mod)

    def _make_executor(self, expr=None):
        if expr is None or isinstance(expr, GlobalVar):
            assert self.mod is not None
221
        def _interp_wrapper(*args, **kwargs):
222
            if expr is None:
223
                args = self._convert_args(self.mod["main"], args, kwargs)
224 225
            else:
                args = self._convert_args(expr, args, kwargs)
226

227 228
            relay_args = []
            for arg in args:
229
                relay_args.append(_arg_to_ast(self.mod, arg))
230

231 232 233 234
            # Set the entry function for the module.
            if expr is None:
                pass
            elif isinstance(expr, GlobalVar):
235
                self.mod["main"] = self.mod[expr]
236
            else:
237 238 239 240
                assert isinstance(expr, Function)
                func = Function([], Call(expr, relay_args))
                relay_args = []
                if self.mod:
241
                    self.mod["main"] = func
242
                else:
243
                    self.mod = IRModule.from_expr(func)
244 245

            mod = self.optimize()
246
            opt_expr = Call(mod["main"], relay_args)
247 248
            _intrp = _backend.CreateInterpreter(mod, self.ctx, self.target)
            return _intrp(opt_expr)
249
        return _interp_wrapper