test: add secp256k1 module with FE (field element) and GE (group element) classes

These are primarily designed for ease of understanding, not performance.
This commit is contained in:
Pieter Wuille 2022-10-01 11:35:28 -04:00
parent b741a62a2f
commit 1830dd8820
3 changed files with 376 additions and 285 deletions

View file

@ -104,8 +104,8 @@ from test_framework.key import (
sign_schnorr, sign_schnorr,
tweak_add_privkey, tweak_add_privkey,
ECKey, ECKey,
SECP256K1
) )
from test_framework import secp256k1
from test_framework.address import ( from test_framework.address import (
hash160, hash160,
program_to_witness, program_to_witness,
@ -695,7 +695,7 @@ def spenders_taproot_active():
# Generate an invalid public key # Generate an invalid public key
while True: while True:
invalid_pub = random_bytes(32) invalid_pub = random_bytes(32)
if not SECP256K1.is_x_coord(int.from_bytes(invalid_pub, 'big')): if not secp256k1.GE.is_valid_x(int.from_bytes(invalid_pub, 'big')):
break break
# Implement a test case that detects validation logic which maps invalid public keys to the # Implement a test case that detects validation logic which maps invalid public keys to the

View file

@ -1,7 +1,7 @@
# Copyright (c) 2019-2020 Pieter Wuille # Copyright (c) 2019-2020 Pieter Wuille
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file COPYING or http://www.opensource.org/licenses/mit-license.php. # file COPYING or http://www.opensource.org/licenses/mit-license.php.
"""Test-only secp256k1 elliptic curve implementation """Test-only secp256k1 elliptic curve protocols implementation
WARNING: This code is slow, uses bad randomness, does not properly protect WARNING: This code is slow, uses bad randomness, does not properly protect
keys, and is trivially vulnerable to side channel attacks. Do not use for keys, and is trivially vulnerable to side channel attacks. Do not use for
@ -13,9 +13,13 @@ import os
import random import random
import unittest import unittest
from test_framework import secp256k1
# Point with no known discrete log. # Point with no known discrete log.
H_POINT = "50929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0" H_POINT = "50929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0"
# Order of the secp256k1 curve
ORDER = secp256k1.GE.ORDER
def TaggedHash(tag, data): def TaggedHash(tag, data):
ss = hashlib.sha256(tag.encode('utf-8')).digest() ss = hashlib.sha256(tag.encode('utf-8')).digest()
@ -23,233 +27,18 @@ def TaggedHash(tag, data):
ss += data ss += data
return hashlib.sha256(ss).digest() return hashlib.sha256(ss).digest()
def jacobi_symbol(n, k):
"""Compute the Jacobi symbol of n modulo k
See https://en.wikipedia.org/wiki/Jacobi_symbol class ECPubKey:
For our application k is always prime, so this is the same as the Legendre symbol."""
assert k > 0 and k & 1, "jacobi symbol is only defined for positive odd k"
n %= k
t = 0
while n != 0:
while n & 1 == 0:
n >>= 1
r = k & 7
t ^= (r == 3 or r == 5)
n, k = k, n
t ^= (n & k & 3 == 3)
n = n % k
if k == 1:
return -1 if t else 1
return 0
def modsqrt(a, p):
"""Compute the square root of a modulo p when p % 4 = 3.
The Tonelli-Shanks algorithm can be used. See https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm
Limiting this function to only work for p % 4 = 3 means we don't need to
iterate through the loop. The highest n such that p - 1 = 2^n Q with Q odd
is n = 1. Therefore Q = (p-1)/2 and sqrt = a^((Q+1)/2) = a^((p+1)/4)
secp256k1's is defined over field of size 2**256 - 2**32 - 977, which is 3 mod 4.
"""
if p % 4 != 3:
raise NotImplementedError("modsqrt only implemented for p % 4 = 3")
sqrt = pow(a, (p + 1)//4, p)
if pow(sqrt, 2, p) == a % p:
return sqrt
return None
class EllipticCurve:
def __init__(self, p, a, b):
"""Initialize elliptic curve y^2 = x^3 + a*x + b over GF(p)."""
self.p = p
self.a = a % p
self.b = b % p
def affine(self, p1):
"""Convert a Jacobian point tuple p1 to affine form, or None if at infinity.
An affine point is represented as the Jacobian (x, y, 1)"""
x1, y1, z1 = p1
if z1 == 0:
return None
inv = pow(z1, -1, self.p)
inv_2 = (inv**2) % self.p
inv_3 = (inv_2 * inv) % self.p
return ((inv_2 * x1) % self.p, (inv_3 * y1) % self.p, 1)
def has_even_y(self, p1):
"""Whether the point p1 has an even Y coordinate when expressed in affine coordinates."""
return not (p1[2] == 0 or self.affine(p1)[1] & 1)
def negate(self, p1):
"""Negate a Jacobian point tuple p1."""
x1, y1, z1 = p1
return (x1, (self.p - y1) % self.p, z1)
def on_curve(self, p1):
"""Determine whether a Jacobian tuple p is on the curve (and not infinity)"""
x1, y1, z1 = p1
z2 = pow(z1, 2, self.p)
z4 = pow(z2, 2, self.p)
return z1 != 0 and (pow(x1, 3, self.p) + self.a * x1 * z4 + self.b * z2 * z4 - pow(y1, 2, self.p)) % self.p == 0
def is_x_coord(self, x):
"""Test whether x is a valid X coordinate on the curve."""
x_3 = pow(x, 3, self.p)
return jacobi_symbol(x_3 + self.a * x + self.b, self.p) != -1
def lift_x(self, x):
"""Given an X coordinate on the curve, return a corresponding affine point for which the Y coordinate is even."""
x_3 = pow(x, 3, self.p)
v = x_3 + self.a * x + self.b
y = modsqrt(v, self.p)
if y is None:
return None
return (x, self.p - y if y & 1 else y, 1)
def double(self, p1):
"""Double a Jacobian tuple p1
See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Doubling"""
x1, y1, z1 = p1
if z1 == 0:
return (0, 1, 0)
y1_2 = (y1**2) % self.p
y1_4 = (y1_2**2) % self.p
x1_2 = (x1**2) % self.p
s = (4*x1*y1_2) % self.p
m = 3*x1_2
if self.a:
m += self.a * pow(z1, 4, self.p)
m = m % self.p
x2 = (m**2 - 2*s) % self.p
y2 = (m*(s - x2) - 8*y1_4) % self.p
z2 = (2*y1*z1) % self.p
return (x2, y2, z2)
def add_mixed(self, p1, p2):
"""Add a Jacobian tuple p1 and an affine tuple p2
See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition (with affine point)"""
x1, y1, z1 = p1
x2, y2, z2 = p2
assert z2 == 1
# Adding to the point at infinity is a no-op
if z1 == 0:
return p2
z1_2 = (z1**2) % self.p
z1_3 = (z1_2 * z1) % self.p
u2 = (x2 * z1_2) % self.p
s2 = (y2 * z1_3) % self.p
if x1 == u2:
if (y1 != s2):
# p1 and p2 are inverses. Return the point at infinity.
return (0, 1, 0)
# p1 == p2. The formulas below fail when the two points are equal.
return self.double(p1)
h = u2 - x1
r = s2 - y1
h_2 = (h**2) % self.p
h_3 = (h_2 * h) % self.p
u1_h_2 = (x1 * h_2) % self.p
x3 = (r**2 - h_3 - 2*u1_h_2) % self.p
y3 = (r*(u1_h_2 - x3) - y1*h_3) % self.p
z3 = (h*z1) % self.p
return (x3, y3, z3)
def add(self, p1, p2):
"""Add two Jacobian tuples p1 and p2
See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition"""
x1, y1, z1 = p1
x2, y2, z2 = p2
# Adding the point at infinity is a no-op
if z1 == 0:
return p2
if z2 == 0:
return p1
# Adding an Affine to a Jacobian is more efficient since we save field multiplications and squarings when z = 1
if z1 == 1:
return self.add_mixed(p2, p1)
if z2 == 1:
return self.add_mixed(p1, p2)
z1_2 = (z1**2) % self.p
z1_3 = (z1_2 * z1) % self.p
z2_2 = (z2**2) % self.p
z2_3 = (z2_2 * z2) % self.p
u1 = (x1 * z2_2) % self.p
u2 = (x2 * z1_2) % self.p
s1 = (y1 * z2_3) % self.p
s2 = (y2 * z1_3) % self.p
if u1 == u2:
if (s1 != s2):
# p1 and p2 are inverses. Return the point at infinity.
return (0, 1, 0)
# p1 == p2. The formulas below fail when the two points are equal.
return self.double(p1)
h = u2 - u1
r = s2 - s1
h_2 = (h**2) % self.p
h_3 = (h_2 * h) % self.p
u1_h_2 = (u1 * h_2) % self.p
x3 = (r**2 - h_3 - 2*u1_h_2) % self.p
y3 = (r*(u1_h_2 - x3) - s1*h_3) % self.p
z3 = (h*z1*z2) % self.p
return (x3, y3, z3)
def mul(self, ps):
"""Compute a (multi) point multiplication
ps is a list of (Jacobian tuple, scalar) pairs.
"""
r = (0, 1, 0)
for i in range(255, -1, -1):
r = self.double(r)
for (p, n) in ps:
if ((n >> i) & 1):
r = self.add(r, p)
return r
SECP256K1_FIELD_SIZE = 2**256 - 2**32 - 977
SECP256K1 = EllipticCurve(SECP256K1_FIELD_SIZE, 0, 7)
SECP256K1_G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8, 1)
SECP256K1_ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
SECP256K1_ORDER_HALF = SECP256K1_ORDER // 2
class ECPubKey():
"""A secp256k1 public key""" """A secp256k1 public key"""
def __init__(self): def __init__(self):
"""Construct an uninitialized public key""" """Construct an uninitialized public key"""
self.valid = False self.p = None
def set(self, data): def set(self, data):
"""Construct a public key from a serialization in compressed or uncompressed format""" """Construct a public key from a serialization in compressed or uncompressed format"""
if (len(data) == 65 and data[0] == 0x04): self.p = secp256k1.GE.from_bytes(data)
p = (int.from_bytes(data[1:33], 'big'), int.from_bytes(data[33:65], 'big'), 1) self.compressed = len(data) == 33
self.valid = SECP256K1.on_curve(p)
if self.valid:
self.p = p
self.compressed = False
elif (len(data) == 33 and (data[0] == 0x02 or data[0] == 0x03)):
x = int.from_bytes(data[1:33], 'big')
if SECP256K1.is_x_coord(x):
p = SECP256K1.lift_x(x)
# Make the Y coordinate odd if required (lift_x always produces
# a point with an even Y coordinate).
if data[0] & 1:
p = SECP256K1.negate(p)
self.p = p
self.valid = True
self.compressed = True
else:
self.valid = False
else:
self.valid = False
@property @property
def is_compressed(self): def is_compressed(self):
@ -257,24 +46,21 @@ class ECPubKey():
@property @property
def is_valid(self): def is_valid(self):
return self.valid return self.p is not None
def get_bytes(self): def get_bytes(self):
assert self.valid assert self.is_valid
p = SECP256K1.affine(self.p)
if p is None:
return None
if self.compressed: if self.compressed:
return bytes([0x02 + (p[1] & 1)]) + p[0].to_bytes(32, 'big') return self.p.to_bytes_compressed()
else: else:
return bytes([0x04]) + p[0].to_bytes(32, 'big') + p[1].to_bytes(32, 'big') return self.p.to_bytes_uncompressed()
def verify_ecdsa(self, sig, msg, low_s=True): def verify_ecdsa(self, sig, msg, low_s=True):
"""Verify a strictly DER-encoded ECDSA signature against this pubkey. """Verify a strictly DER-encoded ECDSA signature against this pubkey.
See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the
ECDSA verifier algorithm""" ECDSA verifier algorithm"""
assert self.valid assert self.is_valid
# Extract r and s from the DER formatted signature. Return false for # Extract r and s from the DER formatted signature. Return false for
# any DER encoding errors. # any DER encoding errors.
@ -310,24 +96,22 @@ class ECPubKey():
s = int.from_bytes(sig[6+rlen:6+rlen+slen], 'big') s = int.from_bytes(sig[6+rlen:6+rlen+slen], 'big')
# Verify that r and s are within the group order # Verify that r and s are within the group order
if r < 1 or s < 1 or r >= SECP256K1_ORDER or s >= SECP256K1_ORDER: if r < 1 or s < 1 or r >= ORDER or s >= ORDER:
return False return False
if low_s and s >= SECP256K1_ORDER_HALF: if low_s and s >= secp256k1.GE.ORDER_HALF:
return False return False
z = int.from_bytes(msg, 'big') z = int.from_bytes(msg, 'big')
# Run verifier algorithm on r, s # Run verifier algorithm on r, s
w = pow(s, -1, SECP256K1_ORDER) w = pow(s, -1, ORDER)
u1 = z*w % SECP256K1_ORDER R = secp256k1.GE.mul((z * w, secp256k1.G), (r * w, self.p))
u2 = r*w % SECP256K1_ORDER if R.infinity or (int(R.x) % ORDER) != r:
R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, u1), (self.p, u2)]))
if R is None or (R[0] % SECP256K1_ORDER) != r:
return False return False
return True return True
def generate_privkey(): def generate_privkey():
"""Generate a valid random 32-byte private key.""" """Generate a valid random 32-byte private key."""
return random.randrange(1, SECP256K1_ORDER).to_bytes(32, 'big') return random.randrange(1, ORDER).to_bytes(32, 'big')
def rfc6979_nonce(key): def rfc6979_nonce(key):
"""Compute signing nonce using RFC6979.""" """Compute signing nonce using RFC6979."""
@ -339,7 +123,7 @@ def rfc6979_nonce(key):
v = hmac.new(k, v, 'sha256').digest() v = hmac.new(k, v, 'sha256').digest()
return hmac.new(k, v, 'sha256').digest() return hmac.new(k, v, 'sha256').digest()
class ECKey(): class ECKey:
"""A secp256k1 private key""" """A secp256k1 private key"""
def __init__(self): def __init__(self):
@ -349,7 +133,7 @@ class ECKey():
"""Construct a private key object with given 32-byte secret and compressed flag.""" """Construct a private key object with given 32-byte secret and compressed flag."""
assert len(secret) == 32 assert len(secret) == 32
secret = int.from_bytes(secret, 'big') secret = int.from_bytes(secret, 'big')
self.valid = (secret > 0 and secret < SECP256K1_ORDER) self.valid = (secret > 0 and secret < ORDER)
if self.valid: if self.valid:
self.secret = secret self.secret = secret
self.compressed = compressed self.compressed = compressed
@ -375,9 +159,7 @@ class ECKey():
"""Compute an ECPubKey object for this secret key.""" """Compute an ECPubKey object for this secret key."""
assert self.valid assert self.valid
ret = ECPubKey() ret = ECPubKey()
p = SECP256K1.mul([(SECP256K1_G, self.secret)]) ret.p = self.secret * secp256k1.G
ret.p = p
ret.valid = True
ret.compressed = self.compressed ret.compressed = self.compressed
return ret return ret
@ -392,12 +174,12 @@ class ECKey():
if rfc6979: if rfc6979:
k = int.from_bytes(rfc6979_nonce(self.secret.to_bytes(32, 'big') + msg), 'big') k = int.from_bytes(rfc6979_nonce(self.secret.to_bytes(32, 'big') + msg), 'big')
else: else:
k = random.randrange(1, SECP256K1_ORDER) k = random.randrange(1, ORDER)
R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, k)])) R = k * secp256k1.G
r = R[0] % SECP256K1_ORDER r = int(R.x) % ORDER
s = (pow(k, -1, SECP256K1_ORDER) * (z + self.secret * r)) % SECP256K1_ORDER s = (pow(k, -1, ORDER) * (z + self.secret * r)) % ORDER
if low_s and s > SECP256K1_ORDER_HALF: if low_s and s > secp256k1.GE.ORDER_HALF:
s = SECP256K1_ORDER - s s = ORDER - s
# Represent in DER format. The byte representations of r and s have # Represent in DER format. The byte representations of r and s have
# length rounded up (255 bits becomes 32 bytes and 256 bits becomes 33 # length rounded up (255 bits becomes 32 bytes and 256 bits becomes 33
# bytes). # bytes).
@ -413,10 +195,10 @@ def compute_xonly_pubkey(key):
assert len(key) == 32 assert len(key) == 32
x = int.from_bytes(key, 'big') x = int.from_bytes(key, 'big')
if x == 0 or x >= SECP256K1_ORDER: if x == 0 or x >= ORDER:
return (None, None) return (None, None)
P = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, x)])) P = x * secp256k1.G
return (P[0].to_bytes(32, 'big'), not SECP256K1.has_even_y(P)) return (P.to_bytes_xonly(), not P.y.is_even())
def tweak_add_privkey(key, tweak): def tweak_add_privkey(key, tweak):
"""Tweak a private key (after negating it if needed).""" """Tweak a private key (after negating it if needed)."""
@ -425,14 +207,14 @@ def tweak_add_privkey(key, tweak):
assert len(tweak) == 32 assert len(tweak) == 32
x = int.from_bytes(key, 'big') x = int.from_bytes(key, 'big')
if x == 0 or x >= SECP256K1_ORDER: if x == 0 or x >= ORDER:
return None return None
if not SECP256K1.has_even_y(SECP256K1.mul([(SECP256K1_G, x)])): if not (x * secp256k1.G).y.is_even():
x = SECP256K1_ORDER - x x = ORDER - x
t = int.from_bytes(tweak, 'big') t = int.from_bytes(tweak, 'big')
if t >= SECP256K1_ORDER: if t >= ORDER:
return None return None
x = (x + t) % SECP256K1_ORDER x = (x + t) % ORDER
if x == 0: if x == 0:
return None return None
return x.to_bytes(32, 'big') return x.to_bytes(32, 'big')
@ -443,19 +225,16 @@ def tweak_add_pubkey(key, tweak):
assert len(key) == 32 assert len(key) == 32
assert len(tweak) == 32 assert len(tweak) == 32
x_coord = int.from_bytes(key, 'big') P = secp256k1.GE.from_bytes_xonly(key)
if x_coord >= SECP256K1_FIELD_SIZE:
return None
P = SECP256K1.lift_x(x_coord)
if P is None: if P is None:
return None return None
t = int.from_bytes(tweak, 'big') t = int.from_bytes(tweak, 'big')
if t >= SECP256K1_ORDER: if t >= ORDER:
return None return None
Q = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, t), (P, 1)])) Q = t * secp256k1.G + P
if Q is None: if Q.infinity:
return None return None
return (Q[0].to_bytes(32, 'big'), not SECP256K1.has_even_y(Q)) return (Q.to_bytes_xonly(), not Q.y.is_even())
def verify_schnorr(key, sig, msg): def verify_schnorr(key, sig, msg):
"""Verify a Schnorr signature (see BIP 340). """Verify a Schnorr signature (see BIP 340).
@ -468,23 +247,20 @@ def verify_schnorr(key, sig, msg):
assert len(msg) == 32 assert len(msg) == 32
assert len(sig) == 64 assert len(sig) == 64
x_coord = int.from_bytes(key, 'big') P = secp256k1.GE.from_bytes_xonly(key)
if x_coord == 0 or x_coord >= SECP256K1_FIELD_SIZE:
return False
P = SECP256K1.lift_x(x_coord)
if P is None: if P is None:
return False return False
r = int.from_bytes(sig[0:32], 'big') r = int.from_bytes(sig[0:32], 'big')
if r >= SECP256K1_FIELD_SIZE: if r >= secp256k1.FE.SIZE:
return False return False
s = int.from_bytes(sig[32:64], 'big') s = int.from_bytes(sig[32:64], 'big')
if s >= SECP256K1_ORDER: if s >= ORDER:
return False return False
e = int.from_bytes(TaggedHash("BIP0340/challenge", sig[0:32] + key + msg), 'big') % SECP256K1_ORDER e = int.from_bytes(TaggedHash("BIP0340/challenge", sig[0:32] + key + msg), 'big') % ORDER
R = SECP256K1.mul([(SECP256K1_G, s), (P, SECP256K1_ORDER - e)]) R = secp256k1.GE.mul((s, secp256k1.G), (-e, P))
if not SECP256K1.has_even_y(R): if R.infinity or not R.y.is_even():
return False return False
if ((r * R[2] * R[2]) % SECP256K1_FIELD_SIZE) != R[0]: if r != R.x:
return False return False
return True return True
@ -499,23 +275,24 @@ def sign_schnorr(key, msg, aux=None, flip_p=False, flip_r=False):
assert len(aux) == 32 assert len(aux) == 32
sec = int.from_bytes(key, 'big') sec = int.from_bytes(key, 'big')
if sec == 0 or sec >= SECP256K1_ORDER: if sec == 0 or sec >= ORDER:
return None return None
P = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, sec)])) P = sec * secp256k1.G
if SECP256K1.has_even_y(P) == flip_p: if P.y.is_even() == flip_p:
sec = SECP256K1_ORDER - sec sec = ORDER - sec
t = (sec ^ int.from_bytes(TaggedHash("BIP0340/aux", aux), 'big')).to_bytes(32, 'big') t = (sec ^ int.from_bytes(TaggedHash("BIP0340/aux", aux), 'big')).to_bytes(32, 'big')
kp = int.from_bytes(TaggedHash("BIP0340/nonce", t + P[0].to_bytes(32, 'big') + msg), 'big') % SECP256K1_ORDER kp = int.from_bytes(TaggedHash("BIP0340/nonce", t + P.to_bytes_xonly() + msg), 'big') % ORDER
assert kp != 0 assert kp != 0
R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, kp)])) R = kp * secp256k1.G
k = kp if SECP256K1.has_even_y(R) != flip_r else SECP256K1_ORDER - kp k = kp if R.y.is_even() != flip_r else ORDER - kp
e = int.from_bytes(TaggedHash("BIP0340/challenge", R[0].to_bytes(32, 'big') + P[0].to_bytes(32, 'big') + msg), 'big') % SECP256K1_ORDER e = int.from_bytes(TaggedHash("BIP0340/challenge", R.to_bytes_xonly() + P.to_bytes_xonly() + msg), 'big') % ORDER
return R[0].to_bytes(32, 'big') + ((k + e * sec) % SECP256K1_ORDER).to_bytes(32, 'big') return R.to_bytes_xonly() + ((k + e * sec) % ORDER).to_bytes(32, 'big')
class TestFrameworkKey(unittest.TestCase): class TestFrameworkKey(unittest.TestCase):
def test_schnorr(self): def test_schnorr(self):
"""Test the Python Schnorr implementation.""" """Test the Python Schnorr implementation."""
byte_arrays = [generate_privkey() for _ in range(3)] + [v.to_bytes(32, 'big') for v in [0, SECP256K1_ORDER - 1, SECP256K1_ORDER, 2**256 - 1]] byte_arrays = [generate_privkey() for _ in range(3)] + [v.to_bytes(32, 'big') for v in [0, ORDER - 1, ORDER, 2**256 - 1]]
keys = {} keys = {}
for privkey in byte_arrays: # build array of key/pubkey pairs for privkey in byte_arrays: # build array of key/pubkey pairs
pubkey, _ = compute_xonly_pubkey(privkey) pubkey, _ = compute_xonly_pubkey(privkey)

View file

@ -0,0 +1,314 @@
# Copyright (c) 2022-2023 The Bitcoin Core developers
# Distributed under the MIT software license, see the accompanying
# file COPYING or http://www.opensource.org/licenses/mit-license.php.
"""Test-only implementation of low-level secp256k1 field and group arithmetic
It is designed for ease of understanding, not performance.
WARNING: This code is slow and trivially vulnerable to side channel attacks. Do not use for
anything but tests.
Exports:
* FE: class for secp256k1 field elements
* GE: class for secp256k1 group elements
* G: the secp256k1 generator point
"""
class FE:
"""Objects of this class represent elements of the field GF(2**256 - 2**32 - 977).
They are represented internally in numerator / denominator form, in order to delay inversions.
"""
# The size of the field (also its modulus and characteristic).
SIZE = 2**256 - 2**32 - 977
def __init__(self, a=0, b=1):
"""Initialize a field element a/b; both a and b can be ints or field elements."""
if isinstance(a, FE):
num = a._num
den = a._den
else:
num = a % FE.SIZE
den = 1
if isinstance(b, FE):
den = (den * b._num) % FE.SIZE
num = (num * b._den) % FE.SIZE
else:
den = (den * b) % FE.SIZE
assert den != 0
if num == 0:
den = 1
self._num = num
self._den = den
def __add__(self, a):
"""Compute the sum of two field elements (second may be int)."""
if isinstance(a, FE):
return FE(self._num * a._den + self._den * a._num, self._den * a._den)
return FE(self._num + self._den * a, self._den)
def __radd__(self, a):
"""Compute the sum of an integer and a field element."""
return FE(a) + self
def __sub__(self, a):
"""Compute the difference of two field elements (second may be int)."""
if isinstance(a, FE):
return FE(self._num * a._den - self._den * a._num, self._den * a._den)
return FE(self._num - self._den * a, self._den)
def __rsub__(self, a):
"""Compute the difference of an integer and a field element."""
return FE(a) - self
def __mul__(self, a):
"""Compute the product of two field elements (second may be int)."""
if isinstance(a, FE):
return FE(self._num * a._num, self._den * a._den)
return FE(self._num * a, self._den)
def __rmul__(self, a):
"""Compute the product of an integer with a field element."""
return FE(a) * self
def __truediv__(self, a):
"""Compute the ratio of two field elements (second may be int)."""
return FE(self, a)
def __pow__(self, a):
"""Raise a field element to an integer power."""
return FE(pow(self._num, a, FE.SIZE), pow(self._den, a, FE.SIZE))
def __neg__(self):
"""Negate a field element."""
return FE(-self._num, self._den)
def __int__(self):
"""Convert a field element to an integer in range 0..p-1. The result is cached."""
if self._den != 1:
self._num = (self._num * pow(self._den, -1, FE.SIZE)) % FE.SIZE
self._den = 1
return self._num
def sqrt(self):
"""Compute the square root of a field element if it exists (None otherwise).
Due to the fact that our modulus is of the form (p % 4) == 3, the Tonelli-Shanks
algorithm (https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm) is simply
raising the argument to the power (p + 1) / 4.
To see why: (p-1) % 2 = 0, so 2 divides the order of the multiplicative group,
and thus only half of the non-zero field elements are squares. An element a is
a (nonzero) square when Euler's criterion, a^((p-1)/2) = 1 (mod p), holds. We're
looking for x such that x^2 = a (mod p). Given a^((p-1)/2) = 1, that is equivalent
to x^2 = a^(1 + (p-1)/2) mod p. As (1 + (p-1)/2) is even, this is equivalent to
x = a^((1 + (p-1)/2)/2) mod p, or x = a^((p+1)/4) mod p."""
v = int(self)
s = pow(v, (FE.SIZE + 1) // 4, FE.SIZE)
if s**2 % FE.SIZE == v:
return FE(s)
return None
def is_square(self):
"""Determine if this field element has a square root."""
# A more efficient algorithm is possible here (Jacobi symbol).
return self.sqrt() is not None
def is_even(self):
"""Determine whether this field element, represented as integer in 0..p-1, is even."""
return int(self) & 1 == 0
def __eq__(self, a):
"""Check whether two field elements are equal (second may be an int)."""
if isinstance(a, FE):
return (self._num * a._den - self._den * a._num) % FE.SIZE == 0
return (self._num - self._den * a) % FE.SIZE == 0
def to_bytes(self):
"""Convert a field element to a 32-byte array (BE byte order)."""
return int(self).to_bytes(32, 'big')
@staticmethod
def from_bytes(b):
"""Convert a 32-byte array to a field element (BE byte order, no overflow allowed)."""
v = int.from_bytes(b, 'big')
if v >= FE.SIZE:
return None
return FE(v)
def __str__(self):
"""Convert this field element to a 64 character hex string."""
return f"{int(self):064x}"
def __repr__(self):
"""Get a string representation of this field element."""
return f"FE(0x{int(self):x})"
class GE:
"""Objects of this class represent secp256k1 group elements (curve points or infinity)
Normal points on the curve have fields:
* x: the x coordinate (a field element)
* y: the y coordinate (a field element, satisfying y^2 = x^3 + 7)
* infinity: False
The point at infinity has field:
* infinity: True
"""
# Order of the group (number of points on the curve, plus 1 for infinity)
ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
# Number of valid distinct x coordinates on the curve.
ORDER_HALF = ORDER // 2
def __init__(self, x=None, y=None):
"""Initialize a group element with specified x and y coordinates, or infinity."""
if x is None:
# Initialize as infinity.
assert y is None
self.infinity = True
else:
# Initialize as point on the curve (and check that it is).
fx = FE(x)
fy = FE(y)
assert fy**2 == fx**3 + 7
self.infinity = False
self.x = fx
self.y = fy
def __add__(self, a):
"""Add two group elements together."""
# Deal with infinity: a + infinity == infinity + a == a.
if self.infinity:
return a
if a.infinity:
return self
if self.x == a.x:
if self.y != a.y:
# A point added to its own negation is infinity.
assert self.y + a.y == 0
return GE()
else:
# For identical inputs, use the tangent (doubling formula).
lam = (3 * self.x**2) / (2 * self.y)
else:
# For distinct inputs, use the line through both points (adding formula).
lam = (self.y - a.y) / (self.x - a.x)
# Determine point opposite to the intersection of that line with the curve.
x = lam**2 - (self.x + a.x)
y = lam * (self.x - x) - self.y
return GE(x, y)
@staticmethod
def mul(*aps):
"""Compute a (batch) scalar group element multiplication.
GE.mul((a1, p1), (a2, p2), (a3, p3)) is identical to a1*p1 + a2*p2 + a3*p3,
but more efficient."""
# Reduce all the scalars modulo order first (so we can deal with negatives etc).
naps = [(a % GE.ORDER, p) for a, p in aps]
# Start with point at infinity.
r = GE()
# Iterate over all bit positions, from high to low.
for i in range(255, -1, -1):
# Double what we have so far.
r = r + r
# Add then add the points for which the corresponding scalar bit is set.
for (a, p) in naps:
if (a >> i) & 1:
r += p
return r
def __rmul__(self, a):
"""Multiply an integer with a group element."""
return GE.mul((a, self))
def __neg__(self):
"""Compute the negation of a group element."""
if self.infinity:
return self
return GE(self.x, -self.y)
def to_bytes_compressed(self):
"""Convert a non-infinite group element to 33-byte compressed encoding."""
assert not self.infinity
return bytes([3 - self.y.is_even()]) + self.x.to_bytes()
def to_bytes_uncompressed(self):
"""Convert a non-infinite group element to 65-byte uncompressed encoding."""
assert not self.infinity
return b'\x04' + self.x.to_bytes() + self.y.to_bytes()
def to_bytes_xonly(self):
"""Convert (the x coordinate of) a non-infinite group element to 32-byte xonly encoding."""
assert not self.infinity
return self.x.to_bytes()
@staticmethod
def lift_x(x):
"""Return group element with specified field element as x coordinate (and even y)."""
y = (FE(x)**3 + 7).sqrt()
if y is None:
return None
if not y.is_even():
y = -y
return GE(x, y)
@staticmethod
def from_bytes(b):
"""Convert a compressed or uncompressed encoding to a group element."""
assert len(b) in (33, 65)
if len(b) == 33:
if b[0] != 2 and b[0] != 3:
return None
x = FE.from_bytes(b[1:])
if x is None:
return None
r = GE.lift_x(x)
if r is None:
return None
if b[0] == 3:
r = -r
return r
else:
if b[0] != 4:
return None
x = FE.from_bytes(b[1:33])
y = FE.from_bytes(b[33:])
if y**2 != x**3 + 7:
return None
return GE(x, y)
@staticmethod
def from_bytes_xonly(b):
"""Convert a point given in xonly encoding to a group element."""
assert len(b) == 32
x = FE.from_bytes(b)
if x is None:
return None
return GE.lift_x(x)
@staticmethod
def is_valid_x(x):
"""Determine whether the provided field element is a valid X coordinate."""
return (FE(x)**3 + 7).is_square()
def __str__(self):
"""Convert this group element to a string."""
if self.infinity:
return "(inf)"
return f"({self.x},{self.y})"
def __repr__(self):
"""Get a string representation for this group element."""
if self.infinity:
return "GE()"
return f"GE(0x{int(self.x):x},0x{int(self.y):x})"
# The secp256k1 generator point
G = GE.lift_x(0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798)