From 748dbcd9f29dbe4110da8a06f08e3eefa95f5321 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 13 Apr 2021 14:37:16 +0200 Subject: [PATCH 1/2] net: add new method Sock::GetSockName() that wraps getsockname() This will help to increase `Sock` usage and make more code mockable. --- src/test/fuzz/util.cpp | 14 ++++++++++++++ src/test/fuzz/util.h | 2 ++ src/test/util/net.h | 6 ++++++ src/util/sock.cpp | 5 +++++ src/util/sock.h | 7 +++++++ 5 files changed, 34 insertions(+) diff --git a/src/test/fuzz/util.cpp b/src/test/fuzz/util.cpp index 883698aff1a..76a85b7a139 100644 --- a/src/test/fuzz/util.cpp +++ b/src/test/fuzz/util.cpp @@ -206,6 +206,20 @@ int FuzzedSock::SetSockOpt(int, int, const void*, socklen_t) const return 0; } +int FuzzedSock::GetSockName(sockaddr* name, socklen_t* name_len) const +{ + constexpr std::array getsockname_errnos{ + ECONNRESET, + ENOBUFS, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, getsockname_errnos); + return -1; + } + *name_len = m_fuzzed_data_provider.ConsumeData(name, *name_len); + return 0; +} + bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { constexpr std::array wait_errnos{ diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index 66d00b17676..9ac7347d8b1 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -69,6 +69,8 @@ public: int SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const override; + int GetSockName(sockaddr* name, socklen_t* name_len) const override; + bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override; bool WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const override; diff --git a/src/test/util/net.h b/src/test/util/net.h index 37d278645ab..ec3c2894ce7 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -152,6 +152,12 @@ public: int SetSockOpt(int, int, const void*, socklen_t) const override { return 0; } + int GetSockName(sockaddr* name, socklen_t* name_len) const override + { + std::memset(name, 0x0, *name_len); + return 0; + } + bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override diff --git a/src/util/sock.cpp b/src/util/sock.cpp index 7d5069423af..b4c0aa42052 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -111,6 +111,11 @@ int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt return setsockopt(m_socket, level, opt_name, static_cast(opt_val), opt_len); } +int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const +{ + return getsockname(m_socket, name, name_len); +} + bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want diff --git a/src/util/sock.h b/src/util/sock.h index 3245820995f..96d0b3b56b2 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -126,6 +126,13 @@ public: const void* opt_val, socklen_t opt_len) const; + /** + * getsockname(2) wrapper. Equivalent to + * `getsockname(this->Get(), name, name_len)`. Code that uses this + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. + */ + [[nodiscard]] virtual int GetSockName(sockaddr* name, socklen_t* name_len) const; + using Event = uint8_t; /** From a8d6abba5ec4ae2a3375e9be0b739f298899eca2 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 13 Apr 2021 15:11:20 +0200 Subject: [PATCH 2/2] net: change GetBindAddress() to take Sock argument This avoids the direct call to `getsockname()` and allows mocking. --- src/net.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index d42f130af71..7499d4c72fd 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -422,13 +422,13 @@ bool CConnman::CheckIncomingNonce(uint64_t nonce) } /** Get the bind address for a socket as CAddress */ -static CAddress GetBindAddress(SOCKET sock) +static CAddress GetBindAddress(const Sock& sock) { CAddress addr_bind; struct sockaddr_storage sockaddr_bind; socklen_t sockaddr_bind_len = sizeof(sockaddr_bind); - if (sock != INVALID_SOCKET) { - if (!getsockname(sock, (struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { + if (sock.Get() != INVALID_SOCKET) { + if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind); } else { LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "getsockname failed\n"); @@ -540,7 +540,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo NodeId id = GetNewNodeId(); uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); if (!addr_bind.IsValid()) { - addr_bind = GetBindAddress(sock->Get()); + addr_bind = GetBindAddress(*sock); } CNode* pnode = new CNode(id, nLocalServices, @@ -1154,7 +1154,7 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE}; } - const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(sock->Get())), NODE_NONE}; + const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock)), NODE_NONE}; NetPermissionFlags permissionFlags = NetPermissionFlags::None; hListenSocket.AddSocketPermissionFlags(permissionFlags);