Update test_framework/psbt.py for PSBTv2

This commit is contained in:
Ava Chow 2024-07-22 17:14:42 -04:00
parent 1ad0eecdaa
commit 985fa0846f
2 changed files with 60 additions and 15 deletions

View file

@ -69,8 +69,8 @@ def signet_txs(block, challenge):
def decode_psbt(b64psbt):
psbt = PSBT.from_base64(b64psbt)
assert len(psbt.tx.vin) == 1
assert len(psbt.tx.vout) == 1
assert len(psbt.i) == 1
assert len(psbt.i) == 1
assert PSBT_SIGNET_BLOCK in psbt.g.map
scriptSig = psbt.i[0].map.get(PSBT_IN_FINAL_SCRIPTSIG, b"")

View file

@ -4,10 +4,14 @@
# file COPYING or http://www.opensource.org/licenses/mit-license.php.
import base64
import struct
from io import BytesIO
from .messages import (
CTransaction,
deser_string,
deser_compact_size,
from_binary,
ser_compact_size,
)
@ -100,37 +104,78 @@ class PSBT:
self.g = g if g is not None else PSBTMap()
self.i = i if i is not None else []
self.o = o if o is not None else []
self.tx = None
self.in_count = len(i) if i is not None else None
self.out_count = len(o) if o is not None else None
self.version = None
def deserialize(self, f):
assert f.read(5) == b"psbt\xff"
self.g = from_binary(PSBTMap, f)
assert PSBT_GLOBAL_UNSIGNED_TX in self.g.map
self.tx = from_binary(CTransaction, self.g.map[PSBT_GLOBAL_UNSIGNED_TX])
self.i = [from_binary(PSBTMap, f) for _ in self.tx.vin]
self.o = [from_binary(PSBTMap, f) for _ in self.tx.vout]
self.version = 0
if PSBT_GLOBAL_VERSION in self.g.map:
assert PSBT_GLOBAL_INPUT_COUNT in self.g.map
assert PSBT_GLOBAL_OUTPUT_COUNT in self.g.map
self.version = struct.unpack("<I", self.g.map[PSBT_GLOBAL_VERSION])[0]
assert self.version in [0, 2]
if self.version == 2:
self.in_count = deser_compact_size(BytesIO(self.g.map[PSBT_GLOBAL_INPUT_COUNT]))
self.out_count = deser_compact_size(BytesIO(self.g.map[PSBT_GLOBAL_OUTPUT_COUNT]))
else:
assert PSBT_GLOBAL_UNSIGNED_TX in self.g.map
tx = from_binary(CTransaction, self.g.map[PSBT_GLOBAL_UNSIGNED_TX])
self.in_count = len(tx.vin)
self.out_count = len(tx.vout)
self.i = [from_binary(PSBTMap, f) for _ in range(self.in_count)]
self.o = [from_binary(PSBTMap, f) for _ in range(self.out_count)]
return self
def serialize(self):
assert isinstance(self.g, PSBTMap)
assert isinstance(self.i, list) and all(isinstance(x, PSBTMap) for x in self.i)
assert isinstance(self.o, list) and all(isinstance(x, PSBTMap) for x in self.o)
assert PSBT_GLOBAL_UNSIGNED_TX in self.g.map
tx = from_binary(CTransaction, self.g.map[PSBT_GLOBAL_UNSIGNED_TX])
assert len(tx.vin) == len(self.i)
assert len(tx.vout) == len(self.o)
if self.version is not None and self.version == 2:
self.g.map[PSBT_GLOBAL_INPUT_COUNT] = ser_compact_size(len(self.i))
self.g.map[PSBT_GLOBAL_OUTPUT_COUNT] = ser_compact_size(len(self.o))
psbt = [x.serialize() for x in [self.g] + self.i + self.o]
return b"psbt\xff" + b"".join(psbt)
def make_blank(self):
"""
Remove all fields except for PSBT_GLOBAL_UNSIGNED_TX
Remove all fields except for required fields depending on version
"""
for m in self.i + self.o:
m.map.clear()
if self.version == 0:
for m in self.i + self.o:
m.map.clear()
self.g = PSBTMap(map={PSBT_GLOBAL_UNSIGNED_TX: self.g.map[PSBT_GLOBAL_UNSIGNED_TX]})
self.g = PSBTMap(map={PSBT_GLOBAL_UNSIGNED_TX: self.g.map[PSBT_GLOBAL_UNSIGNED_TX]})
elif self.version == 2:
self.g = PSBTMap(map={
PSBT_GLOBAL_TX_VERSION: self.g.map[PSBT_GLOBAL_TX_VERSION],
PSBT_GLOBAL_INPUT_COUNT: self.g.map[PSBT_GLOBAL_INPUT_COUNT],
PSBT_GLOBAL_OUTPUT_COUNT: self.g.map[PSBT_GLOBAL_OUTPUT_COUNT],
PSBT_GLOBAL_VERSION: self.g.map[PSBT_GLOBAL_VERSION],
})
new_i = []
for m in self.i:
new_i.append(PSBTMap(map={
PSBT_IN_PREVIOUS_TXID: m.map[PSBT_IN_PREVIOUS_TXID],
PSBT_IN_OUTPUT_INDEX: m.map[PSBT_IN_OUTPUT_INDEX],
}))
self.i = new_i
new_o = []
for m in self.o:
new_o.append(PSBTMap(map={
PSBT_OUT_SCRIPT: m.map[PSBT_OUT_SCRIPT],
PSBT_OUT_AMOUNT: m.map[PSBT_OUT_AMOUNT],
}))
self.o = new_o
else:
assert False
def to_base64(self):
return base64.b64encode(self.serialize()).decode("utf8")