From 62a09a30772141ef4add2f10d29927211abf57eb Mon Sep 17 00:00:00 2001 From: Russell Yanofsky Date: Thu, 28 May 2020 13:06:43 -0400 Subject: [PATCH] refactor: remove ::vpwallets and related global variables Move global wallet variables to WalletContext struct --- src/interfaces/wallet.h | 5 +- src/qt/test/addressbooktests.cpp | 7 +-- src/qt/test/wallettests.cpp | 7 +-- src/wallet/context.h | 16 +++++- src/wallet/interfaces.cpp | 30 +++++------ src/wallet/load.cpp | 35 +++++++------ src/wallet/load.h | 13 ++--- src/wallet/rpcwallet.cpp | 18 ++++--- src/wallet/test/wallet_tests.cpp | 48 +++++++++++------- src/wallet/wallet.cpp | 86 ++++++++++++++++---------------- src/wallet/wallet.h | 22 ++++---- src/wallet/walletdb.cpp | 4 +- src/wallet/walletdb.h | 3 +- 13 files changed, 167 insertions(+), 127 deletions(-) diff --git a/src/interfaces/wallet.h b/src/interfaces/wallet.h index fb1febc11b..a85db04b8b 100644 --- a/src/interfaces/wallet.h +++ b/src/interfaces/wallet.h @@ -332,6 +332,9 @@ public: //! loaded at startup or by RPC. using LoadWalletFn = std::function wallet)>; virtual std::unique_ptr handleLoadWallet(LoadWalletFn fn) = 0; + + //! Return pointer to internal context, useful for testing. + virtual WalletContext* context() { return nullptr; } }; //! Information about one wallet address. @@ -410,7 +413,7 @@ struct WalletTxOut //! Return implementation of Wallet interface. This function is defined in //! dummywallet.cpp and throws if the wallet component is not compiled. -std::unique_ptr MakeWallet(const std::shared_ptr& wallet); +std::unique_ptr MakeWallet(WalletContext& context, const std::shared_ptr& wallet); //! Return implementation of ChainClient interface for a wallet client. This //! function will be undefined in builds where ENABLE_WALLET is false. diff --git a/src/qt/test/addressbooktests.cpp b/src/qt/test/addressbooktests.cpp index 39c69fe184..022f367422 100644 --- a/src/qt/test/addressbooktests.cpp +++ b/src/qt/test/addressbooktests.cpp @@ -109,9 +109,10 @@ void TestAddAddressesToSendBook(interfaces::Node& node) std::unique_ptr platformStyle(PlatformStyle::instantiate("other")); OptionsModel optionsModel; ClientModel clientModel(node, &optionsModel); - AddWallet(wallet); - WalletModel walletModel(interfaces::MakeWallet(wallet), clientModel, platformStyle.get()); - RemoveWallet(wallet, std::nullopt); + WalletContext& context = *node.walletClient().context(); + AddWallet(context, wallet); + WalletModel walletModel(interfaces::MakeWallet(context, wallet), clientModel, platformStyle.get()); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); EditAddressDialog editAddressDialog(EditAddressDialog::NewSendingAddress); editAddressDialog.setModel(walletModel.getAddressTableModel()); diff --git a/src/qt/test/wallettests.cpp b/src/qt/test/wallettests.cpp index e883337fb5..1976bee74b 100644 --- a/src/qt/test/wallettests.cpp +++ b/src/qt/test/wallettests.cpp @@ -164,9 +164,10 @@ void TestGUI(interfaces::Node& node) TransactionView transactionView(platformStyle.get()); OptionsModel optionsModel; ClientModel clientModel(node, &optionsModel); - AddWallet(wallet); - WalletModel walletModel(interfaces::MakeWallet(wallet), clientModel, platformStyle.get()); - RemoveWallet(wallet, std::nullopt); + WalletContext& context = *node.walletClient().context(); + AddWallet(context, wallet); + WalletModel walletModel(interfaces::MakeWallet(context, wallet), clientModel, platformStyle.get()); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); sendCoinsDialog.setModel(&walletModel); transactionView.setModel(&walletModel); diff --git a/src/wallet/context.h b/src/wallet/context.h index a83591154f..a382fb9021 100644 --- a/src/wallet/context.h +++ b/src/wallet/context.h @@ -5,11 +5,22 @@ #ifndef BITCOIN_WALLET_CONTEXT_H #define BITCOIN_WALLET_CONTEXT_H +#include + +#include +#include +#include +#include + class ArgsManager; +class CWallet; namespace interfaces { class Chain; +class Wallet; } // namespace interfaces +using LoadWalletFn = std::function wallet)>; + //! WalletContext struct containing references to state shared between CWallet //! instances, like the reference to the chain interface, and the list of opened //! wallets. @@ -22,7 +33,10 @@ class Chain; //! behavior. struct WalletContext { interfaces::Chain* chain{nullptr}; - ArgsManager* args{nullptr}; + ArgsManager* args{nullptr}; // Currently a raw pointer because the memory is not managed by this struct + Mutex wallets_mutex; + std::vector> wallets GUARDED_BY(wallets_mutex); + std::list wallet_load_fns GUARDED_BY(wallets_mutex); //! Declare default constructor and destructor that are not inline, so code //! instantiating the WalletContext struct doesn't need to #include class diff --git a/src/wallet/interfaces.cpp b/src/wallet/interfaces.cpp index 2c891c3c1e..0d4b98ecaf 100644 --- a/src/wallet/interfaces.cpp +++ b/src/wallet/interfaces.cpp @@ -110,7 +110,7 @@ WalletTxOut MakeWalletTxOut(const CWallet& wallet, class WalletImpl : public Wallet { public: - explicit WalletImpl(const std::shared_ptr& wallet) : m_wallet(wallet) {} + explicit WalletImpl(WalletContext& context, const std::shared_ptr& wallet) : m_context(context), m_wallet(wallet) {} bool encryptWallet(const SecureString& wallet_passphrase) override { @@ -458,7 +458,7 @@ public: CAmount getDefaultMaxTxFee() override { return m_wallet->m_default_max_tx_fee; } void remove() override { - RemoveWallet(m_wallet, false /* load_on_start */); + RemoveWallet(m_context, m_wallet, false /* load_on_start */); } bool isLegacy() override { return m_wallet->IsLegacy(); } std::unique_ptr handleUnload(UnloadFn fn) override @@ -494,6 +494,7 @@ public: } CWallet* wallet() override { return m_wallet.get(); } + WalletContext& m_context; std::shared_ptr m_wallet; }; @@ -505,7 +506,7 @@ public: m_context.chain = &chain; m_context.args = &args; } - ~WalletClientImpl() override { UnloadWallets(); } + ~WalletClientImpl() override { UnloadWallets(m_context); } //! ChainClient methods void registerRpcs() override @@ -519,11 +520,11 @@ public: m_rpc_handlers.emplace_back(m_context.chain->handleRpc(m_rpc_commands.back())); } } - bool verify() override { return VerifyWallets(*m_context.chain); } - bool load() override { return LoadWallets(*m_context.chain); } - void start(CScheduler& scheduler) override { return StartWallets(scheduler, *Assert(m_context.args)); } - void flush() override { return FlushWallets(); } - void stop() override { return StopWallets(); } + bool verify() override { return VerifyWallets(m_context); } + bool load() override { return LoadWallets(m_context); } + void start(CScheduler& scheduler) override { return StartWallets(m_context, scheduler); } + void flush() override { return FlushWallets(m_context); } + void stop() override { return StopWallets(m_context); } void setMockTime(int64_t time) override { return SetMockTime(time); } //! WalletClient methods @@ -535,14 +536,14 @@ public: options.require_create = true; options.create_flags = wallet_creation_flags; options.create_passphrase = passphrase; - return MakeWallet(CreateWallet(*m_context.chain, name, true /* load_on_start */, options, status, error, warnings)); + return MakeWallet(m_context, CreateWallet(m_context, name, true /* load_on_start */, options, status, error, warnings)); } std::unique_ptr loadWallet(const std::string& name, bilingual_str& error, std::vector& warnings) override { DatabaseOptions options; DatabaseStatus status; options.require_existing = true; - return MakeWallet(LoadWallet(*m_context.chain, name, true /* load_on_start */, options, status, error, warnings)); + return MakeWallet(m_context, LoadWallet(m_context, name, true /* load_on_start */, options, status, error, warnings)); } std::string getWalletDir() override { @@ -559,15 +560,16 @@ public: std::vector> getWallets() override { std::vector> wallets; - for (const auto& wallet : GetWallets()) { - wallets.emplace_back(MakeWallet(wallet)); + for (const auto& wallet : GetWallets(m_context)) { + wallets.emplace_back(MakeWallet(m_context, wallet)); } return wallets; } std::unique_ptr handleLoadWallet(LoadWalletFn fn) override { - return HandleLoadWallet(std::move(fn)); + return HandleLoadWallet(m_context, std::move(fn)); } + WalletContext* context() override { return &m_context; } WalletContext m_context; const std::vector m_wallet_filenames; @@ -578,7 +580,7 @@ public: } // namespace wallet namespace interfaces { -std::unique_ptr MakeWallet(const std::shared_ptr& wallet) { return wallet ? std::make_unique(wallet) : nullptr; } +std::unique_ptr MakeWallet(WalletContext& context, const std::shared_ptr& wallet) { return wallet ? std::make_unique(context, wallet) : nullptr; } std::unique_ptr MakeWalletClient(Chain& chain, ArgsManager& args) { diff --git a/src/wallet/load.cpp b/src/wallet/load.cpp index dbf9fd46b6..9009f93dbc 100644 --- a/src/wallet/load.cpp +++ b/src/wallet/load.cpp @@ -11,13 +11,15 @@ #include #include #include +#include #include #include #include -bool VerifyWallets(interfaces::Chain& chain) +bool VerifyWallets(WalletContext& context) { + interfaces::Chain& chain = *context.chain; if (gArgs.IsArgSet("-walletdir")) { fs::path wallet_dir = gArgs.GetArg("-walletdir", ""); boost::system::error_code error; @@ -87,8 +89,9 @@ bool VerifyWallets(interfaces::Chain& chain) return true; } -bool LoadWallets(interfaces::Chain& chain) +bool LoadWallets(WalletContext& context) { + interfaces::Chain& chain = *context.chain; try { std::set wallet_paths; for (const std::string& name : gArgs.GetArgs("-wallet")) { @@ -106,13 +109,13 @@ bool LoadWallets(interfaces::Chain& chain) continue; } chain.initMessage(_("Loading wallet…").translated); - std::shared_ptr pwallet = database ? CWallet::Create(&chain, name, std::move(database), options.create_flags, error, warnings) : nullptr; + std::shared_ptr pwallet = database ? CWallet::Create(context, name, std::move(database), options.create_flags, error, warnings) : nullptr; if (!warnings.empty()) chain.initWarning(Join(warnings, Untranslated("\n"))); if (!pwallet) { chain.initError(error); return false; } - AddWallet(pwallet); + AddWallet(context, pwallet); } return true; } catch (const std::runtime_error& e) { @@ -121,41 +124,41 @@ bool LoadWallets(interfaces::Chain& chain) } } -void StartWallets(CScheduler& scheduler, const ArgsManager& args) +void StartWallets(WalletContext& context, CScheduler& scheduler) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { pwallet->postInitProcess(); } // Schedule periodic wallet flushes and tx rebroadcasts - if (args.GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) { - scheduler.scheduleEvery(MaybeCompactWalletDB, std::chrono::milliseconds{500}); + if (context.args->GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) { + scheduler.scheduleEvery([&context] { MaybeCompactWalletDB(context); }, std::chrono::milliseconds{500}); } - scheduler.scheduleEvery(MaybeResendWalletTxs, std::chrono::milliseconds{1000}); + scheduler.scheduleEvery([&context] { MaybeResendWalletTxs(context); }, std::chrono::milliseconds{1000}); } -void FlushWallets() +void FlushWallets(WalletContext& context) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { pwallet->Flush(); } } -void StopWallets() +void StopWallets(WalletContext& context) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { pwallet->Close(); } } -void UnloadWallets() +void UnloadWallets(WalletContext& context) { - auto wallets = GetWallets(); + auto wallets = GetWallets(context); while (!wallets.empty()) { auto wallet = wallets.back(); wallets.pop_back(); std::vector warnings; - RemoveWallet(wallet, std::nullopt, warnings); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt, warnings); UnloadWallet(std::move(wallet)); } } diff --git a/src/wallet/load.h b/src/wallet/load.h index 7910f0d6e1..e207bc2e09 100644 --- a/src/wallet/load.h +++ b/src/wallet/load.h @@ -11,27 +11,28 @@ class ArgsManager; class CScheduler; +struct WalletContext; namespace interfaces { class Chain; } // namespace interfaces //! Responsible for reading and validating the -wallet arguments and verifying the wallet database. -bool VerifyWallets(interfaces::Chain& chain); +bool VerifyWallets(WalletContext& context); //! Load wallet databases. -bool LoadWallets(interfaces::Chain& chain); +bool LoadWallets(WalletContext& context); //! Complete startup of wallets. -void StartWallets(CScheduler& scheduler, const ArgsManager& args); +void StartWallets(WalletContext& context, CScheduler& scheduler); //! Flush all wallets in preparation for shutdown. -void FlushWallets(); +void FlushWallets(WalletContext& context); //! Stop all wallets. Wallets will be flushed first. -void StopWallets(); +void StopWallets(WalletContext& context); //! Close all wallets. -void UnloadWallets(); +void UnloadWallets(WalletContext& context); #endif // BITCOIN_WALLET_LOAD_H diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index 2a5b547858..916f811f9b 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -96,14 +96,16 @@ bool GetWalletNameFromJSONRPCRequest(const JSONRPCRequest& request, std::string& std::shared_ptr GetWalletForJSONRPCRequest(const JSONRPCRequest& request) { CHECK_NONFATAL(request.mode == JSONRPCRequest::EXECUTE); + WalletContext& context = EnsureWalletContext(request.context); + std::string wallet_name; if (GetWalletNameFromJSONRPCRequest(request, wallet_name)) { - std::shared_ptr pwallet = GetWallet(wallet_name); + std::shared_ptr pwallet = GetWallet(context, wallet_name); if (!pwallet) throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded"); return pwallet; } - std::vector> wallets = GetWallets(); + std::vector> wallets = GetWallets(context); if (wallets.size() == 1) { return wallets[0]; } @@ -2562,7 +2564,8 @@ static RPCHelpMan listwallets() { UniValue obj(UniValue::VARR); - for (const std::shared_ptr& wallet : GetWallets()) { + WalletContext& context = EnsureWalletContext(request.context); + for (const std::shared_ptr& wallet : GetWallets(context)) { LOCK(wallet->cs_wallet); obj.push_back(wallet->GetName()); } @@ -2580,7 +2583,7 @@ static std::tuple, std::vector> LoadWall bilingual_str error; std::vector warnings; std::optional load_on_start = load_on_start_param.isNull() ? std::nullopt : std::optional(load_on_start_param.get_bool()); - std::shared_ptr const wallet = LoadWallet(*context.chain, wallet_name, load_on_start, options, status, error, warnings); + std::shared_ptr const wallet = LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings); if (!wallet) { // Map bad format to not found, since bad format is returned when the @@ -2788,7 +2791,7 @@ static RPCHelpMan createwallet() options.create_passphrase = passphrase; bilingual_str error; std::optional load_on_start = request.params[6].isNull() ? std::nullopt : std::optional(request.params[6].get_bool()); - std::shared_ptr wallet = CreateWallet(*context.chain, request.params[0].get_str(), load_on_start, options, status, error, warnings); + std::shared_ptr wallet = CreateWallet(context, request.params[0].get_str(), load_on_start, options, status, error, warnings); if (!wallet) { RPCErrorCode code = status == DatabaseStatus::FAILED_ENCRYPT ? RPC_WALLET_ENCRYPTION_FAILED : RPC_WALLET_ERROR; throw JSONRPCError(code, error.original); @@ -2892,7 +2895,8 @@ static RPCHelpMan unloadwallet() wallet_name = request.params[0].get_str(); } - std::shared_ptr wallet = GetWallet(wallet_name); + WalletContext& context = EnsureWalletContext(request.context); + std::shared_ptr wallet = GetWallet(context, wallet_name); if (!wallet) { throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded"); } @@ -2902,7 +2906,7 @@ static RPCHelpMan unloadwallet() // is destroyed (see CheckUniqueFileid). std::vector warnings; std::optional load_on_start = request.params[1].isNull() ? std::nullopt : std::optional(request.params[1].get_bool()); - if (!RemoveWallet(wallet, load_on_start, warnings)) { + if (!RemoveWallet(context, wallet, load_on_start, warnings)) { throw JSONRPCError(RPC_MISC_ERROR, "Requested wallet already unloaded"); } diff --git a/src/wallet/test/wallet_tests.cpp b/src/wallet/test/wallet_tests.cpp index c8c5215e1b..9fcea5826b 100644 --- a/src/wallet/test/wallet_tests.cpp +++ b/src/wallet/test/wallet_tests.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -30,8 +31,6 @@ RPCHelpMan importmulti(); RPCHelpMan dumpwallet(); RPCHelpMan importwallet(); -extern RecursiveMutex cs_wallets; - // Ensure that fee levels defined in the wallet are at least as high // as the default levels for node policy. static_assert(DEFAULT_TRANSACTION_MINFEE >= DEFAULT_MIN_RELAY_TX_FEE, "wallet minimum fee is smaller than default relay fee"); @@ -39,15 +38,15 @@ static_assert(WALLET_INCREMENTAL_RELAY_FEE >= DEFAULT_INCREMENTAL_RELAY_FEE, "wa BOOST_FIXTURE_TEST_SUITE(wallet_tests, WalletTestingSetup) -static std::shared_ptr TestLoadWallet(interfaces::Chain* chain) +static std::shared_ptr TestLoadWallet(WalletContext& context) { DatabaseOptions options; DatabaseStatus status; bilingual_str error; std::vector warnings; auto database = MakeWalletDatabase("", options, status, error); - auto wallet = CWallet::Create(chain, "", std::move(database), options.create_flags, error, warnings); - if (chain) { + auto wallet = CWallet::Create(context, "", std::move(database), options.create_flags, error, warnings); + if (context.chain) { wallet->postInitProcess(); } return wallet; @@ -200,7 +199,8 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) std::shared_ptr wallet = std::make_shared(m_node.chain.get(), "", CreateDummyWalletDatabase()); wallet->SetupLegacyScriptPubKeyMan(); WITH_LOCK(wallet->cs_wallet, wallet->SetLastBlockProcessed(newTip->nHeight, newTip->GetBlockHash())); - AddWallet(wallet); + WalletContext context; + AddWallet(context, wallet); UniValue keys; keys.setArray(); UniValue key; @@ -218,6 +218,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) key.pushKV("internal", UniValue(true)); keys.push_back(key); JSONRPCRequest request; + request.context = &context; request.params.setArray(); request.params.push_back(keys); @@ -231,7 +232,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) "downloading and rescanning the relevant blocks (see -reindex and -rescan " "options).\"}},{\"success\":true}]", 0, oldTip->GetBlockTimeMax(), TIMESTAMP_WINDOW)); - RemoveWallet(wallet, std::nullopt); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); } } @@ -258,6 +259,7 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) // Import key into wallet and call dumpwallet to create backup file. { + WalletContext context; std::shared_ptr wallet = std::make_shared(m_node.chain.get(), "", CreateDummyWalletDatabase()); { auto spk_man = wallet->GetOrCreateLegacyScriptPubKeyMan(); @@ -265,15 +267,16 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) spk_man->mapKeyMetadata[coinbaseKey.GetPubKey().GetID()].nCreateTime = KEY_TIME; spk_man->AddKeyPubKey(coinbaseKey, coinbaseKey.GetPubKey()); - AddWallet(wallet); + AddWallet(context, wallet); wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash()); } JSONRPCRequest request; + request.context = &context; request.params.setArray(); request.params.push_back(backup_file); ::dumpwallet().HandleRequest(request); - RemoveWallet(wallet, std::nullopt); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); } // Call importwallet RPC and verify all blocks with timestamps >= BLOCK_TIME @@ -283,13 +286,15 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) LOCK(wallet->cs_wallet); wallet->SetupLegacyScriptPubKeyMan(); + WalletContext context; JSONRPCRequest request; + request.context = &context; request.params.setArray(); request.params.push_back(backup_file); - AddWallet(wallet); + AddWallet(context, wallet); wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash()); ::importwallet().HandleRequest(request); - RemoveWallet(wallet, std::nullopt); + RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt); BOOST_CHECK_EQUAL(wallet->mapWallet.size(), 3U); BOOST_CHECK_EQUAL(m_coinbase_txns.size(), 103U); @@ -679,7 +684,9 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) { gArgs.ForceSetArg("-unsafesqlitesync", "1"); // Create new wallet with known key and unload it. - auto wallet = TestLoadWallet(m_node.chain.get()); + WalletContext context; + context.chain = m_node.chain.get(); + auto wallet = TestLoadWallet(context); CKey key; key.MakeNewKey(true); AddKey(*wallet, key); @@ -719,7 +726,7 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) // Reload wallet and make sure new transactions are detected despite events // being blocked - wallet = TestLoadWallet(m_node.chain.get()); + wallet = TestLoadWallet(context); BOOST_CHECK(rescan_completed); BOOST_CHECK_EQUAL(addtx_count, 2); { @@ -746,20 +753,20 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) // deadlock during the sync and simulates a new block notification happening // as soon as possible. addtx_count = 0; - auto handler = HandleLoadWallet([&](std::unique_ptr wallet) EXCLUSIVE_LOCKS_REQUIRED(wallet->wallet()->cs_wallet, cs_wallets) { + auto handler = HandleLoadWallet(context, [&](std::unique_ptr wallet) EXCLUSIVE_LOCKS_REQUIRED(wallet->wallet()->cs_wallet, context.wallets_mutex) { BOOST_CHECK(rescan_completed); m_coinbase_txns.push_back(CreateAndProcessBlock({}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]); block_tx = TestSimpleSpend(*m_coinbase_txns[2], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey())); m_coinbase_txns.push_back(CreateAndProcessBlock({block_tx}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]); mempool_tx = TestSimpleSpend(*m_coinbase_txns[3], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey())); BOOST_CHECK(m_node.chain->broadcastTransaction(MakeTransactionRef(mempool_tx), DEFAULT_TRANSACTION_MAXFEE, false, error)); - LEAVE_CRITICAL_SECTION(cs_wallets); + LEAVE_CRITICAL_SECTION(context.wallets_mutex); LEAVE_CRITICAL_SECTION(wallet->wallet()->cs_wallet); SyncWithValidationInterfaceQueue(); ENTER_CRITICAL_SECTION(wallet->wallet()->cs_wallet); - ENTER_CRITICAL_SECTION(cs_wallets); + ENTER_CRITICAL_SECTION(context.wallets_mutex); }); - wallet = TestLoadWallet(m_node.chain.get()); + wallet = TestLoadWallet(context); BOOST_CHECK_EQUAL(addtx_count, 4); { LOCK(wallet->cs_wallet); @@ -773,7 +780,8 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup) BOOST_FIXTURE_TEST_CASE(CreateWalletWithoutChain, BasicTestingSetup) { - auto wallet = TestLoadWallet(nullptr); + WalletContext context; + auto wallet = TestLoadWallet(context); BOOST_CHECK(wallet); UnloadWallet(std::move(wallet)); } @@ -781,7 +789,9 @@ BOOST_FIXTURE_TEST_CASE(CreateWalletWithoutChain, BasicTestingSetup) BOOST_FIXTURE_TEST_CASE(ZapSelectTx, TestChain100Setup) { gArgs.ForceSetArg("-unsafesqlitesync", "1"); - auto wallet = TestLoadWallet(m_node.chain.get()); + WalletContext context; + context.chain = m_node.chain.get(); + auto wallet = TestLoadWallet(context); CKey key; key.MakeNewKey(true); AddKey(*wallet, key); diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index e6227048d2..cf869fac0c 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -54,10 +55,6 @@ const std::map WALLET_FLAG_CAVEATS{ }, }; -RecursiveMutex cs_wallets; -static std::vector> vpwallets GUARDED_BY(cs_wallets); -static std::list g_load_wallet_fns GUARDED_BY(cs_wallets); - bool AddWalletSetting(interfaces::Chain& chain, const std::string& wallet_name) { util::SettingsValue setting_value = chain.getRwSetting("wallet"); @@ -104,19 +101,19 @@ static void RefreshMempoolStatus(CWalletTx& tx, interfaces::Chain& chain) tx.fInMempool = chain.isInMempool(tx.GetHash()); } -bool AddWallet(const std::shared_ptr& wallet) +bool AddWallet(WalletContext& context, const std::shared_ptr& wallet) { - LOCK(cs_wallets); + LOCK(context.wallets_mutex); assert(wallet); - std::vector>::const_iterator i = std::find(vpwallets.begin(), vpwallets.end(), wallet); - if (i != vpwallets.end()) return false; - vpwallets.push_back(wallet); + std::vector>::const_iterator i = std::find(context.wallets.begin(), context.wallets.end(), wallet); + if (i != context.wallets.end()) return false; + context.wallets.push_back(wallet); wallet->ConnectScriptPubKeyManNotifiers(); wallet->NotifyCanGetAddressesChanged(); return true; } -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings) +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings) { assert(wallet); @@ -125,10 +122,10 @@ bool RemoveWallet(const std::shared_ptr& wallet, std::optional lo // Unregister with the validation interface which also drops shared ponters. wallet->m_chain_notifications_handler.reset(); - LOCK(cs_wallets); - std::vector>::iterator i = std::find(vpwallets.begin(), vpwallets.end(), wallet); - if (i == vpwallets.end()) return false; - vpwallets.erase(i); + LOCK(context.wallets_mutex); + std::vector>::iterator i = std::find(context.wallets.begin(), context.wallets.end(), wallet); + if (i == context.wallets.end()) return false; + context.wallets.erase(i); // Write the wallet setting UpdateWalletSetting(chain, name, load_on_start, warnings); @@ -136,32 +133,32 @@ bool RemoveWallet(const std::shared_ptr& wallet, std::optional lo return true; } -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start) +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start) { std::vector warnings; - return RemoveWallet(wallet, load_on_start, warnings); + return RemoveWallet(context, wallet, load_on_start, warnings); } -std::vector> GetWallets() +std::vector> GetWallets(WalletContext& context) { - LOCK(cs_wallets); - return vpwallets; + LOCK(context.wallets_mutex); + return context.wallets; } -std::shared_ptr GetWallet(const std::string& name) +std::shared_ptr GetWallet(WalletContext& context, const std::string& name) { - LOCK(cs_wallets); - for (const std::shared_ptr& wallet : vpwallets) { + LOCK(context.wallets_mutex); + for (const std::shared_ptr& wallet : context.wallets) { if (wallet->GetName() == name) return wallet; } return nullptr; } -std::unique_ptr HandleLoadWallet(LoadWalletFn load_wallet) +std::unique_ptr HandleLoadWallet(WalletContext& context, LoadWalletFn load_wallet) { - LOCK(cs_wallets); - auto it = g_load_wallet_fns.emplace(g_load_wallet_fns.end(), std::move(load_wallet)); - return interfaces::MakeHandler([it] { LOCK(cs_wallets); g_load_wallet_fns.erase(it); }); + LOCK(context.wallets_mutex); + auto it = context.wallet_load_fns.emplace(context.wallet_load_fns.end(), std::move(load_wallet)); + return interfaces::MakeHandler([&context, it] { LOCK(context.wallets_mutex); context.wallet_load_fns.erase(it); }); } static Mutex g_loading_wallet_mutex; @@ -213,7 +210,7 @@ void UnloadWallet(std::shared_ptr&& wallet) } namespace { -std::shared_ptr LoadWalletInternal(interfaces::Chain& chain, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +std::shared_ptr LoadWalletInternal(WalletContext& context, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) { try { std::unique_ptr database = MakeWalletDatabase(name, options, status, error); @@ -222,18 +219,18 @@ std::shared_ptr LoadWalletInternal(interfaces::Chain& chain, const std: return nullptr; } - chain.initMessage(_("Loading wallet…").translated); - std::shared_ptr wallet = CWallet::Create(&chain, name, std::move(database), options.create_flags, error, warnings); + context.chain->initMessage(_("Loading wallet…").translated); + std::shared_ptr wallet = CWallet::Create(context, name, std::move(database), options.create_flags, error, warnings); if (!wallet) { error = Untranslated("Wallet loading failed.") + Untranslated(" ") + error; status = DatabaseStatus::FAILED_LOAD; return nullptr; } - AddWallet(wallet); + AddWallet(context, wallet); wallet->postInitProcess(); // Write the wallet setting - UpdateWalletSetting(chain, name, load_on_start, warnings); + UpdateWalletSetting(*context.chain, name, load_on_start, warnings); return wallet; } catch (const std::runtime_error& e) { @@ -244,7 +241,7 @@ std::shared_ptr LoadWalletInternal(interfaces::Chain& chain, const std: } } // namespace -std::shared_ptr LoadWallet(interfaces::Chain& chain, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +std::shared_ptr LoadWallet(WalletContext& context, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) { auto result = WITH_LOCK(g_loading_wallet_mutex, return g_loading_wallet_set.insert(name)); if (!result.second) { @@ -252,12 +249,12 @@ std::shared_ptr LoadWallet(interfaces::Chain& chain, const std::string& status = DatabaseStatus::FAILED_LOAD; return nullptr; } - auto wallet = LoadWalletInternal(chain, name, load_on_start, options, status, error, warnings); + auto wallet = LoadWalletInternal(context, name, load_on_start, options, status, error, warnings); WITH_LOCK(g_loading_wallet_mutex, g_loading_wallet_set.erase(result.first)); return wallet; } -std::shared_ptr CreateWallet(interfaces::Chain& chain, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +std::shared_ptr CreateWallet(WalletContext& context, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) { uint64_t wallet_creation_flags = options.create_flags; const SecureString& passphrase = options.create_passphrase; @@ -302,8 +299,8 @@ std::shared_ptr CreateWallet(interfaces::Chain& chain, const std::strin } // Make the wallet - chain.initMessage(_("Loading wallet…").translated); - std::shared_ptr wallet = CWallet::Create(&chain, name, std::move(database), wallet_creation_flags, error, warnings); + context.chain->initMessage(_("Loading wallet…").translated); + std::shared_ptr wallet = CWallet::Create(context, name, std::move(database), wallet_creation_flags, error, warnings); if (!wallet) { error = Untranslated("Wallet creation failed.") + Untranslated(" ") + error; status = DatabaseStatus::FAILED_CREATE; @@ -345,11 +342,11 @@ std::shared_ptr CreateWallet(interfaces::Chain& chain, const std::strin wallet->Lock(); } } - AddWallet(wallet); + AddWallet(context, wallet); wallet->postInitProcess(); // Write the wallet settings - UpdateWalletSetting(chain, name, load_on_start, warnings); + UpdateWalletSetting(*context.chain, name, load_on_start, warnings); status = DatabaseStatus::SUCCESS; return wallet; @@ -1802,9 +1799,9 @@ void CWallet::ResendWalletTransactions() /** @} */ // end of mapWallet -void MaybeResendWalletTxs() +void MaybeResendWalletTxs(WalletContext& context) { - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { pwallet->ResendWalletTransactions(); } } @@ -2509,8 +2506,9 @@ std::unique_ptr MakeWalletDatabase(const std::string& name, cons return MakeDatabase(wallet_path, options, status, error_string); } -std::shared_ptr CWallet::Create(interfaces::Chain* chain, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings) +std::shared_ptr CWallet::Create(WalletContext& context, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings) { + interfaces::Chain* chain = context.chain; const std::string& walletFile = database->Filename(); int64_t nStart = GetTimeMillis(); @@ -2722,9 +2720,9 @@ std::shared_ptr CWallet::Create(interfaces::Chain* chain, const std::st } { - LOCK(cs_wallets); - for (auto& load_wallet : g_load_wallet_fns) { - load_wallet(interfaces::MakeWallet(walletInstance)); + LOCK(context.wallets_mutex); + for (auto& load_wallet : context.wallet_load_fns) { + load_wallet(interfaces::MakeWallet(context, walletInstance)); } } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 25f89e8ea4..a1bbf66136 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -42,6 +42,8 @@ #include +struct WalletContext; + using LoadWalletFn = std::function wallet)>; struct bilingual_str; @@ -53,14 +55,14 @@ struct bilingual_str; //! by the shared pointer deleter. void UnloadWallet(std::shared_ptr&& wallet); -bool AddWallet(const std::shared_ptr& wallet); -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings); -bool RemoveWallet(const std::shared_ptr& wallet, std::optional load_on_start); -std::vector> GetWallets(); -std::shared_ptr GetWallet(const std::string& name); -std::shared_ptr LoadWallet(interfaces::Chain& chain, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); -std::shared_ptr CreateWallet(interfaces::Chain& chain, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); -std::unique_ptr HandleLoadWallet(LoadWalletFn load_wallet); +bool AddWallet(WalletContext& context, const std::shared_ptr& wallet); +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start, std::vector& warnings); +bool RemoveWallet(WalletContext& context, const std::shared_ptr& wallet, std::optional load_on_start); +std::vector> GetWallets(WalletContext& context); +std::shared_ptr GetWallet(WalletContext& context, const std::string& name); +std::shared_ptr LoadWallet(WalletContext& context, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); +std::shared_ptr CreateWallet(WalletContext& context, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); +std::unique_ptr HandleLoadWallet(WalletContext& context, LoadWalletFn load_wallet); std::unique_ptr MakeWalletDatabase(const std::string& name, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error); //! -paytxfee default @@ -772,7 +774,7 @@ public: bool MarkReplaced(const uint256& originalHash, const uint256& newHash); /* Initializes the wallet, returns a new CWallet instance or a null pointer in case of an error */ - static std::shared_ptr Create(interfaces::Chain* chain, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings); + static std::shared_ptr Create(WalletContext& context, const std::string& name, std::unique_ptr database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector& warnings); /** * Wallet post-init setup @@ -919,7 +921,7 @@ public: * Called periodically by the schedule thread. Prompts individual wallets to resend * their transactions. Actual rebroadcast schedule is managed by the wallets themselves. */ -void MaybeResendWalletTxs(); +void MaybeResendWalletTxs(WalletContext& context); /** RAII object to check and reserve a wallet rescan */ class WalletRescanReserver diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index 1e5d8dfa3a..2fabe65a93 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -1004,14 +1004,14 @@ DBErrors WalletBatch::ZapSelectTx(std::vector& vTxHashIn, std::vector fOneThread(false); if (fOneThread.exchange(true)) { return; } - for (const std::shared_ptr& pwallet : GetWallets()) { + for (const std::shared_ptr& pwallet : GetWallets(context)) { WalletDatabase& dbh = pwallet->GetDatabase(); unsigned int nUpdateCounter = dbh.nUpdateCounter; diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index 9b775eb481..25c2ec5909 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -31,6 +31,7 @@ static const bool DEFAULT_FLUSHWALLET = true; struct CBlockLocator; +struct WalletContext; class CKeyPool; class CMasterKey; class CScript; @@ -279,7 +280,7 @@ private: }; //! Compacts BDB state so that wallet.dat is self-contained (if there are changes) -void MaybeCompactWalletDB(); +void MaybeCompactWalletDB(WalletContext& context); //! Callback for filtering key types to deserialize in ReadKeyValue using KeyFilterFn = std::function;