test: introduce get_weight() helper for CTransaction

This commit is contained in:
Sebastian Falbesoner 2021-06-30 23:40:39 +02:00
parent 3fc20abab0
commit a084ebe133
3 changed files with 14 additions and 13 deletions

View file

@ -260,8 +260,8 @@ class SegWitTest(BitcoinTestFramework):
assert_equal(int(self.nodes[0].getmempoolentry(txid1)["wtxid"], 16), tx1.calc_sha256(True)) assert_equal(int(self.nodes[0].getmempoolentry(txid1)["wtxid"], 16), tx1.calc_sha256(True))
# Check that weight and vsize are properly reported in mempool entry (txid1) # Check that weight and vsize are properly reported in mempool entry (txid1)
assert_equal(self.nodes[0].getmempoolentry(txid1)["vsize"], (self.nodes[0].getmempoolentry(txid1)["weight"] + 3) // 4) assert_equal(self.nodes[0].getmempoolentry(txid1)["vsize"], tx1.get_vsize())
assert_equal(self.nodes[0].getmempoolentry(txid1)["weight"], len(tx1.serialize_without_witness())*3 + len(tx1.serialize_with_witness())) assert_equal(self.nodes[0].getmempoolentry(txid1)["weight"], tx1.get_weight())
# Now create tx2, which will spend from txid1. # Now create tx2, which will spend from txid1.
tx = CTransaction() tx = CTransaction()
@ -276,8 +276,8 @@ class SegWitTest(BitcoinTestFramework):
assert_equal(int(self.nodes[0].getmempoolentry(txid2)["wtxid"], 16), tx.calc_sha256(True)) assert_equal(int(self.nodes[0].getmempoolentry(txid2)["wtxid"], 16), tx.calc_sha256(True))
# Check that weight and vsize are properly reported in mempool entry (txid2) # Check that weight and vsize are properly reported in mempool entry (txid2)
assert_equal(self.nodes[0].getmempoolentry(txid2)["vsize"], (self.nodes[0].getmempoolentry(txid2)["weight"] + 3) // 4) assert_equal(self.nodes[0].getmempoolentry(txid2)["vsize"], tx.get_vsize())
assert_equal(self.nodes[0].getmempoolentry(txid2)["weight"], len(tx.serialize_without_witness())*3 + len(tx.serialize_with_witness())) assert_equal(self.nodes[0].getmempoolentry(txid2)["weight"], tx.get_weight())
# Now create tx3, which will spend from txid2 # Now create tx3, which will spend from txid2
tx = CTransaction() tx = CTransaction()
@ -299,8 +299,8 @@ class SegWitTest(BitcoinTestFramework):
assert_equal(int(self.nodes[0].getmempoolentry(txid3)["wtxid"], 16), tx.calc_sha256(True)) assert_equal(int(self.nodes[0].getmempoolentry(txid3)["wtxid"], 16), tx.calc_sha256(True))
# Check that weight and vsize are properly reported in mempool entry (txid3) # Check that weight and vsize are properly reported in mempool entry (txid3)
assert_equal(self.nodes[0].getmempoolentry(txid3)["vsize"], (self.nodes[0].getmempoolentry(txid3)["weight"] + 3) // 4) assert_equal(self.nodes[0].getmempoolentry(txid3)["vsize"], tx.get_vsize())
assert_equal(self.nodes[0].getmempoolentry(txid3)["weight"], len(tx.serialize_without_witness())*3 + len(tx.serialize_with_witness())) assert_equal(self.nodes[0].getmempoolentry(txid3)["weight"], tx.get_weight())
# Mine a block to clear the gbt cache again. # Mine a block to clear the gbt cache again.
self.nodes[0].generate(1) self.nodes[0].generate(1)

View file

@ -4,7 +4,6 @@
# file COPYING or http://www.opensource.org/licenses/mit-license.php. # file COPYING or http://www.opensource.org/licenses/mit-license.php.
"""Test segwit transactions and blocks on P2P network.""" """Test segwit transactions and blocks on P2P network."""
from decimal import Decimal from decimal import Decimal
import math
import random import random
import struct import struct
import time import time
@ -1367,10 +1366,9 @@ class SegWitTest(BitcoinTestFramework):
raw_tx = self.nodes[0].getrawtransaction(tx3.hash, 1) raw_tx = self.nodes[0].getrawtransaction(tx3.hash, 1)
assert_equal(int(raw_tx["hash"], 16), tx3.calc_sha256(True)) assert_equal(int(raw_tx["hash"], 16), tx3.calc_sha256(True))
assert_equal(raw_tx["size"], len(tx3.serialize_with_witness())) assert_equal(raw_tx["size"], len(tx3.serialize_with_witness()))
weight = len(tx3.serialize_with_witness()) + 3 * len(tx3.serialize_without_witness()) vsize = tx3.get_vsize()
vsize = math.ceil(weight / 4)
assert_equal(raw_tx["vsize"], vsize) assert_equal(raw_tx["vsize"], vsize)
assert_equal(raw_tx["weight"], weight) assert_equal(raw_tx["weight"], tx3.get_weight())
assert_equal(len(raw_tx["vin"][0]["txinwitness"]), 1) assert_equal(len(raw_tx["vin"][0]["txinwitness"]), 1)
assert_equal(raw_tx["vin"][0]["txinwitness"][0], witness_program.hex()) assert_equal(raw_tx["vin"][0]["txinwitness"][0], witness_program.hex())
assert vsize != raw_tx["size"] assert vsize != raw_tx["size"]

View file

@ -590,12 +590,15 @@ class CTransaction:
return False return False
return True return True
# Calculate the virtual transaction size using witness and non-witness # Calculate the transaction weight using witness and non-witness
# serialization size (does NOT use sigops). # serialization size (does NOT use sigops).
def get_vsize(self): def get_weight(self):
with_witness_size = len(self.serialize_with_witness()) with_witness_size = len(self.serialize_with_witness())
without_witness_size = len(self.serialize_without_witness()) without_witness_size = len(self.serialize_without_witness())
return math.ceil(((WITNESS_SCALE_FACTOR - 1) * without_witness_size + with_witness_size) / WITNESS_SCALE_FACTOR) return (WITNESS_SCALE_FACTOR - 1) * without_witness_size + with_witness_size
def get_vsize(self):
return math.ceil(self.get_weight() / WITNESS_SCALE_FACTOR)
def __repr__(self): def __repr__(self):
return "CTransaction(nVersion=%i vin=%s vout=%s wit=%s nLockTime=%i)" \ return "CTransaction(nVersion=%i vin=%s vout=%s wit=%s nLockTime=%i)" \