From aa17a44551c03b00a47854438afe9f2f89b6ea74 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 4 Jan 2021 13:14:32 +0100 Subject: [PATCH 1/6] net: move MillisToTimeval() from netbase to util/time Move `MillisToTimeval()` from `netbase.{h,cpp}` to `src/util/system.{h,cpp}`. This is necessary in order to use `MillisToTimeval()` from a newly introduced `src/util/sock.{h,cpp}` which cannot depend on netbase because netbase will depend on it. --- src/netbase.cpp | 9 +-------- src/netbase.h | 4 ---- src/torcontrol.cpp | 1 + src/util/time.cpp | 9 +++++++++ src/util/time.h | 7 +++++++ 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/netbase.cpp b/src/netbase.cpp index 264029d8a2..93c395b9ec 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -271,14 +272,6 @@ CService LookupNumeric(const std::string& name, int portDefault) return addr; } -struct timeval MillisToTimeval(int64_t nTimeout) -{ - struct timeval timeout; - timeout.tv_sec = nTimeout / 1000; - timeout.tv_usec = (nTimeout % 1000) * 1000; - return timeout; -} - /** SOCKS version */ enum SOCKSVersion: uint8_t { SOCKS4 = 0x04, diff --git a/src/netbase.h b/src/netbase.h index ac4cd97673..3dc656d0db 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -62,10 +62,6 @@ bool CloseSocket(SOCKET& hSocket); bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking); /** Set the TCP_NODELAY flag on a socket */ bool SetSocketNoDelay(const SOCKET& hSocket); -/** - * Convert milliseconds to a struct timeval for e.g. select. - */ -struct timeval MillisToTimeval(int64_t nTimeout); void InterruptSocks5(bool interrupt); #endif // BITCOIN_NETBASE_H diff --git a/src/torcontrol.cpp b/src/torcontrol.cpp index 90ee9422ba..208794a4e5 100644 --- a/src/torcontrol.cpp +++ b/src/torcontrol.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include diff --git a/src/util/time.cpp b/src/util/time.cpp index e96972fe12..4da041e5a5 100644 --- a/src/util/time.cpp +++ b/src/util/time.cpp @@ -7,6 +7,7 @@ #include #endif +#include #include #include @@ -114,3 +115,11 @@ int64_t ParseISO8601DateTime(const std::string& str) return 0; return (ptime - epoch).total_seconds(); } + +struct timeval MillisToTimeval(int64_t nTimeout) +{ + struct timeval timeout; + timeout.tv_sec = nTimeout / 1000; + timeout.tv_usec = (nTimeout % 1000) * 1000; + return timeout; +} diff --git a/src/util/time.h b/src/util/time.h index c69f604dc6..2c0e3d83f6 100644 --- a/src/util/time.h +++ b/src/util/time.h @@ -6,6 +6,8 @@ #ifndef BITCOIN_UTIL_TIME_H #define BITCOIN_UTIL_TIME_H +#include + #include #include #include @@ -57,4 +59,9 @@ std::string FormatISO8601DateTime(int64_t nTime); std::string FormatISO8601Date(int64_t nTime); int64_t ParseISO8601DateTime(const std::string& str); +/** + * Convert milliseconds to a struct timeval for e.g. select. + */ +struct timeval MillisToTimeval(int64_t nTimeout); + #endif // BITCOIN_UTIL_TIME_H From dec9b5e850c6aad989e814aea5b630b36f55d580 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 4 Jan 2021 13:02:43 +0100 Subject: [PATCH 2/6] net: move CloseSocket() from netbase to util/sock Move `CloseSocket()` (and `NetworkErrorString()` which it uses) from `netbase.{h,cpp}` to newly added `src/util/sock.{h,cpp}`. This is necessary in order to use `CloseSocket()` from a newly introduced Sock class (which will live in `src/util/sock.{h,cpp}`). `sock.{h,cpp}` cannot depend on netbase because netbase will depend on it. --- src/Makefile.am | 2 ++ src/net.cpp | 1 + src/netbase.cpp | 52 +------------------------------------- src/netbase.h | 4 --- src/util/sock.cpp | 64 +++++++++++++++++++++++++++++++++++++++++++++++ src/util/sock.h | 18 +++++++++++++ 6 files changed, 86 insertions(+), 55 deletions(-) create mode 100644 src/util/sock.cpp create mode 100644 src/util/sock.h diff --git a/src/Makefile.am b/src/Makefile.am index 2871df124c..2e35ecdfbd 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -238,6 +238,7 @@ BITCOIN_CORE_H = \ util/rbf.h \ util/ref.h \ util/settings.h \ + util/sock.h \ util/spanparsing.h \ util/string.h \ util/system.h \ @@ -552,6 +553,7 @@ libbitcoin_util_a_SOURCES = \ util/error.cpp \ util/fees.cpp \ util/hasher.cpp \ + util/sock.cpp \ util/system.cpp \ util/message.cpp \ util/moneystr.cpp \ diff --git a/src/net.cpp b/src/net.cpp index 4f74bbede4..38aaeff121 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include diff --git a/src/netbase.cpp b/src/netbase.cpp index 93c395b9ec..3a3407f901 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -862,57 +863,6 @@ bool LookupSubNet(const std::string& strSubnet, CSubNet& ret) return false; } -#ifdef WIN32 -std::string NetworkErrorString(int err) -{ - wchar_t buf[256]; - buf[0] = 0; - if(FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_MAX_WIDTH_MASK, - nullptr, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - buf, ARRAYSIZE(buf), nullptr)) - { - return strprintf("%s (%d)", std::wstring_convert,wchar_t>().to_bytes(buf), err); - } - else - { - return strprintf("Unknown error (%d)", err); - } -} -#else -std::string NetworkErrorString(int err) -{ - char buf[256]; - buf[0] = 0; - /* Too bad there are two incompatible implementations of the - * thread-safe strerror. */ - const char *s; -#ifdef STRERROR_R_CHAR_P /* GNU variant can return a pointer outside the passed buffer */ - s = strerror_r(err, buf, sizeof(buf)); -#else /* POSIX variant always returns message in buffer */ - s = buf; - if (strerror_r(err, buf, sizeof(buf))) - buf[0] = 0; -#endif - return strprintf("%s (%d)", s, err); -} -#endif - -bool CloseSocket(SOCKET& hSocket) -{ - if (hSocket == INVALID_SOCKET) - return false; -#ifdef WIN32 - int ret = closesocket(hSocket); -#else - int ret = close(hSocket); -#endif - if (ret) { - LogPrintf("Socket close failed: %d. Error: %s\n", hSocket, NetworkErrorString(WSAGetLastError())); - } - hSocket = INVALID_SOCKET; - return ret != SOCKET_ERROR; -} - bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking) { if (fNonBlocking) { diff --git a/src/netbase.h b/src/netbase.h index 3dc656d0db..38d33e475b 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -54,10 +54,6 @@ bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet); SOCKET CreateSocket(const CService &addrConnect); bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection); bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed); -/** Return readable error string for a network error code */ -std::string NetworkErrorString(int err); -/** Close socket and set hSocket to INVALID_SOCKET */ -bool CloseSocket(SOCKET& hSocket); /** Disable or enable blocking-mode for a socket */ bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking); /** Set the TCP_NODELAY flag on a socket */ diff --git a/src/util/sock.cpp b/src/util/sock.cpp new file mode 100644 index 0000000000..35eca4afb1 --- /dev/null +++ b/src/util/sock.cpp @@ -0,0 +1,64 @@ +// Copyright (c) 2020-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef WIN32 +std::string NetworkErrorString(int err) +{ + wchar_t buf[256]; + buf[0] = 0; + if(FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_MAX_WIDTH_MASK, + nullptr, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + buf, ARRAYSIZE(buf), nullptr)) + { + return strprintf("%s (%d)", std::wstring_convert,wchar_t>().to_bytes(buf), err); + } + else + { + return strprintf("Unknown error (%d)", err); + } +} +#else +std::string NetworkErrorString(int err) +{ + char buf[256]; + buf[0] = 0; + /* Too bad there are two incompatible implementations of the + * thread-safe strerror. */ + const char *s; +#ifdef STRERROR_R_CHAR_P /* GNU variant can return a pointer outside the passed buffer */ + s = strerror_r(err, buf, sizeof(buf)); +#else /* POSIX variant always returns message in buffer */ + s = buf; + if (strerror_r(err, buf, sizeof(buf))) + buf[0] = 0; +#endif + return strprintf("%s (%d)", s, err); +} +#endif + +bool CloseSocket(SOCKET& hSocket) +{ + if (hSocket == INVALID_SOCKET) + return false; +#ifdef WIN32 + int ret = closesocket(hSocket); +#else + int ret = close(hSocket); +#endif + if (ret) { + LogPrintf("Socket close failed: %d. Error: %s\n", hSocket, NetworkErrorString(WSAGetLastError())); + } + hSocket = INVALID_SOCKET; + return ret != SOCKET_ERROR; +} diff --git a/src/util/sock.h b/src/util/sock.h new file mode 100644 index 0000000000..0d48235043 --- /dev/null +++ b/src/util/sock.h @@ -0,0 +1,18 @@ +// Copyright (c) 2020-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_UTIL_SOCK_H +#define BITCOIN_UTIL_SOCK_H + +#include + +#include + +/** Return readable error string for a network error code */ +std::string NetworkErrorString(int err); + +/** Close socket and set hSocket to INVALID_SOCKET */ +bool CloseSocket(SOCKET& hSocket); + +#endif // BITCOIN_UTIL_SOCK_H From ba9d73268f9585d4b9254adcf54708f88222798b Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Wed, 23 Dec 2020 16:40:11 +0100 Subject: [PATCH 3/6] net: add RAII socket and use it instead of bare SOCKET Introduce a class to manage the lifetime of a socket - when the object that contains the socket goes out of scope, the underlying socket will be closed. In addition, the new `Sock` class has a `Send()`, `Recv()` and `Wait()` methods that can be overridden by unit tests to mock the socket operations. The `Wait()` method also hides the `#ifdef USE_POLL poll() #else select() #endif` technique from higher level code. --- src/net.cpp | 47 +++++++++++----------- src/netbase.cpp | 31 +++++++------- src/netbase.h | 17 +++++++- src/util/sock.cpp | 85 +++++++++++++++++++++++++++++++++++++++ src/util/sock.h | 100 ++++++++++++++++++++++++++++++++++++++++++++++ src/util/time.cpp | 5 +++ src/util/time.h | 6 +++ 7 files changed, 250 insertions(+), 41 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 38aaeff121..2a3669b90e 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -429,24 +429,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo // Connect bool connected = false; - SOCKET hSocket = INVALID_SOCKET; + std::unique_ptr sock; proxyType proxy; if (addrConnect.IsValid()) { bool proxyConnectionFailed = false; if (GetProxy(addrConnect.GetNetwork(), proxy)) { - hSocket = CreateSocket(proxy.proxy); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(proxy.proxy); + if (!sock) { return nullptr; } - connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), hSocket, nConnectTimeout, proxyConnectionFailed); + connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), + sock->Get(), nConnectTimeout, proxyConnectionFailed); } else { // no proxy needed (none set for target network) - hSocket = CreateSocket(addrConnect); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(addrConnect); + if (!sock) { return nullptr; } - connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, conn_type == ConnectionType::MANUAL); + connected = ConnectSocketDirectly(addrConnect, sock->Get(), nConnectTimeout, + conn_type == ConnectionType::MANUAL); } if (!proxyConnectionFailed) { // If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to @@ -454,26 +456,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo addrman.Attempt(addrConnect, fCountFailure); } } else if (pszDest && GetNameProxy(proxy)) { - hSocket = CreateSocket(proxy.proxy); - if (hSocket == INVALID_SOCKET) { + sock = CreateSock(proxy.proxy); + if (!sock) { return nullptr; } std::string host; int port = default_port; SplitHostPort(std::string(pszDest), port, host); bool proxyConnectionFailed; - connected = ConnectThroughProxy(proxy, host, port, hSocket, nConnectTimeout, proxyConnectionFailed); + connected = ConnectThroughProxy(proxy, host, port, sock->Get(), nConnectTimeout, + proxyConnectionFailed); } if (!connected) { - CloseSocket(hSocket); return nullptr; } // Add node NodeId id = GetNewNodeId(); uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); - CAddress addr_bind = GetBindAddress(hSocket); - CNode* pnode = new CNode(id, nLocalServices, hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); + CAddress addr_bind = GetBindAddress(sock->Get()); + CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); pnode->AddRef(); // We're making a new connection, harvest entropy from the time (and our peer count) @@ -2177,9 +2179,8 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - SOCKET hListenSocket = CreateSocket(addrBind); - if (hListenSocket == INVALID_SOCKET) - { + std::unique_ptr sock = CreateSock(addrBind); + if (!sock) { strError = strprintf(Untranslated("Error: Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError.original); return false; @@ -2187,21 +2188,21 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, // Allow binding if the port is still in TIME_WAIT state after // the program was closed and restarted. - setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); + setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option // and enable it by default or not. Try to enable it, if possible. if (addrBind.IsIPv6()) { #ifdef IPV6_V6ONLY - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); + setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); #endif #ifdef WIN32 int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); + setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); #endif } - if (::bind(hListenSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) + if (::bind(sock->Get(), (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); if (nErr == WSAEADDRINUSE) @@ -2209,21 +2210,19 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, else strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToString(), NetworkErrorString(nErr)); LogPrintf("%s\n", strError.original); - CloseSocket(hListenSocket); return false; } LogPrintf("Bound to %s\n", addrBind.ToString()); // Listen for incoming connections - if (listen(hListenSocket, SOMAXCONN) == SOCKET_ERROR) + if (listen(sock->Get(), SOMAXCONN) == SOCKET_ERROR) { strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError.original); - CloseSocket(hListenSocket); return false; } - vhListenSocket.push_back(ListenSocket(hListenSocket, permissions)); + vhListenSocket.push_back(ListenSocket(sock->Release(), permissions)); return true; } diff --git a/src/netbase.cpp b/src/netbase.cpp index 3a3407f901..93a04ab5b4 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -15,7 +15,9 @@ #include #include +#include #include +#include #ifndef WIN32 #include @@ -559,34 +561,28 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials return true; } -/** - * Try to create a socket file descriptor with specific properties in the - * communications domain (address family) of the specified service. - * - * For details on the desired properties, see the inline comments in the source - * code. - */ -SOCKET CreateSocket(const CService &addrConnect) +std::unique_ptr CreateSockTCP(const CService& address_family) { // Create a sockaddr from the specified service. struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { - LogPrintf("Cannot create socket for %s: unsupported network\n", addrConnect.ToString()); - return INVALID_SOCKET; + if (!address_family.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { + LogPrintf("Cannot create socket for %s: unsupported network\n", address_family.ToString()); + return nullptr; } // Create a TCP socket in the address family of the specified service. SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); - if (hSocket == INVALID_SOCKET) - return INVALID_SOCKET; + if (hSocket == INVALID_SOCKET) { + return nullptr; + } // Ensure that waiting for I/O on this socket won't result in undefined // behavior. if (!IsSelectableSocket(hSocket)) { CloseSocket(hSocket); LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); - return INVALID_SOCKET; + return nullptr; } #ifdef SO_NOSIGPIPE @@ -602,11 +598,14 @@ SOCKET CreateSocket(const CService &addrConnect) // Set the non-blocking option on the socket. if (!SetSocketNonBlocking(hSocket, true)) { CloseSocket(hSocket); - LogPrintf("CreateSocket: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError())); + LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError())); + return nullptr; } - return hSocket; + return std::make_unique(hSocket); } +std::function(const CService&)> CreateSock = CreateSockTCP; + template static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) { std::string error_message = tfm::format(fmt, args...); diff --git a/src/netbase.h b/src/netbase.h index 38d33e475b..d906888235 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -12,7 +12,10 @@ #include #include #include +#include +#include +#include #include #include #include @@ -51,7 +54,19 @@ bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllo bool Lookup(const std::string& name, std::vector& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions); CService LookupNumeric(const std::string& name, int portDefault = 0); bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet); -SOCKET CreateSocket(const CService &addrConnect); + +/** + * Create a TCP socket in the given address family. + * @param[in] address_family The socket is created in the same address family as this address. + * @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure + */ +std::unique_ptr CreateSockTCP(const CService& address_family); + +/** + * Socket factory. Defaults to `CreateSockTCP()`, but can be overridden by unit tests. + */ +extern std::function(const CService&)> CreateSock; + bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection); bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed); /** Disable or enable blocking-mode for a socket */ diff --git a/src/util/sock.cpp b/src/util/sock.cpp index 35eca4afb1..4c65b5b680 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -6,12 +6,97 @@ #include #include #include +#include +#include #include #include #include #include +#ifdef USE_POLL +#include +#endif + +Sock::Sock() : m_socket(INVALID_SOCKET) {} + +Sock::Sock(SOCKET s) : m_socket(s) {} + +Sock::Sock(Sock&& other) +{ + m_socket = other.m_socket; + other.m_socket = INVALID_SOCKET; +} + +Sock::~Sock() { Reset(); } + +Sock& Sock::operator=(Sock&& other) +{ + Reset(); + m_socket = other.m_socket; + other.m_socket = INVALID_SOCKET; + return *this; +} + +SOCKET Sock::Get() const { return m_socket; } + +SOCKET Sock::Release() +{ + const SOCKET s = m_socket; + m_socket = INVALID_SOCKET; + return s; +} + +void Sock::Reset() { CloseSocket(m_socket); } + +ssize_t Sock::Send(const void* data, size_t len, int flags) const +{ + return send(m_socket, static_cast(data), len, flags); +} + +ssize_t Sock::Recv(void* buf, size_t len, int flags) const +{ + return recv(m_socket, static_cast(buf), len, flags); +} + +bool Sock::Wait(std::chrono::milliseconds timeout, Event requested) const +{ +#ifdef USE_POLL + pollfd fd; + fd.fd = m_socket; + fd.events = 0; + if (requested & RECV) { + fd.events |= POLLIN; + } + if (requested & SEND) { + fd.events |= POLLOUT; + } + + return poll(&fd, 1, count_milliseconds(timeout)) != SOCKET_ERROR; +#else + if (!IsSelectableSocket(m_socket)) { + return false; + } + + fd_set fdset_recv; + fd_set fdset_send; + FD_ZERO(&fdset_recv); + FD_ZERO(&fdset_send); + + if (requested & RECV) { + FD_SET(m_socket, &fdset_recv); + } + + if (requested & SEND) { + FD_SET(m_socket, &fdset_send); + } + + timeval timeout_struct = MillisToTimeval(timeout); + + return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct) != SOCKET_ERROR; +#endif /* USE_POLL */ +} + #ifdef WIN32 std::string NetworkErrorString(int err) { diff --git a/src/util/sock.h b/src/util/sock.h index 0d48235043..26fe60f18f 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -7,8 +7,108 @@ #include +#include #include +/** + * RAII helper class that manages a socket. Mimics `std::unique_ptr`, but instead of a pointer it + * contains a socket and closes it automatically when it goes out of scope. + */ +class Sock +{ +public: + /** + * Default constructor, creates an empty object that does nothing when destroyed. + */ + Sock(); + + /** + * Take ownership of an existent socket. + */ + explicit Sock(SOCKET s); + + /** + * Copy constructor, disabled because closing the same socket twice is undesirable. + */ + Sock(const Sock&) = delete; + + /** + * Move constructor, grab the socket from another object and close ours (if set). + */ + Sock(Sock&& other); + + /** + * Destructor, close the socket or do nothing if empty. + */ + virtual ~Sock(); + + /** + * Copy assignment operator, disabled because closing the same socket twice is undesirable. + */ + Sock& operator=(const Sock&) = delete; + + /** + * Move assignment operator, grab the socket from another object and close ours (if set). + */ + virtual Sock& operator=(Sock&& other); + + /** + * Get the value of the contained socket. + * @return socket or INVALID_SOCKET if empty + */ + virtual SOCKET Get() const; + + /** + * Get the value of the contained socket and drop ownership. It will not be closed by the + * destructor after this call. + * @return socket or INVALID_SOCKET if empty + */ + virtual SOCKET Release(); + + /** + * Close if non-empty. + */ + virtual void Reset(); + + /** + * send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Send(const void* data, size_t len, int flags) const; + + /** + * recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Recv(void* buf, size_t len, int flags) const; + + using Event = uint8_t; + + /** + * If passed to `Wait()`, then it will wait for readiness to read from the socket. + */ + static constexpr Event RECV = 0b01; + + /** + * If passed to `Wait()`, then it will wait for readiness to send to the socket. + */ + static constexpr Event SEND = 0b10; + + /** + * Wait for readiness for input (recv) or output (send). + * @param[in] timeout Wait this much for at least one of the requested events to occur. + * @param[in] requested Wait for those events, bitwise-or of `RECV` and `SEND`. + * @return true on success and false otherwise + */ + virtual bool Wait(std::chrono::milliseconds timeout, Event requested) const; + +private: + /** + * Contained socket. `INVALID_SOCKET` designates the object is empty. + */ + SOCKET m_socket; +}; + /** Return readable error string for a network error code */ std::string NetworkErrorString(int err); diff --git a/src/util/time.cpp b/src/util/time.cpp index 4da041e5a5..4aed9f60b0 100644 --- a/src/util/time.cpp +++ b/src/util/time.cpp @@ -123,3 +123,8 @@ struct timeval MillisToTimeval(int64_t nTimeout) timeout.tv_usec = (nTimeout % 1000) * 1000; return timeout; } + +struct timeval MillisToTimeval(std::chrono::milliseconds ms) +{ + return MillisToTimeval(count_milliseconds(ms)); +} diff --git a/src/util/time.h b/src/util/time.h index 2c0e3d83f6..03b75b5be5 100644 --- a/src/util/time.h +++ b/src/util/time.h @@ -27,6 +27,7 @@ void UninterruptibleSleep(const std::chrono::microseconds& n); * interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI) */ inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); } +inline int64_t count_milliseconds(std::chrono::milliseconds t) { return t.count(); } inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); } /** @@ -64,4 +65,9 @@ int64_t ParseISO8601DateTime(const std::string& str); */ struct timeval MillisToTimeval(int64_t nTimeout); +/** + * Convert milliseconds to a struct timeval for e.g. select. + */ +struct timeval MillisToTimeval(std::chrono::milliseconds ms); + #endif // BITCOIN_UTIL_TIME_H From 04ae8469049e1f14585aabfb618ae522150240a7 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 28 Dec 2020 16:57:10 +0100 Subject: [PATCH 4/6] net: use Sock in InterruptibleRecv() and Socks5() Use the `Sock` class instead of `SOCKET` for `InterruptibleRecv()` and `Socks5()`. This way the `Socks5()` function can be tested by giving it a mocked instance of a socket. Co-authored-by: practicalswift --- src/net.cpp | 4 ++-- src/netbase.cpp | 33 +++++++++------------------------ src/netbase.h | 2 +- 3 files changed, 12 insertions(+), 27 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 2a3669b90e..16aa489873 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -440,7 +440,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo return nullptr; } connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), - sock->Get(), nConnectTimeout, proxyConnectionFailed); + *sock, nConnectTimeout, proxyConnectionFailed); } else { // no proxy needed (none set for target network) sock = CreateSock(addrConnect); @@ -464,7 +464,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo int port = default_port; SplitHostPort(std::string(pszDest), port, host); bool proxyConnectionFailed; - connected = ConnectThroughProxy(proxy, host, port, sock->Get(), nConnectTimeout, + connected = ConnectThroughProxy(proxy, host, port, *sock, nConnectTimeout, proxyConnectionFailed); } if (!connected) { diff --git a/src/netbase.cpp b/src/netbase.cpp index 93a04ab5b4..59a082befa 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -343,7 +343,7 @@ enum class IntrRecvError { * Sockets can be made non-blocking with SetSocketNonBlocking(const * SOCKET&, bool). */ -static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const SOCKET& hSocket) +static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const Sock& hSocket) { int64_t curTime = GetTimeMillis(); int64_t endTime = curTime + timeout; @@ -351,7 +351,7 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c // (in millis) to break off in case of an interruption. const int64_t maxWait = 1000; while (len > 0 && curTime < endTime) { - ssize_t ret = recv(hSocket, (char*)data, len, 0); // Optimistically try the recv first + ssize_t ret = hSocket.Recv(data, len, 0); // Optimistically try the recv first if (ret > 0) { len -= ret; data += ret; @@ -360,25 +360,10 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c } else { // Other error or blocking int nErr = WSAGetLastError(); if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) { - if (!IsSelectableSocket(hSocket)) { - return IntrRecvError::NetworkError; - } // Only wait at most maxWait milliseconds at a time, unless // we're approaching the end of the specified total timeout int timeout_ms = std::min(endTime - curTime, maxWait); -#ifdef USE_POLL - struct pollfd pollfd = {}; - pollfd.fd = hSocket; - pollfd.events = POLLIN; - int nRet = poll(&pollfd, 1, timeout_ms); -#else - struct timeval tval = MillisToTimeval(timeout_ms); - fd_set fdset; - FD_ZERO(&fdset); - FD_SET(hSocket, &fdset); - int nRet = select(hSocket + 1, &fdset, nullptr, nullptr, &tval); -#endif - if (nRet == SOCKET_ERROR) { + if (!hSocket.Wait(std::chrono::milliseconds{timeout_ms}, Sock::RECV)) { return IntrRecvError::NetworkError; } } else { @@ -442,7 +427,7 @@ static std::string Socks5ErrorString(uint8_t err) * @see RFC1928: SOCKS Protocol * Version 5 */ -static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, const SOCKET& hSocket) +static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* auth, const Sock& hSocket) { IntrRecvError recvr; LogPrint(BCLog::NET, "SOCKS5 connecting %s\n", strDest); @@ -460,7 +445,7 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vSocks5Init.push_back(0x01); // 1 method identifier follows... vSocks5Init.push_back(SOCKS5Method::NOAUTH); } - ssize_t ret = send(hSocket, (const char*)vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL); + ssize_t ret = hSocket.Send(vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5Init.size()) { return error("Error sending to proxy"); } @@ -482,7 +467,7 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vAuth.insert(vAuth.end(), auth->username.begin(), auth->username.end()); vAuth.push_back(auth->password.size()); vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end()); - ret = send(hSocket, (const char*)vAuth.data(), vAuth.size(), MSG_NOSIGNAL); + ret = hSocket.Send(vAuth.data(), vAuth.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vAuth.size()) { return error("Error sending authentication to proxy"); } @@ -508,7 +493,7 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vSocks5.insert(vSocks5.end(), strDest.begin(), strDest.end()); vSocks5.push_back((port >> 8) & 0xFF); vSocks5.push_back((port >> 0) & 0xFF); - ret = send(hSocket, (const char*)vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL); + ret = hSocket.Send(vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5.size()) { return error("Error sending to proxy"); } @@ -787,10 +772,10 @@ bool IsProxy(const CNetAddr &addr) { * * @returns Whether or not the operation succeeded. */ -bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocket, int nTimeout, bool& outProxyConnectionFailed) +bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& hSocket, int nTimeout, bool& outProxyConnectionFailed) { // first connect to proxy server - if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout, true)) { + if (!ConnectSocketDirectly(proxy.proxy, hSocket.Get(), nTimeout, true)) { outProxyConnectionFailed = true; return false; } diff --git a/src/netbase.h b/src/netbase.h index d906888235..9ad9a864d8 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -68,7 +68,7 @@ std::unique_ptr CreateSockTCP(const CService& address_family); extern std::function(const CService&)> CreateSock; bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection); -bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed); +bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& hSocketRet, int nTimeout, bool& outProxyConnectionFailed); /** Disable or enable blocking-mode for a socket */ bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking); /** Set the TCP_NODELAY flag on a socket */ From 7bd21ce1efc363b3e8ea1d51dd1410ccd66820cb Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Thu, 4 Feb 2021 18:07:24 +0100 Subject: [PATCH 5/6] style: rename hSocket to sock In the arguments of `InterruptibleRecv()`, `Socks5()` and `ConnectThroughProxy()` the variable `hSocket` was previously of type `SOCKET`, but has been changed to `Sock`. Thus rename it to `sock` to imply its type, to distinguish from other `SOCKET` variables and to abide to the coding style wrt variables' names. --- src/netbase.cpp | 45 ++++++++++++++++++++++----------------------- src/netbase.h | 2 +- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/netbase.cpp b/src/netbase.cpp index 59a082befa..24188f83c6 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -332,8 +332,7 @@ enum class IntrRecvError { * @param data The buffer where the read bytes should be stored. * @param len The number of bytes to read into the specified buffer. * @param timeout The total timeout in milliseconds for this read. - * @param hSocket The socket (has to be in non-blocking mode) from which to read - * bytes. + * @param sock The socket (has to be in non-blocking mode) from which to read bytes. * * @returns An IntrRecvError indicating the resulting status of this read. * IntrRecvError::OK only if all of the specified number of bytes were @@ -343,7 +342,7 @@ enum class IntrRecvError { * Sockets can be made non-blocking with SetSocketNonBlocking(const * SOCKET&, bool). */ -static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const Sock& hSocket) +static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const Sock& sock) { int64_t curTime = GetTimeMillis(); int64_t endTime = curTime + timeout; @@ -351,7 +350,7 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c // (in millis) to break off in case of an interruption. const int64_t maxWait = 1000; while (len > 0 && curTime < endTime) { - ssize_t ret = hSocket.Recv(data, len, 0); // Optimistically try the recv first + ssize_t ret = sock.Recv(data, len, 0); // Optimistically try the recv first if (ret > 0) { len -= ret; data += ret; @@ -363,7 +362,7 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c // Only wait at most maxWait milliseconds at a time, unless // we're approaching the end of the specified total timeout int timeout_ms = std::min(endTime - curTime, maxWait); - if (!hSocket.Wait(std::chrono::milliseconds{timeout_ms}, Sock::RECV)) { + if (!sock.Wait(std::chrono::milliseconds{timeout_ms}, Sock::RECV)) { return IntrRecvError::NetworkError; } } else { @@ -417,7 +416,7 @@ static std::string Socks5ErrorString(uint8_t err) * @param port The destination port. * @param auth The credentials with which to authenticate with the specified * SOCKS5 proxy. - * @param hSocket The SOCKS5 proxy socket. + * @param sock The SOCKS5 proxy socket. * * @returns Whether or not the operation succeeded. * @@ -427,7 +426,7 @@ static std::string Socks5ErrorString(uint8_t err) * @see RFC1928: SOCKS Protocol * Version 5 */ -static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* auth, const Sock& hSocket) +static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* auth, const Sock& sock) { IntrRecvError recvr; LogPrint(BCLog::NET, "SOCKS5 connecting %s\n", strDest); @@ -445,12 +444,12 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* vSocks5Init.push_back(0x01); // 1 method identifier follows... vSocks5Init.push_back(SOCKS5Method::NOAUTH); } - ssize_t ret = hSocket.Send(vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL); + ssize_t ret = sock.Send(vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5Init.size()) { return error("Error sending to proxy"); } uint8_t pchRet1[2]; - if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { LogPrintf("Socks5() connect to %s:%d failed: InterruptibleRecv() timeout or other failure\n", strDest, port); return false; } @@ -467,13 +466,13 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* vAuth.insert(vAuth.end(), auth->username.begin(), auth->username.end()); vAuth.push_back(auth->password.size()); vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end()); - ret = hSocket.Send(vAuth.data(), vAuth.size(), MSG_NOSIGNAL); + ret = sock.Send(vAuth.data(), vAuth.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vAuth.size()) { return error("Error sending authentication to proxy"); } LogPrint(BCLog::PROXY, "SOCKS5 sending proxy authentication %s:%s\n", auth->username, auth->password); uint8_t pchRetA[2]; - if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { return error("Error reading proxy authentication response"); } if (pchRetA[0] != 0x01 || pchRetA[1] != 0x00) { @@ -493,12 +492,12 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* vSocks5.insert(vSocks5.end(), strDest.begin(), strDest.end()); vSocks5.push_back((port >> 8) & 0xFF); vSocks5.push_back((port >> 0) & 0xFF); - ret = hSocket.Send(vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL); + ret = sock.Send(vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5.size()) { return error("Error sending to proxy"); } uint8_t pchRet2[4]; - if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { if (recvr == IntrRecvError::Timeout) { /* If a timeout happens here, this effectively means we timed out while connecting * to the remote node. This is very common for Tor, so do not print an @@ -522,16 +521,16 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* uint8_t pchRet3[256]; switch (pchRet2[3]) { - case SOCKS5Atyp::IPV4: recvr = InterruptibleRecv(pchRet3, 4, SOCKS5_RECV_TIMEOUT, hSocket); break; - case SOCKS5Atyp::IPV6: recvr = InterruptibleRecv(pchRet3, 16, SOCKS5_RECV_TIMEOUT, hSocket); break; + case SOCKS5Atyp::IPV4: recvr = InterruptibleRecv(pchRet3, 4, SOCKS5_RECV_TIMEOUT, sock); break; + case SOCKS5Atyp::IPV6: recvr = InterruptibleRecv(pchRet3, 16, SOCKS5_RECV_TIMEOUT, sock); break; case SOCKS5Atyp::DOMAINNAME: { - recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, hSocket); + recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, sock); if (recvr != IntrRecvError::OK) { return error("Error reading from proxy"); } int nRecv = pchRet3[0]; - recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, hSocket); + recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, sock); break; } default: return error("Error: malformed proxy response"); @@ -539,7 +538,7 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* if (recvr != IntrRecvError::OK) { return error("Error reading from proxy"); } - if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) { + if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) { return error("Error reading from proxy"); } LogPrint(BCLog::NET, "SOCKS5 connected %s\n", strDest); @@ -764,7 +763,7 @@ bool IsProxy(const CNetAddr &addr) { * @param proxy The SOCKS5 proxy. * @param strDest The destination service to which to connect. * @param port The destination port. - * @param hSocket The socket on which to connect to the SOCKS5 proxy. + * @param sock The socket on which to connect to the SOCKS5 proxy. * @param nTimeout Wait this many milliseconds for the connection to the SOCKS5 * proxy to be established. * @param[out] outProxyConnectionFailed Whether or not the connection to the @@ -772,10 +771,10 @@ bool IsProxy(const CNetAddr &addr) { * * @returns Whether or not the operation succeeded. */ -bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& hSocket, int nTimeout, bool& outProxyConnectionFailed) +bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& sock, int nTimeout, bool& outProxyConnectionFailed) { // first connect to proxy server - if (!ConnectSocketDirectly(proxy.proxy, hSocket.Get(), nTimeout, true)) { + if (!ConnectSocketDirectly(proxy.proxy, sock.Get(), nTimeout, true)) { outProxyConnectionFailed = true; return false; } @@ -784,11 +783,11 @@ bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int ProxyCredentials random_auth; static std::atomic_int counter(0); random_auth.username = random_auth.password = strprintf("%i", counter++); - if (!Socks5(strDest, (uint16_t)port, &random_auth, hSocket)) { + if (!Socks5(strDest, (uint16_t)port, &random_auth, sock)) { return false; } } else { - if (!Socks5(strDest, (uint16_t)port, 0, hSocket)) { + if (!Socks5(strDest, (uint16_t)port, 0, sock)) { return false; } } diff --git a/src/netbase.h b/src/netbase.h index 9ad9a864d8..afc373ef49 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -68,7 +68,7 @@ std::unique_ptr CreateSockTCP(const CService& address_family); extern std::function(const CService&)> CreateSock; bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection); -bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& hSocketRet, int nTimeout, bool& outProxyConnectionFailed); +bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& sock, int nTimeout, bool& outProxyConnectionFailed); /** Disable or enable blocking-mode for a socket */ bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking); /** Set the TCP_NODELAY flag on a socket */ From 615ba0eb96cf131364c1ceca9d3dedf006fa1e1c Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Mon, 4 Jan 2021 11:44:21 +0100 Subject: [PATCH 6/6] test: add Sock unit tests --- src/Makefile.test.include | 1 + src/test/sock_tests.cpp | 149 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 src/test/sock_tests.cpp diff --git a/src/Makefile.test.include b/src/Makefile.test.include index e9f9b73abe..770aba467d 100644 --- a/src/Makefile.test.include +++ b/src/Makefile.test.include @@ -114,6 +114,7 @@ BITCOIN_TESTS =\ test/sighash_tests.cpp \ test/sigopcount_tests.cpp \ test/skiplist_tests.cpp \ + test/sock_tests.cpp \ test/streams_tests.cpp \ test/sync_tests.cpp \ test/system_tests.cpp \ diff --git a/src/test/sock_tests.cpp b/src/test/sock_tests.cpp new file mode 100644 index 0000000000..cc0e6e7057 --- /dev/null +++ b/src/test/sock_tests.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2021-2021 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include + +#include + +#include + +using namespace std::chrono_literals; + +BOOST_FIXTURE_TEST_SUITE(sock_tests, BasicTestingSetup) + +static bool SocketIsClosed(const SOCKET& s) +{ + // Notice that if another thread is running and creates its own socket after `s` has been + // closed, it may be assigned the same file descriptor number. In this case, our test will + // wrongly pretend that the socket is not closed. + int type; + socklen_t len = sizeof(type); + return getsockopt(s, SOL_SOCKET, SO_TYPE, (sockopt_arg_type)&type, &len) == SOCKET_ERROR; +} + +static SOCKET CreateSocket() +{ + const SOCKET s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + BOOST_REQUIRE(s != static_cast(SOCKET_ERROR)); + return s; +} + +BOOST_AUTO_TEST_CASE(constructor_and_destructor) +{ + const SOCKET s = CreateSocket(); + Sock* sock = new Sock(s); + BOOST_CHECK_EQUAL(sock->Get(), s); + BOOST_CHECK(!SocketIsClosed(s)); + delete sock; + BOOST_CHECK(SocketIsClosed(s)); +} + +BOOST_AUTO_TEST_CASE(move_constructor) +{ + const SOCKET s = CreateSocket(); + Sock* sock1 = new Sock(s); + Sock* sock2 = new Sock(std::move(*sock1)); + delete sock1; + BOOST_CHECK(!SocketIsClosed(s)); + BOOST_CHECK_EQUAL(sock2->Get(), s); + delete sock2; + BOOST_CHECK(SocketIsClosed(s)); +} + +BOOST_AUTO_TEST_CASE(move_assignment) +{ + const SOCKET s = CreateSocket(); + Sock* sock1 = new Sock(s); + Sock* sock2 = new Sock(); + *sock2 = std::move(*sock1); + delete sock1; + BOOST_CHECK(!SocketIsClosed(s)); + BOOST_CHECK_EQUAL(sock2->Get(), s); + delete sock2; + BOOST_CHECK(SocketIsClosed(s)); +} + +BOOST_AUTO_TEST_CASE(release) +{ + SOCKET s = CreateSocket(); + Sock* sock = new Sock(s); + BOOST_CHECK_EQUAL(sock->Release(), s); + delete sock; + BOOST_CHECK(!SocketIsClosed(s)); + BOOST_REQUIRE(CloseSocket(s)); +} + +BOOST_AUTO_TEST_CASE(reset) +{ + const SOCKET s = CreateSocket(); + Sock sock(s); + sock.Reset(); + BOOST_CHECK(SocketIsClosed(s)); +} + +#ifndef WIN32 // Windows does not have socketpair(2). + +static void CreateSocketPair(int s[2]) +{ + BOOST_REQUIRE_EQUAL(socketpair(AF_UNIX, SOCK_STREAM, 0, s), 0); +} + +static void SendAndRecvMessage(const Sock& sender, const Sock& receiver) +{ + const char* msg = "abcd"; + constexpr size_t msg_len = 4; + char recv_buf[10]; + + BOOST_CHECK_EQUAL(sender.Send(msg, msg_len, 0), msg_len); + BOOST_CHECK_EQUAL(receiver.Recv(recv_buf, sizeof(recv_buf), 0), msg_len); + BOOST_CHECK_EQUAL(strncmp(msg, recv_buf, msg_len), 0); +} + +BOOST_AUTO_TEST_CASE(send_and_receive) +{ + int s[2]; + CreateSocketPair(s); + + Sock* sock0 = new Sock(s[0]); + Sock* sock1 = new Sock(s[1]); + + SendAndRecvMessage(*sock0, *sock1); + + Sock* sock0moved = new Sock(std::move(*sock0)); + Sock* sock1moved = new Sock(); + *sock1moved = std::move(*sock1); + + delete sock0; + delete sock1; + + SendAndRecvMessage(*sock1moved, *sock0moved); + + delete sock0moved; + delete sock1moved; + + BOOST_CHECK(SocketIsClosed(s[0])); + BOOST_CHECK(SocketIsClosed(s[1])); +} + +BOOST_AUTO_TEST_CASE(wait) +{ + int s[2]; + CreateSocketPair(s); + + Sock sock0(s[0]); + Sock sock1(s[1]); + + std::thread waiter([&sock0]() { sock0.Wait(24h, Sock::RECV); }); + + BOOST_REQUIRE_EQUAL(sock1.Send("a", 1, 0), 1); + + waiter.join(); +} + +#endif /* WIN32 */ + +BOOST_AUTO_TEST_SUITE_END()