// // SPDX-FileCopyrightText: 2021-2022 Espressif Systems (Shanghai) CO LTD // // SPDX-License-Identifier: BSL-1.0 // #pragma once #include "mbedtls/ssl.h" #include "mbedtls/entropy.h" #include "mbedtls/ctr_drbg.h" #include "mbedtls/error.h" #include "mbedtls/certs.h" #include "mbedtls/esp_debug.h" #include "esp_log.h" namespace asio { namespace ssl { namespace mbedtls { const char *error_message(int error_code) { static char error_buf[100]; mbedtls_strerror(error_code, error_buf, sizeof(error_buf)); return error_buf; } void throw_alloc_failure(const char* location) { asio::error_code ec( MBEDTLS_ERR_SSL_ALLOC_FAILED, asio::error::get_mbedtls_category()); asio::detail::throw_error(ec, location); } namespace error_codes { bool is_error(int ret) { return ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE; } static bool want_write(int ret) { return ret == MBEDTLS_ERR_SSL_WANT_WRITE; } static bool want_read(int ret) { return ret == MBEDTLS_ERR_SSL_WANT_READ; } } // namespace error_codes enum rw_state { IDLE, READING, WRITING, CLOSED }; class engine { public: explicit engine(std::shared_ptr ctx): ctx_(std::move(ctx)), bio_(bio::new_pair("mbedtls-engine")), state_(IDLE), verify_mode_(0) {} void set_verify_mode(asio::ssl::verify_mode mode) { verify_mode_ = mode; } bio* ext_bio() const { return bio_.second.get(); } rw_state get_state() const { return state_; } int shutdown() { int ret = mbedtls_ssl_close_notify(&impl_.ssl_); if (ret) { impl::print_error("mbedtls_ssl_close_notify", ret); } state_ = CLOSED; return ret; } int connect() { return handshake(true); } int accept() { return handshake(false); } int write(const void *buffer, int len) { int ret = impl_.write(buffer, len); state_ = ret == len ? IDLE: WRITING; return ret; } int read(void *buffer, int len) { int ret = impl_.read(buffer, len); state_ = ret == len ? IDLE: READING; return ret; } private: int handshake(bool is_client_not_server) { if (impl_.before_handshake()) { impl_.configure(ctx_.get(), is_client_not_server, impl_verify_mode(is_client_not_server)); } return do_handshake(); } static int bio_read(void *ctx, unsigned char *buf, size_t len) { auto bio = static_cast(ctx); int read = bio->read(buf, len); if (read <= 0 && bio->should_read()) { return MBEDTLS_ERR_SSL_WANT_READ; } return read; } static int bio_write(void *ctx, const unsigned char *buf, size_t len) { auto bio = static_cast(ctx); int written = bio->write(buf, len); if (written <= 0 && bio->should_write()) { return MBEDTLS_ERR_SSL_WANT_WRITE; } return written; } int do_handshake() { int ret = 0; mbedtls_ssl_set_bio(&impl_.ssl_, bio_.first.get(), bio_write, bio_read, nullptr); while (impl_.ssl_.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER) { ret = mbedtls_ssl_handshake_step(&impl_.ssl_); if (ret != 0) { if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { impl::print_error("mbedtls_ssl_handshake_step", ret); } if (ret == MBEDTLS_ERR_SSL_WANT_READ) { state_ = READING; } else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { state_ = WRITING; } break; } } return ret; } // Converts OpenSSL verification mode to mbedtls enum int impl_verify_mode(bool is_client_not_server) const { int mode = MBEDTLS_SSL_VERIFY_UNSET; if (is_client_not_server) { if (verify_mode_ & SSL_VERIFY_PEER) mode = MBEDTLS_SSL_VERIFY_REQUIRED; else if (verify_mode_ == SSL_VERIFY_NONE) mode = MBEDTLS_SSL_VERIFY_NONE; } else { if (verify_mode_ & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) mode = MBEDTLS_SSL_VERIFY_REQUIRED; else if (verify_mode_ & SSL_VERIFY_PEER) mode = MBEDTLS_SSL_VERIFY_OPTIONAL; else if (verify_mode_ == SSL_VERIFY_NONE) mode = MBEDTLS_SSL_VERIFY_NONE; } return mode; } struct impl { static void print_error(const char* function, int error_code) { constexpr const char *TAG="mbedtls-engine-impl"; ESP_LOGE(TAG, "%s() returned -0x%04X", function, -error_code); ESP_LOGI(TAG, "-0x%04X: %s", -error_code, error_message(error_code)); } bool before_handshake() const { return ssl_.MBEDTLS_PRIVATE(state) == 0; } int write(const void *buffer, int len) { int ret = mbedtls_ssl_write(&ssl_, static_cast(buffer), len); if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { print_error("mbedtls_ssl_write", ret); } return ret; } int read(void *buffer, int len) { int ret = mbedtls_ssl_read(&ssl_, static_cast(buffer), len); if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_READ) { print_error("mbedtls_ssl_read", ret); } return ret; } impl() { const unsigned char pers[] = "asio ssl"; mbedtls_ssl_init(&ssl_); mbedtls_ssl_config_init(&conf_); mbedtls_ctr_drbg_init(&ctr_drbg_); #ifdef CONFIG_MBEDTLS_DEBUG mbedtls_esp_enable_debug_log(&conf_, CONFIG_MBEDTLS_DEBUG_LEVEL); #endif mbedtls_entropy_init(&entropy_); mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, pers, sizeof(pers)); mbedtls_x509_crt_init(&public_cert_); mbedtls_pk_init(&pk_key_); mbedtls_x509_crt_init(&ca_cert_); } bool configure(context *ctx, bool is_client_not_server, int mbedtls_verify_mode) { mbedtls_x509_crt_init(&public_cert_); mbedtls_pk_init(&pk_key_); mbedtls_x509_crt_init(&ca_cert_); int ret = mbedtls_ssl_config_defaults(&conf_, is_client_not_server ? MBEDTLS_SSL_IS_CLIENT: MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); if (ret) { print_error("mbedtls_ssl_config_defaults", ret); return false; } mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_); mbedtls_ssl_conf_authmode(&conf_, mbedtls_verify_mode); if (ctx->cert_chain_.size() > 0 && ctx->private_key_.size() > 0) { ret = mbedtls_x509_crt_parse(&public_cert_, ctx->data(container::CERT), ctx->size(container::CERT)); if (ret < 0) { print_error("mbedtls_x509_crt_parse", ret); return false; } ret = mbedtls_pk_parse_key(&pk_key_, ctx->data(container::PRIVKEY), ctx->size(container::PRIVKEY), nullptr, 0, mbedtls_ctr_drbg_random, &ctr_drbg_); if (ret < 0) { print_error("mbedtls_pk_parse_keyfile", ret); return false; } ret = mbedtls_ssl_conf_own_cert(&conf_, &public_cert_, &pk_key_); if (ret) { print_error("mbedtls_ssl_conf_own_cert", ret); return false; } } if (ctx->ca_cert_.size() > 0) { ret = mbedtls_x509_crt_parse(&ca_cert_, ctx->data(container::CA_CERT), ctx->size(container::CA_CERT)); if (ret < 0) { print_error("mbedtls_x509_crt_parse", ret); return false; } mbedtls_ssl_conf_ca_chain(&conf_, &ca_cert_, nullptr); } else { mbedtls_ssl_conf_ca_chain(&conf_, nullptr, nullptr); } ret = mbedtls_ssl_setup(&ssl_, &conf_); if (ret) { print_error("mbedtls_ssl_setup", ret); return false; } return true; } mbedtls_ssl_context ssl_{}; mbedtls_entropy_context entropy_{}; mbedtls_ctr_drbg_context ctr_drbg_{}; mbedtls_ssl_config conf_{}; mbedtls_x509_crt public_cert_{}; mbedtls_pk_context pk_key_{}; mbedtls_x509_crt ca_cert_{}; }; impl impl_{}; std::shared_ptr ctx_; std::pair, std::shared_ptr> bio_; enum rw_state state_; asio::ssl::verify_mode verify_mode_; }; } } } // namespace asio::ssl::mbedtls