diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index b971be5ddd..bc22dcd704 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -490,7 +490,9 @@ std::shared_ptr CreateWallet(WalletContext& context, const std::string& return wallet; } -std::shared_ptr RestoreWallet(WalletContext& context, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings) +// Re-creates wallet from the backup file by renaming and moving it into the wallet's directory. +// If 'load_after_restore=true', the wallet object will be fully initialized and appended to the context. +std::shared_ptr RestoreWallet(WalletContext& context, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings, bool load_after_restore) { DatabaseOptions options; ReadDatabaseArgs(*context.args, options); @@ -515,13 +517,17 @@ std::shared_ptr RestoreWallet(WalletContext& context, const fs::path& b fs::copy_file(backup_file, wallet_file, fs::copy_options::none); - wallet = LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings); + if (load_after_restore) { + wallet = LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings); + } } catch (const std::exception& e) { assert(!wallet); if (!error.empty()) error += Untranslated("\n"); error += Untranslated(strprintf("Unexpected exception: %s", e.what())); } - if (!wallet) { + + // Remove created wallet path only when loading fails + if (load_after_restore && !wallet) { fs::remove_all(wallet_path); } @@ -4527,7 +4533,7 @@ util::Result MigrateLegacyToDescriptor(const std::string& walle } if (!success) { // Migration failed, cleanup - // Copy the backup to the actual wallet dir + // Before deleting the wallet's directory, copy the backup file to the top-level wallets dir fs::path temp_backup_location = fsbridge::AbsPathJoin(GetWalletDir(), backup_filename); fs::copy_file(backup_path, temp_backup_location, fs::copy_options::none); @@ -4564,17 +4570,24 @@ util::Result MigrateLegacyToDescriptor(const std::string& walle } // Restore the backup - DatabaseStatus status; - std::vector warnings; - if (!RestoreWallet(context, temp_backup_location, wallet_name, /*load_on_start=*/std::nullopt, status, error, warnings)) { - error += _("\nUnable to restore backup of wallet."); + // Convert the backup file to the wallet db file by renaming it and moving it into the wallet's directory. + // Reload it into memory if the wallet was previously loaded. + bilingual_str restore_error; + const auto& ptr_wallet = RestoreWallet(context, temp_backup_location, wallet_name, /*load_on_start=*/std::nullopt, status, restore_error, warnings, /*load_after_restore=*/was_loaded); + if (!restore_error.empty()) { + error += restore_error + _("\nUnable to restore backup of wallet."); return util::Error{error}; } - // Move the backup to the wallet dir + // The wallet directory has been restored, but just in case, copy the previously created backup to the wallet dir fs::copy_file(temp_backup_location, backup_path, fs::copy_options::none); fs::remove(temp_backup_location); + // Verify that there is no dangling wallet: when the wallet wasn't loaded before, expect null. + // This check is performed after restoration to avoid an early error before saving the backup. + bool wallet_reloaded = ptr_wallet != nullptr; + assert(was_loaded == wallet_reloaded); + return util::Error{error}; } return res; diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index d869f031bb..c6a45e9a14 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -95,7 +95,7 @@ std::shared_ptr GetDefaultWallet(WalletContext& context, size_t& count) std::shared_ptr GetWallet(WalletContext& context, const std::string& name); std::shared_ptr LoadWallet(WalletContext& context, const std::string& name, std::optional load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); std::shared_ptr CreateWallet(WalletContext& context, const std::string& name, std::optional load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); -std::shared_ptr RestoreWallet(WalletContext& context, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings); +std::shared_ptr RestoreWallet(WalletContext& context, const fs::path& backup_file, const std::string& wallet_name, std::optional load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector& warnings, bool load_after_restore = true); std::unique_ptr HandleLoadWallet(WalletContext& context, LoadWalletFn load_wallet); void NotifyWalletLoaded(WalletContext& context, const std::shared_ptr& wallet); std::unique_ptr MakeWalletDatabase(const std::string& name, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error); diff --git a/test/functional/wallet_migration.py b/test/functional/wallet_migration.py index 3a56050731..5be56cec29 100755 --- a/test/functional/wallet_migration.py +++ b/test/functional/wallet_migration.py @@ -896,9 +896,7 @@ class WalletMigrationTest(BitcoinTestFramework): shutil.copytree(self.old_node.wallets_path / "failed", self.master_node.wallets_path / "failed") assert_raises_rpc_error(-4, "Failed to create database", self.master_node.migratewallet, "failed") - assert "failed" in self.master_node.listwallets() - assert "failed_watchonly" not in self.master_node.listwallets() - assert "failed_solvables" not in self.master_node.listwallets() + assert all(wallet not in self.master_node.listwallets() for wallet in ["failed", "failed_watchonly", "failed_solvables"]) assert not (self.master_node.wallets_path / "failed_watchonly").exists() # Since the file in failed_solvables is one that we put there, migration shouldn't touch it @@ -912,6 +910,22 @@ class WalletMigrationTest(BitcoinTestFramework): _, _, magic = struct.unpack("QII", data) assert_equal(magic, BTREE_MAGIC) + #################################################### + # Perform the same test with a loaded legacy wallet. + # The wallet should remain loaded after the failure. + # + # This applies only when BDB is enabled, as the user + # cannot interact with the legacy wallet database + # without BDB support. + if self.is_bdb_compiled() is not None: + # Advance time to generate a different backup name + self.master_node.setmocktime(self.master_node.getblockheader(self.master_node.getbestblockhash())['time'] + 100) + assert "failed" not in self.master_node.listwallets() + self.master_node.loadwallet("failed") + assert_raises_rpc_error(-4, "Failed to create database", self.master_node.migratewallet, "failed") + wallets = self.master_node.listwallets() + assert "failed" in wallets and all(wallet not in wallets for wallet in ["failed_watchonly", "failed_solvables"]) + def test_blank(self): self.log.info("Test that a blank wallet is migrated") wallet = self.create_legacy_wallet("blank", blank=True)