_backend.py 2.92 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
"""The interface of expr function exposed from C++."""
from __future__ import absolute_import

from ... import build_module as _build
from ... import container as _container
from ..._ffi.function import _init_api, register_func


@register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
    """Backend function for lowering.

    Parameters
    ----------
    sch : tvm.Schedule
        The schedule.

    inputs : List[tvm.Tensor]
        The inputs to the function.

    func_name : str
        The name of the function.

    source-func : tvm.relay.Function
        The source function to be lowered.

    Returns
    -------
    lowered_funcs : List[tvm.LoweredFunc]
        The result of lowering.
    """
    import traceback
    # pylint: disable=broad-except
    try:
        f = _build.lower(sch, inputs, name=func_name)
52 53
        # logging.debug("lower function %s", func_name)
        # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    except Exception:
        msg = traceback.format_exc()
        msg += "Error during compile function\n"
        msg += "-----------------------------\n"
        msg += source_func.astext()
        raise RuntimeError(msg)
    return f if isinstance(
        f, (_container.Array, tuple, list)) else [f]


@register_func("relay.backend.build")
def build(funcs, target, target_host=None):
    """Backend build function.

    Parameters
    ----------
70 71 72
    funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
        A list of lowered functions or dictionary mapping from targets to
        lowered functions.
73

74

75
    target : tvm.Target
76
        The target to run the code on.
77 78

    target_host : tvm.Target
79
        The host target.
80 81 82 83

    Returns
    -------
    module : tvm.Module
84
        The runtime module.
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    """
    if target_host == "":
        target_host = None
    return _build.build(funcs, target=target, target_host=target_host)


@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue):
    return str(tvalue.data.asnumpy())


@register_func("relay._constant_repr")
def _tensor_constant_repr(tvalue):
    return str(tvalue.data.asnumpy())


_init_api("relay.backend", __name__)