Commit 6bd8dbc7 by Tianqi Chen Committed by GitHub

[RPC] Refactor, introduce tracker (#1080)

* [RPC] Refactor, introduce tracker

* [RPC] Change RPC hand shake convention, always get remote key.

* fix lint
parent c24e82e3
...@@ -280,7 +280,7 @@ lib/libtvm_runtime.${SHARED_LIBRARY_SUFFIX}: $(RUNTIME_DEP) ...@@ -280,7 +280,7 @@ lib/libtvm_runtime.${SHARED_LIBRARY_SUFFIX}: $(RUNTIME_DEP)
lib/libtvm_web_runtime.bc: web/web_runtime.cc lib/libtvm_web_runtime.bc: web/web_runtime.cc
@mkdir -p build/web @mkdir -p build/web
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT lib/libtvm_web_runtime.bc $< >build/web/web_runtime.d emcc $(EMCC_FLAGS) -MM -MT lib/libtvm_web_runtime.bc $< >build/web/web_runtime.d
emcc $(EMCC_FLAGS) -o $@ web/web_runtime.cc emcc $(EMCC_FLAGS) -o $@ web/web_runtime.cc
lib/libtvm_web_runtime.js: lib/libtvm_web_runtime.bc lib/libtvm_web_runtime.js: lib/libtvm_web_runtime.bc
......
...@@ -31,9 +31,12 @@ using FEventHandler = std::function<int(const std::string& in_bytes, int event_f ...@@ -31,9 +31,12 @@ using FEventHandler = std::function<int(const std::string& in_bytes, int event_f
* *
* \param outputStream The output stream used to send outputs. * \param outputStream The output stream used to send outputs.
* \param name The name of the server. * \param name The name of the server.
* \param remote_key The remote key
* \return The event handler. * \return The event handler.
*/ */
FEventHandler CreateServerEventHandler(NSOutputStream *outputStream, std::string name); FEventHandler CreateServerEventHandler(NSOutputStream *outputStream,
std::string name,
std::string remote_key);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../../src/runtime/cpu_device_api.cc" #include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc" #include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/module_util.cc" #include "../../src/runtime/module_util.cc"
#include "../../src/runtime/system_lib_module.cc" #include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/module.cc" #include "../../src/runtime/module.cc"
...@@ -59,9 +60,10 @@ class NSStreamChannel final : public RPCChannel { ...@@ -59,9 +60,10 @@ class NSStreamChannel final : public RPCChannel {
NSOutputStream* stream_; NSOutputStream* stream_;
}; };
FEventHandler CreateServerEventHandler(NSOutputStream *outputStream, std::string name) { FEventHandler CreateServerEventHandler(
NSOutputStream *outputStream, std::string name, std::string remote_key) {
std::unique_ptr<NSStreamChannel> ch(new NSStreamChannel(outputStream)); std::unique_ptr<NSStreamChannel> ch(new NSStreamChannel(outputStream));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name); std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name, remote_key);
return [sess](const std::string& in_bytes, int flag) { return [sess](const std::string& in_bytes, int flag) {
return sess->ServerEventHandler(in_bytes, flag); return sess->ServerEventHandler(in_bytes, flag);
}; };
......
...@@ -143,7 +143,7 @@ ...@@ -143,7 +143,7 @@
[outputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; [outputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode];
[outputStream_ open]; [outputStream_ open];
[inputStream_ open]; [inputStream_ open];
handler_ = tvm::runtime::CreateServerEventHandler(outputStream_, key_); handler_ = tvm::runtime::CreateServerEventHandler(outputStream_, key_, "%toinit");
CHECK(handler_ != nullptr); CHECK(handler_ != nullptr);
self.infoText.text = @""; self.infoText.text = @"";
self.statusLabel.text = @"Connecting..."; self.statusLabel.text = @"Connecting...";
...@@ -169,7 +169,6 @@ ...@@ -169,7 +169,6 @@
} }
- (IBAction)disconnect:(id)sender { - (IBAction)disconnect:(id)sender {
[self close]; [self close];
} }
......
...@@ -74,6 +74,9 @@ public class ConnectProxyServerProcessor implements ServerProcessor { ...@@ -74,6 +74,9 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
} else if (magic != RPC.RPC_MAGIC) { } else if (magic != RPC.RPC_MAGIC) {
throw new RuntimeException(address + " is not RPC Proxy"); throw new RuntimeException(address + " is not RPC Proxy");
} }
// Get key from remote
int keylen = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
String remoteKey = Utils.decodeToStr(Utils.recvAll(in, keylen));
System.err.println("RPCProxy connected to " + address); System.err.println("RPCProxy connected to " + address);
final int sockFd = socketFileDescriptorGetter.get(currSocket); final int sockFd = socketFileDescriptorGetter.get(currSocket);
......
...@@ -60,7 +60,12 @@ public class StandaloneServerProcessor implements ServerProcessor { ...@@ -60,7 +60,12 @@ public class StandaloneServerProcessor implements ServerProcessor {
out.write(Utils.toBytes(RPC.RPC_MAGIC + 2)); out.write(Utils.toBytes(RPC.RPC_MAGIC + 2));
} else { } else {
out.write(Utils.toBytes(RPC.RPC_MAGIC)); out.write(Utils.toBytes(RPC.RPC_MAGIC));
// send server key to the client
String serverKey = "server:java";
out.write(Utils.toBytes(serverKey.length()));
out.write(Utils.toBytes(serverKey));
} }
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
final int sockFd = socketFileDescriptorGetter.get(socket); final int sockFd = socketFileDescriptorGetter.get(socket);
if (sockFd != -1) { if (sockFd != -1) {
......
import time import time
from tvm.contrib import rpc_proxy from tvm.contrib.rpc import proxy
def start_proxy_server(port, timeout): def start_proxy_server(port, timeout):
prox = rpc_proxy.Proxy("localhost", port=port, port_end=port+1) prox = proxy.Proxy("localhost", port=port, port_end=port+1)
if timeout > 0: if timeout > 0:
import time import time
time.sleep(timeout) time.sleep(timeout)
...@@ -17,4 +17,3 @@ if __name__ == "__main__": ...@@ -17,4 +17,3 @@ if __name__ == "__main__":
port = int(sys.argv[1]) port = int(sys.argv[1])
timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2]) timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2])
start_proxy_server(port, timeout) start_proxy_server(port, timeout)
"""Minimum graph runtime that executes graph containing TVM PackedFunc.""" """Minimum graph runtime that executes graph containing TVM PackedFunc."""
from . import rpc
from .._ffi.base import string_types from .._ffi.base import string_types
from .._ffi.function import get_global_func from .._ffi.function import get_global_func
from .rpc import base as rpc_base
from .. import ndarray as nd from .. import ndarray as nd
...@@ -33,12 +33,12 @@ def create(graph_json_str, libmod, ctx): ...@@ -33,12 +33,12 @@ def create(graph_json_str, libmod, ctx):
raise ValueError("Type %s is not supported" % type(graph_json_str)) raise ValueError("Type %s is not supported" % type(graph_json_str))
device_type = ctx.device_type device_type = ctx.device_type
device_id = ctx.device_id device_id = ctx.device_id
if device_type >= rpc.RPC_SESS_MASK: if device_type >= rpc_base.RPC_SESS_MASK:
assert libmod.type_key == "rpc" assert libmod.type_key == "rpc"
assert rpc._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index assert rpc_base._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index
hmod = rpc._ModuleHandle(libmod) hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create") fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create")
device_type = device_type % rpc.RPC_SESS_MASK device_type = device_type % rpc_base.RPC_SESS_MASK
return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx) return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx)
fcreate = get_global_func("tvm.graph_runtime.create") fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx) return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx)
......
"""RPC interface for easy testing.
RPC enables connect to a remote server, upload and launch functions.
This is useful to for cross-compile and remote testing,
The compiler stack runs on local server, while we use RPC server
to run on remote runtime which don't have a compiler available.
The test program compiles the program on local server,
upload and run remote RPC server, get the result back to verify correctness.
"""
from __future__ import absolute_import
import os
import socket
import struct
import logging
import multiprocessing
import subprocess
import time
from . import util, cc, tar
from ..module import load as _load_module
from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context
from .._ffi.base import py_str
RPC_MAGIC = 0xff271
RPC_SESS_MASK = 128
def _server_env():
"""Server environment function return temp dir"""
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)
@register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
# Try create a shared library in remote
if path.endswith(".o"):
logging.info("Create shared library based on %s", path)
cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = util.tempdir()
tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
cc.create_shared(path + ".so", files)
path += ".so"
m = _load_module(path)
logging.info("load_module %s", path)
return m
return temp
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env()
_ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
def _recvall(sock, nbytes):
res = []
nread = 0
while nread < nbytes:
chunk = sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b"".join(res)
def _listen_loop(sock, exclusive):
"""Lisenting loop"""
last_proc = None
while True:
conn, addr = sock.accept()
if last_proc and last_proc.is_alive() and exclusive:
logging.info("Kill last call")
last_proc.terminate()
logging.info("RPCServer: connection from %s", addr)
magic = struct.unpack("@i", _recvall(conn, 4))[0]
if magic != RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack("@i", _recvall(conn, 4))[0]
key = py_str(_recvall(conn, keylen))
if not key.startswith("client:"):
conn.sendall(struct.pack("@i", RPC_MAGIC + 2))
else:
conn.sendall(struct.pack("@i", RPC_MAGIC))
logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True
process.start()
last_proc = process
# close from our side.
conn.close()
def _connect_proxy_loop(addr, key):
key = "server:" + key
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack("@i", RPC_MAGIC))
sock.sendall(struct.pack("@i", len(key)))
sock.sendall(key.encode("utf-8"))
magic = struct.unpack("@i", _recvall(sock, 4))[0]
if magic == RPC_MAGIC + 1:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != RPC_MAGIC:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr))
process.deamon = True
process.start()
process.join()
def _popen(cmd):
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
env=os.environ)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Server invoke error:\n"
msg += out
raise RuntimeError(msg)
class Server(object):
"""Start RPC server on a seperate process.
This is a simple python implementation based on multi-processing.
It is also possible to implement a similar C based sever with
TVM runtime which does not depend on the python.
Parameters
----------
host : str
The host url of the server.
port : int
The port to be bind to
port_end : int, optional
The end port to search
is_proxy : bool, optional
Whether the address specified is a proxy.
If this is true, the host and port actually corresponds to the
address of the proxy server.
use_popen : bool, optional
Whether to use Popen to start a fresh new process instead of fork.
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
exclusive : bool, optional
If this is enabled, the server will kill old connection
when new connection comes. This can make sure the current call
monopolize the hardware resource.
key : str, optional
The key used to identify the server in Proxy connection.
"""
def __init__(self,
host,
port=9091,
port_end=9199,
is_proxy=False,
use_popen=False,
exclusive=False,
key=""):
try:
if _ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
self.host = host
self.port = port
self.libs = []
if use_popen:
cmd = ["python",
"-m", "tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port]
self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
self.proc.start()
time.sleep(1)
elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(self.sock, exclusive))
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key))
self.proc.deamon = True
self.proc.start()
def terminate(self):
"""Terminate the server process"""
if self.proc:
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()
class RPCSession(object):
"""RPC Client session module
Do not directly create the obhect, call connect
"""
# pylint: disable=invalid-name
def __init__(self, sess):
self._sess = sess
self._tbl_index = _SessTableIndex(sess)
self._remote_funcs = {}
def get_function(self, name):
"""Get function from the session.
Parameters
----------
name : str
The name of the function
Returns
-------
f : Function
The result function.
"""
return self._sess.get_function(name)
def context(self, dev_type, dev_id=0):
"""Construct a remote context.
Parameters
----------
dev_type: int or str
dev_id: int, optional
Returns
-------
ctx: TVMContext
The corresponding encoded remote context.
"""
ctx = _context(dev_type, dev_id)
encode = (self._tbl_index + 1) * RPC_SESS_MASK
ctx.device_type += encode
ctx._rpc_sess = self
return ctx
def cpu(self, dev_id=0):
"""Construct remote CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct remote GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct remote OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def opengl(self, dev_id=0):
"""Construct remote OpenGL device."""
return self.context(11, dev_id)
def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)
def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
Parameters
----------
data : str or bytearray
The file name or binary in local to upload.
target : str, optional
The path in remote
"""
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload")
self._remote_funcs["upload"](target, blob)
def download(self, path):
"""Download file from remote temp folder.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
blob : bytearray
The result blob from the file.
"""
if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download")
return self._remote_funcs["download"](path)
def load_module(self, path):
"""Load a remote module, the file need to be uploaded first.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
m : Module
The remote module containing remote function.
"""
return _LoadRemoteModule(self._sess, path)
def connect(url, port, key=""):
"""Connect to RPC Server
Parameters
----------
url : str
The url of the host
port : int
The port to connect to
key : str, optional
Additional key to match server
Returns
-------
sess : RPCSession
The connected session.
"""
try:
sess = _Connect(url, port, key)
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
return RPCSession(sess)
_init_api("tvm.contrib.rpc")
"""Lightweight TVM RPC module.
RPC enables connect to a remote server, upload and launch functions.
This is useful to for cross-compile and remote testing,
The compiler stack runs on local server, while we use RPC server
to run on remote runtime which don't have a compiler available.
The test program compiles the program on local server,
upload and run remote RPC server, get the result back to verify correctness.
"""
from .server import Server
from .client import RPCSession, connect, connect_tracker
"""Base definitions for RPC."""
from __future__ import absolute_import
import socket
import time
import json
import errno
import struct
import random
import logging
from ..._ffi.function import _init_api
from ..._ffi.base import py_str
# Magic header for RPC data plane
RPC_MAGIC = 0xff271
# magic header for RPC tracker(control plane)
RPC_TRACKER_MAGIC = 0x2f271
# sucess response
RPC_CODE_SUCCESS = RPC_MAGIC + 0
# duplicate key in proxy
RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2
class TrackerCode(object):
"""Enumeration code for the RPC tracker"""
FAIL = -1
SUCCESS = 0
PING = 1
STOP = 2
PUT = 3
REQUEST = 4
RPC_SESS_MASK = 128
def recvall(sock, nbytes):
"""Receive all nbytes from socket.
Parameters
----------
sock: Socket
The socket
nbytes : int
Number of bytes to be received.
"""
res = []
nread = 0
while nread < nbytes:
chunk = sock.recv(min(nbytes - nread, 1024))
if not chunk:
raise IOError("connection reset")
nread += len(chunk)
res.append(chunk)
return b"".join(res)
def sendjson(sock, data):
"""send a python value to remote via json
Parameters
----------
sock : Socket
The socket
data : object
Python value to be sent.
"""
data = json.dumps(data)
sock.sendall(struct.pack("@i", len(data)))
sock.sendall(data.encode("utf-8"))
def recvjson(sock):
"""receive python value from remote via json
Parameters
----------
sock : Socket
The socket
Returns
-------
value : object
The value received.
"""
size = struct.unpack("@i", recvall(sock, 4))[0]
data = json.loads(py_str(recvall(sock, size)))
return data
def random_key():
"""Generate a random key n"""
return str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5):
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
Parameters
----------
addr : tuple
address tuple
timeout : float
Timeout during retry
retry_period : float
Number of seconds before we retry again.
"""
tstart = time.time()
while True:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
return sock
except socket.error as sock_err:
if sock_err.args[0] not in (errno.ECONNREFUSED,):
raise sock_err
period = time.time() - tstart
if period > timeout:
raise RuntimeError(
"Failed to connect to server %s" % str(addr))
logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period)
time.sleep(retry_period)
# Still use tvm.contrib.rpc for the foreign functions
_init_api("tvm.contrib.rpc", "tvm.contrib.rpc.base")
"""RPC client tools"""
from __future__ import absolute_import
import os
import socket
import struct
from . import base
from ..._ffi.base import TVMError
from ..._ffi.ndarray import context as _context
class RPCSession(object):
"""RPC Client session module
Do not directly create the obhect, call connect
"""
# pylint: disable=invalid-name
def __init__(self, sess):
self._sess = sess
self._tbl_index = base._SessTableIndex(sess)
self._remote_funcs = {}
def get_function(self, name):
"""Get function from the session.
Parameters
----------
name : str
The name of the function
Returns
-------
f : Function
The result function.
"""
return self._sess.get_function(name)
def context(self, dev_type, dev_id=0):
"""Construct a remote context.
Parameters
----------
dev_type: int or str
dev_id: int, optional
Returns
-------
ctx: TVMContext
The corresponding encoded remote context.
"""
ctx = _context(dev_type, dev_id)
encode = (self._tbl_index + 1) * base.RPC_SESS_MASK
ctx.device_type += encode
ctx._rpc_sess = self
return ctx
def cpu(self, dev_id=0):
"""Construct remote CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct remote GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct remote OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def opengl(self, dev_id=0):
"""Construct remote OpenGL device."""
return self.context(11, dev_id)
def ext_dev(self, dev_id=0):
"""Construct remote extension device."""
return self.context(12, dev_id)
def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
Parameters
----------
data : str or bytearray
The file name or binary in local to upload.
target : str, optional
The path in remote
"""
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload")
self._remote_funcs["upload"](target, blob)
def download(self, path):
"""Download file from remote temp folder.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
blob : bytearray
The result blob from the file.
"""
if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download")
return self._remote_funcs["download"](path)
def load_module(self, path):
"""Load a remote module, the file need to be uploaded first.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
m : Module
The remote module containing remote function.
"""
return base._LoadRemoteModule(self._sess, path)
class TrackerSession(object):
"""Tracker client session.
Parameters
----------
addr : tuple
The address tuple
"""
def __init__(self, addr):
self._addr = addr
self._sock = None
self._max_request_retry = 5
self._connect()
def __del__(self):
self.close()
def _connect(self):
self._sock = base.connect_with_retry(self._addr)
self._sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(self._sock, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(self._addr))
def close(self):
"""Close the tracker connection."""
if self._sock:
self._sock.close()
self._sock = None
def request(self, key, priority=1, session_timeout=0):
"""Request a new connection from the tracker.
Parameters
----------
key : str
The type key of the device.
priority : int, optional
The priority of the request.
session_timeout : float, optional
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
"""
for _ in range(self._max_request_retry):
try:
if self._sock is None:
self._connect()
base.sendjson(self._sock,
[base.TrackerCode.REQUEST, key, "", priority])
value = base.recvjson(self._sock)
if value[0] != base.TrackerCode.SUCCESS:
raise RuntimeError("Invalid return value %s" % str(value))
url, port, matchkey = value[1]
return connect(url, port, key + matchkey, session_timeout)
except socket.error:
self.close()
except TVMError:
pass
def connect(url, port, key="", session_timeout=0):
"""Connect to RPC Server
Parameters
----------
url : str
The url of the host
port : int
The port to connect to
key : str, optional
Additional key to match server
session_timeout : float, optional
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
Returns
-------
sess : RPCSession
The connected session.
"""
try:
if session_timeout:
key += " -timeout=%s" % str(session_timeout)
sess = base._Connect(url, port, key)
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
return RPCSession(sess)
def connect_tracker(url, port):
"""Connect to a RPC tracker
Parameters
----------
url : str
The url of the host
port : int
The port to connect to
Returns
-------
sess : TrackerSession
The connected tracker session.
"""
return TrackerSession((url, port))
...@@ -14,17 +14,20 @@ import socket ...@@ -14,17 +14,20 @@ import socket
import multiprocessing import multiprocessing
import errno import errno
import struct import struct
try: try:
import tornado import tornado
from tornado import gen from tornado import gen
from tornado import websocket from tornado import websocket
from tornado import ioloop from tornado import ioloop
from tornado import websocket from . import tornado_util
except ImportError as error_msg: except ImportError as error_msg:
raise ImportError("RPCProxy module requires tornado package %s" % error_msg) raise ImportError("RPCProxy module requires tornado package %s" % error_msg)
from . import rpc
from .rpc import RPC_MAGIC, _server_env from . import base
from .._ffi.base import py_str from .base import RPC_MAGIC, RPC_CODE_DUPLICATE, RPC_CODE_SUCCESS, RPC_CODE_MISMATCH
from .server import _server_env
from ..._ffi.base import py_str
class ForwardHandler(object): class ForwardHandler(object):
"""Forward handler to forward the message.""" """Forward handler to forward the message."""
...@@ -102,102 +105,38 @@ class ForwardHandler(object): ...@@ -102,102 +105,38 @@ class ForwardHandler(object):
"""on close event""" """on close event"""
assert not self._done assert not self._done
logging.info("RPCProxy:on_close %s ...", self.name()) logging.info("RPCProxy:on_close %s ...", self.name())
self._done = True
self.forward_proxy = None
if self.rpc_key: if self.rpc_key:
key = self.rpc_key[6:] key = self.rpc_key[7:]
if ProxyServerHandler.current._client_pool.get(key, None) == self: if ProxyServerHandler.current._client_pool.get(key, None) == self:
ProxyServerHandler.current._client_pool.pop(key) ProxyServerHandler.current._client_pool.pop(key)
if ProxyServerHandler.current._server_pool.get(key, None) == self: if ProxyServerHandler.current._server_pool.get(key, None) == self:
ProxyServerHandler.current._server_pool.pop(key) ProxyServerHandler.current._server_pool.pop(key)
self._done = True
self.forward_proxy = None
class TCPHandler(ForwardHandler): class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
"""Event driven TCP handler.""" """Event driven TCP handler."""
def __init__(self, sock, addr): def __init__(self, sock, addr):
super(TCPHandler, self).__init__(sock)
self._init_handler() self._init_handler()
self.sock = sock
assert self.sock
self.addr = addr self.addr = addr
self.loop = ioloop.IOLoop.current()
self.sock.setblocking(0)
self.pending_write = []
self._signal_close = False
def event_handler(_, events):
self._on_event(events)
ioloop.IOLoop.current().add_handler(
self.sock.fileno(), event_handler, self.loop.READ | self.loop.ERROR)
def name(self): def name(self):
return "TCPSocket: %s:%s" % (str(self.addr), self.rpc_key) return "TCPSocket: %s:%s" % (str(self.addr), self.rpc_key)
def send_data(self, message, binary=True): def send_data(self, message, binary=True):
assert binary self.write_message(message, True)
self.pending_write.append(message)
self._on_write()
def _on_write(self):
while self.pending_write:
try:
msg = self.pending_write[0]
nsend = self.sock.send(msg)
if nsend != len(msg):
self.pending_write[0] = msg[nsend:]
else:
del self.pending_write[0]
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
else:
self.on_error(err)
if self.pending_write:
self.loop.update_handler(
self.sock.fileno(), self.loop.READ | self.loop.ERROR | self.loop.WRITE)
else:
if self._signal_close:
self.close()
else:
self.loop.update_handler(
self.sock.fileno(), self.loop.READ | self.loop.ERROR)
def _on_read(self): def on_message(self, message):
try: self.on_data(message)
msg = bytes(self.sock.recv(4096))
if msg:
self.on_data(msg)
return True
else:
self.close_pair()
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
pass
else:
self.on_error(e)
return False
def _on_event(self, events):
if (events & self.loop.ERROR) or (events & self.loop.READ):
if self._on_read() and (events & self.loop.WRITE):
self._on_write()
elif events & self.loop.WRITE:
self._on_write()
def signal_close(self):
if not self.pending_write:
self.close()
else:
self._signal_close = True
def close(self): def on_close(self):
if self.sock is not None: if self.forward_proxy:
logging.info("%s Close socket..", self.name()) self.forward_proxy.signal_close()
try: self.forward_proxy = None
ioloop.IOLoop.current().remove_handler(self.sock.fileno()) logging.info("%s Close socket..", self.name())
self.sock.close() self.on_close_event()
except socket.error:
pass
self.sock = None
self.on_close_event()
class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler): class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
...@@ -207,7 +146,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler): ...@@ -207,7 +146,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
self._init_handler() self._init_handler()
def name(self): def name(self):
return "WebSocketProxy" return "WebSocketProxy: %s" % (self.rpc_key)
def on_message(self, message): def on_message(self, message):
self.on_data(message) self.on_data(message)
...@@ -234,12 +173,16 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler): ...@@ -234,12 +173,16 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
class RequestHandler(tornado.web.RequestHandler): class RequestHandler(tornado.web.RequestHandler):
"""Handles html request.""" """Handles html request."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.page = open(kwargs.pop("file_path")).read() file_path = kwargs.pop("file_path")
web_port = kwargs.pop("rpc_web_port", None) if file_path.endswith("html"):
if web_port: self.page = open(file_path).read()
self.page = self.page.replace( web_port = kwargs.pop("rpc_web_port", None)
"ws://localhost:9190/ws", if web_port:
"ws://localhost:%d/ws" % web_port) self.page = self.page.replace(
"ws://localhost:9190/ws",
"ws://localhost:%d/ws" % web_port)
else:
self.page = open(file_path, "rb").read()
super(RequestHandler, self).__init__(*args, **kwargs) super(RequestHandler, self).__init__(*args, **kwargs)
def data_received(self, _): def data_received(self, _):
...@@ -303,14 +246,22 @@ class ProxyServerHandler(object): ...@@ -303,14 +246,22 @@ class ProxyServerHandler(object):
def _pair_up(self, lhs, rhs): def _pair_up(self, lhs, rhs):
lhs.forward_proxy = rhs lhs.forward_proxy = rhs
rhs.forward_proxy = lhs rhs.forward_proxy = lhs
lhs.send_data(struct.pack('@i', RPC_MAGIC))
rhs.send_data(struct.pack('@i', RPC_MAGIC)) lhs.send_data(struct.pack('@i', RPC_CODE_SUCCESS))
lhs.send_data(struct.pack('@i', len(rhs.rpc_key)))
lhs.send_data(rhs.rpc_key.encode("utf-8"))
rhs.send_data(struct.pack('@i', RPC_CODE_SUCCESS))
rhs.send_data(struct.pack('@i', len(lhs.rpc_key)))
rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name()) logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
def handler_ready(self, handler): def handler_ready(self, handler):
"""Report handler to be ready.""" """Report handler to be ready."""
logging.info("Handler ready %s", handler.name()) logging.info("Handler ready %s", handler.name())
key = handler.rpc_key[6:] key = handler.rpc_key[7:]
if handler.rpc_key.startswith("server:"): if handler.rpc_key.startswith("server:"):
pool_src, pool_dst = self._client_pool, self._server_pool pool_src, pool_dst = self._client_pool, self._server_pool
timeout = self.timeout_server timeout = self.timeout_server
...@@ -329,18 +280,19 @@ class ProxyServerHandler(object): ...@@ -329,18 +280,19 @@ class ProxyServerHandler(object):
logging.info("Timeout client connection %s, cannot find match key=%s", logging.info("Timeout client connection %s, cannot find match key=%s",
handler.name(), key) handler.name(), key)
pool_dst.pop(key) pool_dst.pop(key)
handler.send_data(struct.pack('@i', RPC_MAGIC + 2)) handler.send_data(struct.pack('@i', RPC_CODE_MISMATCH))
handler.signal_close() handler.signal_close()
self.loop.call_later(timeout, cleanup) self.loop.call_later(timeout, cleanup)
else: else:
logging.info("Duplicate connection with same key=%s", key) logging.info("Duplicate connection with same key=%s", key)
handler.send_data(struct.pack('@i', RPC_MAGIC + 1)) handler.send_data(struct.pack('@i', RPC_CODE_DUPLICATE))
handler.signal_close() handler.signal_close()
def run(self): def run(self):
"""Run the proxy server""" """Run the proxy server"""
ioloop.IOLoop.current().start() ioloop.IOLoop.current().start()
def _proxy_server(listen_sock, def _proxy_server(listen_sock,
web_port, web_port,
timeout_client, timeout_client,
...@@ -446,7 +398,8 @@ def websocket_proxy_server(url, key=""): ...@@ -446,7 +398,8 @@ def websocket_proxy_server(url, key=""):
data = bytes(data) data = bytes(data)
conn.write_message(data, binary=True) conn.write_message(data, binary=True)
return len(data) return len(data)
on_message = rpc._CreateEventDrivenServer(_fsend, "WebSocketProxyServer") on_message = base._CreateEventDrivenServer(
_fsend, "WebSocketProxyServer", "%toinit")
return on_message return on_message
@gen.coroutine @gen.coroutine
...@@ -462,14 +415,16 @@ def websocket_proxy_server(url, key=""): ...@@ -462,14 +415,16 @@ def websocket_proxy_server(url, key=""):
msg = yield conn.read_message() msg = yield conn.read_message()
assert len(msg) >= 4 assert len(msg) >= 4
magic = struct.unpack('@i', msg[:4])[0] magic = struct.unpack('@i', msg[:4])[0]
if magic == RPC_MAGIC + 1: if magic == RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2: elif magic == RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key) logging.info("RPCProxy do not have matching client key %s", key)
elif magic != RPC_MAGIC: elif magic != RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % url) raise RuntimeError("%s is not RPC Proxy" % url)
logging.info("Connection established")
msg = msg[4:] msg = msg[4:]
logging.info("Connection established with remote")
if msg: if msg:
on_message(bytearray(msg), 3) on_message(bytearray(msg), 3)
......
"""RPC server implementation.
Note
----
Server is TCP based with the following protocol:
- Initial handshake to the peer
- [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
- {server|client}:device-type[:matchkey] [-timeout=timeout]
"""
from __future__ import absolute_import
import os
import socket
import select
import struct
import logging
import multiprocessing
import subprocess
import time
from ..._ffi.function import register_func
from ..._ffi.base import py_str
from ...module import load as _load_module
from .. import util, cc, tar
from . import base
from . base import TrackerCode
def _server_env():
"""Server environment function return temp dir"""
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)
@register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
# Try create a shared library in remote
if path.endswith(".o"):
logging.info("Create shared library based on %s", path)
cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = util.tempdir()
tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
cc.create_shared(path + ".so", files)
path += ".so"
m = _load_module(path)
logging.info("load_module %s", path)
return m
return temp
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env()
base._ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
def _parse_server_opt(opts):
# parse client options
ret = {}
for kv in opts:
if kv.startswith("-timeout="):
ret["timeout"] = float(kv[9:])
return ret
def _listen_loop(sock, port, rpc_key, tracker_addr):
"""Lisenting loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=0.1):
"""Accept connection from the other places.
Parameters
----------
listen_sock: Socket
The socket used by listening process.
tracker_conn : connnection to tracker
Tracker connection
ping_period : float, optional
ping tracker every k seconds if no connection is accepted.
"""
# Report resource to tracker
if tracker_conn:
matchkey = ":" + base.random_key()
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
else:
matchkey = ""
# Wait until we get a valid connection
while True:
if tracker_conn:
trigger = select.select([listen_sock], [], [], ping_period)
if not listen_sock in trigger[0]:
base.sendjson(tracker_conn, [TrackerCode.PING])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
continue
conn, addr = listen_sock.accept()
magic = struct.unpack("@i", base.recvall(conn, 4))[0]
if magic != base.RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack("@i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen))
arr = key.split()
expect_header = "client:" + rpc_key + matchkey
server_key = "server:" + rpc_key
if arr[0] != expect_header:
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH))
conn.close()
logging.info("RPCServer: mismatch key from %s", addr)
continue
else:
conn.sendall(struct.pack("@i", base.RPC_CODE_SUCCESS))
conn.sendall(struct.pack("@i", len(server_key)))
conn.sendall(server_key.encode("utf-8"))
return conn, addr, _parse_server_opt(arr[1:])
# Server logic
tracker_conn = None
while True:
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
try:
# step 2: wait for in-coming connections
conn, addr, opts = _accept_conn(sock, tracker_conn)
except (socket.error, IOError):
# retry when tracker is dropped
tracker_conn.close()
tracker_conn = None
continue
# step 3: serving
logging.info("RPCServer: connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
server_proc.deamon = True
server_proc.start()
# close from our side.
conn.close()
# wait until server process finish or timeout
server_proc.join(opts.get("timeout", None))
if server_proc.is_alive():
logging.info("Timeout in RPC session, kill..")
server_proc.terminate()
def _connect_proxy_loop(addr, key):
key = "server:" + key
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack("@i", base.RPC_MAGIC))
sock.sendall(struct.pack("@i", len(key)))
sock.sendall(key.encode("utf-8"))
magic = struct.unpack("@i", base.recvall(sock, 4))[0]
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
logging.info("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("@i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logging.info("RPCProxy connected to %s", str(addr))
process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr))
process.deamon = True
process.start()
sock.close()
process.join(opts.get("timeout", None))
if process.is_alive():
logging.info("Timeout in RPC session, kill..")
process.terminate()
def _popen(cmd):
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=os.environ)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Server invoke error:\n"
msg += out
raise RuntimeError(msg)
class Server(object):
"""Start RPC server on a seperate process.
This is a simple python implementation based on multi-processing.
It is also possible to implement a similar C based sever with
TVM runtime which does not depend on the python.
Parameters
----------
host : str
The host url of the server.
port : int
The port to be bind to
port_end : int, optional
The end port to search
is_proxy : bool, optional
Whether the address specified is a proxy.
If this is true, the host and port actually corresponds to the
address of the proxy server.
use_popen : bool, optional
Whether to use Popen to start a fresh new process instead of fork.
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
key : str, optional
The key used to identify the server in Proxy connection.
"""
def __init__(self,
host,
port=9091,
port_end=9199,
is_proxy=False,
use_popen=False,
tracker_addr=None,
key=""):
try:
if base._ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
self.host = host
self.port = port
self.libs = []
if use_popen:
cmd = ["python",
"-m", "tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port]
if tracker_addr:
assert key
cmd += ["--tracker=%s:%d" % tracker_addr,
"--key=%s" % key]
self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
self.proc.start()
time.sleep(1)
elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr))
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key))
self.proc.deamon = True
self.proc.start()
def terminate(self):
"""Terminate the server process"""
if self.proc:
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()
"""Utilities used in tornado."""
import socket
import errno
from tornado import ioloop
class TCPHandler(object):
"""TCP socket handler backed tornado event loop.
Parameters
----------
sock : Socket
The TCP socket, will set it to non-blocking mode.
"""
def __init__(self, sock):
self._sock = sock
self._ioloop = ioloop.IOLoop.current()
self._sock.setblocking(0)
self._pending_write = []
self._signal_close = False
def _event_handler(_, events):
self._event_handler(events)
self._ioloop.add_handler(
self._sock.fileno(), _event_handler,
self._ioloop.READ | self._ioloop.ERROR)
def signal_close(self):
"""Signal the handler to close.
The handler will be closed after the existing
pending message are sent to the peer.
"""
if not self._pending_write:
self.close()
else:
self._signal_close = True
def close(self):
"""Close the socket"""
if self._sock is not None:
try:
self._ioloop.remove_handler(self._sock.fileno())
self._sock.close()
except socket.error:
pass
self._sock = None
self.on_close()
def write_message(self, message, binary=True):
assert binary
self._pending_write.append(message)
self._update_write()
def _event_handler(self, events):
"""centeral event handler"""
if (events & self._ioloop.ERROR) or (events & self._ioloop.READ):
if self._update_read() and (events & self._ioloop.WRITE):
self._update_write()
elif events & self._ioloop.WRITE:
self._update_write()
def _update_write(self):
"""Update the state on write"""
while self._pending_write:
try:
msg = self._pending_write[0]
nsend = self._sock.send(msg)
if nsend != len(msg):
self._pending_write[0] = msg[nsend:]
else:
self._pending_write.pop(0)
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
else:
self.on_error(err)
if self._pending_write:
self._ioloop.update_handler(
self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR | self._ioloop.WRITE)
else:
if self._signal_close:
self.close()
else:
self._ioloop.update_handler(
self._sock.fileno(), self._ioloop.READ | self._ioloop.ERROR)
def _update_read(self):
"""Update state when there is read event"""
try:
msg = bytes(self._sock.recv(4096))
if msg:
self.on_message(msg)
return True
else:
# normal close, remote is closed
self.close()
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
pass
else:
self.on_error(err)
return False
"""RPC Tracker, tracks and distributes the TVM RPC resources.
This folder implemements the tracker server logic.
Note
----
Tracker is a TCP based rest api with the following protocol:
- Initial handshake to the peer
- RPC_TRACKER_MAGIC
- Normal message: [size(int32), json-data]
- Each message is initiated by the client, and the tracker replies with a json.
List of available APIs:
- PING: check if tracker is alive
- input: [TrackerCode.PING]
- return: TrackerCode.SUCCESS
- PUT: report resource to tracker
- input: [TrackerCode.PUT, [port, match-key]]
- return: TrackerCode.SUCCESS
- note: match-key is a randomly generated identify the resource during connection.
- REQUEST: request a new resource from tracker
- input: [TrackerCode.REQUEST, [key, user, priority]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
import heapq
import time
import logging
import socket
import multiprocessing
import errno
import struct
import json
try:
from tornado import ioloop
from . import tornado_util
except ImportError as error_msg:
raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg)
from ..._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode
class Scheduler(object):
"""Abstratc interface of scheduler."""
def put(self, value):
"""Push a resource into the scheduler.
This function can trigger callbacks in the scheduler.
Parameters
----------
value : object
The resource to be put in the scheduler.
"""
raise NotImplementedError()
def request(self, user, priority, callback):
"""Request a resource.
Parameters
----------
user : str
The user who is requesting the resource.
priority : int
The job priority
callback : function: value->bool
Callback function to receive an resource when ready
returns True if the resource is consumed.
"""
raise NotImplementedError()
class PriorityScheduler(Scheduler):
"""Priority based scheduler, FIFO based on time"""
def __init__(self):
self._values = []
self._requests = []
def _schedule(self):
while self._requests and self._values:
value = self._values.pop(0)
item = heapq.heappop(self._requests)
callback = item[-1]
if not callback(value):
self._values.append(value)
def put(self, value):
self._values.append(value)
self._schedule()
def request(self, user, priority, callback):
heapq.heappush(self._requests, (-priority, time.time(), callback))
self._schedule()
class TCPEventHandler(tornado_util.TCPHandler):
"""Base asynchronize message handler.
The tracker and client follows a simple message protocol.
The message is in form [nbytes(int32)] [json-str].
All the information is packed in json-str
"""
def __init__(self, tracker, sock, addr):
super(TCPEventHandler, self).__init__(sock)
self._data = bytearray()
self._tracker = tracker
self._msg_size = 0
self._addr = addr
self._init_req_nbytes = 4
self._tracker._connections.add(self)
def name(self):
"""name of connection"""
return "TCPSocket: %s" % str(self._addr)
def _init_conn(self, message):
"""Initialie the connection"""
if len(message) != 4:
logging.info("Invalid connection from %s", self.name())
self.close()
magic = struct.unpack('@i', message)[0]
if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name())
self.close()
self.write_message(struct.pack('@i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0
def on_message(self, message):
"""Callback when a message is received.
Parameters
----------
message : bytearray
The bytes received
"""
assert isinstance(message, bytes)
if self._init_req_nbytes:
self._init_conn(message)
return
self._data += message
while True:
if self._msg_size == 0:
if len(self._data) >= 4:
self._msg_size = struct.unpack('@i', self._data[:4])[0]
else:
return
if self._msg_size != 0 and len(self._data) >= self._msg_size + 4:
msg = py_str(bytes(self._data[4:4 + self._msg_size]))
del self._data[:4 + self._msg_size]
self._msg_size = 0
# pylint: disable=broad-except
self.call_handler(json.loads(msg))
else:
return
def ret_value(self, data):
"""return value to the output"""
data = json.dumps(data)
self.write_message(
struct.pack('@i', len(data)), binary=True)
self.write_message(data.encode("utf-8"), binary=True)
def call_handler(self, args):
"""Event handler when json request arrives."""
code = args[0]
if code == TrackerCode.PUT:
key = args[1]
port, matchkey = args[2]
self._tracker.put(key, (self._addr[0], port, matchkey))
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST:
key = args[1]
user = args[2]
priority = args[3]
def _cb(value):
self.ret_value([TrackerCode.SUCCESS, value])
return True
self._tracker.request(key, user, priority, _cb)
elif code == TrackerCode.PING:
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.STOP:
# safe stop tracker
if self._tracker._stop_key == args[1]:
self.ret_value(TrackerCode.SUCCESS)
self._tracker.stop()
else:
self.ret_value(TrackerCode.FAIL)
else:
logging.info("Unknown tracker code %d", code)
self.close()
def on_close(self):
self._tracker._connections.remove(self)
def on_error(self, err):
logging.info("%s: Error in RPC Tracker: %s", self.name(), err)
self.close()
class TrackerServerHandler(object):
"""Tracker that tracks the resources."""
def __init__(self, sock, stop_key):
self._scheduler_map = {}
self._sock = sock
self._sock.setblocking(0)
self._ioloop = ioloop.IOLoop.current()
self._stop_key = stop_key
self._connections = set()
def _event_handler(_, events):
self._on_event(events)
self._ioloop.add_handler(
self._sock.fileno(), _event_handler, self._ioloop.READ)
def _on_event(self, _):
while True:
try:
conn, addr = self._sock.accept()
TCPEventHandler(self, conn, addr)
except socket.error as err:
if err.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
break
def create_scheduler(self, key):
"""Create a new scheduler."""
_ = key
return PriorityScheduler()
def put(self, key, value):
"""Report a new resource to the tracker."""
if key not in self._scheduler_map:
self._scheduler_map[key] = self.create_scheduler(key)
self._scheduler_map[key].put(value)
def request(self, key, user, priority, callback):
"""Request a new resource."""
if key not in self._scheduler_map:
self._scheduler_map[key] = self.create_scheduler(key)
self._scheduler_map[key].request(user, priority, callback)
def stop(self):
"""Safely stop tracker."""
for conn in list(self._connections):
conn.close()
self._sock.close()
self._ioloop.stop()
def run(self):
"""Run the tracker server"""
self._ioloop.start()
def _tracker_server(listen_sock, stop_key):
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()
logging.info("Tracker Stop signal received, terminating...")
class Tracker(object):
"""Start RPC tracker on a seperate process.
Python implementation based on multi-processing.
Parameters
----------
host : str
The host url of the server.
port : int
The TCP port to be bind to
port_end : int, optional
The end TCP port to search
"""
def __init__(self,
host,
port=9190,
port_end=9199):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key()
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCTracker: bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key))
self.proc.start()
self.host = host
# close the socket on this process
sock.close()
def _stop_tracker(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port))
sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("@i", base.recvall(sock, 4))[0]
assert magic == base.RPC_TRACKER_MAGIC
base.sendjson(sock, [TrackerCode.STOP, self.stop_key])
assert base.recvjson(sock) == TrackerCode.SUCCESS
sock.close()
def terminate(self):
"""Terminate the server process"""
if self.proc:
if self.proc.is_alive():
self._stop_tracker()
self.proc.join(1)
if self.proc.is_alive():
logging.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
import logging import logging
import argparse import argparse
import os import os
from ..contrib.rpc_proxy import Proxy from ..contrib.rpc.proxy import Proxy
def find_example_resource(): def find_example_resource():
"""Find resource examples.""" """Find resource examples."""
......
...@@ -17,13 +17,14 @@ def main(): ...@@ -17,13 +17,14 @@ def main():
help='The port of the PRC') help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199, parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC') help='The end search port of the PRC')
parser.add_argument('--key', type=str, default="",
help="RPC key used to identify the connection type.")
parser.add_argument('--with-executor', type=bool, default=False, parser.add_argument('--with-executor', type=bool, default=False,
help="Whether to load executor runtime") help="Whether to load executor runtime")
parser.add_argument('--load-library', type=str, default="", parser.add_argument('--load-library', type=str, default="",
help="Additional library to load") help="Additional library to load")
parser.add_argument('--exclusive', action='store_true', parser.add_argument('--tracker', type=str, default="",
help="If this is enabled, the server will kill old connection" help="Report to RPC tracker")
"when new connection comes")
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -38,7 +39,21 @@ def main(): ...@@ -38,7 +39,21 @@ def main():
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name) logging.info("Load additional library %s", file_name)
server = rpc.Server(args.host, args.port, args.port_end, exclusive=args.exclusive) if args.tracker:
url, port = args.tracker.split(":")
port = int(port)
tracker_addr = (url, port)
if not args.key:
raise RuntimeError(
"Need key to present type of resource when tracker is available")
else:
tracker_addr = None
server = rpc.Server(args.host,
args.port,
args.port_end,
key=args.key,
tracker_addr=tracker_addr)
server.libs += libs server.libs += libs
server.proc.join() server.proc.join()
......
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
from __future__ import absolute_import
import logging
import argparse
from ..contrib.rpc.tracker import Tracker
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9190,
help='The port of the PRC')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
tracker = Tracker(args.host, port=args.port)
tracker.proc.join()
if __name__ == "__main__":
main()
...@@ -32,9 +32,12 @@ class CallbackChannel final : public RPCChannel { ...@@ -32,9 +32,12 @@ class CallbackChannel final : public RPCChannel {
PackedFunc fsend_; PackedFunc fsend_;
}; };
PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) { PackedFunc CreateEventDrivenServer(PackedFunc fsend,
std::string name,
std::string remote_key) {
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend)); std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name); std::shared_ptr<RPCSession> sess =
RPCSession::Create(std::move(ch), name, remote_key);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
int ret = sess->ServerEventHandler(args[0], args[1]); int ret = sess->ServerEventHandler(args[0], args[1]);
*rv = ret; *rv = ret;
...@@ -43,7 +46,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) { ...@@ -43,7 +46,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) {
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer") TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEventDrivenServer(args[0], args[1]); *rv = CreateEventDrivenServer(args[0], args[1], args[2]);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -49,6 +49,7 @@ class RPCModuleNode final : public ModuleNode { ...@@ -49,6 +49,7 @@ class RPCModuleNode final : public ModuleNode {
~RPCModuleNode() { ~RPCModuleNode() {
if (module_handle_ != nullptr) { if (module_handle_ != nullptr) {
sess_->CallRemote(RPCCode::kModuleFree, module_handle_); sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
module_handle_ = nullptr;
} }
} }
...@@ -198,5 +199,6 @@ TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex") ...@@ -198,5 +199,6 @@ TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
CHECK_EQ(tkey, "rpc"); CHECK_EQ(tkey, "rpc");
*rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index(); *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -49,11 +49,19 @@ class RPCSession::EventHandler { ...@@ -49,11 +49,19 @@ class RPCSession::EventHandler {
EventHandler(common::RingBuffer* reader, EventHandler(common::RingBuffer* reader,
common::RingBuffer* writer, common::RingBuffer* writer,
int rpc_sess_table_index, int rpc_sess_table_index,
std::string name) std::string name,
: reader_(reader), writer_(writer), std::string* remote_key)
: reader_(reader),
writer_(writer),
rpc_sess_table_index_(rpc_sess_table_index), rpc_sess_table_index_(rpc_sess_table_index),
name_(name) { name_(name),
remote_key_(remote_key) {
this->Clear(); this->Clear();
if (*remote_key == "%toinit") {
state_ = kInitHeader;
remote_key_->resize(0);
pending_request_bytes_ = sizeof(int32_t);
}
} }
// Bytes needed to fulfill current request // Bytes needed to fulfill current request
size_t BytesNeeded() { size_t BytesNeeded() {
...@@ -75,6 +83,7 @@ class RPCSession::EventHandler { ...@@ -75,6 +83,7 @@ class RPCSession::EventHandler {
std::swap(client_mode_, client_mode); std::swap(client_mode_, client_mode);
while (this->Ready()) { while (this->Ready()) {
switch (state_) { switch (state_) {
case kInitHeader: HandleInitHeader(); break;
case kRecvCode: HandleRecvCode(); break; case kRecvCode: HandleRecvCode(); break;
case kRecvCallHandle: { case kRecvCallHandle: {
this->Read(&call_handle_, sizeof(call_handle_)); this->Read(&call_handle_, sizeof(call_handle_));
...@@ -223,6 +232,7 @@ class RPCSession::EventHandler { ...@@ -223,6 +232,7 @@ class RPCSession::EventHandler {
protected: protected:
enum State { enum State {
kInitHeader,
kRecvCode, kRecvCode,
kRecvCallHandle, kRecvCallHandle,
kRecvPackedSeqNumArgs, kRecvPackedSeqNumArgs,
...@@ -240,6 +250,8 @@ class RPCSession::EventHandler { ...@@ -240,6 +250,8 @@ class RPCSession::EventHandler {
RPCCode code_; RPCCode code_;
// Handle for the remote function call. // Handle for the remote function call.
uint64_t call_handle_; uint64_t call_handle_;
// Initialize remote header
bool init_header_step_{0};
// Number of packed arguments. // Number of packed arguments.
int num_packed_args_; int num_packed_args_;
// Current argument index. // Current argument index.
...@@ -266,6 +278,10 @@ class RPCSession::EventHandler { ...@@ -266,6 +278,10 @@ class RPCSession::EventHandler {
<< "state=" << state; << "state=" << state;
state_ = state; state_ = state;
switch (state) { switch (state) {
case kInitHeader: {
LOG(FATAL) << "cannot switch to init header";
break;
}
case kRecvCode: { case kRecvCode: {
this->RequestBytes(sizeof(RPCCode)); this->RequestBytes(sizeof(RPCCode));
break; break;
...@@ -438,6 +454,21 @@ class RPCSession::EventHandler { ...@@ -438,6 +454,21 @@ class RPCSession::EventHandler {
this->SwitchToState(kRecvPackedSeqArg); this->SwitchToState(kRecvPackedSeqArg);
} }
} }
// handler for initial header read
void HandleInitHeader() {
if (init_header_step_ == 0) {
int32_t len;
this->Read(&len, sizeof(len));
remote_key_->resize(len);
init_header_step_ = 1;
this->RequestBytes(len);
return;
} else {
CHECK_EQ(init_header_step_, 1);
this->Read(dmlc::BeginPtr(*remote_key_), remote_key_->length());
this->SwitchToState(kRecvCode);
}
}
// Handler for read code. // Handler for read code.
void HandleRecvCode() { void HandleRecvCode() {
this->Read(&code_, sizeof(code_)); this->Read(&code_, sizeof(code_));
...@@ -633,6 +664,8 @@ class RPCSession::EventHandler { ...@@ -633,6 +664,8 @@ class RPCSession::EventHandler {
int rpc_sess_table_index_; int rpc_sess_table_index_;
// Name of session. // Name of session.
std::string name_; std::string name_;
// remote key
std::string* remote_key_;
}; };
struct RPCSessTable { struct RPCSessTable {
...@@ -699,7 +732,8 @@ RPCCode RPCSession::HandleUntilReturnEvent( ...@@ -699,7 +732,8 @@ RPCCode RPCSession::HandleUntilReturnEvent(
void RPCSession::Init() { void RPCSession::Init() {
// Event handler // Event handler
handler_ = std::make_shared<EventHandler>(&reader_, &writer_, table_index_, name_); handler_ = std::make_shared<EventHandler>(
&reader_, &writer_, table_index_, name_, &remote_key_);
// Quick function to call remote. // Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
...@@ -709,10 +743,13 @@ void RPCSession::Init() { ...@@ -709,10 +743,13 @@ void RPCSession::Init() {
} }
std::shared_ptr<RPCSession> RPCSession::Create( std::shared_ptr<RPCSession> RPCSession::Create(
std::unique_ptr<RPCChannel> channel, std::string name) { std::unique_ptr<RPCChannel> channel,
std::string name,
std::string remote_key) {
std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>(); std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
sess->channel_ = std::move(channel); sess->channel_ = std::move(channel);
sess->name_ = std::move(name); sess->name_ = std::move(name);
sess->remote_key_ = std::move(remote_key);
sess->table_index_ = RPCSessTable::Global()->Insert(sess); sess->table_index_ = RPCSessTable::Global()->Insert(sess);
sess->Init(); sess->Init();
return sess; return sess;
......
...@@ -171,12 +171,15 @@ class RPCSession { ...@@ -171,12 +171,15 @@ class RPCSession {
/*! /*!
* \brief Create a RPC session with given channel. * \brief Create a RPC session with given channel.
* \param channel The communication channel. * \param channel The communication channel.
* \param name The name of the session, used for debug * \param name The local name of the session, used for debug
* \return The session. * \param remote_key The remote key of the session
* if remote_key equals "%toinit", we need to re-intialize
* it by event handler.
*/ */
static std::shared_ptr<RPCSession> Create( static std::shared_ptr<RPCSession> Create(
std::unique_ptr<RPCChannel> channel, std::unique_ptr<RPCChannel> channel,
std::string name); std::string name,
std::string remote_key);
/*! /*!
* \brief Try get session from the global session table by table index. * \brief Try get session from the global session table by table index.
* \param table_index The table index of the session. * \param table_index The table index of the session.
...@@ -208,6 +211,8 @@ class RPCSession { ...@@ -208,6 +211,8 @@ class RPCSession {
int table_index_{0}; int table_index_{0};
// The name of the session. // The name of the session.
std::string name_; std::string name_;
// The remote key
std::string remote_key_;
}; };
/*! /*!
......
...@@ -68,7 +68,14 @@ RPCConnect(std::string url, int port, std::string key) { ...@@ -68,7 +68,14 @@ RPCConnect(std::string url, int port, std::string key) {
sock.Close(); sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server"; LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
} }
return RPCSession::Create(std::unique_ptr<SockChannel>(new SockChannel(sock)), key); CHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
std::string remote_key;
if (keylen != 0) {
remote_key.resize(keylen);
CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
}
return RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
} }
Module RPCClientConnect(std::string url, int port, std::string key) { Module RPCClientConnect(std::string url, int port, std::string key) {
...@@ -80,7 +87,7 @@ void RPCServerLoop(int sockfd) { ...@@ -80,7 +87,7 @@ void RPCServerLoop(int sockfd) {
static_cast<common::TCPSocket::SockType>(sockfd)); static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create( RPCSession::Create(
std::unique_ptr<SockChannel>(new SockChannel(sock)), std::unique_ptr<SockChannel>(new SockChannel(sock)),
"SockServerLoop")->ServerLoop(); "SockServerLoop", "")->ServerLoop();
} }
TVM_REGISTER_GLOBAL("contrib.rpc._Connect") TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
......
...@@ -17,9 +17,9 @@ def rpc_proxy_check(): ...@@ -17,9 +17,9 @@ def rpc_proxy_check():
""" """
try: try:
from tvm.contrib import rpc_proxy from tvm.contrib.rpc import proxy
web_port = 8888 web_port = 8888
prox = rpc_proxy.Proxy("localhost", web_port=web_port) prox = proxy.Proxy("localhost", web_port=web_port)
def check(): def check():
if not tvm.module.enabled("rpc"): if not tvm.module.enabled("rpc"):
return return
...@@ -30,7 +30,7 @@ def rpc_proxy_check(): ...@@ -30,7 +30,7 @@ def rpc_proxy_check():
def addone(name, x): def addone(name, x):
return "%s:%d" % (name, x) return "%s:%d" % (name, x)
server = multiprocessing.Process( server = multiprocessing.Process(
target=rpc_proxy.websocket_proxy_server, target=proxy.websocket_proxy_server,
args=("ws://localhost:%d/ws" % web_port,"x1")) args=("ws://localhost:%d/ws" % web_port,"x1"))
# Need to make sure that the connection start after proxy comes up # Need to make sure that the connection start after proxy comes up
time.sleep(0.1) time.sleep(0.1)
......
import tvm
import logging
import numpy as np
import time
import multiprocessing
from tvm.contrib import rpc
def check_server_drop():
"""test when server drops"""
try:
from tvm.contrib.rpc import tracker, base
from tvm.contrib.rpc.base import TrackerCode
@tvm.register_func("rpc.test2.addone")
def addone(x):
return x + 1
def _put(tclient, value):
base.sendjson(tclient._sock, value)
base.recvjson(tclient._sock)
tserver = tracker.Tracker("localhost", 8888)
tclient = rpc.connect_tracker("localhost", tserver.port)
server1 = rpc.Server(
"localhost", port=9099,
tracker_addr=("localhost", tserver.port),
key="xyz")
server2 = rpc.Server(
"localhost", port=9099,
tracker_addr=("localhost", tserver.port),
key="xyz")
# Fault tolerence to stale worker value
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")])
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")])
_put(tclient, [TrackerCode.PUT, "xyz", (server2.port, "abcxxx11")])
# Fault tolerence server timeout
def check_timeout(timeout, sleeptime):
try:
remote = tclient.request("xyz", priority=0, session_timeout=timeout)
remote2 = tclient.request("xyz", session_timeout=timeout)
time.sleep(sleeptime)
f1 = remote.get_function("rpc.test2.addone")
assert f1(10) == 11
f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11
except tvm.TVMError as e:
pass
check_timeout(0.01, 0.1)
check_timeout(2, 0)
except ImportError:
print("Skip because tornado is not available")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
check_server_drop()
...@@ -18,7 +18,7 @@ def test_rpc_simple(): ...@@ -18,7 +18,7 @@ def test_rpc_simple():
def remotethrow(name): def remotethrow(name):
raise ValueError("%s" % name) raise ValueError("%s" % name)
server = rpc.Server("localhost") server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1") client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.addone") f1 = client.get_function("rpc.test.addone")
assert f1(10) == 11 assert f1(10) == 11
...@@ -41,7 +41,6 @@ def test_rpc_array(): ...@@ -41,7 +41,6 @@ def test_rpc_array():
np.testing.assert_equal(y.asnumpy(), x) np.testing.assert_equal(y.asnumpy(), x)
server = rpc.Server("localhost") server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port) remote = rpc.connect(server.host, server.port)
print("second connect")
r_cpu = tvm.nd.array(x, remote.cpu(0)) r_cpu = tvm.nd.array(x, remote.cpu(0))
assert str(r_cpu.context).startswith("remote") assert str(r_cpu.context).startswith("remote")
np.testing.assert_equal(r_cpu.asnumpy(), x) np.testing.assert_equal(r_cpu.asnumpy(), x)
...@@ -141,7 +140,7 @@ def test_rpc_return_func(): ...@@ -141,7 +140,7 @@ def test_rpc_return_func():
@tvm.register_func("rpc.test.remote_func") @tvm.register_func("rpc.test.remote_func")
def addone(x): def addone(x):
return lambda y: x+y return lambda y: x+y
server = rpc.Server("localhost") server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1") client = rpc.connect(server.host, server.port, key="x1")
f1 = client.get_function("rpc.test.remote_func") f1 = client.get_function("rpc.test.remote_func")
fadd = f1(10) fadd = f1(10)
......
...@@ -851,6 +851,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -851,6 +851,7 @@ var tvm_runtime = tvm_runtime || {};
} }
// Node js, import websocket // Node js, import websocket
var bkey = StringToUint8Array("server:" + key); var bkey = StringToUint8Array("server:" + key);
bkey = bkey.slice(0, bkey.length - 1);
var server_name = "WebSocketRPCServer[" + key + "]"; var server_name = "WebSocketRPCServer[" + key + "]";
var RPC_MAGIC = 0xff271; var RPC_MAGIC = 0xff271;
function checkEndian() { function checkEndian() {
...@@ -895,7 +896,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -895,7 +896,7 @@ var tvm_runtime = tvm_runtime || {};
} else { } else {
return new TVMConstant(0, "int32"); return new TVMConstant(0, "int32");
} }
} , server_name); } , server_name, "%toinit");
function on_open(event) { function on_open(event) {
var intbuf = new Int32Array(1); var intbuf = new Int32Array(1);
...@@ -912,6 +913,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -912,6 +913,7 @@ var tvm_runtime = tvm_runtime || {};
var msg = new Uint8Array(event.data); var msg = new Uint8Array(event.data);
CHECK(msg.length >= 4, "Need message header to be bigger than 4"); CHECK(msg.length >= 4, "Need message header to be bigger than 4");
var magic = new Int32Array(event.data)[0]; var magic = new Int32Array(event.data)[0];
if (magic == RPC_MAGIC + 1) { if (magic == RPC_MAGIC + 1) {
throwError("key: " + key + " has already been used in proxy"); throwError("key: " + key + " has already been used in proxy");
} else if (magic == RPC_MAGIC + 2) { } else if (magic == RPC_MAGIC + 2) {
...@@ -1014,7 +1016,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -1014,7 +1016,7 @@ var tvm_runtime = tvm_runtime || {};
/** /**
* Load parameters from serialized byte array of parameter dict. * Load parameters from serialized byte array of parameter dict.
* *
* @param {Uint8Array} params The serialized parameter dict. * @param {Uint8Array} params The serialized parameter dict.
*/ */
"load_params" : function(params) { "load_params" : function(params) {
...@@ -1024,7 +1026,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -1024,7 +1026,7 @@ var tvm_runtime = tvm_runtime || {};
/** /**
* Load parameters from serialized base64 string of parameter dict. * Load parameters from serialized base64 string of parameter dict.
* *
* @param {string} base64_params The serialized parameter dict. * @param {string} base64_params The serialized parameter dict.
*/ */
"load_base64_params" : function(base64_params) { "load_base64_params" : function(base64_params) {
...@@ -1046,7 +1048,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -1046,7 +1048,7 @@ var tvm_runtime = tvm_runtime || {};
/** /**
* Get index-th output to out. * Get index-th output to out.
* *
* @param {NDArray} out The output array container. * @param {NDArray} out The output array container.
* @return {NDArray} The output array container. * @return {NDArray} The output array container.
*/ */
...@@ -1076,7 +1078,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -1076,7 +1078,7 @@ var tvm_runtime = tvm_runtime || {};
var tvm_graph_module = fcreate(graph_json_str, libmod, var tvm_graph_module = fcreate(graph_json_str, libmod,
new TVMConstant(ctx.device_type, "int32"), new TVMConstant(ctx.device_type, "int32"),
new TVMConstant(ctx.device_id, "int32")); new TVMConstant(ctx.device_id, "int32"));
return new GraphModule(tvm_graph_module, ctx); return new GraphModule(tvm_graph_module, ctx);
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment