mbedtls hardware RSA: Fix "mbedtls_mpi_exp_mod" hardware calculations

This commit is contained in:
Dong Heng 2016-11-16 20:37:51 +08:00 committed by Angus Gratton
parent cf8c9770a0
commit 6b687b43f4

View File

@ -53,12 +53,11 @@ void mbedtls_mpi_printf(const char *name, const mbedtls_mpi *X)
static char buf[1024];
size_t n;
memset(buf, 0, sizeof(buf));
printf("%s = 0x", name);
mbedtls_mpi_write_string(X, 16, buf, sizeof(buf)-1, &n);
if(n) {
puts(buf);
ESP_LOGI(TAG, "%s = 0x%s", name, buf);
} else {
puts("TOOLONG");
ESP_LOGI(TAG, "TOOLONG");
}
}
@ -278,6 +277,7 @@ int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
/*
* Sliding-window exponentiation: Z = X^Y mod M (HAC 14.85)
*/
#if 0
int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR )
{
int ret;
@ -336,6 +336,155 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi
return ret;
}
#else
/**
* There is a need for the value of integer N' such that B^-1(B-1)-N^-1N'=1,
* where B^-1(B-1) mod N=1. Actually, only the least significant part of
* N' is needed, hence the definition N0'=N' mod b. We reproduce below the
* simple algorithm from an article by Dusse and Kaliski to efficiently
* find N0' from N0 and b
*/
static mbedtls_mpi_uint modular_inverse(const mbedtls_mpi *M)
{
int i;
uint64_t t = 1;
uint64_t two_2_i_minus_1 = 2; /* 2^(i-1) */
uint64_t two_2_i = 4; /* 2^i */
uint64_t N = M->p[0];
for (i = 2; i <= 32; i++) {
if ((mbedtls_mpi_uint) N * t % two_2_i >= two_2_i_minus_1) {
t += two_2_i_minus_1;
}
two_2_i_minus_1 <<= 1;
two_2_i <<= 1;
}
return (mbedtls_mpi_uint)(UINT32_MAX - t + 1);
}
static int bignum_param_init(const mbedtls_mpi *M, mbedtls_mpi *_RR, mbedtls_mpi *r, mbedtls_mpi_uint *Mi, size_t num_words)
{
int ret = 0;
size_t num_bits;
mbedtls_mpi RR;
/* Calculate number of bits */
num_bits = num_words * 32;
ESP_LOGI(TAG, "num_bits = %d\n", num_bits);
/*
* R = b^n where b = 2^32, n=num_words,
* R = 2^N (where N=num_bits)
* RR(R^2) = 2^(2*N) (where N=num_bits)
*
* r = RR(R^2) mod M
*
* Get the RR(RR == r) value from up level if RR and RR->p is not NULL
*/
ESP_LOGI(TAG, "r = RR(R^2) mod M\n");
if (_RR == NULL || _RR->p == NULL) {
ESP_LOGI(TAG, "RR(R^2) = 2^(2*N) (where N=num_bits)\n");
mbedtls_mpi_init(&RR);
MBEDTLS_MPI_CHK(mbedtls_mpi_set_bit(&RR, num_bits * 2, 1));
mbedtls_mpi_printf("RR", &RR);
MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(r, &RR, M));
if (_RR != NULL)
memcpy(_RR, r, sizeof( mbedtls_mpi ) );
} else {
memcpy(r, _RR, sizeof( mbedtls_mpi ) );
}
mbedtls_mpi_printf("r", r);
*Mi = modular_inverse(M);
cleanup:
mbedtls_mpi_free(&RR);
return ret;
}
static void bignum_param_deinit(mbedtls_mpi *_RR, mbedtls_mpi *r)
{
if (_RR == NULL || _RR->p == NULL)
mbedtls_mpi_free(r);
}
/*
* Sliding-window exponentiation: Z = X^Y mod M (HAC 14.85)
*/
int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR )
{
int ret = 0;
size_t z_words = hardware_words_needed(Z);
size_t x_words = hardware_words_needed(X);
size_t y_words = hardware_words_needed(Y);
size_t m_words = hardware_words_needed(M);
size_t num_words;
mbedtls_mpi r;
mbedtls_mpi_uint Mi = 0;
/* "all numbers must be the same length", so choose longest number
as cardinal length of operation...
*/
num_words = z_words;
if (x_words > num_words) {
num_words = x_words;
}
if (y_words > num_words) {
num_words = y_words;
}
if (m_words > num_words) {
num_words = m_words;
}
ESP_LOGI(TAG, "num_words = %d # %d, %d, %d\n", num_words, x_words, y_words, m_words);
if (num_words * 32 > 4096)
return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
mbedtls_mpi_init(&r);
ret = bignum_param_init(M, _RR, &r, &Mi, num_words);
if (ret != 0) {
return ret;
}
mbedtls_mpi_printf("X",X);
mbedtls_mpi_printf("Y",Y);
esp_mpi_acquire_hardware();
/* "mode" register loaded with number of 512-bit blocks, minus 1 */
REG_WRITE(RSA_MODEXP_MODE_REG, (num_words / 16) - 1);
/* Load M, X, Rinv, M-prime (M-prime is mod 2^32) */
mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words);
mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words);
mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, &r, num_words);
REG_WRITE(RSA_M_DASH_REG, Mi);
execute_op(RSA_START_MODEXP_REG);
ret = mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, num_words);
esp_mpi_release_hardware();
mbedtls_mpi_printf("Z",Z);
ESP_LOGI(TAG, "print (Z == (X ** Y) %% M)\n");
bignum_param_deinit(_RR, &r);
return ret;
}
#endif
#endif /* MBEDTLS_MPI_EXP_MOD_ALT */
@ -385,7 +534,7 @@ static int modular_op_prepare(const mbedtls_mpi *X, const mbedtls_mpi *M, size_t
/* Block of debugging data, output suitable to paste into Python
TODO remove
*/
mbedtls_mpi_printf("R", &RR);
mbedtls_mpi_printf("RR", &RR);
mbedtls_mpi_printf("M", M);
mbedtls_mpi_printf("Rinv", &Rinv);
mbedtls_mpi_printf("Mprime", &Mprime);
@ -463,6 +612,7 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
multiplication doesn't have the same restriction, so result is simply the
number of bits in X plus number of bits in in Y.)
*/
//ESP_LOGE(TAG, "INFO: %d bit result (%d bits * %d bits)\n", words_z * 32, mbedtls_mpi_bitlen(X), mbedtls_mpi_bitlen(Y));
if (words_mult * 32 > 2048) {
/* Calculate new length of Z */
words_z = words_x + words_y;