diff --git a/src/compat.h b/src/compat.h index 049579c365..7b164d5630 100644 --- a/src/compat.h +++ b/src/compat.h @@ -92,8 +92,15 @@ typedef void* sockopt_arg_type; typedef char* sockopt_arg_type; #endif +// Note these both should work with the current usage of poll, but best to be safe +// WIN32 poll is broken https://daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/ +// __APPLE__ poll is broke https://github.com/bitcoin/bitcoin/pull/14336#issuecomment-437384408 +#if defined(__linux__) +#define USE_POLL +#endif + bool static inline IsSelectableSocket(const SOCKET& s) { -#ifdef WIN32 +#if defined(USE_POLL) || defined(WIN32) return true; #else return (s < FD_SETSIZE); diff --git a/src/init.cpp b/src/init.cpp index a3a7c5a3bb..18c145a023 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -953,8 +953,13 @@ bool AppInitParameterInteraction() // Trim requested connection counts, to fit into system limitations // in std::min(...) to work around FreeBSD compilation issue described in #2695 - nMaxConnections = std::max(std::min(nMaxConnections, FD_SETSIZE - nBind - MIN_CORE_FILEDESCRIPTORS - MAX_ADDNODE_CONNECTIONS), 0); nFD = RaiseFileDescriptorLimit(nMaxConnections + MIN_CORE_FILEDESCRIPTORS + MAX_ADDNODE_CONNECTIONS); +#ifdef USE_POLL + int fd_max = nFD; +#else + int fd_max = FD_SETSIZE; +#endif + nMaxConnections = std::max(std::min(nMaxConnections, fd_max - nBind - MIN_CORE_FILEDESCRIPTORS - MAX_ADDNODE_CONNECTIONS), 0); if (nFD < MIN_CORE_FILEDESCRIPTORS) return InitError(_("Not enough file descriptors available.")); nMaxConnections = std::min(nFD - MIN_CORE_FILEDESCRIPTORS - MAX_ADDNODE_CONNECTIONS, nMaxConnections); diff --git a/src/net.cpp b/src/net.cpp index e595fb0b0b..a288c6e81e 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -26,6 +26,10 @@ #include #endif +#ifdef USE_POLL +#include +#endif + #ifdef USE_UPNP #include #include @@ -33,6 +37,7 @@ #include #endif +#include #include @@ -71,6 +76,10 @@ enum BindFlags { BF_WHITELIST = (1U << 2), }; +// The set of sockets cannot be modified while waiting +// The sleep time needs to be small to avoid new sockets stalling +static const uint64_t SELECT_TIMEOUT_MILLISECONDS = 50; + const static std::string NET_MESSAGE_COMMAND_OTHER = "*other*"; static const uint64_t RANDOMIZER_ID_NETGROUP = 0x6c0edd8036ef4036ULL; // SHA256("netgroup")[0:8] @@ -1258,28 +1267,10 @@ void CConnman::InactivityCheck(CNode *pnode) } } -void CConnman::SocketHandler() +bool CConnman::GenerateSelectSet(std::set &recv_set, std::set &send_set, std::set &error_set) { - // - // Find which sockets have data to receive - // - struct timeval timeout; - timeout.tv_sec = 0; - timeout.tv_usec = 50000; // frequency to poll pnode->vSend - - fd_set fdsetRecv; - fd_set fdsetSend; - fd_set fdsetError; - FD_ZERO(&fdsetRecv); - FD_ZERO(&fdsetSend); - FD_ZERO(&fdsetError); - SOCKET hSocketMax = 0; - bool have_fds = false; - for (const ListenSocket& hListenSocket : vhListenSocket) { - FD_SET(hListenSocket.socket, &fdsetRecv); - hSocketMax = std::max(hSocketMax, hListenSocket.socket); - have_fds = true; + recv_set.insert(hListenSocket.socket); } { @@ -1308,46 +1299,151 @@ void CConnman::SocketHandler() if (pnode->hSocket == INVALID_SOCKET) continue; - FD_SET(pnode->hSocket, &fdsetError); - hSocketMax = std::max(hSocketMax, pnode->hSocket); - have_fds = true; - + error_set.insert(pnode->hSocket); if (select_send) { - FD_SET(pnode->hSocket, &fdsetSend); + send_set.insert(pnode->hSocket); continue; } if (select_recv) { - FD_SET(pnode->hSocket, &fdsetRecv); + recv_set.insert(pnode->hSocket); } } } - int nSelect = select(have_fds ? hSocketMax + 1 : 0, - &fdsetRecv, &fdsetSend, &fdsetError, &timeout); + return !recv_set.empty() || !send_set.empty() || !error_set.empty(); +} + +#ifdef USE_POLL +void CConnman::SocketEvents(std::set &recv_set, std::set &send_set, std::set &error_set) +{ + std::set recv_select_set, send_select_set, error_select_set; + if (!GenerateSelectSet(recv_select_set, send_select_set, error_select_set)) { + interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS)); + return; + } + + std::unordered_map pollfds; + for (SOCKET socket_id : recv_select_set) { + pollfds[socket_id].fd = socket_id; + pollfds[socket_id].events |= POLLIN; + } + + for (SOCKET socket_id : send_select_set) { + pollfds[socket_id].fd = socket_id; + pollfds[socket_id].events |= POLLOUT; + } + + for (SOCKET socket_id : error_select_set) { + pollfds[socket_id].fd = socket_id; + // These flags are ignored, but we set them for clarity + pollfds[socket_id].events |= POLLERR|POLLHUP; + } + + std::vector vpollfds; + vpollfds.reserve(pollfds.size()); + for (auto it : pollfds) { + vpollfds.push_back(std::move(it.second)); + } + + if (poll(vpollfds.data(), vpollfds.size(), SELECT_TIMEOUT_MILLISECONDS) < 0) return; + + if (interruptNet) return; + + for (struct pollfd pollfd_entry : vpollfds) { + if (pollfd_entry.revents & POLLIN) recv_set.insert(pollfd_entry.fd); + if (pollfd_entry.revents & POLLOUT) send_set.insert(pollfd_entry.fd); + if (pollfd_entry.revents & (POLLERR|POLLHUP)) error_set.insert(pollfd_entry.fd); + } +} +#else +void CConnman::SocketEvents(std::set &recv_set, std::set &send_set, std::set &error_set) +{ + std::set recv_select_set, send_select_set, error_select_set; + if (!GenerateSelectSet(recv_select_set, send_select_set, error_select_set)) { + interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS)); + return; + } + + // + // Find which sockets have data to receive + // + struct timeval timeout; + timeout.tv_sec = 0; + timeout.tv_usec = SELECT_TIMEOUT_MILLISECONDS * 1000; // frequency to poll pnode->vSend + + fd_set fdsetRecv; + fd_set fdsetSend; + fd_set fdsetError; + FD_ZERO(&fdsetRecv); + FD_ZERO(&fdsetSend); + FD_ZERO(&fdsetError); + SOCKET hSocketMax = 0; + + for (SOCKET hSocket : recv_select_set) { + FD_SET(hSocket, &fdsetRecv); + hSocketMax = std::max(hSocketMax, hSocket); + } + + for (SOCKET hSocket : send_select_set) { + FD_SET(hSocket, &fdsetSend); + hSocketMax = std::max(hSocketMax, hSocket); + } + + for (SOCKET hSocket : error_select_set) { + FD_SET(hSocket, &fdsetError); + hSocketMax = std::max(hSocketMax, hSocket); + } + + int nSelect = select(hSocketMax + 1, &fdsetRecv, &fdsetSend, &fdsetError, &timeout); + if (interruptNet) return; if (nSelect == SOCKET_ERROR) { - if (have_fds) - { - int nErr = WSAGetLastError(); - LogPrintf("socket select error %s\n", NetworkErrorString(nErr)); - for (unsigned int i = 0; i <= hSocketMax; i++) - FD_SET(i, &fdsetRecv); - } + int nErr = WSAGetLastError(); + LogPrintf("socket select error %s\n", NetworkErrorString(nErr)); + for (unsigned int i = 0; i <= hSocketMax; i++) + FD_SET(i, &fdsetRecv); FD_ZERO(&fdsetSend); FD_ZERO(&fdsetError); - if (!interruptNet.sleep_for(std::chrono::milliseconds(timeout.tv_usec/1000))) + if (!interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS))) return; } + for (SOCKET hSocket : recv_select_set) { + if (FD_ISSET(hSocket, &fdsetRecv)) { + recv_set.insert(hSocket); + } + } + + for (SOCKET hSocket : send_select_set) { + if (FD_ISSET(hSocket, &fdsetSend)) { + send_set.insert(hSocket); + } + } + + for (SOCKET hSocket : error_select_set) { + if (FD_ISSET(hSocket, &fdsetError)) { + error_set.insert(hSocket); + } + } +} +#endif + +void CConnman::SocketHandler() +{ + std::set recv_set, send_set, error_set; + SocketEvents(recv_set, send_set, error_set); + + if (interruptNet) return; + // // Accept new connections // for (const ListenSocket& hListenSocket : vhListenSocket) { - if (hListenSocket.socket != INVALID_SOCKET && FD_ISSET(hListenSocket.socket, &fdsetRecv)) + if (hListenSocket.socket != INVALID_SOCKET && recv_set.count(hListenSocket.socket) > 0) { AcceptConnection(hListenSocket); } @@ -1378,9 +1474,9 @@ void CConnman::SocketHandler() LOCK(pnode->cs_hSocket); if (pnode->hSocket == INVALID_SOCKET) continue; - recvSet = FD_ISSET(pnode->hSocket, &fdsetRecv); - sendSet = FD_ISSET(pnode->hSocket, &fdsetSend); - errorSet = FD_ISSET(pnode->hSocket, &fdsetError); + recvSet = recv_set.count(pnode->hSocket) > 0; + sendSet = send_set.count(pnode->hSocket) > 0; + errorSet = error_set.count(pnode->hSocket) > 0; } if (recvSet || errorSet) { diff --git a/src/net.h b/src/net.h index 775d0c8099..915a8e5b35 100644 --- a/src/net.h +++ b/src/net.h @@ -346,6 +346,8 @@ private: void DisconnectNodes(); void NotifyNumConnectionsChanged(); void InactivityCheck(CNode *pnode); + bool GenerateSelectSet(std::set &recv_set, std::set &send_set, std::set &error_set); + void SocketEvents(std::set &recv_set, std::set &send_set, std::set &error_set); void SocketHandler(); void ThreadSocketHandler(); void ThreadDNSAddressSeed(); diff --git a/src/netbase.cpp b/src/netbase.cpp index 1c043fc981..355e21d4e6 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -21,6 +21,10 @@ #include #endif +#ifdef USE_POLL +#include +#endif + #if !defined(MSG_NOSIGNAL) #define MSG_NOSIGNAL 0 #endif @@ -264,11 +268,19 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c if (!IsSelectableSocket(hSocket)) { return IntrRecvError::NetworkError; } - struct timeval tval = MillisToTimeval(std::min(endTime - curTime, maxWait)); + int timeout_ms = std::min(endTime - curTime, maxWait); +#ifdef USE_POLL + struct pollfd pollfd = {}; + pollfd.fd = hSocket; + pollfd.events = POLLIN | POLLOUT; + int nRet = poll(&pollfd, 1, timeout_ms); +#else + struct timeval tval = MillisToTimeval(timeout_ms); fd_set fdset; FD_ZERO(&fdset); FD_SET(hSocket, &fdset); int nRet = select(hSocket + 1, &fdset, nullptr, nullptr, &tval); +#endif if (nRet == SOCKET_ERROR) { return IntrRecvError::NetworkError; } @@ -499,11 +511,18 @@ bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocket, i // WSAEINVAL is here because some legacy version of winsock uses it if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) { +#ifdef USE_POLL + struct pollfd pollfd = {}; + pollfd.fd = hSocket; + pollfd.events = POLLIN | POLLOUT; + int nRet = poll(&pollfd, 1, nTimeout); +#else struct timeval timeout = MillisToTimeval(nTimeout); fd_set fdset; FD_ZERO(&fdset); FD_SET(hSocket, &fdset); int nRet = select(hSocket + 1, nullptr, &fdset, nullptr, &timeout); +#endif if (nRet == 0) { LogPrint(BCLog::NET, "connection to %s timeout\n", addrConnect.ToString());