/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file socket.h * \brief this file aims to provide a wrapper of sockets * \author Tianqi Chen */ #ifndef TVM_SUPPORT_SOCKET_H_ #define TVM_SUPPORT_SOCKET_H_ #if defined(_WIN32) #define NOMINMAX #include <winsock2.h> #include <ws2tcpip.h> #undef NOMINMAX using ssize_t = int; #ifdef _MSC_VER #pragma comment(lib, "Ws2_32.lib") #endif #else #include <fcntl.h> #include <netdb.h> #include <errno.h> #include <unistd.h> #include <arpa/inet.h> #include <netinet/in.h> #include <sys/socket.h> #include <sys/select.h> #include <sys/ioctl.h> #endif #include <dmlc/logging.h> #include <string> #include <cstring> #include <vector> #include <unordered_map> #include "../support/util.h" #if defined(_WIN32) static inline int poll(struct pollfd *pfd, int nfds, int timeout) { return WSAPoll(pfd, nfds, timeout); } #else #include <sys/poll.h> #endif // defined(_WIN32) namespace tvm { namespace support { /*! * \brief Get current host name. * \return The hostname. */ inline std::string GetHostName() { std::string buf; buf.resize(256); CHECK_NE(gethostname(&buf[0], 256), -1); return std::string(buf.c_str()); } /*! * \brief ValidateIP validates an ip address. * \param ip The ip address in string format localhost or x.x.x.x format * \return result of operation. */ inline bool ValidateIP(std::string ip) { if (ip == "localhost") { return true; } struct sockaddr_in sa_ipv4; struct sockaddr_in6 sa_ipv6; bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr)); bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr)); return is_ipv4 || is_ipv6; } /*! * \brief Common data structure for network address. */ struct SockAddr { sockaddr_storage addr; SockAddr() {} /*! * \brief construct address by url and port * \param url The url of the address * \param port The port of the address. */ SockAddr(const char *url, int port) { this->Set(url, port); } /*! * \brief SockAddr Get the socket address from tracker. * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) * \return SockAddr parsed from url. */ explicit SockAddr(const std::string &url) { size_t sep = url.find(","); std::string host = url.substr(2, sep - 3); std::string port = url.substr(sep + 1, url.length() - 1); CHECK(ValidateIP(host)) << "Url address is not valid " << url; if (host == "localhost") { host = "127.0.0.1"; } this->Set(host.c_str(), std::stoi(port)); } /*! * \brief set the address * \param host the url of the address * \param port the port of address */ void Set(const char *host, int port) { addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = PF_UNSPEC; hints.ai_flags = AI_PASSIVE; hints.ai_socktype = SOCK_STREAM; addrinfo *res = NULL; int sig = getaddrinfo(host, NULL, &hints, &res); CHECK(sig == 0 && res != NULL) << "cannot obtain address of " << host; switch (res->ai_family) { case AF_INET: { sockaddr_in *addr4 = reinterpret_cast<sockaddr_in *>(&addr); memcpy(addr4, res->ai_addr, res->ai_addrlen); addr4->sin_port = htons(port); addr4->sin_family = AF_INET; } break; case AF_INET6: { sockaddr_in6 *addr6 = reinterpret_cast<sockaddr_in6 *>(&addr); memcpy(addr6, res->ai_addr, res->ai_addrlen); addr6->sin6_port = htons(port); addr6->sin6_family = AF_INET6; } break; default: CHECK(false) << "cannot decode address"; } freeaddrinfo(res); } /*! \brief return port of the address */ int port() const { return ntohs((addr.ss_family == AF_INET6)? \ reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_port : \ reinterpret_cast<const sockaddr_in *>(&addr)->sin_port); } /*! \brief return the ip address family */ int ss_family() const { return addr.ss_family; } /*! \return a string representation of the address */ std::string AsString() const { std::string buf; buf.resize(256); const void *sinx_addr = nullptr; if (addr.ss_family == AF_INET6) { const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_addr; sinx_addr = reinterpret_cast<const void *>(&addr6); } else if (addr.ss_family == AF_INET) { const in_addr& addr4 = reinterpret_cast<const sockaddr_in *>(&addr)->sin_addr; sinx_addr = reinterpret_cast<const void *>(&addr4); } else { CHECK(false) << "illegal address"; } #ifdef _WIN32 const char *s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) &buf[0], buf.length()); #else const char *s = inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast<socklen_t>(buf.length())); #endif CHECK(s != nullptr) << "cannot decode address"; std::ostringstream os; os << s << ":" << port(); return os.str(); } }; /*! * \brief base class containing common operations of TCP and UDP sockets */ class Socket { public: #if defined(_WIN32) using sock_size_t = int; using SockType = SOCKET; #else using SockType = int; using sock_size_t = size_t; static constexpr int INVALID_SOCKET = -1; #endif /*! \brief the file descriptor of socket */ SockType sockfd; /*! * \brief set this socket to use non-blocking mode * \param non_block whether set it to be non-block, if it is false * it will set it back to block mode */ void SetNonBlock(bool non_block) { #ifdef _WIN32 u_long mode = non_block ? 1 : 0; if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { Socket::Error("SetNonBlock"); } #else int flag = fcntl(sockfd, F_GETFL, 0); if (flag == -1) { Socket::Error("SetNonBlock-1"); } if (non_block) { flag |= O_NONBLOCK; } else { flag &= ~O_NONBLOCK; } if (fcntl(sockfd, F_SETFL, flag) == -1) { Socket::Error("SetNonBlock-2"); } #endif } /*! * \brief bind the socket to an address * \param addr The address to be binded */ void Bind(const SockAddr &addr) { if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == -1) { Socket::Error("Bind"); } } /*! * \brief try bind the socket to host, from start_port to end_port * \param host host address to bind the socket * \param start_port starting port number to try * \param end_port ending port number to try * \return the port successfully bind to, return -1 if failed to bind any port */ inline int TryBindHost(std::string host, int start_port, int end_port) { for (int port = start_port; port < end_port; ++port) { SockAddr addr(host.c_str(), port); if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr), (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0) { return port; } else { LOG(WARNING) << "Bind failed to " << host << ":" << port; } #if defined(_WIN32) if (WSAGetLastError() != WSAEADDRINUSE) { Socket::Error("TryBindHost"); } #else if (errno != EADDRINUSE) { Socket::Error("TryBindHost"); } #endif } return -1; } /*! \brief get last error code if any */ int GetSockError() const { int error = 0; socklen_t len = sizeof(error); if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) { Error("GetSockError"); } return error; } /*! \brief check if anything bad happens */ bool BadSocket() const { if (IsClosed()) return true; int err = GetSockError(); if (err == EBADF || err == EINTR) return true; return false; } /*! \brief check if socket is already closed */ bool IsClosed() const { return sockfd == INVALID_SOCKET; } /*! \brief close the socket */ void Close() { if (sockfd != INVALID_SOCKET) { #ifdef _WIN32 closesocket(sockfd); #else close(sockfd); #endif sockfd = INVALID_SOCKET; } else { Error("Socket::Close double close the socket or close without create"); } } /*! * \return last error of socket 2operation */ static int GetLastError() { #ifdef _WIN32 return WSAGetLastError(); #else return errno; #endif } /*! \return whether last error was would block */ static bool LastErrorWouldBlock() { int errsv = GetLastError(); #ifdef _WIN32 return errsv == WSAEWOULDBLOCK; #else return errsv == EAGAIN || errsv == EWOULDBLOCK; #endif } /*! * \brief start up the socket module * call this before using the sockets */ static void Startup() { #ifdef _WIN32 WSADATA wsa_data; if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { Socket::Error("Startup"); } if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { WSACleanup(); LOG(FATAL) << "Could not find a usable version of Winsock.dll"; } #endif } /*! * \brief shutdown the socket module after use, all sockets need to be closed */ static void Finalize() { #ifdef _WIN32 WSACleanup(); #endif } /*! * \brief Report an socket error. * \param msg The error message. */ static void Error(const char *msg) { int errsv = GetLastError(); #ifdef _WIN32 LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; #else LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv); #endif } protected: explicit Socket(SockType sockfd) : sockfd(sockfd) { } }; /*! * \brief a wrapper of TCP socket that hopefully be cross platform */ class TCPSocket : public Socket { public: TCPSocket() : Socket(INVALID_SOCKET) { } /*! * \brief construct a TCP socket from existing descriptor * \param sockfd The descriptor */ explicit TCPSocket(SockType sockfd) : Socket(sockfd) { } /*! * \brief enable/disable TCP keepalive * \param keepalive whether to set the keep alive option on */ void SetKeepAlive(bool keepalive) { int opt = static_cast<int>(keepalive); if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) { Socket::Error("SetKeepAlive"); } } /*! * \brief create the socket, call this before using socket * \param af domain */ void Create(int af = PF_INET) { sockfd = socket(af, SOCK_STREAM, 0); if (sockfd == INVALID_SOCKET) { Socket::Error("Create"); } } /*! * \brief perform listen of the socket * \param backlog backlog parameter */ void Listen(int backlog = 16) { listen(sockfd, backlog); } /*! * \brief get a new connection * \return The accepted socket connection. */ TCPSocket Accept() { SockType newfd = accept(sockfd, NULL, NULL); if (newfd == INVALID_SOCKET) { Socket::Error("Accept"); } return TCPSocket(newfd); } /*! * \brief get a new connection * \param addr client address from which connection accepted * \return The accepted socket connection. */ TCPSocket Accept(SockAddr *addr) { socklen_t addrlen = sizeof(addr->addr); SockType newfd = accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); if (newfd == INVALID_SOCKET) { Socket::Error("Accept"); } return TCPSocket(newfd); } /*! * \brief decide whether the socket is at OOB mark * \return 1 if at mark, 0 if not, -1 if an error occurred */ int AtMark() const { #ifdef _WIN32 unsigned long atmark; // NOLINT(*) if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; #else int atmark; if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; #endif return static_cast<int>(atmark); } /*! * \brief connect to an address * \param addr the address to connect to * \return whether connect is successful */ bool Connect(const SockAddr &addr) { return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0; } /*! * \brief send data using the socket * \param buf_ the pointer to the buffer * \param len the size of the buffer * \param flag extra flags * \return size of data actually sent * return -1 if error occurs */ ssize_t Send(const void *buf_, size_t len, int flag = 0) { const char *buf = reinterpret_cast<const char*>(buf_); return send(sockfd, buf, static_cast<sock_size_t>(len), flag); } /*! * \brief receive data using the socket * \param buf_ the pointer to the buffer * \param len the size of the buffer * \param flags extra flags * \return size of data actually received * return -1 if error occurs */ ssize_t Recv(void *buf_, size_t len, int flags = 0) { char *buf = reinterpret_cast<char*>(buf_); return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); } /*! * \brief peform block write that will attempt to send all data out * can still return smaller than request when error occurs * \param buf_ the pointer to the buffer * \param len the size of the buffer * \return size of data actually sent */ size_t SendAll(const void *buf_, size_t len) { const char *buf = reinterpret_cast<const char*>(buf_); size_t ndone = 0; while (ndone < len) { ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); if (ret == -1) { if (LastErrorWouldBlock()) return ndone; Socket::Error("SendAll"); } buf += ret; ndone += ret; } return ndone; } /*! * \brief peform block read that will attempt to read all data * can still return smaller than request when error occurs * \param buf_ the buffer pointer * \param len length of data to recv * \return size of data actually sent */ size_t RecvAll(void *buf_, size_t len) { char *buf = reinterpret_cast<char*>(buf_); size_t ndone = 0; while (ndone < len) { ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); if (ret == -1) { if (LastErrorWouldBlock()) { LOG(FATAL) << "would block"; return ndone; } Socket::Error("RecvAll"); } if (ret == 0) return ndone; buf += ret; ndone += ret; } return ndone; } /*! * \brief Send the data to remote. * \param data The data to be sent. */ void SendBytes(std::string data) { int datalen = data.length(); CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen)); CHECK_EQ(SendAll(data.c_str(), datalen), datalen); } /*! * \brief Receive the data to remote. * \return The data received. */ std::string RecvBytes() { int datalen = 0; CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen)); std::string data; data.resize(datalen); CHECK_EQ(RecvAll(&data[0], datalen), datalen); return data; } }; /*! \brief helper data structure to perform poll */ struct PollHelper { public: /*! * \brief add file descriptor to watch for read * \param fd file descriptor to be watched */ inline void WatchRead(TCPSocket::SockType fd) { auto& pfd = fds[fd]; pfd.fd = fd; pfd.events |= POLLIN; } /*! * \brief add file descriptor to watch for write * \param fd file descriptor to be watched */ inline void WatchWrite(TCPSocket::SockType fd) { auto& pfd = fds[fd]; pfd.fd = fd; pfd.events |= POLLOUT; } /*! * \brief add file descriptor to watch for exception * \param fd file descriptor to be watched */ inline void WatchException(TCPSocket::SockType fd) { auto& pfd = fds[fd]; pfd.fd = fd; pfd.events |= POLLPRI; } /*! * \brief Check if the descriptor is ready for read * \param fd file descriptor to check status */ inline bool CheckRead(TCPSocket::SockType fd) const { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); } /*! * \brief Check if the descriptor is ready for write * \param fd file descriptor to check status */ inline bool CheckWrite(TCPSocket::SockType fd) const { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); } /*! * \brief Check if the descriptor has any exception * \param fd file descriptor to check status */ inline bool CheckExcept(TCPSocket::SockType fd) const { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); } /*! * \brief wait for exception event on a single descriptor * \param fd the file descriptor to wait the event for * \param timeout the timeout counter, can be negative, which means wait until the event happen * \return 1 if success, 0 if timeout, and -1 if error occurs */ inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) pollfd pfd; pfd.fd = fd; pfd.events = POLLPRI; return poll(&pfd, 1, timeout); } /*! * \brief peform poll on the set defined, read, write, exception * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block * \return */ inline void Poll(long timeout = -1) { // NOLINT(*) std::vector<pollfd> fdset; fdset.reserve(fds.size()); for (auto kv : fds) { fdset.push_back(kv.second); } int ret = poll(fdset.data(), fdset.size(), timeout); if (ret == -1) { Socket::Error("Poll"); } else { for (auto& pfd : fdset) { auto revents = pfd.revents & pfd.events; if (!revents) { fds.erase(pfd.fd); } else { fds[pfd.fd].events = revents; } } } } std::unordered_map<TCPSocket::SockType, pollfd> fds; }; } // namespace support } // namespace tvm #endif // TVM_SUPPORT_SOCKET_H_