mxnet.py 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
"""MXNet bridge wrap Function MXNet's async function."""
from __future__ import absolute_import as _abs

from .. import api, _api_internal, ndarray
from ..module import Module

# pylint: disable=invalid-name
_wrap_async = None


def to_mxnet_func(func, const_loc=None):
    """Wrap a TVM function as MXNet function

    MXNet function runs asynchrously via its engine.

    Parameters
    ----------
    func : Function
        A TVM function that can take positional arguments

    const_loc : list of int
        List of integers indicating the argument position
        of read only NDArray argument.
        The NDArray argument location that are not annotated
        will be viewed as mutable arrays in MXNet's engine.

    Returns
    -------
    async_func : Function
        A function that can take MXNet NDArray as argument
        in places that used to expect TVM NDArray.
        Run asynchrously in MXNet's async engine.
    """
    # only import mxnet when wrap get called.
    # pylint: disable=import-self
    import mxnet
    if isinstance(func, Module):
        func = func.entry_func

    def _get_bridge_func():
        """Get MXNet bridge function"""
        if not mxnet.base._LIB.MXTVMBridge:
            raise RuntimeError(
                "MXTVMBridge not exist in mxnet package,"
                " please update to latest version")

        fdict = api.extract_ext_funcs(mxnet.base._LIB.MXTVMBridge)
        ret = fdict["WrapAsyncCall"]
        ret.is_global = True
        return ret
    global _wrap_async

    if _wrap_async is None:
        # Register extension type in first time
        _wrap_async = _get_bridge_func()
        ndarray.register_extension(mxnet.nd.NDArray)

    const_loc = const_loc if const_loc else []
    return _wrap_async(func, _api_internal._TVMSetStream, len(const_loc), *const_loc)