Commit 134c6ba3 by Tianqi Chen Committed by GitHub

[RUNTIME] RPC runtime that support run testing on remote device. (#147)

* [RUNTIME] RPC runtime that support run testing on remote device.

* Fix ctypes in OSX.

* fix lint
parent b7fe6119
......@@ -8,6 +8,7 @@ endif()
include(cmake/Util.cmake)
tvm_option(USE_CUDA "Build with CUDA" ON)
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_RPC "Build with RPC" OFF)
tvm_option(USE_LLVM "Build with LLVM" OFF)
tvm_option(USE_RTTI "Build with RTTI" OFF)
tvm_option(USE_MSVC_MT "Build with MT" OFF)
......@@ -67,6 +68,7 @@ file(GLOB RUNTIME_SRCS src/runtime/*.cc)
file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc)
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
if(USE_CUDA)
find_package(CUDA)
......@@ -98,6 +100,11 @@ else(USE_OPENCL)
add_definitions(-DTVM_OPENCL_RUNTIME=0)
endif(USE_OPENCL)
if(USE_RPC)
message(STATUS "Build with RPC support...")
list(APPEND RUNTIME_SRCS ${RUNTIME_RPC_SRCS})
endif(USE_RPC)
if(USE_LLVM)
find_package(LLVM REQUIRED CONFIG)
message(STATUS "Build with LLVM support...")
......
......@@ -20,10 +20,16 @@ LIB_HALIDEIR = HalideIR/lib/libHalideIR.a
CC_SRC = $(filter-out src/contrib/%.cc src/runtime/%.cc,\
$(wildcard src/*/*.cc src/*/*/*.cc))
METAL_SRC = $(wildcard src/runtime/metal/*.mm)
RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
CUDA_SRC = $(wildcard src/runtime/cuda/*.cc)
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
# Objectives
METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC))
CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC))
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC))
RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
CONTRIB_OBJ =
......@@ -51,6 +57,7 @@ endif
ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart -lnvrtc
RUNTIME_DEP += $(CUDA_OBJ)
else
CFLAGS += -DTVM_CUDA_RUNTIME=0
endif
......@@ -62,6 +69,7 @@ ifeq ($(USE_OPENCL), 1)
else
LDFLAGS += -lOpenCL
endif
RUNTIME_DEP += $(OPENCL_OBJ)
else
CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif
......@@ -75,6 +83,10 @@ else
CFLAGS += -DTVM_METAL_RUNTIME=0
endif
ifeq ($(USE_RPC), 1)
RUNTIME_DEP += $(RPC_OBJ)
endif
# llvm configuration
ifdef LLVM_CONFIG
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
......
......@@ -57,13 +57,14 @@ typedef enum {
// that is used by TVM API calls.
kHandle = 3U,
kNull = 4U,
kArrayHandle = 5U,
kTVMType = 6U,
kNodeHandle = 7U,
kModuleHandle = 8U,
kFuncHandle = 9U,
kStr = 10U,
kBytes = 11U
kTVMType = 5U,
kTVMContext = 6U,
kArrayHandle = 7U,
kNodeHandle = 8U,
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U
} TVMTypeCode;
/*!
......@@ -98,6 +99,7 @@ typedef union {
void* v_handle;
const char* v_str;
TVMType v_type;
TVMContext v_ctx;
} TVMValue;
/*!
......
......@@ -205,6 +205,10 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kArrayHandle);
return static_cast<TVMArray*>(value_.v_handle);
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
int type_code() const {
return type_code_;
}
......@@ -254,6 +258,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator TVMArray*;
using TVMPODValue_::operator TVMContext;
// conversion operator.
operator std::string() const {
if (type_code_ == kTVMType) {
......@@ -333,6 +338,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator TVMArray*;
using TVMPODValue_::operator TVMContext;
// Disable copy and assign from another value, but allow move.
TVMRetValue(const TVMRetValue& other) {
this->Assign(other);
......@@ -474,7 +480,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kModuleHandle: {
SwitchToClass<PackedFunc>(kModuleHandle, other);
SwitchToClass<Module>(kModuleHandle, other);
break;
}
case kNodeHandle: {
......@@ -532,6 +538,7 @@ inline const char* TypeCode2Str(int type_code) {
case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle";
case kTVMType: return "TVMType";
case kTVMContext: return "TVMContext";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
default: LOG(FATAL) << "unknown type_code="
......@@ -659,6 +666,10 @@ class TVMArgsSetter {
values_[i].v_handle = value;
type_codes_[i] = kArrayHandle;
}
void operator()(size_t i, TVMContext value) const {
values_[i].v_ctx = value;
type_codes_[i] = kTVMContext;
}
void operator()(size_t i, TVMType value) const {
values_[i].v_type = value;
type_codes_[i] = kTVMType;
......@@ -674,6 +685,10 @@ class TVMArgsSetter {
values_[i].v_str = value.c_str();
type_codes_[i] = kStr;
}
void operator()(size_t i, TVMByteArray& value) const { // NOLINT(*)
values_[i].v_handle = &value;
type_codes_[i] = kBytes;
}
void operator()(size_t i, PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = &value;
type_codes_[i] = kFuncHandle;
......
......@@ -37,6 +37,9 @@ USE_OPENCL = 0
# whether enable Metal during compile
USE_METAL = 0
# Whether enable RPC during compile
USE_RPC = 0
# whether build with LLVM support
# Requires LLVM version >= 4.0
# Set LLVM_CONFIG to your version, uncomment to build with llvm support
......
......@@ -10,7 +10,7 @@ from numbers import Number, Integral
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMByteArray
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
......@@ -107,6 +107,9 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext):
values[i].v_ctx = arg
type_codes[i] = TypeCode.TVM_CONTEXT
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
......
......@@ -13,13 +13,15 @@ class TypeCode(object):
FLOAT = 2
HANDLE = 3
NULL = 4
ARRAY_HANDLE = 5
TVM_TYPE = 6
NODE_HANDLE = 7
MODULE_HANDLE = 8
FUNC_HANDLE = 9
STR = 10
BYTES = 11
TVM_TYPE = 5
TVM_CONTEXT = 6
ARRAY_HANDLE = 7
NODE_HANDLE = 8
MODULE_HANDLE = 9
FUNC_HANDLE = 10
STR = 11
BYTES = 12
class TVMValue(ctypes.Union):
"""TVMValue in C API"""
......
......@@ -10,13 +10,14 @@ cdef enum TVMTypeCode:
kFloat = 2
kHandle = 3
kNull = 4
kArrayHandle = 5
kTVMType = 6
kNodeHandle = 7
kModuleHandle = 8
kFuncHandle = 9
kStr = 10
kBytes = 11
kTVMType = 5
kTVMContext = 6
kArrayHandle = 7
kNodeHandle = 8
kModuleHandle = 9
kFuncHandle = 10
kStr = 11
kBytes = 12
cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType:
......@@ -43,6 +44,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* v_handle
const char* v_str
DLDataType v_type
DLContext v_ctx
ctypedef int64_t tvm_index_t
ctypedef void* DLTensorHandle
......
......@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMByteArray
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
print("TVM: Initializing cython mode...")
......@@ -110,6 +110,10 @@ cdef inline void make_arg(object arg,
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, TVMContext):
value[0].v_ctx = (<DLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kTVMContext
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
......@@ -170,6 +174,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
return make_ret_bytes(value.v_handle)
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kTVMContext:
return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
elif tcode == kModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle:
......
......@@ -61,6 +61,9 @@ class ModuleBase(object):
self.handle = handle
self._entry = None
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
@property
def entry_func(self):
"""Get the entry function
......
......@@ -64,6 +64,7 @@ class TVMType(ctypes.Structure):
def __ne__(self, other):
return not self.__eq__(other)
RPC_SESS_MASK = 128
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
......@@ -121,6 +122,11 @@ class TVMContext(ctypes.Structure):
return not self.__eq__(other)
def __repr__(self):
if self.device_type >= RPC_SESS_MASK:
tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % (
tbl_id, TVMContext.MASK2STR[dev_type], self.device_id)
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
......
"""RPC interface for easy testing.
RPC enables connect to a remote server, upload and launch functions.
This is useful to for cross-compile and remote testing,
The compiler stack runs on local server, while we use RPC server
to run on remote runtime which don't have a compiler available.
The test program compiles the program on local server,
upload and run remote RPC server, get the result back to verify correctness.
"""
from __future__ import absolute_import
import os
import socket
import struct
import logging
import multiprocessing
from . import util
from ..module import load as _load_module
from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context
RPC_MAGIC = 0xff271
RPC_SESS_MASK = 128
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.upload")
def upload(file_name, blob):
"""Upload the blob to remote temp file"""
path = temp.relpath(file_name)
with open(path, "wb") as out_file:
out_file.write(blob)
@register_func("tvm.contrib.rpc.server.download")
def download(file_name):
"""Download file from remote"""
path = temp.relpath(file_name)
dat = bytearray(open(path, "rb").read())
return dat
@register_func("tvm.contrib.rpc.server.load_module")
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
m = _load_module(path)
return m
_ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
def _recvall(sock, nbytes):
res = []
nread = 0
while nread < nbytes:
chunk = sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b''.join(res)
def _listen_loop(sock):
"""Lisenting loop"""
while True:
conn, addr = sock.accept()
logging.info("RPCServer: connection from %s", addr)
conn.sendall(struct.pack('@i', RPC_MAGIC))
magic = struct.unpack('@i', _recvall(conn, 4))[0]
if magic != RPC_MAGIC:
conn.close()
continue
logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True
process.start()
# close from our side.
conn.close()
class Server(object):
"""Start RPC server on a seperate process.
This is a simple python implementation based on multi-processing.
It is also possible to implement a similar C based sever with
TVM runtime which does not depend on the python.
Parameter
---------
host : str
The host url of the server.
port : int
The port to be bind to
port_end : int, optional
The end port to search
"""
def __init__(self, host, port=9091, port_end=9199):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for port in range(port, port_end):
try:
sock.bind((host, port))
self.port = port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.host = host
self.proc = multiprocessing.Process(target=_listen_loop, args=(self.sock,))
self.proc.start()
def terminate(self):
"""Terminate the server process"""
if self.proc:
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()
class RPCSession(object):
"""RPC Client session module
Do not directly create the obhect, call connect
"""
# pylint: disable=invalid-name
def __init__(self, sess):
self._sess = sess
self._tbl_index = _SessTableIndex(sess)
self._upload_func = None
self._download_func = None
def get_function(self, name):
"""Get function from the session.
Parameters
----------
name : str
The name of the function
Returns
-------
f : Function
The result function.
"""
return self._sess.get_function(name)
def context(self, dev_type, dev_id=0):
"""Construct a remote context.
Parameters
----------
dev_type: int or str
dev_id: int, optional
Returns
-------
ctx: TVMContext
The corresponding encoded remote context.
"""
ctx = _context(dev_type, dev_id)
encode = (self._tbl_index + 1) * RPC_SESS_MASK
ctx.device_type += encode
return ctx
def cpu(self, dev_id=0):
"""Construct remote CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct remote GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct remote OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
Parameters
----------
data : str or bytearray
The file name or binary in local to upload.
target : str, optional
The path in remote
"""
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
if not self._upload_func:
self._upload_func = self.get_function(
"tvm.contrib.rpc.server.upload")
self._upload_func(target, blob)
def download(self, path):
"""Download file from remote temp folder.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
blob : bytearray
The result blob from the file.
"""
if not self._download_func:
self._download_func = self.get_function(
"tvm.contrib.rpc.server.download")
return self._download_func(path)
def load_module(self, path):
"""Load a remote module, the file need to be uploaded first.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
m : Module
The remote module containing remote function.
"""
return _LoadRemoteModule(self._sess, path)
def connect(url, port):
"""Connect to RPC Server
Parameters
----------
url : str
The url of the host
port : int
The port to connect to
Returns
-------
sess : RPCSession
The connected session.
"""
sess = _Connect(url, port)
return RPCSession(sess)
_init_api("tvm.contrib.rpc")
......@@ -12,8 +12,14 @@ class TempDirectory(object):
def __init__(self):
self.temp_dir = tempfile.mkdtemp()
def __del__(self):
def remove(self):
"""Remote the tmp dir"""
if self.temp_dir:
shutil.rmtree(self.temp_dir)
self.temp_dir = None
def __del__(self):
self.remove()
def relpath(self, name):
"""Relative path in temp dir
......
......@@ -32,6 +32,11 @@ def cpu(dev_id=0):
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
......@@ -43,6 +48,11 @@ def gpu(dev_id=0):
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
......@@ -54,6 +64,11 @@ def opencl(dev_id=0):
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(4, dev_id)
......@@ -65,6 +80,11 @@ def metal(dev_id=0):
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(8, dev_id)
......@@ -76,6 +96,11 @@ def vpi(dev_id=0):
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(9, dev_id)
......
......@@ -44,8 +44,8 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
if (ptr + size >= ram_max_) return nullptr;
return (char*)(&ram_[0]) + ptr; // NOLINT(*)
}
void SetDevice(int dev_id) final {}
void GetAttr(int dev_id, runtime::DeviceAttrKind kind, TVMRetValue* rv) final {
void SetDevice(TVMContext ctx) final {}
void GetAttr(TVMContext ctx, runtime::DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == runtime::kExist) {
*rv = 1;
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file socket.h
* \brief this file aims to provide a wrapper of sockets
* \author Tianqi Chen
*/
#ifndef TVM_COMMON_SOCKET_H_
#define TVM_COMMON_SOCKET_H_
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif
#else
#include <fcntl.h>
#include <netdb.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#endif
#include <dmlc/logging.h>
#include <string>
#include <cstring>
namespace tvm {
namespace common {
/*!
* \brief Get current host name.
* \return The hostname.
*/
inline std::string GetHostName() {
std::string buf; buf.resize(256);
CHECK_NE(gethostname(&buf[0], 256), -1);
return std::string(buf.c_str());
}
/*!
* \brief Common data structure fornetwork address.
*/
struct SockAddr {
sockaddr_in addr;
SockAddr() {}
/*!
* \brief construc address by url and port
* \param url The url of the address
* \param port The port of the address.
*/
SockAddr(const char *url, int port) {
this->Set(url, port);
}
/*!
* \brief set the address
* \param url the url of the address
* \param port the port of address
*/
void Set(const char *host, int port) {
addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_protocol = SOCK_STREAM;
addrinfo *res = NULL;
int sig = getaddrinfo(host, NULL, &hints, &res);
CHECK(sig == 0 && res != NULL)
<< "cannot obtain address of " << host;
CHECK(res->ai_family == AF_INET)
<< "Does not support IPv6";
memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port);
freeaddrinfo(res);
}
/*! \brief return port of the address */
int port() const {
return ntohs(addr.sin_port);
}
/*! \return a string representation of the address */
std::string AsString() const {
std::string buf; buf.resize(256);
#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length());
#else
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif
CHECK(s != nullptr) << "cannot decode address";
std::ostringstream os;
os << s << ":" << port();
return os.str();
}
};
/*!
* \brief base class containing common operations of TCP and UDP sockets
*/
class Socket {
public:
#if defined(_WIN32)
using ssize_t = int;
using sock_size_t = int;
using SockType = SOCKET;
#else
using SockType = int;
using sock_size_t = size_t;
static constexpr int INVALID_SOCKET = -1;
#endif
/*! \brief the file descriptor of socket */
SockType sockfd;
/*!
* \brief set this socket to use non-blocking mode
* \param non_block whether set it to be non-block, if it is false
* it will set it back to block mode
*/
void SetNonBlock(bool non_block) {
#ifdef _WIN32
u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock");
}
#else
int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) {
Socket::Error("SetNonBlock-1");
}
if (non_block) {
flag |= O_NONBLOCK;
} else {
flag &= ~O_NONBLOCK;
}
if (fcntl(sockfd, F_SETFL, flag) == -1) {
Socket::Error("SetNonBlock-2");
}
#endif
}
/*!
* \brief bind the socket to an address
* \param addr
*/
void Bind(const SockAddr &addr) {
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) {
Socket::Error("Bind");
}
}
/*!
* \brief try bind the socket to host, from start_port to end_port
* \param start_port starting port number to try
* \param end_port ending port number to try
* \return the port successfully bind to, return -1 if failed to bind any port
*/
inline int TryBindHost(int start_port, int end_port) {
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0) {
return port;
}
#if defined(_WIN32)
if (WSAGetLastError() != WSAEADDRINUSE) {
Socket::Error("TryBindHost");
}
#else
if (errno != EADDRINUSE) {
Socket::Error("TryBindHost");
}
#endif
}
return -1;
}
/*! \brief get last error code if any */
int GetSockError() const {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
Error("GetSockError");
}
return error;
}
/*! \brief check if anything bad happens */
bool BadSocket() const {
if (IsClosed()) return true;
int err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
/*! \brief check if socket is already closed */
bool IsClosed() const {
return sockfd == INVALID_SOCKET;
}
/*! \brief close the socket */
void Close() {
if (sockfd != INVALID_SOCKET) {
#ifdef _WIN32
closesocket(sockfd);
#else
close(sockfd);
#endif
sockfd = INVALID_SOCKET;
} else {
Error("Socket::Close double close the socket or close without create");
}
}
/*!
* \return last error of socket 2operation
*/
static int GetLastError() {
#ifdef _WIN32
return WSAGetLastError();
#else
return errno;
#endif
}
/*! \return whether last error was would block */
static bool LastErrorWouldBlock() {
int errsv = GetLastError();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
#endif
}
/*!
* \brief start up the socket module
* call this before using the sockets
*/
static void Startup() {
#ifdef _WIN32
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
LOG(FATAL) << "Could not find a usable version of Winsock.dll";
}
#endif
}
/*!
* \brief shutdown the socket module after use, all sockets need to be closed
*/
static void Finalize() {
#ifdef _WIN32
WSACleanup();
#endif
}
/*!
* \brief Report an socket error.
* \param msg The error message.
*/
static void Error(const char *msg) {
int errsv = GetLastError();
#ifdef _WIN32
LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
#else
LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv);
#endif
}
protected:
explicit Socket(SockType sockfd) : sockfd(sockfd) {
}
};
/*!
* \brief a wrapper of TCP socket that hopefully be cross platform
*/
class TCPSocket : public Socket {
public:
TCPSocket() : Socket(INVALID_SOCKET) {
}
/*!
* \brief construct a TCP socket from existing descriptor
* \param sockfd The descriptor
*/
explicit TCPSocket(SockType sockfd) : Socket(sockfd) {
}
/*!
* \brief enable/disable TCP keepalive
* \param keepalive whether to set the keep alive option on
*/
void SetKeepAlive(bool keepalive) {
int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive");
}
}
/*!
* \brief create the socket, call this before using socket
* \param af domain
*/
void Create(int af = PF_INET) {
sockfd = socket(PF_INET, SOCK_STREAM, 0);
if (sockfd == INVALID_SOCKET) {
Socket::Error("Create");
}
}
/*!
* \brief perform listen of the socket
* \param backlog backlog parameter
*/
void Listen(int backlog = 16) {
listen(sockfd, backlog);
}
/*!
* \brief get a new connection
* \return The accepted socket connection.
*/
TCPSocket Accept() {
SockType newfd = accept(sockfd, NULL, NULL);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
return TCPSocket(newfd);
}
/*!
* \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured
*/
int AtMark() const {
#ifdef _WIN32
unsigned long atmark; // NOLINT(*)
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
#else
int atmark;
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif
return static_cast<int>(atmark);
}
/*!
* \brief connect to an address
* \param addr the address to connect to
* \return whether connect is successful
*/
bool Connect(const SockAddr &addr) {
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0;
}
/*!
* \brief send data using the socket
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually sent
* return -1 if error occurs
*/
ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_);
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
}
/*!
* \brief receive data using the socket
* \param buf_ the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually received
* return -1 if error occurs
*/
ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_);
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
}
/*!
* \brief peform block write that will attempt to send all data out
* can still return smaller than request when error occurs
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \return size of data actually sent
*/
size_t SendAll(const void *buf_, size_t len) {
const char *buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
if (ret == -1) {
if (LastErrorWouldBlock()) return ndone;
Socket::Error("SendAll");
}
buf += ret;
ndone += ret;
}
return ndone;
}
/*!
* \brief peform block read that will attempt to read all data
* can still return smaller than request when error occurs
* \param buf_ the buffer pointer
* \param len length of data to recv
* \return size of data actually sent
*/
size_t RecvAll(void *buf_, size_t len) {
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) {
if (LastErrorWouldBlock()) {
LOG(FATAL) << "would block";
return ndone;
}
Socket::Error("RecvAll");
}
if (ret == 0) return ndone;
buf += ret;
ndone += ret;
}
return ndone;
}
};
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_SOCKET_H_
......@@ -34,6 +34,7 @@ class DeviceAPIManager {
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() {
......@@ -45,25 +46,38 @@ class DeviceAPIManager {
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing);
};
DeviceAPI* DeviceAPIManager::GetAPI(int type, bool allow_missing) {
DeviceAPI* GetAPI(int type, bool allow_missing) {
if (type < kRPCSessMask) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
std::string factory = "device_api." + DeviceName(type);
api_[type] = GetAPI(DeviceName(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
std::lock_guard<std::mutex> lock(mutex_);
if (rpc_api_ != nullptr) return rpc_api_;
rpc_api_ = GetAPI("rpc", allow_missing);
return rpc_api_;
}
}
DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
std::string factory = "device_api." + name;
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << DeviceName(type) << " is not enabled.";
<< "Device API " << name << " is not enabled.";
return nullptr;
}
void* ptr = (*f)();
api_[type] = static_cast<DeviceAPI*>(ptr);
return api_[type];
}
return static_cast<DeviceAPI*>(ptr);
}
};
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
return DeviceAPIManager::Get(
static_cast<int>(ctx.device_type), allow_missing);
}
inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray();
......@@ -293,7 +307,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError();
if (ret != 0) {
std::ostringstream os;
os << "TVMCall CFunc Error:\n" << TVMGetLastError();
throw dmlc::Error(os.str());
}
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
......@@ -303,7 +321,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
[func, rpack](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError();
if (ret != 0) {
std::ostringstream os;
os << "TVMCall CFunc Error:\n" << TVMGetLastError();
throw dmlc::Error(os.str());
}
});
}
API_END();
......@@ -375,25 +397,28 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) {
int dev_type = args[0];
int dev_id = args[1];
DeviceAPIManager::Get(dev_type)->SetDevice(dev_id);
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
});
// set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int dev_type = args[0];
int dev_id = args[1];
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(dev_type, true);
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) {
api->GetAttr(dev_id, kind, ret);
api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else {
DeviceAPIManager::Get(dev_type)->GetAttr(dev_id, kind, ret);
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
......@@ -13,8 +13,8 @@ namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
public:
void SetDevice(int dev_id) final {}
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final {
void SetDevice(TVMContext ctx) final {}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == kExist) {
*rv = 1;
}
......
......@@ -17,26 +17,26 @@ namespace runtime {
class CUDADeviceAPI final : public DeviceAPI {
public:
void SetDevice(int dev_id) final {
CUDA_CALL(cudaSetDevice(dev_id));
void SetDevice(TVMContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
}
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final {
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
int value;
switch (kind) {
case kExist:
value = (
cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, dev_id)
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
== cudaSuccess);
break;
case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, dev_id));
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
break;
}
case kWarpSize: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrWarpSize, dev_id));
&value, cudaDevAttrWarpSize, ctx.device_id));
break;
}
}
......
......@@ -24,18 +24,18 @@ class DeviceAPI {
/*! \brief virtual destructor */
virtual ~DeviceAPI() {}
/*!
* \brief Set the environment device id to dev_id
* \param dev_id The device id.
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
* \return The allocated device pointer
*/
virtual void SetDevice(int dev_id) = 0;
virtual void SetDevice(TVMContext ctx) = 0;
/*!
* \brief Get attribute of specified device.
* \param dev_id The device id
* \param ctx The device context
* \param kind The result kind
* \param rv The return value.
*/
virtual void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) = 0;
virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0;
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
......@@ -77,8 +77,18 @@ class DeviceAPI {
* \param stream The stream to be sync.
*/
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
/*!
* \brief Get device API base don context.
* \param ctx The context
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
};
/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
/*!
* \brief The name of Device API factory.
* \param type The device type.
......
......@@ -26,7 +26,7 @@ class DSOModuleNode final : public ModuleNode {
if (lib_handle_) Unload();
}
const char* type_key() const {
const char* type_key() const final {
return "dso";
}
......
......@@ -60,8 +60,8 @@ class MetalWorkspace final : public DeviceAPI {
// Return false if already initialized, otherwise return true.
void Init();
// override device API
void SetDevice(int dev_id) final;
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final;
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
......
......@@ -18,9 +18,9 @@ MetalWorkspace* MetalWorkspace::Global() {
}
void MetalWorkspace::GetAttr(
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) {
TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(dev_id);
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = int(index< devices.size());
return;
......@@ -30,7 +30,7 @@ void MetalWorkspace::GetAttr(
switch (kind) {
case kMaxThreadsPerBlock: {
*rv = static_cast<int>(
[devices[dev_id] maxThreadsPerThreadgroup].width);
[devices[ctx.device_id] maxThreadsPerThreadgroup].width);
break;
}
case kWarpSize: {
......@@ -69,7 +69,7 @@ int GetWarpSize(id<MTLDevice> dev) {
[NSString stringWithUTF8String:kDummyKernel]
options:nil
error:&error_msg];
CHECK(lib != nil) << error_msg;
CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
id<MTLFunction> f =
[lib
newFunctionWithName:
......@@ -79,7 +79,7 @@ int GetWarpSize(id<MTLDevice> dev) {
[dev
newComputePipelineStateWithFunction:f
error:&error_msg];
CHECK(state != nil) << error_msg;
CHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
return state.threadExecutionWidth;
}
......@@ -109,8 +109,8 @@ void MetalWorkspace::Init() {
}
}
void MetalWorkspace::SetDevice(int dev_id) {
MetalThreadEntry::ThreadLocal()->context.device_id = dev_id;
void MetalWorkspace::SetDevice(TVMContext ctx) {
MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
}
void* MetalWorkspace::AllocDataSpace(
......
......@@ -97,6 +97,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "codegen.build_stackvm";
} else if (target == "llvm") {
f_name = "codegen.build_llvm";
} else if (target == "rpc") {
f_name = "device_api.rpc";
} else if (target == "vpi" || target == "verilog") {
f_name = "device_api.vpi";
} else {
......
......@@ -139,8 +139,8 @@ class OpenCLWorkspace final : public DeviceAPI {
return queues[ctx.device_id];
}
// override device API
void SetDevice(int dev_id) final;
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final;
void SetDevice(TVMContext ctx) final;
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
......
......@@ -18,14 +18,14 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
return &inst;
}
void OpenCLWorkspace::SetDevice(int dev_id) {
OpenCLThreadEntry::ThreadLocal()->context.device_id = dev_id;
void OpenCLWorkspace::SetDevice(TVMContext ctx) {
OpenCLThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
}
void OpenCLWorkspace::GetAttr(
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) {
TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(dev_id);
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = static_cast<int>(index< devices.size());
return;
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_device_api.cc
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include "./rpc_session.h"
#include "../device_api.h"
namespace tvm {
namespace runtime {
class RPCDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {
GetSess(ctx)->CallRemote(
RPCCode::kDevSetDevice, ctx);
}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
*rv = GetSess(ctx)->CallRemote(
RPCCode::kDevGetAttr, ctx, static_cast<int>(kind));
}
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
auto sess = GetSess(ctx);
void *data = sess->CallRemote(
RPCCode::kDevAllocData, ctx, size, alignment);
RemoteSpace* space = new RemoteSpace();
space->data = data;
space->sess = std::move(sess);
return space;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr);
GetSess(ctx)->CallRemote(
RPCCode::kDevFreeData, ctx, space->data);
delete space;
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final {
int from_dev_type = ctx_from.device_type;
int to_dev_type = ctx_to.device_type;
if (from_dev_type > kRPCSessMask &&
to_dev_type > kRPCSessMask) {
CHECK(ctx_from.device_type == ctx_to.device_type)
<< "Cannot copy across two different remote session";
GetSess(ctx_from)->CallRemote(
RPCCode::kCopyAmongRemote,
static_cast<const RemoteSpace*>(from)->data, from_offset,
static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_from, ctx_to, stream);
} else if (from_dev_type > kRPCSessMask &&
to_dev_type == kCPU) {
GetSess(ctx_from)->CopyFromRemote(
static_cast<const RemoteSpace*>(from)->data, from_offset,
to, to_offset, size,
ctx_from);
} else if (from_dev_type == kCPU &&
to_dev_type > kRPCSessMask) {
GetSess(ctx_to)->CopyToRemote(
(void*)from, from_offset, // NOLINT(*)
static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_to);
} else {
LOG(FATAL) << "expect copy from/to remote or between remote";
}
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
GetSess(ctx)->CallRemote(
RPCCode::kDevStreamSync, ctx, stream);
}
private:
std::shared_ptr<RPCSession> GetSess(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_GE(dev_type, kRPCSessMask);
int tbl_index = dev_type / kRPCSessMask - 1;
return RPCSession::Get(tbl_index);
}
};
TVM_REGISTER_GLOBAL("device_api.rpc")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCDeviceAPI inst;
DeviceAPI* ptr = &inst;
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_device_api.cc
* \brief RPC module.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
namespace tvm {
namespace runtime {
const int kRPCMagic = 0xff271;
// Wrapped remote function to packed func.
struct RPCWrappedFunc {
public:
RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess)
: handle_(handle), sess_(sess) {}
void operator()(TVMArgs args, TVMRetValue *rv) const {
sess_->CallFunc(handle_, args, rv);
}
~RPCWrappedFunc() {
sess_->CallRemote(RPCCode::kFreeFunc, handle_);
}
private:
void* handle_{nullptr};
std::shared_ptr<RPCSession> sess_;
};
// RPC that represents a remote module session.
class RPCModuleNode final : public ModuleNode {
public:
RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
: module_handle_(module_handle), sess_(sess) {
}
~RPCModuleNode() {
if (module_handle_ != nullptr) {
sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
}
}
const char* type_key() const final {
return "rpc";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
RPCFuncHandle handle = nullptr;
if (module_handle_ == nullptr) {
handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
} else {
handle = sess_->CallRemote(
RPCCode::kModuleGetFunc, module_handle_, name);
}
if (handle == nullptr) return PackedFunc();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "RPCModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "RPCModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
if (module_handle_ != nullptr) {
std::string ret = sess_->CallRemote(
RPCCode::kModuleGetSource, module_handle_, format);
}
return "";
}
std::shared_ptr<RPCSession>& sess() {
return sess_;
}
private:
// The module handle
void* module_handle_{nullptr};
// The local channel
std::shared_ptr<RPCSession> sess_;
};
Module RPCConnect(std::string url, int port) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
int code = kRPCMagic;
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, RPCSession::Create(sock));
return Module(n);
}
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(sock)->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
CHECK_EQ(tkey, "rpc");
auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(mhandle, sess);
*rv = Module(n);
});
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
CHECK_EQ(tkey, "rpc");
*rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_session.cc
* \brief RPC session for remote function call.
*/
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <array>
#include "./rpc_session.h"
#include "../device_api.h"
namespace tvm {
namespace runtime {
// Temp buffer for data array
struct RPCByteArrayBuffer {
TVMByteArray arr;
std::string data;
};
// Temp buffer for data array
struct RPCDataArrayBuffer {
DLTensor tensor;
std::vector<int64_t> shape;
};
/*!
* \brief Temporal argument buffer.
*/
struct RPCArgBuffer {
// The argument values
std::vector<TVMValue> value;
// The type codes.
std::vector<int> tcode;
// Temporal resources.
std::vector<std::unique_ptr<RPCByteArrayBuffer> > temp_bytes;
// Temporal array
std::vector<std::unique_ptr<RPCDataArrayBuffer> > temp_array;
// convert buffer as TVMArgs
TVMArgs AsTVMArgs() const {
return TVMArgs(value.data(), tcode.data(), value.size());
}
};
struct RPCSessTable {
public:
static constexpr int kMaxRPCSession = 32;
// Get global singleton
static RPCSessTable* Global() {
static RPCSessTable inst;
return &inst;
}
// Get session from table
std::shared_ptr<RPCSession> Get(int index) {
CHECK(index >= 0 && index < kMaxRPCSession);
return tbl_[index].lock();
}
// Insert session into table.
int Insert(std::shared_ptr<RPCSession> ptr) {
std::lock_guard<std::mutex> lock(mutex_);
for (int i = 0; i < kMaxRPCSession; ++i) {
if (tbl_[i].lock() == nullptr) {
tbl_[i] = ptr; return i;
}
}
LOG(FATAL) << "maximum number of RPC session reached";
return 0;
}
private:
// The mutex
std::mutex mutex_;
// Use weak_ptr intentionally
// If the RPCSession get released, the pointer session will be released
std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
};
void RPCSession::Init() {
// Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
this->SendPackedSeq(args.values, args.type_codes, args.num_args);
RPCCode code = RPCCode::kCallFunc;
while (code != RPCCode::kReturn) {
code = HandleNextEvent(rv);
}
});
}
std::shared_ptr<RPCSession> RPCSession::Create(common::TCPSocket sock) {
std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
sess->sock_ = sock;
sess->Init();
sess->table_index_ = RPCSessTable::Global()->Insert(sess);
return sess;
}
std::shared_ptr<RPCSession> RPCSession::Get(int table_index) {
return RPCSessTable::Global()->Get(table_index);
}
RPCSession::~RPCSession() {
this->Shutdown();
}
void RPCSession::Shutdown() {
if (!sock_.BadSocket()) {
RPCCode code = RPCCode::kShutdown;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
sock_.Close();
}
}
void RPCSession::ServerLoop() {
std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc;
TVMRetValue rv;
while (code != RPCCode::kShutdown) {
code = HandleNextEvent(&rv);
CHECK(code != RPCCode::kReturn);
}
if (!sock_.BadSocket()) {
sock_.Close();
}
}
// Get remote function with name
void RPCSession::CallFunc(void* h, TVMArgs args, TVMRetValue* rv) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
RPCCode code = RPCCode::kCallFunc;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(h);
CHECK_EQ(sock_.SendAll(&handle, sizeof(handle)), sizeof(handle));
call_remote_.CallPacked(args, rv);
}
void RPCSession::CopyToRemote(void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t data_size,
TVMContext ctx_to) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_to = StripSessMask(ctx_to);
RPCCode code = RPCCode::kCopyToRemote;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(to);
CHECK_EQ(sock_.SendAll(&handle, sizeof(handle)), sizeof(handle));
uint64_t offset = static_cast<uint64_t>(to_offset);
CHECK_EQ(sock_.SendAll(&offset, sizeof(offset)), sizeof(offset));
uint64_t size = static_cast<uint64_t>(data_size);
CHECK_EQ(sock_.SendAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.SendAll(&ctx_to, sizeof(ctx_to)), sizeof(ctx_to));
CHECK_EQ(sock_.SendAll(reinterpret_cast<char*>(from) + from_offset, data_size),
data_size);
TVMRetValue rv;
while (code != RPCCode::kReturn) {
code = HandleNextEvent(&rv);
}
}
void RPCSession::CopyFromRemote(void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t data_size,
TVMContext ctx_from) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
ctx_from = StripSessMask(ctx_from);
RPCCode code = RPCCode::kCopyFromRemote;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
uint64_t handle = reinterpret_cast<uint64_t>(from);
CHECK_EQ(sock_.SendAll(&handle, sizeof(handle)), sizeof(handle));
uint64_t offset = static_cast<uint64_t>(from_offset);
CHECK_EQ(sock_.SendAll(&offset, sizeof(offset)), sizeof(offset));
uint64_t size = static_cast<uint64_t>(data_size);
CHECK_EQ(sock_.SendAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.SendAll(&ctx_from, sizeof(ctx_from)), sizeof(ctx_from));
CHECK_EQ(sock_.RecvAll(&code, sizeof(code)), sizeof(code));
if (code == RPCCode::kCopyAck) {
CHECK_EQ(sock_.RecvAll(reinterpret_cast<char*>(to) + to_offset, data_size),
data_size);
} else {
HandleException();
}
}
void RPCSession::SendReturnValue(
int succ, TVMValue ret_value, int ret_tcode) {
if (succ == 0) {
RPCCode code = RPCCode::kReturn;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
} else {
RPCCode code = RPCCode::kException;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
ret_value.v_str = TVMGetLastError();
ret_tcode = kStr;
}
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
template<typename F>
void RPCSession::CallHandler(F f) {
RPCArgBuffer args;
this->RecvPackedSeq(&args);
TVMRetValue rv;
TVMValue ret_value;
int ret_tcode;
try {
f(TVMArgs(args.value.data(), args.tcode.data(),
static_cast<int>(args.value.size())), &rv);
RPCCode code = RPCCode::kReturn;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
if (rv.type_code() == kStr) {
std::string str = rv;
ret_value.v_str = str.c_str();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
} else {
ret_value = rv.value();
ret_tcode = rv.type_code();
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
} catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
ret_value.v_str = e.what();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
}
void RPCSession::HandleCallFunc() {
uint64_t handle;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
PackedFunc* pf = reinterpret_cast<PackedFunc*>(handle);
CallHandler([pf](TVMArgs args, TVMRetValue *rv) {
pf->CallPacked(args, rv);
});
}
void RPCSession::HandleException() {
RPCArgBuffer ret;
this->RecvPackedSeq(&ret);
CHECK_EQ(ret.value.size(), 1U);
CHECK_EQ(ret.tcode[0], kStr);
std::ostringstream os;
os << "Except caught from RPC call: " << ret.value[0].v_str;
throw dmlc::Error(os.str());
}
void RPCSession::HandleCopyToRemote() {
uint64_t handle, offset, size;
TVMContext ctx;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
CHECK_EQ(sock_.RecvAll(&offset, sizeof(offset)), sizeof(offset));
CHECK_EQ(sock_.RecvAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.RecvAll(&ctx, sizeof(ctx)), sizeof(ctx));
int succ = 0;
if (ctx.device_type == kCPU) {
CHECK_EQ(sock_.RecvAll(reinterpret_cast<char*>(handle) + offset, size),
static_cast<size_t>(size));
} else {
temp_data_.resize(size+1);
CHECK_EQ(sock_.RecvAll(&temp_data_[0], size),
static_cast<size_t>(size));
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(ctx)->CopyDataFromTo(
temp_data_.data(), 0,
reinterpret_cast<void*>(handle), offset,
size, cpu_ctx, ctx, nullptr);
} catch (const std::runtime_error &e) {
TVMAPISetLastError(e.what());
succ = -1;
}
}
TVMValue ret_value;
ret_value.v_handle = nullptr;
int ret_tcode = kNull;
SendReturnValue(succ, ret_value, ret_tcode);
}
void RPCSession::HandleCopyFromRemote() {
uint64_t handle, offset, size;
TVMContext ctx;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
CHECK_EQ(sock_.RecvAll(&offset, sizeof(offset)), sizeof(offset));
CHECK_EQ(sock_.RecvAll(&size, sizeof(size)), sizeof(size));
CHECK_EQ(sock_.RecvAll(&ctx, sizeof(ctx)), sizeof(ctx));
if (ctx.device_type == kCPU) {
RPCCode code = RPCCode::kCopyAck;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock_.SendAll(reinterpret_cast<char*>(handle) + offset, size),
static_cast<size_t>(size));
} else {
temp_data_.resize(size + 1);
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(ctx)->CopyDataFromTo(
reinterpret_cast<void*>(handle), offset,
dmlc::BeginPtr(temp_data_), 0,
size, ctx, cpu_ctx, nullptr);
RPCCode code = RPCCode::kCopyAck;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock_.SendAll(&temp_data_[0], size),
static_cast<size_t>(size));
} catch (const std::runtime_error &e) {
RPCCode code = RPCCode::kException;
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
TVMValue ret_value;
ret_value.v_str = e.what();
int ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
}
}
}
void RPCSession::HandleReturn(TVMRetValue* rv) {
RPCArgBuffer ret;
this->RecvPackedSeq(&ret);
CHECK_EQ(ret.value.size(), 1U);
TVMArgValue argv = ret.AsTVMArgs()[0];
*rv = argv;
}
TVMContext RPCSession::StripSessMask(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_EQ(dev_type / kRPCSessMask, table_index_ + 1)
<< "Can only TVMContext related to the same remote sesstion";
ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
return ctx;
}
// packed Send sequence to the channel
void RPCSession::SendPackedSeq(
const TVMValue* arg_values, const int* type_codes, int n) {
CHECK_EQ(sock_.SendAll(&n, sizeof(n)), sizeof(n));
CHECK_EQ(sock_.SendAll(type_codes, sizeof(int) * n), sizeof(int) * n);
// Argument packing.
for (int i = 0; i < n; ++i) {
int tcode = type_codes[i];
TVMValue value = arg_values[i];
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kTVMType: {
CHECK_EQ(sock_.SendAll(&value, sizeof(TVMValue)), sizeof(TVMValue));
break;
}
case kTVMContext: {
value.v_ctx = StripSessMask(value.v_ctx);
CHECK_EQ(sock_.SendAll(&value, sizeof(TVMValue)), sizeof(TVMValue));
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
CHECK_EQ(sock_.SendAll(&handle, sizeof(uint64_t)), sizeof(uint64_t));
break;
}
case kArrayHandle: {
DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
TVMContext ctx = StripSessMask(arr->ctx);
uint64_t data = reinterpret_cast<uint64_t>(
static_cast<RemoteSpace*>(arr->data)->data);
CHECK_EQ(sock_.SendAll(&data, sizeof(uint64_t)), sizeof(uint64_t));
CHECK_EQ(sock_.SendAll(&ctx, sizeof(ctx)), sizeof(ctx));
CHECK_EQ(sock_.SendAll(&(arr->ndim), sizeof(int)), sizeof(int));
CHECK_EQ(sock_.SendAll(&(arr->dtype), sizeof(DLDataType)), sizeof(DLDataType));
CHECK_EQ(sock_.SendAll(arr->shape, sizeof(int64_t) * arr->ndim),
sizeof(int64_t) * arr->ndim);
CHECK(arr->strides == nullptr)
<< "Donot support strided remote array";
CHECK_EQ(arr->byte_offset, 0)
<< "Donot support send byte offset";
break;
}
case kNull: break;
case kStr: {
const char* s = value.v_str;
uint64_t len = strlen(s);
CHECK_EQ(sock_.SendAll(&len, sizeof(len)), sizeof(len));
CHECK_EQ(sock_.SendAll(s, sizeof(char) * len), sizeof(char) * len);
break;
}
case kBytes: {
TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
uint64_t len = bytes->size;
CHECK_EQ(sock_.SendAll(&len, sizeof(len)), sizeof(len));
CHECK_EQ(sock_.SendAll(bytes->data, sizeof(char) * len), sizeof(char) * len);
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
}
}
// Receive packed sequence from the channel
void RPCSession::RecvPackedSeq(RPCArgBuffer *buf) {
int n;
CHECK_EQ(sock_.RecvAll(&n, sizeof(n)), sizeof(n));
buf->value.resize(n);
buf->tcode.resize(n);
buf->temp_bytes.clear();
if (n != 0) {
buf->tcode.resize(n);
CHECK_EQ(sock_.RecvAll(buf->tcode.data(), sizeof(int) * n),
sizeof(int) * n);
}
buf->value.resize(n);
for (int i = 0; i < n; ++i) {
int tcode = buf->tcode[i];
TVMValue& value = buf->value[i];
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kTVMType:
case kTVMContext: {
CHECK_EQ(sock_.RecvAll(&value, sizeof(TVMValue)), sizeof(TVMValue));
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(uint64_t)), sizeof(uint64_t));
value.v_handle = reinterpret_cast<void*>(handle);
break;
}
case kNull: {
value.v_handle = nullptr;
break;
}
case kStr:
case kBytes: {
uint64_t len;
CHECK_EQ(sock_.RecvAll(&len, sizeof(len)), sizeof(len));
std::unique_ptr<RPCByteArrayBuffer> temp(new RPCByteArrayBuffer());
temp->data.resize(len);
if (len != 0) {
CHECK_EQ(sock_.RecvAll(&(temp->data[0]), sizeof(char) * len),
sizeof(char) * len);
}
if (tcode == kStr) {
value.v_str = temp->data.c_str();
} else {
temp->arr.size = static_cast<size_t>(len);
temp->arr.data = dmlc::BeginPtr(temp->data);
value.v_handle = &(temp->arr);
}
buf->temp_bytes.emplace_back(std::move(temp));
break;
}
case kArrayHandle: {
std::unique_ptr<RPCDataArrayBuffer> temp(new RPCDataArrayBuffer());
uint64_t handle;
CHECK_EQ(sock_.RecvAll(&handle, sizeof(handle)), sizeof(handle));
DLTensor& tensor = temp->tensor;
tensor.data = reinterpret_cast<void*>(handle);
CHECK_EQ(sock_.RecvAll(&(tensor.ctx), sizeof(TVMContext)), sizeof(TVMContext));
CHECK_EQ(sock_.RecvAll(&(tensor.ndim), sizeof(int)), sizeof(int));
CHECK_EQ(sock_.RecvAll(&(tensor.dtype), sizeof(DLDataType)), sizeof(DLDataType));
temp->shape.resize(tensor.ndim);
tensor.shape = temp->shape.data();
CHECK_EQ(sock_.RecvAll(tensor.shape, tensor.ndim * sizeof(int64_t)),
tensor.ndim * sizeof(int64_t));
tensor.strides = nullptr;
tensor.byte_offset = 0;
value.v_handle = &tensor;
buf->temp_array.emplace_back(std::move(temp));
break;
}
default: {
LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
break;
}
}
}
}
// Event handler functions
void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) {
std::string name = args[0];
auto *fp = tvm::runtime::Registry::Get(name);
if (fp != nullptr) {
*rv = static_cast<void*>(new tvm::runtime::PackedFunc(*fp));
} else {
*rv = nullptr;
}
}
void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) {
void* handle = args[0];
delete static_cast<PackedFunc*>(handle);
}
void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) {
TVMContext ctx = args[0];
DeviceAPI::Get(ctx)->SetDevice(ctx);
}
void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) {
TVMContext ctx = args[0];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPI::Get(ctx, true);
if (api != nullptr) {
api->GetAttr(ctx, kind, rv);
} else {
*rv = 0;
}
} else {
DeviceAPI::Get(ctx)->GetAttr(
ctx, static_cast<DeviceAttrKind>(kind), rv);
}
}
void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) {
TVMContext ctx = args[0];
uint64_t size = args[1];
uint64_t alignment = args[2];
void* data = DeviceAPI::Get(ctx)->AllocDataSpace(ctx, size, alignment);
*rv = data;
}
void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) {
TVMContext ctx = args[0];
void* ptr = args[1];
DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr);
}
void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) {
TVMContext ctx = args[0];
TVMStreamHandle handle = args[1];
DeviceAPI::Get(ctx)->StreamSync(ctx, handle);
}
void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
void* from = args[0];
uint64_t from_offset = args[1];
void* to = args[2];
uint64_t to_offset = args[3];
uint64_t size = args[4];
TVMContext ctx_from = args[5];
TVMContext ctx_to = args[6];
TVMStreamHandle stream = args[7];
TVMContext ctx = ctx_from;
if (ctx.device_type == kCPU) {
ctx = ctx_to;
} else {
CHECK(ctx_to.device_type == kCPU ||
ctx_to.device_type == ctx_from.device_type)
<< "Can not copy across different ctx types directly";
}
DeviceAPI::Get(ctx)->CopyDataFromTo(
from, from_offset,
to, to_offset,
size, ctx_from, ctx_to, stream);
}
void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
static const PackedFunc* fsys_load_ = nullptr;
if (fsys_load_ == nullptr) {
fsys_load_ = runtime::Registry::Get("tvm.contrib.rpc.server.load_module");
CHECK(fsys_load_ != nullptr);
}
std::string file_name = args[0];
TVMRetValue ret = (*fsys_load_)(file_name);
Module m = ret;
*rv = static_cast<void*>(new Module(m));
}
void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0];
delete static_cast<Module*>(mhandle);
}
void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0];
PackedFunc pf = static_cast<Module*>(mhandle)->GetFunction(
args[1], false);
*rv = static_cast<void*>(new PackedFunc(pf));
}
void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0];
std::string fmt = args[1];
*rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
}
RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
RPCCode code;
CHECK_EQ(sock_.RecvAll(&code, sizeof(int)), sizeof(int));
switch (code) {
case RPCCode::kCallFunc: HandleCallFunc(); break;
case RPCCode::kReturn: HandleReturn(rv); break;
case RPCCode::kException: HandleException(); break;
case RPCCode::kCopyFromRemote: HandleCopyFromRemote(); break;
case RPCCode::kCopyToRemote: HandleCopyToRemote(); break;
case RPCCode::kShutdown: break;
// system functions
case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break;
case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break;
case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break;
case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
default: LOG(FATAL) << "Unknown event " << static_cast<int>(code);
}
return code;
}
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_session.h
* \brief Base RPC session interface.
*/
#ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
#define TVM_RUNTIME_RPC_RPC_SESSION_H_
#include <tvm/runtime/packed_func.h>
#include <mutex>
#include <string>
#include "../device_api.h"
#include "../../common/socket.h"
namespace tvm {
namespace runtime {
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;
struct RPCArgBuffer;
/*! \brief The RPC code */
enum class RPCCode : int {
kCallFunc,
kReturn,
kException,
kShutdown,
kCopyFromRemote,
kCopyToRemote,
kCopyAck,
// The following are code that can send over CallRemote
kGetGlobalFunc,
kFreeFunc,
kDevSetDevice,
kDevGetAttr,
kDevAllocData,
kDevFreeData,
kDevStreamSync,
kCopyAmongRemote,
kModuleLoad,
kModuleFree,
kModuleGetFunc,
kModuleGetSource
};
// Bidirectional Communication Session of PackedRPC
class RPCSession {
public:
/*! \brief virtual destructor */
~RPCSession();
/*!
* \brief The server loop that server runs to handle RPC calls.
*/
void ServerLoop();
/*!
* \brief Call into remote function
* \param handle The function handle
* \param args The arguments
* \param rv The return value.
*/
void CallFunc(RPCFuncHandle handle,
TVMArgs args,
TVMRetValue* rv);
/*!
* \brief Copy bytes into remote array content.
* \param from The source host data.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param size The size of the memory.
* \param ctx_to The target context.
*/
void CopyToRemote(void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_to);
/*!
* \brief Copy bytes from remote array content.
* \param from The source host data.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param size The size of the memory.
* \param ctx_from The source context.
*/
void CopyFromRemote(void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from);
/*!
* \brief Call a remote defined system function with arguments.
* \param fcode The function code.
* \param args The arguments
* \return The returned remote value.
*/
template<typename... Args>
inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args);
/*!
* \return The session table index of the session.
*/
int table_index() const {
return table_index_;
}
/*!
* \brief Create a RPC session with given socket
* \param sock The socket.
* \return The session.
*/
static std::shared_ptr<RPCSession> Create(common::TCPSocket sock);
/*!
* \brief Try get session from the global session table by table index.
* \param table_index The table index of the session.
* \return The shared_ptr to the session, can be nullptr.
*/
static std::shared_ptr<RPCSession> Get(int table_index);
private:
/*!
* \brief Handle the remote call with f
* \param f The handle function
* \tparam F the handler function.
*/
template<typename F>
void CallHandler(F f);
void Init();
void Shutdown();
void SendReturnValue(int succ, TVMValue value, int tcode);
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n);
void RecvPackedSeq(RPCArgBuffer *buf);
RPCCode HandleNextEvent(TVMRetValue *rv);
// special handler.
void HandleCallFunc();
void HandleException();
void HandleCopyFromRemote();
void HandleCopyToRemote();
void HandleReturn(TVMRetValue* rv);
TVMContext StripSessMask(TVMContext ctx);
// Internal mutex
std::recursive_mutex mutex_;
// Internal socket
common::TCPSocket sock_;
// Internal temporal data space.
std::string temp_data_;
// call remote with the specified function coede.
PackedFunc call_remote_;
// The index of this session in RPC session table.
int table_index_{0};
};
// Remote space pointer.
struct RemoteSpace {
void* data;
std::shared_ptr<RPCSession> sess;
};
// implementation of inline functions
template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
return call_remote_(std::forward<Args>(args)...);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RPC_RPC_SESSION_H_
......@@ -52,7 +52,6 @@ def test_convert():
f = tvm.convert(myfunc)
assert isinstance(f, tvm.Function)
f(*targs)
def test_byte_array():
s = "hello"
......@@ -63,9 +62,10 @@ def test_byte_array():
f = tvm.convert(myfunc)
f(a)
if __name__ == "__main__":
test_get_global()
test_get_callback_with_node()
test_convert()
test_get_global()
test_return_func()
test_byte_array()
import tvm
import logging
import numpy as np
import time
from tvm.contrib import rpc, util
def test_rpc_simple():
if not tvm.module.enabled("rpc"):
return
@tvm.register_func("rpc.test.addone")
def addone(x):
return x + 1
@tvm.register_func("rpc.test.strcat")
def addone(name, x):
return "%s:%d" % (name, x)
server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port)
f1 = client.get_function("rpc.test.addone")
assert f1(10) == 11
f2 = client.get_function("rpc.test.strcat")
assert f2("abc", 11) == "abc:11"
def test_rpc_array():
if not tvm.module.enabled("rpc"):
return
x = np.random.randint(0, 10, size=(3, 4))
@tvm.register_func("rpc.test.remote_array_func")
def remote_array_func(y):
np.testing.assert_equal(y.asnumpy(), x)
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
print("second connect")
r_cpu = tvm.nd.array(x, remote.cpu(0))
assert str(r_cpu.context).startswith("remote")
np.testing.assert_equal(r_cpu.asnumpy(), x)
fremote = remote.get_function("rpc.test.remote_array_func")
fremote(r_cpu)
def test_rpc_file_exchange():
if not tvm.module.enabled("rpc"):
return
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
blob = bytearray(np.random.randint(0, 10, size=(127)))
remote.upload(blob, "dat.bin")
rev = remote.download("dat.bin")
def test_rpc_remote_module():
if not tvm.module.enabled("rpc"):
return
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_remote():
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
temp = util.tempdir()
ctx = remote.cpu(0)
f = tvm.build(s, [A, B], "llvm", name="myadd")
path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
remote.upload(path_dso)
f1 = remote.load_module("dev_lib.so")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_array()
test_rpc_remote_module()
test_rpc_file_exchange()
test_rpc_simple()
......@@ -19,6 +19,7 @@ fi
cp make/config.mk config.mk
echo "USE_CUDA=0" >> config.mk
echo "USE_RPC=1" >> config.mk
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
echo "USE_OPENCL=1" >> config.mk
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment