Implement joinpsbts RPC and tests

Adds a joinpsbts RPC which combines multiple distinct PSBTs into
one PSBT.
This commit is contained in:
Andrew Chow 2018-07-20 18:24:16 -07:00
parent 7344a7b998
commit 08f749c914
5 changed files with 115 additions and 0 deletions

View file

@ -42,6 +42,26 @@ bool PartiallySignedTransaction::IsSane() const
return true;
}
bool PartiallySignedTransaction::AddInput(const CTxIn& txin, PSBTInput& psbtin)
{
if (std::find(tx->vin.begin(), tx->vin.end(), txin) != tx->vin.end()) {
return false;
}
tx->vin.push_back(txin);
psbtin.partial_sigs.clear();
psbtin.final_script_sig.clear();
psbtin.final_script_witness.SetNull();
inputs.push_back(psbtin);
return true;
}
bool PartiallySignedTransaction::AddOutput(const CTxOut& txout, const PSBTOutput& psbtout)
{
tx->vout.push_back(txout);
outputs.push_back(psbtout);
return true;
}
bool PSBTInput::IsNull() const
{
return !non_witness_utxo && witness_utxo.IsNull() && partial_sigs.empty() && unknown.empty() && hd_keypaths.empty() && redeem_script.empty() && witness_script.empty();

View file

@ -389,6 +389,8 @@ struct PartiallySignedTransaction
* same actual Bitcoin transaction.) Returns true if the merge succeeded, false otherwise. */
NODISCARD bool Merge(const PartiallySignedTransaction& psbt);
bool IsSane() const;
bool AddInput(const CTxIn& txin, PSBTInput& psbtin);
bool AddOutput(const CTxOut& txout, const PSBTOutput& psbtout);
PartiallySignedTransaction() {}
PartiallySignedTransaction(const PartiallySignedTransaction& psbt_in) : tx(psbt_in.tx), inputs(psbt_in.inputs), outputs(psbt_in.outputs), unknown(psbt_in.unknown) {}
explicit PartiallySignedTransaction(const CMutableTransaction& tx);

View file

@ -112,6 +112,7 @@ static const CRPCConvertParam vRPCConvertParams[] =
{ "createpsbt", 2, "locktime" },
{ "createpsbt", 3, "replaceable" },
{ "combinepsbt", 0, "txs"},
{ "joinpsbts", 0, "txs"},
{ "finalizepsbt", 1, "extract"},
{ "converttopsbt", 1, "permitsigdata"},
{ "converttopsbt", 2, "iswitness"},

View file

@ -1755,6 +1755,80 @@ UniValue utxoupdatepsbt(const JSONRPCRequest& request)
return EncodeBase64((unsigned char*)ssTx.data(), ssTx.size());
}
UniValue joinpsbts(const JSONRPCRequest& request)
{
if (request.fHelp || request.params.size() != 1) {
throw std::runtime_error(
RPCHelpMan{"joinpsbts",
"\nJoins multiple distinct PSBTs with different inputs and outputs into one PSBT with inputs and outputs from all of the PSBTs\n"
"No input in any of the PSBTs can be in more than one of the PSBTs.\n",
{
{"txs", RPCArg::Type::ARR, RPCArg::Optional::NO, "A json array of base64 strings of partially signed transactions",
{
{"psbt", RPCArg::Type::STR, RPCArg::Optional::NO, "A base64 string of a PSBT"}
}}
},
RPCResult {
" \"psbt\" (string) The base64-encoded partially signed transaction\n"
},
RPCExamples {
HelpExampleCli("joinpsbts", "\"psbt\"")
}}.ToString());
}
RPCTypeCheck(request.params, {UniValue::VARR}, true);
// Unserialize the transactions
std::vector<PartiallySignedTransaction> psbtxs;
UniValue txs = request.params[0].get_array();
if (txs.size() <= 1) {
throw JSONRPCError(RPC_INVALID_PARAMETER, "At least two PSBTs are required to join PSBTs.");
}
int32_t best_version = 1;
uint32_t best_locktime = 0xffffffff;
for (unsigned int i = 0; i < txs.size(); ++i) {
PartiallySignedTransaction psbtx;
std::string error;
if (!DecodeBase64PSBT(psbtx, txs[i].get_str(), error)) {
throw JSONRPCError(RPC_DESERIALIZATION_ERROR, strprintf("TX decode failed %s", error));
}
psbtxs.push_back(psbtx);
// Choose the highest version number
if (psbtx.tx->nVersion > best_version) {
best_version = psbtx.tx->nVersion;
}
// Choose the lowest lock time
if (psbtx.tx->nLockTime < best_locktime) {
best_locktime = psbtx.tx->nLockTime;
}
}
// Create a blank psbt where everything will be added
PartiallySignedTransaction merged_psbt;
merged_psbt.tx = CMutableTransaction();
merged_psbt.tx->nVersion = best_version;
merged_psbt.tx->nLockTime = best_locktime;
// Merge
for (auto& psbt : psbtxs) {
for (unsigned int i = 0; i < psbt.tx->vin.size(); ++i) {
if (!merged_psbt.AddInput(psbt.tx->vin[i], psbt.inputs[i])) {
throw JSONRPCError(RPC_INVALID_PARAMETER, strprintf("Input %s:%d exists in multiple PSBTs", psbt.tx->vin[i].prevout.hash.ToString().c_str(), psbt.tx->vin[i].prevout.n));
}
}
for (unsigned int i = 0; i < psbt.tx->vout.size(); ++i) {
merged_psbt.AddOutput(psbt.tx->vout[i], psbt.outputs[i]);
}
merged_psbt.unknown.insert(psbt.unknown.begin(), psbt.unknown.end());
}
CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION);
ssTx << merged_psbt;
return EncodeBase64((unsigned char*)ssTx.data(), ssTx.size());
}
// clang-format off
static const CRPCCommand commands[] =
{ // category name actor (function) argNames
@ -1774,6 +1848,7 @@ static const CRPCCommand commands[] =
{ "rawtransactions", "createpsbt", &createpsbt, {"inputs","outputs","locktime","replaceable"} },
{ "rawtransactions", "converttopsbt", &converttopsbt, {"hexstring","permitsigdata","iswitness"} },
{ "rawtransactions", "utxoupdatepsbt", &utxoupdatepsbt, {"psbt"} },
{ "rawtransactions", "joinpsbts", &joinpsbts, {"txs"} },
{ "blockchain", "gettxoutproof", &gettxoutproof, {"txids", "blockhash"} },
{ "blockchain", "verifytxoutproof", &verifytxoutproof, {"proof"} },

View file

@ -321,6 +321,23 @@ class PSBTTest(BitcoinTestFramework):
assert "witness_utxo" not in decoded['inputs'][1] and "non_witness_utxo" not in decoded['inputs'][1]
assert "witness_utxo" not in decoded['inputs'][2] and "non_witness_utxo" not in decoded['inputs'][2]
# Two PSBTs with a common input should not be joinable
psbt1 = self.nodes[1].createpsbt([{"txid":txid1, "vout":vout1}], {self.nodes[0].getnewaddress():Decimal('10.999')})
assert_raises_rpc_error(-8, "exists in multiple PSBTs", self.nodes[1].joinpsbts, [psbt1, updated])
# Join two distinct PSBTs
addr4 = self.nodes[1].getnewaddress("", "p2sh-segwit")
txid4 = self.nodes[0].sendtoaddress(addr4, 5)
vout4 = find_output(self.nodes[0], txid4, 5)
self.nodes[0].generate(6)
self.sync_all()
psbt2 = self.nodes[1].createpsbt([{"txid":txid4, "vout":vout4}], {self.nodes[0].getnewaddress():Decimal('4.999')})
psbt2 = self.nodes[1].walletprocesspsbt(psbt2)['psbt']
psbt2_decoded = self.nodes[0].decodepsbt(psbt2)
assert "final_scriptwitness" in psbt2_decoded['inputs'][0] and "final_scriptSig" in psbt2_decoded['inputs'][0]
joined = self.nodes[0].joinpsbts([psbt, psbt2])
joined_decoded = self.nodes[0].decodepsbt(joined)
assert len(joined_decoded['inputs']) == 4 and len(joined_decoded['outputs']) == 2 and "final_scriptwitness" not in joined_decoded['inputs'][3] and "final_scriptSig" not in joined_decoded['inputs'][3]
if __name__ == '__main__':