descriptors: Have GetPubKey fill origins directly

Instead of having ExpandHelper fill in the origins in the
FlatSigningProvider output, have GetPubKey do it by itself. This reduces
the extra variables needed in order to track and set origins in
ExpandHelper.

Also changes GetPubKey to return a std::optional<CPubKey> rather than
using a bool and output parameters.
This commit is contained in:
Ava Chow 2024-01-22 17:07:50 -05:00
parent 6268bde0af
commit 25a3b9b0f5

View file

@ -174,22 +174,20 @@ public:
* Used by the Miniscript descriptors to check for duplicate keys in the script. * Used by the Miniscript descriptors to check for duplicate keys in the script.
*/ */
bool operator<(PubkeyProvider& other) const { bool operator<(PubkeyProvider& other) const {
CPubKey a, b; FlatSigningProvider dummy;
SigningProvider dummy;
KeyOriginInfo dummy_info;
GetPubKey(0, dummy, a, dummy_info); std::optional<CPubKey> a = GetPubKey(0, dummy, dummy);
other.GetPubKey(0, dummy, b, dummy_info); std::optional<CPubKey> b = other.GetPubKey(0, dummy, dummy);
return a < b; return a < b;
} }
/** Derive a public key. /** Derive a public key and put it into out.
* read_cache is the cache to read keys from (if not nullptr) * read_cache is the cache to read keys from (if not nullptr)
* write_cache is the cache to write keys to (if not nullptr) * write_cache is the cache to write keys to (if not nullptr)
* Caches are not exclusive but this is not tested. Currently we use them exclusively * Caches are not exclusive but this is not tested. Currently we use them exclusively
*/ */
virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0; virtual std::optional<CPubKey> GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0;
/** Whether this represent multiple public keys at different positions. */ /** Whether this represent multiple public keys at different positions. */
virtual bool IsRange() const = 0; virtual bool IsRange() const = 0;
@ -240,12 +238,15 @@ class OriginPubkeyProvider final : public PubkeyProvider
public: public:
OriginPubkeyProvider(uint32_t exp_index, KeyOriginInfo info, std::unique_ptr<PubkeyProvider> provider, bool apostrophe) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)), m_apostrophe(apostrophe) {} OriginPubkeyProvider(uint32_t exp_index, KeyOriginInfo info, std::unique_ptr<PubkeyProvider> provider, bool apostrophe) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)), m_apostrophe(apostrophe) {}
bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override std::optional<CPubKey> GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
{ {
if (!m_provider->GetPubKey(pos, arg, key, info, read_cache, write_cache)) return false; std::optional<CPubKey> pub = m_provider->GetPubKey(pos, arg, out, read_cache, write_cache);
std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), info.fingerprint); if (!pub) return std::nullopt;
info.path.insert(info.path.begin(), m_origin.path.begin(), m_origin.path.end()); auto& [pubkey, suborigin] = out.origins[pub->GetID()];
return true; Assert(pubkey == *pub); // m_provider must have a valid origin by this point.
std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), suborigin.fingerprint);
suborigin.path.insert(suborigin.path.begin(), m_origin.path.begin(), m_origin.path.end());
return pub;
} }
bool IsRange() const override { return m_provider->IsRange(); } bool IsRange() const override { return m_provider->IsRange(); }
size_t GetSize() const override { return m_provider->GetSize(); } size_t GetSize() const override { return m_provider->GetSize(); }
@ -298,13 +299,13 @@ class ConstPubkeyProvider final : public PubkeyProvider
public: public:
ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey, bool xonly) : PubkeyProvider(exp_index), m_pubkey(pubkey), m_xonly(xonly) {} ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey, bool xonly) : PubkeyProvider(exp_index), m_pubkey(pubkey), m_xonly(xonly) {}
bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override std::optional<CPubKey> GetPubKey(int pos, const SigningProvider&, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
{ {
key = m_pubkey; KeyOriginInfo info;
info.path.clear();
CKeyID keyid = m_pubkey.GetID(); CKeyID keyid = m_pubkey.GetID();
std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint); std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint);
return true; out.origins.emplace(keyid, std::make_pair(m_pubkey, info));
return m_pubkey;
} }
bool IsRange() const override { return false; } bool IsRange() const override { return false; }
size_t GetSize() const override { return m_pubkey.size(); } size_t GetSize() const override { return m_pubkey.size(); }
@ -394,7 +395,7 @@ public:
BIP32PubkeyProvider(uint32_t exp_index, const CExtPubKey& extkey, KeyPath path, DeriveType derive, bool apostrophe) : PubkeyProvider(exp_index), m_root_extkey(extkey), m_path(std::move(path)), m_derive(derive), m_apostrophe(apostrophe) {} BIP32PubkeyProvider(uint32_t exp_index, const CExtPubKey& extkey, KeyPath path, DeriveType derive, bool apostrophe) : PubkeyProvider(exp_index), m_root_extkey(extkey), m_path(std::move(path)), m_derive(derive), m_apostrophe(apostrophe) {}
bool IsRange() const override { return m_derive != DeriveType::NO; } bool IsRange() const override { return m_derive != DeriveType::NO; }
size_t GetSize() const override { return 33; } size_t GetSize() const override { return 33; }
bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key_out, KeyOriginInfo& final_info_out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override std::optional<CPubKey> GetPubKey(int pos, const SigningProvider& arg, FlatSigningProvider& out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override
{ {
KeyOriginInfo info; KeyOriginInfo info;
CKeyID keyid = m_root_extkey.pubkey.GetID(); CKeyID keyid = m_root_extkey.pubkey.GetID();
@ -410,16 +411,16 @@ public:
bool der = true; bool der = true;
if (read_cache) { if (read_cache) {
if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, final_extkey)) { if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, final_extkey)) {
if (m_derive == DeriveType::HARDENED) return false; if (m_derive == DeriveType::HARDENED) return std::nullopt;
// Try to get the derivation parent // Try to get the derivation parent
if (!read_cache->GetCachedParentExtPubKey(m_expr_index, parent_extkey)) return false; if (!read_cache->GetCachedParentExtPubKey(m_expr_index, parent_extkey)) return std::nullopt;
final_extkey = parent_extkey; final_extkey = parent_extkey;
if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos);
} }
} else if (IsHardened()) { } else if (IsHardened()) {
CExtKey xprv; CExtKey xprv;
CExtKey lh_xprv; CExtKey lh_xprv;
if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return false; if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return std::nullopt;
parent_extkey = xprv.Neuter(); parent_extkey = xprv.Neuter();
if (m_derive == DeriveType::UNHARDENED) der = xprv.Derive(xprv, pos); if (m_derive == DeriveType::UNHARDENED) der = xprv.Derive(xprv, pos);
if (m_derive == DeriveType::HARDENED) der = xprv.Derive(xprv, pos | 0x80000000UL); if (m_derive == DeriveType::HARDENED) der = xprv.Derive(xprv, pos | 0x80000000UL);
@ -429,16 +430,15 @@ public:
} }
} else { } else {
for (auto entry : m_path) { for (auto entry : m_path) {
if (!parent_extkey.Derive(parent_extkey, entry)) return false; if (!parent_extkey.Derive(parent_extkey, entry)) return std::nullopt;
} }
final_extkey = parent_extkey; final_extkey = parent_extkey;
if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos);
assert(m_derive != DeriveType::HARDENED); assert(m_derive != DeriveType::HARDENED);
} }
if (!der) return false; if (!der) return std::nullopt;
final_info_out = info; out.origins.emplace(final_extkey.pubkey.GetID(), std::make_pair(final_extkey.pubkey, info));
key_out = final_extkey.pubkey;
if (write_cache) { if (write_cache) {
// Only cache parent if there is any unhardened derivation // Only cache parent if there is any unhardened derivation
@ -448,12 +448,12 @@ public:
if (last_hardened_extkey.pubkey.IsValid()) { if (last_hardened_extkey.pubkey.IsValid()) {
write_cache->CacheLastHardenedExtPubKey(m_expr_index, last_hardened_extkey); write_cache->CacheLastHardenedExtPubKey(m_expr_index, last_hardened_extkey);
} }
} else if (final_info_out.path.size() > 0) { } else if (info.path.size() > 0) {
write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey); write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey);
} }
} }
return true; return final_extkey.pubkey;
} }
std::string ToString(StringType type, bool normalized) const std::string ToString(StringType type, bool normalized) const
{ {
@ -696,16 +696,17 @@ public:
// NOLINTNEXTLINE(misc-no-recursion) // NOLINTNEXTLINE(misc-no-recursion)
bool ExpandHelper(int pos, const SigningProvider& arg, const DescriptorCache* read_cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache) const bool ExpandHelper(int pos, const SigningProvider& arg, const DescriptorCache* read_cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache) const
{ {
std::vector<std::pair<CPubKey, KeyOriginInfo>> entries; FlatSigningProvider subprovider;
entries.reserve(m_pubkey_args.size()); std::vector<CPubKey> pubkeys;
pubkeys.reserve(m_pubkey_args.size());
// Construct temporary data in `entries`, `subscripts`, and `subprovider` to avoid producing output in case of failure. // Construct temporary data in `pubkeys`, `subscripts`, and `subprovider` to avoid producing output in case of failure.
for (const auto& p : m_pubkey_args) { for (const auto& p : m_pubkey_args) {
entries.emplace_back(); std::optional<CPubKey> pubkey = p->GetPubKey(pos, arg, subprovider, read_cache, write_cache);
if (!p->GetPubKey(pos, arg, entries.back().first, entries.back().second, read_cache, write_cache)) return false; if (!pubkey) return false;
pubkeys.push_back(pubkey.value());
} }
std::vector<CScript> subscripts; std::vector<CScript> subscripts;
FlatSigningProvider subprovider;
for (const auto& subarg : m_subdescriptor_args) { for (const auto& subarg : m_subdescriptor_args) {
std::vector<CScript> outscripts; std::vector<CScript> outscripts;
if (!subarg->ExpandHelper(pos, arg, read_cache, outscripts, subprovider, write_cache)) return false; if (!subarg->ExpandHelper(pos, arg, read_cache, outscripts, subprovider, write_cache)) return false;
@ -714,13 +715,6 @@ public:
} }
out.Merge(std::move(subprovider)); out.Merge(std::move(subprovider));
std::vector<CPubKey> pubkeys;
pubkeys.reserve(entries.size());
for (auto& entry : entries) {
pubkeys.push_back(entry.first);
out.origins.emplace(entry.first.GetID(), std::make_pair<CPubKey, KeyOriginInfo>(CPubKey(entry.first), std::move(entry.second)));
}
output_scripts = MakeScripts(pubkeys, std::span{subscripts}, out); output_scripts = MakeScripts(pubkeys, std::span{subscripts}, out);
return true; return true;
} }