a0e46a238441cd617f2afe6089aaed3693a115e7 galt Wed Sep 19 14:54:50 2012 -0700 code cleanup suggested by Angie diff --git src/lib/net.c src/lib/net.c index e61fc90..ffa7e89 100644 --- src/lib/net.c +++ src/lib/net.c @@ -73,218 +73,221 @@ } ++cf; } return result; } void addConnFailure(char *hostName, int port, char *format, ...) /* add a failure to connFailures[] * which can save time and avoid more timeouts */ { if (!connFailuresEnabled) return; char errorString[1024]; va_list args; -va_start(args, format); -vsprintf(errorString, format, args); +vasafef(errorString, sizeof errorString, format, args); va_end(args); if (!checkConnFailure(hostName,port,NULL)) { pthread_mutex_lock( &cfMutex ); if (numConnFailures < MAXCONNFAILURES) { struct connFailure *cf = connFailures + numConnFailures; cf->hostName = cloneString(hostName); cf->port = port; cf->errorString = cloneString(errorString); numConnFailures++; } pthread_mutex_unlock( &cfMutex ); } } static int netStreamSocket() /* Create a TCP/IP streaming socket. Complain and return something * negative if can't */ { int sd = socket(AF_INET, SOCK_STREAM, 0); if (sd < 0) warn("Couldn't make AF_INET socket."); return sd; } +static int setSocketNonBlocking(int sd, boolean set) +/* Use socket control flags to set O_NONBLOCK if set==TRUE, + * or clear it if set==FALSE. + * Return -1 if there are any errors, 0 if successful. i + * Also closes sd if error. */ +{ +long fcntlFlags; +// Set or clear non-blocking +if ((fcntlFlags = fcntl(sd, F_GETFL, NULL)) < 0) + { + warn("Error fcntl(..., F_GETFL) (%s)", strerror(errno)); + return -1; + } +if (set) + fcntlFlags |= O_NONBLOCK; +else + fcntlFlags &= (~O_NONBLOCK); +if (fcntl(sd, F_SETFL, fcntlFlags) < 0) + { + warn("Error fcntl(..., F_SETFL) (%s)", strerror(errno)); + return -1; + } +return 0; +} + +static struct timeval tvMinus(struct timeval a, struct timeval b) +/* Return the result of a - b; this handles wrapping of milliseconds. + * result.tv_usec will always be positive. + * result.tv_sec will be negative if b > a. */ +{ +// subtract b from a. +if (a.tv_usec < b.tv_usec) + { + a.tv_usec += 1000000; + a.tv_sec--; + } +a.tv_usec -= b.tv_usec; +a.tv_sec -= b.tv_sec; +return a; +} + static int netConnectWithTimeout(char *hostName, int port, long msTimeout) /* In order to avoid a very long default timeout (several minutes) for hosts that will * not answer the port, we are forced to connect non-blocking. * After the connection has been established, we return to blocking mode. */ { int sd; struct sockaddr_in sai; /* Some system socket info. */ int res; fd_set mySet; -struct timeval lTime; -long fcntlFlags; -struct timeval startTime; -gettimeofday(&startTime, NULL); -struct timeval remainingTime; -remainingTime.tv_sec = (long) (msTimeout/1000); -remainingTime.tv_usec = (long) (((msTimeout/1000)-remainingTime.tv_sec)*1000000); char *errorString = NULL; if (checkConnFailure(hostName, port, &errorString)) { warn(errorString); return -1; } if (hostName == NULL) { warn("NULL hostName in netConnect"); return -1; } if (!internetFillInAddress(hostName, port, &sai)) return -1; if ((sd = netStreamSocket()) < 0) return sd; -// Set non-blocking -if ((fcntlFlags = fcntl(sd, F_GETFL, NULL)) < 0) +// Set socket to nonblocking so we can manage our own timeout time. +if (setSocketNonBlocking(sd, TRUE) < 0) { - warn("Error fcntl(..., F_GETFL) (%s)", strerror(errno)); - close(sd); - return -1; - } -fcntlFlags |= O_NONBLOCK; -if (fcntl(sd, F_SETFL, fcntlFlags) < 0) - { - warn("Error fcntl(..., F_SETFL) (%s)", strerror(errno)); close(sd); return -1; } // Trying to connect with timeout res = connect(sd, (struct sockaddr*) &sai, sizeof(sai)); if (res < 0) { if (errno == EINPROGRESS) { + struct timeval startTime; + gettimeofday(&startTime, NULL); + struct timeval remainingTime; + remainingTime.tv_sec = (long) (msTimeout/1000); + remainingTime.tv_usec = (long) (((msTimeout/1000)-remainingTime.tv_sec)*1000000); while (1) { - lTime.tv_sec = remainingTime.tv_sec; - lTime.tv_usec = remainingTime.tv_usec; FD_ZERO(&mySet); FD_SET(sd, &mySet); - res = select(sd+1, NULL, &mySet, &mySet, &lTime); // some platforms may modify lTime. + // use tempTime (instead of using remainingTime directly) because on some platforms select() may modify the time val. + struct timeval tempTime = remainingTime; + res = select(sd+1, NULL, &mySet, &mySet, &tempTime); if (res < 0) { - if (errno == EINTR) // ignore the interrupt but subtract the elapsed time from remainingTime since some platforms need this. + if (errno == EINTR) // Ignore the interrupt { + // Subtract the elapsed time from remaining time since some platforms need this. struct timeval newTime; gettimeofday(&newTime, NULL); - struct timeval elapsedTime; - // subtract startTime from newTime. - if (newTime.tv_usec < startTime.tv_usec) - { - newTime.tv_usec += 1000000; - newTime.tv_sec--; - } - elapsedTime.tv_usec = newTime.tv_usec - startTime.tv_usec; - elapsedTime.tv_sec = newTime.tv_sec - startTime.tv_sec; - // the elapsedTime should never be negative - // subtract elapsedTime from remainingTime - if (remainingTime.tv_usec < elapsedTime.tv_usec) - { - remainingTime.tv_usec += 1000000; - remainingTime.tv_sec--; - } - remainingTime.tv_usec = remainingTime.tv_usec - elapsedTime.tv_usec; - remainingTime.tv_sec = remainingTime.tv_sec - elapsedTime.tv_sec; - // the remainingTime.tv_usec should never be negative - // the remainingTime.tv_sec may be negative + struct timeval elapsedTime = tvMinus(newTime, startTime); + remainingTime = tvMinus(remainingTime, elapsedTime); if (remainingTime.tv_sec < 0) // means our timeout has more than expired { remainingTime.tv_sec = 0; remainingTime.tv_usec = 0; } - // for the next cycle set start = new - startTime.tv_sec = newTime.tv_sec; - startTime.tv_usec = newTime.tv_usec; + startTime = newTime; } else { warn("Error in select() during TCP non-blocking connect %d - %s", errno, strerror(errno)); close(sd); return -1; } } else if (res > 0) { // Socket selected for write when it is ready int valOpt; socklen_t lon; // But check the socket for any errors lon = sizeof(valOpt); if (getsockopt(sd, SOL_SOCKET, SO_ERROR, (void*) (&valOpt), &lon) < 0) { warn("Error in getsockopt() %d - %s", errno, strerror(errno)); close(sd); return -1; } // Check the value returned... if (valOpt) { warn("Error in TCP non-blocking connect() %d - %s", valOpt, strerror(valOpt)); - if (valOpt == 110) + if (valOpt == ETIMEDOUT) addConnFailure(hostName, port, "Error in TCP non-blocking connect() %d - %s", valOpt, strerror(valOpt)); close(sd); return -1; } break; } else { addConnFailure(hostName, port, "TCP non-blocking connect() to %s timed-out in select() after %ld milliseconds - Cancelling!", hostName, msTimeout); warn("TCP non-blocking connect() to %s timed-out in select() after %ld milliseconds - Cancelling!", hostName, msTimeout); close(sd); return -1; } } } else { warn("TCP non-blocking connect() error %d - %s", errno, strerror(errno)); close(sd); return -1; } } // Set to blocking mode again -if ((fcntlFlags = fcntl(sd, F_GETFL, NULL)) < 0) - { - warn("Error fcntl(..., F_GETFL) (%s)", strerror(errno)); - close(sd); - return -1; - } -fcntlFlags &= (~O_NONBLOCK); -if (fcntl(sd, F_SETFL, fcntlFlags) < 0) +if (setSocketNonBlocking(sd, FALSE) < 0) { - warn("Error fcntl(..., F_SETFL) (%s)", strerror(errno)); close(sd); return -1; } return sd; } int netConnect(char *hostName, int port) /* Start connection with a server. */ { return netConnectWithTimeout(hostName, port, DEFAULTCONNECTTIMEOUTMSEC); // 10 seconds connect timeout }