socket.h 13.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31
/*!
 *  Copyright (c) 2017 by Contributors
 * \file socket.h
 * \brief this file aims to provide a wrapper of sockets
 * \author Tianqi Chen
 */
#ifndef TVM_COMMON_SOCKET_H_
#define TVM_COMMON_SOCKET_H_

#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
32
using ssize_t = int;
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif
#else
#include <fcntl.h>
#include <netdb.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#endif
#include <dmlc/logging.h>
#include <string>
#include <cstring>


namespace tvm {
namespace common {
/*!
 * \brief Get current host name.
 * \return The hostname.
 */
inline std::string GetHostName() {
  std::string buf; buf.resize(256);
  CHECK_NE(gethostname(&buf[0], 256), -1);
  return std::string(buf.c_str());
}

/*!
64
 * \brief Common data structure for network address.
65 66
 */
struct SockAddr {
67
  sockaddr_storage addr;
68 69
  SockAddr() {}
  /*!
70
   * \brief construct address by url and port
71 72 73 74 75 76 77 78
   * \param url The url of the address
   * \param port The port of the address.
   */
  SockAddr(const char *url, int port) {
    this->Set(url, port);
  }
  /*!
   * \brief set the address
79
   * \param host the url of the address
80 81 82 83 84
   * \param port the port of address
   */
  void Set(const char *host, int port) {
    addrinfo hints;
    memset(&hints, 0, sizeof(hints));
85 86
    hints.ai_family = PF_UNSPEC;
    hints.ai_flags = AI_PASSIVE;
87
    hints.ai_socktype = SOCK_STREAM;
88 89 90 91
    addrinfo *res = NULL;
    int sig = getaddrinfo(host, NULL, &hints, &res);
    CHECK(sig == 0 && res != NULL)
        << "cannot obtain address of " <<  host;
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    switch (res->ai_family) {
      case AF_INET: {
          sockaddr_in *addr4 = reinterpret_cast<sockaddr_in *>(&addr);
          memcpy(addr4, res->ai_addr, res->ai_addrlen);
          addr4->sin_port = htons(port);
          addr4->sin_family = AF_INET;
        }
        break;
      case AF_INET6: {
          sockaddr_in6 *addr6 = reinterpret_cast<sockaddr_in6 *>(&addr);
          memcpy(addr6, res->ai_addr, res->ai_addrlen);
          addr6->sin6_port = htons(port);
          addr6->sin6_family = AF_INET6;
        }
        break;
      default:
        CHECK(false) << "cannot decode address";
    }
110 111 112 113
    freeaddrinfo(res);
  }
  /*! \brief return port of the address */
  int port() const {
114 115 116 117 118 119 120
    return ntohs((addr.ss_family == AF_INET6)? \
                    reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_port : \
                    reinterpret_cast<const sockaddr_in *>(&addr)->sin_port);
  }
  /*! \brief return the ip address family */
  int ss_family() const {
    return addr.ss_family;
121 122 123 124
  }
  /*! \return a string representation of the address */
  std::string AsString() const {
    std::string buf; buf.resize(256);
125 126 127 128 129 130 131 132 133 134 135 136

  const void *sinx_addr = nullptr;
  if (addr.ss_family == AF_INET6) {
    const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_addr;
    sinx_addr = reinterpret_cast<const void *>(&addr6);
  } else if (addr.ss_family == AF_INET) {
    const in_addr& addr4 = reinterpret_cast<const sockaddr_in *>(&addr)->sin_addr;
    sinx_addr = reinterpret_cast<const void *>(&addr4);
  } else {
    CHECK(false) << "illegal address";
  }

137
#ifdef _WIN32
138
    const char *s = inet_ntop(addr.ss_family, (PVOID)sinx_addr,  // NOLINT(*)
139 140
                              &buf[0], buf.length());
#else
141
    const char *s = inet_ntop(addr.ss_family, sinx_addr,
142
                              &buf[0], static_cast<socklen_t>(buf.length()));
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
#endif
    CHECK(s != nullptr) << "cannot decode address";
    std::ostringstream os;
    os << s << ":" << port();
    return os.str();
  }
};
/*!
 * \brief base class containing common operations of TCP and UDP sockets
 */
class Socket {
 public:
#if defined(_WIN32)
  using sock_size_t = int;
  using SockType = SOCKET;
#else
  using SockType = int;
  using sock_size_t = size_t;
  static constexpr int INVALID_SOCKET = -1;
#endif
  /*! \brief the file descriptor of socket */
  SockType sockfd;
  /*!
   * \brief set this socket to use non-blocking mode
   * \param non_block whether set it to be non-block, if it is false
   *        it will set it back to block mode
   */
  void SetNonBlock(bool non_block) {
#ifdef _WIN32
    u_long mode = non_block ? 1 : 0;
    if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
      Socket::Error("SetNonBlock");
    }
#else
    int flag = fcntl(sockfd, F_GETFL, 0);
    if (flag == -1) {
      Socket::Error("SetNonBlock-1");
    }
    if (non_block) {
      flag |= O_NONBLOCK;
    } else {
      flag &= ~O_NONBLOCK;
    }
    if (fcntl(sockfd, F_SETFL, flag) == -1) {
      Socket::Error("SetNonBlock-2");
    }
#endif
  }
  /*!
   * \brief bind the socket to an address
193
   * \param addr The address to be binded
194 195 196
   */
  void Bind(const SockAddr &addr) {
    if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
eqy committed
197 198
             (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
                                                sizeof(sockaddr_in))) == -1) {
199 200 201 202 203 204 205 206 207 208 209 210 211
      Socket::Error("Bind");
    }
  }
  /*!
   * \brief try bind the socket to host, from start_port to end_port
   * \param start_port starting port number to try
   * \param end_port ending port number to try
   * \return the port successfully bind to, return -1 if failed to bind any port
   */
  inline int TryBindHost(int start_port, int end_port) {
    for (int port = start_port; port < end_port; ++port) {
      SockAddr addr("0.0.0.0", port);
      if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
eqy committed
212 213
               (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
                                                  sizeof(sockaddr_in))) == 0) {
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
        return port;
      }
#if defined(_WIN32)
      if (WSAGetLastError() != WSAEADDRINUSE) {
        Socket::Error("TryBindHost");
      }
#else
      if (errno != EADDRINUSE) {
        Socket::Error("TryBindHost");
      }
#endif
    }
    return -1;
  }
  /*! \brief get last error code if any */
  int GetSockError() const {
    int error = 0;
    socklen_t len = sizeof(error);
    if (getsockopt(sockfd,  SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
      Error("GetSockError");
    }
    return error;
  }
  /*! \brief check if anything bad happens */
  bool BadSocket() const {
    if (IsClosed()) return true;
    int err = GetSockError();
    if (err == EBADF || err == EINTR) return true;
    return false;
  }
  /*! \brief check if socket is already closed */
  bool IsClosed() const {
    return sockfd == INVALID_SOCKET;
  }
  /*! \brief close the socket */
  void Close() {
    if (sockfd != INVALID_SOCKET) {
#ifdef _WIN32
      closesocket(sockfd);
#else
      close(sockfd);
#endif
      sockfd = INVALID_SOCKET;
    } else {
      Error("Socket::Close double close the socket or close without create");
    }
  }
  /*!
   * \return last error of socket 2operation
   */
  static int GetLastError() {
#ifdef _WIN32
    return WSAGetLastError();
#else
    return errno;
#endif
  }
  /*! \return whether last error was would block */
  static bool LastErrorWouldBlock() {
    int errsv = GetLastError();
#ifdef _WIN32
    return errsv == WSAEWOULDBLOCK;
#else
    return errsv == EAGAIN || errsv == EWOULDBLOCK;
#endif
  }
  /*!
   * \brief start up the socket module
   *   call this before using the sockets
   */
  static void Startup() {
#ifdef _WIN32
    WSADATA wsa_data;
    if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
      Socket::Error("Startup");
    }
    if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
      WSACleanup();
      LOG(FATAL) << "Could not find a usable version of Winsock.dll";
    }
#endif
  }
  /*!
   * \brief shutdown the socket module after use, all sockets need to be closed
   */
  static void Finalize() {
#ifdef _WIN32
    WSACleanup();
#endif
  }
  /*!
   * \brief Report an socket error.
   * \param msg The error message.
   */
  static void Error(const char *msg) {
    int errsv = GetLastError();
#ifdef _WIN32
    LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
#else
    LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv);
#endif
  }

 protected:
  explicit Socket(SockType sockfd) : sockfd(sockfd) {
  }
};

/*!
 * \brief a wrapper of TCP socket that hopefully be cross platform
 */
class TCPSocket : public Socket {
 public:
  TCPSocket() : Socket(INVALID_SOCKET) {
  }
  /*!
   * \brief construct a TCP socket from existing descriptor
   * \param sockfd The descriptor
   */
  explicit TCPSocket(SockType sockfd) : Socket(sockfd) {
  }
  /*!
   * \brief enable/disable TCP keepalive
   * \param keepalive whether to set the keep alive option on
   */
  void SetKeepAlive(bool keepalive) {
    int opt = static_cast<int>(keepalive);
    if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
                   reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
      Socket::Error("SetKeepAlive");
    }
  }
  /*!
   * \brief create the socket, call this before using socket
   * \param af domain
   */
  void Create(int af = PF_INET) {
351
    sockfd = socket(af, SOCK_STREAM, 0);
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
    if (sockfd == INVALID_SOCKET) {
      Socket::Error("Create");
    }
  }
  /*!
   * \brief perform listen of the socket
   * \param backlog backlog parameter
   */
  void Listen(int backlog = 16) {
    listen(sockfd, backlog);
  }
  /*!
   * \brief get a new connection
   * \return The accepted socket connection.
   */
  TCPSocket Accept() {
    SockType newfd = accept(sockfd, NULL, NULL);
    if (newfd == INVALID_SOCKET) {
      Socket::Error("Accept");
    }
    return TCPSocket(newfd);
  }
  /*!
   * \brief decide whether the socket is at OOB mark
376
   * \return 1 if at mark, 0 if not, -1 if an error occurred
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
   */
  int AtMark() const {
#ifdef _WIN32
    unsigned long atmark;  // NOLINT(*)
    if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
#else
    int atmark;
    if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif
    return static_cast<int>(atmark);
  }
  /*!
   * \brief connect to an address
   * \param addr the address to connect to
   * \return whether connect is successful
   */
  bool Connect(const SockAddr &addr) {
    return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
eqy committed
395 396
                   (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) :
                                                      sizeof(sockaddr_in))) == 0;
397 398 399
  }
  /*!
   * \brief send data using the socket
400
   * \param buf_ the pointer to the buffer
401
   * \param len the size of the buffer
402
   * \param flag extra flags
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
   * \return size of data actually sent
   *         return -1 if error occurs
   */
  ssize_t Send(const void *buf_, size_t len, int flag = 0) {
    const char *buf = reinterpret_cast<const char*>(buf_);
    return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
  }
  /*!
   * \brief receive data using the socket
   * \param buf_ the pointer to the buffer
   * \param len the size of the buffer
   * \param flags extra flags
   * \return size of data actually received
   *         return -1 if error occurs
   */
  ssize_t Recv(void *buf_, size_t len, int flags = 0) {
    char *buf = reinterpret_cast<char*>(buf_);
    return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
  }
  /*!
   * \brief peform block write that will attempt to send all data out
   *    can still return smaller than request when error occurs
425
   * \param buf_ the pointer to the buffer
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
   * \param len the size of the buffer
   * \return size of data actually sent
   */
  size_t SendAll(const void *buf_, size_t len) {
    const char *buf = reinterpret_cast<const char*>(buf_);
    size_t ndone = 0;
    while (ndone <  len) {
      ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
      if (ret == -1) {
        if (LastErrorWouldBlock()) return ndone;
        Socket::Error("SendAll");
      }
      buf += ret;
      ndone += ret;
    }
    return ndone;
  }
  /*!
   * \brief peform block read that will attempt to read all data
   *    can still return smaller than request when error occurs
   * \param buf_ the buffer pointer
   * \param len length of data to recv
   * \return size of data actually sent
   */
  size_t RecvAll(void *buf_, size_t len) {
    char *buf = reinterpret_cast<char*>(buf_);
    size_t ndone = 0;
    while (ndone <  len) {
      ssize_t ret = recv(sockfd, buf,
                         static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
      if (ret == -1) {
        if (LastErrorWouldBlock())  {
          LOG(FATAL) << "would block";
          return ndone;
        }
        Socket::Error("RecvAll");
      }
      if (ret == 0) return ndone;
      buf += ret;
      ndone += ret;
    }
    return ndone;
  }
};
}  // namespace common
}  // namespace tvm
#endif  // TVM_COMMON_SOCKET_H_