diff --git a/src/psbt.cpp b/src/psbt.cpp index 94e80cd075e..6cf4fa014d7 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -424,6 +424,13 @@ PSBTError SignPSBTInput(const SigningProvider& provider, PartiallySignedTransact if (input.sighash_type && input.sighash_type != sighash) { return PSBTError::SIGHASH_MISMATCH; } + // Set the PSBT sighash field when sighash is not DEFAULT or ALL + // DEFAULT is allowed for non-taproot inputs since DEFAULT may be passed for them (e.g. the psbt being signed also has taproot inputs) + // Note that signing already aliases DEFAULT to ALL for non-taproot inputs. + if (utxo.scriptPubKey.IsPayToTaproot() ? sighash != SIGHASH_DEFAULT : + (sighash != SIGHASH_DEFAULT && sighash != SIGHASH_ALL)) { + input.sighash_type = sighash; + } // Check all existing signatures use the sighash type if (sighash == SIGHASH_DEFAULT) { @@ -522,7 +529,8 @@ bool FinalizePSBT(PartiallySignedTransaction& psbtx) bool complete = true; const PrecomputedTransactionData txdata = PrecomputePSBTData(psbtx); for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) { - complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, std::nullopt, nullptr, true) == PSBTError::OK); + PSBTInput& input = psbtx.inputs.at(i); + complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, input.sighash_type, nullptr, true) == PSBTError::OK); } return complete; diff --git a/test/functional/rpc_psbt.py b/test/functional/rpc_psbt.py index 3e257287415..151ed4cc6b7 100755 --- a/test/functional/rpc_psbt.py +++ b/test/functional/rpc_psbt.py @@ -228,6 +228,53 @@ class PSBTTest(BitcoinTestFramework): wallet.unloadwallet() + def test_sighash_adding(self): + self.log.info("Test adding of sighash type field") + self.nodes[0].createwallet("sighash_adding") + wallet = self.nodes[0].get_wallet_rpc("sighash_adding") + def_wallet = self.nodes[0].get_wallet_rpc(self.default_wallet_name) + + addr = wallet.getnewaddress(address_type="bech32") + outputs = [{addr: 1}] + outputs.append({wallet.getnewaddress(address_type="bech32m"): 1}) + descs = wallet.listdescriptors(True)["descriptors"] + def_wallet.send(outputs) + self.generate(self.nodes[0], 6) + utxos = wallet.listunspent() + + # Make a PSBT + psbt = wallet.walletcreatefundedpsbt(utxos, [{def_wallet.getnewaddress(): 0.5}])["psbt"] + + # Process the PSBT with the wallet + wallet_psbt = wallet.walletprocesspsbt(psbt=psbt, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"] + + # Separately process the PSBT with descriptors + desc_psbt = self.nodes[0].descriptorprocesspsbt(psbt=psbt, descriptors=descs, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"] + + for psbt in [wallet_psbt, desc_psbt]: + # Check that the PSBT has a sighash field on all inputs + dec_psbt = self.nodes[0].decodepsbt(psbt) + for input in dec_psbt["inputs"]: + assert_equal(input["sighash"], "ALL|ANYONECANPAY") + + # Make sure we can still finalize the transaction + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], True) + fin_hex = fin_res["hex"] + + # Change the sighash field to a different value and make sure we can no longer finalize + mod_psbt = PSBT.from_base64(psbt) + mod_psbt.i[0].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little") + mod_psbt.i[1].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little") + psbt = mod_psbt.to_base64() + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], False) + + self.nodes[0].sendrawtransaction(fin_hex) + self.generate(self.nodes[0], 1) + + wallet.unloadwallet() + def assert_change_type(self, psbtx, expected_type): """Assert that the given PSBT has a change output with the given type.""" @@ -1064,6 +1111,7 @@ class PSBTTest(BitcoinTestFramework): assert_raises_rpc_error(-8, "'all' is not a valid sighash parameter.", self.nodes[2].descriptorprocesspsbt, psbt, [descriptor], sighashtype="all") self.test_sighash_mismatch() + self.test_sighash_adding() if __name__ == '__main__': PSBTTest(__file__).main()