/* * SPDX-FileCopyrightText: 2020-2021 Espressif Systems (Shanghai) CO LTD * * SPDX-License-Identifier: Apache-2.0 */ #include #include "esp_mbedtls_dynamic_impl.h" int __real_mbedtls_ssl_write(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len); int __real_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len); void __real_mbedtls_ssl_free(mbedtls_ssl_context *ssl); int __real_mbedtls_ssl_session_reset(mbedtls_ssl_context *ssl); int __real_mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf); int __real_mbedtls_ssl_send_alert_message(mbedtls_ssl_context *ssl, unsigned char level, unsigned char message); int __real_mbedtls_ssl_close_notify(mbedtls_ssl_context *ssl); int __wrap_mbedtls_ssl_write(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len); int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len); void __wrap_mbedtls_ssl_free(mbedtls_ssl_context *ssl); int __wrap_mbedtls_ssl_session_reset(mbedtls_ssl_context *ssl); int __wrap_mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf); int __wrap_mbedtls_ssl_send_alert_message(mbedtls_ssl_context *ssl, unsigned char level, unsigned char message); int __wrap_mbedtls_ssl_close_notify(mbedtls_ssl_context *ssl); static const char *TAG = "SSL TLS"; static int tx_done(mbedtls_ssl_context *ssl) { if (!ssl->out_left) return 1; return 0; } static int rx_done(mbedtls_ssl_context *ssl) { if (!ssl->in_msglen) { return 1; } ESP_LOGD(TAG, "RX left %d bytes", ssl->in_msglen); return 0; } static void ssl_transform_init( mbedtls_ssl_transform *transform ) { memset( transform, 0, sizeof(mbedtls_ssl_transform) ); mbedtls_cipher_init( &transform->cipher_ctx_enc ); mbedtls_cipher_init( &transform->cipher_ctx_dec ); mbedtls_md_init( &transform->md_ctx_enc ); mbedtls_md_init( &transform->md_ctx_dec ); } static void ssl_update_checksum_start( mbedtls_ssl_context *ssl, const unsigned char *buf, size_t len ) { #if defined(MBEDTLS_SSL_PROTO_SSL3) || defined(MBEDTLS_SSL_PROTO_TLS1) || \ defined(MBEDTLS_SSL_PROTO_TLS1_1) mbedtls_md5_update_ret( &ssl->handshake->fin_md5 , buf, len ); mbedtls_sha1_update_ret( &ssl->handshake->fin_sha1, buf, len ); #endif #if defined(MBEDTLS_SSL_PROTO_TLS1_2) #if defined(MBEDTLS_SHA256_C) mbedtls_sha256_update_ret( &ssl->handshake->fin_sha256, buf, len ); #endif #if defined(MBEDTLS_SHA512_C) mbedtls_sha512_update_ret( &ssl->handshake->fin_sha512, buf, len ); #endif #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ } static void ssl_handshake_params_init( mbedtls_ssl_handshake_params *handshake ) { memset( handshake, 0, sizeof( mbedtls_ssl_handshake_params ) ); #if defined(MBEDTLS_SSL_PROTO_SSL3) || defined(MBEDTLS_SSL_PROTO_TLS1) || \ defined(MBEDTLS_SSL_PROTO_TLS1_1) mbedtls_md5_init( &handshake->fin_md5 ); mbedtls_sha1_init( &handshake->fin_sha1 ); mbedtls_md5_starts_ret( &handshake->fin_md5 ); mbedtls_sha1_starts_ret( &handshake->fin_sha1 ); #endif #if defined(MBEDTLS_SSL_PROTO_TLS1_2) #if defined(MBEDTLS_SHA256_C) mbedtls_sha256_init( &handshake->fin_sha256 ); mbedtls_sha256_starts_ret( &handshake->fin_sha256, 0 ); #endif #if defined(MBEDTLS_SHA512_C) mbedtls_sha512_init( &handshake->fin_sha512 ); mbedtls_sha512_starts_ret( &handshake->fin_sha512, 1 ); #endif #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */ handshake->update_checksum = ssl_update_checksum_start; #if defined(MBEDTLS_SSL_PROTO_TLS1_2) && \ defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) mbedtls_ssl_sig_hash_set_init( &handshake->hash_algs ); #endif #if defined(MBEDTLS_DHM_C) mbedtls_dhm_init( &handshake->dhm_ctx ); #endif #if defined(MBEDTLS_ECDH_C) mbedtls_ecdh_init( &handshake->ecdh_ctx ); #endif #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) mbedtls_ecjpake_init( &handshake->ecjpake_ctx ); #if defined(MBEDTLS_SSL_CLI_C) handshake->ecjpake_cache = NULL; handshake->ecjpake_cache_len = 0; #endif #endif #if defined(MBEDTLS_SSL__ECP_RESTARTABLE) mbedtls_x509_crt_restart_init( &handshake->ecrs_ctx ); #endif #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) handshake->sni_authmode = MBEDTLS_SSL_VERIFY_UNSET; #endif } static int ssl_handshake_init( mbedtls_ssl_context *ssl ) { /* Clear old handshake information if present */ if( ssl->transform_negotiate ) mbedtls_ssl_transform_free( ssl->transform_negotiate ); if( ssl->session_negotiate ) mbedtls_ssl_session_free( ssl->session_negotiate ); if( ssl->handshake ) mbedtls_ssl_handshake_free( ssl ); /* * Either the pointers are now NULL or cleared properly and can be freed. * Now allocate missing structures. */ if( ssl->transform_negotiate == NULL ) { ssl->transform_negotiate = mbedtls_calloc( 1, sizeof(mbedtls_ssl_transform) ); } if( ssl->session_negotiate == NULL ) { ssl->session_negotiate = mbedtls_calloc( 1, sizeof(mbedtls_ssl_session) ); } if( ssl->handshake == NULL ) { ssl->handshake = mbedtls_calloc( 1, sizeof(mbedtls_ssl_handshake_params) ); } /* All pointers should exist and can be directly freed without issue */ if( ssl->handshake == NULL || ssl->transform_negotiate == NULL || ssl->session_negotiate == NULL ) { ESP_LOGD(TAG, "alloc() of ssl sub-contexts failed"); mbedtls_free( ssl->handshake ); mbedtls_free( ssl->transform_negotiate ); mbedtls_free( ssl->session_negotiate ); ssl->handshake = NULL; ssl->transform_negotiate = NULL; ssl->session_negotiate = NULL; return( MBEDTLS_ERR_SSL_ALLOC_FAILED ); } /* Initialize structures */ mbedtls_ssl_session_init( ssl->session_negotiate ); ssl_transform_init( ssl->transform_negotiate ); ssl_handshake_params_init( ssl->handshake ); return( 0 ); } int __wrap_mbedtls_ssl_setup(mbedtls_ssl_context *ssl, const mbedtls_ssl_config *conf) { ssl->conf = conf; CHECK_OK(ssl_handshake_init(ssl)); ssl->out_buf = NULL; CHECK_OK(esp_mbedtls_setup_tx_buffer(ssl)); ssl->in_buf = NULL; esp_mbedtls_setup_rx_buffer(ssl); return 0; } int __wrap_mbedtls_ssl_write(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len) { int ret; CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, 0)); ret = __real_mbedtls_ssl_write(ssl, buf, len); if (tx_done(ssl)) { CHECK_OK(esp_mbedtls_free_tx_buffer(ssl)); } return ret; } int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t len) { int ret; ESP_LOGD(TAG, "add mbedtls RX buffer"); ret = esp_mbedtls_add_rx_buffer(ssl); if (ret == MBEDTLS_ERR_SSL_CONN_EOF) { ESP_LOGD(TAG, "fail, the connection indicated an EOF"); return 0; } else if (ret < 0) { ESP_LOGD(TAG, "fail, error=-0x%x", -ret); return ret; } ESP_LOGD(TAG, "end"); ret = __real_mbedtls_ssl_read(ssl, buf, len); if (rx_done(ssl)) { CHECK_OK(esp_mbedtls_free_rx_buffer(ssl)); } return ret; } void __wrap_mbedtls_ssl_free(mbedtls_ssl_context *ssl) { if (ssl->out_buf) { esp_mbedtls_free_buf(ssl->out_buf); ssl->out_buf = NULL; } if (ssl->in_buf) { esp_mbedtls_free_buf(ssl->in_buf); ssl->in_buf = NULL; } __real_mbedtls_ssl_free(ssl); } int __wrap_mbedtls_ssl_session_reset(mbedtls_ssl_context *ssl) { CHECK_OK(esp_mbedtls_reset_add_tx_buffer(ssl)); CHECK_OK(esp_mbedtls_reset_add_rx_buffer(ssl)); CHECK_OK(__real_mbedtls_ssl_session_reset(ssl)); CHECK_OK(esp_mbedtls_reset_free_tx_buffer(ssl)); esp_mbedtls_reset_free_rx_buffer(ssl); return 0; } int __wrap_mbedtls_ssl_send_alert_message(mbedtls_ssl_context *ssl, unsigned char level, unsigned char message) { int ret; CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, 0)); ret = __real_mbedtls_ssl_send_alert_message(ssl, level, message); if (tx_done(ssl)) { CHECK_OK(esp_mbedtls_free_tx_buffer(ssl)); } return ret; } int __wrap_mbedtls_ssl_close_notify(mbedtls_ssl_context *ssl) { int ret; CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, 0)); ret = __real_mbedtls_ssl_close_notify(ssl); if (tx_done(ssl)) { CHECK_OK(esp_mbedtls_free_tx_buffer(ssl)); } return ret; }