# Copyright (c) 2022 Pieter Wuille # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. """ This module provides the ASNEntry and ASMap classes. """ import copy import ipaddress import random import unittest from collections.abc import Callable, Iterable from enum import Enum from functools import total_ordering from typing import Optional, Union, overload def net_to_prefix(net: Union[ipaddress.IPv4Network,ipaddress.IPv6Network]) -> list[bool]: """ Convert an IPv4 or IPv6 network to a prefix represented as a list of bits. IPv4 ranges are remapped to their IPv4-mapped IPv6 range (::ffff:0:0/96). """ num_bits = net.prefixlen netrange = int.from_bytes(net.network_address.packed, 'big') # Map an IPv4 prefix into IPv6 space. if isinstance(net, ipaddress.IPv4Network): num_bits += 96 netrange += 0xffff00000000 # Strip unused bottom bits. assert (netrange & ((1 << (128 - num_bits)) - 1)) == 0 return [((netrange >> (127 - i)) & 1) != 0 for i in range(num_bits)] def prefix_to_net(prefix: list[bool]) -> Union[ipaddress.IPv4Network,ipaddress.IPv6Network]: """The reverse operation of net_to_prefix.""" # Convert to number netrange = sum(b << (127 - i) for i, b in enumerate(prefix)) num_bits = len(prefix) assert num_bits <= 128 # Return IPv4 range if in ::ffff:0:0/96 if num_bits >= 96 and (netrange >> 32) == 0xffff: return ipaddress.IPv4Network((netrange & 0xffffffff, num_bits - 96), True) # Return IPv6 range otherwise. return ipaddress.IPv6Network((netrange, num_bits), True) # Shortcut for (prefix, ASN) entries. ASNEntry = tuple[list[bool], int] # Shortcut for (prefix, old ASN, new ASN) entries. ASNDiff = tuple[list[bool], int, int] class _VarLenCoder: """ A class representing a custom variable-length binary encoder/decoder for integers. Each object represents a different coder, with different parameters minval and clsbits. The encoding is easiest to describe using an example. Let's say minval=100 and clsbits=[4,2,2,3]. In that case: - x in [100..115]: encoded as [0] + [4-bit BE encoding of (x-100)]. - x in [116..119]: encoded as [1,0] + [2-bit BE encoding of (x-116)]. - x in [120..123]: encoded as [1,1,0] + [2-bit BE encoding of (x-120)]. - x in [124..131]: encoded as [1,1,1] + [3-bit BE encoding of (x-124)]. In general, every number is encoded as: - First, k "1"-bits, where k is the class the number falls in (there is one class per element of clsbits). - Then, a "0"-bit, unless k is the highest class, in which case there is nothing. - Lastly, clsbits[k] bits encoding in big endian the position in its class that number falls into. - Every class k consists of 2^clsbits[k] consecutive integers. k=0 starts at minval, other classes start one past the last element of the class before it. """ def __init__(self, minval: int, clsbits: list[int]): """Construct a new _VarLenCoder.""" self._minval = minval self._clsbits = clsbits self._maxval = minval + sum(1 << b for b in clsbits) - 1 def can_encode(self, val: int) -> bool: """Check whether value val is in the range this coder supports.""" return self._minval <= val <= self._maxval def encode(self, val: int, ret: list[int]) -> None: """Append encoding of val onto integer list ret.""" assert self._minval <= val <= self._maxval val -= self._minval bits = 0 for k, bits in enumerate(self._clsbits): if val >> bits: # If the value will not fit in class k, subtract its range from v, # emit a "1" bit and continue with the next class. val -= 1 << bits ret.append(1) else: if k + 1 < len(self._clsbits): # Unless we're in the last class, emit a "0" bit. ret.append(0) break # And then encode v (now the position within the class) in big endian. ret.extend((val >> (bits - 1 - b)) & 1 for b in range(bits)) def encode_size(self, val: int) -> int: """Compute how many bits are needed to encode val.""" assert self._minval <= val <= self._maxval val -= self._minval ret = 0 bits = 0 for k, bits in enumerate(self._clsbits): if val >> bits: val -= 1 << bits ret += 1 else: ret += k + 1 < len(self._clsbits) break return ret + bits def decode(self, stream, bitpos) -> tuple[int,int]: """Decode a number starting at bitpos in stream, returning value and new bitpos.""" val = self._minval bits = 0 for k, bits in enumerate(self._clsbits): bit = 0 if k + 1 < len(self._clsbits): bit = stream[bitpos] bitpos += 1 if not bit: break val += 1 << bits for i in range(bits): bit = stream[bitpos] bitpos += 1 val += bit << (bits - 1 - i) return val, bitpos # Variable-length encoders used in the binary asmap format. _CODER_INS = _VarLenCoder(0, [0, 0, 1]) _CODER_ASN = _VarLenCoder(1, list(range(15, 25))) _CODER_MATCH = _VarLenCoder(2, list(range(1, 9))) _CODER_JUMP = _VarLenCoder(17, list(range(5, 31))) class _Instruction(Enum): """One instruction in the binary asmap format.""" # A return instruction, encoded as [0], returns a constant ASN. It is followed by # an integer using the ASN encoding. RETURN = 0 # A jump instruction, encoded as [1,0] inspects the next unused bit in the input # and either continues execution (if 0), or skips a specified number of bits (if 1). # It is followed by an integer, and then two subprograms. The integer uses jump encoding # and corresponds to the length of the first subprogram (so it can be skipped). JUMP = 1 # A match instruction, encoded as [1,1,0] inspects 1 or more of the next unused bits # in the input with its argument. If they all match, execution continues. If they do # not, failure is returned. If a default instruction has been executed before, instead # of failure the default instruction's argument is returned. It is followed by an # integer in match encoding, and a subprogram. That value is at least 2 bits and at # most 9 bits. An n-bit value signifies matching (n-1) bits in the input with the lower # (n-1) bits in the match value. MATCH = 2 # A default instruction, encoded as [1,1,1] sets the default variable to its argument, # and continues execution. It is followed by an integer in ASN encoding, and a subprogram. DEFAULT = 3 # Not an actual instruction, but a way to encode the empty program that fails. In the # encoder, it is used more generally to represent the failure case inside MATCH instructions, # which may (if used inside the context of a DEFAULT instruction) actually correspond to # a successful return. In this usage, they're always converted to an actual MATCH or RETURN # before the top level is reached (see make_default below). END = 4 class _BinNode: """A class representing a (node of) the parsed binary asmap format.""" @overload def __init__(self, ins: _Instruction): ... @overload def __init__(self, ins: _Instruction, arg1: int): ... @overload def __init__(self, ins: _Instruction, arg1: "_BinNode", arg2: "_BinNode"): ... @overload def __init__(self, ins: _Instruction, arg1: int, arg2: "_BinNode"): ... def __init__(self, ins: _Instruction, arg1=None, arg2=None): """ Construct a new asmap node. Possibilities are: - _BinNode(_Instruction.RETURN, asn) - _BinNode(_Instruction.JUMP, node_0, node_1) - _BinNode(_Instruction.MATCH, val, node) - _BinNode(_Instruction.DEFAULT, asn, node) - _BinNode(_Instruction.END) """ self.ins = ins self.arg1 = arg1 self.arg2 = arg2 if ins == _Instruction.RETURN: assert isinstance(arg1, int) assert arg2 is None self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) elif ins == _Instruction.JUMP: assert isinstance(arg1, _BinNode) assert isinstance(arg2, _BinNode) self.size = (_CODER_INS.encode_size(ins.value) + _CODER_JUMP.encode_size(arg1.size) + arg1.size + arg2.size) elif ins == _Instruction.DEFAULT: assert isinstance(arg1, int) assert isinstance(arg2, _BinNode) self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) + arg2.size elif ins == _Instruction.MATCH: assert isinstance(arg1, int) assert isinstance(arg2, _BinNode) self.size = (_CODER_INS.encode_size(ins.value) + _CODER_MATCH.encode_size(arg1) + arg2.size) elif ins == _Instruction.END: assert arg1 is None assert arg2 is None self.size = 0 else: assert False @staticmethod def make_end() -> "_BinNode": """Constructor for a _BinNode with just an END instruction.""" return _BinNode(_Instruction.END) @staticmethod def make_leaf(val: int) -> "_BinNode": """Constructor for a _BinNode of just a RETURN instruction.""" assert val is not None and val > 0 return _BinNode(_Instruction.RETURN, val) @staticmethod def make_branch(node0: "_BinNode", node1: "_BinNode") -> "_BinNode": """ Construct a _BinNode corresponding to running either the node0 or node1 subprogram, based on the next input bit. It exploits shortcuts that are possible in the encoding, and uses either a JUMP, MATCH, or END instruction. """ if node0.ins == _Instruction.END and node1.ins == _Instruction.END: return node0 if node0.ins == _Instruction.END: if node1.ins == _Instruction.MATCH and node1.arg1 <= 0xFF: return _BinNode(node1.ins, node1.arg1 + (1 << node1.arg1.bit_length()), node1.arg2) return _BinNode(_Instruction.MATCH, 3, node1) if node1.ins == _Instruction.END: if node0.ins == _Instruction.MATCH and node0.arg1 <= 0xFF: return _BinNode(node0.ins, node0.arg1 + (1 << (node0.arg1.bit_length() - 1)), node0.arg2) return _BinNode(_Instruction.MATCH, 2, node0) return _BinNode(_Instruction.JUMP, node0, node1) @staticmethod def make_default(val: int, sub: "_BinNode") -> "_BinNode": """ Construct a _BinNode that corresponds to the specified subprogram, with the specified default value. It exploits shortcuts that are possible in the encoding, and will use either a DEFAULT or a RETURN instruction.""" assert val is not None and val > 0 if sub.ins == _Instruction.END: return _BinNode(_Instruction.RETURN, val) if sub.ins in (_Instruction.RETURN, _Instruction.DEFAULT): return sub return _BinNode(_Instruction.DEFAULT, val, sub) @total_ordering class ASMap: """ A class whose objects represent a mapping from subnets to ASNs. Internally the mapping is stored as a binary trie, but can be converted from/to a list of ASNEntry objects, and from/to the binary asmap file format. In the trie representation, nodes are represented as bare lists for efficiency and ease of manipulation: - [0] means an unassigned subnet (no ASN mapping for it is present) - [int] means a subnet mapped entirely to the specified ASN. - [node,node] means a subnet whose lower half and upper half have different - mappings, represented by new trie nodes. """ def update(self, prefix: list[bool], asn: int) -> None: """Update this ASMap object to map prefix to the specified asn.""" assert asn == 0 or _CODER_ASN.can_encode(asn) def recurse(node: list, offset: int) -> None: if offset == len(prefix): # Reached the end of prefix; overwrite this node. node.clear() node.append(asn) return if len(node) == 1: # Need to descend into a leaf node; split it up. oldasn = node[0] node.clear() node.append([oldasn]) node.append([oldasn]) # Descend into the node. recurse(node[prefix[offset]], offset + 1) # If the result is two identical leaf children, merge them. if len(node[0]) == 1 and len(node[1]) == 1 and node[0] == node[1]: oldasn = node[0][0] node.clear() node.append(oldasn) recurse(self._trie, 0) def update_multi(self, entries: list[tuple[list[bool], int]]) -> None: """Apply multiple update operations, where longer prefixes take precedence.""" entries.sort(key=lambda entry: len(entry[0])) for prefix, asn in entries: self.update(prefix, asn) def _set_trie(self, trie) -> None: """Set trie directly. Internal use only.""" def recurse(node: list) -> None: if len(node) < 2: return recurse(node[0]) recurse(node[1]) if len(node[0]) == 2: return if node[0] == node[1]: if len(node[0]) == 0: node.clear() else: asn = node[0][0] node.clear() node.append(asn) recurse(trie) self._trie = trie def __init__(self, entries: Optional[Iterable[ASNEntry]] = None) -> None: """Construct an ASMap object from an optional list of entries.""" self._trie = [0] if entries is not None: def entry_key(entry): """Sort function that places shorter prefixes first.""" prefix, asn = entry return len(prefix), prefix, asn for prefix, asn in sorted(entries, key=entry_key): self.update(prefix, asn) def lookup(self, prefix: list[bool]) -> Optional[int]: """Look up a prefix. Returns ASN, or 0 if unassigned, or None if indeterminate.""" node = self._trie for bit in prefix: if len(node) == 1: break node = node[bit] if len(node) == 1: return node[0] return None def _to_entries_flat(self, fill: bool = False) -> list[ASNEntry]: """Convert an ASMap object to a list of non-overlapping (prefix, asn) objects.""" prefix : list[bool] = [] def recurse(node: list) -> list[ASNEntry]: ret = [] if len(node) == 1: if node[0] > 0: ret = [(list(prefix), node[0])] elif len(node) == 2: prefix.append(False) ret = recurse(node[0]) prefix[-1] = True ret += recurse(node[1]) prefix.pop() if fill and len(ret) > 1: asns = set(x[1] for x in ret) if len(asns) == 1: ret = [(list(prefix), list(asns)[0])] return ret return recurse(self._trie) def _to_entries_minimal(self, fill: bool = False) -> list[ASNEntry]: """Convert a trie to a minimal list of ASNEntry objects, exploiting overlap.""" prefix : list[bool] = [] def recurse(node: list) -> (tuple[dict[Optional[int], list[ASNEntry]], bool]): if len(node) == 1 and node[0] == 0: return {None if fill else 0: []}, True if len(node) == 1: return {node[0]: [], None: [(list(prefix), node[0])]}, False ret: dict[Optional[int], list[ASNEntry]] = {} prefix.append(False) left, lhole = recurse(node[0]) prefix[-1] = True right, rhole = recurse(node[1]) prefix.pop() hole = not fill and (lhole or rhole) def candidate(ctx: Optional[int], res0: Optional[list[ASNEntry]], res1: Optional[list[ASNEntry]]): if res0 is not None and res1 is not None: if ctx not in ret or len(res0) + len(res1) < len(ret[ctx]): ret[ctx] = res0 + res1 for ctx in set(left) | set(right): candidate(ctx, left.get(ctx), right.get(ctx)) candidate(ctx, left.get(None), right.get(ctx)) candidate(ctx, left.get(ctx), right.get(None)) if not hole: for ctx in list(ret): if ctx is not None: candidate(None, [(list(prefix), ctx)], ret[ctx]) if None in ret: ret = {ctx:entries for ctx, entries in ret.items() if ctx is None or len(entries) < len(ret[None])} if hole: ret = {ctx:entries for ctx, entries in ret.items() if ctx is None or ctx == 0} return ret, hole res, _ = recurse(self._trie) return res[0] if 0 in res else res[None] def __str__(self) -> str: """Convert this ASMap object to a string containing Python code constructing it.""" return f"ASMap({self._trie})" def to_entries(self, overlapping: bool = True, fill: bool = False) -> list[ASNEntry]: """ Convert the mappings in this ASMap object to a list of ASNEntry objects. Arguments: overlapping: Permit the subnets in the resulting ASNEntry to overlap. Setting this can result in a shorter list. fill: Permit the resulting ASNEntry objects to cover subnets that are unassigned in this ASMap object. Setting this can result in a shorter list. """ if overlapping: return self._to_entries_minimal(fill) return self._to_entries_flat(fill) @staticmethod def from_random(num_leaves: int = 10, max_asn: int = 6, unassigned_prob: float = 0.5) -> "ASMap": """ Construct a random ASMap object, with specified: - Number of leaves in its trie (at least 1) - Maximum ASN value (at least 1) - Probability for leaf nodes to be unassigned The number of leaves in the resulting object may be less than what is requested. This method is mostly intended for testing. """ assert num_leaves >= 1 assert max_asn >= 1 or unassigned_prob == 1 assert _CODER_ASN.can_encode(max_asn) assert 0.0 <= unassigned_prob <= 1.0 trie: list = [] leaves = [trie] ret = ASMap() for i in range(1, num_leaves): idx = random.randrange(i) leaf = leaves[idx] lastleaf = leaves.pop() if idx + 1 < i: leaves[idx] = lastleaf leaf.append([]) leaf.append([]) leaves.append(leaf[0]) leaves.append(leaf[1]) for leaf in leaves: if random.random() >= unassigned_prob: leaf.append(random.randrange(1, max_asn + 1)) else: leaf.append(0) #pylint: disable=protected-access ret._set_trie(trie) return ret def _to_binnode(self, fill: bool = False) -> _BinNode: """Convert a trie to a _BinNode object.""" def recurse(node: list) -> tuple[dict[Optional[int], _BinNode], bool]: if len(node) == 1 and node[0] == 0: return {(None if fill else 0): _BinNode.make_end()}, True if len(node) == 1: return {None: _BinNode.make_leaf(node[0]), node[0]: _BinNode.make_end()}, False ret: dict[Optional[int], _BinNode] = {} left, lhole = recurse(node[0]) right, rhole = recurse(node[1]) hole = (lhole or rhole) and not fill def candidate(ctx: Optional[int], arg1, arg2, func: Callable): if arg1 is not None and arg2 is not None: cand = func(arg1, arg2) if ctx not in ret or cand.size < ret[ctx].size: ret[ctx] = cand for ctx in set(left) | set(right): candidate(ctx, left.get(ctx), right.get(ctx), _BinNode.make_branch) candidate(ctx, left.get(None), right.get(ctx), _BinNode.make_branch) candidate(ctx, left.get(ctx), right.get(None), _BinNode.make_branch) if not hole: for ctx in set(ret) - set([None]): candidate(None, ctx, ret[ctx], _BinNode.make_default) if None in ret: ret = {ctx:enc for ctx, enc in ret.items() if ctx is None or enc.size < ret[None].size} if hole: ret = {ctx:enc for ctx, enc in ret.items() if ctx is None or ctx == 0} return ret, hole res, _ = recurse(self._trie) return res[0] if 0 in res else res[None] @staticmethod def _from_binnode(binnode: _BinNode) -> "ASMap": """Construct an ASMap object from a _BinNode. Internal use only.""" def recurse(node: _BinNode, default: int) -> list: if node.ins == _Instruction.RETURN: return [node.arg1] if node.ins == _Instruction.JUMP: return [recurse(node.arg1, default), recurse(node.arg2, default)] if node.ins == _Instruction.MATCH: val = node.arg1 sub = recurse(node.arg2, default) while val >= 2: bit = val & 1 val >>= 1 if bit: sub = [[default], sub] else: sub = [sub, [default]] return sub assert node.ins == _Instruction.DEFAULT return recurse(node.arg2, node.arg1) ret = ASMap() if binnode.ins != _Instruction.END: #pylint: disable=protected-access ret._set_trie(recurse(binnode, 0)) return ret def to_binary(self, fill: bool = False) -> bytes: """ Convert this ASMap object to binary. Argument: fill: permit the resulting binary encoder to contain mappers for unassigned subnets in this ASMap object. Doing so may reduce the size of the encoding. Returns: A bytes object with the encoding of this ASMap object. """ bits: list[int] = [] def recurse(node: _BinNode) -> None: _CODER_INS.encode(node.ins.value, bits) if node.ins == _Instruction.RETURN: _CODER_ASN.encode(node.arg1, bits) elif node.ins == _Instruction.JUMP: _CODER_JUMP.encode(node.arg1.size, bits) recurse(node.arg1) recurse(node.arg2) elif node.ins == _Instruction.DEFAULT: _CODER_ASN.encode(node.arg1, bits) recurse(node.arg2) else: assert node.ins == _Instruction.MATCH _CODER_MATCH.encode(node.arg1, bits) recurse(node.arg2) binnode = self._to_binnode(fill) if binnode.ins != _Instruction.END: recurse(binnode) val = 0 nbits = 0 ret = [] for bit in bits: val += (bit << nbits) nbits += 1 if nbits == 8: ret.append(val) val = 0 nbits = 0 if nbits: ret.append(val) return bytes(ret) @staticmethod def from_binary(bindata: bytes) -> Optional["ASMap"]: """Decode an ASMap object from the provided binary encoding.""" bits: list[int] = [] for byte in bindata: bits.extend((byte >> i) & 1 for i in range(8)) def recurse(bitpos: int) -> tuple[_BinNode, int]: insval, bitpos = _CODER_INS.decode(bits, bitpos) ins = _Instruction(insval) if ins == _Instruction.RETURN: asn, bitpos = _CODER_ASN.decode(bits, bitpos) return _BinNode(ins, asn), bitpos if ins == _Instruction.JUMP: jump, bitpos = _CODER_JUMP.decode(bits, bitpos) left, bitpos1 = recurse(bitpos) if bitpos1 != bitpos + jump: raise ValueError("Inconsistent jump") right, bitpos = recurse(bitpos1) return _BinNode(ins, left, right), bitpos if ins == _Instruction.MATCH: match, bitpos = _CODER_MATCH.decode(bits, bitpos) sub, bitpos = recurse(bitpos) return _BinNode(ins, match, sub), bitpos assert ins == _Instruction.DEFAULT asn, bitpos = _CODER_ASN.decode(bits, bitpos) sub, bitpos = recurse(bitpos) return _BinNode(ins, asn, sub), bitpos if len(bits) == 0: binnode = _BinNode(_Instruction.END) else: try: binnode, bitpos = recurse(0) except (ValueError, IndexError): return None if bitpos < len(bits) - 7: return None if not all(bit == 0 for bit in bits[bitpos:]): return None return ASMap._from_binnode(binnode) def __lt__(self, other: "ASMap") -> bool: return self._trie < other._trie def __eq__(self, other: object) -> bool: if isinstance(other, ASMap): return self._trie == other._trie return False def extends(self, req: "ASMap") -> bool: """Determine whether this matches req for all subranges where req is assigned.""" def recurse(actual: list, require: list) -> bool: if len(require) == 1 and require[0] == 0: return True if len(require) == 1: if len(actual) == 1: return bool(require[0] == actual[0]) return recurse(actual[0], require) and recurse(actual[1], require) if len(actual) == 2: return recurse(actual[0], require[0]) and recurse(actual[1], require[1]) return recurse(actual, require[0]) and recurse(actual, require[1]) assert isinstance(req, ASMap) #pylint: disable=protected-access return recurse(self._trie, req._trie) def diff(self, other: "ASMap") -> list[ASNDiff]: """Compute the diff from self to other.""" prefix: list[bool] = [] ret: list[ASNDiff] = [] def recurse(old_node: list, new_node: list): if len(old_node) == 1 and len(new_node) == 1: if old_node[0] != new_node[0]: ret.append((list(prefix), old_node[0], new_node[0])) else: old_left: list = old_node if len(old_node) == 1 else old_node[0] old_right: list = old_node if len(old_node) == 1 else old_node[1] new_left: list = new_node if len(new_node) == 1 else new_node[0] new_right: list = new_node if len(new_node) == 1 else new_node[1] prefix.append(False) recurse(old_left, new_left) prefix[-1] = True recurse(old_right, new_right) prefix.pop() assert isinstance(other, ASMap) #pylint: disable=protected-access recurse(self._trie, other._trie) return ret def __copy__(self) -> "ASMap": """Construct a copy of this ASMap object. Its state will not be shared.""" ret = ASMap() #pylint: disable=protected-access ret._set_trie(copy.deepcopy(self._trie)) return ret def __deepcopy__(self, _) -> "ASMap": # ASMap objects do not allow sharing of the _trie member, so we don't need the memoization. return self.__copy__() class TestASMap(unittest.TestCase): """Unit tests for this module.""" def test_ipv6_prefix_roundtrips(self) -> None: """Test that random IPv6 network ranges roundtrip through prefix encoding.""" for _ in range(20): net_bits = random.getrandbits(128) for prefix_len in range(0, 129): masked_bits = (net_bits >> (128 - prefix_len)) << (128 - prefix_len) net = ipaddress.IPv6Network((masked_bits.to_bytes(16, 'big'), prefix_len)) prefix = net_to_prefix(net) self.assertTrue(len(prefix) <= 128) net2 = prefix_to_net(prefix) self.assertEqual(net, net2) def test_ipv4_prefix_roundtrips(self) -> None: """Test that random IPv4 network ranges roundtrip through prefix encoding.""" for _ in range(100): net_bits = random.getrandbits(32) for prefix_len in range(0, 33): masked_bits = (net_bits >> (32 - prefix_len)) << (32 - prefix_len) net = ipaddress.IPv4Network((masked_bits.to_bytes(4, 'big'), prefix_len)) prefix = net_to_prefix(net) self.assertTrue(32 <= len(prefix) <= 128) net2 = prefix_to_net(prefix) self.assertEqual(net, net2) def test_asmap_roundtrips(self) -> None: """Test case that verifies random ASMap objects roundtrip to/from entries/binary.""" # Iterate over the number of leaves the random test ASMap objects have. for leaves in range(1, 20): # Iterate over the number of bits in the AS numbers used. for asnbits in range(0, 24): # Iterate over the probability that leaves are unassigned. for pct in range(101): # Construct a random ASMap object according to the above parameters. asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits), unassigned_prob=0.01 * pct) # Run tests for to_entries and construction from those entries, both # for overlapping and non-overlapping ones. for overlapping in [False, True]: entries = asmap.to_entries(overlapping=overlapping, fill=False) random.shuffle(entries) asmap2 = ASMap(entries) assert asmap2 is not None self.assertEqual(asmap2, asmap) entries = asmap.to_entries(overlapping=overlapping, fill=True) random.shuffle(entries) asmap2 = ASMap(entries) assert asmap2 is not None self.assertTrue(asmap2.extends(asmap)) # Run tests for to_binary and construction from binary. enc = asmap.to_binary(fill=False) asmap3 = ASMap.from_binary(enc) assert asmap3 is not None self.assertEqual(asmap3, asmap) enc = asmap.to_binary(fill=True) asmap3 = ASMap.from_binary(enc) assert asmap3 is not None self.assertTrue(asmap3.extends(asmap)) def test_patching(self) -> None: """Test behavior of update, lookup, extends, and diff.""" #pylint: disable=too-many-locals,too-many-nested-blocks # Iterate over the number of leaves the random test ASMap objects have. for leaves in range(1, 20): # Iterate over the number of bits in the AS numbers used. for asnbits in range(0, 10): # Iterate over the probability that leaves are unassigned. for pct in range(0, 101): # Construct a random ASMap object according to the above parameters. asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits), unassigned_prob=0.01 * pct) # Make a copy of that asmap object to which patches will be applied. # It starts off being equal to asmap. patched = copy.copy(asmap) # Keep a list of patches performed. patches: list[ASNEntry] = [] # Initially there cannot be any difference. self.assertEqual(asmap.diff(patched), []) # Make 5 patches, each building on top of the previous ones. for _ in range(0, 5): # Construct a random path and new ASN to assign it to, apply it to patched, # and remember it in patches. pathlen = random.randrange(5) path = [random.getrandbits(1) != 0 for _ in range(pathlen)] newasn = random.randrange(1 + (1 << asnbits)) patched.update(path, newasn) patches = [(path, newasn)] + patches # Compute the diff, and whether asmap extends patched, and the other way # around. diff = asmap.diff(patched) self.assertEqual(asmap == patched, len(diff) == 0) extends = asmap.extends(patched) back_extends = patched.extends(asmap) # Determine whether those extends results are consistent with the diff # result. self.assertEqual(extends, all(d[2] == 0 for d in diff)) self.assertEqual(back_extends, all(d[1] == 0 for d in diff)) # For every diff found: for path, old_asn, new_asn in diff: # Verify asmap and patched actually differ there. self.assertTrue(old_asn != new_asn) self.assertEqual(asmap.lookup(path), old_asn) self.assertEqual(patched.lookup(path), new_asn) for _ in range(2): # Extend the path far enough that it's smaller than any mapped # range, and check the lookup holds there too. spec_path = list(path) while len(spec_path) < 32: spec_path.append(random.getrandbits(1) != 0) self.assertEqual(asmap.lookup(spec_path), old_asn) self.assertEqual(patched.lookup(spec_path), new_asn) # Search through the list of performed patches to find the last one # applying to the extended path (note that patches is in reverse # order, so the first match should work). found = False for patch_path, patch_asn in patches: if spec_path[:len(patch_path)] == patch_path: # When found, it must match whatever the result was patched # to. self.assertEqual(new_asn, patch_asn) found = True break # And such a patch must exist. self.assertTrue(found) if __name__ == '__main__': unittest.main()