diff --git a/src/ecdsa_impl.h b/src/ecdsa_impl.h index 9dd0548493d..2792e3388ad 100644 --- a/src/ecdsa_impl.h +++ b/src/ecdsa_impl.h @@ -25,7 +25,7 @@ int static secp256k1_ecdsa_pubkey_parse(secp256k1_ge_t *elem, const unsigned cha if (size == 33 && (pub[0] == 0x02 || pub[0] == 0x03)) { secp256k1_fe_t x; secp256k1_fe_set_b32(&x, pub+1); - secp256k1_ge_set_xo(elem, &x, pub[0] == 0x03); + return secp256k1_ge_set_xo(elem, &x, pub[0] == 0x03); } else if (size == 65 && (pub[0] == 0x04 || pub[0] == 0x06 || pub[0] == 0x07)) { secp256k1_fe_t x, y; secp256k1_fe_set_b32(&x, pub+1); @@ -33,10 +33,10 @@ int static secp256k1_ecdsa_pubkey_parse(secp256k1_ge_t *elem, const unsigned cha secp256k1_ge_set_xy(elem, &x, &y); if ((pub[0] == 0x06 || pub[0] == 0x07) && secp256k1_fe_is_odd(&y) != (pub[0] == 0x07)) return 0; + return secp256k1_ge_is_valid(elem); } else { return 0; } - return secp256k1_ge_is_valid(elem); } int static secp256k1_ecdsa_sig_parse(secp256k1_ecdsa_sig_t *r, const unsigned char *sig, int size) { @@ -134,8 +134,7 @@ int static secp256k1_ecdsa_sig_recover(const secp256k1_ecdsa_sig_t *sig, secp256 secp256k1_fe_t fx; secp256k1_fe_set_b32(&fx, brx); secp256k1_ge_t x; - secp256k1_ge_set_xo(&x, &fx, recid & 1); - if (!secp256k1_ge_is_valid(&x)) + if (!secp256k1_ge_set_xo(&x, &fx, recid & 1)) return 0; secp256k1_gej_t xj; secp256k1_gej_set_ge(&xj, &x); diff --git a/src/field.h b/src/field.h index f31bbe14bab..f5fa68aa189 100644 --- a/src/field.h +++ b/src/field.h @@ -82,9 +82,10 @@ void static secp256k1_fe_mul(secp256k1_fe_t *r, const secp256k1_fe_t *a, const s * The output magnitude is 1 (but not guaranteed to be normalized). */ void static secp256k1_fe_sqr(secp256k1_fe_t *r, const secp256k1_fe_t *a); -/** Sets a field element to be the (modular) square root of another. Requires the inputs' magnitude to - * be at most 8. The output magnitude is 1 (but not guaranteed to be normalized). */ -void static secp256k1_fe_sqrt(secp256k1_fe_t *r, const secp256k1_fe_t *a); +/** Sets a field element to be the (modular) square root (if any exist) of another. Requires the + * input's magnitude to be at most 8. The output magnitude is 1 (but not guaranteed to be + * normalized). Return value indicates whether a square root was found. */ +int static secp256k1_fe_sqrt(secp256k1_fe_t *r, const secp256k1_fe_t *a); /** Sets a field element to be the (modular) inverse of another. Requires the input's magnitude to be * at most 8. The output magnitude is 1 (but not guaranteed to be normalized). */ diff --git a/src/field_impl.h b/src/field_impl.h index 11fdfd52a82..748c5ab9bd3 100644 --- a/src/field_impl.h +++ b/src/field_impl.h @@ -62,7 +62,7 @@ void static secp256k1_fe_set_hex(secp256k1_fe_t *r, const char *a, int alen) { secp256k1_fe_set_b32(r, tmp); } -void static secp256k1_fe_sqrt(secp256k1_fe_t *r, const secp256k1_fe_t *a) { +int static secp256k1_fe_sqrt(secp256k1_fe_t *r, const secp256k1_fe_t *a) { // The binary representation of (p + 1)/4 has 3 blocks of 1s, with lengths in // { 2, 22, 223 }. Use an addition chain to calculate 2^n - 1 for each block: @@ -121,6 +121,14 @@ void static secp256k1_fe_sqrt(secp256k1_fe_t *r, const secp256k1_fe_t *a) { secp256k1_fe_mul(&t1, &t1, &x2); secp256k1_fe_sqr(&t1, &t1); secp256k1_fe_sqr(r, &t1); + + // Check that a square root was actually calculated + + secp256k1_fe_sqr(&t1, r); + secp256k1_fe_negate(&t1, &t1, 1); + secp256k1_fe_add(&t1, a); + secp256k1_fe_normalize(&t1); + return secp256k1_fe_is_zero(&t1); } void static secp256k1_fe_inv(secp256k1_fe_t *r, const secp256k1_fe_t *a) { diff --git a/src/group.h b/src/group.h index fc02a424995..738daff2925 100644 --- a/src/group.h +++ b/src/group.h @@ -48,9 +48,9 @@ void static secp256k1_ge_set_infinity(secp256k1_ge_t *r); /** Set a group element equal to the point with given X and Y coordinates */ void static secp256k1_ge_set_xy(secp256k1_ge_t *r, const secp256k1_fe_t *x, const secp256k1_fe_t *y); -/** Set a group element (jacobian) equal to the point with given X coordinate, and given oddness for Y. - The result is not guaranteed to be valid. */ -void static secp256k1_ge_set_xo(secp256k1_ge_t *r, const secp256k1_fe_t *x, int odd); +/** Set a group element (affine) equal to the point with the given X coordinate, and given oddness + * for Y. Return value indicates whether the result is valid. */ +int static secp256k1_ge_set_xo(secp256k1_ge_t *r, const secp256k1_fe_t *x, int odd); /** Check whether a group element is the point at infinity. */ int static secp256k1_ge_is_infinity(const secp256k1_ge_t *a); @@ -91,7 +91,7 @@ void static secp256k1_gej_double(secp256k1_gej_t *r, const secp256k1_gej_t *a); /** Set r equal to the sum of a and b. */ void static secp256k1_gej_add(secp256k1_gej_t *r, const secp256k1_gej_t *a, const secp256k1_gej_t *b); -/** Set r equal to the sum of a and b (with b given in jacobian coordinates). This is more efficient +/** Set r equal to the sum of a and b (with b given in affine coordinates). This is more efficient than secp256k1_gej_add. */ void static secp256k1_gej_add_ge(secp256k1_gej_t *r, const secp256k1_gej_t *a, const secp256k1_ge_t *b); diff --git a/src/group_impl.h b/src/group_impl.h index 927d25a1dcd..d9dace7386f 100644 --- a/src/group_impl.h +++ b/src/group_impl.h @@ -77,17 +77,19 @@ void static secp256k1_gej_set_xy(secp256k1_gej_t *r, const secp256k1_fe_t *x, co secp256k1_fe_set_int(&r->z, 1); } -void static secp256k1_ge_set_xo(secp256k1_ge_t *r, const secp256k1_fe_t *x, int odd) { +int static secp256k1_ge_set_xo(secp256k1_ge_t *r, const secp256k1_fe_t *x, int odd) { r->x = *x; secp256k1_fe_t x2; secp256k1_fe_sqr(&x2, x); secp256k1_fe_t x3; secp256k1_fe_mul(&x3, x, &x2); r->infinity = 0; secp256k1_fe_t c; secp256k1_fe_set_int(&c, 7); secp256k1_fe_add(&c, &x3); - secp256k1_fe_sqrt(&r->y, &c); + if (!secp256k1_fe_sqrt(&r->y, &c)) + return 0; secp256k1_fe_normalize(&r->y); if (secp256k1_fe_is_odd(&r->y) != odd) secp256k1_fe_negate(&r->y, &r->y, 1); + return 1; } void static secp256k1_gej_set_ge(secp256k1_gej_t *r, const secp256k1_ge_t *a) { diff --git a/src/tests.c b/src/tests.c index 30baa5ea1de..819aceb4649 100644 --- a/src/tests.c +++ b/src/tests.c @@ -209,6 +209,54 @@ void run_num_smalltests() { run_num_int(); } +/***** FIELD TESTS *****/ + +void random_fe(secp256k1_fe_t *x) { + unsigned char bin[32]; + secp256k1_rand256(bin); + secp256k1_fe_set_b32(x, bin); +} + +void random_fe_non_square(secp256k1_fe_t *ns) { + secp256k1_fe_t r; + int tries = 100; + while (--tries >= 0) { + random_fe(ns); + if (!secp256k1_fe_sqrt(&r, ns)) + break; + } + // 2^-100 probability of spurious failure here + assert(tries >= 0); +} + +void test_sqrt(const secp256k1_fe_t *a, const secp256k1_fe_t *k) { + secp256k1_fe_t r1, r2; + int v = secp256k1_fe_sqrt(&r1, a); + assert((v == 0) == (k == NULL)); + + if (k != NULL) { + // Check that the returned root is +/- the given known answer + secp256k1_fe_negate(&r2, &r1, 1); + secp256k1_fe_add(&r1, k); secp256k1_fe_add(&r2, k); + secp256k1_fe_normalize(&r1); secp256k1_fe_normalize(&r2); + assert(secp256k1_fe_is_zero(&r1) || secp256k1_fe_is_zero(&r2)); + } +} + +void run_sqrt() { + secp256k1_fe_t ns, x, s, t; + random_fe_non_square(&ns); + for (int i=0; i<10*count; i++) { + random_fe(&x); + secp256k1_fe_sqr(&s, &x); + test_sqrt(&s, &x); + secp256k1_fe_mul(&t, &s, &ns); + test_sqrt(&t, NULL); + } +} + +/***** ECMULT TESTS *****/ + void run_ecmult_chain() { // random starting point A (on the curve) secp256k1_fe_t ax; secp256k1_fe_set_hex(&ax, "8b30bbe9ae2a990696b22f670709dff3727fd8bc04d3362c6c7bf458e2846004", 64); @@ -275,10 +323,7 @@ void run_ecmult_chain() { } void test_point_times_order(const secp256k1_gej_t *point) { - // either the point is not on the curve, or multiplying it by the order results in O - if (!secp256k1_gej_is_valid(point)) - return; - + // multiplying a point by the order results in O const secp256k1_num_t *order = &secp256k1_ge_consts->order; secp256k1_num_t zero; secp256k1_num_init(&zero); @@ -292,9 +337,14 @@ void test_point_times_order(const secp256k1_gej_t *point) { void run_point_times_order() { secp256k1_fe_t x; secp256k1_fe_set_hex(&x, "02", 2); for (int i=0; i<500; i++) { - secp256k1_ge_t p; secp256k1_ge_set_xo(&p, &x, 1); - secp256k1_gej_t j; secp256k1_gej_set_ge(&j, &p); - test_point_times_order(&j); + secp256k1_ge_t p; + if (secp256k1_ge_set_xo(&p, &x, 1)) { + assert(secp256k1_ge_is_valid(&p)); + secp256k1_gej_t j; + secp256k1_gej_set_ge(&j, &p); + assert(secp256k1_gej_is_valid(&j)); + test_point_times_order(&j); + } secp256k1_fe_sqr(&x, &x); } char c[65]; int cl=65; @@ -451,6 +501,9 @@ int main(int argc, char **argv) { // num tests run_num_smalltests(); + // field tests + run_sqrt(); + // ecmult tests run_wnaf(); run_point_times_order();