hwcrypto bignum: Implement multiplication modulo

Fixes case where hardware bignum multiplication fails due to either
operand >2048 bits.
This commit is contained in:
Angus Gratton 2016-09-20 21:02:07 +10:00
parent 1a6dd44d03
commit 6b3bc4d8c5
2 changed files with 223 additions and 26 deletions

View File

@ -52,6 +52,140 @@ static void esp_mpi_release_hardware( void )
_lock_release(&mpi_lock); _lock_release(&mpi_lock);
} }
/* Given a & b, determine u & v such that
gcd(a,b) = d = au + bv
Underlying algorithm comes from:
http://www.ucl.ac.uk/~ucahcjm/combopt/ext_gcd_python_programs.pdf
http://www.hackersdelight.org/hdcodetxt/mont64.c.txt
*/
static void extended_binary_gcd(const mbedtls_mpi *a, const mbedtls_mpi *b,
mbedtls_mpi *u, mbedtls_mpi *v)
{
mbedtls_mpi ta, tb;
mbedtls_mpi_init(&ta);
mbedtls_mpi_copy(&ta, a);
mbedtls_mpi_init(&tb);
mbedtls_mpi_copy(&tb, b);
mbedtls_mpi_lset(u, 1);
mbedtls_mpi_lset(v, 0);
/* Loop invariant:
ta = u*2*a - v*b. */
while (mbedtls_mpi_cmp_int(&ta, 0) != 0) {
mbedtls_mpi_shift_r(&ta, 1);
if (mbedtls_mpi_get_bit(u, 0) == 0) {
// Remove common factor of 2 in u & v
mbedtls_mpi_shift_r(u, 1);
mbedtls_mpi_shift_r(v, 1);
}
else {
/* u = (u + b) >> 1 */
mbedtls_mpi_add_mpi(u, u, b);
mbedtls_mpi_shift_r(u, 1);
/* v = (v >> 1) + a */
mbedtls_mpi_shift_r(v, 1);
mbedtls_mpi_add_mpi(v, v, a);
}
}
mbedtls_mpi_free(&ta);
mbedtls_mpi_free(&tb);
/* u = u * 2, so 1 = u*a - v*b */
mbedtls_mpi_shift_l(u, 1);
}
/* inner part of MPI modular multiply, after Rinv & Mprime are calculated */
static int mpi_mul_mpi_mod_inner(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *M, mbedtls_mpi *Rinv, uint32_t Mprime, size_t num_words)
{
int ret;
mbedtls_mpi TA, TB;
size_t num_bits = num_words * 32;
mbedtls_mpi_grow(Rinv, num_words);
/* TODO: fill memory blocks directly so this isn't needed */
mbedtls_mpi_init(&TA);
mbedtls_mpi_copy(&TA, A);
mbedtls_mpi_grow(&TA, num_words);
A = &TA;
mbedtls_mpi_init(&TB);
mbedtls_mpi_copy(&TB, B);
mbedtls_mpi_grow(&TB, num_words);
B = &TB;
esp_mpi_acquire_hardware();
if(ets_bigint_mod_mult_prepare(A->p, B->p, M->p, Mprime,
Rinv->p, num_bits, false)) {
mbedtls_mpi_grow(X, num_words);
ets_bigint_wait_finish();
if(ets_bigint_mod_mult_getz(M->p, X->p, num_bits)) {
X->s = A->s * B->s;
ret = 0;
} else {
printf("ets_bigint_mod_mult_getz failed\n");
ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
}
} else {
printf("ets_bigint_mod_mult_prepare failed\n");
ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
}
esp_mpi_release_hardware();
/* unclear why this is necessary, but the result seems
to come back rotated 32 bits to the right... */
uint32_t last_word = X->p[num_words-1];
X->p[num_words-1] = 0;
mbedtls_mpi_shift_l(X, 32);
X->p[0] = last_word;
mbedtls_mpi_free(&TA);
mbedtls_mpi_free(&TB);
return ret;
}
/* X = (A * B) mod M
Not an mbedTLS function
num_bits guaranteed to be a multiple of 512 already.
TODO: ensure M is odd
*/
int esp_mpi_mul_mpi_mod(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *M, size_t num_bits)
{
int ret = 0;
mbedtls_mpi RR, Rinv, Mprime;
uint32_t Mprime_int;
size_t num_words = num_bits / 32;
/* Rinv & Mprime are calculated via extended binary gcd
algorithm, see references on extended_binary_gcd above.
*/
mbedtls_mpi_init(&Rinv);
mbedtls_mpi_init(&RR);
mbedtls_mpi_set_bit(&RR, num_bits+32, 1);
mbedtls_mpi_init(&Mprime);
extended_binary_gcd(&RR, M, &Rinv, &Mprime);
/* M' is mod 2^32 */
Mprime_int = Mprime.p[0];
ret = mpi_mul_mpi_mod_inner(X, A, B, M, &Rinv, Mprime_int, num_words);
mbedtls_mpi_free(&RR);
mbedtls_mpi_free(&Mprime);
mbedtls_mpi_free(&Rinv);
return ret;
}
/* /*
* Helper for mbedtls_mpi multiplication * Helper for mbedtls_mpi multiplication
* copied/trimmed from mbedtls bignum.c * copied/trimmed from mbedtls bignum.c
@ -223,6 +357,53 @@ static inline size_t hardware_words_needed(const mbedtls_mpi *mpi)
return res; return res;
} }
/* Special-case multiply, where we use hardware montgomery mod
multiplication to solve the case where A or B are >2048 bits so
can't do standard multiplication.
the modulus here is chosen with M=(2^num_bits-1)
to guarantee the output isn't actually modulo anything. This means
we don't need to calculate M' and Rinv, they are predictable
as follows:
M' = 1
Rinv = (1 << (num_bits - 32)
(See RSA Accelerator section in Technical Reference for derivation
of M', Rinv)
*/
static int esp_mpi_mult_mpi_failover_mod_mult(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, size_t num_words)
{
mbedtls_mpi M, Rinv;
int ret;
size_t mprime;
size_t num_bits = num_words * 32;
mbedtls_mpi_init(&M);
mbedtls_mpi_init(&Rinv);
/* TODO: it may be faster to just use 4096-bit arithmetic every time,
and make these constants rather than runtime derived
derived. */
/* M = (2^num_words)-1 */
mbedtls_mpi_grow(&M, num_words);
for(int i = 0; i < num_words*32; i++) {
mbedtls_mpi_set_bit(&M, i, 1);
}
/* Rinv = (2^num_words-32) */
mbedtls_mpi_grow(&Rinv, num_words);
mbedtls_mpi_set_bit(&Rinv, num_bits - 32, 1);
mprime = 1;
ret = mpi_mul_mpi_mod_inner(X, A, B, &M, &Rinv, mprime, num_words);
mbedtls_mpi_free(&M);
mbedtls_mpi_free(&Rinv);
return ret;
}
int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B ) int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
{ {
int ret = -1; int ret = -1;
@ -236,6 +417,8 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi
words_a = hardware_words_needed(A); words_a = hardware_words_needed(A);
words_b = hardware_words_needed(B); words_b = hardware_words_needed(B);
words_mult = (words_a > words_b ? words_a : words_b);
/* Take a copy of A if either X == A OR if A isn't long enough /* Take a copy of A if either X == A OR if A isn't long enough
to hold the number of words needed for hardware. to hold the number of words needed for hardware.
@ -248,47 +431,63 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi
RAM. But we need to reimplement ets_bigint_mult_prepare() in RAM. But we need to reimplement ets_bigint_mult_prepare() in
software for this. software for this.
*/ */
if( X == A || A->n < words_a) { if( X == A || A->n < words_mult) {
MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) );
MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TA, words_a) ); MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TA, words_mult) );
A = &TA; A = &TA;
} }
/* Same for B */ /* Same for B */
if( X == B || B->n < words_b ) { if( X == B || B->n < words_mult ) {
MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) );
MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TB, words_b) ); MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TB, words_mult) );
B = &TB; B = &TB;
} }
/* Result X has to have room for double the larger operand */ /* Result X has to have room for double the larger operand */
words_mult = (words_a > words_b ? words_a : words_b);
words_x = words_mult * 2; words_x = words_mult * 2;
MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, words_x ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, words_x ) );
/* TODO: check if lset here is necessary, hardware should zero */ /* TODO: check if lset here is necessary, hardware should zero */
MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, 0 ) ); MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, 0 ) );
esp_mpi_acquire_hardware(); /* If either operand is over 2048 bits, we can't use the standard hardware multiplier
(it assumes result is double longest operand, and result is max 4096 bits.)
However, we can fail over to mod_mult for up to 4096 bits.
*/
if(words_mult * 32 > 2048) { if(words_mult * 32 > 2048) {
printf("WARNING: %d bit operands (%d bits * %d bits) too large for hardware unit\n", words_mult * 32, mbedtls_mpi_bitlen(A), mbedtls_mpi_bitlen(B)); /* TODO: check if there's an overflow condition if words_a & words_b are both
} the bit lengths of the operands, result could be 1 bit longer
if (ets_bigint_mult_prepare(A->p, B->p, words_mult * 32)) {
ets_bigint_wait_finish();
/* NB: argument to bigint_mult_getz is length of inputs, double this number (words_x) is
copied to output X->p.
*/ */
if (ets_bigint_mult_getz(X->p, words_mult * 32) == true) { if((words_a + words_b) * 32 > 4096) {
ret = 0; printf("ERROR: %d bit operands (%d bits * %d bits) too large for hardware unit\n", words_mult * 32, mbedtls_mpi_bitlen(A), mbedtls_mpi_bitlen(B));
} else { ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
printf("ets_bigint_mult_getz failed\n"); }
} else {
} else{ ret = esp_mpi_mult_mpi_failover_mod_mult(X, A, B, words_a + words_b);
printf("Baseline multiplication failed\n"); }
} }
esp_mpi_release_hardware(); else {
X->s = A->s * B->s; /* normal mpi multiplication */
esp_mpi_acquire_hardware();
if (ets_bigint_mult_prepare(A->p, B->p, words_mult * 32)) {
ets_bigint_wait_finish();
/* NB: argument to bigint_mult_getz is length of inputs, double this number (words_x) is
copied to output X->p.
*/
if (ets_bigint_mult_getz(X->p, words_mult * 32) == true) {
X->s = A->s * B->s;
ret = 0;
} else {
printf("ets_bigint_mult_getz failed\n");
ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
}
} else{
printf("Baseline multiplication failed\n");
ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
}
esp_mpi_release_hardware();
}
cleanup: cleanup:
mbedtls_mpi_free( &TB ); mbedtls_mpi_free( &TA ); mbedtls_mpi_free( &TB ); mbedtls_mpi_free( &TA );

View File

@ -250,10 +250,8 @@
/* The following MPI (bignum) functions have ESP32 hardware support, /* The following MPI (bignum) functions have ESP32 hardware support,
Uncommenting these macros will use the hardware-accelerated Uncommenting these macros will use the hardware-accelerated
implementations. implementations.
Disabled as number of limbs limited by bug. Internal TW#7112.
*/ */
#define MBEDTLS_MPI_EXP_MOD_ALT //#define MBEDTLS_MPI_EXP_MOD_ALT
#define MBEDTLS_MPI_MUL_MPI_ALT #define MBEDTLS_MPI_MUL_MPI_ALT
/** /**