diff --git a/src/net.cpp b/src/net.cpp index c09e3aedb6..eb62ee8a06 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -14,6 +14,7 @@ #include "clientversion.h" #include "consensus/consensus.h" #include "crypto/common.h" +#include "crypto/sha256.h" #include "hash.h" #include "primitives/transaction.h" #include "scheduler.h" @@ -838,6 +839,7 @@ struct NodeEvictionCandidate int64_t nTimeConnected; int64_t nMinPingUsecTime; CAddress addr; + std::vector vchKeyedNetGroup; }; static bool ReverseCompareNodeMinPingTime(const NodeEvictionCandidate &a, const NodeEvictionCandidate &b) @@ -850,36 +852,8 @@ static bool ReverseCompareNodeTimeConnected(const NodeEvictionCandidate &a, cons return a.nTimeConnected > b.nTimeConnected; } -class CompareNetGroupKeyed -{ - std::vector vchSecretKey; -public: - CompareNetGroupKeyed() - { - vchSecretKey.resize(32, 0); - GetRandBytes(vchSecretKey.data(), vchSecretKey.size()); - } - - bool operator()(const NodeEvictionCandidate &a, const NodeEvictionCandidate &b) - { - std::vector vchGroupA, vchGroupB; - CSHA256 hashA, hashB; - std::vector vchA(32), vchB(32); - - vchGroupA = a.addr.GetGroup(); - vchGroupB = b.addr.GetGroup(); - - hashA.Write(begin_ptr(vchGroupA), vchGroupA.size()); - hashB.Write(begin_ptr(vchGroupB), vchGroupB.size()); - - hashA.Write(begin_ptr(vchSecretKey), vchSecretKey.size()); - hashB.Write(begin_ptr(vchSecretKey), vchSecretKey.size()); - - hashA.Finalize(begin_ptr(vchA)); - hashB.Finalize(begin_ptr(vchB)); - - return vchA < vchB; - } +static bool CompareNetGroupKeyed(const NodeEvictionCandidate &a, const NodeEvictionCandidate &b) { + return a.vchKeyedNetGroup < b.vchKeyedNetGroup; }; /** Try to find a connection to evict when the node is full. @@ -902,7 +876,7 @@ static bool AttemptToEvictConnection(bool fPreferNewConnection) { continue; if (node->fDisconnect) continue; - NodeEvictionCandidate candidate = {node->id, node->nTimeConnected, node->nMinPingUsecTime, node->addr}; + NodeEvictionCandidate candidate = {node->id, node->nTimeConnected, node->nMinPingUsecTime, node->addr, node->vchKeyedNetGroup}; vEvictionCandidates.push_back(candidate); } } @@ -912,9 +886,8 @@ static bool AttemptToEvictConnection(bool fPreferNewConnection) { // Protect connections with certain characteristics // Deterministically select 4 peers to protect by netgroup. - // An attacker cannot predict which netgroups will be protected. - static CompareNetGroupKeyed comparerNetGroupKeyed; - std::sort(vEvictionCandidates.begin(), vEvictionCandidates.end(), comparerNetGroupKeyed); + // An attacker cannot predict which netgroups will be protected + std::sort(vEvictionCandidates.begin(), vEvictionCandidates.end(), CompareNetGroupKeyed); vEvictionCandidates.erase(vEvictionCandidates.end() - std::min(4, static_cast(vEvictionCandidates.size())), vEvictionCandidates.end()); if (vEvictionCandidates.empty()) return false; @@ -2392,6 +2365,8 @@ CNode::CNode(SOCKET hSocketIn, const CAddress& addrIn, const std::string& addrNa lastSentFeeFilter = 0; nextSendTimeFeeFilter = 0; + CalculateKeyedNetGroup(); + BOOST_FOREACH(const std::string &msg, getAllNetMessageTypes()) mapRecvBytesPerMsgCmd[msg] = 0; mapRecvBytesPerMsgCmd[NET_MESSAGE_COMMAND_OTHER] = 0; diff --git a/src/net.h b/src/net.h index 403653e8c8..019a3f7ee3 100644 --- a/src/net.h +++ b/src/net.h @@ -9,6 +9,8 @@ #include "amount.h" #include "bloom.h" #include "compat.h" +#include "crypto/common.h" +#include "crypto/sha256.h" #include "limitedmap.h" #include "netbase.h" #include "protocol.h" @@ -362,6 +364,8 @@ public: CBloomFilter* pfilter; int nRefCount; NodeId id; + + std::vector vchKeyedNetGroup; protected: // Denial-of-service detection/prevention @@ -450,6 +454,22 @@ private: CNode(const CNode&); void operator=(const CNode&); + void CalculateKeyedNetGroup() { + static std::vector vchSecretKey; + if (vchSecretKey.empty()) { + vchSecretKey.resize(32, 0); + GetRandBytes(vchSecretKey.data(), vchSecretKey.size()); + } + + std::vector vchNetGroup(this->addr.GetGroup()); + + CSHA256 hash; + hash.Write(begin_ptr(vchNetGroup), vchNetGroup.size()); + hash.Write(begin_ptr(vchSecretKey), vchSecretKey.size()); + + vchKeyedNetGroup.resize(32, 0); + hash.Finalize(begin_ptr(vchKeyedNetGroup)); + } public: NodeId GetId() const {