Unverified Commit afcf9397 by jmorrill Committed by GitHub

Windows Support for cpp_rpc (#4857)

* Windows Support for cpp_rpc

* Add missing patches that fix crashes under Windows

* On Windows, use python to untar vs wsl

* remove some CMakeLists.txt stuff

* more minor CMakeLists.txt changes

* Remove items from CMakeLists.txt

* Minor CMakeLists.txt changes

* More minor CMakeLists.txt changes

* Even more minor CMakeLists.txt changes

* Modify readme
parent 9a8ed5b7
...@@ -66,9 +66,14 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) ...@@ -66,9 +66,14 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_CPP_RPC "Build CPP RPC" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)
if(USE_CPP_RPC AND UNIX)
message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.")
endif()
# include directories # include directories
include_directories(${CMAKE_INCLUDE_PATH}) include_directories(${CMAKE_INCLUDE_PATH})
include_directories("include") include_directories("include")
...@@ -309,6 +314,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) ...@@ -309,6 +314,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(USE_CPP_RPC)
add_subdirectory("apps/cpp_rpc")
endif()
if(USE_RELAY_DEBUG) if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...") message(STATUS "Building Relay in debug mode...")
......
set(TVM_RPC_SOURCES
main.cc
rpc_env.cc
rpc_server.cc
)
if(WIN32)
list(APPEND TVM_RPC_SOURCES win32_process.cc)
endif()
# Set output to same directory as the other TVM libs
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
add_executable(tvm_rpc ${TVM_RPC_SOURCES})
set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)
if(WIN32)
target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX)
endif()
target_include_directories(
tvm_rpc
PUBLIC "../../include"
PUBLIC DLPACK_PATH
PUBLIC DMLC_PATH
)
target_link_libraries(tvm_rpc tvm_runtime)
\ No newline at end of file
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# TVM RPC Server # TVM RPC Server
This folder contains a simple recipe to make RPC server in c++. This folder contains a simple recipe to make RPC server in c++.
## Usage ## Usage (Non-Windows)
- Build tvm runtime - Build tvm runtime
- Make the rpc executable [Makefile](Makefile). - Make the rpc executable [Makefile](Makefile).
`make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux` `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux`
...@@ -35,6 +35,12 @@ This folder contains a simple recipe to make RPC server in c++. ...@@ -35,6 +35,12 @@ This folder contains a simple recipe to make RPC server in c++.
``` ```
- Use `./tvm_rpc server` to start the RPC server - Use `./tvm_rpc server` to start the RPC server
## Usage (Windows)
- Build tvm with the argument -DUSE_CPP_RPC
- Install [LLVM pre-build binaries](https://releases.llvm.org/download.html), making sure to select the option to add it to the PATH.
- Verify Python 3.6 or newer is installed and in the PATH.
- Use `<tmv_output_dir>\tvm_rpc.exe` to start the RPC server
## How it works ## How it works
- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library. - The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library.
...@@ -53,4 +59,4 @@ Command line usage ...@@ -53,4 +59,4 @@ Command line usage
``` ```
## Note ## Note
Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently. Currently support is only there for Linux / Android / Windows environment and proxy mode doesn't be supported currently.
\ No newline at end of file \ No newline at end of file
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
* \file rpc_server.cc * \file rpc_server.cc
* \brief RPC Server for TVM. * \brief RPC Server for TVM.
*/ */
#include <stdlib.h> #include <cstdlib>
#include <signal.h> #include <csignal>
#include <stdio.h> #include <cstdio>
#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h> #include <unistd.h>
#endif
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
...@@ -35,11 +37,15 @@ ...@@ -35,11 +37,15 @@
#include "../../src/support/socket.h" #include "../../src/support/socket.h"
#include "rpc_server.h" #include "rpc_server.h"
#if defined(_WIN32)
#include "win32_process.h"
#endif
using namespace std; using namespace std;
using namespace tvm::runtime; using namespace tvm::runtime;
using namespace tvm::support; using namespace tvm::support;
static const string kUSAGE = \ static const string kUsage = \
"Command line usage\n" \ "Command line usage\n" \
" server - Start the server\n" \ " server - Start the server\n" \
"--host - The hostname of the server, Default=0.0.0.0\n" \ "--host - The hostname of the server, Default=0.0.0.0\n" \
...@@ -73,13 +79,16 @@ struct RpcServerArgs { ...@@ -73,13 +79,16 @@ struct RpcServerArgs {
string key; string key;
string custom_addr; string custom_addr;
bool silent = false; bool silent = false;
#if defined(WIN32)
std::string mmap_path;
#endif
}; };
/*! /*!
* \brief PrintArgs print the contents of RpcServerArgs * \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure * \param args RpcServerArgs structure
*/ */
void PrintArgs(struct RpcServerArgs args) { void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "host = " << args.host; LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port; LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end; LOG(INFO) << "port_end = " << args.port_end;
...@@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) { ...@@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) {
LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
} }
#if defined(__linux__) || defined(__ANDROID__)
/*! /*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed * \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal * \param s signal
...@@ -109,7 +119,7 @@ void HandleCtrlC() { ...@@ -109,7 +119,7 @@ void HandleCtrlC() {
sigIntHandler.sa_flags = 0; sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr); sigaction(SIGINT, &sigIntHandler, nullptr);
} }
#endif
/*! /*!
* \brief GetCmdOption Parse and find the command option. * \brief GetCmdOption Parse and find the command option.
* \param argc arg counter * \param argc arg counter
...@@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { ...@@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
} }
// We assume "=" is the end of option. // We assume "=" is the end of option.
CHECK_EQ(*option.rbegin(), '='); CHECK_EQ(*option.rbegin(), '=');
cmd = arg.substr(arg.find("=") + 1); cmd = arg.substr(arg.find('=') + 1);
return cmd; return cmd;
} }
} }
...@@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) { ...@@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) {
* \brief ParseCmdArgs parses the command line arguments. * \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter * \param argc arg counter
* \param argv arg values * \param argv arg values
* \param args, the output structure which holds the parsed values * \param args the output structure which holds the parsed values
*/ */
void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
string silent = GetCmdOption(argc, argv, "--silent", true); const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) { if (!silent.empty()) {
args.silent = true; args.silent = true;
// Only errors and fatal is logged // Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2"); dmlc::InitLogging("--minloglevel=2");
} }
string host = GetCmdOption(argc, argv, "--host="); const string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) { if (!host.empty()) {
if (!ValidateIP(host)) { if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format."; LOG(WARNING) << "Wrong host address format.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.host = host; args.host = host;
} }
string port = GetCmdOption(argc, argv, "--port="); const string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) { if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) { if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number."; LOG(WARNING) << "Wrong port number.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.port = stoi(port); args.port = stoi(port);
} }
string port_end = GetCmdOption(argc, argv, "--port_end="); const string port_end = GetCmdOption(argc, argv, "--port_end=");
if (!port_end.empty()) { if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) { if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port_end number."; LOG(WARNING) << "Wrong port_end number.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.port_end = stoi(port_end); args.port_end = stoi(port_end);
...@@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { ...@@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
if (!tracker.empty()) { if (!tracker.empty()) {
if (!ValidateTracker(tracker)) { if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format."; LOG(WARNING) << "Wrong tracker address format.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.tracker = tracker; args.tracker = tracker;
} }
string key = GetCmdOption(argc, argv, "--key="); const string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) { if (!key.empty()) {
args.key = key; args.key = key;
} }
string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); const string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
if (!custom_addr.empty()) { if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) { if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format."; LOG(WARNING) << "Wrong custom address format.";
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
exit(1); exit(1);
} }
args.custom_addr = custom_addr; args.custom_addr = custom_addr;
} }
#if defined(WIN32)
const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
if(!mmap_path.empty()) {
args.mmap_path = mmap_path;
dmlc::InitLogging("--minloglevel=0");
}
#endif
} }
/*! /*!
...@@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { ...@@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
* \return result of operation. * \return result of operation.
*/ */
int RpcServer(int argc, char * argv[]) { int RpcServer(int argc, char * argv[]) {
struct RpcServerArgs args; RpcServerArgs args;
/* parse the command line args */ /* parse the command line args */
ParseCmdArgs(argc, argv, args); ParseCmdArgs(argc, argv, args);
PrintArgs(args); PrintArgs(args);
// Ctrl+C handler
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
#if defined(__linux__) || defined(__ANDROID__)
// Ctrl+C handler
HandleCtrlC(); HandleCtrlC();
tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker, #endif
args.key, args.custom_addr, args.silent);
#if defined(WIN32)
if(!args.mmap_path.empty()) {
int ret = 0;
try {
ChildProcSocketHandler(args.mmap_path);
} catch (const std::exception&) {
ret = -1;
}
return ret;
}
#endif
RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
return 0; return 0;
} }
...@@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) { ...@@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) {
*/ */
int main(int argc, char * argv[]) { int main(int argc, char * argv[]) {
if (argc <= 1) { if (argc <= 1) {
LOG(INFO) << kUSAGE; LOG(INFO) << kUsage;
return 0; return 0;
} }
// Runs WSAStartup on Win32, no-op on POSIX
Socket::Startup();
#if defined(_WIN32)
SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1");
#endif
if (0 == strcmp(argv[1], "server")) { if (0 == strcmp(argv[1], "server")) {
RpcServer(argc, argv); return RpcServer(argc, argv);
} else {
LOG(INFO) << kUSAGE;
} }
LOG(INFO) << kUsage;
return 0; return 0;
} }
...@@ -40,7 +40,7 @@ namespace runtime { ...@@ -40,7 +40,7 @@ namespace runtime {
* \param file The format of file * \param file The format of file
* \return Module The loaded module * \return Module The loaded module
*/ */
Module Load(std::string *path, const std::string fmt = ""); Module Load(std::string *path, const std::string& fmt = "");
/*! /*!
* \brief CleanDir Removes the files from the directory * \brief CleanDir Removes the files from the directory
...@@ -62,11 +62,11 @@ struct RPCEnv { ...@@ -62,11 +62,11 @@ struct RPCEnv {
* \param name The file name * \param name The file name
* \return The full path of file. * \return The full path of file.
*/ */
std::string GetPath(std::string file_name); std::string GetPath(const std::string& file_name) const;
/*! /*!
* \brief The RPC Environment cleanup function * \brief The RPC Environment cleanup function
*/ */
void CleanUp(); void CleanUp() const;
private: private:
/*! /*!
......
...@@ -30,6 +30,15 @@ ...@@ -30,6 +30,15 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
#if defined(WIN32)
/*!
* \brief ServerLoopFromChild The Server loop process.
* \param sock The socket information
* \param addr The socket address information
*/
void ServerLoopFromChild(SOCKET socket);
#endif
/*! /*!
* \brief RPCServerCreate Creates the RPC Server. * \brief RPCServerCreate Creates the RPC Server.
* \param host The hostname of the server, Default=0.0.0.0 * \param host The hostname of the server, Default=0.0.0.0
...@@ -40,13 +49,13 @@ namespace runtime { ...@@ -40,13 +49,13 @@ namespace runtime {
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param silent Whether run in silent mode. Default=True * \param silent Whether run in silent mode. Default=True
*/ */
TVM_DLL void RPCServerCreate(std::string host = "", void RPCServerCreate(std::string host = "",
int port = 9090, int port = 9090,
int port_end = 9099, int port_end = 9099,
std::string tracker_addr = "", std::string tracker_addr = "",
std::string key = "", std::string key = "",
std::string custom_addr = "", std::string custom_addr = "",
bool silent = true); bool silent = true);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_ #endif // TVM_APPS_CPP_RPC_SERVER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file win32_process.h
* \brief Win32 process code to mimic a POSIX fork()
*/
#ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
#define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
#include <chrono>
#include <string>
namespace tvm {
namespace runtime {
/*!
* \brief SpawnRPCChild Spawns a child process with a given timeout to run
* \param fd The client socket to duplicate in the child
* \param timeout The time in seconds to wait for the child to complete before termination
*/
void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout);
/*!
* \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket
* \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent
*/
void ChildProcSocketHandler(const std::string& mmap_path);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
\ No newline at end of file
...@@ -34,8 +34,12 @@ class SockChannel final : public RPCChannel { ...@@ -34,8 +34,12 @@ class SockChannel final : public RPCChannel {
explicit SockChannel(support::TCPSocket sock) explicit SockChannel(support::TCPSocket sock)
: sock_(sock) {} : sock_(sock) {}
~SockChannel() { ~SockChannel() {
if (!sock_.BadSocket()) { try {
sock_.Close(); // BadSocket can throw
if (!sock_.BadSocket()) {
sock_.Close();
}
} catch (...) {
} }
} }
size_t Send(const void* data, size_t size) final { size_t Send(const void* data, size_t size) final {
...@@ -100,7 +104,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) { ...@@ -100,7 +104,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) {
return CreateRPCModule(RPCConnect(url, port, "client:" + key)); return CreateRPCModule(RPCConnect(url, port, "client:" + key));
} }
void RPCServerLoop(int sockfd) { // TVM_DLL needed for MSVC
TVM_DLL void RPCServerLoop(int sockfd) {
support::TCPSocket sock( support::TCPSocket sock(
static_cast<support::TCPSocket::SockType>(sockfd)); static_cast<support::TCPSocket::SockType>(sockfd));
RPCSession::Create( RPCSession::Create(
......
...@@ -63,7 +63,7 @@ class RingBuffer { ...@@ -63,7 +63,7 @@ class RingBuffer {
size_t ncopy = head_ptr_ + bytes_available_ - old_size; size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy); memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
} }
} else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) {
// shrink too large temporary buffer to avoid out of memory on some embedded devices // shrink too large temporary buffer to avoid out of memory on some embedded devices
size_t old_bytes = bytes_available_; size_t old_bytes = bytes_available_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment