7ada996f71680cb0f7b1555a01a6e16814a57a7f
galt
  Sun Aug 4 04:30:55 2019 -0700
Making TCP client connect more robust by trying each IP address returned by getaddrinfo() until one connects, or all have failed.

diff --git src/lib/net.c src/lib/net.c
index 8471eca..db0b56a 100644
--- src/lib/net.c
+++ src/lib/net.c
@@ -10,31 +10,31 @@
 #include <sys/time.h>
 #include <pthread.h>
 #include "errAbort.h"
 #include "hash.h"
 #include "net.h"
 #include "linefile.h"
 #include "base64.h"
 #include "cheapcgi.h"
 #include "https.h"
 #include "sqlNum.h"
 #include "obscure.h"
 
 /* Brought errno in to get more useful error messages */
 extern int errno;
 
-static int netStreamSocket6(struct addrinfo *address)
+static int netStreamSocketFromAddrInfo(struct addrinfo *address)
 /* Create a socket from addrinfo structure.  
  * Complain and return something negative if can't. */
 {
 int sd = socket(address->ai_family, address->ai_socktype, address->ai_protocol);
 if (sd < 0)
     warn("Couldn't make %s socket.", familyToString(address->ai_family));
 
 return sd;
 }
 
 static int setSocketNonBlocking(int sd, boolean set)
 /* Use socket control flags to set O_NONBLOCK if set==TRUE,
  * or clear it if set==FALSE.
  * Return -1 if there are any errors, 0 if successful. */
 {
@@ -84,153 +84,176 @@
 /* Return the result of a - b; this handles wrapping of milliseconds. 
  * result.tv_usec will always be positive. 
  * result.tv_sec will be negative if b > a. */
 {
 // subtract b from a.
 if (a.tv_usec < b.tv_usec)
     {
     a.tv_usec += 1000000;
     a.tv_sec--;
     }
 a.tv_usec -= b.tv_usec;
 a.tv_sec -= b.tv_sec;
 return a;
 }
 
-static int netConnectWithTimeout(char *hostName, int port, long msTimeout)
-/* In order to avoid a very long default timeout (several minutes) for hosts that will
- * not answer the port, we are forced to connect non-blocking.
- * After the connection has been established, we return to blocking mode.
- * Also closes sd if error. */
-{
-int sd;
-struct addrinfo *address=NULL;
-int res;
-fd_set mySet;
-char portStr[8];
-safef(portStr, sizeof portStr, "%d", port);
-
-if (hostName == NULL)
+int netConnectWithTimeoutOneAddr(int sd, struct addrinfo *address, long msTimeout, char *hostName, int port, struct dyString *dy)
+/* Try to connect one address with timeout or return success == 0, failure == -1 */
 {
-    warn("NULL hostName in netConnect");
-    return -1;
-    }
-if (!internetGetAddrInfo6n4(hostName, portStr, &address))
-    return -1;
-if ((sd = netStreamSocket6(address)) < 0)
-    return sd;
-
 // Set socket to nonblocking so we can manage our own timeout time.
 if (setSocketNonBlocking(sd, TRUE) < 0)
     {
-    close(sd);
     return -1;
     }
 
 // Trying to connect with timeout
+int res;
 res = connect(sd, address->ai_addr, address->ai_addrlen);
+char ipStr[NI_MAXHOST];
+getAddrAsString6n4((struct sockaddr_storage *)address->ai_addr, ipStr, sizeof ipStr);
 if (res < 0)
     {
     if (errno == EINPROGRESS)
 	{
 	struct timeval startTime;
 	gettimeofday(&startTime, NULL);
 	struct timeval remainingTime;
 	remainingTime.tv_sec = (long) (msTimeout/1000);
 	remainingTime.tv_usec = (long) (((msTimeout/1000)-remainingTime.tv_sec)*1000000);
 	while (1) 
 	    {
+	    fd_set mySet;
 	    FD_ZERO(&mySet);
 	    FD_SET(sd, &mySet);
 	    // use tempTime (instead of using remainingTime directly) because on some platforms select() may modify the time val.
 	    struct timeval tempTime = remainingTime;
 	    res = select(sd+1, NULL, &mySet, &mySet, &tempTime);  
 	    if (res < 0) 
 		{
 		if (errno == EINTR)  // Ignore the interrupt 
 		    {
 		    // Subtract the elapsed time from remaining time since some platforms need this.
 		    struct timeval newTime;
 		    gettimeofday(&newTime, NULL);
 		    struct timeval elapsedTime = tvMinus(newTime, startTime);
 		    remainingTime = tvMinus(remainingTime, elapsedTime);
 		    if (remainingTime.tv_sec < 0)  // means our timeout has more than expired
 			{
 			remainingTime.tv_sec = 0;
 			remainingTime.tv_usec = 0;
 			}
 		    startTime = newTime;
 		    }
 		else
 		    {
-		    warn("Error in select() during TCP non-blocking connect %d - %s", errno, strerror(errno));
-		    close(sd);
+		    dyStringPrintf(dy, "Error in select() during TCP non-blocking connect %d - %s\n", errno, strerror(errno));
 		    return -1;
 		    }
 		}
 	    else if (res > 0)
 		{
 		// Socket selected for write when it is ready
 		int valOpt;
 		socklen_t lon;
 		// But check the socket for any errors
 		lon = sizeof(valOpt);
 		if (getsockopt(sd, SOL_SOCKET, SO_ERROR, (void*) (&valOpt), &lon) < 0)
 		    {
 		    warn("Error in getsockopt() %d - %s", errno, strerror(errno));
-                    close(sd);
 		    return -1;
 		    }
 		// Check the value returned...
 		if (valOpt)
 		    {
-                    warn("Error in TCP non-blocking connect() %d - %s. Host %s port %d.", valOpt, strerror(valOpt), hostName, port);
-                    close(sd);
+		    dyStringPrintf(dy, "Error in TCP non-blocking connect() %d - %s. Host %s IP %s port %d.\n", valOpt, strerror(valOpt), hostName, ipStr, port);
 		    return -1;
 		    }
-		break;
+		break;  // OK
 		}
 	    else
 		{
-		warn("TCP non-blocking connect() to %s timed-out in select() after %ld milliseconds - Cancelling!", hostName, msTimeout);
-		close(sd);
+		dyStringPrintf(dy, "TCP non-blocking connect() to %s IP %s timed-out in select() after %ld milliseconds - Cancelling!", hostName, ipStr, msTimeout);
 		return -1;
 		}
 	    }
 	}
     else
 	{
-	warn("TCP non-blocking connect() error %d - %s", errno, strerror(errno));
-	close(sd);
+	dyStringPrintf(dy, "TCP non-blocking connect() error %d - %s", errno, strerror(errno));
 	return -1;
 	}
     }
+return 0; // OK
+}
+
+
+static int netConnectWithTimeout(char *hostName, int port, long msTimeout)
+/* In order to avoid a very long default timeout (several minutes) for hosts that will
+* not answer the port, we are forced to connect non-blocking.
+* After the connection has been established, we return to blocking mode.
+* Also closes sd if error. */
+{
+int sd;
+struct addrinfo *addressList=NULL, *address;
+char portStr[8];
+safef(portStr, sizeof portStr, "%d", port);
+
+if (hostName == NULL)
+    {
+    warn("NULL hostName in netConnect");
+    return -1;
+    }
+if (!internetGetAddrInfo6n4(hostName, portStr, &addressList))
+    return -1;
+
+struct dyString *errMsg = newDyString(256);
+for (address = addressList; address; address = address->ai_next)
+    {
+    if ((sd = netStreamSocketFromAddrInfo(address)) < 0)
+	continue;
+
+    if (netConnectWithTimeoutOneAddr(sd, address, msTimeout, hostName, port, errMsg) == 0)
+	break;
+
+    close(sd);
+    }
+freeaddrinfo(addressList);
+
+if (!address) // none of the addresses connected successfully
+    {
+    if (!sameString(errMsg->string, ""))
+	{
+	warn("%s", errMsg->string);
+	}
+    }
+dyStringFree(&errMsg);
+if (!address)
+    return -1;
 
 // Set to blocking mode again
 if (setSocketNonBlocking(sd, FALSE) < 0)
     {
     close(sd);
     return -1;
     }
 
 if (setReadWriteTimeouts(sd, DEFAULTREADWRITETTIMEOUTSEC) < 0)
     {
     close(sd);
     return -1;
     }
 
-freeaddrinfo(address);
 return sd;
 
 }
 
 
 int netConnect(char *hostName, int port)
 /* Start connection with a server. */
 {
 return netConnectWithTimeout(hostName, port, DEFAULTCONNECTTIMEOUTMSEC); // 10 seconds connect timeout
 }
 
 int netMustConnect(char *hostName, int port)
 /* Start connection with server or die. */
 {
 int sd = netConnect(hostName, port);