7ada996f71680cb0f7b1555a01a6e16814a57a7f galt Sun Aug 4 04:30:55 2019 -0700 Making TCP client connect more robust by trying each IP address returned by getaddrinfo() until one connects, or all have failed. diff --git src/lib/net.c src/lib/net.c index 8471eca..db0b56a 100644 --- src/lib/net.c +++ src/lib/net.c @@ -10,31 +10,31 @@ #include <sys/time.h> #include <pthread.h> #include "errAbort.h" #include "hash.h" #include "net.h" #include "linefile.h" #include "base64.h" #include "cheapcgi.h" #include "https.h" #include "sqlNum.h" #include "obscure.h" /* Brought errno in to get more useful error messages */ extern int errno; -static int netStreamSocket6(struct addrinfo *address) +static int netStreamSocketFromAddrInfo(struct addrinfo *address) /* Create a socket from addrinfo structure. * Complain and return something negative if can't. */ { int sd = socket(address->ai_family, address->ai_socktype, address->ai_protocol); if (sd < 0) warn("Couldn't make %s socket.", familyToString(address->ai_family)); return sd; } static int setSocketNonBlocking(int sd, boolean set) /* Use socket control flags to set O_NONBLOCK if set==TRUE, * or clear it if set==FALSE. * Return -1 if there are any errors, 0 if successful. */ { @@ -84,153 +84,176 @@ /* Return the result of a - b; this handles wrapping of milliseconds. * result.tv_usec will always be positive. * result.tv_sec will be negative if b > a. */ { // subtract b from a. if (a.tv_usec < b.tv_usec) { a.tv_usec += 1000000; a.tv_sec--; } a.tv_usec -= b.tv_usec; a.tv_sec -= b.tv_sec; return a; } -static int netConnectWithTimeout(char *hostName, int port, long msTimeout) -/* In order to avoid a very long default timeout (several minutes) for hosts that will - * not answer the port, we are forced to connect non-blocking. - * After the connection has been established, we return to blocking mode. - * Also closes sd if error. */ -{ -int sd; -struct addrinfo *address=NULL; -int res; -fd_set mySet; -char portStr[8]; -safef(portStr, sizeof portStr, "%d", port); - -if (hostName == NULL) +int netConnectWithTimeoutOneAddr(int sd, struct addrinfo *address, long msTimeout, char *hostName, int port, struct dyString *dy) +/* Try to connect one address with timeout or return success == 0, failure == -1 */ { - warn("NULL hostName in netConnect"); - return -1; - } -if (!internetGetAddrInfo6n4(hostName, portStr, &address)) - return -1; -if ((sd = netStreamSocket6(address)) < 0) - return sd; - // Set socket to nonblocking so we can manage our own timeout time. if (setSocketNonBlocking(sd, TRUE) < 0) { - close(sd); return -1; } // Trying to connect with timeout +int res; res = connect(sd, address->ai_addr, address->ai_addrlen); +char ipStr[NI_MAXHOST]; +getAddrAsString6n4((struct sockaddr_storage *)address->ai_addr, ipStr, sizeof ipStr); if (res < 0) { if (errno == EINPROGRESS) { struct timeval startTime; gettimeofday(&startTime, NULL); struct timeval remainingTime; remainingTime.tv_sec = (long) (msTimeout/1000); remainingTime.tv_usec = (long) (((msTimeout/1000)-remainingTime.tv_sec)*1000000); while (1) { + fd_set mySet; FD_ZERO(&mySet); FD_SET(sd, &mySet); // use tempTime (instead of using remainingTime directly) because on some platforms select() may modify the time val. struct timeval tempTime = remainingTime; res = select(sd+1, NULL, &mySet, &mySet, &tempTime); if (res < 0) { if (errno == EINTR) // Ignore the interrupt { // Subtract the elapsed time from remaining time since some platforms need this. struct timeval newTime; gettimeofday(&newTime, NULL); struct timeval elapsedTime = tvMinus(newTime, startTime); remainingTime = tvMinus(remainingTime, elapsedTime); if (remainingTime.tv_sec < 0) // means our timeout has more than expired { remainingTime.tv_sec = 0; remainingTime.tv_usec = 0; } startTime = newTime; } else { - warn("Error in select() during TCP non-blocking connect %d - %s", errno, strerror(errno)); - close(sd); + dyStringPrintf(dy, "Error in select() during TCP non-blocking connect %d - %s\n", errno, strerror(errno)); return -1; } } else if (res > 0) { // Socket selected for write when it is ready int valOpt; socklen_t lon; // But check the socket for any errors lon = sizeof(valOpt); if (getsockopt(sd, SOL_SOCKET, SO_ERROR, (void*) (&valOpt), &lon) < 0) { warn("Error in getsockopt() %d - %s", errno, strerror(errno)); - close(sd); return -1; } // Check the value returned... if (valOpt) { - warn("Error in TCP non-blocking connect() %d - %s. Host %s port %d.", valOpt, strerror(valOpt), hostName, port); - close(sd); + dyStringPrintf(dy, "Error in TCP non-blocking connect() %d - %s. Host %s IP %s port %d.\n", valOpt, strerror(valOpt), hostName, ipStr, port); return -1; } - break; + break; // OK } else { - warn("TCP non-blocking connect() to %s timed-out in select() after %ld milliseconds - Cancelling!", hostName, msTimeout); - close(sd); + dyStringPrintf(dy, "TCP non-blocking connect() to %s IP %s timed-out in select() after %ld milliseconds - Cancelling!", hostName, ipStr, msTimeout); return -1; } } } else { - warn("TCP non-blocking connect() error %d - %s", errno, strerror(errno)); - close(sd); + dyStringPrintf(dy, "TCP non-blocking connect() error %d - %s", errno, strerror(errno)); return -1; } } +return 0; // OK +} + + +static int netConnectWithTimeout(char *hostName, int port, long msTimeout) +/* In order to avoid a very long default timeout (several minutes) for hosts that will +* not answer the port, we are forced to connect non-blocking. +* After the connection has been established, we return to blocking mode. +* Also closes sd if error. */ +{ +int sd; +struct addrinfo *addressList=NULL, *address; +char portStr[8]; +safef(portStr, sizeof portStr, "%d", port); + +if (hostName == NULL) + { + warn("NULL hostName in netConnect"); + return -1; + } +if (!internetGetAddrInfo6n4(hostName, portStr, &addressList)) + return -1; + +struct dyString *errMsg = newDyString(256); +for (address = addressList; address; address = address->ai_next) + { + if ((sd = netStreamSocketFromAddrInfo(address)) < 0) + continue; + + if (netConnectWithTimeoutOneAddr(sd, address, msTimeout, hostName, port, errMsg) == 0) + break; + + close(sd); + } +freeaddrinfo(addressList); + +if (!address) // none of the addresses connected successfully + { + if (!sameString(errMsg->string, "")) + { + warn("%s", errMsg->string); + } + } +dyStringFree(&errMsg); +if (!address) + return -1; // Set to blocking mode again if (setSocketNonBlocking(sd, FALSE) < 0) { close(sd); return -1; } if (setReadWriteTimeouts(sd, DEFAULTREADWRITETTIMEOUTSEC) < 0) { close(sd); return -1; } -freeaddrinfo(address); return sd; } int netConnect(char *hostName, int port) /* Start connection with a server. */ { return netConnectWithTimeout(hostName, port, DEFAULTCONNECTTIMEOUTMSEC); // 10 seconds connect timeout } int netMustConnect(char *hostName, int port) /* Start connection with server or die. */ { int sd = netConnect(hostName, port);