diff --git a/src/interfaces/node.cpp b/src/interfaces/node.cpp index 33f0dac2632..969767b90fc 100644 --- a/src/interfaces/node.cpp +++ b/src/interfaces/node.cpp @@ -56,6 +56,7 @@ namespace { class NodeImpl : public Node { public: + NodeImpl(NodeContext* context) { setContext(context); } void initError(const bilingual_str& message) override { InitError(message); } bool parseParameters(int argc, const char* const argv[], std::string& error) override { @@ -81,13 +82,13 @@ public: } bool appInitMain() override { - m_context.chain = MakeChain(m_context); - return AppInitMain(m_context_ref, m_context); + m_context->chain = MakeChain(*m_context); + return AppInitMain(m_context_ref, *m_context); } void appShutdown() override { - Interrupt(m_context); - Shutdown(m_context); + Interrupt(*m_context); + Shutdown(*m_context); } void startShutdown() override { @@ -108,19 +109,19 @@ public: StopMapPort(); } } - void setupServerArgs() override { return SetupServerArgs(m_context); } + void setupServerArgs() override { return SetupServerArgs(*m_context); } bool getProxy(Network net, proxyType& proxy_info) override { return GetProxy(net, proxy_info); } size_t getNodeCount(CConnman::NumConnections flags) override { - return m_context.connman ? m_context.connman->GetNodeCount(flags) : 0; + return m_context->connman ? m_context->connman->GetNodeCount(flags) : 0; } bool getNodesStats(NodesStats& stats) override { stats.clear(); - if (m_context.connman) { + if (m_context->connman) { std::vector stats_temp; - m_context.connman->GetNodeStats(stats_temp); + m_context->connman->GetNodeStats(stats_temp); stats.reserve(stats_temp.size()); for (auto& node_stats_temp : stats_temp) { @@ -141,46 +142,46 @@ public: } bool getBanned(banmap_t& banmap) override { - if (m_context.banman) { - m_context.banman->GetBanned(banmap); + if (m_context->banman) { + m_context->banman->GetBanned(banmap); return true; } return false; } bool ban(const CNetAddr& net_addr, int64_t ban_time_offset) override { - if (m_context.banman) { - m_context.banman->Ban(net_addr, ban_time_offset); + if (m_context->banman) { + m_context->banman->Ban(net_addr, ban_time_offset); return true; } return false; } bool unban(const CSubNet& ip) override { - if (m_context.banman) { - m_context.banman->Unban(ip); + if (m_context->banman) { + m_context->banman->Unban(ip); return true; } return false; } bool disconnectByAddress(const CNetAddr& net_addr) override { - if (m_context.connman) { - return m_context.connman->DisconnectNode(net_addr); + if (m_context->connman) { + return m_context->connman->DisconnectNode(net_addr); } return false; } bool disconnectById(NodeId id) override { - if (m_context.connman) { - return m_context.connman->DisconnectNode(id); + if (m_context->connman) { + return m_context->connman->DisconnectNode(id); } return false; } - int64_t getTotalBytesRecv() override { return m_context.connman ? m_context.connman->GetTotalBytesRecv() : 0; } - int64_t getTotalBytesSent() override { return m_context.connman ? m_context.connman->GetTotalBytesSent() : 0; } - size_t getMempoolSize() override { return m_context.mempool ? m_context.mempool->size() : 0; } - size_t getMempoolDynamicUsage() override { return m_context.mempool ? m_context.mempool->DynamicMemoryUsage() : 0; } + int64_t getTotalBytesRecv() override { return m_context->connman ? m_context->connman->GetTotalBytesRecv() : 0; } + int64_t getTotalBytesSent() override { return m_context->connman ? m_context->connman->GetTotalBytesSent() : 0; } + size_t getMempoolSize() override { return m_context->mempool ? m_context->mempool->size() : 0; } + size_t getMempoolDynamicUsage() override { return m_context->mempool ? m_context->mempool->DynamicMemoryUsage() : 0; } bool getHeaderTip(int& height, int64_t& block_time) override { LOCK(::cs_main); @@ -223,11 +224,11 @@ public: bool getImporting() override { return ::fImporting; } void setNetworkActive(bool active) override { - if (m_context.connman) { - m_context.connman->SetNetworkActive(active); + if (m_context->connman) { + m_context->connman->SetNetworkActive(active); } } - bool getNetworkActive() override { return m_context.connman && m_context.connman->GetNetworkActive(); } + bool getNetworkActive() override { return m_context->connman && m_context->connman->GetNetworkActive(); } CFeeRate estimateSmartFee(int num_blocks, bool conservative, int* returned_target = nullptr) override { FeeCalculation fee_calc; @@ -269,7 +270,7 @@ public: std::vector> getWallets() override { std::vector> wallets; - for (auto& client : m_context.chain_clients) { + for (auto& client : m_context->chain_clients) { auto client_wallets = client->getWallets(); std::move(client_wallets.begin(), client_wallets.end(), std::back_inserter(wallets)); } @@ -277,12 +278,12 @@ public: } std::unique_ptr loadWallet(const std::string& name, bilingual_str& error, std::vector& warnings) override { - return MakeWallet(LoadWallet(*m_context.chain, name, error, warnings)); + return MakeWallet(LoadWallet(*m_context->chain, name, error, warnings)); } std::unique_ptr createWallet(const SecureString& passphrase, uint64_t wallet_creation_flags, const std::string& name, bilingual_str& error, std::vector& warnings, WalletCreationStatus& status) override { std::shared_ptr wallet; - status = CreateWallet(*m_context.chain, passphrase, wallet_creation_flags, name, error, warnings, wallet); + status = CreateWallet(*m_context->chain, passphrase, wallet_creation_flags, name, error, warnings, wallet); return MakeWallet(wallet); } std::unique_ptr handleInitMessage(InitMessageFn fn) override @@ -336,13 +337,22 @@ public: /* verification progress is unused when a header was received */ 0); })); } - NodeContext* context() override { return &m_context; } - NodeContext m_context; - util::Ref m_context_ref{m_context}; + NodeContext* context() override { return m_context; } + void setContext(NodeContext* context) override + { + m_context = context; + if (context) { + m_context_ref.Set(*context); + } else { + m_context_ref.Clear(); + } + } + NodeContext* m_context{nullptr}; + util::Ref m_context_ref; }; } // namespace -std::unique_ptr MakeNode() { return MakeUnique(); } +std::unique_ptr MakeNode(NodeContext* context) { return MakeUnique(context); } } // namespace interfaces diff --git a/src/interfaces/node.h b/src/interfaces/node.h index a9680c42b51..cd3cfe487d5 100644 --- a/src/interfaces/node.h +++ b/src/interfaces/node.h @@ -268,12 +268,14 @@ public: std::function; virtual std::unique_ptr handleNotifyHeaderTip(NotifyHeaderTipFn fn) = 0; - //! Return pointer to internal chain interface, useful for testing. + //! Get and set internal node context. Useful for testing, but not + //! accessible across processes. virtual NodeContext* context() { return nullptr; } + virtual void setContext(NodeContext* context) { } }; //! Return implementation of Node interface. -std::unique_ptr MakeNode(); +std::unique_ptr MakeNode(NodeContext* context = nullptr); //! Block tip (could be a header or not, depends on the subscribed signal). struct BlockTip { diff --git a/src/qt/bitcoin.cpp b/src/qt/bitcoin.cpp index 523f5c429b9..4f1e0056be9 100644 --- a/src/qt/bitcoin.cpp +++ b/src/qt/bitcoin.cpp @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -430,7 +431,8 @@ int GuiMain(int argc, char* argv[]) SetupEnvironment(); util::ThreadSetInternalName("main"); - std::unique_ptr node = interfaces::MakeNode(); + NodeContext node_context; + std::unique_ptr node = interfaces::MakeNode(&node_context); // Subscribe to global signals from core std::unique_ptr handler_message_box = node->handleMessageBox(noui_ThreadSafeMessageBox); diff --git a/src/qt/test/addressbooktests.cpp b/src/qt/test/addressbooktests.cpp index 9347ff9e427..035c8196bc3 100644 --- a/src/qt/test/addressbooktests.cpp +++ b/src/qt/test/addressbooktests.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -59,6 +60,7 @@ void EditAddressAndSubmit( void TestAddAddressesToSendBook(interfaces::Node& node) { TestChain100Setup test; + node.setContext(&test.m_node); std::shared_ptr wallet = std::make_shared(node.context()->chain.get(), WalletLocation(), CreateMockWalletDatabase()); wallet->SetupLegacyScriptPubKeyMan(); bool firstRun; diff --git a/src/qt/test/test_main.cpp b/src/qt/test/test_main.cpp index 12efca25036..031913bd026 100644 --- a/src/qt/test/test_main.cpp +++ b/src/qt/test/test_main.cpp @@ -52,7 +52,8 @@ int main(int argc, char* argv[]) BasicTestingSetup dummy{CBaseChainParams::REGTEST}; } - std::unique_ptr node = interfaces::MakeNode(); + NodeContext node_context; + std::unique_ptr node = interfaces::MakeNode(&node_context); bool fInvalid = false; diff --git a/src/qt/test/wallettests.cpp b/src/qt/test/wallettests.cpp index 6648029bae8..475fd589af1 100644 --- a/src/qt/test/wallettests.cpp +++ b/src/qt/test/wallettests.cpp @@ -138,8 +138,7 @@ void TestGUI(interfaces::Node& node) for (int i = 0; i < 5; ++i) { test.CreateAndProcessBlock({}, GetScriptForRawPubKey(test.coinbaseKey.GetPubKey())); } - node.context()->connman = std::move(test.m_node.connman); - node.context()->mempool = std::move(test.m_node.mempool); + node.setContext(&test.m_node); std::shared_ptr wallet = std::make_shared(node.context()->chain.get(), WalletLocation(), CreateMockWalletDatabase()); bool firstRun; wallet->LoadWallet(firstRun); diff --git a/src/test/util/setup_common.cpp b/src/test/util/setup_common.cpp index 14f65dcb7c9..b2ae1cb845d 100644 --- a/src/test/util/setup_common.cpp +++ b/src/test/util/setup_common.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -32,6 +33,7 @@ #include #include #include +#include #include @@ -104,6 +106,8 @@ BasicTestingSetup::BasicTestingSetup(const std::string& chainName, const std::ve SetupNetworking(); InitSignatureCache(); InitScriptExecutionCache(); + m_node.chain = interfaces::MakeChain(m_node); + g_wallet_init_interface.Construct(m_node); fCheckBlockIndex = true; static bool noui_connected = false; if (!noui_connected) {