Enforce PSBT version constraints

With PSBTv2, some fields are not allowed in PSBTv2, and some are
required. Enforce those.
This commit is contained in:
Ava Chow 2024-07-22 17:14:08 -04:00
parent 964e0a7bdc
commit 057249398d

View file

@ -463,26 +463,29 @@ struct PSBTInput
SerializeToVector(s, final_script_witness.stack); SerializeToVector(s, final_script_witness.stack);
} }
// Write prev txid, vout, sequence, and lock times // Write PSBTv2 fields
if (!prev_txid.IsNull()) { if (m_psbt_version >= 2) {
SerializeToVector(s, CompactSizeWriter(PSBT_IN_PREVIOUS_TXID)); // Write prev txid, vout, sequence, and lock times
SerializeToVector(s, prev_txid); if (!prev_txid.IsNull()) {
} SerializeToVector(s, CompactSizeWriter(PSBT_IN_PREVIOUS_TXID));
if (prev_out != std::nullopt) { SerializeToVector(s, prev_txid);
SerializeToVector(s, CompactSizeWriter(PSBT_IN_OUTPUT_INDEX)); }
SerializeToVector(s, *prev_out); if (prev_out != std::nullopt) {
} SerializeToVector(s, CompactSizeWriter(PSBT_IN_OUTPUT_INDEX));
if (sequence != std::nullopt) { SerializeToVector(s, *prev_out);
SerializeToVector(s, CompactSizeWriter(PSBT_IN_SEQUENCE)); }
SerializeToVector(s, *sequence); if (sequence != std::nullopt) {
} SerializeToVector(s, CompactSizeWriter(PSBT_IN_SEQUENCE));
if (time_locktime != std::nullopt) { SerializeToVector(s, *sequence);
SerializeToVector(s, CompactSizeWriter(PSBT_IN_REQUIRED_TIME_LOCKTIME)); }
SerializeToVector(s, *time_locktime); if (time_locktime != std::nullopt) {
} SerializeToVector(s, CompactSizeWriter(PSBT_IN_REQUIRED_TIME_LOCKTIME));
if (height_locktime != std::nullopt) { SerializeToVector(s, *time_locktime);
SerializeToVector(s, CompactSizeWriter(PSBT_IN_REQUIRED_HEIGHT_LOCKTIME)); }
SerializeToVector(s, *height_locktime); if (height_locktime != std::nullopt) {
SerializeToVector(s, CompactSizeWriter(PSBT_IN_REQUIRED_HEIGHT_LOCKTIME));
SerializeToVector(s, *height_locktime);
}
} }
// Write proprietary things // Write proprietary things
@ -713,6 +716,8 @@ struct PSBTInput
throw std::ios_base::failure("Duplicate Key, previous txid is already provided"); throw std::ios_base::failure("Duplicate Key, previous txid is already provided");
} else if (key.size() != 1) { } else if (key.size() != 1) {
throw std::ios_base::failure("Previous txid key is more than one byte type"); throw std::ios_base::failure("Previous txid key is more than one byte type");
} else if (m_psbt_version == 0) {
throw std::ios_base::failure("Previous txid is not allowed in PSBTv0");
} }
UnserializeFromVector(s, prev_txid); UnserializeFromVector(s, prev_txid);
break; break;
@ -723,6 +728,8 @@ struct PSBTInput
throw std::ios_base::failure("Duplicate Key, previous output's index is already provided"); throw std::ios_base::failure("Duplicate Key, previous output's index is already provided");
} else if (key.size() != 1) { } else if (key.size() != 1) {
throw std::ios_base::failure("Previous output's index is more than one byte type"); throw std::ios_base::failure("Previous output's index is more than one byte type");
} else if (m_psbt_version == 0) {
throw std::ios_base::failure("Previous output's index is not allowed in PSBTv0");
} }
uint32_t v; uint32_t v;
UnserializeFromVector(s, v); UnserializeFromVector(s, v);
@ -735,6 +742,8 @@ struct PSBTInput
throw std::ios_base::failure("Duplicate Key, sequence is already provided"); throw std::ios_base::failure("Duplicate Key, sequence is already provided");
} else if (key.size() != 1) { } else if (key.size() != 1) {
throw std::ios_base::failure("Sequence key is more than one byte type"); throw std::ios_base::failure("Sequence key is more than one byte type");
} else if (m_psbt_version == 0) {
throw std::ios_base::failure("Sequence is not allowed in PSBTv0");
} }
uint32_t v; uint32_t v;
UnserializeFromVector(s, v); UnserializeFromVector(s, v);
@ -747,6 +756,8 @@ struct PSBTInput
throw std::ios_base::failure("Duplicate Key, required time based locktime is already provided"); throw std::ios_base::failure("Duplicate Key, required time based locktime is already provided");
} else if (key.size() != 1) { } else if (key.size() != 1) {
throw std::ios_base::failure("Required time based locktime is more than one byte type"); throw std::ios_base::failure("Required time based locktime is more than one byte type");
} else if (m_psbt_version == 0) {
throw std::ios_base::failure("Required time based locktime is not allowed in PSBTv0");
} }
uint32_t v; uint32_t v;
UnserializeFromVector(s, v); UnserializeFromVector(s, v);
@ -762,6 +773,8 @@ struct PSBTInput
throw std::ios_base::failure("Duplicate Key, required height based locktime is already provided"); throw std::ios_base::failure("Duplicate Key, required height based locktime is already provided");
} else if (key.size() != 1) { } else if (key.size() != 1) {
throw std::ios_base::failure("Required height based locktime is more than one byte type"); throw std::ios_base::failure("Required height based locktime is more than one byte type");
} else if (m_psbt_version == 0) {
throw std::ios_base::failure("Required height based locktime is not allowed in PSBTv0");
} }
uint32_t v; uint32_t v;
UnserializeFromVector(s, v); UnserializeFromVector(s, v);
@ -948,6 +961,16 @@ struct PSBTInput
if (!found_sep) { if (!found_sep) {
throw std::ios_base::failure("Separator is missing at the end of an input map"); throw std::ios_base::failure("Separator is missing at the end of an input map");
} }
// Make sure required PSBTv2 fields are present
if (m_psbt_version >= 2) {
if (prev_txid.IsNull()) {
throw std::ios_base::failure("Previous TXID is required in PSBTv2");
}
if (prev_out == std::nullopt) {
throw std::ios_base::failure("Previous output's index is required in PSBTv2");
}
}
} }
template <typename Stream> template <typename Stream>
@ -999,14 +1022,16 @@ struct PSBTOutput
// Write any hd keypaths // Write any hd keypaths
SerializeHDKeypaths(s, hd_keypaths, CompactSizeWriter(PSBT_OUT_BIP32_DERIVATION)); SerializeHDKeypaths(s, hd_keypaths, CompactSizeWriter(PSBT_OUT_BIP32_DERIVATION));
// Write amount and spk if (m_psbt_version >= 2) {
if (amount != std::nullopt) { // Write amount and spk
SerializeToVector(s, CompactSizeWriter(PSBT_OUT_AMOUNT)); if (amount != std::nullopt) {
SerializeToVector(s, *amount); SerializeToVector(s, CompactSizeWriter(PSBT_OUT_AMOUNT));
} SerializeToVector(s, *amount);
if (script.has_value()) { }
SerializeToVector(s, CompactSizeWriter(PSBT_OUT_SCRIPT)); if (script.has_value()) {
s << *script; SerializeToVector(s, CompactSizeWriter(PSBT_OUT_SCRIPT));
s << *script;
}
} }
// Write proprietary things // Write proprietary things
@ -1122,6 +1147,8 @@ struct PSBTOutput
throw std::ios_base::failure("Duplicate Key, output amount is already provided"); throw std::ios_base::failure("Duplicate Key, output amount is already provided");
} else if (key.size() != 1) { } else if (key.size() != 1) {
throw std::ios_base::failure("Output amount key is more than one byte type"); throw std::ios_base::failure("Output amount key is more than one byte type");
} else if (m_psbt_version == 0) {
throw std::ios_base::failure("Output amount is not allowed in PSBTv0");
} }
CAmount v; CAmount v;
UnserializeFromVector(s, v); UnserializeFromVector(s, v);
@ -1134,6 +1161,8 @@ struct PSBTOutput
throw std::ios_base::failure("Duplicate Key, output script is already provided"); throw std::ios_base::failure("Duplicate Key, output script is already provided");
} else if (key.size() != 1) { } else if (key.size() != 1) {
throw std::ios_base::failure("Output script key is more than one byte type"); throw std::ios_base::failure("Output script key is more than one byte type");
} else if (m_psbt_version == 0) {
throw std::ios_base::failure("Output script is not allowed in PSBTv0");
} }
CScript v; CScript v;
s >> v; s >> v;
@ -1247,6 +1276,16 @@ struct PSBTOutput
if (!found_sep) { if (!found_sep) {
throw std::ios_base::failure("Separator is missing at the end of an output map"); throw std::ios_base::failure("Separator is missing at the end of an output map");
} }
// Make sure required PSBTv2 fields are present
if (m_psbt_version >= 2) {
if (amount == std::nullopt) {
throw std::ios_base::failure("Output amount is required in PSBTv2");
}
if (!script.has_value()) {
throw std::ios_base::failure("Output script is required in PSBTv2");
}
}
} }
template <typename Stream> template <typename Stream>
@ -1296,11 +1335,13 @@ struct PartiallySignedTransaction
// magic bytes // magic bytes
s << PSBT_MAGIC_BYTES; s << PSBT_MAGIC_BYTES;
// unsigned tx flag if (GetVersion() == 0) {
SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_UNSIGNED_TX)); // unsigned tx flag
SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_UNSIGNED_TX));
// Write serialized tx to a stream // Write serialized tx to a stream
SerializeToVector(s, TX_NO_WITNESS(*tx)); SerializeToVector(s, TX_NO_WITNESS(*tx));
}
// Write xpubs // Write xpubs
for (const auto& xpub_pair : m_xpubs) { for (const auto& xpub_pair : m_xpubs) {
@ -1314,24 +1355,24 @@ struct PartiallySignedTransaction
} }
} }
// Write PSBTv2 tx version, locktime, counts, etc. if (GetVersion() >= 2) {
if (tx_version != std::nullopt) { // Write PSBTv2 tx version, locktime, counts, etc.
SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_TX_VERSION)); SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_TX_VERSION));
SerializeToVector(s, *tx_version); SerializeToVector(s, *tx_version);
} if (fallback_locktime != std::nullopt) {
if (fallback_locktime != std::nullopt) { SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_FALLBACK_LOCKTIME));
SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_FALLBACK_LOCKTIME)); SerializeToVector(s, *fallback_locktime);
SerializeToVector(s, *fallback_locktime); }
}
if (m_version != std::nullopt && *m_version >= 2) {
SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_INPUT_COUNT)); SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_INPUT_COUNT));
SerializeToVector(s, CompactSizeWriter(inputs.size())); SerializeToVector(s, CompactSizeWriter(inputs.size()));
SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_OUTPUT_COUNT)); SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_OUTPUT_COUNT));
SerializeToVector(s, CompactSizeWriter(outputs.size())); SerializeToVector(s, CompactSizeWriter(outputs.size()));
}
if (m_tx_modifiable != std::nullopt) { if (m_tx_modifiable != std::nullopt) {
SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_TX_MODIFIABLE)); SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_TX_MODIFIABLE));
SerializeToVector(s, static_cast<uint8_t>(m_tx_modifiable->to_ulong())); SerializeToVector(s, static_cast<uint8_t>(m_tx_modifiable->to_ulong()));
}
} }
// PSBT version // PSBT version
@ -1385,6 +1426,8 @@ struct PartiallySignedTransaction
bool found_sep = false; bool found_sep = false;
uint64_t input_count = 0; uint64_t input_count = 0;
uint64_t output_count = 0; uint64_t output_count = 0;
bool found_input_count = false;
bool found_output_count = false;
while(!s.empty()) { while(!s.empty()) {
// Read // Read
std::vector<unsigned char> key; std::vector<unsigned char> key;
@ -1458,6 +1501,7 @@ struct PartiallySignedTransaction
} }
CompactSizeReader reader(input_count); CompactSizeReader reader(input_count);
UnserializeFromVector(s, reader); UnserializeFromVector(s, reader);
found_input_count = true;
break; break;
} }
case PSBT_GLOBAL_OUTPUT_COUNT: case PSBT_GLOBAL_OUTPUT_COUNT:
@ -1469,6 +1513,7 @@ struct PartiallySignedTransaction
} }
CompactSizeReader reader(output_count); CompactSizeReader reader(output_count);
UnserializeFromVector(s, reader); UnserializeFromVector(s, reader);
found_output_count = true;
break; break;
} }
case PSBT_GLOBAL_TX_MODIFIABLE: case PSBT_GLOBAL_TX_MODIFIABLE:
@ -1559,13 +1604,52 @@ struct PartiallySignedTransaction
throw std::ios_base::failure("Separator is missing at the end of the global map"); throw std::ios_base::failure("Separator is missing at the end of the global map");
} }
// Make sure that we got an unsigned tx
if (!tx) {
throw std::ios_base::failure("No unsigned transaction was provided");
}
const uint32_t psbt_ver = GetVersion(); const uint32_t psbt_ver = GetVersion();
// Check PSBT version constraints
if (psbt_ver == 0) {
// Make sure that we got an unsigned tx for PSBTv0
if (!tx) {
throw std::ios_base::failure("No unsigned transaction was provided");
}
// Make sure no PSBTv2 fields are present
if (tx_version != std::nullopt) {
throw std::ios_base::failure("PSBT_GLOBAL_TX_VERSION is not allowed in PSBTv0");
}
if (fallback_locktime != std::nullopt) {
throw std::ios_base::failure("PSBT_GLOBAL_FALLBACK_LOCKTIME is not allowed in PSBTv0");
}
if (found_input_count) {
throw std::ios_base::failure("PSBT_GLOBAL_INPUT_COUNT is not allowed in PSBTv0");
}
if (found_output_count) {
throw std::ios_base::failure("PSBT_GLOBAL_OUTPUT_COUNT is not allowed in PSBTv0");
}
if (m_tx_modifiable != std::nullopt) {
throw std::ios_base::failure("PSBT_GLOBAL_TX_MODIFIABLE is not allowed in PSBTv0");
}
}
// Disallow v1
if (psbt_ver == 1) {
throw std::ios_base::failure("There is no PSBT version 1");
}
if (psbt_ver >= 2) {
// Tx version, input, and output counts are required
if (tx_version == std::nullopt) {
throw std::ios_base::failure("PSBT_GLOBAL_TX_VERSION is required in PSBTv2");
}
if (!found_input_count) {
throw std::ios_base::failure("PSBT_GLOBAL_INPUT_COUNT is required in PSBTv2");
}
if (!found_output_count) {
throw std::ios_base::failure("PSBT_GLOBAL_OUTPUT_COUNT is required in PSBTv2");
}
// Unsigned tx is disallowed
if (tx) {
throw std::ios_base::failure("PSBT_GLOBAL_UNSIGNED_TX is not allowed in PSBTv2");
}
}
// Read input data // Read input data
unsigned int i = 0; unsigned int i = 0;
while (!s.empty() && i < input_count) { while (!s.empty() && i < input_count) {
@ -1575,11 +1659,21 @@ struct PartiallySignedTransaction
// Make sure the non-witness utxo matches the outpoint // Make sure the non-witness utxo matches the outpoint
if (input.non_witness_utxo) { if (input.non_witness_utxo) {
if (input.non_witness_utxo->GetHash() != tx->vin[i].prevout.hash) { if (psbt_ver == 0) {
throw std::ios_base::failure("Non-witness UTXO does not match outpoint hash"); if (input.non_witness_utxo->GetHash() != tx->vin[i].prevout.hash) {
throw std::ios_base::failure("Non-witness UTXO does not match outpoint hash");
}
if (tx->vin[i].prevout.n >= input.non_witness_utxo->vout.size()) {
throw std::ios_base::failure("Input specifies output index that does not exist");
}
} }
if (tx->vin[i].prevout.n >= input.non_witness_utxo->vout.size()) { if (psbt_ver >= 2) {
throw std::ios_base::failure("Input specifies output index that does not exist"); if (input.non_witness_utxo->GetHash() != input.prev_txid) {
throw std::ios_base::failure("Non-witness UTXO does not match outpoint hash");
}
if (input.prev_out.value() >= input.non_witness_utxo->vout.size()) {
throw std::ios_base::failure("Input specifies output index that does not exist");
}
} }
} }
++i; ++i;