Commit 0806b69e by llyfacebook Committed by eqy

[RPC] Add the IPV6 support for server side auto tuning (#2462)

* use IPV6 instead of IPV4

* backward compatible

* add error report

* fix linter

* more linter

* fix the python2 api
parent e4b9f986
...@@ -12,7 +12,7 @@ def main(args): ...@@ -12,7 +12,7 @@ def main(args):
"""Main function""" """Main function"""
if args.tracker: if args.tracker:
url, port = args.tracker.split(":") url, port = args.tracker.rsplit(":", 1)
port = int(port) port = int(port)
tracker_addr = (url, port) tracker_addr = (url, port)
if not args.key: if not args.key:
......
...@@ -42,6 +42,11 @@ class TrackerCode(object): ...@@ -42,6 +42,11 @@ class TrackerCode(object):
RPC_SESS_MASK = 128 RPC_SESS_MASK = 128
def get_addr_family(addr):
res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP)
return res[0][0]
def recvall(sock, nbytes): def recvall(sock, nbytes):
"""Receive all nbytes from socket. """Receive all nbytes from socket.
...@@ -142,7 +147,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5): ...@@ -142,7 +147,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
tstart = time.time() tstart = time.time()
while True: while True:
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(get_addr_family(addr), socket.SOCK_STREAM)
sock.connect(addr) sock.connect(addr)
return sock return sock
except socket.error as sock_err: except socket.error as sock_err:
......
...@@ -298,7 +298,8 @@ class ProxyServerHandler(object): ...@@ -298,7 +298,8 @@ class ProxyServerHandler(object):
"""Update information on tracker.""" """Update information on tracker."""
try: try:
if self._tracker_conn is None: if self._tracker_conn is None:
self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._tracker_conn = socket.socket(base.get_addr_family(self._tracker_addr),
socket.SOCK_STREAM)
self._tracker_conn.connect(self._tracker_addr) self._tracker_conn.connect(self._tracker_addr)
self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0] magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0]
...@@ -481,7 +482,7 @@ class Proxy(object): ...@@ -481,7 +482,7 @@ class Proxy(object):
tracker_addr=None, tracker_addr=None,
index_page=None, index_page=None,
resource_files=None): resource_files=None):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None self.port = None
for my_port in range(port, port_end): for my_port in range(port, port_end):
try: try:
......
...@@ -201,7 +201,7 @@ def _connect_proxy_loop(addr, key, load_library): ...@@ -201,7 +201,7 @@ def _connect_proxy_loop(addr, key, load_library):
retry_period = 5 retry_period = 5
while True: while True:
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM)
sock.connect(addr) sock.connect(addr)
sock.sendall(struct.pack("<i", base.RPC_MAGIC)) sock.sendall(struct.pack("<i", base.RPC_MAGIC))
sock.sendall(struct.pack("<i", len(key))) sock.sendall(struct.pack("<i", len(key)))
...@@ -334,7 +334,7 @@ class Server(object): ...@@ -334,7 +334,7 @@ class Server(object):
self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
time.sleep(0.5) time.sleep(0.5)
elif not is_proxy: elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None self.port = None
for my_port in range(port, port_end): for my_port in range(port, port_end):
try: try:
......
...@@ -366,7 +366,7 @@ class Tracker(object): ...@@ -366,7 +366,7 @@ class Tracker(object):
if silent: if silent:
logger.setLevel(logging.WARN) logger.setLevel(logging.WARN)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None self.port = None
self.stop_key = base.random_key("tracker") self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end): for my_port in range(port, port_end):
...@@ -391,7 +391,7 @@ class Tracker(object): ...@@ -391,7 +391,7 @@ class Tracker(object):
sock.close() sock.close()
def _stop_tracker(self): def _stop_tracker(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
sock.connect((self.host, self.port)) sock.connect((self.host, self.port))
sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(sock, 4))[0] magic = struct.unpack("<i", base.recvall(sock, 4))[0]
......
...@@ -45,7 +45,7 @@ inline std::string GetHostName() { ...@@ -45,7 +45,7 @@ inline std::string GetHostName() {
* \brief Common data structure fornetwork address. * \brief Common data structure fornetwork address.
*/ */
struct SockAddr { struct SockAddr {
sockaddr_in addr; sockaddr_storage addr;
SockAddr() {} SockAddr() {}
/*! /*!
* \brief construc address by url and port * \brief construc address by url and port
...@@ -63,30 +63,63 @@ struct SockAddr { ...@@ -63,30 +63,63 @@ struct SockAddr {
void Set(const char *host, int port) { void Set(const char *host, int port) {
addrinfo hints; addrinfo hints;
memset(&hints, 0, sizeof(hints)); memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET; hints.ai_family = PF_UNSPEC;
hints.ai_flags = AI_PASSIVE;
hints.ai_protocol = SOCK_STREAM; hints.ai_protocol = SOCK_STREAM;
addrinfo *res = NULL; addrinfo *res = NULL;
int sig = getaddrinfo(host, NULL, &hints, &res); int sig = getaddrinfo(host, NULL, &hints, &res);
CHECK(sig == 0 && res != NULL) CHECK(sig == 0 && res != NULL)
<< "cannot obtain address of " << host; << "cannot obtain address of " << host;
CHECK(res->ai_family == AF_INET) switch (res->ai_family) {
<< "Does not support IPv6"; case AF_INET: {
memcpy(&addr, res->ai_addr, res->ai_addrlen); sockaddr_in *addr4 = reinterpret_cast<sockaddr_in *>(&addr);
addr.sin_port = htons(port); memcpy(addr4, res->ai_addr, res->ai_addrlen);
addr4->sin_port = htons(port);
addr4->sin_family = AF_INET;
}
break;
case AF_INET6: {
sockaddr_in6 *addr6 = reinterpret_cast<sockaddr_in6 *>(&addr);
memcpy(addr6, res->ai_addr, res->ai_addrlen);
addr6->sin6_port = htons(port);
addr6->sin6_family = AF_INET6;
}
break;
default:
CHECK(false) << "cannot decode address";
}
freeaddrinfo(res); freeaddrinfo(res);
} }
/*! \brief return port of the address */ /*! \brief return port of the address */
int port() const { int port() const {
return ntohs(addr.sin_port); return ntohs((addr.ss_family == AF_INET6)? \
reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_port : \
reinterpret_cast<const sockaddr_in *>(&addr)->sin_port);
}
/*! \brief return the ip address family */
int ss_family() const {
return addr.ss_family;
} }
/*! \return a string representation of the address */ /*! \return a string representation of the address */
std::string AsString() const { std::string AsString() const {
std::string buf; buf.resize(256); std::string buf; buf.resize(256);
const void *sinx_addr = nullptr;
if (addr.ss_family == AF_INET6) {
const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6 *>(&addr)->sin6_addr;
sinx_addr = reinterpret_cast<const void *>(&addr6);
} else if (addr.ss_family == AF_INET) {
const in_addr& addr4 = reinterpret_cast<const sockaddr_in *>(&addr)->sin_addr;
sinx_addr = reinterpret_cast<const void *>(&addr4);
} else {
CHECK(false) << "illegal address";
}
#ifdef _WIN32 #ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, const char *s = inet_ntop(addr.ss_family, sinx_addr,
&buf[0], buf.length()); &buf[0], buf.length());
#else #else
const char *s = inet_ntop(AF_INET, &addr.sin_addr, const char *s = inet_ntop(addr.ss_family, sinx_addr,
&buf[0], static_cast<socklen_t>(buf.length())); &buf[0], static_cast<socklen_t>(buf.length()));
#endif #endif
CHECK(s != nullptr) << "cannot decode address"; CHECK(s != nullptr) << "cannot decode address";
...@@ -294,7 +327,7 @@ class TCPSocket : public Socket { ...@@ -294,7 +327,7 @@ class TCPSocket : public Socket {
* \param af domain * \param af domain
*/ */
void Create(int af = PF_INET) { void Create(int af = PF_INET) {
sockfd = socket(PF_INET, SOCK_STREAM, 0); sockfd = socket(af, SOCK_STREAM, 0);
if (sockfd == INVALID_SOCKET) { if (sockfd == INVALID_SOCKET) {
Socket::Error("Create"); Socket::Error("Create");
} }
......
...@@ -43,7 +43,7 @@ std::shared_ptr<RPCSession> ...@@ -43,7 +43,7 @@ std::shared_ptr<RPCSession>
RPCConnect(std::string url, int port, std::string key) { RPCConnect(std::string url, int port, std::string key) {
common::TCPSocket sock; common::TCPSocket sock;
common::SockAddr addr(url.c_str(), port); common::SockAddr addr(url.c_str(), port);
sock.Create(); sock.Create(addr.ss_family());
CHECK(sock.Connect(addr)) CHECK(sock.Connect(addr))
<< "Connect to " << addr.AsString() << " failed"; << "Connect to " << addr.AsString() << " failed";
// hand shake // hand shake
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment