Commit 1c97eaf6 by Tianqi Chen Committed by GitHub

MXNet NDArray bridge. (#930)

* MXNet NDArray bridge.
Support convert a tvm Function as MXNet's async NDArray function.

* fix lint

* update comment
parent 60d42a97
"""MXNet bridge wrap Function MXNet's async function."""
from __future__ import absolute_import as _abs
from .. import api, _api_internal, ndarray
from ..module import Module
# pylint: disable=invalid-name
_wrap_async = None
def to_mxnet_func(func, const_loc=None):
"""Wrap a TVM function as MXNet function
MXNet function runs asynchrously via its engine.
Parameters
----------
func : Function
A TVM function that can take positional arguments
const_loc : list of int
List of integers indicating the argument position
of read only NDArray argument.
The NDArray argument location that are not annotated
will be viewed as mutable arrays in MXNet's engine.
Returns
-------
async_func : Function
A function that can take MXNet NDArray as argument
in places that used to expect TVM NDArray.
Run asynchrously in MXNet's async engine.
"""
# only import mxnet when wrap get called.
# pylint: disable=import-self
import mxnet
if isinstance(func, Module):
func = func.entry_func
def _get_bridge_func():
"""Get MXNet bridge function"""
if not mxnet.base._LIB.MXTVMBridge:
raise RuntimeError(
"MXTVMBridge not exist in mxnet package,"
" please update to latest version")
fdict = api.extract_ext_funcs(mxnet.base._LIB.MXTVMBridge)
ret = fdict["WrapAsyncCall"]
ret.is_global = True
return ret
global _wrap_async
if _wrap_async is None:
# Register extension type in first time
_wrap_async = _get_bridge_func()
ndarray.register_extension(mxnet.nd.NDArray)
const_loc = const_loc if const_loc else []
return _wrap_async(func, _api_internal._TVMSetStream, len(const_loc), *const_loc)
......@@ -36,4 +36,9 @@ TVM_REGISTER_API("_load_json")
TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
});
} // namespace tvm
def mxnet_check():
"""This is a simple test function for MXNet bridge
It is not included as nosetests, because of its dependency on mxnet
User can directly run this script to verify correctness.
"""
import mxnet as mx
import topi
import tvm
import numpy as np
from tvm.contrib.mxnet import to_mxnet_func
# build a TVM function through topi
n = 20
shape = (20,)
scale = tvm.var("scale", dtype="float32")
x = tvm.placeholder(shape)
y = tvm.placeholder(shape)
z = topi.broadcast_add(x, y)
zz = tvm.compute(shape, lambda *i: z(*i) * scale)
target = tvm.target.cuda()
# build the function
with target:
s = topi.generic.schedule_injective(zz)
f = tvm.build(s, [x, y, zz, scale])
# get a mxnet version
mxf = to_mxnet_func(f, const_loc=[0, 1])
ctx = mx.gpu(0)
xx = mx.nd.uniform(shape=shape, ctx=ctx)
yy = mx.nd.uniform(shape=shape, ctx=ctx)
zz = mx.nd.empty(shape=shape, ctx=ctx)
# invoke myf: this runs in mxnet engine
mxf(xx, yy, zz, 10.0)
mxf(xx, yy, zz, 10.0)
np.testing.assert_allclose(
zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)
if __name__ == "__main__":
mxnet_check()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment