Commit 1077f8e8 by Tianqi Chen Committed by GitHub

[RUNTIME][RPC] Enable remote linking of device code. (#444)

* [RUNTIME][RPC] Enable remote linking of device code.

* fix build
parent 64402c14
......@@ -15,7 +15,7 @@ import socket
import struct
import logging
import multiprocessing
from . import util, cc
from . import util, cc, tar
from ..module import load as _load_module
from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context
......@@ -37,10 +37,16 @@ def _server_env():
"""Load module from remote side."""
path = temp.relpath(file_name)
# Try create a shared library in remote
if path.endswith('.o'):
logging.info('Create shared library based on %s', path)
cc.create_shared(path + '.so', path)
path += '.so'
if path.endswith(".o"):
logging.info("Create shared library based on %s", path)
cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = util.tempdir()
tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
cc.create_shared(path + ".so", files)
path += ".so"
m = _load_module(path)
logging.info("load_module %s", path)
return m
......@@ -63,7 +69,7 @@ def _recvall(sock, nbytes):
chunk = sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b''.join(res)
return b"".join(res)
def _listen_loop(sock):
......@@ -71,16 +77,16 @@ def _listen_loop(sock):
while True:
conn, addr = sock.accept()
logging.info("RPCServer: connection from %s", addr)
magic = struct.unpack('@i', _recvall(conn, 4))[0]
magic = struct.unpack("@i", _recvall(conn, 4))[0]
if magic != RPC_MAGIC:
conn.close()
continue
keylen = struct.unpack('@i', _recvall(conn, 4))[0]
keylen = struct.unpack("@i", _recvall(conn, 4))[0]
key = py_str(_recvall(conn, keylen))
if not key.startswith("client:"):
conn.sendall(struct.pack('@i', RPC_MAGIC + 2))
conn.sendall(struct.pack("@i", RPC_MAGIC + 2))
else:
conn.sendall(struct.pack('@i', RPC_MAGIC))
conn.sendall(struct.pack("@i", RPC_MAGIC))
logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True
......@@ -94,10 +100,10 @@ def _connect_proxy_loop(addr, key):
while True:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(addr)
sock.sendall(struct.pack('@i', RPC_MAGIC))
sock.sendall(struct.pack('@i', len(key)))
sock.sendall(struct.pack("@i", RPC_MAGIC))
sock.sendall(struct.pack("@i", len(key)))
sock.sendall(key.encode("utf-8"))
magic = struct.unpack('@i', _recvall(sock, 4))[0]
magic = struct.unpack("@i", _recvall(sock, 4))[0]
if magic == RPC_MAGIC + 1:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == RPC_MAGIC + 2:
......@@ -321,7 +327,7 @@ def connect(url, port, key=""):
try:
sess = _Connect(url, port, key)
except NameError:
raise RuntimeError('Please compile with USE_RPC=1')
raise RuntimeError("Please compile with USE_RPC=1")
return RPCSession(sess)
_init_api("tvm.contrib.rpc")
"""Util to invoke tarball in the system."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import os
import shutil
import subprocess
from . import util
def tar(output, files):
"""Create tarball containing all files in root.
Parameters
----------
output : str
The target shared library.
files : list
List of files to be bundled.
"""
cmd = ["tar"]
cmd += ["-czf"]
temp = util.tempdir()
fset = set()
for fname in files:
base = os.path.basename(fname)
if base in fset:
raise ValueError("duplicate file name %s" % base)
fset.add(base)
shutil.copy(fname, temp.relpath(base))
cmd += [output]
cmd += ["-C", temp.temp_dir]
cmd += temp.listdir()
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Tar error:\n"
msg += out
raise RuntimeError(msg)
def untar(tar_file, directory):
"""Unpack all tar files into the directory
Parameters
----------
tar_file : str
The source tar file.
directory : str
The target directory
"""
cmd = ["tar"]
cmd += ["-xf"]
cmd += [tar_file]
cmd += ["-C", directory]
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Tar error:\n"
msg += out
raise RuntimeError(msg)
......@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
from collections import namedtuple
from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api
from .contrib import cc as _cc, util as _util
from .contrib import cc as _cc, tar as _tar, util as _util
ProfileResult = namedtuple("ProfileResult", ["mean"])
......@@ -100,7 +100,11 @@ class Module(ModuleBase):
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)
fcompile = fcompile if fcompile else _cc.create_shared
if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
fcompile(file_name, files, **kwargs)
def time_evaluator(self, func_name, ctx, number):
......
......@@ -7,6 +7,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_set>
#include <cstring>
#include "./file_util.h"
namespace tvm {
......@@ -26,6 +27,16 @@ PackedFunc Module::GetFunction(
}
void Module::Import(Module other) {
// specially handle rpc
if (!std::strcmp((*this)->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr);
}
(*fimport_)(*this, other);
return;
}
// cyclic detection.
std::unordered_set<const ModuleNode*> visited{other.node_.get()};
std::vector<const ModuleNode*> stack{other.node_.get()};
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include <cstring>
#include "./rpc_session.h"
namespace tvm {
......@@ -83,6 +84,10 @@ class RPCModuleNode final : public ModuleNode {
return WrapRemote(handle);
}
void* module_handle() const {
return module_handle_;
}
private:
PackedFunc WrapRemote(RPCFuncHandle handle) {
if (handle == nullptr) return PackedFunc();
......@@ -162,6 +167,21 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
*rv = Module(n);
});
TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module parent = args[0];
Module child = args[1];
CHECK(!std::strcmp(parent->type_key(), "rpc") &&
!std::strcmp(child->type_key(), "rpc"));
auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
CHECK(pmod->sess().get() == cmod->sess().get())
<< "Import of remote module need to belong to same session.";
pmod->sess()->CallRemote(RPCCode::kModuleImport,
pmod->module_handle(),
cmod->module_handle());
});
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
......
......@@ -940,6 +940,13 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
*rv = static_cast<void*>(new Module(m));
}
void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
void* pmod = args[0];
void* cmod = args[1];
static_cast<Module*>(pmod)->Import(
*static_cast<Module*>(cmod));
}
void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0];
delete static_cast<Module*>(mhandle);
......@@ -1006,6 +1013,7 @@ void RPCSession::EventHandler::HandlePackedCall() {
case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break;
case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
......
......@@ -44,6 +44,7 @@ enum class RPCCode : int {
kDevStreamSync,
kCopyAmongRemote,
kModuleLoad,
kModuleImport,
kModuleFree,
kModuleGetFunc,
kModuleGetSource,
......
......@@ -86,6 +86,55 @@ def test_rpc_remote_module():
cost = time_f(a, b).mean
print('%g secs/op' % cost)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
def check_remote_link_cl():
"""Test function to run remote code such as cl
This is not enabled because there is forking issue
of TVM runtime when server launches after OpenCL
runtime initializes. We leave it as an example
on how to do rpc when we want to do linking on remote.
"""
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
if not tvm.module.enabled("opencl"):
print("Skip because opencl is not enabled")
return
temp = util.tempdir()
ctx = remote.cl(0)
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B], "opencl", target_host="llvm", name="myadd")
# Option 1: save modules separately and rely on remote compiler
path_o = temp.relpath("myadd.o")
path_cl = temp.relpath("myadd.cl")
path_json = temp.relpath("myadd.tvm_meta.json")
f.save(path_o)
f.imported_modules[0].save(path_cl)
remote.upload(path_o)
remote.upload(path_cl)
# upload meta data
remote.upload(path_json)
fhost = remote.load_module("myadd.o")
fdev = remote.load_module("myadd.cl")
fhost.import_module(fdev)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
fhost(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
# Option 2: export library as a tar ball then handled by remote compiler
path_tar = temp.relpath("myadd.tar")
f.export_library(path_tar)
remote.upload(path_tar)
fhost = remote.load_module("myadd.tar")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
fhost(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
def test_rpc_return_func():
......@@ -101,8 +150,8 @@ def test_rpc_return_func():
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_remote_module()
test_rpc_return_func()
test_rpc_file_exchange()
test_rpc_array()
test_rpc_remote_module()
test_rpc_simple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment