api.py 1.45 KB
Newer Older
1 2 3
"""APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs

4
from .._ffi.base import decorate
5 6 7
from .. import _api_internal as _tvm_internal
from ..tensor import Tensor

8
from .parser import parse_python
9
from .util import _pruned_source
10

11 12 13 14 15 16 17 18 19 20 21 22

def script(pyfunc):
    """Decorate a python function function as  hybrid script.

    The hybrid function support emulation mode and parsing to
    the internal language IR.

    Returns
    -------
    hybrid_func : function
        A decorated hybrid script function.
    """
23
    def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
24 25
        from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
        if _is_tvm_arg_types(args):
26
            src = _pruned_source(func)
27
            parser = parse_python(src, func.__globals__, args)
28 29 30 31 32 33 34 35 36

            input_tensors = []
            for i in args:
                if isinstance(i, Tensor):
                    input_tensors.append(i)
            op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
                                         parser.outputs, parser.parsed_body)
            res = [op.output(i) for i in range(len(parser.outputs))]
            return res[0] if len(res) == 1 else res
37

38
        intersect = _enter_hybrid_runtime(func)
39
        value = func(*args, **kwargs)
40
        _restore_runtime(func, intersect)
41
        return value
42

43
    return decorate(pyfunc, wrapped_func)