rpc_socket_impl.cc 4.06 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 * \file rpc_socket_impl.cc
 * \brief Socket based RPC implementation.
 */
#include <tvm/runtime/registry.h>
#include <memory>
26
#include "rpc_session.h"
27
#include "../../support/socket.h"
28 29 30 31 32 33

namespace tvm {
namespace runtime {

class SockChannel final : public RPCChannel {
 public:
34
  explicit SockChannel(support::TCPSocket sock)
35 36 37
      : sock_(sock) {}
  ~SockChannel() {
    if (!sock_.BadSocket()) {
38
      sock_.Close();
39 40 41 42 43
    }
  }
  size_t Send(const void* data, size_t size) final {
    ssize_t n = sock_.Send(data, size);
    if (n == -1) {
44
      support::Socket::Error("SockChannel::Send");
45 46 47 48 49 50
    }
    return static_cast<size_t>(n);
  }
  size_t Recv(void* data, size_t size) final {
    ssize_t n = sock_.Recv(data, size);
    if (n == -1) {
51
      support::Socket::Error("SockChannel::Recv");
52 53 54 55 56
    }
    return static_cast<size_t>(n);
  }

 private:
57
  support::TCPSocket sock_;
58 59
};

Tianqi Chen committed
60 61
std::shared_ptr<RPCSession>
RPCConnect(std::string url, int port, std::string key) {
62 63
  support::TCPSocket sock;
  support::SockAddr addr(url.c_str(), port);
64
  sock.Create(addr.ss_family());
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
  CHECK(sock.Connect(addr))
      << "Connect to " << addr.AsString() << " failed";
  // hand shake
  std::ostringstream os;
  int code = kRPCMagic;
  int keylen = static_cast<int>(key.length());
  CHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code));
  CHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
  if (keylen != 0) {
    CHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen);
  }
  CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code));
  if (code == kRPCMagic + 2) {
    sock.Close();
    LOG(FATAL) << "URL " << url << ":" << port
               << " cannot find server that matches key=" << key;
  } else if (code == kRPCMagic + 1) {
    sock.Close();
    LOG(FATAL) << "URL " << url << ":" << port
Tianqi Chen committed
84
               << " server already have key=" << key;
85 86 87 88
  } else if (code != kRPCMagic) {
    sock.Close();
    LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server";
  }
89 90 91 92 93 94 95 96
  CHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen));
  std::string remote_key;
  if (keylen != 0) {
    remote_key.resize(keylen);
    CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
  }
  return RPCSession::Create(
      std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
Tianqi Chen committed
97 98 99 100
}

Module RPCClientConnect(std::string url, int port, std::string key) {
  return CreateRPCModule(RPCConnect(url, port, "client:" + key));
101 102 103
}

void RPCServerLoop(int sockfd) {
104 105
  support::TCPSocket sock(
      static_cast<support::TCPSocket::SockType>(sockfd));
106 107
  RPCSession::Create(
      std::unique_ptr<SockChannel>(new SockChannel(sock)),
108
      "SockServerLoop", "")->ServerLoop();
109 110
}

111 112 113 114 115 116
void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) {
  RPCSession::Create(std::unique_ptr<CallbackChannel>(
      new CallbackChannel(fsend, frecv)),
      "SockServerLoop", "")->ServerLoop();
}

117
TVM_REGISTER_GLOBAL("rpc._Connect")
118
.set_body_typed(RPCClientConnect);
119

120
TVM_REGISTER_GLOBAL("rpc._ServerLoop")
121
.set_body([](TVMArgs args, TVMRetValue* rv) {
122 123 124 125 126 127 128 129
    if (args.size() == 1) {
      RPCServerLoop(args[0]);
    } else {
      CHECK_EQ(args.size(), 2);
      RPCServerLoop(
          args[0].operator tvm::runtime::PackedFunc(),
          args[1].operator tvm::runtime::PackedFunc());
    }
130 131 132
  });
}  // namespace runtime
}  // namespace tvm