/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file rpc_server.cc * \brief RPC Server implementation. */ #include <tvm/runtime/registry.h> #if defined(__linux__) || defined(__ANDROID__) #include <sys/select.h> #include <sys/wait.h> #endif #include <set> #include <iostream> #include <future> #include <thread> #include <chrono> #include <string> #include "rpc_server.h" #include "rpc_env.h" #include "rpc_tracker_client.h" #include "../../src/runtime/rpc/rpc_session.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" #include "../../src/common/socket.h" namespace tvm { namespace runtime { /*! * \brief wait the child process end. * \param status status value */ #if defined(__linux__) || defined(__ANDROID__) static pid_t waitPidEintr(int *status) { pid_t pid = 0; while ((pid = waitpid(-1, status, 0)) == -1) { if (errno == EINTR) { continue; } else { perror("waitpid"); abort(); } } return pid; } #endif /*! * \brief RPCServer RPC Server class. * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" */ class RPCServer { public: /*! * \brief Constructor. */ RPCServer(const std::string &host, int port, int port_end, const std::string &tracker_addr, const std::string &key, const std::string &custom_addr) { // Init the values host_ = host; port_ = port; port_end_ = port_end; tracker_addr_ = tracker_addr; key_ = key; custom_addr_ = custom_addr; } /*! * \brief Destructor. */ ~RPCServer() { // Free the resources tracker_sock_.Close(); listen_sock_.Close(); } /*! * \brief Start Creates the RPC listen process and execution. */ void Start() { listen_sock_.Create(); my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); LOG(INFO) << "bind to " << host_ << ":" << my_port_; listen_sock_.Listen(1); std::future<void> proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this)); proc.get(); // Close the listen socket listen_sock_.Close(); } private: /*! * \brief ListenLoopProc The listen process. */ void ListenLoopProc() { TrackerClient tracker(tracker_addr_, key_, custom_addr_); while (true) { common::TCPSocket conn; common::SockAddr addr("0.0.0.0", 0); std::string opts; try { // step 1: setup tracker and report to tracker tracker.TryConnect(); // step 2: wait for in-coming connections AcceptConnection(&tracker, &conn, &addr, &opts); } catch (const char* msg) { LOG(WARNING) << "Socket exception: " << msg; // close tracker resource tracker.Close(); continue; } catch (std::exception& e) { // Other errors LOG(WARNING) << "Exception standard: " << e.what(); continue; } int timeout = GetTimeOutFromOpts(opts); #if defined(__linux__) || defined(__ANDROID__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); if (timer_pid == 0) { // Timer process sleep(timeout); exit(0); } const pid_t worker_pid = fork(); if (worker_pid == 0) { // Worker process ServerLoopProc(conn, addr); exit(0); } int status = 0; const pid_t finished_first = waitPidEintr(&status); if (finished_first == timer_pid) { kill(worker_pid, SIGKILL); } else if (finished_first == worker_pid) { kill(timer_pid, SIGKILL); } else { LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; } int status_second = 0; waitPidEintr(&status_second); // Logging. if (finished_first == timer_pid) { LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout << "), Process status = " << status_second; } else if (finished_first == worker_pid) { LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; } } else { auto pid = fork(); if (pid == 0) { ServerLoopProc(conn, addr); exit(0); } // Wait for the result int status = 0; wait(&status); LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; } #else // step 3: serving std::future<void> proc(std::async(std::launch::async, &RPCServer::ServerLoopProc, this, conn, addr)); // wait until server process finish or timeout if (timeout != 0) { // Autoterminate after timeout proc.wait_for(std::chrono::seconds(timeout)); } else { // Wait for the result proc.get(); } #endif // close from our side. LOG(INFO) << "Socket Connection Closed"; conn.Close(); } } /*! * \brief AcceptConnection Accepts the RPC Server connection. * \param tracker Tracker details. * \param conn New connection information. * \param addr New connection address information. * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ void AcceptConnection(TrackerClient* tracker, common::TCPSocket* conn_sock, common::SockAddr* addr, std::string* opts, int ping_period = 2) { std::set <std::string> old_keyset; std::string matchkey; // Report resource to tracker and get key tracker->ReportResourceAndGetKey(my_port_, &matchkey); while (true) { tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey); common::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; continue; } int keylen = 0; CHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); const char* CLIENT_HEADER = "client:"; const char* SERVER_HEADER = "server:"; std::string expect_header = CLIENT_HEADER + matchkey; std::string server_key = SERVER_HEADER + key_; if (size_t(keylen) < expect_header.length()) { conn.Close(); LOG(INFO) << "Wrong client header length"; continue; } CHECK_NE(keylen, 0); std::string remote_key; remote_key.resize(keylen); CHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen); std::stringstream ssin(remote_key); std::string arg0; ssin >> arg0; if (arg0 != expect_header) { code = kRPCMismatch; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); conn.Close(); LOG(WARNING) << "Mismatch key from" << addr->AsString(); continue; } else { code = kRPCSuccess; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); keylen = server_key.length(); CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); ssin >> *opts; *conn_sock = conn; return; } } } /*! * \brief ServerLoopProc The Server loop process. * \param sock The socket information * \param addr The socket address information */ void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { // Server loop auto env = RPCEnv(); RPCServerLoop(sock.sockfd); LOG(INFO) << "Finish serving " << addr.AsString(); env.CleanUp(); } /*! * \brief GetTimeOutFromOpts Parse and get the timeout option. * \param opts The option string * \param timeout value after parsing. */ int GetTimeOutFromOpts(std::string opts) { std::string cmd; std::string option = "-timeout="; if (opts.find(option) == 0) { cmd = opts.substr(opts.find_last_of(option) + 1); CHECK(common::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } return 0; } std::string host_; int port_; int my_port_; int port_end_; std::string tracker_addr_; std::string key_; std::string custom_addr_; common::TCPSocket listen_sock_; common::TCPSocket tracker_sock_; }; /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, std::string key, std::string custom_addr, bool silent) { if (silent) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr); rpc.Start(); } TVM_REGISTER_GLOBAL("rpc._ServerCreate") .set_body([](TVMArgs args, TVMRetValue* rv) { RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); } // namespace runtime } // namespace tvm