diff --git a/src/scalar.h b/src/scalar.h index 3baacb3721..2469302b80 100644 --- a/src/scalar.h +++ b/src/scalar.h @@ -36,6 +36,9 @@ static void secp256k1_scalar_get_b32(unsigned char *bin, const secp256k1_scalar_ /** Add two scalars together (modulo the group order). */ static void secp256k1_scalar_add(secp256k1_scalar_t *r, const secp256k1_scalar_t *a, const secp256k1_scalar_t *b); +/** Add a power of two to a scalar. The result is not allowed to overflow. */ +static void secp256k1_scalar_add_bit(secp256k1_scalar_t *r, unsigned int bit); + /** Multiply two scalars (modulo the group order). */ static void secp256k1_scalar_mul(secp256k1_scalar_t *r, const secp256k1_scalar_t *a, const secp256k1_scalar_t *b); diff --git a/src/scalar_4x64_impl.h b/src/scalar_4x64_impl.h index f78718234f..02ae318040 100644 --- a/src/scalar_4x64_impl.h +++ b/src/scalar_4x64_impl.h @@ -75,6 +75,22 @@ static void secp256k1_scalar_add(secp256k1_scalar_t *r, const secp256k1_scalar_t secp256k1_scalar_reduce(r, t + secp256k1_scalar_check_overflow(r)); } +static void secp256k1_scalar_add_bit(secp256k1_scalar_t *r, unsigned int bit) { + VERIFY_CHECK(bit < 256); + uint128_t t = (uint128_t)r->d[0] + (((uint64_t)((bit >> 6) == 0)) << bit); + r->d[0] = t & 0xFFFFFFFFFFFFFFFFULL; t >>= 64; + t += (uint128_t)r->d[1] + (((uint64_t)((bit >> 6) == 1)) << (bit & 0x3F)); + r->d[1] = t & 0xFFFFFFFFFFFFFFFFULL; t >>= 64; + t += (uint128_t)r->d[2] + (((uint64_t)((bit >> 6) == 2)) << (bit & 0x3F)); + r->d[2] = t & 0xFFFFFFFFFFFFFFFFULL; t >>= 64; + t += (uint128_t)r->d[3] + (((uint64_t)((bit >> 6) == 3)) << (bit & 0x3F)); + r->d[3] = t & 0xFFFFFFFFFFFFFFFFULL; +#ifdef VERIFY + VERIFY_CHECK((t >> 64) == 0); + VERIFY_CHECK(secp256k1_scalar_check_overflow(r) == 0); +#endif +} + static void secp256k1_scalar_set_b32(secp256k1_scalar_t *r, const unsigned char *b32, int *overflow) { r->d[0] = (uint64_t)b32[31] | (uint64_t)b32[30] << 8 | (uint64_t)b32[29] << 16 | (uint64_t)b32[28] << 24 | (uint64_t)b32[27] << 32 | (uint64_t)b32[26] << 40 | (uint64_t)b32[25] << 48 | (uint64_t)b32[24] << 56; r->d[1] = (uint64_t)b32[23] | (uint64_t)b32[22] << 8 | (uint64_t)b32[21] << 16 | (uint64_t)b32[20] << 24 | (uint64_t)b32[19] << 32 | (uint64_t)b32[18] << 40 | (uint64_t)b32[17] << 48 | (uint64_t)b32[16] << 56; diff --git a/src/scalar_8x32_impl.h b/src/scalar_8x32_impl.h index e58be1365f..cad1065922 100644 --- a/src/scalar_8x32_impl.h +++ b/src/scalar_8x32_impl.h @@ -109,6 +109,30 @@ static void secp256k1_scalar_add(secp256k1_scalar_t *r, const secp256k1_scalar_t secp256k1_scalar_reduce(r, t + secp256k1_scalar_check_overflow(r)); } +static void secp256k1_scalar_add_bit(secp256k1_scalar_t *r, unsigned int bit) { + VERIFY_CHECK(bit < 256); + uint64_t t = (uint64_t)r->d[0] + (((uint32_t)((bit >> 5) == 0)) << bit); + r->d[0] = t & 0xFFFFFFFFULL; t >>= 32; + t += (uint64_t)r->d[1] + (((uint32_t)((bit >> 5) == 1)) << (bit & 0x1F)); + r->d[1] = t & 0xFFFFFFFFULL; t >>= 32; + t += (uint64_t)r->d[2] + (((uint32_t)((bit >> 5) == 2)) << (bit & 0x1F)); + r->d[2] = t & 0xFFFFFFFFULL; t >>= 32; + t += (uint64_t)r->d[3] + (((uint32_t)((bit >> 5) == 3)) << (bit & 0x1F)); + r->d[3] = t & 0xFFFFFFFFULL; t >>= 32; + t += (uint64_t)r->d[4] + (((uint32_t)((bit >> 5) == 4)) << (bit & 0x1F)); + r->d[4] = t & 0xFFFFFFFFULL; t >>= 32; + t += (uint64_t)r->d[5] + (((uint32_t)((bit >> 5) == 5)) << (bit & 0x1F)); + r->d[5] = t & 0xFFFFFFFFULL; t >>= 32; + t += (uint64_t)r->d[6] + (((uint32_t)((bit >> 5) == 6)) << (bit & 0x1F)); + r->d[6] = t & 0xFFFFFFFFULL; t >>= 32; + t += (uint64_t)r->d[7] + (((uint32_t)((bit >> 5) == 7)) << (bit & 0x1F)); + r->d[7] = t & 0xFFFFFFFFULL; +#ifdef VERIFY + VERIFY_CHECK((t >> 32) == 0); + VERIFY_CHECK(secp256k1_scalar_check_overflow(r) == 0); +#endif +} + static void secp256k1_scalar_set_b32(secp256k1_scalar_t *r, const unsigned char *b32, int *overflow) { r->d[0] = (uint32_t)b32[31] | (uint32_t)b32[30] << 8 | (uint32_t)b32[29] << 16 | (uint32_t)b32[28] << 24; r->d[1] = (uint32_t)b32[27] | (uint32_t)b32[26] << 8 | (uint32_t)b32[25] << 16 | (uint32_t)b32[24] << 24; diff --git a/src/tests.c b/src/tests.c index 143f91327c..e96444969c 100644 --- a/src/tests.c +++ b/src/tests.c @@ -382,6 +382,25 @@ void scalar_test(void) { CHECK(secp256k1_scalar_eq(&r1, &r2)); } + { + /* Test add_bit. */ + int bit = secp256k1_rand32() % 256; + secp256k1_scalar_t b; + secp256k1_scalar_clear(&b); + secp256k1_scalar_add_bit(&b, 0); + CHECK(secp256k1_scalar_is_one(&b)); + for (int i = 0; i < bit; i++) { + secp256k1_scalar_add(&b, &b, &b); + } + secp256k1_scalar_t r1 = s1, r2 = s1; + secp256k1_scalar_add(&r1, &r1, &b); + if (!(secp256k1_scalar_get_bits(&s1, 255, 1) == 1 && secp256k1_scalar_get_bits(&r1, 255, 1) == 0)) { + /* No overflow happened. */ + secp256k1_scalar_add_bit(&r2, bit); + CHECK(secp256k1_scalar_eq(&r1, &r2)); + } + } + { /* Test commutativity of mul. */ secp256k1_scalar_t r1, r2;