Merge bitcoin/bitcoin#19101: refactor: remove ::vpwallets and related global variables

62a09a3077 refactor: remove ::vpwallets and related global variables (Russell Yanofsky)

Pull request description:

  Get rid of global wallet list variables by moving them to WalletContext struct

  - [`cs_wallets`](e638acf697/src/wallet/wallet.cpp (L56)) is now [`WalletContext::wallet_mutex`](4be544c7ec/src/wallet/context.h (L37))
  - [`vpwallets`](e638acf697/src/wallet/wallet.cpp (L57)) is now [`WalletContext::wallets`](4be544c7ec/src/wallet/context.h (L38))
  - [`g_load_wallet_fns`](e638acf697/src/wallet/wallet.cpp (L58)) is now [`WalletContext::wallet_load_fns`](4be544c7ec/src/wallet/context.h (L39))

ACKs for top commit:
  achow101:
    ACK 62a09a3077
  meshcollider:
    re-utACK 62a09a3077

Tree-SHA512: 74428180d57b4214c3d96963e6ff43e8778f6f23b6880262d1272f2de67d02714fdc3ebb558f62e48655b221a642c36f80ef37c8f89d362e2d66fd93cbf03b8f
This commit is contained in:
fanquake 2021-08-19 09:18:28 +08:00
commit 638855af63
No known key found for this signature in database
GPG key ID: 2EEB9F5CC09526C1
13 changed files with 167 additions and 127 deletions

View file

@ -332,6 +332,9 @@ public:
//! loaded at startup or by RPC. //! loaded at startup or by RPC.
using LoadWalletFn = std::function<void(std::unique_ptr<Wallet> wallet)>; using LoadWalletFn = std::function<void(std::unique_ptr<Wallet> wallet)>;
virtual std::unique_ptr<Handler> handleLoadWallet(LoadWalletFn fn) = 0; virtual std::unique_ptr<Handler> handleLoadWallet(LoadWalletFn fn) = 0;
//! Return pointer to internal context, useful for testing.
virtual WalletContext* context() { return nullptr; }
}; };
//! Information about one wallet address. //! Information about one wallet address.
@ -410,7 +413,7 @@ struct WalletTxOut
//! Return implementation of Wallet interface. This function is defined in //! Return implementation of Wallet interface. This function is defined in
//! dummywallet.cpp and throws if the wallet component is not compiled. //! dummywallet.cpp and throws if the wallet component is not compiled.
std::unique_ptr<Wallet> MakeWallet(const std::shared_ptr<CWallet>& wallet); std::unique_ptr<Wallet> MakeWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet);
//! Return implementation of ChainClient interface for a wallet client. This //! Return implementation of ChainClient interface for a wallet client. This
//! function will be undefined in builds where ENABLE_WALLET is false. //! function will be undefined in builds where ENABLE_WALLET is false.

View file

@ -109,9 +109,10 @@ void TestAddAddressesToSendBook(interfaces::Node& node)
std::unique_ptr<const PlatformStyle> platformStyle(PlatformStyle::instantiate("other")); std::unique_ptr<const PlatformStyle> platformStyle(PlatformStyle::instantiate("other"));
OptionsModel optionsModel; OptionsModel optionsModel;
ClientModel clientModel(node, &optionsModel); ClientModel clientModel(node, &optionsModel);
AddWallet(wallet); WalletContext& context = *node.walletClient().context();
WalletModel walletModel(interfaces::MakeWallet(wallet), clientModel, platformStyle.get()); AddWallet(context, wallet);
RemoveWallet(wallet, std::nullopt); WalletModel walletModel(interfaces::MakeWallet(context, wallet), clientModel, platformStyle.get());
RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt);
EditAddressDialog editAddressDialog(EditAddressDialog::NewSendingAddress); EditAddressDialog editAddressDialog(EditAddressDialog::NewSendingAddress);
editAddressDialog.setModel(walletModel.getAddressTableModel()); editAddressDialog.setModel(walletModel.getAddressTableModel());

View file

@ -164,9 +164,10 @@ void TestGUI(interfaces::Node& node)
TransactionView transactionView(platformStyle.get()); TransactionView transactionView(platformStyle.get());
OptionsModel optionsModel; OptionsModel optionsModel;
ClientModel clientModel(node, &optionsModel); ClientModel clientModel(node, &optionsModel);
AddWallet(wallet); WalletContext& context = *node.walletClient().context();
WalletModel walletModel(interfaces::MakeWallet(wallet), clientModel, platformStyle.get()); AddWallet(context, wallet);
RemoveWallet(wallet, std::nullopt); WalletModel walletModel(interfaces::MakeWallet(context, wallet), clientModel, platformStyle.get());
RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt);
sendCoinsDialog.setModel(&walletModel); sendCoinsDialog.setModel(&walletModel);
transactionView.setModel(&walletModel); transactionView.setModel(&walletModel);

View file

@ -5,11 +5,22 @@
#ifndef BITCOIN_WALLET_CONTEXT_H #ifndef BITCOIN_WALLET_CONTEXT_H
#define BITCOIN_WALLET_CONTEXT_H #define BITCOIN_WALLET_CONTEXT_H
#include <sync.h>
#include <functional>
#include <list>
#include <memory>
#include <vector>
class ArgsManager; class ArgsManager;
class CWallet;
namespace interfaces { namespace interfaces {
class Chain; class Chain;
class Wallet;
} // namespace interfaces } // namespace interfaces
using LoadWalletFn = std::function<void(std::unique_ptr<interfaces::Wallet> wallet)>;
//! WalletContext struct containing references to state shared between CWallet //! WalletContext struct containing references to state shared between CWallet
//! instances, like the reference to the chain interface, and the list of opened //! instances, like the reference to the chain interface, and the list of opened
//! wallets. //! wallets.
@ -22,7 +33,10 @@ class Chain;
//! behavior. //! behavior.
struct WalletContext { struct WalletContext {
interfaces::Chain* chain{nullptr}; interfaces::Chain* chain{nullptr};
ArgsManager* args{nullptr}; ArgsManager* args{nullptr}; // Currently a raw pointer because the memory is not managed by this struct
Mutex wallets_mutex;
std::vector<std::shared_ptr<CWallet>> wallets GUARDED_BY(wallets_mutex);
std::list<LoadWalletFn> wallet_load_fns GUARDED_BY(wallets_mutex);
//! Declare default constructor and destructor that are not inline, so code //! Declare default constructor and destructor that are not inline, so code
//! instantiating the WalletContext struct doesn't need to #include class //! instantiating the WalletContext struct doesn't need to #include class

View file

@ -110,7 +110,7 @@ WalletTxOut MakeWalletTxOut(const CWallet& wallet,
class WalletImpl : public Wallet class WalletImpl : public Wallet
{ {
public: public:
explicit WalletImpl(const std::shared_ptr<CWallet>& wallet) : m_wallet(wallet) {} explicit WalletImpl(WalletContext& context, const std::shared_ptr<CWallet>& wallet) : m_context(context), m_wallet(wallet) {}
bool encryptWallet(const SecureString& wallet_passphrase) override bool encryptWallet(const SecureString& wallet_passphrase) override
{ {
@ -458,7 +458,7 @@ public:
CAmount getDefaultMaxTxFee() override { return m_wallet->m_default_max_tx_fee; } CAmount getDefaultMaxTxFee() override { return m_wallet->m_default_max_tx_fee; }
void remove() override void remove() override
{ {
RemoveWallet(m_wallet, false /* load_on_start */); RemoveWallet(m_context, m_wallet, false /* load_on_start */);
} }
bool isLegacy() override { return m_wallet->IsLegacy(); } bool isLegacy() override { return m_wallet->IsLegacy(); }
std::unique_ptr<Handler> handleUnload(UnloadFn fn) override std::unique_ptr<Handler> handleUnload(UnloadFn fn) override
@ -494,6 +494,7 @@ public:
} }
CWallet* wallet() override { return m_wallet.get(); } CWallet* wallet() override { return m_wallet.get(); }
WalletContext& m_context;
std::shared_ptr<CWallet> m_wallet; std::shared_ptr<CWallet> m_wallet;
}; };
@ -505,7 +506,7 @@ public:
m_context.chain = &chain; m_context.chain = &chain;
m_context.args = &args; m_context.args = &args;
} }
~WalletClientImpl() override { UnloadWallets(); } ~WalletClientImpl() override { UnloadWallets(m_context); }
//! ChainClient methods //! ChainClient methods
void registerRpcs() override void registerRpcs() override
@ -519,11 +520,11 @@ public:
m_rpc_handlers.emplace_back(m_context.chain->handleRpc(m_rpc_commands.back())); m_rpc_handlers.emplace_back(m_context.chain->handleRpc(m_rpc_commands.back()));
} }
} }
bool verify() override { return VerifyWallets(*m_context.chain); } bool verify() override { return VerifyWallets(m_context); }
bool load() override { return LoadWallets(*m_context.chain); } bool load() override { return LoadWallets(m_context); }
void start(CScheduler& scheduler) override { return StartWallets(scheduler, *Assert(m_context.args)); } void start(CScheduler& scheduler) override { return StartWallets(m_context, scheduler); }
void flush() override { return FlushWallets(); } void flush() override { return FlushWallets(m_context); }
void stop() override { return StopWallets(); } void stop() override { return StopWallets(m_context); }
void setMockTime(int64_t time) override { return SetMockTime(time); } void setMockTime(int64_t time) override { return SetMockTime(time); }
//! WalletClient methods //! WalletClient methods
@ -535,14 +536,14 @@ public:
options.require_create = true; options.require_create = true;
options.create_flags = wallet_creation_flags; options.create_flags = wallet_creation_flags;
options.create_passphrase = passphrase; options.create_passphrase = passphrase;
return MakeWallet(CreateWallet(*m_context.chain, name, true /* load_on_start */, options, status, error, warnings)); return MakeWallet(m_context, CreateWallet(m_context, name, true /* load_on_start */, options, status, error, warnings));
} }
std::unique_ptr<Wallet> loadWallet(const std::string& name, bilingual_str& error, std::vector<bilingual_str>& warnings) override std::unique_ptr<Wallet> loadWallet(const std::string& name, bilingual_str& error, std::vector<bilingual_str>& warnings) override
{ {
DatabaseOptions options; DatabaseOptions options;
DatabaseStatus status; DatabaseStatus status;
options.require_existing = true; options.require_existing = true;
return MakeWallet(LoadWallet(*m_context.chain, name, true /* load_on_start */, options, status, error, warnings)); return MakeWallet(m_context, LoadWallet(m_context, name, true /* load_on_start */, options, status, error, warnings));
} }
std::string getWalletDir() override std::string getWalletDir() override
{ {
@ -559,15 +560,16 @@ public:
std::vector<std::unique_ptr<Wallet>> getWallets() override std::vector<std::unique_ptr<Wallet>> getWallets() override
{ {
std::vector<std::unique_ptr<Wallet>> wallets; std::vector<std::unique_ptr<Wallet>> wallets;
for (const auto& wallet : GetWallets()) { for (const auto& wallet : GetWallets(m_context)) {
wallets.emplace_back(MakeWallet(wallet)); wallets.emplace_back(MakeWallet(m_context, wallet));
} }
return wallets; return wallets;
} }
std::unique_ptr<Handler> handleLoadWallet(LoadWalletFn fn) override std::unique_ptr<Handler> handleLoadWallet(LoadWalletFn fn) override
{ {
return HandleLoadWallet(std::move(fn)); return HandleLoadWallet(m_context, std::move(fn));
} }
WalletContext* context() override { return &m_context; }
WalletContext m_context; WalletContext m_context;
const std::vector<std::string> m_wallet_filenames; const std::vector<std::string> m_wallet_filenames;
@ -578,7 +580,7 @@ public:
} // namespace wallet } // namespace wallet
namespace interfaces { namespace interfaces {
std::unique_ptr<Wallet> MakeWallet(const std::shared_ptr<CWallet>& wallet) { return wallet ? std::make_unique<wallet::WalletImpl>(wallet) : nullptr; } std::unique_ptr<Wallet> MakeWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet) { return wallet ? std::make_unique<wallet::WalletImpl>(context, wallet) : nullptr; }
std::unique_ptr<WalletClient> MakeWalletClient(Chain& chain, ArgsManager& args) std::unique_ptr<WalletClient> MakeWalletClient(Chain& chain, ArgsManager& args)
{ {

View file

@ -11,13 +11,15 @@
#include <util/string.h> #include <util/string.h>
#include <util/system.h> #include <util/system.h>
#include <util/translation.h> #include <util/translation.h>
#include <wallet/context.h>
#include <wallet/wallet.h> #include <wallet/wallet.h>
#include <wallet/walletdb.h> #include <wallet/walletdb.h>
#include <univalue.h> #include <univalue.h>
bool VerifyWallets(interfaces::Chain& chain) bool VerifyWallets(WalletContext& context)
{ {
interfaces::Chain& chain = *context.chain;
if (gArgs.IsArgSet("-walletdir")) { if (gArgs.IsArgSet("-walletdir")) {
fs::path wallet_dir = gArgs.GetArg("-walletdir", ""); fs::path wallet_dir = gArgs.GetArg("-walletdir", "");
boost::system::error_code error; boost::system::error_code error;
@ -87,8 +89,9 @@ bool VerifyWallets(interfaces::Chain& chain)
return true; return true;
} }
bool LoadWallets(interfaces::Chain& chain) bool LoadWallets(WalletContext& context)
{ {
interfaces::Chain& chain = *context.chain;
try { try {
std::set<fs::path> wallet_paths; std::set<fs::path> wallet_paths;
for (const std::string& name : gArgs.GetArgs("-wallet")) { for (const std::string& name : gArgs.GetArgs("-wallet")) {
@ -106,13 +109,13 @@ bool LoadWallets(interfaces::Chain& chain)
continue; continue;
} }
chain.initMessage(_("Loading wallet…").translated); chain.initMessage(_("Loading wallet…").translated);
std::shared_ptr<CWallet> pwallet = database ? CWallet::Create(&chain, name, std::move(database), options.create_flags, error, warnings) : nullptr; std::shared_ptr<CWallet> pwallet = database ? CWallet::Create(context, name, std::move(database), options.create_flags, error, warnings) : nullptr;
if (!warnings.empty()) chain.initWarning(Join(warnings, Untranslated("\n"))); if (!warnings.empty()) chain.initWarning(Join(warnings, Untranslated("\n")));
if (!pwallet) { if (!pwallet) {
chain.initError(error); chain.initError(error);
return false; return false;
} }
AddWallet(pwallet); AddWallet(context, pwallet);
} }
return true; return true;
} catch (const std::runtime_error& e) { } catch (const std::runtime_error& e) {
@ -121,41 +124,41 @@ bool LoadWallets(interfaces::Chain& chain)
} }
} }
void StartWallets(CScheduler& scheduler, const ArgsManager& args) void StartWallets(WalletContext& context, CScheduler& scheduler)
{ {
for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { for (const std::shared_ptr<CWallet>& pwallet : GetWallets(context)) {
pwallet->postInitProcess(); pwallet->postInitProcess();
} }
// Schedule periodic wallet flushes and tx rebroadcasts // Schedule periodic wallet flushes and tx rebroadcasts
if (args.GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) { if (context.args->GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) {
scheduler.scheduleEvery(MaybeCompactWalletDB, std::chrono::milliseconds{500}); scheduler.scheduleEvery([&context] { MaybeCompactWalletDB(context); }, std::chrono::milliseconds{500});
} }
scheduler.scheduleEvery(MaybeResendWalletTxs, std::chrono::milliseconds{1000}); scheduler.scheduleEvery([&context] { MaybeResendWalletTxs(context); }, std::chrono::milliseconds{1000});
} }
void FlushWallets() void FlushWallets(WalletContext& context)
{ {
for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { for (const std::shared_ptr<CWallet>& pwallet : GetWallets(context)) {
pwallet->Flush(); pwallet->Flush();
} }
} }
void StopWallets() void StopWallets(WalletContext& context)
{ {
for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { for (const std::shared_ptr<CWallet>& pwallet : GetWallets(context)) {
pwallet->Close(); pwallet->Close();
} }
} }
void UnloadWallets() void UnloadWallets(WalletContext& context)
{ {
auto wallets = GetWallets(); auto wallets = GetWallets(context);
while (!wallets.empty()) { while (!wallets.empty()) {
auto wallet = wallets.back(); auto wallet = wallets.back();
wallets.pop_back(); wallets.pop_back();
std::vector<bilingual_str> warnings; std::vector<bilingual_str> warnings;
RemoveWallet(wallet, std::nullopt, warnings); RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt, warnings);
UnloadWallet(std::move(wallet)); UnloadWallet(std::move(wallet));
} }
} }

View file

@ -11,27 +11,28 @@
class ArgsManager; class ArgsManager;
class CScheduler; class CScheduler;
struct WalletContext;
namespace interfaces { namespace interfaces {
class Chain; class Chain;
} // namespace interfaces } // namespace interfaces
//! Responsible for reading and validating the -wallet arguments and verifying the wallet database. //! Responsible for reading and validating the -wallet arguments and verifying the wallet database.
bool VerifyWallets(interfaces::Chain& chain); bool VerifyWallets(WalletContext& context);
//! Load wallet databases. //! Load wallet databases.
bool LoadWallets(interfaces::Chain& chain); bool LoadWallets(WalletContext& context);
//! Complete startup of wallets. //! Complete startup of wallets.
void StartWallets(CScheduler& scheduler, const ArgsManager& args); void StartWallets(WalletContext& context, CScheduler& scheduler);
//! Flush all wallets in preparation for shutdown. //! Flush all wallets in preparation for shutdown.
void FlushWallets(); void FlushWallets(WalletContext& context);
//! Stop all wallets. Wallets will be flushed first. //! Stop all wallets. Wallets will be flushed first.
void StopWallets(); void StopWallets(WalletContext& context);
//! Close all wallets. //! Close all wallets.
void UnloadWallets(); void UnloadWallets(WalletContext& context);
#endif // BITCOIN_WALLET_LOAD_H #endif // BITCOIN_WALLET_LOAD_H

View file

@ -96,14 +96,16 @@ bool GetWalletNameFromJSONRPCRequest(const JSONRPCRequest& request, std::string&
std::shared_ptr<CWallet> GetWalletForJSONRPCRequest(const JSONRPCRequest& request) std::shared_ptr<CWallet> GetWalletForJSONRPCRequest(const JSONRPCRequest& request)
{ {
CHECK_NONFATAL(request.mode == JSONRPCRequest::EXECUTE); CHECK_NONFATAL(request.mode == JSONRPCRequest::EXECUTE);
WalletContext& context = EnsureWalletContext(request.context);
std::string wallet_name; std::string wallet_name;
if (GetWalletNameFromJSONRPCRequest(request, wallet_name)) { if (GetWalletNameFromJSONRPCRequest(request, wallet_name)) {
std::shared_ptr<CWallet> pwallet = GetWallet(wallet_name); std::shared_ptr<CWallet> pwallet = GetWallet(context, wallet_name);
if (!pwallet) throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded"); if (!pwallet) throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded");
return pwallet; return pwallet;
} }
std::vector<std::shared_ptr<CWallet>> wallets = GetWallets(); std::vector<std::shared_ptr<CWallet>> wallets = GetWallets(context);
if (wallets.size() == 1) { if (wallets.size() == 1) {
return wallets[0]; return wallets[0];
} }
@ -2562,7 +2564,8 @@ static RPCHelpMan listwallets()
{ {
UniValue obj(UniValue::VARR); UniValue obj(UniValue::VARR);
for (const std::shared_ptr<CWallet>& wallet : GetWallets()) { WalletContext& context = EnsureWalletContext(request.context);
for (const std::shared_ptr<CWallet>& wallet : GetWallets(context)) {
LOCK(wallet->cs_wallet); LOCK(wallet->cs_wallet);
obj.push_back(wallet->GetName()); obj.push_back(wallet->GetName());
} }
@ -2580,7 +2583,7 @@ static std::tuple<std::shared_ptr<CWallet>, std::vector<bilingual_str>> LoadWall
bilingual_str error; bilingual_str error;
std::vector<bilingual_str> warnings; std::vector<bilingual_str> warnings;
std::optional<bool> load_on_start = load_on_start_param.isNull() ? std::nullopt : std::optional<bool>(load_on_start_param.get_bool()); std::optional<bool> load_on_start = load_on_start_param.isNull() ? std::nullopt : std::optional<bool>(load_on_start_param.get_bool());
std::shared_ptr<CWallet> const wallet = LoadWallet(*context.chain, wallet_name, load_on_start, options, status, error, warnings); std::shared_ptr<CWallet> const wallet = LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings);
if (!wallet) { if (!wallet) {
// Map bad format to not found, since bad format is returned when the // Map bad format to not found, since bad format is returned when the
@ -2788,7 +2791,7 @@ static RPCHelpMan createwallet()
options.create_passphrase = passphrase; options.create_passphrase = passphrase;
bilingual_str error; bilingual_str error;
std::optional<bool> load_on_start = request.params[6].isNull() ? std::nullopt : std::optional<bool>(request.params[6].get_bool()); std::optional<bool> load_on_start = request.params[6].isNull() ? std::nullopt : std::optional<bool>(request.params[6].get_bool());
std::shared_ptr<CWallet> wallet = CreateWallet(*context.chain, request.params[0].get_str(), load_on_start, options, status, error, warnings); std::shared_ptr<CWallet> wallet = CreateWallet(context, request.params[0].get_str(), load_on_start, options, status, error, warnings);
if (!wallet) { if (!wallet) {
RPCErrorCode code = status == DatabaseStatus::FAILED_ENCRYPT ? RPC_WALLET_ENCRYPTION_FAILED : RPC_WALLET_ERROR; RPCErrorCode code = status == DatabaseStatus::FAILED_ENCRYPT ? RPC_WALLET_ENCRYPTION_FAILED : RPC_WALLET_ERROR;
throw JSONRPCError(code, error.original); throw JSONRPCError(code, error.original);
@ -2892,7 +2895,8 @@ static RPCHelpMan unloadwallet()
wallet_name = request.params[0].get_str(); wallet_name = request.params[0].get_str();
} }
std::shared_ptr<CWallet> wallet = GetWallet(wallet_name); WalletContext& context = EnsureWalletContext(request.context);
std::shared_ptr<CWallet> wallet = GetWallet(context, wallet_name);
if (!wallet) { if (!wallet) {
throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded"); throw JSONRPCError(RPC_WALLET_NOT_FOUND, "Requested wallet does not exist or is not loaded");
} }
@ -2902,7 +2906,7 @@ static RPCHelpMan unloadwallet()
// is destroyed (see CheckUniqueFileid). // is destroyed (see CheckUniqueFileid).
std::vector<bilingual_str> warnings; std::vector<bilingual_str> warnings;
std::optional<bool> load_on_start = request.params[1].isNull() ? std::nullopt : std::optional<bool>(request.params[1].get_bool()); std::optional<bool> load_on_start = request.params[1].isNull() ? std::nullopt : std::optional<bool>(request.params[1].get_bool());
if (!RemoveWallet(wallet, load_on_start, warnings)) { if (!RemoveWallet(context, wallet, load_on_start, warnings)) {
throw JSONRPCError(RPC_MISC_ERROR, "Requested wallet already unloaded"); throw JSONRPCError(RPC_MISC_ERROR, "Requested wallet already unloaded");
} }

View file

@ -20,6 +20,7 @@
#include <util/translation.h> #include <util/translation.h>
#include <validation.h> #include <validation.h>
#include <wallet/coincontrol.h> #include <wallet/coincontrol.h>
#include <wallet/context.h>
#include <wallet/test/util.h> #include <wallet/test/util.h>
#include <wallet/test/wallet_test_fixture.h> #include <wallet/test/wallet_test_fixture.h>
@ -30,8 +31,6 @@ RPCHelpMan importmulti();
RPCHelpMan dumpwallet(); RPCHelpMan dumpwallet();
RPCHelpMan importwallet(); RPCHelpMan importwallet();
extern RecursiveMutex cs_wallets;
// Ensure that fee levels defined in the wallet are at least as high // Ensure that fee levels defined in the wallet are at least as high
// as the default levels for node policy. // as the default levels for node policy.
static_assert(DEFAULT_TRANSACTION_MINFEE >= DEFAULT_MIN_RELAY_TX_FEE, "wallet minimum fee is smaller than default relay fee"); static_assert(DEFAULT_TRANSACTION_MINFEE >= DEFAULT_MIN_RELAY_TX_FEE, "wallet minimum fee is smaller than default relay fee");
@ -39,15 +38,15 @@ static_assert(WALLET_INCREMENTAL_RELAY_FEE >= DEFAULT_INCREMENTAL_RELAY_FEE, "wa
BOOST_FIXTURE_TEST_SUITE(wallet_tests, WalletTestingSetup) BOOST_FIXTURE_TEST_SUITE(wallet_tests, WalletTestingSetup)
static std::shared_ptr<CWallet> TestLoadWallet(interfaces::Chain* chain) static std::shared_ptr<CWallet> TestLoadWallet(WalletContext& context)
{ {
DatabaseOptions options; DatabaseOptions options;
DatabaseStatus status; DatabaseStatus status;
bilingual_str error; bilingual_str error;
std::vector<bilingual_str> warnings; std::vector<bilingual_str> warnings;
auto database = MakeWalletDatabase("", options, status, error); auto database = MakeWalletDatabase("", options, status, error);
auto wallet = CWallet::Create(chain, "", std::move(database), options.create_flags, error, warnings); auto wallet = CWallet::Create(context, "", std::move(database), options.create_flags, error, warnings);
if (chain) { if (context.chain) {
wallet->postInitProcess(); wallet->postInitProcess();
} }
return wallet; return wallet;
@ -200,7 +199,8 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup)
std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(m_node.chain.get(), "", CreateDummyWalletDatabase()); std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(m_node.chain.get(), "", CreateDummyWalletDatabase());
wallet->SetupLegacyScriptPubKeyMan(); wallet->SetupLegacyScriptPubKeyMan();
WITH_LOCK(wallet->cs_wallet, wallet->SetLastBlockProcessed(newTip->nHeight, newTip->GetBlockHash())); WITH_LOCK(wallet->cs_wallet, wallet->SetLastBlockProcessed(newTip->nHeight, newTip->GetBlockHash()));
AddWallet(wallet); WalletContext context;
AddWallet(context, wallet);
UniValue keys; UniValue keys;
keys.setArray(); keys.setArray();
UniValue key; UniValue key;
@ -218,6 +218,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup)
key.pushKV("internal", UniValue(true)); key.pushKV("internal", UniValue(true));
keys.push_back(key); keys.push_back(key);
JSONRPCRequest request; JSONRPCRequest request;
request.context = &context;
request.params.setArray(); request.params.setArray();
request.params.push_back(keys); request.params.push_back(keys);
@ -231,7 +232,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup)
"downloading and rescanning the relevant blocks (see -reindex and -rescan " "downloading and rescanning the relevant blocks (see -reindex and -rescan "
"options).\"}},{\"success\":true}]", "options).\"}},{\"success\":true}]",
0, oldTip->GetBlockTimeMax(), TIMESTAMP_WINDOW)); 0, oldTip->GetBlockTimeMax(), TIMESTAMP_WINDOW));
RemoveWallet(wallet, std::nullopt); RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt);
} }
} }
@ -258,6 +259,7 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup)
// Import key into wallet and call dumpwallet to create backup file. // Import key into wallet and call dumpwallet to create backup file.
{ {
WalletContext context;
std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(m_node.chain.get(), "", CreateDummyWalletDatabase()); std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(m_node.chain.get(), "", CreateDummyWalletDatabase());
{ {
auto spk_man = wallet->GetOrCreateLegacyScriptPubKeyMan(); auto spk_man = wallet->GetOrCreateLegacyScriptPubKeyMan();
@ -265,15 +267,16 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup)
spk_man->mapKeyMetadata[coinbaseKey.GetPubKey().GetID()].nCreateTime = KEY_TIME; spk_man->mapKeyMetadata[coinbaseKey.GetPubKey().GetID()].nCreateTime = KEY_TIME;
spk_man->AddKeyPubKey(coinbaseKey, coinbaseKey.GetPubKey()); spk_man->AddKeyPubKey(coinbaseKey, coinbaseKey.GetPubKey());
AddWallet(wallet); AddWallet(context, wallet);
wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash()); wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash());
} }
JSONRPCRequest request; JSONRPCRequest request;
request.context = &context;
request.params.setArray(); request.params.setArray();
request.params.push_back(backup_file); request.params.push_back(backup_file);
::dumpwallet().HandleRequest(request); ::dumpwallet().HandleRequest(request);
RemoveWallet(wallet, std::nullopt); RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt);
} }
// Call importwallet RPC and verify all blocks with timestamps >= BLOCK_TIME // Call importwallet RPC and verify all blocks with timestamps >= BLOCK_TIME
@ -283,13 +286,15 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup)
LOCK(wallet->cs_wallet); LOCK(wallet->cs_wallet);
wallet->SetupLegacyScriptPubKeyMan(); wallet->SetupLegacyScriptPubKeyMan();
WalletContext context;
JSONRPCRequest request; JSONRPCRequest request;
request.context = &context;
request.params.setArray(); request.params.setArray();
request.params.push_back(backup_file); request.params.push_back(backup_file);
AddWallet(wallet); AddWallet(context, wallet);
wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash()); wallet->SetLastBlockProcessed(m_node.chainman->ActiveChain().Height(), m_node.chainman->ActiveChain().Tip()->GetBlockHash());
::importwallet().HandleRequest(request); ::importwallet().HandleRequest(request);
RemoveWallet(wallet, std::nullopt); RemoveWallet(context, wallet, /* load_on_startup= */ std::nullopt);
BOOST_CHECK_EQUAL(wallet->mapWallet.size(), 3U); BOOST_CHECK_EQUAL(wallet->mapWallet.size(), 3U);
BOOST_CHECK_EQUAL(m_coinbase_txns.size(), 103U); BOOST_CHECK_EQUAL(m_coinbase_txns.size(), 103U);
@ -679,7 +684,9 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup)
{ {
gArgs.ForceSetArg("-unsafesqlitesync", "1"); gArgs.ForceSetArg("-unsafesqlitesync", "1");
// Create new wallet with known key and unload it. // Create new wallet with known key and unload it.
auto wallet = TestLoadWallet(m_node.chain.get()); WalletContext context;
context.chain = m_node.chain.get();
auto wallet = TestLoadWallet(context);
CKey key; CKey key;
key.MakeNewKey(true); key.MakeNewKey(true);
AddKey(*wallet, key); AddKey(*wallet, key);
@ -719,7 +726,7 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup)
// Reload wallet and make sure new transactions are detected despite events // Reload wallet and make sure new transactions are detected despite events
// being blocked // being blocked
wallet = TestLoadWallet(m_node.chain.get()); wallet = TestLoadWallet(context);
BOOST_CHECK(rescan_completed); BOOST_CHECK(rescan_completed);
BOOST_CHECK_EQUAL(addtx_count, 2); BOOST_CHECK_EQUAL(addtx_count, 2);
{ {
@ -746,20 +753,20 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup)
// deadlock during the sync and simulates a new block notification happening // deadlock during the sync and simulates a new block notification happening
// as soon as possible. // as soon as possible.
addtx_count = 0; addtx_count = 0;
auto handler = HandleLoadWallet([&](std::unique_ptr<interfaces::Wallet> wallet) EXCLUSIVE_LOCKS_REQUIRED(wallet->wallet()->cs_wallet, cs_wallets) { auto handler = HandleLoadWallet(context, [&](std::unique_ptr<interfaces::Wallet> wallet) EXCLUSIVE_LOCKS_REQUIRED(wallet->wallet()->cs_wallet, context.wallets_mutex) {
BOOST_CHECK(rescan_completed); BOOST_CHECK(rescan_completed);
m_coinbase_txns.push_back(CreateAndProcessBlock({}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]); m_coinbase_txns.push_back(CreateAndProcessBlock({}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]);
block_tx = TestSimpleSpend(*m_coinbase_txns[2], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey())); block_tx = TestSimpleSpend(*m_coinbase_txns[2], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey()));
m_coinbase_txns.push_back(CreateAndProcessBlock({block_tx}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]); m_coinbase_txns.push_back(CreateAndProcessBlock({block_tx}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]);
mempool_tx = TestSimpleSpend(*m_coinbase_txns[3], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey())); mempool_tx = TestSimpleSpend(*m_coinbase_txns[3], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey()));
BOOST_CHECK(m_node.chain->broadcastTransaction(MakeTransactionRef(mempool_tx), DEFAULT_TRANSACTION_MAXFEE, false, error)); BOOST_CHECK(m_node.chain->broadcastTransaction(MakeTransactionRef(mempool_tx), DEFAULT_TRANSACTION_MAXFEE, false, error));
LEAVE_CRITICAL_SECTION(cs_wallets); LEAVE_CRITICAL_SECTION(context.wallets_mutex);
LEAVE_CRITICAL_SECTION(wallet->wallet()->cs_wallet); LEAVE_CRITICAL_SECTION(wallet->wallet()->cs_wallet);
SyncWithValidationInterfaceQueue(); SyncWithValidationInterfaceQueue();
ENTER_CRITICAL_SECTION(wallet->wallet()->cs_wallet); ENTER_CRITICAL_SECTION(wallet->wallet()->cs_wallet);
ENTER_CRITICAL_SECTION(cs_wallets); ENTER_CRITICAL_SECTION(context.wallets_mutex);
}); });
wallet = TestLoadWallet(m_node.chain.get()); wallet = TestLoadWallet(context);
BOOST_CHECK_EQUAL(addtx_count, 4); BOOST_CHECK_EQUAL(addtx_count, 4);
{ {
LOCK(wallet->cs_wallet); LOCK(wallet->cs_wallet);
@ -773,7 +780,8 @@ BOOST_FIXTURE_TEST_CASE(CreateWallet, TestChain100Setup)
BOOST_FIXTURE_TEST_CASE(CreateWalletWithoutChain, BasicTestingSetup) BOOST_FIXTURE_TEST_CASE(CreateWalletWithoutChain, BasicTestingSetup)
{ {
auto wallet = TestLoadWallet(nullptr); WalletContext context;
auto wallet = TestLoadWallet(context);
BOOST_CHECK(wallet); BOOST_CHECK(wallet);
UnloadWallet(std::move(wallet)); UnloadWallet(std::move(wallet));
} }
@ -781,7 +789,9 @@ BOOST_FIXTURE_TEST_CASE(CreateWalletWithoutChain, BasicTestingSetup)
BOOST_FIXTURE_TEST_CASE(ZapSelectTx, TestChain100Setup) BOOST_FIXTURE_TEST_CASE(ZapSelectTx, TestChain100Setup)
{ {
gArgs.ForceSetArg("-unsafesqlitesync", "1"); gArgs.ForceSetArg("-unsafesqlitesync", "1");
auto wallet = TestLoadWallet(m_node.chain.get()); WalletContext context;
context.chain = m_node.chain.get();
auto wallet = TestLoadWallet(context);
CKey key; CKey key;
key.MakeNewKey(true); key.MakeNewKey(true);
AddKey(*wallet, key); AddKey(*wallet, key);

View file

@ -33,6 +33,7 @@
#include <util/string.h> #include <util/string.h>
#include <util/translation.h> #include <util/translation.h>
#include <wallet/coincontrol.h> #include <wallet/coincontrol.h>
#include <wallet/context.h>
#include <wallet/fees.h> #include <wallet/fees.h>
#include <wallet/external_signer_scriptpubkeyman.h> #include <wallet/external_signer_scriptpubkeyman.h>
@ -54,10 +55,6 @@ const std::map<uint64_t,std::string> WALLET_FLAG_CAVEATS{
}, },
}; };
RecursiveMutex cs_wallets;
static std::vector<std::shared_ptr<CWallet>> vpwallets GUARDED_BY(cs_wallets);
static std::list<LoadWalletFn> g_load_wallet_fns GUARDED_BY(cs_wallets);
bool AddWalletSetting(interfaces::Chain& chain, const std::string& wallet_name) bool AddWalletSetting(interfaces::Chain& chain, const std::string& wallet_name)
{ {
util::SettingsValue setting_value = chain.getRwSetting("wallet"); util::SettingsValue setting_value = chain.getRwSetting("wallet");
@ -104,19 +101,19 @@ static void RefreshMempoolStatus(CWalletTx& tx, interfaces::Chain& chain)
tx.fInMempool = chain.isInMempool(tx.GetHash()); tx.fInMempool = chain.isInMempool(tx.GetHash());
} }
bool AddWallet(const std::shared_ptr<CWallet>& wallet) bool AddWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet)
{ {
LOCK(cs_wallets); LOCK(context.wallets_mutex);
assert(wallet); assert(wallet);
std::vector<std::shared_ptr<CWallet>>::const_iterator i = std::find(vpwallets.begin(), vpwallets.end(), wallet); std::vector<std::shared_ptr<CWallet>>::const_iterator i = std::find(context.wallets.begin(), context.wallets.end(), wallet);
if (i != vpwallets.end()) return false; if (i != context.wallets.end()) return false;
vpwallets.push_back(wallet); context.wallets.push_back(wallet);
wallet->ConnectScriptPubKeyManNotifiers(); wallet->ConnectScriptPubKeyManNotifiers();
wallet->NotifyCanGetAddressesChanged(); wallet->NotifyCanGetAddressesChanged();
return true; return true;
} }
bool RemoveWallet(const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start, std::vector<bilingual_str>& warnings) bool RemoveWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start, std::vector<bilingual_str>& warnings)
{ {
assert(wallet); assert(wallet);
@ -125,10 +122,10 @@ bool RemoveWallet(const std::shared_ptr<CWallet>& wallet, std::optional<bool> lo
// Unregister with the validation interface which also drops shared ponters. // Unregister with the validation interface which also drops shared ponters.
wallet->m_chain_notifications_handler.reset(); wallet->m_chain_notifications_handler.reset();
LOCK(cs_wallets); LOCK(context.wallets_mutex);
std::vector<std::shared_ptr<CWallet>>::iterator i = std::find(vpwallets.begin(), vpwallets.end(), wallet); std::vector<std::shared_ptr<CWallet>>::iterator i = std::find(context.wallets.begin(), context.wallets.end(), wallet);
if (i == vpwallets.end()) return false; if (i == context.wallets.end()) return false;
vpwallets.erase(i); context.wallets.erase(i);
// Write the wallet setting // Write the wallet setting
UpdateWalletSetting(chain, name, load_on_start, warnings); UpdateWalletSetting(chain, name, load_on_start, warnings);
@ -136,32 +133,32 @@ bool RemoveWallet(const std::shared_ptr<CWallet>& wallet, std::optional<bool> lo
return true; return true;
} }
bool RemoveWallet(const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start) bool RemoveWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start)
{ {
std::vector<bilingual_str> warnings; std::vector<bilingual_str> warnings;
return RemoveWallet(wallet, load_on_start, warnings); return RemoveWallet(context, wallet, load_on_start, warnings);
} }
std::vector<std::shared_ptr<CWallet>> GetWallets() std::vector<std::shared_ptr<CWallet>> GetWallets(WalletContext& context)
{ {
LOCK(cs_wallets); LOCK(context.wallets_mutex);
return vpwallets; return context.wallets;
} }
std::shared_ptr<CWallet> GetWallet(const std::string& name) std::shared_ptr<CWallet> GetWallet(WalletContext& context, const std::string& name)
{ {
LOCK(cs_wallets); LOCK(context.wallets_mutex);
for (const std::shared_ptr<CWallet>& wallet : vpwallets) { for (const std::shared_ptr<CWallet>& wallet : context.wallets) {
if (wallet->GetName() == name) return wallet; if (wallet->GetName() == name) return wallet;
} }
return nullptr; return nullptr;
} }
std::unique_ptr<interfaces::Handler> HandleLoadWallet(LoadWalletFn load_wallet) std::unique_ptr<interfaces::Handler> HandleLoadWallet(WalletContext& context, LoadWalletFn load_wallet)
{ {
LOCK(cs_wallets); LOCK(context.wallets_mutex);
auto it = g_load_wallet_fns.emplace(g_load_wallet_fns.end(), std::move(load_wallet)); auto it = context.wallet_load_fns.emplace(context.wallet_load_fns.end(), std::move(load_wallet));
return interfaces::MakeHandler([it] { LOCK(cs_wallets); g_load_wallet_fns.erase(it); }); return interfaces::MakeHandler([&context, it] { LOCK(context.wallets_mutex); context.wallet_load_fns.erase(it); });
} }
static Mutex g_loading_wallet_mutex; static Mutex g_loading_wallet_mutex;
@ -213,7 +210,7 @@ void UnloadWallet(std::shared_ptr<CWallet>&& wallet)
} }
namespace { namespace {
std::shared_ptr<CWallet> LoadWalletInternal(interfaces::Chain& chain, const std::string& name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings) std::shared_ptr<CWallet> LoadWalletInternal(WalletContext& context, const std::string& name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings)
{ {
try { try {
std::unique_ptr<WalletDatabase> database = MakeWalletDatabase(name, options, status, error); std::unique_ptr<WalletDatabase> database = MakeWalletDatabase(name, options, status, error);
@ -222,18 +219,18 @@ std::shared_ptr<CWallet> LoadWalletInternal(interfaces::Chain& chain, const std:
return nullptr; return nullptr;
} }
chain.initMessage(_("Loading wallet…").translated); context.chain->initMessage(_("Loading wallet…").translated);
std::shared_ptr<CWallet> wallet = CWallet::Create(&chain, name, std::move(database), options.create_flags, error, warnings); std::shared_ptr<CWallet> wallet = CWallet::Create(context, name, std::move(database), options.create_flags, error, warnings);
if (!wallet) { if (!wallet) {
error = Untranslated("Wallet loading failed.") + Untranslated(" ") + error; error = Untranslated("Wallet loading failed.") + Untranslated(" ") + error;
status = DatabaseStatus::FAILED_LOAD; status = DatabaseStatus::FAILED_LOAD;
return nullptr; return nullptr;
} }
AddWallet(wallet); AddWallet(context, wallet);
wallet->postInitProcess(); wallet->postInitProcess();
// Write the wallet setting // Write the wallet setting
UpdateWalletSetting(chain, name, load_on_start, warnings); UpdateWalletSetting(*context.chain, name, load_on_start, warnings);
return wallet; return wallet;
} catch (const std::runtime_error& e) { } catch (const std::runtime_error& e) {
@ -244,7 +241,7 @@ std::shared_ptr<CWallet> LoadWalletInternal(interfaces::Chain& chain, const std:
} }
} // namespace } // namespace
std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const std::string& name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings) std::shared_ptr<CWallet> LoadWallet(WalletContext& context, const std::string& name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings)
{ {
auto result = WITH_LOCK(g_loading_wallet_mutex, return g_loading_wallet_set.insert(name)); auto result = WITH_LOCK(g_loading_wallet_mutex, return g_loading_wallet_set.insert(name));
if (!result.second) { if (!result.second) {
@ -252,12 +249,12 @@ std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const std::string&
status = DatabaseStatus::FAILED_LOAD; status = DatabaseStatus::FAILED_LOAD;
return nullptr; return nullptr;
} }
auto wallet = LoadWalletInternal(chain, name, load_on_start, options, status, error, warnings); auto wallet = LoadWalletInternal(context, name, load_on_start, options, status, error, warnings);
WITH_LOCK(g_loading_wallet_mutex, g_loading_wallet_set.erase(result.first)); WITH_LOCK(g_loading_wallet_mutex, g_loading_wallet_set.erase(result.first));
return wallet; return wallet;
} }
std::shared_ptr<CWallet> CreateWallet(interfaces::Chain& chain, const std::string& name, std::optional<bool> load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings) std::shared_ptr<CWallet> CreateWallet(WalletContext& context, const std::string& name, std::optional<bool> load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings)
{ {
uint64_t wallet_creation_flags = options.create_flags; uint64_t wallet_creation_flags = options.create_flags;
const SecureString& passphrase = options.create_passphrase; const SecureString& passphrase = options.create_passphrase;
@ -302,8 +299,8 @@ std::shared_ptr<CWallet> CreateWallet(interfaces::Chain& chain, const std::strin
} }
// Make the wallet // Make the wallet
chain.initMessage(_("Loading wallet…").translated); context.chain->initMessage(_("Loading wallet…").translated);
std::shared_ptr<CWallet> wallet = CWallet::Create(&chain, name, std::move(database), wallet_creation_flags, error, warnings); std::shared_ptr<CWallet> wallet = CWallet::Create(context, name, std::move(database), wallet_creation_flags, error, warnings);
if (!wallet) { if (!wallet) {
error = Untranslated("Wallet creation failed.") + Untranslated(" ") + error; error = Untranslated("Wallet creation failed.") + Untranslated(" ") + error;
status = DatabaseStatus::FAILED_CREATE; status = DatabaseStatus::FAILED_CREATE;
@ -345,11 +342,11 @@ std::shared_ptr<CWallet> CreateWallet(interfaces::Chain& chain, const std::strin
wallet->Lock(); wallet->Lock();
} }
} }
AddWallet(wallet); AddWallet(context, wallet);
wallet->postInitProcess(); wallet->postInitProcess();
// Write the wallet settings // Write the wallet settings
UpdateWalletSetting(chain, name, load_on_start, warnings); UpdateWalletSetting(*context.chain, name, load_on_start, warnings);
status = DatabaseStatus::SUCCESS; status = DatabaseStatus::SUCCESS;
return wallet; return wallet;
@ -1802,9 +1799,9 @@ void CWallet::ResendWalletTransactions()
/** @} */ // end of mapWallet /** @} */ // end of mapWallet
void MaybeResendWalletTxs() void MaybeResendWalletTxs(WalletContext& context)
{ {
for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { for (const std::shared_ptr<CWallet>& pwallet : GetWallets(context)) {
pwallet->ResendWalletTransactions(); pwallet->ResendWalletTransactions();
} }
} }
@ -2509,8 +2506,9 @@ std::unique_ptr<WalletDatabase> MakeWalletDatabase(const std::string& name, cons
return MakeDatabase(wallet_path, options, status, error_string); return MakeDatabase(wallet_path, options, status, error_string);
} }
std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain* chain, const std::string& name, std::unique_ptr<WalletDatabase> database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector<bilingual_str>& warnings) std::shared_ptr<CWallet> CWallet::Create(WalletContext& context, const std::string& name, std::unique_ptr<WalletDatabase> database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector<bilingual_str>& warnings)
{ {
interfaces::Chain* chain = context.chain;
const std::string& walletFile = database->Filename(); const std::string& walletFile = database->Filename();
int64_t nStart = GetTimeMillis(); int64_t nStart = GetTimeMillis();
@ -2722,9 +2720,9 @@ std::shared_ptr<CWallet> CWallet::Create(interfaces::Chain* chain, const std::st
} }
{ {
LOCK(cs_wallets); LOCK(context.wallets_mutex);
for (auto& load_wallet : g_load_wallet_fns) { for (auto& load_wallet : context.wallet_load_fns) {
load_wallet(interfaces::MakeWallet(walletInstance)); load_wallet(interfaces::MakeWallet(context, walletInstance));
} }
} }

View file

@ -42,6 +42,8 @@
#include <boost/signals2/signal.hpp> #include <boost/signals2/signal.hpp>
struct WalletContext;
using LoadWalletFn = std::function<void(std::unique_ptr<interfaces::Wallet> wallet)>; using LoadWalletFn = std::function<void(std::unique_ptr<interfaces::Wallet> wallet)>;
struct bilingual_str; struct bilingual_str;
@ -53,14 +55,14 @@ struct bilingual_str;
//! by the shared pointer deleter. //! by the shared pointer deleter.
void UnloadWallet(std::shared_ptr<CWallet>&& wallet); void UnloadWallet(std::shared_ptr<CWallet>&& wallet);
bool AddWallet(const std::shared_ptr<CWallet>& wallet); bool AddWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet);
bool RemoveWallet(const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start, std::vector<bilingual_str>& warnings); bool RemoveWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start, std::vector<bilingual_str>& warnings);
bool RemoveWallet(const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start); bool RemoveWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet, std::optional<bool> load_on_start);
std::vector<std::shared_ptr<CWallet>> GetWallets(); std::vector<std::shared_ptr<CWallet>> GetWallets(WalletContext& context);
std::shared_ptr<CWallet> GetWallet(const std::string& name); std::shared_ptr<CWallet> GetWallet(WalletContext& context, const std::string& name);
std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const std::string& name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings); std::shared_ptr<CWallet> LoadWallet(WalletContext& context, const std::string& name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings);
std::shared_ptr<CWallet> CreateWallet(interfaces::Chain& chain, const std::string& name, std::optional<bool> load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings); std::shared_ptr<CWallet> CreateWallet(WalletContext& context, const std::string& name, std::optional<bool> load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings);
std::unique_ptr<interfaces::Handler> HandleLoadWallet(LoadWalletFn load_wallet); std::unique_ptr<interfaces::Handler> HandleLoadWallet(WalletContext& context, LoadWalletFn load_wallet);
std::unique_ptr<WalletDatabase> MakeWalletDatabase(const std::string& name, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error); std::unique_ptr<WalletDatabase> MakeWalletDatabase(const std::string& name, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error);
//! -paytxfee default //! -paytxfee default
@ -772,7 +774,7 @@ public:
bool MarkReplaced(const uint256& originalHash, const uint256& newHash); bool MarkReplaced(const uint256& originalHash, const uint256& newHash);
/* Initializes the wallet, returns a new CWallet instance or a null pointer in case of an error */ /* Initializes the wallet, returns a new CWallet instance or a null pointer in case of an error */
static std::shared_ptr<CWallet> Create(interfaces::Chain* chain, const std::string& name, std::unique_ptr<WalletDatabase> database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector<bilingual_str>& warnings); static std::shared_ptr<CWallet> Create(WalletContext& context, const std::string& name, std::unique_ptr<WalletDatabase> database, uint64_t wallet_creation_flags, bilingual_str& error, std::vector<bilingual_str>& warnings);
/** /**
* Wallet post-init setup * Wallet post-init setup
@ -919,7 +921,7 @@ public:
* Called periodically by the schedule thread. Prompts individual wallets to resend * Called periodically by the schedule thread. Prompts individual wallets to resend
* their transactions. Actual rebroadcast schedule is managed by the wallets themselves. * their transactions. Actual rebroadcast schedule is managed by the wallets themselves.
*/ */
void MaybeResendWalletTxs(); void MaybeResendWalletTxs(WalletContext& context);
/** RAII object to check and reserve a wallet rescan */ /** RAII object to check and reserve a wallet rescan */
class WalletRescanReserver class WalletRescanReserver

View file

@ -1004,14 +1004,14 @@ DBErrors WalletBatch::ZapSelectTx(std::vector<uint256>& vTxHashIn, std::vector<u
return DBErrors::LOAD_OK; return DBErrors::LOAD_OK;
} }
void MaybeCompactWalletDB() void MaybeCompactWalletDB(WalletContext& context)
{ {
static std::atomic<bool> fOneThread(false); static std::atomic<bool> fOneThread(false);
if (fOneThread.exchange(true)) { if (fOneThread.exchange(true)) {
return; return;
} }
for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { for (const std::shared_ptr<CWallet>& pwallet : GetWallets(context)) {
WalletDatabase& dbh = pwallet->GetDatabase(); WalletDatabase& dbh = pwallet->GetDatabase();
unsigned int nUpdateCounter = dbh.nUpdateCounter; unsigned int nUpdateCounter = dbh.nUpdateCounter;

View file

@ -31,6 +31,7 @@
static const bool DEFAULT_FLUSHWALLET = true; static const bool DEFAULT_FLUSHWALLET = true;
struct CBlockLocator; struct CBlockLocator;
struct WalletContext;
class CKeyPool; class CKeyPool;
class CMasterKey; class CMasterKey;
class CScript; class CScript;
@ -279,7 +280,7 @@ private:
}; };
//! Compacts BDB state so that wallet.dat is self-contained (if there are changes) //! Compacts BDB state so that wallet.dat is self-contained (if there are changes)
void MaybeCompactWalletDB(); void MaybeCompactWalletDB(WalletContext& context);
//! Callback for filtering key types to deserialize in ReadKeyValue //! Callback for filtering key types to deserialize in ReadKeyValue
using KeyFilterFn = std::function<bool(const std::string&)>; using KeyFilterFn = std::function<bool(const std::string&)>;