Commit b3f09b01 by Tianqi Chen Committed by GitHub

[RPC] LocalSession to provide RPCSession back by local env (#1102)

parent 56a6ef31
...@@ -10,4 +10,4 @@ upload and run remote RPC server, get the result back to verify correctness. ...@@ -10,4 +10,4 @@ upload and run remote RPC server, get the result back to verify correctness.
""" """
from .server import Server from .server import Server
from .client import RPCSession, connect, connect_tracker from .client import RPCSession, LocalSession, connect, connect_tracker
...@@ -7,8 +7,11 @@ import struct ...@@ -7,8 +7,11 @@ import struct
import time import time
from . import base from . import base
from .. import util
from ..._ffi.base import TVMError from ..._ffi.base import TVMError
from ..._ffi.ndarray import context as _context from ..._ffi import function as function
from ..._ffi import ndarray as nd
from ...module import load as _load_module
class RPCSession(object): class RPCSession(object):
...@@ -51,36 +54,12 @@ class RPCSession(object): ...@@ -51,36 +54,12 @@ class RPCSession(object):
ctx: TVMContext ctx: TVMContext
The corresponding encoded remote context. The corresponding encoded remote context.
""" """
ctx = _context(dev_type, dev_id) ctx = nd.context(dev_type, dev_id)
encode = (self._tbl_index + 1) * base.RPC_SESS_MASK encode = (self._tbl_index + 1) * base.RPC_SESS_MASK
ctx.device_type += encode ctx.device_type += encode
ctx._rpc_sess = self ctx._rpc_sess = self
return ctx return ctx
def cpu(self, dev_id=0):
"""Construct remote CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct remote GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct remote OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def opengl(self, dev_id=0):
"""Construct remote OpenGL device."""
return self.context(11, dev_id)
def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)
def upload(self, data, target=None): def upload(self, data, target=None):
"""Upload file to remote runtime temp folder """Upload file to remote runtime temp folder
...@@ -139,6 +118,61 @@ class RPCSession(object): ...@@ -139,6 +118,61 @@ class RPCSession(object):
""" """
return base._LoadRemoteModule(self._sess, path) return base._LoadRemoteModule(self._sess, path)
def cpu(self, dev_id=0):
"""Construct CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct Metal device."""
return self.context(8, dev_id)
def opengl(self, dev_id=0):
"""Construct OpenGL device."""
return self.context(11, dev_id)
def ext_dev(self, dev_id=0):
"""Construct extension device."""
return self.context(12, dev_id)
class LocalSession(RPCSession):
"""RPCSession interface backed by local environment.
This class can be used to implement functions that
need to be ran both locally and remotely.
"""
def __init__(self):
# pylint: disable=super-init-not-called
self.context = nd.context
self.get_function = function.get_global_func
self._temp = util.tempdir()
def upload(self, data, target=None):
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
with open(self._temp.relpath(target), "wb") as f:
f.write(blob)
def download(self, path):
return bytearray(open(self._temp.relpath(path), "rb").read())
def load_module(self, path):
return _load_module(self._temp.relpath(path))
class TrackerSession(object): class TrackerSession(object):
"""Tracker client session. """Tracker client session.
......
...@@ -22,7 +22,7 @@ import time ...@@ -22,7 +22,7 @@ import time
from ..._ffi.function import register_func from ..._ffi.function import register_func
from ..._ffi.base import py_str from ..._ffi.base import py_str
from ...module import load as _load_module from ...module import load as _load_module
from .. import util, cc, tar from .. import util
from . import base from . import base
from . base import TrackerCode from . base import TrackerCode
...@@ -38,17 +38,6 @@ def _server_env(): ...@@ -38,17 +38,6 @@ def _server_env():
def load_module(file_name): def load_module(file_name):
"""Load module from remote side.""" """Load module from remote side."""
path = temp.relpath(file_name) path = temp.relpath(file_name)
# Try create a shared library in remote
if path.endswith(".o"):
logging.info("Create shared library based on %s", path)
cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = util.tempdir()
tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
cc.create_shared(path + ".so", files)
path += ".so"
m = _load_module(path) m = _load_module(path)
logging.info("load_module %s", path) logging.info("load_module %s", path)
return m return m
......
...@@ -186,7 +186,7 @@ def system_lib(): ...@@ -186,7 +186,7 @@ def system_lib():
def load(path, fmt=""): def load(path, fmt=""):
"""Load module from file """Load module from file.
Parameters Parameters
---------- ----------
...@@ -201,7 +201,24 @@ def load(path, fmt=""): ...@@ -201,7 +201,24 @@ def load(path, fmt=""):
------- -------
module : Module module : Module
The loaded module The loaded module
Note
----
This function will automatically call
cc.create_shared if the path is in format .o or .tar
""" """
# High level handling for .o and .tar file.
# We support this to be consistent with RPC module load.
if path.endswith(".o"):
_cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = _util.tempdir()
_tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files)
path += ".so"
# Redirect to the load API
return _LoadFromFile(path, fmt) return _LoadFromFile(path, fmt)
......
...@@ -61,14 +61,14 @@ def test_rpc_remote_module(): ...@@ -61,14 +61,14 @@ def test_rpc_remote_module():
if not tvm.module.enabled("rpc"): if not tvm.module.enabled("rpc"):
return return
server = rpc.Server("localhost") server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port) client = rpc.connect(server.host, server.port)
# graph # graph
n = tvm.convert(1024) n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
def check_remote(): def check_remote(remote):
if not tvm.module.enabled("llvm"): if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled") print("Skip because llvm is not enabled")
return return
...@@ -86,7 +86,7 @@ def test_rpc_remote_module(): ...@@ -86,7 +86,7 @@ def test_rpc_remote_module():
print('%g secs/op' % cost) print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
def check_remote_link_cl(): def check_remote_link_cl(remote):
"""Test function to run remote code such as cl """Test function to run remote code such as cl
This is not enabled because there is forking issue This is not enabled because there is forking issue
...@@ -134,7 +134,9 @@ def test_rpc_remote_module(): ...@@ -134,7 +134,9 @@ def test_rpc_remote_module():
fhost(a, b) fhost(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote() check_remote(client)
check_remote(rpc.LocalSession())
def test_rpc_return_func(): def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func") @tvm.register_func("rpc.test.remote_func")
...@@ -147,6 +149,21 @@ def test_rpc_return_func(): ...@@ -147,6 +149,21 @@ def test_rpc_return_func():
assert fadd(12) == 22 assert fadd(12) == 22
def test_local_func():
@tvm.register_func("rpc.test.remote_func2")
def addone(x):
return lambda y: x+y
client = rpc.LocalSession()
f1 = client.get_function("rpc.test.remote_func2")
fadd = f1(10)
assert fadd(12) == 22
blob = bytearray(np.random.randint(0, 10, size=(10)))
client.upload(blob, "dat.bin")
rev = client.download("dat.bin")
assert rev == blob
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_rpc_remote_module() test_rpc_remote_module()
...@@ -154,3 +171,4 @@ if __name__ == "__main__": ...@@ -154,3 +171,4 @@ if __name__ == "__main__":
test_rpc_file_exchange() test_rpc_file_exchange()
test_rpc_array() test_rpc_array()
test_rpc_simple() test_rpc_simple()
test_local_func()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment