Unverified Commit afcf9397 by jmorrill Committed by GitHub

Windows Support for cpp_rpc (#4857)

* Windows Support for cpp_rpc

* Add missing patches that fix crashes under Windows

* On Windows, use python to untar vs wsl

* remove some CMakeLists.txt stuff

* more minor CMakeLists.txt changes

* Remove items from CMakeLists.txt

* Minor CMakeLists.txt changes

* More minor CMakeLists.txt changes

* Even more minor CMakeLists.txt changes

* Modify readme
parent 9a8ed5b7
...@@ -66,9 +66,14 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) ...@@ -66,9 +66,14 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_CPP_RPC "Build CPP RPC" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)
if(USE_CPP_RPC AND UNIX)
message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.")
endif()
# include directories # include directories
include_directories(${CMAKE_INCLUDE_PATH}) include_directories(${CMAKE_INCLUDE_PATH})
include_directories("include") include_directories("include")
...@@ -309,6 +314,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) ...@@ -309,6 +314,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(USE_CPP_RPC)
add_subdirectory("apps/cpp_rpc")
endif()
if(USE_RELAY_DEBUG) if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...") message(STATUS "Building Relay in debug mode...")
......
set(TVM_RPC_SOURCES
main.cc
rpc_env.cc
rpc_server.cc
)
if(WIN32)
list(APPEND TVM_RPC_SOURCES win32_process.cc)
endif()
# Set output to same directory as the other TVM libs
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
add_executable(tvm_rpc ${TVM_RPC_SOURCES})
set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)
if(WIN32)
target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX)
endif()
target_include_directories(
tvm_rpc
PUBLIC "../../include"
PUBLIC DLPACK_PATH
PUBLIC DMLC_PATH
)
target_link_libraries(tvm_rpc tvm_runtime)
\ No newline at end of file
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# TVM RPC Server # TVM RPC Server
This folder contains a simple recipe to make RPC server in c++. This folder contains a simple recipe to make RPC server in c++.
## Usage ## Usage (Non-Windows)
- Build tvm runtime - Build tvm runtime
- Make the rpc executable [Makefile](Makefile). - Make the rpc executable [Makefile](Makefile).
`make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux` `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux`
...@@ -35,6 +35,12 @@ This folder contains a simple recipe to make RPC server in c++. ...@@ -35,6 +35,12 @@ This folder contains a simple recipe to make RPC server in c++.
``` ```
- Use `./tvm_rpc server` to start the RPC server - Use `./tvm_rpc server` to start the RPC server
## Usage (Windows)
- Build tvm with the argument -DUSE_CPP_RPC
- Install [LLVM pre-build binaries](https://releases.llvm.org/download.html), making sure to select the option to add it to the PATH.
- Verify Python 3.6 or newer is installed and in the PATH.
- Use `<tmv_output_dir>\tvm_rpc.exe` to start the RPC server
## How it works ## How it works
- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library. - The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library.
...@@ -53,4 +59,4 @@ Command line usage ...@@ -53,4 +59,4 @@ Command line usage
``` ```
## Note ## Note
Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently. Currently support is only there for Linux / Android / Windows environment and proxy mode doesn't be supported currently.
\ No newline at end of file \ No newline at end of file
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
* \file rpc_server.cc * \file rpc_server.cc
* \brief RPC Server for TVM. * \brief RPC Server for TVM.
*/ */
#include <stdlib.h> #include <cstdlib>
#include <signal.h> #include <csignal>
#include <stdio.h> #include <cstdio>
#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h> #include <unistd.h>
#endif
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
...@@ -35,11 +37,15 @@ ...@@ -35,11 +37,15 @@
#include "../../src/support/socket.h" #include "../../src/support/socket.h"
#include "rpc_server.h" #include "rpc_server.h"
#if defined(_WIN32)
#include "win32_process.h"
#endif
using namespace std; using namespace std;
using namespace tvm::runtime; using namespace tvm::runtime;
using namespace tvm::support; using namespace tvm::support;
static const string kUSAGE = \ static const string kUsage = \
"Command line usage\n" \ "Command line usage\n" \
" server - Start the server\n" \ " server - Start the server\n" \
"--host - The hostname of the server, Default=0.0.0.0\n" \ "--host - The hostname of the server, Default=0.0.0.0\n" \
...@@ -73,13 +79,16 @@ struct RpcServerArgs { ...@@ -73,13 +79,16 @@ struct RpcServerArgs {
string key; string key;
string custom_addr; string custom_addr;
bool silent = false; bool silent = false;
#if defined(WIN32)
std::string mmap_path;
#endif
}; };
/*! /*!
* \brief PrintArgs print the contents of RpcServerArgs * \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure * \param args RpcServerArgs structure
*/ */
void PrintArgs(struct RpcServerArgs args) { void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "host = " << args.host; LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port; LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end; LOG(INFO) << "port_end = " << args.port_end;
...@@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) { ...@@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) {
LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
} }
#if defined(__linux__) || defined(__ANDROID__)
/*! /*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed * \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal * \param s signal
...@@ -109,7 +119,7 @@ void HandleCtrlC() { ...@@ -109,7 +119,7 @@ void HandleCtrlC() {
sigIntHandler.sa_flags = 0; sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr); sigaction(SIGINT, &sigIntHandler, nullptr);
} }
#endif
/*! /*!
* \brief GetCmdOption Parse and find the command option. * \brief GetCmdOption Parse and find the command option.
* \param argc arg counter * \param argc arg counter
...@@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { ...@@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
} }
// We assume "=" is the end of option. // We assume "=" is the end of option.
CHECK_EQ(*option.rbegin(), '='); CHECK_EQ(*option.rbegin(), '=');
cmd = arg.substr(arg.find("=") + 1); cmd = arg.substr(arg.find('=') + 1);
return cmd; return cmd;
} }
} }
...@@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) { ...@@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) {
* \brief ParseCmdArgs parses the command line arguments. * \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter * \param argc arg counter
* \param argv arg values * \param argv arg values
* \param args, the output structure which holds the parsed values * \param args the output structure which holds the parsed values
*/ */
void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
string silent = GetCmdOption(argc, argv, "--silent", true); const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) { if (!silent.empty()) {
args.silent = true; args.silent = true;
// Only errors and fatal is logged // Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2"); dmlc::InitLogging("--minloglevel=2");
} }
string host = GetCmdOption(argc, argv, "--host="); const string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) { if (!host.empty()) {
if (!ValidateIP(host)) { if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format."; LOG(WARNING) << "Wrong host address format.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.host = host; args.host = host;
} }
string port = GetCmdOption(argc, argv, "--port="); const string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) { if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) { if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number."; LOG(WARNING) << "Wrong port number.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.port = stoi(port); args.port = stoi(port);
} }
string port_end = GetCmdOption(argc, argv, "--port_end="); const string port_end = GetCmdOption(argc, argv, "--port_end=");
if (!port_end.empty()) { if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) { if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port_end number."; LOG(WARNING) << "Wrong port_end number.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.port_end = stoi(port_end); args.port_end = stoi(port_end);
...@@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { ...@@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
if (!tracker.empty()) { if (!tracker.empty()) {
if (!ValidateTracker(tracker)) { if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format."; LOG(WARNING) << "Wrong tracker address format.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.tracker = tracker; args.tracker = tracker;
} }
string key = GetCmdOption(argc, argv, "--key="); const string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) { if (!key.empty()) {
args.key = key; args.key = key;
} }
string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); const string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
if (!custom_addr.empty()) { if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) { if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format."; LOG(WARNING) << "Wrong custom address format.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.custom_addr = custom_addr; args.custom_addr = custom_addr;
} }
#if defined(WIN32)
const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
if(!mmap_path.empty()) {
args.mmap_path = mmap_path;
dmlc::InitLogging("--minloglevel=0");
}
#endif
} }
/*! /*!
...@@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { ...@@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
* \return result of operation. * \return result of operation.
*/ */
int RpcServer(int argc, char * argv[]) { int RpcServer(int argc, char * argv[]) {
struct RpcServerArgs args; RpcServerArgs args;
/* parse the command line args */ /* parse the command line args */
ParseCmdArgs(argc, argv, args); ParseCmdArgs(argc, argv, args);
PrintArgs(args); PrintArgs(args);
// Ctrl+C handler
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
#if defined(__linux__) || defined(__ANDROID__)
// Ctrl+C handler
HandleCtrlC(); HandleCtrlC();
tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker, #endif
args.key, args.custom_addr, args.silent);
#if defined(WIN32)
if(!args.mmap_path.empty()) {
int ret = 0;
try {
ChildProcSocketHandler(args.mmap_path);
} catch (const std::exception&) {
ret = -1;
}
return ret;
}
#endif
RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
return 0; return 0;
} }
...@@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) { ...@@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) {
*/ */
int main(int argc, char * argv[]) { int main(int argc, char * argv[]) {
if (argc <= 1) { if (argc <= 1) {
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
return 0; return 0;
} }
// Runs WSAStartup on Win32, no-op on POSIX
Socket::Startup();
#if defined(_WIN32)
SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1");
#endif
if (0 == strcmp(argv[1], "server")) { if (0 == strcmp(argv[1], "server")) {
RpcServer(argc, argv); return RpcServer(argc, argv);
} else {
LOG(INFO) << kUSAGE;
} }
LOG(INFO) << kUsage;
return 0; return 0;
} }
...@@ -20,77 +20,86 @@ ...@@ -20,77 +20,86 @@
* \file rpc_env.cc * \file rpc_env.cc
* \brief Server environment of the RPC. * \brief Server environment of the RPC.
*/ */
#include <cerrno>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <errno.h> #ifndef _WIN32
#ifndef _MSC_VER
#include <sys/stat.h>
#include <dirent.h> #include <dirent.h>
#include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
#else #else
#include <Windows.h> #include <Windows.h>
#include <direct.h>
namespace {
int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
}
#endif #endif
#include <cstring>
#include <fstream> #include <fstream>
#include <vector>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <cstring> #include <vector>
#include <string>
#include "rpc_env.h"
#include "../../src/support/util.h" #include "../../src/support/util.h"
#include "../../src/runtime/file_util.h" #include "../../src/runtime/file_util.h"
#include "rpc_env.h"
namespace {
std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) {
std::string untar_cmd;
untar_cmd.reserve(512);
#if defined(__linux__) || defined(__ANDROID__)
untar_cmd += "tar -C ";
untar_cmd += output_dir;
untar_cmd += " -zxf ";
untar_cmd += tar_file;
#elif defined(_WIN32)
untar_cmd += "python -m tarfile -e ";
untar_cmd += tar_file;
untar_cmd += " ";
untar_cmd += output_dir;
#endif
return untar_cmd;
}
}// Anonymous namespace
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
RPCEnv::RPCEnv() { RPCEnv::RPCEnv() {
#if defined(__linux__) || defined(__ANDROID__) base_ = "./rpc";
base_ = "./rpc"; mkdir(base_.c_str(), 0777);
mkdir(&base_[0], 0777); TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") *rv = env.GetPath(args[0]);
.set_body([](TVMArgs args, TVMRetValue* rv) { });
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) {
.set_body([](TVMArgs args, TVMRetValue *rv) { static RPCEnv env;
static RPCEnv env; std::string file_name = env.GetPath(args[0]);
std::string file_name = env.GetPath(args[0]); *rv = Load(&file_name, "");
*rv = Load(&file_name, ""); LOG(INFO) << "Load module from " << file_name << " ...";
LOG(INFO) << "Load module from " << file_name << " ..."; });
});
#else
LOG(FATAL) << "Only support RPC in linux environment";
#endif
} }
/*! /*!
* \brief GetPath To get the workpath from packed function * \brief GetPath To get the work path from packed function
* \param name The file name * \param file_name The file name
* \return The full path of file. * \return The full path of file.
*/ */
std::string RPCEnv::GetPath(std::string file_name) { std::string RPCEnv::GetPath(const std::string& file_name) const {
// we assume file_name has "/" means file_name is the exact path // we assume file_name has "/" means file_name is the exact path
// and does not create /.rpc/ // and does not create /.rpc/
if (file_name.find("/") != std::string::npos) { return file_name.find('/') != std::string::npos ? file_name : base_ + "/" + file_name;
return file_name;
} else {
return base_ + "/" + file_name;
}
} }
/*! /*!
* \brief Remove The RPC Environment cleanup function * \brief Remove The RPC Environment cleanup function
*/ */
void RPCEnv::CleanUp() { void RPCEnv::CleanUp() const {
#if defined(__linux__) || defined(__ANDROID__) CleanDir(base_);
CleanDir(&base_[0]); const int ret = rmdir(base_.c_str());
int ret = rmdir(&base_[0]); if (ret != 0) {
if (ret != 0) { LOG(WARNING) << "Remove directory " << base_ << " failed";
LOG(WARNING) << "Remove directory " << base_ << " failed"; }
}
#else
LOG(FATAL) << "Only support RPC in linux environment";
#endif
} }
/*! /*!
...@@ -98,53 +107,54 @@ void RPCEnv::CleanUp() { ...@@ -98,53 +107,54 @@ void RPCEnv::CleanUp() {
* \param dirname The root directory name * \param dirname The root directory name
* \return vector Files in directory. * \return vector Files in directory.
*/ */
std::vector<std::string> ListDir(const std::string &dirname) { std::vector<std::string> ListDir(const std::string& dirname) {
std::vector<std::string> vec; std::vector<std::string> vec;
#ifndef _MSC_VER #ifndef _WIN32
DIR *dp = opendir(dirname.c_str()); DIR* dp = opendir(dirname.c_str());
if (dp == nullptr) { if (dp == nullptr) {
int errsv = errno; int errsv = errno;
LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv); LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
} }
dirent *d; dirent* d;
while ((d = readdir(dp)) != nullptr) { while ((d = readdir(dp)) != nullptr) {
std::string filename = d->d_name; std::string filename = d->d_name;
if (filename != "." && filename != "..") { if (filename != "." && filename != "..") {
std::string f = dirname; std::string f = dirname;
if (f[f.length() - 1] != '/') { if (f[f.length() - 1] != '/') {
f += '/'; f += '/';
}
f += d->d_name;
vec.push_back(f);
} }
f += d->d_name;
vec.push_back(f);
} }
closedir(dp); }
#else closedir(dp);
WIN32_FIND_DATA fd; #elif defined(_WIN32)
std::string pattern = dirname + "/*"; WIN32_FIND_DATAA fd;
HANDLE handle = FindFirstFile(pattern.c_str(), &fd); const std::string pattern = dirname + "/*";
if (handle == INVALID_HANDLE_VALUE) { HANDLE handle = FindFirstFileA(pattern.c_str(), &fd);
int errsv = GetLastError(); if (handle == INVALID_HANDLE_VALUE) {
LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); const int errsv = GetLastError();
} LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
do { }
if (fd.cFileName != "." && fd.cFileName != "..") { do {
std::string f = dirname; std::string filename = fd.cFileName;
char clast = f[f.length() - 1]; if (filename != "." && filename != "..") {
if (f == ".") { std::string f = dirname;
f = fd.cFileName; if (f[f.length() - 1] != '/') {
} else if (clast != '/' && clast != '\\') { f += '/';
f += '/';
f += fd.cFileName;
}
vec.push_back(f);
} }
} while (FindNextFile(handle, &fd)); f += filename;
FindClose(handle); vec.push_back(f);
#endif }
} while (FindNextFileA(handle, &fd));
FindClose(handle);
#else
LOG(FATAL) << "Operating system not supported";
#endif
return vec; return vec;
} }
#if defined(__linux__) || defined(__ANDROID__)
/*! /*!
* \brief LinuxShared Creates a linux shared library * \brief LinuxShared Creates a linux shared library
* \param output The output file name * \param output The output file name
...@@ -152,9 +162,9 @@ std::vector<std::string> ListDir(const std::string &dirname) { ...@@ -152,9 +162,9 @@ std::vector<std::string> ListDir(const std::string &dirname) {
* \param options The compiler options * \param options The compiler options
* \param cc The compiler * \param cc The compiler
*/ */
void LinuxShared(const std::string output, void LinuxShared(const std::string output,
const std::vector<std::string> &files, const std::vector<std::string> &files,
std::string options = "", std::string options = "",
std::string cc = "g++") { std::string cc = "g++") {
std::string cmd = cc; std::string cmd = cc;
cmd += " -shared -fPIC "; cmd += " -shared -fPIC ";
...@@ -169,18 +179,48 @@ void LinuxShared(const std::string output, ...@@ -169,18 +179,48 @@ void LinuxShared(const std::string output,
LOG(FATAL) << err_msg; LOG(FATAL) << err_msg;
} }
} }
#endif
#ifdef _WIN32
/*!
* \brief WindowsShared Creates a Windows shared library
* \param output The output file name
* \param files The files for building
* \param options The compiler options
* \param cc The compiler
*/
void WindowsShared(const std::string& output,
const std::vector<std::string>& files,
const std::string& options = "",
const std::string& cc = "clang") {
std::string cmd = cc;
cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared ";
cmd += " -o " + output;
for (const auto& file : files) {
cmd += " " + file;
}
cmd += " " + options;
std::string err_msg;
const auto executed_status = support::Execute(cmd, &err_msg);
if (executed_status) {
LOG(FATAL) << err_msg;
}
}
#endif
/*! /*!
* \brief CreateShared Creates a shared library * \brief CreateShared Creates a shared library
* \param output The output file name * \param output The output file name
* \param files The files for building * \param files The files for building
*/ */
void CreateShared(const std::string output, const std::vector<std::string> &files) { void CreateShared(const std::string& output, const std::vector<std::string>& files) {
#if defined(__linux__) || defined(__ANDROID__) #if defined(__linux__) || defined(__ANDROID__)
LinuxShared(output, files); LinuxShared(output, files);
#else #elif defined(_WIN32)
LOG(FATAL) << "Do not support creating shared library"; WindowsShared(output, files);
#endif #else
LOG(FATAL) << "Operating system not supported";
#endif
} }
/*! /*!
...@@ -193,61 +233,52 @@ void CreateShared(const std::string output, const std::vector<std::string> &file ...@@ -193,61 +233,52 @@ void CreateShared(const std::string output, const std::vector<std::string> &file
* \param fmt The format of file * \param fmt The format of file
* \return Module The loaded module * \return Module The loaded module
*/ */
Module Load(std::string *fileIn, const std::string fmt) { Module Load(std::string *fileIn, const std::string& fmt) {
std::string file = *fileIn; const std::string& file = *fileIn;
if (support::EndsWith(file, ".so")) { if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) {
return Module::LoadFromFile(file, fmt); return Module::LoadFromFile(file, fmt);
} }
#if defined(__linux__) || defined(__ANDROID__) std::string file_name = file + ".so";
std::string file_name = file + ".so"; if (support::EndsWith(file, ".o")) {
if (support::EndsWith(file, ".o")) { std::vector<std::string> files;
std::vector<std::string> files; files.push_back(file);
files.push_back(file); CreateShared(file_name, files);
CreateShared(file_name, files); } else if (support::EndsWith(file, ".tar")) {
} else if (support::EndsWith(file, ".tar")) { const std::string tmp_dir = "./rpc/tmp/";
std::string tmp_dir = "./rpc/tmp/"; mkdir(tmp_dir.c_str(), 0777);
mkdir(&tmp_dir[0], 0777);
std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; const std::string cmd = GenerateUntarCommand(file, tmp_dir);
std::string err_msg;
int executed_status = support::Execute(cmd, &err_msg); std::string err_msg;
if (executed_status) { const int executed_status = support::Execute(cmd, &err_msg);
LOG(FATAL) << err_msg; if (executed_status) {
} LOG(FATAL) << err_msg;
CreateShared(file_name, ListDir(tmp_dir));
CleanDir(tmp_dir);
rmdir(&tmp_dir[0]);
} else {
file_name = file;
} }
*fileIn = file_name; CreateShared(file_name, ListDir(tmp_dir));
return Module::LoadFromFile(file_name, fmt); CleanDir(tmp_dir);
#else (void)rmdir(tmp_dir.c_str());
LOG(FATAL) << "Do not support creating shared library"; } else {
#endif file_name = file;
}
*fileIn = file_name;
return Module::LoadFromFile(file_name, fmt);
} }
/*! /*!
* \brief CleanDir Removes the files from the directory * \brief CleanDir Removes the files from the directory
* \param dirname The name of the directory * \param dirname The name of the directory
*/ */
void CleanDir(const std::string &dirname) { void CleanDir(const std::string& dirname) {
#if defined(__linux__) || defined(__ANDROID__) auto files = ListDir(dirname);
DIR *dp = opendir(dirname.c_str()); for (const auto& filename : files) {
dirent *d; std::string file_path = dirname + "/";
while ((d = readdir(dp)) != nullptr) { file_path += filename;
std::string filename = d->d_name; const int ret = std::remove(filename.c_str());
if (filename != "." && filename != "..") { if (ret != 0) {
filename = dirname + "/" + d->d_name; LOG(WARNING) << "Remove file " << filename << " failed";
int ret = std::remove(&filename[0]);
if (ret != 0) {
LOG(WARNING) << "Remove file " << filename << " failed";
}
}
} }
#else }
LOG(FATAL) << "Only support RPC in linux environment";
#endif
} }
} // namespace runtime } // namespace runtime
......
...@@ -40,7 +40,7 @@ namespace runtime { ...@@ -40,7 +40,7 @@ namespace runtime {
* \param file The format of file * \param file The format of file
* \return Module The loaded module * \return Module The loaded module
*/ */
Module Load(std::string *path, const std::string fmt = ""); Module Load(std::string *path, const std::string& fmt = "");
/*! /*!
* \brief CleanDir Removes the files from the directory * \brief CleanDir Removes the files from the directory
...@@ -62,11 +62,11 @@ struct RPCEnv { ...@@ -62,11 +62,11 @@ struct RPCEnv {
* \param name The file name * \param name The file name
* \return The full path of file. * \return The full path of file.
*/ */
std::string GetPath(std::string file_name); std::string GetPath(const std::string& file_name) const;
/*! /*!
* \brief The RPC Environment cleanup function * \brief The RPC Environment cleanup function
*/ */
void CleanUp(); void CleanUp() const;
private: private:
/*! /*!
......
...@@ -22,24 +22,27 @@ ...@@ -22,24 +22,27 @@
* \brief RPC Server implementation. * \brief RPC Server implementation.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#if defined(__linux__) || defined(__ANDROID__) #if defined(__linux__) || defined(__ANDROID__)
#include <sys/select.h> #include <sys/select.h>
#include <sys/wait.h> #include <sys/wait.h>
#endif #endif
#include <set>
#include <iostream>
#include <future>
#include <thread>
#include <chrono> #include <chrono>
#include <future>
#include <iostream>
#include <set>
#include <string> #include <string>
#include "rpc_server.h" #include "../../src/support/socket.h"
#include "rpc_env.h"
#include "rpc_tracker_client.h"
#include "../../src/runtime/rpc/rpc_session.h" #include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_socket_impl.h" #include "../../src/runtime/rpc/rpc_socket_impl.h"
#include "../../src/support/socket.h" #include "rpc_env.h"
#include "rpc_server.h"
#include "rpc_tracker_client.h"
#if defined(_WIN32)
#include "win32_process.h"
#endif
using namespace std::chrono;
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -49,7 +52,7 @@ namespace runtime { ...@@ -49,7 +52,7 @@ namespace runtime {
* \param status status value * \param status status value
*/ */
#if defined(__linux__) || defined(__ANDROID__) #if defined(__linux__) || defined(__ANDROID__)
static pid_t waitPidEintr(int *status) { static pid_t waitPidEintr(int* status) {
pid_t pid = 0; pid_t pid = 0;
while ((pid = waitpid(-1, status, 0)) == -1) { while ((pid = waitpid(-1, status, 0)) == -1) {
if (errno == EINTR) { if (errno == EINTR) {
...@@ -76,34 +79,32 @@ class RPCServer { ...@@ -76,34 +79,32 @@ class RPCServer {
public: public:
/*! /*!
* \brief Constructor. * \brief Constructor.
*/ */
RPCServer(const std::string &host, RPCServer(std::string host, int port, int port_end, std::string tracker_addr,
int port, std::string key, std::string custom_addr) :
int port_end, host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end),
const std::string &tracker_addr, tracker_addr_(std::move(tracker_addr)), key_(std::move(key)),
const std::string &key, custom_addr_(std::move(custom_addr))
const std::string &custom_addr) { {
// Init the values
host_ = host;
port_ = port;
port_end_ = port_end;
tracker_addr_ = tracker_addr;
key_ = key;
custom_addr_ = custom_addr;
} }
/*! /*!
* \brief Destructor. * \brief Destructor.
*/ */
~RPCServer() { ~RPCServer() {
// Free the resources try {
tracker_sock_.Close(); // Free the resources
listen_sock_.Close(); tracker_sock_.Close();
listen_sock_.Close();
} catch(...) {
}
} }
/*! /*!
* \brief Start Creates the RPC listen process and execution. * \brief Start Creates the RPC listen process and execution.
*/ */
void Start() { void Start() {
listen_sock_.Create(); listen_sock_.Create();
my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_);
...@@ -130,102 +131,98 @@ class RPCServer { ...@@ -130,102 +131,98 @@ class RPCServer {
tracker.TryConnect(); tracker.TryConnect();
// step 2: wait for in-coming connections // step 2: wait for in-coming connections
AcceptConnection(&tracker, &conn, &addr, &opts); AcceptConnection(&tracker, &conn, &addr, &opts);
} } catch (const char* msg) {
catch (const char* msg) {
LOG(WARNING) << "Socket exception: " << msg; LOG(WARNING) << "Socket exception: " << msg;
// close tracker resource // close tracker resource
tracker.Close(); tracker.Close();
continue; continue;
} } catch (const std::exception& e) {
catch (std::exception& e) { // close tracker resource
// Other errors tracker.Close();
LOG(WARNING) << "Exception standard: " << e.what(); LOG(WARNING) << "Exception standard: " << e.what();
continue; continue;
} }
int timeout = GetTimeOutFromOpts(opts); int timeout = GetTimeOutFromOpts(opts);
#if defined(__linux__) || defined(__ANDROID__) #if defined(__linux__) || defined(__ANDROID__)
// step 3: serving // step 3: serving
if (timeout != 0) { if (timeout != 0) {
const pid_t timer_pid = fork(); const pid_t timer_pid = fork();
if (timer_pid == 0) { if (timer_pid == 0) {
// Timer process // Timer process
sleep(timeout); sleep(timeout);
exit(0); exit(0);
} }
const pid_t worker_pid = fork(); const pid_t worker_pid = fork();
if (worker_pid == 0) { if (worker_pid == 0) {
// Worker process // Worker process
ServerLoopProc(conn, addr); ServerLoopProc(conn, addr);
exit(0); exit(0);
} }
int status = 0; int status = 0;
const pid_t finished_first = waitPidEintr(&status); const pid_t finished_first = waitPidEintr(&status);
if (finished_first == timer_pid) { if (finished_first == timer_pid) {
kill(worker_pid, SIGKILL); kill(worker_pid, SIGKILL);
} else if (finished_first == worker_pid) { } else if (finished_first == worker_pid) {
kill(timer_pid, SIGKILL); kill(timer_pid, SIGKILL);
} else { } else {
LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue.";
} }
int status_second = 0; int status_second = 0;
waitPidEintr(&status_second); waitPidEintr(&status_second);
// Logging. // Logging.
if (finished_first == timer_pid) { if (finished_first == timer_pid) {
LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout
<< "), Process status = " << status_second; << "), Process status = " << status_second;
} else if (finished_first == worker_pid) { } else if (finished_first == worker_pid) {
LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second;
}
} else {
auto pid = fork();
if (pid == 0) {
ServerLoopProc(conn, addr);
exit(0);
}
// Wait for the result
int status = 0;
wait(&status);
LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status;
} }
#else } else {
// step 3: serving auto pid = fork();
std::future<void> proc(std::async(std::launch::async, if (pid == 0) {
&RPCServer::ServerLoopProc, this, conn, addr)); ServerLoopProc(conn, addr);
// wait until server process finish or timeout exit(0);
if (timeout != 0) {
// Autoterminate after timeout
proc.wait_for(std::chrono::seconds(timeout));
} else {
// Wait for the result
proc.get();
} }
#endif // Wait for the result
int status = 0;
wait(&status);
LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status;
}
#elif defined(WIN32)
auto start_time = high_resolution_clock::now();
try {
SpawnRPCChild(conn.sockfd, seconds(timeout));
} catch (const std::exception&) {
}
auto dur = high_resolution_clock::now() - start_time;
LOG(INFO) << "Serve Time " << duration_cast<milliseconds>(dur).count() << "ms";
#endif
// close from our side. // close from our side.
LOG(INFO) << "Socket Connection Closed"; LOG(INFO) << "Socket Connection Closed";
conn.Close(); conn.Close();
} }
} }
/*! /*!
* \brief AcceptConnection Accepts the RPC Server connection. * \brief AcceptConnection Accepts the RPC Server connection.
* \param tracker Tracker details. * \param tracker Tracker details.
* \param conn New connection information. * \param conn_sock New connection information.
* \param addr New connection address information. * \param addr New connection address information.
* \param opts Parsed options for socket * \param opts Parsed options for socket
* \param ping_period Timeout for select call waiting * \param ping_period Timeout for select call waiting
*/ */
void AcceptConnection(TrackerClient* tracker, void AcceptConnection(TrackerClient* tracker,
support::TCPSocket* conn_sock, support::TCPSocket* conn_sock,
support::SockAddr* addr, support::SockAddr* addr,
std::string* opts, std::string* opts,
int ping_period = 2) { int ping_period = 2) {
std::set <std::string> old_keyset; std::set<std::string> old_keyset;
std::string matchkey; std::string matchkey;
// Report resource to tracker and get key // Report resource to tracker and get key
...@@ -236,7 +233,7 @@ class RPCServer { ...@@ -236,7 +233,7 @@ class RPCServer {
support::TCPSocket conn = listen_sock_.Accept(addr); support::TCPSocket conn = listen_sock_.Accept(addr);
int code = kRPCMagic; int code = kRPCMagic;
CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) { if (code != kRPCMagic) {
conn.Close(); conn.Close();
LOG(FATAL) << "Client connected is not TVM RPC server"; LOG(FATAL) << "Client connected is not TVM RPC server";
...@@ -265,15 +262,15 @@ class RPCServer { ...@@ -265,15 +262,15 @@ class RPCServer {
std::string arg0; std::string arg0;
ssin >> arg0; ssin >> arg0;
if (arg0 != expect_header) { if (arg0 != expect_header) {
code = kRPCMismatch; code = kRPCMismatch;
CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
conn.Close(); conn.Close();
LOG(WARNING) << "Mismatch key from" << addr->AsString(); LOG(WARNING) << "Mismatch key from" << addr->AsString();
continue; continue;
} else { } else {
code = kRPCSuccess; code = kRPCSuccess;
CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
keylen = server_key.length(); keylen = int(server_key.length());
CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen);
LOG(INFO) << "Connection success " << addr->AsString(); LOG(INFO) << "Connection success " << addr->AsString();
...@@ -289,25 +286,23 @@ class RPCServer { ...@@ -289,25 +286,23 @@ class RPCServer {
* \param sock The socket information * \param sock The socket information
* \param addr The socket address information * \param addr The socket address information
*/ */
void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) { static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
// Server loop // Server loop
auto env = RPCEnv(); const auto env = RPCEnv();
RPCServerLoop(sock.sockfd); RPCServerLoop(int(sock.sockfd));
LOG(INFO) << "Finish serving " << addr.AsString(); LOG(INFO) << "Finish serving " << addr.AsString();
env.CleanUp(); env.CleanUp();
} }
/*! /*!
* \brief GetTimeOutFromOpts Parse and get the timeout option. * \brief GetTimeOutFromOpts Parse and get the timeout option.
* \param opts The option string * \param opts The option string
* \param timeout value after parsing.
*/ */
int GetTimeOutFromOpts(std::string opts) { int GetTimeOutFromOpts(const std::string& opts) const {
std::string cmd; const std::string option = "-timeout=";
std::string option = "-timeout=";
if (opts.find(option) == 0) { if (opts.find(option) == 0) {
cmd = opts.substr(opts.find_last_of(option) + 1); const std::string cmd = opts.substr(opts.find_last_of(option) + 1);
CHECK(support::IsNumber(cmd)) << "Timeout is not valid"; CHECK(support::IsNumber(cmd)) << "Timeout is not valid";
return std::stoi(cmd); return std::stoi(cmd);
} }
...@@ -325,29 +320,40 @@ class RPCServer { ...@@ -325,29 +320,40 @@ class RPCServer {
support::TCPSocket tracker_sock_; support::TCPSocket tracker_sock_;
}; };
#if defined(WIN32)
/*!
* \brief ServerLoopFromChild The Server loop process.
* \param socket The socket information
*/
void ServerLoopFromChild(SOCKET socket) {
// Server loop
tvm::support::TCPSocket sock(socket);
const auto env = RPCEnv();
RPCServerLoop(int(sock.sockfd));
sock.Close();
env.CleanUp();
}
#endif
/*! /*!
* \brief RPCServerCreate Creates the RPC Server. * \brief RPCServerCreate Creates the RPC Server.
* \param host The hostname of the server, Default=0.0.0.0 * \param host The hostname of the server, Default=0.0.0.0
* \param port The port of the RPC, Default=9090 * \param port The port of the RPC, Default=9090
* \param port_end The end search port of the RPC, Default=9199 * \param port_end The end search port of the RPC, Default=9199
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \param key The key used to identify the device type in tracker. Default="" * \param key The key used to identify the device type in tracker. Default=""
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param silent Whether run in silent mode. Default=True * \param silent Whether run in silent mode. Default=True
*/ */
void RPCServerCreate(std::string host, void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
int port, std::string key, std::string custom_addr, bool silent) {
int port_end,
std::string tracker_addr,
std::string key,
std::string custom_addr,
bool silent) {
if (silent) { if (silent) {
// Only errors and fatal is logged // Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2"); dmlc::InitLogging("--minloglevel=2");
} }
// Start the rpc server // Start the rpc server
RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr); RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr));
rpc.Start(); rpc.Start();
} }
......
...@@ -30,6 +30,15 @@ ...@@ -30,6 +30,15 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
#if defined(WIN32)
/*!
* \brief ServerLoopFromChild The Server loop process.
* \param sock The socket information
* \param addr The socket address information
*/
void ServerLoopFromChild(SOCKET socket);
#endif
/*! /*!
* \brief RPCServerCreate Creates the RPC Server. * \brief RPCServerCreate Creates the RPC Server.
* \param host The hostname of the server, Default=0.0.0.0 * \param host The hostname of the server, Default=0.0.0.0
...@@ -40,13 +49,13 @@ namespace runtime { ...@@ -40,13 +49,13 @@ namespace runtime {
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param silent Whether run in silent mode. Default=True * \param silent Whether run in silent mode. Default=True
*/ */
TVM_DLL void RPCServerCreate(std::string host = "", void RPCServerCreate(std::string host = "",
int port = 9090, int port = 9090,
int port_end = 9099, int port_end = 9099,
std::string tracker_addr = "", std::string tracker_addr = "",
std::string key = "", std::string key = "",
std::string custom_addr = "", std::string custom_addr = "",
bool silent = true); bool silent = true);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_ #endif // TVM_APPS_CPP_RPC_SERVER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#include <winsock2.h>
#include <ws2tcpip.h>
#include <cstdio>
#include <memory>
#include <conio.h>
#include <string>
#include <stdexcept>
#include <dmlc/logging.h>
#include "win32_process.h"
#include "rpc_server.h"
using namespace std::chrono;
using namespace tvm::runtime;
namespace {
// The prefix path for the memory mapped file used to store IPC information
const std::string kMemoryMapPrefix = "/MAPPED_FILE/TVM_RPC";
// Used to construct unique names for named resources in the parent process
const std::string kParent = "parent";
// Used to construct unique names for named resources in the child process
const std::string kChild = "child";
// The timeout of the WIN32 events, in the parent and the child
const milliseconds kEventTimeout(2000);
// Used to create unique WIN32 mmap paths and event names
int child_counter_ = 0;
/*!
* \brief HandleDeleter Deleter for UniqueHandle smart pointer
* \param handle The WIN32 HANDLE to manage
*/
struct HandleDeleter {
void operator()(HANDLE handle) const {
if (handle != INVALID_HANDLE_VALUE && handle != nullptr) {
CloseHandle(handle);
}
}
};
/*!
* \brief UniqueHandle Smart pointer to manage a WIN32 HANDLE
*/
using UniqueHandle = std::unique_ptr<void, HandleDeleter>;
/*!
* \brief MakeUniqueHandle Helper method to construct a UniqueHandle
* \param handle The WIN32 HANDLE to manage
*/
UniqueHandle MakeUniqueHandle(HANDLE handle) {
if (handle == INVALID_HANDLE_VALUE || handle == nullptr) {
return nullptr;
}
return UniqueHandle(handle);
}
/*!
* \brief GetSocket Gets the socket info from the parent process and duplicates the socket
* \param mmap_path The path to the memory mapped info set by the parent
*/
SOCKET GetSocket(const std::string& mmap_path) {
WSAPROTOCOL_INFO protocol_info;
const std::string parent_event_name = mmap_path + kParent;
const std::string child_event_name = mmap_path + kChild;
// Open the events
UniqueHandle parent_file_mapping_event;
if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
}
UniqueHandle child_file_mapping_event;
if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
}
// Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read
if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) {
LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError();
}
const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE,
false,
mmap_path.c_str()));
if (!file_map) {
LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
}
void* map_view = MapViewOfFile(file_map.get(),
FILE_MAP_READ | FILE_MAP_WRITE,
0, 0, 0);
SOCKET sock_duplicated = INVALID_SOCKET;
if (map_view != nullptr) {
memcpy(&protocol_info, map_view, sizeof(WSAPROTOCOL_INFO));
UnmapViewOfFile(map_view);
// Creates the duplicate socket, that was created in the parent
sock_duplicated = WSASocket(FROM_PROTOCOL_INFO,
FROM_PROTOCOL_INFO,
FROM_PROTOCOL_INFO,
&protocol_info,
0,
0);
// Let the parent know we are finished dupicating the socket
SetEvent(child_file_mapping_event.get());
} else {
LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError();
}
return sock_duplicated;
}
}// Anonymous namespace
namespace tvm {
namespace runtime {
/*!
* \brief SpawnRPCChild Spawns a child process with a given timeout to run
* \param fd The client socket to duplicate in the child
* \param timeout The time in seconds to wait for the child to complete before termination
*/
void SpawnRPCChild(SOCKET fd, seconds timeout) {
STARTUPINFOA startup_info;
memset(&startup_info, 0, sizeof(startup_info));
startup_info.cb = sizeof(startup_info);
std::string file_map_path = kMemoryMapPrefix + std::to_string(child_counter_++);
const std::string parent_event_name = file_map_path + kParent;
const std::string child_event_name = file_map_path + kChild;
// Create an event to let the child know the socket info was set to the mmap file
UniqueHandle parent_file_mapping_event;
if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "CreateEvent for parent file mapping failed";
}
UniqueHandle child_file_mapping_event;
// An event to let the parent know the socket info was read from the mmap file
if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) {
LOG(FATAL) << "CreateEvent for child file mapping failed";
}
char current_executable[MAX_PATH];
// Get the full path of the current executable
GetModuleFileNameA(nullptr, current_executable, MAX_PATH);
std::string child_command_line = current_executable;
child_command_line += " server --child_proc=";
child_command_line += file_map_path;
// CreateProcessA requires a non const char*, so we copy our std::string
std::unique_ptr<char[]> command_line_ptr(new char[child_command_line.size() + 1]);
strcpy(command_line_ptr.get(), child_command_line.c_str());
PROCESS_INFORMATION child_process_info;
if (CreateProcessA(nullptr,
command_line_ptr.get(),
nullptr,
nullptr,
false,
CREATE_NO_WINDOW,
nullptr,
nullptr,
&startup_info,
&child_process_info)) {
// Child process and thread handles must be closed, so wrapped in RAII
auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess);
auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread);
WSAPROTOCOL_INFO protocol_info;
// Get info needed to duplicate the socket
if (WSADuplicateSocket(fd,
child_process_info.dwProcessId,
&protocol_info) == SOCKET_ERROR) {
LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError();
}
// Create a mmap file to store the info needed for duplicating the SOCKET in the child proc
UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE,
nullptr,
PAGE_READWRITE,
0,
sizeof(WSAPROTOCOL_INFO),
file_map_path.c_str()));
if (!file_map) {
LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
}
if (GetLastError() == ERROR_ALREADY_EXISTS) {
LOG(FATAL) << "CreateFileMapping(): mapping file already exists";
} else {
void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
if (map_view != nullptr) {
memcpy(map_view, &protocol_info, sizeof(WSAPROTOCOL_INFO));
UnmapViewOfFile(map_view);
// Let child proc know the mmap file is ready to be read
SetEvent(parent_file_mapping_event.get());
// Wait for the child to finish reading mmap file
if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) {
TerminateProcess(child_process_handle.get(), 0);
LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process.";
}
} else {
TerminateProcess(child_process_handle.get(), 0);
LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError();
}
}
const DWORD process_timeout = timeout.count()
? uint32_t(duration_cast<milliseconds>(timeout).count())
: INFINITE;
// Wait for child process to exit, or hit configured timeout
if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) {
LOG(INFO) << "Child process timeout. Terminating.";
TerminateProcess(child_process_handle.get(), 0);
}
} else {
LOG(INFO) << "Create child process failed: " << GetLastError();
}
}
/*!
* \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket
* \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent
*/
void ChildProcSocketHandler(const std::string& mmap_path) {
SOCKET socket;
// Set high thread priority to avoid the thread scheduler from
// interfering with any measurements in the RPC server.
SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL);
if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) {
tvm::runtime::ServerLoopFromChild(socket);
}
else {
LOG(FATAL) << "GetSocket() failed";
}
}
} // namespace runtime
} // namespace tvm
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file win32_process.h
* \brief Win32 process code to mimic a POSIX fork()
*/
#ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
#define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
#include <chrono>
#include <string>
namespace tvm {
namespace runtime {
/*!
* \brief SpawnRPCChild Spawns a child process with a given timeout to run
* \param fd The client socket to duplicate in the child
* \param timeout The time in seconds to wait for the child to complete before termination
*/
void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout);
/*!
* \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket
* \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent
*/
void ChildProcSocketHandler(const std::string& mmap_path);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
\ No newline at end of file
...@@ -34,8 +34,12 @@ class SockChannel final : public RPCChannel { ...@@ -34,8 +34,12 @@ class SockChannel final : public RPCChannel {
explicit SockChannel(support::TCPSocket sock) explicit SockChannel(support::TCPSocket sock)
: sock_(sock) {} : sock_(sock) {}
~SockChannel() { ~SockChannel() {
if (!sock_.BadSocket()) { try {
sock_.Close(); // BadSocket can throw
if (!sock_.BadSocket()) {
sock_.Close();
}
} catch (...) {
} }
} }
size_t Send(const void* data, size_t size) final { size_t Send(const void* data, size_t size) final {
...@@ -100,7 +104,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) { ...@@ -100,7 +104,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) {
return CreateRPCModule(RPCConnect(url, port, "client:" + key)); return CreateRPCModule(RPCConnect(url, port, "client:" + key));
} }
void RPCServerLoop(int sockfd) { // TVM_DLL needed for MSVC
TVM_DLL void RPCServerLoop(int sockfd) {
support::TCPSocket sock( support::TCPSocket sock(
static_cast<support::TCPSocket::SockType>(sockfd)); static_cast<support::TCPSocket::SockType>(sockfd));
RPCSession::Create( RPCSession::Create(
......
...@@ -63,7 +63,7 @@ class RingBuffer { ...@@ -63,7 +63,7 @@ class RingBuffer {
size_t ncopy = head_ptr_ + bytes_available_ - old_size; size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy); memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
} }
} else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) {
// shrink too large temporary buffer to avoid out of memory on some embedded devices // shrink too large temporary buffer to avoid out of memory on some embedded devices
size_t old_bytes = bytes_available_; size_t old_bytes = bytes_available_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment