rpc_socket_impl.cc 2.96 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rpc_socket_impl.cc
 * \brief Socket based RPC implementation.
 */
#include <tvm/runtime/registry.h>
#include <memory>
8
#include "rpc_session.h"
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
#include "../../common/socket.h"

namespace tvm {
namespace runtime {

class SockChannel final : public RPCChannel {
 public:
  explicit SockChannel(common::TCPSocket sock)
      : sock_(sock) {}
  ~SockChannel() {
    if (!sock_.BadSocket()) {
        sock_.Close();
    }
  }
  size_t Send(const void* data, size_t size) final {
    ssize_t n = sock_.Send(data, size);
    if (n == -1) {
      common::Socket::Error("SockChannel::Send");
    }
    return static_cast<size_t>(n);
  }
  size_t Recv(void* data, size_t size) final {
    ssize_t n = sock_.Recv(data, size);
    if (n == -1) {
      common::Socket::Error("SockChannel::Recv");
    }
    return static_cast<size_t>(n);
  }

 private:
  common::TCPSocket sock_;
};

Tianqi Chen committed
42 43
std::shared_ptr<RPCSession>
RPCConnect(std::string url, int port, std::string key) {
44 45
  common::TCPSocket sock;
  common::SockAddr addr(url.c_str(), port);
46
  sock.Create(addr.ss_family());
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
  CHECK(sock.Connect(addr))
      << "Connect to " << addr.AsString() << " failed";
  // hand shake
  std::ostringstream os;
  int code = kRPCMagic;
  int keylen = static_cast<int>(key.length());
  CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
  CHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
  if (keylen != 0) {
    CHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen);
  }
  CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
  if (code == kRPCMagic + 2) {
    sock.Close();
    LOG(FATAL) << "URL " << url << ":" << port
               << " cannot find server that matches key=" << key;
  } else if (code == kRPCMagic + 1) {
    sock.Close();
    LOG(FATAL) << "URL " << url << ":" << port
Tianqi Chen committed
66
               << " server already have key=" << key;
67 68 69 70
  } else if (code != kRPCMagic) {
    sock.Close();
    LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
  }
71 72 73 74 75 76 77 78
  CHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
  std::string remote_key;
  if (keylen != 0) {
    remote_key.resize(keylen);
    CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
  }
  return RPCSession::Create(
      std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
Tianqi Chen committed
79 80 81 82
}

Module RPCClientConnect(std::string url, int port, std::string key) {
  return CreateRPCModule(RPCConnect(url, port, "client:" + key));
83 84 85 86 87 88 89
}

void RPCServerLoop(int sockfd) {
  common::TCPSocket sock(
      static_cast<common::TCPSocket::SockType>(sockfd));
  RPCSession::Create(
      std::unique_ptr<SockChannel>(new SockChannel(sock)),
90
      "SockServerLoop", "")->ServerLoop();
91 92
}

93
TVM_REGISTER_GLOBAL("rpc._Connect")
94
.set_body([](TVMArgs args, TVMRetValue* rv) {
Tianqi Chen committed
95
    *rv = RPCClientConnect(args[0], args[1], args[2]);
96 97
  });

98
TVM_REGISTER_GLOBAL("rpc._ServerLoop")
99 100 101 102 103
.set_body([](TVMArgs args, TVMRetValue* rv) {
    RPCServerLoop(args[0]);
  });
}  // namespace runtime
}  // namespace tvm