Commit d2fc0252 by Zhao Wu Committed by Tianqi Chen

[RUTNIME] Support C++ RPC (#4281)

parent 2f65a87f
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Makefile to compile RPC Server.
TVM_ROOT=$(shell cd ../..; pwd)
DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core
TVM_RUNTIME_DIR?=
OS?=
# Android can not link pthrad, but Linux need.
ifeq ($(OS), Linux)
LINK_PTHREAD=-lpthread
else
LINK_PTHREAD=
endif
PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\
-I${TVM_ROOT}/include\
-I${DMLC_CORE}/include\
-I${TVM_ROOT}/3rdparty/dlpack/include
PKG_LDFLAGS = -L$(TVM_RUNTIME_DIR) $(LINK_PTHREAD) -ltvm_runtime -ldl -Wl,-R$(TVM_RUNTIME_DIR)
ifeq ($(USE_GLOG), 1)
PKG_CFLAGS += -DDMLC_USE_GLOG=1
PKG_LDFLAGS += -lglog
endif
.PHONY: clean all
all: tvm_rpc
# Build rule for all in one TVM package library
tvm_rpc: *.cc
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -o $@ $(filter %.cc %.o %.a, $^) $(PKG_LDFLAGS)
clean:
-rm -f tvm_rpc
\ No newline at end of file
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
# TVM RPC Server
This folder contains a simple recipe to make RPC server in c++.
## Usage
- Build tvm runtime
- Make the rpc executable [Makefile](Makefile).
`make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux`
if you want to compile it for embedded Linux, you should add `OS=Linux`.
if the target os is Android, you doesn't need to pass OS argument.
You could cross compile the TVM runtime like this:
```
cd tvm
mkdir arm_runtime
cp cmake/config.cmake arm_runtime
cd arm_runtime
cmake .. -DCMAKE_CXX_COMPILER="/path/to/cross compiler g++/"
make runtime
```
- Use `./tvm_rpc server` to start the RPC server
## How it works
- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library.
```
Command line usage
server - Start the server
--host - The hostname of the server, Default=0.0.0.0
--port - The port of the RPC, Default=9090
--port-end - The end search port of the RPC, Default=9199
--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=""
--key - The key used to identify the device type in tracker. Default=""
--custom-addr - Custom IP Address to Report to RPC Tracker. Default=""
--silent - Whether to run in silent mode. Default=False
Example
./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 --tracker=127.0.0.1:9190 --key=rasp
```
## Note
Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently.
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_server.cc
* \brief RPC Server for TVM.
*/
#include <stdlib.h>
#include <signal.h>
#include <stdio.h>
#include <unistd.h>
#include <dmlc/logging.h>
#include <iostream>
#include <cstring>
#include <vector>
#include <sstream>
#include "../../src/common/util.h"
#include "../../src/common/socket.h"
#include "rpc_server.h"
using namespace std;
using namespace tvm::runtime;
using namespace tvm::common;
static const string kUSAGE = \
"Command line usage\n" \
" server - Start the server\n" \
"--host - The hostname of the server, Default=0.0.0.0\n" \
"--port - The port of the RPC, Default=9090\n" \
"--port-end - The end search port of the RPC, Default=9199\n" \
"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \
"--key - The key used to identify the device type in tracker. Default=\"\"\n" \
"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \
"--silent - Whether to run in silent mode. Default=False\n" \
"\n" \
" Example\n" \
" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 "
" --tracker=127.0.0.1:9190 --key=rasp" \
"\n";
/*!
* \brief RpcServerArgs.
* \arg host The hostname of the server, Default=0.0.0.0
* \arg port The port of the RPC, Default=9090
* \arg port_end The end search port of the RPC, Default=9199
* \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \arg key The key used to identify the device type in tracker. Default=""
* \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \arg silent Whether run in silent mode. Default=False
*/
struct RpcServerArgs {
string host = "0.0.0.0";
int port = 9090;
int port_end = 9099;
string tracker;
string key;
string custom_addr;
bool silent = false;
};
/*!
* \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure
*/
void PrintArgs(struct RpcServerArgs args) {
LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end;
LOG(INFO) << "tracker = " << args.tracker;
LOG(INFO) << "key = " << args.key;
LOG(INFO) << "custom_addr = " << args.custom_addr;
LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
}
/*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal
*/
void CtrlCHandler(int s) {
LOG(INFO) << "\nUser pressed Ctrl+C, Exiting";
exit(1);
}
/*!
* \brief HandleCtrlC Register for handling Ctrl+C event.
*/
void HandleCtrlC() {
// Ctrl+C handler
struct sigaction sigIntHandler;
sigIntHandler.sa_handler = CtrlCHandler;
sigemptyset(&sigIntHandler.sa_mask);
sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr);
}
/*!
* \brief GetCmdOption Parse and find the command option.
* \param argc arg counter
* \param argv arg values
* \param option command line option to search for.
* \param key whether the option itself is key
* \return value corresponding to option.
*/
string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
string cmd;
for (int i = 1; i < argc; ++i) {
string arg = argv[i];
if (arg.find(option) == 0) {
if (key) {
cmd = argv[i];
return cmd;
}
// We assume "=" is the end of option.
CHECK_EQ(*option.rbegin(), '=');
cmd = arg.substr(arg.find("=") + 1);
return cmd;
}
}
return cmd;
}
/*!
* \brief ValidateTracker Check the tracker address format is correct and changes the format.
* \param tracker The tracker input.
* \return result of operation.
*/
bool ValidateTracker(string &tracker) {
vector<string> list = Split(tracker, ':');
if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) {
return false;
}
ostringstream ss;
ss << "('" << list[0] << "', " << list[1] << ")";
tracker = ss.str();
return true;
}
/*!
* \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter
* \param argv arg values
* \param args, the output structure which holds the parsed values
*/
void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) {
args.silent = true;
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) {
if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format.";
LOG(INFO) << kUSAGE;
exit(1);
}
args.host = host;
}
string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number.";
LOG(INFO) << kUSAGE;
exit(1);
}
args.port = stoi(port);
}
string port_end = GetCmdOption(argc, argv, "--port_end=");
if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port_end number.";
LOG(INFO) << kUSAGE;
exit(1);
}
args.port_end = stoi(port_end);
}
string tracker = GetCmdOption(argc, argv, "--tracker=");
if (!tracker.empty()) {
if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format.";
LOG(INFO) << kUSAGE;
exit(1);
}
args.tracker = tracker;
}
string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) {
args.key = key;
}
string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format.";
LOG(INFO) << kUSAGE;
exit(1);
}
args.custom_addr = custom_addr;
}
}
/*!
* \brief RpcServer Starts the RPC server.
* \param argc arg counter
* \param argv arg values
* \return result of operation.
*/
int RpcServer(int argc, char * argv[]) {
struct RpcServerArgs args;
/* parse the command line args */
ParseCmdArgs(argc, argv, args);
PrintArgs(args);
// Ctrl+C handler
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
HandleCtrlC();
tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
return 0;
}
/*!
* \brief main The main function.
* \param argc arg counter
* \param argv arg values
* \return result of operation.
*/
int main(int argc, char * argv[]) {
if (argc <= 1) {
LOG(INFO) << kUSAGE;
return 0;
}
if (0 == strcmp(argv[1], "server")) {
RpcServer(argc, argv);
} else {
LOG(INFO) << kUSAGE;
}
return 0;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_env.cc
* \brief Server environment of the RPC.
*/
#include <tvm/runtime/registry.h>
#include <errno.h>
#ifndef _MSC_VER
#include <sys/stat.h>
#include <dirent.h>
#include <unistd.h>
#else
#include <Windows.h>
#endif
#include <fstream>
#include <vector>
#include <iostream>
#include <string>
#include <cstring>
#include "rpc_env.h"
#include "../../src/common/util.h"
#include "../../src/runtime/file_util.h"
namespace tvm {
namespace runtime {
RPCEnv::RPCEnv() {
#if defined(__linux__) || defined(__ANDROID__)
base_ = "./rpc";
mkdir(&base_[0], 0777);
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
static RPCEnv env;
std::string file_name = env.GetPath(args[0]);
*rv = Load(&file_name, "");
LOG(INFO) << "Load module from " << file_name << " ...";
});
#else
LOG(FATAL) << "Only support RPC in linux environment";
#endif
}
/*!
* \brief GetPath To get the workpath from packed function
* \param name The file name
* \return The full path of file.
*/
std::string RPCEnv::GetPath(std::string file_name) {
// we assume file_name has "/" means file_name is the exact path
// and does not create /.rpc/
if (file_name.find("/") != std::string::npos) {
return file_name;
} else {
return base_ + "/" + file_name;
}
}
/*!
* \brief Remove The RPC Environment cleanup function
*/
void RPCEnv::CleanUp() {
#if defined(__linux__) || defined(__ANDROID__)
CleanDir(&base_[0]);
int ret = rmdir(&base_[0]);
if (ret != 0) {
LOG(WARNING) << "Remove directory " << base_ << " failed";
}
#else
LOG(FATAL) << "Only support RPC in linux environment";
#endif
}
/*!
* \brief ListDir get the list of files in a directory
* \param dirname The root directory name
* \return vector Files in directory.
*/
std::vector<std::string> ListDir(const std::string &dirname) {
std::vector<std::string> vec;
#ifndef _MSC_VER
DIR *dp = opendir(dirname.c_str());
if (dp == nullptr) {
int errsv = errno;
LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv);
}
dirent *d;
while ((d = readdir(dp)) != nullptr) {
std::string filename = d->d_name;
if (filename != "." && filename != "..") {
std::string f = dirname;
if (f[f.length() - 1] != '/') {
f += '/';
}
f += d->d_name;
vec.push_back(f);
}
}
closedir(dp);
#else
WIN32_FIND_DATA fd;
std::string pattern = dirname + "/*";
HANDLE handle = FindFirstFile(pattern.c_str(), &fd);
if (handle == INVALID_HANDLE_VALUE) {
int errsv = GetLastError();
LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
}
do {
if (fd.cFileName != "." && fd.cFileName != "..") {
std::string f = dirname;
char clast = f[f.length() - 1];
if (f == ".") {
f = fd.cFileName;
} else if (clast != '/' && clast != '\\') {
f += '/';
f += fd.cFileName;
}
vec.push_back(f);
}
} while (FindNextFile(handle, &fd));
FindClose(handle);
#endif
return vec;
}
/*!
* \brief LinuxShared Creates a linux shared library
* \param output The output file name
* \param files The files for building
* \param options The compiler options
* \param cc The compiler
*/
void LinuxShared(const std::string output,
const std::vector<std::string> &files,
std::string options = "",
std::string cc = "g++") {
std::string cmd = cc;
cmd += " -shared -fPIC ";
cmd += " -o " + output;
for (auto f = files.begin(); f != files.end(); ++f) {
cmd += " " + *f;
}
cmd += " " + options;
std::string err_msg;
auto executed_status = common::Execute(cmd, &err_msg);
if (executed_status) {
LOG(FATAL) << err_msg;
}
}
/*!
* \brief CreateShared Creates a shared library
* \param output The output file name
* \param files The files for building
*/
void CreateShared(const std::string output, const std::vector<std::string> &files) {
#if defined(__linux__) || defined(__ANDROID__)
LinuxShared(output, files);
#else
LOG(FATAL) << "Do not support creating shared library";
#endif
}
/*!
* \brief Load Load module from file
This function will automatically call
cc.create_shared if the path is in format .o or .tar
High level handling for .o and .tar file.
We support this to be consistent with RPC module load.
* \param fileIn The input file, file name will be updated
* \param fmt The format of file
* \return Module The loaded module
*/
Module Load(std::string *fileIn, const std::string fmt) {
std::string file = *fileIn;
if (common::EndsWith(file, ".so")) {
return Module::LoadFromFile(file, fmt);
}
#if defined(__linux__) || defined(__ANDROID__)
std::string file_name = file + ".so";
if (common::EndsWith(file, ".o")) {
std::vector<std::string> files;
files.push_back(file);
CreateShared(file_name, files);
} else if (common::EndsWith(file, ".tar")) {
std::string tmp_dir = "./rpc/tmp/";
mkdir(&tmp_dir[0], 0777);
std::string cmd = "tar -C " + tmp_dir + " -zxf " + file;
std::string err_msg;
int executed_status = common::Execute(cmd, &err_msg);
if (executed_status) {
LOG(FATAL) << err_msg;
}
CreateShared(file_name, ListDir(tmp_dir));
CleanDir(tmp_dir);
rmdir(&tmp_dir[0]);
} else {
file_name = file;
}
*fileIn = file_name;
return Module::LoadFromFile(file_name, fmt);
#else
LOG(FATAL) << "Do not support creating shared library";
#endif
}
/*!
* \brief CleanDir Removes the files from the directory
* \param dirname The name of the directory
*/
void CleanDir(const std::string &dirname) {
#if defined(__linux__) || defined(__ANDROID__)
DIR *dp = opendir(dirname.c_str());
dirent *d;
while ((d = readdir(dp)) != nullptr) {
std::string filename = d->d_name;
if (filename != "." && filename != "..") {
filename = dirname + "/" + d->d_name;
int ret = std::remove(&filename[0]);
if (ret != 0) {
LOG(WARNING) << "Remove file " << filename << " failed";
}
}
}
#else
LOG(FATAL) << "Only support RPC in linux environment";
#endif
}
} // namespace runtime
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_env.h
* \brief Server environment of the RPC.
*/
#ifndef TVM_APPS_CPP_RPC_ENV_H_
#define TVM_APPS_CPP_RPC_ENV_H_
#include <tvm/runtime/registry.h>
#include <string>
namespace tvm {
namespace runtime {
/*!
* \brief Load Load module from file
This function will automatically call
cc.create_shared if the path is in format .o or .tar
High level handling for .o and .tar file.
We support this to be consistent with RPC module load.
* \param file The input file
* \param file The format of file
* \return Module The loaded module
*/
Module Load(std::string *path, const std::string fmt = "");
/*!
* \brief CleanDir Removes the files from the directory
* \param dirname THe name of the directory
*/
void CleanDir(const std::string &dirname);
/*!
* \brief RPCEnv The RPC Environment parameters for c++ rpc server
*/
struct RPCEnv {
public:
/*!
* \brief Constructor Init The RPC Environment initialize function
*/
RPCEnv();
/*!
* \brief GetPath To get the workpath from packed function
* \param name The file name
* \return The full path of file.
*/
std::string GetPath(std::string file_name);
/*!
* \brief The RPC Environment cleanup function
*/
void CleanUp();
private:
/*!
* \brief Holds the environment path.
*/
std::string base_;
}; // RPCEnv
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_ENV_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_server.h
* \brief RPC Server implementation.
*/
#ifndef TVM_APPS_CPP_RPC_SERVER_H_
#define TVM_APPS_CPP_RPC_SERVER_H_
#include <string>
#include "tvm/runtime/c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief RPCServerCreate Creates the RPC Server.
* \param host The hostname of the server, Default=0.0.0.0
* \param port The port of the RPC, Default=9090
* \param port_end The end search port of the RPC, Default=9199
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \param key The key used to identify the device type in tracker. Default=""
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param silent Whether run in silent mode. Default=True
*/
TVM_DLL void RPCServerCreate(std::string host = "",
int port = 9090,
int port_end = 9099,
std::string tracker_addr = "",
std::string key = "",
std::string custom_addr = "",
bool silent = true);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rpc_tracker_client.h
* \brief RPC Tracker client to report resources.
*/
#ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
#define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
#include <set>
#include <iostream>
#include <chrono>
#include <random>
#include <vector>
#include <string>
#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/common/socket.h"
namespace tvm {
namespace runtime {
/*!
* \brief TrackerClient Tracker client class.
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \param key The key used to identify the device type in tracker. Default=""
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
*/
class TrackerClient {
public:
/*!
* \brief Constructor.
*/
TrackerClient(const std::string& tracker_addr,
const std::string& key,
const std::string& custom_addr)
: tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr),
gen_(std::random_device{}()), dis_(0.0, 1.0) {
}
/*!
* \brief Destructor.
*/
~TrackerClient() {
// Free the resources
Close();
}
/*!
* \brief IsValid Check tracker is valid.
*/
bool IsValid() {
return (!tracker_addr_.empty() && !tracker_sock_.IsClosed());
}
/*!
* \brief TryConnect Connect to tracker if the tracker address is valid.
*/
void TryConnect() {
if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) {
tracker_sock_ = ConnectWithRetry();
int code = kRPCTrackerMagic;
CHECK_EQ(tracker_sock_.SendAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(tracker_sock_.RecvAll(&code, sizeof(code)), sizeof(code));
CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker";
std::ostringstream ss;
ss << "[" << static_cast<int>(TrackerCode::kUpdateInfo)
<< ", {\"key\": \"server:"<< key_ << "\"}]";
tracker_sock_.SendBytes(ss.str());
// Receive status and validate
std::string remote_status = tracker_sock_.RecvBytes();
CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
}
}
/*!
* \brief Close Clean up tracker resources.
*/
void Close() {
// close tracker resource
if (!tracker_sock_.IsClosed()) {
tracker_sock_.Close();
}
}
/*!
* \brief ReportResourceAndGetKey Report resource to tracker.
* \param port listening port.
* \param matchkey Random match key output.
*/
void ReportResourceAndGetKey(int port,
std::string *matchkey) {
if (!tracker_sock_.IsClosed()) {
*matchkey = RandomKey(key_ + ":", old_keyset_);
if (custom_addr_.empty()) {
custom_addr_ = "null";
}
std::ostringstream ss;
ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
<< port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
tracker_sock_.SendBytes(ss.str());
// Receive status and validate
std::string remote_status = tracker_sock_.RecvBytes();
CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
} else {
*matchkey = key_;
}
}
/*!
* \brief ReportResourceAndGetKey Report resource to tracker.
* \param listen_sock Listen socket details for select.
* \param port listening port.
* \param ping_period Select wait time.
* \param matchkey Random match key output.
*/
void WaitConnectionAndUpdateKey(common::TCPSocket listen_sock,
int port,
int ping_period,
std::string *matchkey) {
int unmatch_period_count = 0;
int unmatch_timeout = 4;
while (true) {
if (!tracker_sock_.IsClosed()) {
common::PollHelper poller;
poller.WatchRead(listen_sock.sockfd);
poller.Poll(ping_period * 1000);
if (!poller.CheckRead(listen_sock.sockfd)) {
std::ostringstream ss;
ss << "[" << int(TrackerCode::kGetPendingMatchKeys) << "]";
tracker_sock_.SendBytes(ss.str());
// Receive status and validate
std::string pending_keys = tracker_sock_.RecvBytes();
old_keyset_.insert(*matchkey);
// if match key not in pending key set
// it means the key is acquired by a client but not used.
if (pending_keys.find(*matchkey) == std::string::npos) {
unmatch_period_count += 1;
} else {
unmatch_period_count = 0;
}
// regenerate match key if key is acquired but not used for a while
if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) {
LOG(INFO) << "no incoming connections, regenerate key ...";
*matchkey = RandomKey(key_ + ":", old_keyset_);
std::ostringstream ss;
ss << "[" << static_cast<int>(TrackerCode::kPut) << ", \"" << key_ << "\", ["
<< port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]";
tracker_sock_.SendBytes(ss.str());
std::string remote_status = tracker_sock_.RecvBytes();
CHECK_EQ(std::stoi(remote_status), static_cast<int>(TrackerCode::kSuccess));
unmatch_period_count = 0;
}
continue;
}
}
break;
}
}
private:
/*!
* \brief Connect to a RPC address with retry.
This function is only reliable to short period of server restart.
* \param timeout Timeout during retry
* \param retry_period Number of seconds before we retry again.
* \return TCPSocket The socket information if connect is success.
*/
common::TCPSocket ConnectWithRetry(int timeout = 60, int retry_period = 5) {
auto tbegin = std::chrono::system_clock::now();
while (true) {
common::SockAddr addr(tracker_addr_);
common::TCPSocket sock;
sock.Create();
LOG(INFO) << "Tracker connecting to " << addr.AsString();
if (sock.Connect(addr)) {
return sock;
}
auto period = (std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::system_clock::now() - tbegin)).count();
CHECK(period < timeout) << "Failed to connect to server" << addr.AsString();
LOG(WARNING) << "Cannot connect to tracker " << addr.AsString()
<< " retry in " << retry_period << " seconds.";
std::this_thread::sleep_for(std::chrono::seconds(retry_period));
}
}
/*!
* \brief Random Generate a random number between 0 and 1.
* \return random float value.
*/
float Random() {
return dis_(gen_);
}
/*!
* \brief Generate a random key.
* \param prefix The string prefix.
* \return cmap The conflict map set.
*/
std::string RandomKey(const std::string& prefix, const std::set <std::string> &cmap) {
if (!cmap.empty()) {
while (true) {
std::string key = prefix + std::to_string(Random());
if (cmap.find(key) == cmap.end()) {
return key;
}
}
}
return prefix + std::to_string(Random());
}
std::string tracker_addr_;
std::string key_;
std::string custom_addr_;
common::TCPSocket tracker_sock_;
std::set <std::string> old_keyset_;
std::mt19937 gen_;
std::uniform_real_distribution<float> dis_;
};
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_
...@@ -43,12 +43,27 @@ using ssize_t = int; ...@@ -43,12 +43,27 @@ using ssize_t = int;
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/select.h>
#include <sys/ioctl.h> #include <sys/ioctl.h>
#endif #endif
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <vector>
#include <unordered_map>
#include "../common/util.h"
#if defined(_WIN32)
static inline int poll(struct pollfd *pfd, int nfds,
int timeout) {
return WSAPoll(pfd, nfds, timeout);
}
static inline int inet_pton(int family, const char* addr_str, void* addr_buf) {
return InetPton(family, addr_str, addr_buf);
}
#else
#include <sys/poll.h>
#endif // defined(_WIN32)
namespace tvm { namespace tvm {
namespace common { namespace common {
...@@ -63,6 +78,22 @@ inline std::string GetHostName() { ...@@ -63,6 +78,22 @@ inline std::string GetHostName() {
} }
/*! /*!
* \brief ValidateIP validates an ip address.
* \param ip The ip address in string format localhost or x.x.x.x format
* \return result of operation.
*/
inline bool ValidateIP(std::string ip) {
if (ip == "localhost") {
return true;
}
struct sockaddr_in sa_ipv4;
struct sockaddr_in6 sa_ipv6;
bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr));
bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr));
return is_ipv4 || is_ipv6;
}
/*!
* \brief Common data structure for network address. * \brief Common data structure for network address.
*/ */
struct SockAddr { struct SockAddr {
...@@ -76,6 +107,23 @@ struct SockAddr { ...@@ -76,6 +107,23 @@ struct SockAddr {
SockAddr(const char *url, int port) { SockAddr(const char *url, int port) {
this->Set(url, port); this->Set(url, port);
} }
/*!
* \brief SockAddr Get the socket address from tracker.
* \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090)
* \return SockAddr parsed from url.
*/
explicit SockAddr(const std::string &url) {
size_t sep = url.find(",");
std::string host = url.substr(2, sep - 3);
std::string port = url.substr(sep + 1, url.length() - 1);
CHECK(ValidateIP(host)) << "Url address is not valid " << url;
if (host == "localhost") {
host = "127.0.0.1";
}
this->Set(host.c_str(), std::stoi(port));
}
/*! /*!
* \brief set the address * \brief set the address
* \param host the url of the address * \param host the url of the address
...@@ -203,17 +251,20 @@ class Socket { ...@@ -203,17 +251,20 @@ class Socket {
} }
/*! /*!
* \brief try bind the socket to host, from start_port to end_port * \brief try bind the socket to host, from start_port to end_port
* \param host host address to bind the socket
* \param start_port starting port number to try * \param start_port starting port number to try
* \param end_port ending port number to try * \param end_port ending port number to try
* \return the port successfully bind to, return -1 if failed to bind any port * \return the port successfully bind to, return -1 if failed to bind any port
*/ */
inline int TryBindHost(int start_port, int end_port) { inline int TryBindHost(std::string host, int start_port, int end_port) {
for (int port = start_port; port < end_port; ++port) { for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port); SockAddr addr(host.c_str(), port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr), if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
(addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
sizeof(sockaddr_in))) == 0) { sizeof(sockaddr_in))) == 0) {
return port; return port;
} else {
LOG(WARNING) << "Bind failed to " << host << ":" << port;
} }
#if defined(_WIN32) #if defined(_WIN32)
if (WSAGetLastError() != WSAEADDRINUSE) { if (WSAGetLastError() != WSAEADDRINUSE) {
...@@ -374,6 +425,20 @@ class TCPSocket : public Socket { ...@@ -374,6 +425,20 @@ class TCPSocket : public Socket {
return TCPSocket(newfd); return TCPSocket(newfd);
} }
/*! /*!
* \brief get a new connection
* \param addr client address from which connection accepted
* \return The accepted socket connection.
*/
TCPSocket Accept(SockAddr *addr) {
socklen_t addrlen = sizeof(addr->addr);
SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr),
&addrlen);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
return TCPSocket(newfd);
}
/*!
* \brief decide whether the socket is at OOB mark * \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occurred * \return 1 if at mark, 0 if not, -1 if an error occurred
*/ */
...@@ -468,7 +533,125 @@ class TCPSocket : public Socket { ...@@ -468,7 +533,125 @@ class TCPSocket : public Socket {
} }
return ndone; return ndone;
} }
/*!
* \brief Send the data to remote.
* \param data The data to be sent.
*/
void SendBytes(std::string data) {
int datalen = data.length();
CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen));
CHECK_EQ(SendAll(data.c_str(), datalen), datalen);
}
/*!
* \brief Receive the data to remote.
* \return The data received.
*/
std::string RecvBytes() {
int datalen = 0;
CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen));
std::string data;
data.resize(datalen);
CHECK_EQ(RecvAll(&data[0], datalen), datalen);
return data;
}
}; };
/*! \brief helper data structure to perform poll */
struct PollHelper {
public:
/*!
* \brief add file descriptor to watch for read
* \param fd file descriptor to be watched
*/
inline void WatchRead(TCPSocket::SockType fd) {
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLIN;
}
/*!
* \brief add file descriptor to watch for write
* \param fd file descriptor to be watched
*/
inline void WatchWrite(TCPSocket::SockType fd) {
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLOUT;
}
/*!
* \brief add file descriptor to watch for exception
* \param fd file descriptor to be watched
*/
inline void WatchException(TCPSocket::SockType fd) {
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLPRI;
}
/*!
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
*/
inline bool CheckRead(TCPSocket::SockType fd) const {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
}
/*!
* \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status
*/
inline bool CheckWrite(TCPSocket::SockType fd) const {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
}
/*!
* \brief Check if the descriptor has any exception
* \param fd file descriptor to check status
*/
inline bool CheckExcept(TCPSocket::SockType fd) const {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
}
/*!
* \brief wait for exception event on a single descriptor
* \param fd the file descriptor to wait the event for
* \param timeout the timeout counter, can be negative, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs
*/
inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*)
pollfd pfd;
pfd.fd = fd;
pfd.events = POLLPRI;
return poll(&pfd, 1, timeout);
}
/*!
* \brief peform poll on the set defined, read, write, exception
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
* \return
*/
inline void Poll(long timeout = -1) { // NOLINT(*)
std::vector<pollfd> fdset;
fdset.reserve(fds.size());
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = poll(fdset.data(), fdset.size(), timeout);
if (ret == -1) {
Socket::Error("Poll");
} else {
for (auto& pfd : fdset) {
auto revents = pfd.revents & pfd.events;
if (!revents) {
fds.erase(pfd.fd);
} else {
fds[pfd.fd].events = revents;
}
}
}
}
std::unordered_map<TCPSocket::SockType, pollfd> fds;
};
} // namespace common } // namespace common
} // namespace tvm } // namespace tvm
#endif // TVM_COMMON_SOCKET_H_ #endif // TVM_COMMON_SOCKET_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file util.h
* \brief Defines some common utility function..
*/
#ifndef TVM_COMMON_UTIL_H_
#define TVM_COMMON_UTIL_H_
#include <stdio.h>
#ifndef _WIN32
#include <sys/wait.h>
#include <sys/types.h>
#endif
#include <vector>
#include <string>
#include <sstream>
#include <algorithm>
#include <array>
#include <memory>
namespace tvm {
namespace common {
/*!
* \brief TVMPOpen wrapper of popen between windows / unix.
* \param command executed command
* \param type "r" is for reading or "w" for writing.
* \return normal standard stream
*/
inline FILE* TVMPOpen(const char* command, const char* type) {
#if defined(_WIN32)
return _popen(command, type);
#else
return popen(command, type);
#endif
}
/*!
* \brief TVMPClose wrapper of pclose between windows / linux
* \param stream the stream needed to be close.
* \return exit status
*/
inline int TVMPClose(FILE* stream) {
#if defined(_WIN32)
return _pclose(stream);
#else
return pclose(stream);
#endif
}
/*!
* \brief TVMWifexited wrapper of WIFEXITED between windows / linux
* \param status The status field that was filled in by the wait or waitpid function
* \return the exit code of the child process
*/
inline int TVMWifexited(int status) {
#if defined(_WIN32)
return (status != 3);
#else
return WIFEXITED(status);
#endif
}
/*!
* \brief TVMWexitstatus wrapper of WEXITSTATUS between windows / linux
* \param status The status field that was filled in by the wait or waitpid function.
* \return the child process exited normally or not
*/
inline int TVMWexitstatus(int status) {
#if defined(_WIN32)
return status;
#else
return WEXITSTATUS(status);
#endif
}
/*!
* \brief IsNumber check whether string is a number.
* \param str input string
* \return result of operation.
*/
inline bool IsNumber(const std::string& str) {
return !str.empty() && std::find_if(str.begin(),
str.end(), [](char c) { return !std::isdigit(c); }) == str.end();
}
/*!
* \brief split Split the string based on delimiter
* \param str Input string
* \param delim The delimiter.
* \return vector of strings which are splitted.
*/
inline std::vector<std::string> Split(const std::string& str, char delim) {
std::string item;
std::istringstream is(str);
std::vector<std::string> ret;
while (std::getline(is, item, delim)) {
ret.push_back(item);
}
return ret;
}
/*!
* \brief EndsWith check whether the strings ends with
* \param value The full string
* \param end The end substring
* \return bool The result.
*/
inline bool EndsWith(std::string const& value, std::string const& end) {
if (end.size() <= value.size()) {
return std::equal(end.rbegin(), end.rend(), value.rbegin());
}
return false;
}
/*!
* \brief Execute the command
* \param cmd The command we want to execute
* \param err_msg The error message if we have
* \return executed output status
*/
inline int Execute(std::string cmd, std::string* err_msg) {
std::array<char, 128> buffer;
std::string result;
cmd += " 2>&1";
FILE* fd = TVMPOpen(cmd.c_str(), "r");
while (fgets(buffer.data(), buffer.size(), fd) != nullptr) {
*err_msg += buffer.data();
}
int status = TVMPClose(fd);
if (TVMWifexited(status)) {
return TVMWexitstatus(status);
}
return 255;
}
} // namespace common
} // namespace tvm
#endif // TVM_COMMON_UTIL_H_
...@@ -36,8 +36,27 @@ ...@@ -36,8 +36,27 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
// Magic header for RPC data plane
const int kRPCMagic = 0xff271; const int kRPCMagic = 0xff271;
// magic header for RPC tracker(control plane)
const int kRPCTrackerMagic = 0x2f271;
// sucess response
const int kRPCSuccess = kRPCMagic + 0;
// cannot found matched key in server
const int kRPCMismatch = kRPCMagic + 2;
/*! \brief Enumeration code for the RPC tracker */
enum class TrackerCode : int {
kFail = -1,
kSuccess = 0,
kPing = 1,
kStop = 2,
kPut = 3,
kRequest = 4,
kUpdateInfo = 5,
kSummary = 6,
kGetPendingMatchKeys = 7
};
/*! \brief The remote functio handle */ /*! \brief The remote functio handle */
using RPCFuncHandle = void*; using RPCFuncHandle = void*;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file rpc_socket_impl.h
* \brief Socket based RPC implementation.
*/
#ifndef TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_
#define TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_
namespace tvm {
namespace runtime {
/*!
* \brief RPCServerLoop Start the rpc server loop.
* \param sockfd Socket file descriptor
*/
void RPCServerLoop(int sockfd);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment