Commit 1077f8e8 by Tianqi Chen Committed by GitHub

[RUNTIME][RPC] Enable remote linking of device code. (#444)

* [RUNTIME][RPC] Enable remote linking of device code.

* fix build
parent 64402c14
...@@ -15,7 +15,7 @@ import socket ...@@ -15,7 +15,7 @@ import socket
import struct import struct
import logging import logging
import multiprocessing import multiprocessing
from . import util, cc from . import util, cc, tar
from ..module import load as _load_module from ..module import load as _load_module
from .._ffi.function import _init_api, register_func from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context from .._ffi.ndarray import context as _context
...@@ -37,10 +37,16 @@ def _server_env(): ...@@ -37,10 +37,16 @@ def _server_env():
"""Load module from remote side.""" """Load module from remote side."""
path = temp.relpath(file_name) path = temp.relpath(file_name)
# Try create a shared library in remote # Try create a shared library in remote
if path.endswith('.o'): if path.endswith(".o"):
logging.info('Create shared library based on %s', path) logging.info("Create shared library based on %s", path)
cc.create_shared(path + '.so', path) cc.create_shared(path + ".so", path)
path += '.so' path += ".so"
elif path.endswith(".tar"):
tar_temp = util.tempdir()
tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
cc.create_shared(path + ".so", files)
path += ".so"
m = _load_module(path) m = _load_module(path)
logging.info("load_module %s", path) logging.info("load_module %s", path)
return m return m
...@@ -63,7 +69,7 @@ def _recvall(sock, nbytes): ...@@ -63,7 +69,7 @@ def _recvall(sock, nbytes):
chunk = sock.recv(min(nbytes - nread, 1024)) chunk = sock.recv(min(nbytes - nread, 1024))
nread += len(chunk) nread += len(chunk)
res.append(chunk) res.append(chunk)
return b''.join(res) return b"".join(res)
def _listen_loop(sock): def _listen_loop(sock):
...@@ -71,16 +77,16 @@ def _listen_loop(sock): ...@@ -71,16 +77,16 @@ def _listen_loop(sock):
while True: while True:
conn, addr = sock.accept() conn, addr = sock.accept()
logging.info("RPCServer: connection from %s", addr) logging.info("RPCServer: connection from %s", addr)
magic = struct.unpack('@i', _recvall(conn, 4))[0] magic = struct.unpack("@i", _recvall(conn, 4))[0]
if magic != RPC_MAGIC: if magic != RPC_MAGIC:
conn.close() conn.close()
continue continue
keylen = struct.unpack('@i', _recvall(conn, 4))[0] keylen = struct.unpack("@i", _recvall(conn, 4))[0]
key = py_str(_recvall(conn, keylen)) key = py_str(_recvall(conn, keylen))
if not key.startswith("client:"): if not key.startswith("client:"):
conn.sendall(struct.pack('@i', RPC_MAGIC + 2)) conn.sendall(struct.pack("@i", RPC_MAGIC + 2))
else: else:
conn.sendall(struct.pack('@i', RPC_MAGIC)) conn.sendall(struct.pack("@i", RPC_MAGIC))
logging.info("Connection from %s", addr) logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr)) process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True process.deamon = True
...@@ -94,10 +100,10 @@ def _connect_proxy_loop(addr, key): ...@@ -94,10 +100,10 @@ def _connect_proxy_loop(addr, key):
while True: while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr) sock.connect(addr)
sock.sendall(struct.pack('@i', RPC_MAGIC)) sock.sendall(struct.pack("@i", RPC_MAGIC))
sock.sendall(struct.pack('@i', len(key))) sock.sendall(struct.pack("@i", len(key)))
sock.sendall(key.encode("utf-8")) sock.sendall(key.encode("utf-8"))
magic = struct.unpack('@i', _recvall(sock, 4))[0] magic = struct.unpack("@i", _recvall(sock, 4))[0]
if magic == RPC_MAGIC + 1: if magic == RPC_MAGIC + 1:
raise RuntimeError("key: %s has already been used in proxy" % key) raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2: elif magic == RPC_MAGIC + 2:
...@@ -321,7 +327,7 @@ def connect(url, port, key=""): ...@@ -321,7 +327,7 @@ def connect(url, port, key=""):
try: try:
sess = _Connect(url, port, key) sess = _Connect(url, port, key)
except NameError: except NameError:
raise RuntimeError('Please compile with USE_RPC=1') raise RuntimeError("Please compile with USE_RPC=1")
return RPCSession(sess) return RPCSession(sess)
_init_api("tvm.contrib.rpc") _init_api("tvm.contrib.rpc")
"""Util to invoke tarball in the system."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import os
import shutil
import subprocess
from . import util
def tar(output, files):
"""Create tarball containing all files in root.
Parameters
----------
output : str
The target shared library.
files : list
List of files to be bundled.
"""
cmd = ["tar"]
cmd += ["-czf"]
temp = util.tempdir()
fset = set()
for fname in files:
base = os.path.basename(fname)
if base in fset:
raise ValueError("duplicate file name %s" % base)
fset.add(base)
shutil.copy(fname, temp.relpath(base))
cmd += [output]
cmd += ["-C", temp.temp_dir]
cmd += temp.listdir()
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Tar error:\n"
msg += out
raise RuntimeError(msg)
def untar(tar_file, directory):
"""Unpack all tar files into the directory
Parameters
----------
tar_file : str
The source tar file.
directory : str
The target directory
"""
cmd = ["tar"]
cmd += ["-xf"]
cmd += [tar_file]
cmd += ["-C", directory]
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Tar error:\n"
msg += out
raise RuntimeError(msg)
...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
from collections import namedtuple from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .contrib import cc as _cc, util as _util from .contrib import cc as _cc, tar as _tar, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"]) ProfileResult = namedtuple("ProfileResult", ["mean"])
...@@ -100,7 +100,11 @@ class Module(ModuleBase): ...@@ -100,7 +100,11 @@ class Module(ModuleBase):
with open(path_cc, "w") as f: with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib)) f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc) files.append(path_cc)
fcompile = fcompile if fcompile else _cc.create_shared if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
fcompile(file_name, files, **kwargs) fcompile(file_name, files, **kwargs)
def time_evaluator(self, func_name, ctx, number): def time_evaluator(self, func_name, ctx, number):
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <unordered_set> #include <unordered_set>
#include <cstring>
#include "./file_util.h" #include "./file_util.h"
namespace tvm { namespace tvm {
...@@ -26,6 +27,16 @@ PackedFunc Module::GetFunction( ...@@ -26,6 +27,16 @@ PackedFunc Module::GetFunction(
} }
void Module::Import(Module other) { void Module::Import(Module other) {
// specially handle rpc
if (!std::strcmp((*this)->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr);
}
(*fimport_)(*this, other);
return;
}
// cyclic detection. // cyclic detection.
std::unordered_set<const ModuleNode*> visited{other.node_.get()}; std::unordered_set<const ModuleNode*> visited{other.node_.get()};
std::vector<const ModuleNode*> stack{other.node_.get()}; std::vector<const ModuleNode*> stack{other.node_.get()};
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <memory> #include <memory>
#include <cstring>
#include "./rpc_session.h" #include "./rpc_session.h"
namespace tvm { namespace tvm {
...@@ -83,6 +84,10 @@ class RPCModuleNode final : public ModuleNode { ...@@ -83,6 +84,10 @@ class RPCModuleNode final : public ModuleNode {
return WrapRemote(handle); return WrapRemote(handle);
} }
void* module_handle() const {
return module_handle_;
}
private: private:
PackedFunc WrapRemote(RPCFuncHandle handle) { PackedFunc WrapRemote(RPCFuncHandle handle) {
if (handle == nullptr) return PackedFunc(); if (handle == nullptr) return PackedFunc();
...@@ -162,6 +167,21 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule") ...@@ -162,6 +167,21 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
*rv = Module(n); *rv = Module(n);
}); });
TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module parent = args[0];
Module child = args[1];
CHECK(!std::strcmp(parent->type_key(), "rpc") &&
!std::strcmp(child->type_key(), "rpc"));
auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
CHECK(pmod->sess().get() == cmod->sess().get())
<< "Import of remote module need to belong to same session.";
pmod->sess()->CallRemote(RPCCode::kModuleImport,
pmod->module_handle(),
cmod->module_handle());
});
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex") TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0]; Module m = args[0];
......
...@@ -940,6 +940,13 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { ...@@ -940,6 +940,13 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
*rv = static_cast<void*>(new Module(m)); *rv = static_cast<void*>(new Module(m));
} }
void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
void* pmod = args[0];
void* cmod = args[1];
static_cast<Module*>(pmod)->Import(
*static_cast<Module*>(cmod));
}
void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0]; void* mhandle = args[0];
delete static_cast<Module*>(mhandle); delete static_cast<Module*>(mhandle);
...@@ -1006,6 +1013,7 @@ void RPCSession::EventHandler::HandlePackedCall() { ...@@ -1006,6 +1013,7 @@ void RPCSession::EventHandler::HandlePackedCall() {
case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break; case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break; case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break; case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break;
case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
......
...@@ -44,6 +44,7 @@ enum class RPCCode : int { ...@@ -44,6 +44,7 @@ enum class RPCCode : int {
kDevStreamSync, kDevStreamSync,
kCopyAmongRemote, kCopyAmongRemote,
kModuleLoad, kModuleLoad,
kModuleImport,
kModuleFree, kModuleFree,
kModuleGetFunc, kModuleGetFunc,
kModuleGetSource, kModuleGetSource,
......
...@@ -86,6 +86,55 @@ def test_rpc_remote_module(): ...@@ -86,6 +86,55 @@ def test_rpc_remote_module():
cost = time_f(a, b).mean cost = time_f(a, b).mean
print('%g secs/op' % cost) print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
def check_remote_link_cl():
"""Test function to run remote code such as cl
This is not enabled because there is forking issue
of TVM runtime when server launches after OpenCL
runtime initializes. We leave it as an example
on how to do rpc when we want to do linking on remote.
"""
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
if not tvm.module.enabled("opencl"):
print("Skip because opencl is not enabled")
return
temp = util.tempdir()
ctx = remote.cl(0)
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B], "opencl", target_host="llvm", name="myadd")
# Option 1: save modules separately and rely on remote compiler
path_o = temp.relpath("myadd.o")
path_cl = temp.relpath("myadd.cl")
path_json = temp.relpath("myadd.tvm_meta.json")
f.save(path_o)
f.imported_modules[0].save(path_cl)
remote.upload(path_o)
remote.upload(path_cl)
# upload meta data
remote.upload(path_json)
fhost = remote.load_module("myadd.o")
fdev = remote.load_module("myadd.cl")
fhost.import_module(fdev)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
fhost(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
# Option 2: export library as a tar ball then handled by remote compiler
path_tar = temp.relpath("myadd.tar")
f.export_library(path_tar)
remote.upload(path_tar)
fhost = remote.load_module("myadd.tar")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
fhost(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote() check_remote()
def test_rpc_return_func(): def test_rpc_return_func():
...@@ -101,8 +150,8 @@ def test_rpc_return_func(): ...@@ -101,8 +150,8 @@ def test_rpc_return_func():
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_rpc_remote_module()
test_rpc_return_func() test_rpc_return_func()
test_rpc_file_exchange() test_rpc_file_exchange()
test_rpc_array() test_rpc_array()
test_rpc_remote_module()
test_rpc_simple() test_rpc_simple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment