Commit b3f09b01 by Tianqi Chen Committed by GitHub

[RPC] LocalSession to provide RPCSession back by local env (#1102)

parent 56a6ef31
......@@ -10,4 +10,4 @@ upload and run remote RPC server, get the result back to verify correctness.
"""
from .server import Server
from .client import RPCSession, connect, connect_tracker
from .client import RPCSession, LocalSession, connect, connect_tracker
......@@ -7,8 +7,11 @@ import struct
import time
from . import base
from .. import util
from ..._ffi.base import TVMError
from ..._ffi.ndarray import context as _context
from ..._ffi import function as function
from ..._ffi import ndarray as nd
from ...module import load as _load_module
class RPCSession(object):
......@@ -51,36 +54,12 @@ class RPCSession(object):
ctx: TVMContext
The corresponding encoded remote context.
"""
ctx = _context(dev_type, dev_id)
ctx = nd.context(dev_type, dev_id)
encode = (self._tbl_index + 1) * base.RPC_SESS_MASK
ctx.device_type += encode
ctx._rpc_sess = self
return ctx
def cpu(self, dev_id=0):
"""Construct remote CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct remote GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct remote OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def opengl(self, dev_id=0):
"""Construct remote OpenGL device."""
return self.context(11, dev_id)
def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)
def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
......@@ -139,6 +118,61 @@ class RPCSession(object):
"""
return base._LoadRemoteModule(self._sess, path)
def cpu(self, dev_id=0):
"""Construct CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct Metal device."""
return self.context(8, dev_id)
def opengl(self, dev_id=0):
"""Construct OpenGL device."""
return self.context(11, dev_id)
def ext_dev(self, dev_id=0):
"""Construct extension device."""
return self.context(12, dev_id)
class LocalSession(RPCSession):
"""RPCSession interface backed by local environment.
This class can be used to implement functions that
need to be ran both locally and remotely.
"""
def __init__(self):
# pylint: disable=super-init-not-called
self.context = nd.context
self.get_function = function.get_global_func
self._temp = util.tempdir()
def upload(self, data, target=None):
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
with open(self._temp.relpath(target), "wb") as f:
f.write(blob)
def download(self, path):
return bytearray(open(self._temp.relpath(path), "rb").read())
def load_module(self, path):
return _load_module(self._temp.relpath(path))
class TrackerSession(object):
"""Tracker client session.
......
......@@ -22,7 +22,7 @@ import time
from ..._ffi.function import register_func
from ..._ffi.base import py_str
from ...module import load as _load_module
from .. import util, cc, tar
from .. import util
from . import base
from . base import TrackerCode
......@@ -38,17 +38,6 @@ def _server_env():
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
# Try create a shared library in remote
if path.endswith(".o"):
logging.info("Create shared library based on %s", path)
cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = util.tempdir()
tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
cc.create_shared(path + ".so", files)
path += ".so"
m = _load_module(path)
logging.info("load_module %s", path)
return m
......
......@@ -186,7 +186,7 @@ def system_lib():
def load(path, fmt=""):
"""Load module from file
"""Load module from file.
Parameters
----------
......@@ -201,7 +201,24 @@ def load(path, fmt=""):
-------
module : Module
The loaded module
Note
----
This function will automatically call
cc.create_shared if the path is in format .o or .tar
"""
# High level handling for .o and .tar file.
# We support this to be consistent with RPC module load.
if path.endswith(".o"):
_cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = _util.tempdir()
_tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files)
path += ".so"
# Redirect to the load API
return _LoadFromFile(path, fmt)
......
......@@ -61,14 +61,14 @@ def test_rpc_remote_module():
if not tvm.module.enabled("rpc"):
return
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
client = rpc.connect(server.host, server.port)
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_remote():
def check_remote(remote):
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
......@@ -86,7 +86,7 @@ def test_rpc_remote_module():
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
def check_remote_link_cl():
def check_remote_link_cl(remote):
"""Test function to run remote code such as cl
This is not enabled because there is forking issue
......@@ -134,7 +134,9 @@ def test_rpc_remote_module():
fhost(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
check_remote(client)
check_remote(rpc.LocalSession())
def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func")
......@@ -147,6 +149,21 @@ def test_rpc_return_func():
assert fadd(12) == 22
def test_local_func():
@tvm.register_func("rpc.test.remote_func2")
def addone(x):
return lambda y: x+y
client = rpc.LocalSession()
f1 = client.get_function("rpc.test.remote_func2")
fadd = f1(10)
assert fadd(12) == 22
blob = bytearray(np.random.randint(0, 10, size=(10)))
client.upload(blob, "dat.bin")
rev = client.download("dat.bin")
assert rev == blob
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_remote_module()
......@@ -154,3 +171,4 @@ if __name__ == "__main__":
test_rpc_file_exchange()
test_rpc_array()
test_rpc_simple()
test_local_func()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment