mirror of
https://github.com/espressif/esp-idf.git
synced 2024-10-05 20:47:46 -04:00
295 lines
8.9 KiB
C++
295 lines
8.9 KiB
C++
|
//
|
||
|
// SPDX-FileCopyrightText: 2021 Espressif Systems (Shanghai) CO LTD
|
||
|
//
|
||
|
// SPDX-License-Identifier: BSL-1.0
|
||
|
//
|
||
|
#pragma once
|
||
|
|
||
|
#include "mbedtls/ssl.h"
|
||
|
#include "mbedtls/entropy.h"
|
||
|
#include "mbedtls/ctr_drbg.h"
|
||
|
#include "mbedtls/error.h"
|
||
|
#include "mbedtls/certs.h"
|
||
|
#include "mbedtls/esp_debug.h"
|
||
|
#include "esp_log.h"
|
||
|
|
||
|
namespace asio {
|
||
|
namespace ssl {
|
||
|
namespace mbedtls {
|
||
|
|
||
|
const char *error_message(int error_code)
|
||
|
{
|
||
|
static char error_buf[100];
|
||
|
mbedtls_strerror(error_code, error_buf, sizeof(error_buf));
|
||
|
return error_buf;
|
||
|
}
|
||
|
|
||
|
void throw_alloc_failure(const char* location)
|
||
|
{
|
||
|
asio::error_code ec( MBEDTLS_ERR_SSL_ALLOC_FAILED, asio::error::get_mbedtls_category());
|
||
|
asio::detail::throw_error(ec, location);
|
||
|
}
|
||
|
|
||
|
namespace error_codes {
|
||
|
|
||
|
bool is_error(int ret)
|
||
|
{
|
||
|
return ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE;
|
||
|
}
|
||
|
|
||
|
static bool want_write(int ret)
|
||
|
{
|
||
|
return ret == MBEDTLS_ERR_SSL_WANT_WRITE;
|
||
|
}
|
||
|
|
||
|
static bool want_read(int ret)
|
||
|
{
|
||
|
return ret == MBEDTLS_ERR_SSL_WANT_READ;
|
||
|
}
|
||
|
|
||
|
} // namespace error_codes
|
||
|
|
||
|
enum rw_state {
|
||
|
IDLE, READING, WRITING, CLOSED
|
||
|
};
|
||
|
|
||
|
class engine {
|
||
|
public:
|
||
|
explicit engine(std::shared_ptr<context> ctx): ctx_(std::move(ctx)),
|
||
|
bio_(bio::new_pair("mbedtls-engine")), state_(IDLE), verify_mode_(0) {}
|
||
|
|
||
|
void set_verify_mode(asio::ssl::verify_mode mode)
|
||
|
{
|
||
|
verify_mode_ = mode;
|
||
|
}
|
||
|
|
||
|
bio* ext_bio() const
|
||
|
{
|
||
|
return bio_.second.get();
|
||
|
}
|
||
|
|
||
|
rw_state get_state() const
|
||
|
{
|
||
|
return state_;
|
||
|
}
|
||
|
|
||
|
int shutdown()
|
||
|
{
|
||
|
int ret = mbedtls_ssl_close_notify(&impl_.ssl_);
|
||
|
if (ret) {
|
||
|
impl::print_error("mbedtls_ssl_close_notify", ret);
|
||
|
}
|
||
|
state_ = CLOSED;
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
int connect()
|
||
|
{
|
||
|
return handshake(true);
|
||
|
}
|
||
|
|
||
|
int accept()
|
||
|
{
|
||
|
return handshake(false);
|
||
|
}
|
||
|
|
||
|
int write(const void *buffer, int len)
|
||
|
{
|
||
|
int ret = impl_.write(buffer, len);
|
||
|
state_ = ret == len ? IDLE: WRITING;
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
int read(void *buffer, int len)
|
||
|
{
|
||
|
int ret = impl_.read(buffer, len);
|
||
|
state_ = ret == len ? IDLE: READING;
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
int handshake(bool is_client_not_server)
|
||
|
{
|
||
|
if (impl_.before_handshake()) {
|
||
|
impl_.configure(ctx_.get(), is_client_not_server, impl_verify_mode(is_client_not_server));
|
||
|
}
|
||
|
return do_handshake();
|
||
|
}
|
||
|
|
||
|
static int bio_read(void *ctx, unsigned char *buf, size_t len)
|
||
|
{
|
||
|
auto bio = static_cast<BIO*>(ctx);
|
||
|
int read = bio->read(buf, len);
|
||
|
if (read <= 0 && bio->should_read()) {
|
||
|
return MBEDTLS_ERR_SSL_WANT_READ;
|
||
|
}
|
||
|
return read;
|
||
|
}
|
||
|
|
||
|
static int bio_write(void *ctx, const unsigned char *buf, size_t len)
|
||
|
{
|
||
|
auto bio = static_cast<BIO*>(ctx);
|
||
|
int written = bio->write(buf, len);
|
||
|
if (written <= 0 && bio->should_write()) {
|
||
|
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||
|
}
|
||
|
return written;
|
||
|
}
|
||
|
|
||
|
int do_handshake()
|
||
|
{
|
||
|
int ret = 0;
|
||
|
mbedtls_ssl_set_bio(&impl_.ssl_, bio_.first.get(), bio_write, bio_read, nullptr);
|
||
|
|
||
|
while (impl_.ssl_.state != MBEDTLS_SSL_HANDSHAKE_OVER) {
|
||
|
ret = mbedtls_ssl_handshake_step(&impl_.ssl_);
|
||
|
|
||
|
if (ret != 0) {
|
||
|
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||
|
impl::print_error("mbedtls_ssl_handshake_step", ret);
|
||
|
}
|
||
|
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
|
||
|
state_ = READING;
|
||
|
} else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||
|
state_ = WRITING;
|
||
|
}
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
// Converts OpenSSL verification mode to mbedtls enum
|
||
|
int impl_verify_mode(bool is_client_not_server) const
|
||
|
{
|
||
|
int mode = MBEDTLS_SSL_VERIFY_UNSET;
|
||
|
if (is_client_not_server) {
|
||
|
if (verify_mode_ & SSL_VERIFY_PEER)
|
||
|
mode = MBEDTLS_SSL_VERIFY_REQUIRED;
|
||
|
else if (verify_mode_ == SSL_VERIFY_NONE)
|
||
|
mode = MBEDTLS_SSL_VERIFY_NONE;
|
||
|
} else {
|
||
|
if (verify_mode_ & SSL_VERIFY_FAIL_IF_NO_PEER_CERT)
|
||
|
mode = MBEDTLS_SSL_VERIFY_REQUIRED;
|
||
|
else if (verify_mode_ & SSL_VERIFY_PEER)
|
||
|
mode = MBEDTLS_SSL_VERIFY_OPTIONAL;
|
||
|
else if (verify_mode_ == SSL_VERIFY_NONE)
|
||
|
mode = MBEDTLS_SSL_VERIFY_NONE;
|
||
|
}
|
||
|
return mode;
|
||
|
}
|
||
|
|
||
|
struct impl {
|
||
|
static void print_error(const char* function, int error_code)
|
||
|
{
|
||
|
constexpr const char *TAG="mbedtls-engine-impl";
|
||
|
ESP_LOGE(TAG, "%s() returned -0x%04X", function, -error_code);
|
||
|
ESP_LOGI(TAG, "-0x%04X: %s", -error_code, error_message(error_code));
|
||
|
}
|
||
|
|
||
|
bool before_handshake() const
|
||
|
{
|
||
|
return ssl_.state == 0;
|
||
|
}
|
||
|
|
||
|
int write(const void *buffer, int len)
|
||
|
{
|
||
|
int ret = mbedtls_ssl_write(&ssl_, static_cast<const unsigned char *>(buffer), len);
|
||
|
if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||
|
print_error("mbedtls_ssl_write", ret);
|
||
|
}
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
int read(void *buffer, int len)
|
||
|
{
|
||
|
int ret = mbedtls_ssl_read(&ssl_, static_cast<unsigned char *>(buffer), len);
|
||
|
if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_READ) {
|
||
|
print_error("mbedtls_ssl_read", ret);
|
||
|
}
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
impl()
|
||
|
{
|
||
|
const unsigned char pers[] = "asio ssl";
|
||
|
mbedtls_ssl_init(&ssl_);
|
||
|
mbedtls_ssl_config_init(&conf_);
|
||
|
mbedtls_ctr_drbg_init(&ctr_drbg_);
|
||
|
#ifdef CONFIG_MBEDTLS_DEBUG
|
||
|
mbedtls_esp_enable_debug_log(&conf_, CONFIG_MBEDTLS_DEBUG_LEVEL);
|
||
|
#endif
|
||
|
mbedtls_entropy_init(&entropy_);
|
||
|
mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, pers, sizeof(pers));
|
||
|
mbedtls_x509_crt_init(&public_cert_);
|
||
|
mbedtls_pk_init(&pk_key_);
|
||
|
mbedtls_x509_crt_init(&ca_cert_);
|
||
|
}
|
||
|
|
||
|
bool configure(context *ctx, bool is_client_not_server, int mbedtls_verify_mode)
|
||
|
{
|
||
|
mbedtls_x509_crt_init(&public_cert_);
|
||
|
mbedtls_pk_init(&pk_key_);
|
||
|
mbedtls_x509_crt_init(&ca_cert_);
|
||
|
int ret = mbedtls_ssl_config_defaults(&conf_, is_client_not_server ? MBEDTLS_SSL_IS_CLIENT: MBEDTLS_SSL_IS_SERVER,
|
||
|
MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
|
||
|
if (ret) {
|
||
|
print_error("mbedtls_ssl_config_defaults", ret);
|
||
|
return false;
|
||
|
}
|
||
|
mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
|
||
|
mbedtls_ssl_conf_authmode(&conf_, mbedtls_verify_mode);
|
||
|
if (ctx->cert_chain_.size() > 0 && ctx->private_key_.size() > 0) {
|
||
|
ret = mbedtls_x509_crt_parse(&public_cert_, ctx->data(container::CERT), ctx->size(container::CERT));
|
||
|
if (ret < 0) {
|
||
|
print_error("mbedtls_x509_crt_parse", ret);
|
||
|
return false;
|
||
|
}
|
||
|
ret = mbedtls_pk_parse_key(&pk_key_, ctx->data(container::PRIVKEY), ctx->size(container::PRIVKEY),
|
||
|
nullptr, 0);
|
||
|
if (ret < 0) {
|
||
|
print_error("mbedtls_pk_parse_keyfile", ret);
|
||
|
return false;
|
||
|
}
|
||
|
ret = mbedtls_ssl_conf_own_cert(&conf_, &public_cert_, &pk_key_);
|
||
|
if (ret) {
|
||
|
print_error("mbedtls_ssl_conf_own_cert", ret);
|
||
|
return false;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (ctx->ca_cert_.size() > 0) {
|
||
|
ret = mbedtls_x509_crt_parse(&ca_cert_, ctx->data(container::CA_CERT), ctx->size(container::CA_CERT));
|
||
|
if (ret < 0) {
|
||
|
print_error("mbedtls_x509_crt_parse", ret);
|
||
|
return false;
|
||
|
}
|
||
|
mbedtls_ssl_conf_ca_chain(&conf_, &ca_cert_, nullptr);
|
||
|
} else {
|
||
|
mbedtls_ssl_conf_ca_chain(&conf_, nullptr, nullptr);
|
||
|
}
|
||
|
ret = mbedtls_ssl_setup(&ssl_, &conf_);
|
||
|
if (ret) {
|
||
|
print_error("mbedtls_ssl_setup", ret);
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
mbedtls_ssl_context ssl_{};
|
||
|
mbedtls_entropy_context entropy_{};
|
||
|
mbedtls_ctr_drbg_context ctr_drbg_{};
|
||
|
mbedtls_ssl_config conf_{};
|
||
|
mbedtls_x509_crt public_cert_{};
|
||
|
mbedtls_pk_context pk_key_{};
|
||
|
mbedtls_x509_crt ca_cert_{};
|
||
|
};
|
||
|
|
||
|
impl impl_{};
|
||
|
std::shared_ptr<context> ctx_;
|
||
|
std::pair<std::shared_ptr<bio>, std::shared_ptr<bio>> bio_;
|
||
|
enum rw_state state_;
|
||
|
asio::ssl::verify_mode verify_mode_;
|
||
|
};
|
||
|
|
||
|
} } } // namespace asio::ssl::mbedtls
|