socket.h 18.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28
/*!
 * \file socket.h
 * \brief this file aims to provide a wrapper of sockets
 * \author Tianqi Chen
 */
#ifndef TVM_COMMON_SOCKET_H_
#define TVM_COMMON_SOCKET_H_

#if defined(_WIN32)
29
#define NOMINMAX
30 31
#include <winsock2.h>
#include <ws2tcpip.h>
32
#undef NOMINMAX
33
using ssize_t = int;
34 35 36 37 38 39 40 41 42 43 44
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif
#else
#include <fcntl.h>
#include <netdb.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
45
#include <sys/select.h>
46 47 48 49 50
#include <sys/ioctl.h>
#endif
#include <dmlc/logging.h>
#include <string>
#include <cstring>
51 52 53
#include <vector>
#include <unordered_map>
#include "../common/util.h"
54

55 56 57 58 59 60 61 62
#if defined(_WIN32)
static inline int poll(struct pollfd *pfd, int nfds,
                       int timeout) {
  return WSAPoll(pfd, nfds, timeout);
}
#else
#include <sys/poll.h>
#endif  // defined(_WIN32)
63 64 65 66 67 68 69 70 71 72 73 74 75 76

namespace tvm {
namespace common {
/*!
 * \brief Get current host name.
 * \return The hostname.
 */
inline std::string GetHostName() {
  std::string buf; buf.resize(256);
  CHECK_NE(gethostname(&buf[0], 256), -1);
  return std::string(buf.c_str());
}

/*!
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
 * \brief ValidateIP validates an ip address.
 * \param ip The ip address in string format localhost or x.x.x.x format
 * \return result of operation.
 */
inline bool ValidateIP(std::string ip) {
  if (ip == "localhost") {
    return true;
  }
  struct sockaddr_in sa_ipv4;
  struct sockaddr_in6 sa_ipv6;
  bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr));
  bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr));
  return is_ipv4 || is_ipv6;
}

/*!
93
 * \brief Common data structure for network address.
94 95
 */
struct SockAddr {
96
  sockaddr_storage addr;
97 98
  SockAddr() {}
  /*!
99
   * \brief construct address by url and port
100 101 102 103 104 105
   * \param url The url of the address
   * \param port The port of the address.
   */
  SockAddr(const char *url, int port) {
    this->Set(url, port);
  }
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

  /*!
  * \brief SockAddr Get the socket address from tracker.
  * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090)
  * \return SockAddr parsed from url.
  */
  explicit SockAddr(const std::string &url) {
    size_t sep = url.find(",");
    std::string host = url.substr(2, sep - 3);
    std::string port = url.substr(sep + 1, url.length() - 1);
    CHECK(ValidateIP(host)) << "Url address is not valid " << url;
    if (host == "localhost") {
      host = "127.0.0.1";
    }
    this->Set(host.c_str(), std::stoi(port));
  }

123 124
  /*!
   * \brief set the address
125
   * \param host the url of the address
126 127 128 129 130
   * \param port the port of address
   */
  void Set(const char *host, int port) {
    addrinfo hints;
    memset(&hints, 0, sizeof(hints));
131 132
    hints.ai_family = PF_UNSPEC;
    hints.ai_flags = AI_PASSIVE;
133
    hints.ai_socktype = SOCK_STREAM;
134 135 136 137
    addrinfo *res = NULL;
    int sig = getaddrinfo(host, NULL, &hints, &res);
    CHECK(sig == 0 && res != NULL)
        << "cannot obtain address of " <<  host;
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    switch (res->ai_family) {
      case AF_INET: {
          sockaddr_in *addr4 = reinterpret_cast<sockaddr_in *>(&addr);
          memcpy(addr4, res->ai_addr, res->ai_addrlen);
          addr4->sin_port = htons(port);
          addr4->sin_family = AF_INET;
        }
        break;
      case AF_INET6: {
          sockaddr_in6 *addr6 = reinterpret_cast<sockaddr_in6 *>(&addr);
          memcpy(addr6, res->ai_addr, res->ai_addrlen);
          addr6->sin6_port = htons(port);
          addr6->sin6_family = AF_INET6;
        }
        break;
      default:
        CHECK(false) << "cannot decode address";
    }
156 157 158 159
    freeaddrinfo(res);
  }
  /*! \brief return port of the address */
  int port() const {
160 161 162 163 164 165 166
    return ntohs((addr.ss_family == AF_INET6)? \
                    reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_port : \
                    reinterpret_cast<const sockaddr_in *>(&addr)->sin_port);
  }
  /*! \brief return the ip address family */
  int ss_family() const {
    return addr.ss_family;
167 168 169 170
  }
  /*! \return a string representation of the address */
  std::string AsString() const {
    std::string buf; buf.resize(256);
171 172 173 174 175 176 177 178 179 180 181 182

  const void *sinx_addr = nullptr;
  if (addr.ss_family == AF_INET6) {
    const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_addr;
    sinx_addr = reinterpret_cast<const void *>(&addr6);
  } else if (addr.ss_family == AF_INET) {
    const in_addr& addr4 = reinterpret_cast<const sockaddr_in *>(&addr)->sin_addr;
    sinx_addr = reinterpret_cast<const void *>(&addr4);
  } else {
    CHECK(false) << "illegal address";
  }

183
#ifdef _WIN32
184
    const char *s = inet_ntop(addr.ss_family, (PVOID)sinx_addr,  // NOLINT(*)
185 186
                              &buf[0], buf.length());
#else
187
    const char *s = inet_ntop(addr.ss_family, sinx_addr,
188
                              &buf[0], static_cast<socklen_t>(buf.length()));
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
#endif
    CHECK(s != nullptr) << "cannot decode address";
    std::ostringstream os;
    os << s << ":" << port();
    return os.str();
  }
};
/*!
 * \brief base class containing common operations of TCP and UDP sockets
 */
class Socket {
 public:
#if defined(_WIN32)
  using sock_size_t = int;
  using SockType = SOCKET;
#else
  using SockType = int;
  using sock_size_t = size_t;
  static constexpr int INVALID_SOCKET = -1;
#endif
  /*! \brief the file descriptor of socket */
  SockType sockfd;
  /*!
   * \brief set this socket to use non-blocking mode
   * \param non_block whether set it to be non-block, if it is false
   *        it will set it back to block mode
   */
  void SetNonBlock(bool non_block) {
#ifdef _WIN32
    u_long mode = non_block ? 1 : 0;
    if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
      Socket::Error("SetNonBlock");
    }
#else
    int flag = fcntl(sockfd, F_GETFL, 0);
    if (flag == -1) {
      Socket::Error("SetNonBlock-1");
    }
    if (non_block) {
      flag |= O_NONBLOCK;
    } else {
      flag &= ~O_NONBLOCK;
    }
    if (fcntl(sockfd, F_SETFL, flag) == -1) {
      Socket::Error("SetNonBlock-2");
    }
#endif
  }
  /*!
   * \brief bind the socket to an address
239
   * \param addr The address to be binded
240 241 242
   */
  void Bind(const SockAddr &addr) {
    if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
eqy committed
243 244
             (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
                                                sizeof(sockaddr_in))) == -1) {
245 246 247 248 249
      Socket::Error("Bind");
    }
  }
  /*!
   * \brief try bind the socket to host, from start_port to end_port
250
   * \param host host address to bind the socket
251 252 253 254
   * \param start_port starting port number to try
   * \param end_port ending port number to try
   * \return the port successfully bind to, return -1 if failed to bind any port
   */
255
  inline int TryBindHost(std::string host, int start_port, int end_port) {
256
    for (int port = start_port; port < end_port; ++port) {
257
      SockAddr addr(host.c_str(), port);
258
      if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
eqy committed
259 260
               (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
                                                  sizeof(sockaddr_in))) == 0) {
261
        return port;
262 263
      } else {
        LOG(WARNING) << "Bind failed to " << host << ":" << port;
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
      }
#if defined(_WIN32)
      if (WSAGetLastError() != WSAEADDRINUSE) {
        Socket::Error("TryBindHost");
      }
#else
      if (errno != EADDRINUSE) {
        Socket::Error("TryBindHost");
      }
#endif
    }
    return -1;
  }
  /*! \brief get last error code if any */
  int GetSockError() const {
    int error = 0;
    socklen_t len = sizeof(error);
    if (getsockopt(sockfd,  SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
      Error("GetSockError");
    }
    return error;
  }
  /*! \brief check if anything bad happens */
  bool BadSocket() const {
    if (IsClosed()) return true;
    int err = GetSockError();
    if (err == EBADF || err == EINTR) return true;
    return false;
  }
  /*! \brief check if socket is already closed */
  bool IsClosed() const {
    return sockfd == INVALID_SOCKET;
  }
  /*! \brief close the socket */
  void Close() {
    if (sockfd != INVALID_SOCKET) {
#ifdef _WIN32
      closesocket(sockfd);
#else
      close(sockfd);
#endif
      sockfd = INVALID_SOCKET;
    } else {
      Error("Socket::Close double close the socket or close without create");
    }
  }
  /*!
   * \return last error of socket 2operation
   */
  static int GetLastError() {
#ifdef _WIN32
    return WSAGetLastError();
#else
    return errno;
#endif
  }
  /*! \return whether last error was would block */
  static bool LastErrorWouldBlock() {
    int errsv = GetLastError();
#ifdef _WIN32
    return errsv == WSAEWOULDBLOCK;
#else
    return errsv == EAGAIN || errsv == EWOULDBLOCK;
#endif
  }
  /*!
   * \brief start up the socket module
   *   call this before using the sockets
   */
  static void Startup() {
#ifdef _WIN32
    WSADATA wsa_data;
    if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
      Socket::Error("Startup");
    }
    if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
      WSACleanup();
      LOG(FATAL) << "Could not find a usable version of Winsock.dll";
    }
#endif
  }
  /*!
   * \brief shutdown the socket module after use, all sockets need to be closed
   */
  static void Finalize() {
#ifdef _WIN32
    WSACleanup();
#endif
  }
  /*!
   * \brief Report an socket error.
   * \param msg The error message.
   */
  static void Error(const char *msg) {
    int errsv = GetLastError();
#ifdef _WIN32
    LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
#else
    LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv);
#endif
  }

 protected:
  explicit Socket(SockType sockfd) : sockfd(sockfd) {
  }
};

/*!
 * \brief a wrapper of TCP socket that hopefully be cross platform
 */
class TCPSocket : public Socket {
 public:
  TCPSocket() : Socket(INVALID_SOCKET) {
  }
  /*!
   * \brief construct a TCP socket from existing descriptor
   * \param sockfd The descriptor
   */
  explicit TCPSocket(SockType sockfd) : Socket(sockfd) {
  }
  /*!
   * \brief enable/disable TCP keepalive
   * \param keepalive whether to set the keep alive option on
   */
  void SetKeepAlive(bool keepalive) {
    int opt = static_cast<int>(keepalive);
    if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
                   reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
      Socket::Error("SetKeepAlive");
    }
  }
  /*!
   * \brief create the socket, call this before using socket
   * \param af domain
   */
  void Create(int af = PF_INET) {
400
    sockfd = socket(af, SOCK_STREAM, 0);
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
    if (sockfd == INVALID_SOCKET) {
      Socket::Error("Create");
    }
  }
  /*!
   * \brief perform listen of the socket
   * \param backlog backlog parameter
   */
  void Listen(int backlog = 16) {
    listen(sockfd, backlog);
  }
  /*!
   * \brief get a new connection
   * \return The accepted socket connection.
   */
  TCPSocket Accept() {
    SockType newfd = accept(sockfd, NULL, NULL);
    if (newfd == INVALID_SOCKET) {
      Socket::Error("Accept");
    }
    return TCPSocket(newfd);
  }
  /*!
424 425 426 427 428 429 430 431 432 433 434 435 436 437
  * \brief get a new connection
  * \param addr client address from which connection accepted
  * \return The accepted socket connection.
  */
  TCPSocket Accept(SockAddr *addr) {
    socklen_t addrlen = sizeof(addr->addr);
    SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr),
                            &addrlen);
    if (newfd == INVALID_SOCKET) {
      Socket::Error("Accept");
    }
    return TCPSocket(newfd);
  }
  /*!
438
   * \brief decide whether the socket is at OOB mark
439
   * \return 1 if at mark, 0 if not, -1 if an error occurred
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
   */
  int AtMark() const {
#ifdef _WIN32
    unsigned long atmark;  // NOLINT(*)
    if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
#else
    int atmark;
    if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif
    return static_cast<int>(atmark);
  }
  /*!
   * \brief connect to an address
   * \param addr the address to connect to
   * \return whether connect is successful
   */
  bool Connect(const SockAddr &addr) {
    return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
eqy committed
458 459
                   (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
                                                      sizeof(sockaddr_in))) == 0;
460 461 462
  }
  /*!
   * \brief send data using the socket
463
   * \param buf_ the pointer to the buffer
464
   * \param len the size of the buffer
465
   * \param flag extra flags
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
   * \return size of data actually sent
   *         return -1 if error occurs
   */
  ssize_t Send(const void *buf_, size_t len, int flag = 0) {
    const char *buf = reinterpret_cast<const char*>(buf_);
    return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
  }
  /*!
   * \brief receive data using the socket
   * \param buf_ the pointer to the buffer
   * \param len the size of the buffer
   * \param flags extra flags
   * \return size of data actually received
   *         return -1 if error occurs
   */
  ssize_t Recv(void *buf_, size_t len, int flags = 0) {
    char *buf = reinterpret_cast<char*>(buf_);
    return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
  }
  /*!
   * \brief peform block write that will attempt to send all data out
   *    can still return smaller than request when error occurs
488
   * \param buf_ the pointer to the buffer
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
   * \param len the size of the buffer
   * \return size of data actually sent
   */
  size_t SendAll(const void *buf_, size_t len) {
    const char *buf = reinterpret_cast<const char*>(buf_);
    size_t ndone = 0;
    while (ndone <  len) {
      ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
      if (ret == -1) {
        if (LastErrorWouldBlock()) return ndone;
        Socket::Error("SendAll");
      }
      buf += ret;
      ndone += ret;
    }
    return ndone;
  }
  /*!
   * \brief peform block read that will attempt to read all data
   *    can still return smaller than request when error occurs
   * \param buf_ the buffer pointer
   * \param len length of data to recv
   * \return size of data actually sent
   */
  size_t RecvAll(void *buf_, size_t len) {
    char *buf = reinterpret_cast<char*>(buf_);
    size_t ndone = 0;
    while (ndone <  len) {
      ssize_t ret = recv(sockfd, buf,
                         static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
      if (ret == -1) {
        if (LastErrorWouldBlock())  {
          LOG(FATAL) << "would block";
          return ndone;
        }
        Socket::Error("RecvAll");
      }
      if (ret == 0) return ndone;
      buf += ret;
      ndone += ret;
    }
    return ndone;
  }
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
  /*!
   * \brief Send the data to remote.
   * \param data The data to be sent.
   */
  void SendBytes(std::string data) {
    int datalen = data.length();
    CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen));
    CHECK_EQ(SendAll(data.c_str(), datalen), datalen);
  }
  /*!
   * \brief Receive the data to remote.
   * \return The data received.
   */
  std::string RecvBytes() {
    int datalen = 0;
    CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen));
    std::string data;
    data.resize(datalen);
    CHECK_EQ(RecvAll(&data[0], datalen), datalen);
    return data;
  }
553
};
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650

/*! \brief helper data structure to perform poll */
struct PollHelper {
 public:
  /*!
   * \brief add file descriptor to watch for read
   * \param fd file descriptor to be watched
   */
  inline void WatchRead(TCPSocket::SockType fd) {
    auto& pfd = fds[fd];
    pfd.fd = fd;
    pfd.events |= POLLIN;
  }
  /*!
   * \brief add file descriptor to watch for write
   * \param fd file descriptor to be watched
   */
  inline void WatchWrite(TCPSocket::SockType fd) {
    auto& pfd = fds[fd];
    pfd.fd = fd;
    pfd.events |= POLLOUT;
  }
  /*!
   * \brief add file descriptor to watch for exception
   * \param fd file descriptor to be watched
   */
  inline void WatchException(TCPSocket::SockType fd) {
    auto& pfd = fds[fd];
    pfd.fd = fd;
    pfd.events |= POLLPRI;
  }
  /*!
   * \brief Check if the descriptor is ready for read
   * \param fd file descriptor to check status
   */
  inline bool CheckRead(TCPSocket::SockType fd) const {
    const auto& pfd = fds.find(fd);
    return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
  }
  /*!
   * \brief Check if the descriptor is ready for write
   * \param fd file descriptor to check status
   */
  inline bool CheckWrite(TCPSocket::SockType fd) const {
    const auto& pfd = fds.find(fd);
    return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
  }
  /*!
   * \brief Check if the descriptor has any exception
   * \param fd file descriptor to check status
   */
  inline bool CheckExcept(TCPSocket::SockType fd) const {
    const auto& pfd = fds.find(fd);
    return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
  }
  /*!
   * \brief wait for exception event on a single descriptor
   * \param fd the file descriptor to wait the event for
   * \param timeout the timeout counter, can be negative, which means wait until the event happen
   * \return 1 if success, 0 if timeout, and -1 if error occurs
   */
  inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*)
    pollfd pfd;
    pfd.fd = fd;
    pfd.events = POLLPRI;
    return poll(&pfd, 1, timeout);
  }

  /*!
   * \brief peform poll on the set defined, read, write, exception
   * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
   * \return
   */
  inline void Poll(long timeout = -1) {  // NOLINT(*)
    std::vector<pollfd> fdset;
    fdset.reserve(fds.size());
    for (auto kv : fds) {
      fdset.push_back(kv.second);
    }
    int ret = poll(fdset.data(), fdset.size(), timeout);
    if (ret == -1) {
      Socket::Error("Poll");
    } else {
      for (auto& pfd : fdset) {
        auto revents = pfd.revents & pfd.events;
        if (!revents) {
          fds.erase(pfd.fd);
        } else {
          fds[pfd.fd].events = revents;
        }
      }
    }
  }

  std::unordered_map<TCPSocket::SockType, pollfd> fds;
};

651 652 653
}  // namespace common
}  // namespace tvm
#endif  // TVM_COMMON_SOCKET_H_