/*! * Copyright (c) 2017 by Contributors * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ #include <tvm/runtime/registry.h> #include <memory> #include "rpc_session.h" #include "../../common/socket.h" namespace tvm { namespace runtime { class SockChannel final : public RPCChannel { public: explicit SockChannel(common::TCPSocket sock) : sock_(sock) {} ~SockChannel() { if (!sock_.BadSocket()) { sock_.Close(); } } size_t Send(const void* data, size_t size) final { ssize_t n = sock_.Send(data, size); if (n == -1) { common::Socket::Error("SockChannel::Send"); } return static_cast<size_t>(n); } size_t Recv(void* data, size_t size) final { ssize_t n = sock_.Recv(data, size); if (n == -1) { common::Socket::Error("SockChannel::Recv"); } return static_cast<size_t>(n); } private: common::TCPSocket sock_; }; std::shared_ptr<RPCSession> RPCConnect(std::string url, int port, std::string key) { common::TCPSocket sock; common::SockAddr addr(url.c_str(), port); sock.Create(); CHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed"; // hand shake std::ostringstream os; int code = kRPCMagic; int keylen = static_cast<int>(key.length()); CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code)); CHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); if (keylen != 0) { CHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen); } CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code)); if (code == kRPCMagic + 2) { sock.Close(); LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key; } else if (code == kRPCMagic + 1) { sock.Close(); LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key; } else if (code != kRPCMagic) { sock.Close(); LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server"; } CHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); std::string remote_key; if (keylen != 0) { remote_key.resize(keylen); CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } return RPCSession::Create( std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key); } Module RPCClientConnect(std::string url, int port, std::string key) { return CreateRPCModule(RPCConnect(url, port, "client:" + key)); } void RPCServerLoop(int sockfd) { common::TCPSocket sock( static_cast<common::TCPSocket::SockType>(sockfd)); RPCSession::Create( std::unique_ptr<SockChannel>(new SockChannel(sock)), "SockServerLoop", "")->ServerLoop(); } TVM_REGISTER_GLOBAL("rpc._Connect") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RPCClientConnect(args[0], args[1], args[2]); }); TVM_REGISTER_GLOBAL("rpc._ServerLoop") .set_body([](TVMArgs args, TVMRetValue* rv) { RPCServerLoop(args[0]); }); } // namespace runtime } // namespace tvm