Commit 134c6ba3 by Tianqi Chen Committed by GitHub

[RUNTIME] RPC runtime that support run testing on remote device. (#147)

* [RUNTIME] RPC runtime that support run testing on remote device.

* Fix ctypes in OSX.

* fix lint
parent b7fe6119
...@@ -8,6 +8,7 @@ endif() ...@@ -8,6 +8,7 @@ endif()
include(cmake/Util.cmake) include(cmake/Util.cmake)
tvm_option(USE_CUDA "Build with CUDA" ON) tvm_option(USE_CUDA "Build with CUDA" ON)
tvm_option(USE_OPENCL "Build with OpenCL" OFF) tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_RPC "Build with RPC" OFF)
tvm_option(USE_LLVM "Build with LLVM" OFF) tvm_option(USE_LLVM "Build with LLVM" OFF)
tvm_option(USE_RTTI "Build with RTTI" OFF) tvm_option(USE_RTTI "Build with RTTI" OFF)
tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(USE_MSVC_MT "Build with MT" OFF)
...@@ -67,6 +68,7 @@ file(GLOB RUNTIME_SRCS src/runtime/*.cc) ...@@ -67,6 +68,7 @@ file(GLOB RUNTIME_SRCS src/runtime/*.cc)
file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc) file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc)
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc) file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
if(USE_CUDA) if(USE_CUDA)
find_package(CUDA) find_package(CUDA)
...@@ -98,6 +100,11 @@ else(USE_OPENCL) ...@@ -98,6 +100,11 @@ else(USE_OPENCL)
add_definitions(-DTVM_OPENCL_RUNTIME=0) add_definitions(-DTVM_OPENCL_RUNTIME=0)
endif(USE_OPENCL) endif(USE_OPENCL)
if(USE_RPC)
message(STATUS "Build with RPC support...")
list(APPEND RUNTIME_SRCS ${RUNTIME_RPC_SRCS})
endif(USE_RPC)
if(USE_LLVM) if(USE_LLVM)
find_package(LLVM REQUIRED CONFIG) find_package(LLVM REQUIRED CONFIG)
message(STATUS "Build with LLVM support...") message(STATUS "Build with LLVM support...")
......
...@@ -20,10 +20,16 @@ LIB_HALIDEIR = HalideIR/lib/libHalideIR.a ...@@ -20,10 +20,16 @@ LIB_HALIDEIR = HalideIR/lib/libHalideIR.a
CC_SRC = $(filter-out src/contrib/%.cc src/runtime/%.cc,\ CC_SRC = $(filter-out src/contrib/%.cc src/runtime/%.cc,\
$(wildcard src/*/*.cc src/*/*/*.cc)) $(wildcard src/*/*.cc src/*/*/*.cc))
METAL_SRC = $(wildcard src/runtime/metal/*.mm) METAL_SRC = $(wildcard src/runtime/metal/*.mm)
RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc) CUDA_SRC = $(wildcard src/runtime/cuda/*.cc)
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
# Objectives # Objectives
METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC)) METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC))
CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC))
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC))
RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC)) RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
CONTRIB_OBJ = CONTRIB_OBJ =
...@@ -51,6 +57,7 @@ endif ...@@ -51,6 +57,7 @@ endif
ifeq ($(USE_CUDA), 1) ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1 CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart -lnvrtc LDFLAGS += -lcuda -lcudart -lnvrtc
RUNTIME_DEP += $(CUDA_OBJ)
else else
CFLAGS += -DTVM_CUDA_RUNTIME=0 CFLAGS += -DTVM_CUDA_RUNTIME=0
endif endif
...@@ -62,6 +69,7 @@ ifeq ($(USE_OPENCL), 1) ...@@ -62,6 +69,7 @@ ifeq ($(USE_OPENCL), 1)
else else
LDFLAGS += -lOpenCL LDFLAGS += -lOpenCL
endif endif
RUNTIME_DEP += $(OPENCL_OBJ)
else else
CFLAGS += -DTVM_OPENCL_RUNTIME=0 CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif endif
...@@ -75,6 +83,10 @@ else ...@@ -75,6 +83,10 @@ else
CFLAGS += -DTVM_METAL_RUNTIME=0 CFLAGS += -DTVM_METAL_RUNTIME=0
endif endif
ifeq ($(USE_RPC), 1)
RUNTIME_DEP += $(RPC_OBJ)
endif
# llvm configuration # llvm configuration
ifdef LLVM_CONFIG ifdef LLVM_CONFIG
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3) LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
......
...@@ -57,13 +57,14 @@ typedef enum { ...@@ -57,13 +57,14 @@ typedef enum {
// that is used by TVM API calls. // that is used by TVM API calls.
kHandle = 3U, kHandle = 3U,
kNull = 4U, kNull = 4U,
kArrayHandle = 5U, kTVMType = 5U,
kTVMType = 6U, kTVMContext = 6U,
kNodeHandle = 7U, kArrayHandle = 7U,
kModuleHandle = 8U, kNodeHandle = 8U,
kFuncHandle = 9U, kModuleHandle = 9U,
kStr = 10U, kFuncHandle = 10U,
kBytes = 11U kStr = 11U,
kBytes = 12U
} TVMTypeCode; } TVMTypeCode;
/*! /*!
...@@ -98,6 +99,7 @@ typedef union { ...@@ -98,6 +99,7 @@ typedef union {
void* v_handle; void* v_handle;
const char* v_str; const char* v_str;
TVMType v_type; TVMType v_type;
TVMContext v_ctx;
} TVMValue; } TVMValue;
/*! /*!
......
...@@ -205,6 +205,10 @@ class TVMPODValue_ { ...@@ -205,6 +205,10 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kArrayHandle); TVM_CHECK_TYPE_CODE(type_code_, kArrayHandle);
return static_cast<TVMArray*>(value_.v_handle); return static_cast<TVMArray*>(value_.v_handle);
} }
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
int type_code() const { int type_code() const {
return type_code_; return type_code_;
} }
...@@ -254,6 +258,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -254,6 +258,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator bool; using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*; using TVMPODValue_::operator void*;
using TVMPODValue_::operator TVMArray*; using TVMPODValue_::operator TVMArray*;
using TVMPODValue_::operator TVMContext;
// conversion operator. // conversion operator.
operator std::string() const { operator std::string() const {
if (type_code_ == kTVMType) { if (type_code_ == kTVMType) {
...@@ -333,6 +338,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -333,6 +338,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator bool; using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*; using TVMPODValue_::operator void*;
using TVMPODValue_::operator TVMArray*; using TVMPODValue_::operator TVMArray*;
using TVMPODValue_::operator TVMContext;
// Disable copy and assign from another value, but allow move. // Disable copy and assign from another value, but allow move.
TVMRetValue(const TVMRetValue& other) { TVMRetValue(const TVMRetValue& other) {
this->Assign(other); this->Assign(other);
...@@ -474,7 +480,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -474,7 +480,7 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
case kModuleHandle: { case kModuleHandle: {
SwitchToClass<PackedFunc>(kModuleHandle, other); SwitchToClass<Module>(kModuleHandle, other);
break; break;
} }
case kNodeHandle: { case kNodeHandle: {
...@@ -532,6 +538,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -532,6 +538,7 @@ inline const char* TypeCode2Str(int type_code) {
case kNodeHandle: return "NodeHandle"; case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle"; case kArrayHandle: return "ArrayHandle";
case kTVMType: return "TVMType"; case kTVMType: return "TVMType";
case kTVMContext: return "TVMContext";
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
default: LOG(FATAL) << "unknown type_code=" default: LOG(FATAL) << "unknown type_code="
...@@ -659,6 +666,10 @@ class TVMArgsSetter { ...@@ -659,6 +666,10 @@ class TVMArgsSetter {
values_[i].v_handle = value; values_[i].v_handle = value;
type_codes_[i] = kArrayHandle; type_codes_[i] = kArrayHandle;
} }
void operator()(size_t i, TVMContext value) const {
values_[i].v_ctx = value;
type_codes_[i] = kTVMContext;
}
void operator()(size_t i, TVMType value) const { void operator()(size_t i, TVMType value) const {
values_[i].v_type = value; values_[i].v_type = value;
type_codes_[i] = kTVMType; type_codes_[i] = kTVMType;
...@@ -674,6 +685,10 @@ class TVMArgsSetter { ...@@ -674,6 +685,10 @@ class TVMArgsSetter {
values_[i].v_str = value.c_str(); values_[i].v_str = value.c_str();
type_codes_[i] = kStr; type_codes_[i] = kStr;
} }
void operator()(size_t i, TVMByteArray& value) const { // NOLINT(*)
values_[i].v_handle = &value;
type_codes_[i] = kBytes;
}
void operator()(size_t i, PackedFunc& value) const { // NOLINT(*) void operator()(size_t i, PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = &value; values_[i].v_handle = &value;
type_codes_[i] = kFuncHandle; type_codes_[i] = kFuncHandle;
......
...@@ -37,6 +37,9 @@ USE_OPENCL = 0 ...@@ -37,6 +37,9 @@ USE_OPENCL = 0
# whether enable Metal during compile # whether enable Metal during compile
USE_METAL = 0 USE_METAL = 0
# Whether enable RPC during compile
USE_RPC = 0
# whether build with LLVM support # whether build with LLVM support
# Requires LLVM version >= 4.0 # Requires LLVM version >= 4.0
# Set LLVM_CONFIG to your version, uncomment to build with llvm support # Set LLVM_CONFIG to your version, uncomment to build with llvm support
......
...@@ -10,7 +10,7 @@ from numbers import Number, Integral ...@@ -10,7 +10,7 @@ from numbers import Number, Integral
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMByteArray from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
...@@ -107,6 +107,9 @@ def _make_tvm_args(args, temp_args): ...@@ -107,6 +107,9 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMType): elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg)) values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext):
values[i].v_ctx = arg
type_codes[i] = TypeCode.TVM_CONTEXT
elif isinstance(arg, bytearray): elif isinstance(arg, bytearray):
arr = TVMByteArray() arr = TVMByteArray()
arr.data = ctypes.cast( arr.data = ctypes.cast(
......
...@@ -13,13 +13,15 @@ class TypeCode(object): ...@@ -13,13 +13,15 @@ class TypeCode(object):
FLOAT = 2 FLOAT = 2
HANDLE = 3 HANDLE = 3
NULL = 4 NULL = 4
ARRAY_HANDLE = 5 TVM_TYPE = 5
TVM_TYPE = 6 TVM_CONTEXT = 6
NODE_HANDLE = 7 ARRAY_HANDLE = 7
MODULE_HANDLE = 8 NODE_HANDLE = 8
FUNC_HANDLE = 9 MODULE_HANDLE = 9
STR = 10 FUNC_HANDLE = 10
BYTES = 11 STR = 11
BYTES = 12
class TVMValue(ctypes.Union): class TVMValue(ctypes.Union):
"""TVMValue in C API""" """TVMValue in C API"""
......
...@@ -10,13 +10,14 @@ cdef enum TVMTypeCode: ...@@ -10,13 +10,14 @@ cdef enum TVMTypeCode:
kFloat = 2 kFloat = 2
kHandle = 3 kHandle = 3
kNull = 4 kNull = 4
kArrayHandle = 5 kTVMType = 5
kTVMType = 6 kTVMContext = 6
kNodeHandle = 7 kArrayHandle = 7
kModuleHandle = 8 kNodeHandle = 8
kFuncHandle = 9 kModuleHandle = 9
kStr = 10 kFuncHandle = 10
kBytes = 11 kStr = 11
kBytes = 12
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType: ctypedef struct DLDataType:
...@@ -43,6 +44,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -43,6 +44,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* v_handle void* v_handle
const char* v_str const char* v_str
DLDataType v_type DLDataType v_type
DLContext v_ctx
ctypedef int64_t tvm_index_t ctypedef int64_t tvm_index_t
ctypedef void* DLTensorHandle ctypedef void* DLTensorHandle
......
...@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF ...@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMByteArray from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
print("TVM: Initializing cython mode...") print("TVM: Initializing cython mode...")
...@@ -110,6 +110,10 @@ cdef inline void make_arg(object arg, ...@@ -110,6 +110,10 @@ cdef inline void make_arg(object arg,
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kStr
temp_args.append(tstr) temp_args.append(tstr)
elif isinstance(arg, TVMContext):
value[0].v_ctx = (<DLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kTVMContext
elif isinstance(arg, bytearray): elif isinstance(arg, bytearray):
arr = TVMByteArray() arr = TVMByteArray()
arr.data = ctypes.cast( arr.data = ctypes.cast(
...@@ -170,6 +174,8 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -170,6 +174,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
return make_ret_bytes(value.v_handle) return make_ret_bytes(value.v_handle)
elif tcode == kHandle: elif tcode == kHandle:
return ctypes_handle(value.v_handle) return ctypes_handle(value.v_handle)
elif tcode == kTVMContext:
return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
elif tcode == kModuleHandle: elif tcode == kModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle)) return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle: elif tcode == kFuncHandle:
......
...@@ -61,6 +61,9 @@ class ModuleBase(object): ...@@ -61,6 +61,9 @@ class ModuleBase(object):
self.handle = handle self.handle = handle
self._entry = None self._entry = None
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
@property @property
def entry_func(self): def entry_func(self):
"""Get the entry function """Get the entry function
......
...@@ -64,6 +64,7 @@ class TVMType(ctypes.Structure): ...@@ -64,6 +64,7 @@ class TVMType(ctypes.Structure):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
RPC_SESS_MASK = 128
class TVMContext(ctypes.Structure): class TVMContext(ctypes.Structure):
"""TVM context strucure.""" """TVM context strucure."""
...@@ -121,6 +122,11 @@ class TVMContext(ctypes.Structure): ...@@ -121,6 +122,11 @@ class TVMContext(ctypes.Structure):
return not self.__eq__(other) return not self.__eq__(other)
def __repr__(self): def __repr__(self):
if self.device_type >= RPC_SESS_MASK:
tbl_id = self.device_type / RPC_SESS_MASK - 1
dev_type = self.device_type % RPC_SESS_MASK
return "remote[%d]:%s(%d)" % (
tbl_id, TVMContext.MASK2STR[dev_type], self.device_id)
return "%s(%d)" % ( return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id) TVMContext.MASK2STR[self.device_type], self.device_id)
......
"""RPC interface for easy testing.
RPC enables connect to a remote server, upload and launch functions.
This is useful to for cross-compile and remote testing,
The compiler stack runs on local server, while we use RPC server
to run on remote runtime which don't have a compiler available.
The test program compiles the program on local server,
upload and run remote RPC server, get the result back to verify correctness.
"""
from __future__ import absolute_import
import os
import socket
import struct
import logging
import multiprocessing
from . import util
from ..module import load as _load_module
from .._ffi.function import _init_api, register_func
from .._ffi.ndarray import context as _context
RPC_MAGIC = 0xff271
RPC_SESS_MASK = 128
def _serve_loop(sock, addr):
"""Server loop"""
sockfd = sock.fileno()
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.upload")
def upload(file_name, blob):
"""Upload the blob to remote temp file"""
path = temp.relpath(file_name)
with open(path, "wb") as out_file:
out_file.write(blob)
@register_func("tvm.contrib.rpc.server.download")
def download(file_name):
"""Download file from remote"""
path = temp.relpath(file_name)
dat = bytearray(open(path, "rb").read())
return dat
@register_func("tvm.contrib.rpc.server.load_module")
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
m = _load_module(path)
return m
_ServerLoop(sockfd)
temp.remove()
logging.info("Finish serving %s", addr)
def _recvall(sock, nbytes):
res = []
nread = 0
while nread < nbytes:
chunk = sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b''.join(res)
def _listen_loop(sock):
"""Lisenting loop"""
while True:
conn, addr = sock.accept()
logging.info("RPCServer: connection from %s", addr)
conn.sendall(struct.pack('@i', RPC_MAGIC))
magic = struct.unpack('@i', _recvall(conn, 4))[0]
if magic != RPC_MAGIC:
conn.close()
continue
logging.info("Connection from %s", addr)
process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
process.deamon = True
process.start()
# close from our side.
conn.close()
class Server(object):
"""Start RPC server on a seperate process.
This is a simple python implementation based on multi-processing.
It is also possible to implement a similar C based sever with
TVM runtime which does not depend on the python.
Parameter
---------
host : str
The host url of the server.
port : int
The port to be bind to
port_end : int, optional
The end port to search
"""
def __init__(self, host, port=9091, port_end=9199):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
for port in range(port, port_end):
try:
sock.bind((host, port))
self.port = port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
else:
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCServer: bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.host = host
self.proc = multiprocessing.Process(target=_listen_loop, args=(self.sock,))
self.proc.start()
def terminate(self):
"""Terminate the server process"""
if self.proc:
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()
class RPCSession(object):
"""RPC Client session module
Do not directly create the obhect, call connect
"""
# pylint: disable=invalid-name
def __init__(self, sess):
self._sess = sess
self._tbl_index = _SessTableIndex(sess)
self._upload_func = None
self._download_func = None
def get_function(self, name):
"""Get function from the session.
Parameters
----------
name : str
The name of the function
Returns
-------
f : Function
The result function.
"""
return self._sess.get_function(name)
def context(self, dev_type, dev_id=0):
"""Construct a remote context.
Parameters
----------
dev_type: int or str
dev_id: int, optional
Returns
-------
ctx: TVMContext
The corresponding encoded remote context.
"""
ctx = _context(dev_type, dev_id)
encode = (self._tbl_index + 1) * RPC_SESS_MASK
ctx.device_type += encode
return ctx
def cpu(self, dev_id=0):
"""Construct remote CPU device."""
return self.context(1, dev_id)
def gpu(self, dev_id=0):
"""Construct remote GPU device."""
return self.context(2, dev_id)
def cl(self, dev_id=0):
"""Construct remote OpenCL device."""
return self.context(4, dev_id)
def metal(self, dev_id=0):
"""Construct remote Metal device."""
return self.context(8, dev_id)
def upload(self, data, target=None):
"""Upload file to remote runtime temp folder
Parameters
----------
data : str or bytearray
The file name or binary in local to upload.
target : str, optional
The path in remote
"""
if isinstance(data, bytearray):
if not target:
raise ValueError("target must present when file is a bytearray")
blob = data
else:
blob = bytearray(open(data, "rb").read())
if not target:
target = os.path.basename(data)
if not self._upload_func:
self._upload_func = self.get_function(
"tvm.contrib.rpc.server.upload")
self._upload_func(target, blob)
def download(self, path):
"""Download file from remote temp folder.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
blob : bytearray
The result blob from the file.
"""
if not self._download_func:
self._download_func = self.get_function(
"tvm.contrib.rpc.server.download")
return self._download_func(path)
def load_module(self, path):
"""Load a remote module, the file need to be uploaded first.
Parameters
----------
path : str
The relative location to remote temp folder.
Returns
-------
m : Module
The remote module containing remote function.
"""
return _LoadRemoteModule(self._sess, path)
def connect(url, port):
"""Connect to RPC Server
Parameters
----------
url : str
The url of the host
port : int
The port to connect to
Returns
-------
sess : RPCSession
The connected session.
"""
sess = _Connect(url, port)
return RPCSession(sess)
_init_api("tvm.contrib.rpc")
...@@ -12,8 +12,14 @@ class TempDirectory(object): ...@@ -12,8 +12,14 @@ class TempDirectory(object):
def __init__(self): def __init__(self):
self.temp_dir = tempfile.mkdtemp() self.temp_dir = tempfile.mkdtemp()
def __del__(self): def remove(self):
"""Remote the tmp dir"""
if self.temp_dir:
shutil.rmtree(self.temp_dir) shutil.rmtree(self.temp_dir)
self.temp_dir = None
def __del__(self):
self.remove()
def relpath(self, name): def relpath(self, name):
"""Relative path in temp dir """Relative path in temp dir
......
...@@ -32,6 +32,11 @@ def cpu(dev_id=0): ...@@ -32,6 +32,11 @@ def cpu(dev_id=0):
---------- ----------
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
Returns
-------
ctx : TVMContext
The created context
""" """
return TVMContext(1, dev_id) return TVMContext(1, dev_id)
...@@ -43,6 +48,11 @@ def gpu(dev_id=0): ...@@ -43,6 +48,11 @@ def gpu(dev_id=0):
---------- ----------
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
Returns
-------
ctx : TVMContext
The created context
""" """
return TVMContext(2, dev_id) return TVMContext(2, dev_id)
...@@ -54,6 +64,11 @@ def opencl(dev_id=0): ...@@ -54,6 +64,11 @@ def opencl(dev_id=0):
---------- ----------
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
Returns
-------
ctx : TVMContext
The created context
""" """
return TVMContext(4, dev_id) return TVMContext(4, dev_id)
...@@ -65,6 +80,11 @@ def metal(dev_id=0): ...@@ -65,6 +80,11 @@ def metal(dev_id=0):
---------- ----------
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
Returns
-------
ctx : TVMContext
The created context
""" """
return TVMContext(8, dev_id) return TVMContext(8, dev_id)
...@@ -76,6 +96,11 @@ def vpi(dev_id=0): ...@@ -76,6 +96,11 @@ def vpi(dev_id=0):
---------- ----------
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
Returns
-------
ctx : TVMContext
The created context
""" """
return TVMContext(9, dev_id) return TVMContext(9, dev_id)
......
...@@ -44,8 +44,8 @@ class VPIDeviceAPI final : public runtime::DeviceAPI { ...@@ -44,8 +44,8 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
if (ptr + size >= ram_max_) return nullptr; if (ptr + size >= ram_max_) return nullptr;
return (char*)(&ram_[0]) + ptr; // NOLINT(*) return (char*)(&ram_[0]) + ptr; // NOLINT(*)
} }
void SetDevice(int dev_id) final {} void SetDevice(TVMContext ctx) final {}
void GetAttr(int dev_id, runtime::DeviceAttrKind kind, TVMRetValue* rv) final { void GetAttr(TVMContext ctx, runtime::DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == runtime::kExist) { if (kind == runtime::kExist) {
*rv = 1; *rv = 1;
} }
......
...@@ -34,6 +34,7 @@ class DeviceAPIManager { ...@@ -34,6 +34,7 @@ class DeviceAPIManager {
private: private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_; std::array<DeviceAPI*, kMaxDeviceAPI> api_;
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_; std::mutex mutex_;
// constructor // constructor
DeviceAPIManager() { DeviceAPIManager() {
...@@ -45,25 +46,38 @@ class DeviceAPIManager { ...@@ -45,25 +46,38 @@ class DeviceAPIManager {
return &inst; return &inst;
} }
// Get or initialize API. // Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing); DeviceAPI* GetAPI(int type, bool allow_missing) {
}; if (type < kRPCSessMask) {
DeviceAPI* DeviceAPIManager::GetAPI(int type, bool allow_missing) {
if (api_[type] != nullptr) return api_[type]; if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type]; if (api_[type] != nullptr) return api_[type];
std::string factory = "device_api." + DeviceName(type); api_[type] = GetAPI(DeviceName(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
std::lock_guard<std::mutex> lock(mutex_);
if (rpc_api_ != nullptr) return rpc_api_;
rpc_api_ = GetAPI("rpc", allow_missing);
return rpc_api_;
}
}
DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
std::string factory = "device_api." + name;
auto* f = Registry::Get(factory); auto* f = Registry::Get(factory);
if (f == nullptr) { if (f == nullptr) {
CHECK(allow_missing) CHECK(allow_missing)
<< "Device API " << DeviceName(type) << " is not enabled."; << "Device API " << name << " is not enabled.";
return nullptr; return nullptr;
} }
void* ptr = (*f)(); void* ptr = (*f)();
api_[type] = static_cast<DeviceAPI*>(ptr); return static_cast<DeviceAPI*>(ptr);
return api_[type]; }
} };
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
return DeviceAPIManager::Get(
static_cast<int>(ctx.device_type), allow_missing);
}
inline TVMArray* TVMArrayCreate_() { inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray(); TVMArray* arr = new TVMArray();
...@@ -293,7 +307,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -293,7 +307,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
[func, resource_handle](TVMArgs args, TVMRetValue* rv) { [func, resource_handle](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle); args.num_args, rv, resource_handle);
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError(); if (ret != 0) {
std::ostringstream os;
os << "TVMCall CFunc Error:\n" << TVMGetLastError();
throw dmlc::Error(os.str());
}
}); });
} else { } else {
// wrap it in a shared_ptr, with fin as deleter. // wrap it in a shared_ptr, with fin as deleter.
...@@ -303,7 +321,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -303,7 +321,11 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
[func, rpack](TVMArgs args, TVMRetValue* rv) { [func, rpack](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get()); args.num_args, rv, rpack.get());
CHECK_EQ(ret, 0) << "TVMCall CFunc Error:\n" << TVMGetLastError(); if (ret != 0) {
std::ostringstream os;
os << "TVMCall CFunc Error:\n" << TVMGetLastError();
throw dmlc::Error(os.str());
}
}); });
} }
API_END(); API_END();
...@@ -375,25 +397,28 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { ...@@ -375,25 +397,28 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
// set device api // set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
int dev_type = args[0]; TVMContext ctx;
int dev_id = args[1]; ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
DeviceAPIManager::Get(dev_type)->SetDevice(dev_id); ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
}); });
// set device api // set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr") TVM_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
int dev_type = args[0]; TVMContext ctx;
int dev_id = args[1]; ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int()); DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) { if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(dev_type, true); DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) { if (api != nullptr) {
api->GetAttr(dev_id, kind, ret); api->GetAttr(ctx, kind, ret);
} else { } else {
*ret = 0; *ret = 0;
} }
} else { } else {
DeviceAPIManager::Get(dev_type)->GetAttr(dev_id, kind, ret); DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
} }
}); });
...@@ -13,8 +13,8 @@ namespace runtime { ...@@ -13,8 +13,8 @@ namespace runtime {
class CPUDeviceAPI final : public DeviceAPI { class CPUDeviceAPI final : public DeviceAPI {
public: public:
void SetDevice(int dev_id) final {} void SetDevice(TVMContext ctx) final {}
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final { void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == kExist) { if (kind == kExist) {
*rv = 1; *rv = 1;
} }
......
...@@ -17,26 +17,26 @@ namespace runtime { ...@@ -17,26 +17,26 @@ namespace runtime {
class CUDADeviceAPI final : public DeviceAPI { class CUDADeviceAPI final : public DeviceAPI {
public: public:
void SetDevice(int dev_id) final { void SetDevice(TVMContext ctx) final {
CUDA_CALL(cudaSetDevice(dev_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
} }
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final { void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
int value; int value;
switch (kind) { switch (kind) {
case kExist: case kExist:
value = ( value = (
cudaDeviceGetAttribute( cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, dev_id) &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
== cudaSuccess); == cudaSuccess);
break; break;
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, dev_id)); &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
break; break;
} }
case kWarpSize: { case kWarpSize: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrWarpSize, dev_id)); &value, cudaDevAttrWarpSize, ctx.device_id));
break; break;
} }
} }
......
...@@ -24,18 +24,18 @@ class DeviceAPI { ...@@ -24,18 +24,18 @@ class DeviceAPI {
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~DeviceAPI() {} virtual ~DeviceAPI() {}
/*! /*!
* \brief Set the environment device id to dev_id * \brief Set the environment device id to ctx
* \param dev_id The device id. * \param ctx The context to be set.
* \return The allocated device pointer * \return The allocated device pointer
*/ */
virtual void SetDevice(int dev_id) = 0; virtual void SetDevice(TVMContext ctx) = 0;
/*! /*!
* \brief Get attribute of specified device. * \brief Get attribute of specified device.
* \param dev_id The device id * \param ctx The device context
* \param kind The result kind * \param kind The result kind
* \param rv The return value. * \param rv The return value.
*/ */
virtual void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) = 0; virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0;
/*! /*!
* \brief Allocate a data space on device. * \brief Allocate a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
...@@ -77,8 +77,18 @@ class DeviceAPI { ...@@ -77,8 +77,18 @@ class DeviceAPI {
* \param stream The stream to be sync. * \param stream The stream to be sync.
*/ */
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0; virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
/*!
* \brief Get device API base don context.
* \param ctx The context
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
}; };
/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
/*! /*!
* \brief The name of Device API factory. * \brief The name of Device API factory.
* \param type The device type. * \param type The device type.
......
...@@ -26,7 +26,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -26,7 +26,7 @@ class DSOModuleNode final : public ModuleNode {
if (lib_handle_) Unload(); if (lib_handle_) Unload();
} }
const char* type_key() const { const char* type_key() const final {
return "dso"; return "dso";
} }
......
...@@ -60,8 +60,8 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -60,8 +60,8 @@ class MetalWorkspace final : public DeviceAPI {
// Return false if already initialized, otherwise return true. // Return false if already initialized, otherwise return true.
void Init(); void Init();
// override device API // override device API
void SetDevice(int dev_id) final; void SetDevice(TVMContext ctx) final;
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final; void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final; void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from, void CopyDataFromTo(const void* from,
......
...@@ -18,9 +18,9 @@ MetalWorkspace* MetalWorkspace::Global() { ...@@ -18,9 +18,9 @@ MetalWorkspace* MetalWorkspace::Global() {
} }
void MetalWorkspace::GetAttr( void MetalWorkspace::GetAttr(
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) { TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init(); this->Init();
size_t index = static_cast<size_t>(dev_id); size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) { if (kind == kExist) {
*rv = int(index< devices.size()); *rv = int(index< devices.size());
return; return;
...@@ -30,7 +30,7 @@ void MetalWorkspace::GetAttr( ...@@ -30,7 +30,7 @@ void MetalWorkspace::GetAttr(
switch (kind) { switch (kind) {
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
*rv = static_cast<int>( *rv = static_cast<int>(
[devices[dev_id] maxThreadsPerThreadgroup].width); [devices[ctx.device_id] maxThreadsPerThreadgroup].width);
break; break;
} }
case kWarpSize: { case kWarpSize: {
...@@ -69,7 +69,7 @@ int GetWarpSize(id<MTLDevice> dev) { ...@@ -69,7 +69,7 @@ int GetWarpSize(id<MTLDevice> dev) {
[NSString stringWithUTF8String:kDummyKernel] [NSString stringWithUTF8String:kDummyKernel]
options:nil options:nil
error:&error_msg]; error:&error_msg];
CHECK(lib != nil) << error_msg; CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
id<MTLFunction> f = id<MTLFunction> f =
[lib [lib
newFunctionWithName: newFunctionWithName:
...@@ -79,7 +79,7 @@ int GetWarpSize(id<MTLDevice> dev) { ...@@ -79,7 +79,7 @@ int GetWarpSize(id<MTLDevice> dev) {
[dev [dev
newComputePipelineStateWithFunction:f newComputePipelineStateWithFunction:f
error:&error_msg]; error:&error_msg];
CHECK(state != nil) << error_msg; CHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
return state.threadExecutionWidth; return state.threadExecutionWidth;
} }
...@@ -109,8 +109,8 @@ void MetalWorkspace::Init() { ...@@ -109,8 +109,8 @@ void MetalWorkspace::Init() {
} }
} }
void MetalWorkspace::SetDevice(int dev_id) { void MetalWorkspace::SetDevice(TVMContext ctx) {
MetalThreadEntry::ThreadLocal()->context.device_id = dev_id; MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
} }
void* MetalWorkspace::AllocDataSpace( void* MetalWorkspace::AllocDataSpace(
......
...@@ -97,6 +97,8 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -97,6 +97,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "codegen.build_stackvm"; f_name = "codegen.build_stackvm";
} else if (target == "llvm") { } else if (target == "llvm") {
f_name = "codegen.build_llvm"; f_name = "codegen.build_llvm";
} else if (target == "rpc") {
f_name = "device_api.rpc";
} else if (target == "vpi" || target == "verilog") { } else if (target == "vpi" || target == "verilog") {
f_name = "device_api.vpi"; f_name = "device_api.vpi";
} else { } else {
......
...@@ -139,8 +139,8 @@ class OpenCLWorkspace final : public DeviceAPI { ...@@ -139,8 +139,8 @@ class OpenCLWorkspace final : public DeviceAPI {
return queues[ctx.device_id]; return queues[ctx.device_id];
} }
// override device API // override device API
void SetDevice(int dev_id) final; void SetDevice(TVMContext ctx) final;
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final; void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final; void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from, void CopyDataFromTo(const void* from,
......
...@@ -18,14 +18,14 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { ...@@ -18,14 +18,14 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
return &inst; return &inst;
} }
void OpenCLWorkspace::SetDevice(int dev_id) { void OpenCLWorkspace::SetDevice(TVMContext ctx) {
OpenCLThreadEntry::ThreadLocal()->context.device_id = dev_id; OpenCLThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
} }
void OpenCLWorkspace::GetAttr( void OpenCLWorkspace::GetAttr(
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) { TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init(); this->Init();
size_t index = static_cast<size_t>(dev_id); size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) { if (kind == kExist) {
*rv = static_cast<int>(index< devices.size()); *rv = static_cast<int>(index< devices.size());
return; return;
......
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_device_api.cc
*/
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include "./rpc_session.h"
#include "../device_api.h"
namespace tvm {
namespace runtime {
class RPCDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {
GetSess(ctx)->CallRemote(
RPCCode::kDevSetDevice, ctx);
}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
*rv = GetSess(ctx)->CallRemote(
RPCCode::kDevGetAttr, ctx, static_cast<int>(kind));
}
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
auto sess = GetSess(ctx);
void *data = sess->CallRemote(
RPCCode::kDevAllocData, ctx, size, alignment);
RemoteSpace* space = new RemoteSpace();
space->data = data;
space->sess = std::move(sess);
return space;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr);
GetSess(ctx)->CallRemote(
RPCCode::kDevFreeData, ctx, space->data);
delete space;
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final {
int from_dev_type = ctx_from.device_type;
int to_dev_type = ctx_to.device_type;
if (from_dev_type > kRPCSessMask &&
to_dev_type > kRPCSessMask) {
CHECK(ctx_from.device_type == ctx_to.device_type)
<< "Cannot copy across two different remote session";
GetSess(ctx_from)->CallRemote(
RPCCode::kCopyAmongRemote,
static_cast<const RemoteSpace*>(from)->data, from_offset,
static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_from, ctx_to, stream);
} else if (from_dev_type > kRPCSessMask &&
to_dev_type == kCPU) {
GetSess(ctx_from)->CopyFromRemote(
static_cast<const RemoteSpace*>(from)->data, from_offset,
to, to_offset, size,
ctx_from);
} else if (from_dev_type == kCPU &&
to_dev_type > kRPCSessMask) {
GetSess(ctx_to)->CopyToRemote(
(void*)from, from_offset, // NOLINT(*)
static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_to);
} else {
LOG(FATAL) << "expect copy from/to remote or between remote";
}
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
GetSess(ctx)->CallRemote(
RPCCode::kDevStreamSync, ctx, stream);
}
private:
std::shared_ptr<RPCSession> GetSess(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_GE(dev_type, kRPCSessMask);
int tbl_index = dev_type / kRPCSessMask - 1;
return RPCSession::Get(tbl_index);
}
};
TVM_REGISTER_GLOBAL("device_api.rpc")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCDeviceAPI inst;
DeviceAPI* ptr = &inst;
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_device_api.cc
* \brief RPC module.
*/
#include <tvm/runtime/registry.h>
#include <memory>
#include "./rpc_session.h"
namespace tvm {
namespace runtime {
const int kRPCMagic = 0xff271;
// Wrapped remote function to packed func.
struct RPCWrappedFunc {
public:
RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess)
: handle_(handle), sess_(sess) {}
void operator()(TVMArgs args, TVMRetValue *rv) const {
sess_->CallFunc(handle_, args, rv);
}
~RPCWrappedFunc() {
sess_->CallRemote(RPCCode::kFreeFunc, handle_);
}
private:
void* handle_{nullptr};
std::shared_ptr<RPCSession> sess_;
};
// RPC that represents a remote module session.
class RPCModuleNode final : public ModuleNode {
public:
RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
: module_handle_(module_handle), sess_(sess) {
}
~RPCModuleNode() {
if (module_handle_ != nullptr) {
sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
}
}
const char* type_key() const final {
return "rpc";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
RPCFuncHandle handle = nullptr;
if (module_handle_ == nullptr) {
handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
} else {
handle = sess_->CallRemote(
RPCCode::kModuleGetFunc, module_handle_, name);
}
if (handle == nullptr) return PackedFunc();
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv);
});
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "RPCModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "RPCModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final {
if (module_handle_ != nullptr) {
std::string ret = sess_->CallRemote(
RPCCode::kModuleGetSource, module_handle_, format);
}
return "";
}
std::shared_ptr<RPCSession>& sess() {
return sess_;
}
private:
// The module handle
void* module_handle_{nullptr};
// The local channel
std::shared_ptr<RPCSession> sess_;
};
Module RPCConnect(std::string url, int port) {
common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port);
sock.Create();
CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed";
// hand shake
int code = kRPCMagic;
CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
sock.Close();
LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
}
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, RPCSession::Create(sock));
return Module(n);
}
void RPCServerLoop(int sockfd) {
common::TCPSocket sock(
static_cast<common::TCPSocket::SockType>(sockfd));
RPCSession::Create(sock)->ServerLoop();
}
TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCConnect(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
CHECK_EQ(tkey, "rpc");
auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(mhandle, sess);
*rv = Module(n);
});
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
std::string tkey = m->type_key();
CHECK_EQ(tkey, "rpc");
*rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
});
TVM_REGISTER_GLOBAL("contrib.rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerLoop(args[0]);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_session.h
* \brief Base RPC session interface.
*/
#ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
#define TVM_RUNTIME_RPC_RPC_SESSION_H_
#include <tvm/runtime/packed_func.h>
#include <mutex>
#include <string>
#include "../device_api.h"
#include "../../common/socket.h"
namespace tvm {
namespace runtime {
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;
struct RPCArgBuffer;
/*! \brief The RPC code */
enum class RPCCode : int {
kCallFunc,
kReturn,
kException,
kShutdown,
kCopyFromRemote,
kCopyToRemote,
kCopyAck,
// The following are code that can send over CallRemote
kGetGlobalFunc,
kFreeFunc,
kDevSetDevice,
kDevGetAttr,
kDevAllocData,
kDevFreeData,
kDevStreamSync,
kCopyAmongRemote,
kModuleLoad,
kModuleFree,
kModuleGetFunc,
kModuleGetSource
};
// Bidirectional Communication Session of PackedRPC
class RPCSession {
public:
/*! \brief virtual destructor */
~RPCSession();
/*!
* \brief The server loop that server runs to handle RPC calls.
*/
void ServerLoop();
/*!
* \brief Call into remote function
* \param handle The function handle
* \param args The arguments
* \param rv The return value.
*/
void CallFunc(RPCFuncHandle handle,
TVMArgs args,
TVMRetValue* rv);
/*!
* \brief Copy bytes into remote array content.
* \param from The source host data.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param size The size of the memory.
* \param ctx_to The target context.
*/
void CopyToRemote(void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_to);
/*!
* \brief Copy bytes from remote array content.
* \param from The source host data.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param size The size of the memory.
* \param ctx_from The source context.
*/
void CopyFromRemote(void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from);
/*!
* \brief Call a remote defined system function with arguments.
* \param fcode The function code.
* \param args The arguments
* \return The returned remote value.
*/
template<typename... Args>
inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args);
/*!
* \return The session table index of the session.
*/
int table_index() const {
return table_index_;
}
/*!
* \brief Create a RPC session with given socket
* \param sock The socket.
* \return The session.
*/
static std::shared_ptr<RPCSession> Create(common::TCPSocket sock);
/*!
* \brief Try get session from the global session table by table index.
* \param table_index The table index of the session.
* \return The shared_ptr to the session, can be nullptr.
*/
static std::shared_ptr<RPCSession> Get(int table_index);
private:
/*!
* \brief Handle the remote call with f
* \param f The handle function
* \tparam F the handler function.
*/
template<typename F>
void CallHandler(F f);
void Init();
void Shutdown();
void SendReturnValue(int succ, TVMValue value, int tcode);
void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n);
void RecvPackedSeq(RPCArgBuffer *buf);
RPCCode HandleNextEvent(TVMRetValue *rv);
// special handler.
void HandleCallFunc();
void HandleException();
void HandleCopyFromRemote();
void HandleCopyToRemote();
void HandleReturn(TVMRetValue* rv);
TVMContext StripSessMask(TVMContext ctx);
// Internal mutex
std::recursive_mutex mutex_;
// Internal socket
common::TCPSocket sock_;
// Internal temporal data space.
std::string temp_data_;
// call remote with the specified function coede.
PackedFunc call_remote_;
// The index of this session in RPC session table.
int table_index_{0};
};
// Remote space pointer.
struct RemoteSpace {
void* data;
std::shared_ptr<RPCSession> sess;
};
// implementation of inline functions
template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
CHECK_EQ(sock_.SendAll(&code, sizeof(code)), sizeof(code));
return call_remote_(std::forward<Args>(args)...);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RPC_RPC_SESSION_H_
...@@ -52,7 +52,6 @@ def test_convert(): ...@@ -52,7 +52,6 @@ def test_convert():
f = tvm.convert(myfunc) f = tvm.convert(myfunc)
assert isinstance(f, tvm.Function) assert isinstance(f, tvm.Function)
f(*targs)
def test_byte_array(): def test_byte_array():
s = "hello" s = "hello"
...@@ -63,9 +62,10 @@ def test_byte_array(): ...@@ -63,9 +62,10 @@ def test_byte_array():
f = tvm.convert(myfunc) f = tvm.convert(myfunc)
f(a) f(a)
if __name__ == "__main__": if __name__ == "__main__":
test_get_global()
test_get_callback_with_node() test_get_callback_with_node()
test_convert() test_convert()
test_get_global()
test_return_func() test_return_func()
test_byte_array() test_byte_array()
import tvm
import logging
import numpy as np
import time
from tvm.contrib import rpc, util
def test_rpc_simple():
if not tvm.module.enabled("rpc"):
return
@tvm.register_func("rpc.test.addone")
def addone(x):
return x + 1
@tvm.register_func("rpc.test.strcat")
def addone(name, x):
return "%s:%d" % (name, x)
server = rpc.Server("localhost")
client = rpc.connect(server.host, server.port)
f1 = client.get_function("rpc.test.addone")
assert f1(10) == 11
f2 = client.get_function("rpc.test.strcat")
assert f2("abc", 11) == "abc:11"
def test_rpc_array():
if not tvm.module.enabled("rpc"):
return
x = np.random.randint(0, 10, size=(3, 4))
@tvm.register_func("rpc.test.remote_array_func")
def remote_array_func(y):
np.testing.assert_equal(y.asnumpy(), x)
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
print("second connect")
r_cpu = tvm.nd.array(x, remote.cpu(0))
assert str(r_cpu.context).startswith("remote")
np.testing.assert_equal(r_cpu.asnumpy(), x)
fremote = remote.get_function("rpc.test.remote_array_func")
fremote(r_cpu)
def test_rpc_file_exchange():
if not tvm.module.enabled("rpc"):
return
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
blob = bytearray(np.random.randint(0, 10, size=(127)))
remote.upload(blob, "dat.bin")
rev = remote.download("dat.bin")
def test_rpc_remote_module():
if not tvm.module.enabled("rpc"):
return
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
def check_remote():
if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
temp = util.tempdir()
ctx = remote.cpu(0)
f = tvm.build(s, [A, B], "llvm", name="myadd")
path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
remote.upload(path_dso)
f1 = remote.load_module("dev_lib.so")
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_remote()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
test_rpc_array()
test_rpc_remote_module()
test_rpc_file_exchange()
test_rpc_simple()
...@@ -19,6 +19,7 @@ fi ...@@ -19,6 +19,7 @@ fi
cp make/config.mk config.mk cp make/config.mk config.mk
echo "USE_CUDA=0" >> config.mk echo "USE_CUDA=0" >> config.mk
echo "USE_RPC=1" >> config.mk
if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
echo "USE_OPENCL=1" >> config.mk echo "USE_OPENCL=1" >> config.mk
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment