Merge branch 'ci/publish_connect_refactor' into 'master'

Publish connect test refactor

See merge request espressif/esp-idf!25311
This commit is contained in:
Rocha Euripedes 2023-10-12 20:03:28 +08:00
commit ade6384954
6 changed files with 834 additions and 553 deletions

View File

@ -8,11 +8,11 @@
*/
#include <stdint.h>
#include "esp_netif.h"
#include "esp_console.h"
#include "esp_log.h"
#include "mqtt_client.h"
#include "esp_tls.h"
#include "publish_connect_test.h"
#if (!defined(CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT)) || \
(!defined(CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT)) || \
@ -34,17 +34,23 @@ extern const uint8_t client_inv_crt[] asm("_binary_client_inv_crt_start");
extern const uint8_t client_no_pwd_key[] asm("_binary_client_no_pwd_key_start");
static const char *TAG = "connect_test";
static esp_mqtt_client_handle_t mqtt_client = NULL;
static int running_test_case = 0;
static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data)
{
(void)handler_args;
(void)base;
(void)event_id;
esp_mqtt_event_handle_t event = event_data;
ESP_LOGD(TAG, "Event: %d, Test case: %d", event->event_id, running_test_case);
switch (event->event_id) {
case MQTT_EVENT_BEFORE_CONNECT:
break;
case MQTT_EVENT_CONNECTED:
ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED: Test=%d", running_test_case);
break;
case MQTT_EVENT_DISCONNECTED:
break;
case MQTT_EVENT_ERROR:
ESP_LOGI(TAG, "MQTT_EVENT_ERROR: Test=%d", running_test_case);
if (event->error_handle->error_type == MQTT_ERROR_TYPE_ESP_TLS) {
@ -61,44 +67,17 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
}
}
static void create_client(void)
static void connect_no_certs(esp_mqtt_client_handle_t client, const char *uri)
{
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = "mqtts://127.0.0.1:1234"
};
esp_mqtt_client_handle_t client = esp_mqtt_client_init(&mqtt_cfg);
esp_mqtt_client_register_event(client, ESP_EVENT_ANY_ID, mqtt_event_handler, client);
mqtt_client = client;
esp_mqtt_client_start(client);
ESP_LOGI(TAG, "mqtt client created for connection tests");
}
static void destroy_client(void)
{
if (mqtt_client) {
esp_mqtt_client_stop(mqtt_client);
esp_mqtt_client_destroy(mqtt_client);
mqtt_client = NULL;
ESP_LOGI(TAG, "mqtt client for connection tests destroyed");
}
}
static void connect_no_certs(const char *host, const int port)
{
char uri[64];
sprintf(uri, "mqtts://%s:%d", host, port);
ESP_LOGI(TAG, "Runnning :CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT");
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
static void connect_with_client_key_password(const char *host, const int port)
static void connect_with_client_key_password(esp_mqtt_client_handle_t client, const char *uri)
{
char uri[64];
sprintf(uri, "mqtts://%s:%d", host, port);
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri,
.broker.verification.certificate = (const char *)ca_local_crt,
@ -107,15 +86,11 @@ static void connect_with_client_key_password(const char *host, const int port)
.credentials.authentication.key_password = "esp32",
.credentials.authentication.key_password_len = 5
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
static void connect_with_server_der_cert(const char *host, const int port)
static void connect_with_server_der_cert(esp_mqtt_client_handle_t client, const char *uri)
{
char uri[64];
sprintf(uri, "mqtts://%s:%d", host, port);
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri,
.broker.verification.certificate = (const char *)ca_der_start,
@ -123,123 +98,96 @@ static void connect_with_server_der_cert(const char *host, const int port)
.credentials.authentication.certificate = "NULL",
.credentials.authentication.key = "NULL"
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
static void connect_with_wrong_server_cert(const char *host, const int port)
static void connect_with_wrong_server_cert(esp_mqtt_client_handle_t client, const char *uri)
{
char uri[64];
sprintf(uri, "mqtts://%s:%d", host, port);
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri,
.broker.verification.certificate = (const char *)client_pwd_crt,
.credentials.authentication.certificate = "NULL",
.credentials.authentication.key = "NULL"
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
static void connect_with_server_cert(const char *host, const int port)
static void connect_with_server_cert(esp_mqtt_client_handle_t client, const char *uri)
{
char uri[64];
sprintf(uri, "mqtts://%s:%d", host, port);
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri,
.broker.verification.certificate = (const char *)ca_local_crt,
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
static void connect_with_server_client_certs(const char *host, const int port)
static void connect_with_server_client_certs(esp_mqtt_client_handle_t client, const char *uri)
{
char uri[64];
sprintf(uri, "mqtts://%s:%d", host, port);
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri,
.broker.verification.certificate = (const char *)ca_local_crt,
.credentials.authentication.certificate = (const char *)client_pwd_crt,
.credentials.authentication.key = (const char *)client_no_pwd_key
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
static void connect_with_invalid_client_certs(const char *host, const int port)
static void connect_with_invalid_client_certs(esp_mqtt_client_handle_t client, const char *uri)
{
char uri[64];
sprintf(uri, "mqtts://%s:%d", host, port);
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri,
.broker.verification.certificate = (const char *)ca_local_crt,
.credentials.authentication.certificate = (const char *)client_inv_crt,
.credentials.authentication.key = (const char *)client_no_pwd_key
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
static void connect_with_alpn(const char *host, const int port)
static void connect_with_alpn(esp_mqtt_client_handle_t client, const char *uri)
{
char uri[64];
const char *alpn_protos[] = { "mymqtt", NULL };
sprintf(uri, "mqtts://%s:%d", host, port);
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = uri,
.broker.verification.alpn_protos = alpn_protos
};
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
esp_mqtt_client_disconnect(mqtt_client);
esp_mqtt_client_reconnect(mqtt_client);
esp_mqtt_set_config(client, &mqtt_cfg);
}
void connection_test(const char *line)
{
char test_type[32];
char host[32];
int port;
int test_case;
void connect_setup(command_context_t * ctx) {
esp_mqtt_client_register_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, ctx->data);
}
sscanf(line, "%s %s %d %d", test_type, host, &port, &test_case);
if (mqtt_client == NULL) {
create_client();
}
if (strcmp(host, "teardown") == 0) {
destroy_client();;
}
ESP_LOGI(TAG, "CASE:%d, connecting to mqtts://%s:%d ", test_case, host, port);
void connect_teardown(command_context_t * ctx) {
esp_mqtt_client_unregister_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler);
}
void connection_test(command_context_t * ctx, const char *uri, int test_case)
{
ESP_LOGI(TAG, "CASE:%d, connecting to %s", test_case, uri);
running_test_case = test_case;
switch (test_case) {
case CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT:
connect_no_certs(host, port);
connect_no_certs(ctx->mqtt_client, uri);
break;
case CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT:
connect_with_server_cert(host, port);
connect_with_server_cert(ctx->mqtt_client, uri);
break;
case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH:
connect_with_server_client_certs(host, port);
connect_with_server_client_certs(ctx->mqtt_client, uri);
break;
case CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT:
connect_with_wrong_server_cert(host, port);
connect_with_wrong_server_cert(ctx->mqtt_client, uri);
break;
case CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT:
connect_with_server_der_cert(host, port);
connect_with_server_der_cert(ctx->mqtt_client, uri);
break;
case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD:
connect_with_client_key_password(host, port);
connect_with_client_key_password(ctx->mqtt_client, uri);
break;
case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT:
connect_with_invalid_client_certs(host, port);
connect_with_invalid_client_certs(ctx->mqtt_client, uri);
break;
case CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN:
connect_with_alpn(host, port);
connect_with_alpn(ctx->mqtt_client, uri);
break;
default:
ESP_LOGE(TAG, "Unknown test case %d ", test_case);

View File

@ -8,68 +8,310 @@
*/
#include <stdio.h>
#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include "esp_system.h"
#include "mqtt_client.h"
#include "nvs_flash.h"
#include "esp_event.h"
#include "esp_netif.h"
#include "protocol_examples_common.h"
#include "esp_console.h"
#include "argtable3/argtable3.h"
#include "esp_log.h"
#include "publish_connect_test.h"
static const char *TAG = "publish_connect_test";
void connection_test(const char *line);
void publish_test(const char *line);
command_context_t command_context;
connection_args_t connection_args;
publish_setup_args_t publish_setup_args;
publish_args_t publish_args;
static void get_string(char *line, size_t size)
{
int count = 0;
while (count < size) {
int c = fgetc(stdin);
if (c == '\n') {
line[count] = '\0';
break;
} else if (c > 0 && c < 127) {
line[count] = c;
++count;
}
vTaskDelay(10 / portTICK_PERIOD_MS);
#define RETURN_ON_PARSE_ERROR(args) do { \
int nerrors = arg_parse(argc, argv, (void **) &(args)); \
if (nerrors != 0) { \
arg_print_errors(stderr, (args).end, argv[0]); \
return 1; \
}} while(0)
static int do_free_heap(int argc, char **argv) {
(void)argc;
(void)argv;
ESP_LOGI(TAG, "Note free memory: %d bytes", esp_get_free_heap_size());
return 0;
}
static int do_init(int argc, char **argv) {
(void)argc;
(void)argv;
const esp_mqtt_client_config_t mqtt_cfg = {
.broker.address.uri = "mqtts://127.0.0.1:1234",
.network.disable_auto_reconnect = true
};
command_context.mqtt_client = esp_mqtt_client_init(&mqtt_cfg);
if(!command_context.mqtt_client) {
ESP_LOGE(TAG, "Failed to initialize client");
return 1;
}
publish_init_flags();
ESP_LOGI(TAG, "Mqtt client initialized");
return 0;
}
static int do_start(int argc, char **argv) {
(void)argc;
(void)argv;
if(esp_mqtt_client_start(command_context.mqtt_client) != ESP_OK) {
ESP_LOGE(TAG, "Failed to start mqtt client task");
return 1;
}
ESP_LOGI(TAG, "Mqtt client started");
return 0;
}
static int do_stop(int argc, char **argv) {
(void)argc;
(void)argv;
if(esp_mqtt_client_stop(command_context.mqtt_client) != ESP_OK) {
ESP_LOGE(TAG, "Failed to stop mqtt client task");
return 1;
}
ESP_LOGI(TAG, "Mqtt client stoped");
return 0;
}
static int do_disconnect(int argc, char **argv) {
(void)argc;
(void)argv;
if(esp_mqtt_client_disconnect(command_context.mqtt_client) != ESP_OK) {
ESP_LOGE(TAG, "Failed to request disconnection");
return 1;
}
ESP_LOGI(TAG, "Mqtt client disconnected");
return 0;
}
static int do_connect_setup(int argc, char **argv) {
(void)argc;
(void)argv;
connect_setup(&command_context);
return 0;
}
static int do_connect_teardown(int argc, char **argv) {
(void)argc;
(void)argv;
connect_teardown(&command_context);
return 0;
}
static int do_reconnect(int argc, char **argv) {
(void)argc;
(void)argv;
if(esp_mqtt_client_reconnect(command_context.mqtt_client) != ESP_OK) {
ESP_LOGE(TAG, "Failed to request reconnection");
return 1;
}
ESP_LOGI(TAG, "Mqtt client will reconnect");
return 0;
;
}
static int do_destroy(int argc, char **argv) {
(void)argc;
(void)argv;
esp_mqtt_client_destroy(command_context.mqtt_client);
command_context.mqtt_client = NULL;
ESP_LOGI(TAG, "mqtt client for tests destroyed");
return 0;
}
static int do_connect(int argc, char **argv)
{
int nerrors = arg_parse(argc, argv, (void **) &connection_args);
if (nerrors != 0) {
arg_print_errors(stderr, connection_args.end, argv[0]);
return 1;
}
if(!command_context.mqtt_client) {
ESP_LOGE(TAG, "MQTT client not initialized, call init first");
return 1;
}
connection_test(&command_context, *connection_args.uri->sval, *connection_args.test_case->ival);
return 0;
}
static int do_publish_setup(int argc, char **argv) {
RETURN_ON_PARSE_ERROR(publish_setup_args);
if(command_context.data) {
free(command_context.data);
}
command_context.data = calloc(1, sizeof(publish_context_t));
((publish_context_t*)command_context.data)->pattern = strdup(*publish_setup_args.pattern->sval);
((publish_context_t*)command_context.data)->pattern_repetitions = *publish_setup_args.pattern_repetitions->ival;
publish_setup(&command_context, *publish_setup_args.transport->sval);
return 0;
}
static int do_publish(int argc, char **argv) {
RETURN_ON_PARSE_ERROR(publish_args);
publish_test(&command_context, publish_args.expected_to_publish->ival[0], publish_args.qos->ival[0], publish_args.enqueue->ival[0]);
return 0;
}
static int do_publish_report(int argc, char **argv) {
(void)argc;
(void)argv;
publish_context_t * ctx = command_context.data;
ESP_LOGI(TAG,"Test Report : Messages received %d, %d expected", ctx->nr_of_msg_received, ctx->nr_of_msg_expected);
return 0;
}
void register_common_commands(void) {
const esp_console_cmd_t init = {
.command = "init",
.help = "Run inition test\n",
.hint = NULL,
.func = &do_init,
};
const esp_console_cmd_t start = {
.command = "start",
.help = "Run startion test\n",
.hint = NULL,
.func = &do_start,
};
const esp_console_cmd_t stop = {
.command = "stop",
.help = "Run stopion test\n",
.hint = NULL,
.func = &do_stop,
};
const esp_console_cmd_t destroy = {
.command = "destroy",
.help = "Run destroyion test\n",
.hint = NULL,
.func = &do_destroy,
};
const esp_console_cmd_t free_heap = {
.command = "free_heap",
.help = "Run destroyion test\n",
.hint = NULL,
.func = &do_free_heap,
};
ESP_ERROR_CHECK(esp_console_cmd_register(&init));
ESP_ERROR_CHECK(esp_console_cmd_register(&start));
ESP_ERROR_CHECK(esp_console_cmd_register(&stop));
ESP_ERROR_CHECK(esp_console_cmd_register(&destroy));
ESP_ERROR_CHECK(esp_console_cmd_register(&free_heap));
}
void register_publish_commands(void) {
publish_setup_args.transport = arg_str1(NULL,NULL,"<transport>", "Selected transport to test");
publish_setup_args.pattern = arg_str1(NULL,NULL,"<pattern>", "Message pattern repeated to build big messages");
publish_setup_args.pattern_repetitions = arg_int1(NULL,NULL,"<pattern repetitions>", "How many times the pattern is repeated");
publish_setup_args.end = arg_end(1);
publish_args.expected_to_publish = arg_int1(NULL,NULL,"<number of messages>", "How many times the pattern is repeated");
publish_args.qos = arg_int1(NULL,NULL,"<qos>", "How many times the pattern is repeated");
publish_args.enqueue = arg_int1(NULL,NULL,"<enqueue>", "How many times the pattern is repeated");
publish_args.end = arg_end(1);
const esp_console_cmd_t publish_setup = {
.command = "publish_setup",
.help = "Run publish test\n",
.hint = NULL,
.func = &do_publish_setup,
.argtable = &publish_setup_args
};
const esp_console_cmd_t publish = {
.command = "publish",
.help = "Run publish test\n",
.hint = NULL,
.func = &do_publish,
.argtable = &publish_args
};
const esp_console_cmd_t publish_report = {
.command = "publish_report",
.help = "Run destroyion test\n",
.hint = NULL,
.func = &do_publish_report,
};
ESP_ERROR_CHECK(esp_console_cmd_register(&publish_setup));
ESP_ERROR_CHECK(esp_console_cmd_register(&publish));
ESP_ERROR_CHECK(esp_console_cmd_register(&publish_report));
}
void register_connect_commands(void){
connection_args.uri = arg_str1(NULL,NULL,"<broker uri>", "Broker address");
connection_args.test_case = arg_int1(NULL, NULL, "<test case>","Selected test case");
connection_args.end = arg_end(1);
const esp_console_cmd_t connect = {
.command = "connect",
.help = "Run connection test\n",
.hint = NULL,
.func = &do_connect,
.argtable = &connection_args
};
const esp_console_cmd_t reconnect = {
.command = "reconnect",
.help = "Run reconnection test\n",
.hint = NULL,
.func = &do_reconnect,
};
const esp_console_cmd_t connection_setup = {
.command = "connection_setup",
.help = "Run reconnection test\n",
.hint = NULL,
.func = &do_connect_setup,
};
const esp_console_cmd_t connection_teardown = {
.command = "connection_teardown",
.help = "Run reconnection test\n",
.hint = NULL,
.func = &do_connect_teardown,
};
const esp_console_cmd_t disconnect = {
.command = "disconnect",
.help = "Run disconnection test\n",
.hint = NULL,
.func = &do_disconnect,
};
ESP_ERROR_CHECK(esp_console_cmd_register(&connect));
ESP_ERROR_CHECK(esp_console_cmd_register(&disconnect));
ESP_ERROR_CHECK(esp_console_cmd_register(&reconnect));
ESP_ERROR_CHECK(esp_console_cmd_register(&connection_setup));
ESP_ERROR_CHECK(esp_console_cmd_register(&connection_teardown));
}
void app_main(void)
{
char line[256];
static const size_t max_line = 256;
ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size());
ESP_LOGI(TAG, "[APP] IDF version: %s", esp_get_idf_version());
esp_log_level_set("*", ESP_LOG_INFO);
esp_log_level_set("wifi", ESP_LOG_ERROR);
esp_log_level_set("mqtt_client", ESP_LOG_VERBOSE);
esp_log_level_set("transport_base", ESP_LOG_VERBOSE);
esp_log_level_set("transport", ESP_LOG_VERBOSE);
esp_log_level_set("outbox", ESP_LOG_VERBOSE);
ESP_ERROR_CHECK(nvs_flash_init());
ESP_ERROR_CHECK(esp_netif_init());
ESP_ERROR_CHECK(esp_event_loop_create_default());
/* This helper function configures Wi-Fi or Ethernet, as selected in menuconfig.
* Read "Establishing Wi-Fi or Ethernet Connection" section in
* examples/protocols/README.md for more information about this function.
*/
ESP_ERROR_CHECK(example_connect());
esp_console_repl_t *repl = NULL;
esp_console_repl_config_t repl_config = ESP_CONSOLE_REPL_CONFIG_DEFAULT();
repl_config.prompt = "mqtt>";
repl_config.max_cmdline_length = max_line;
esp_console_register_help_command();
register_common_commands();
register_connect_commands();
register_publish_commands();
while (1) {
get_string(line, sizeof(line));
if (memcmp(line, "conn", 4) == 0) {
// line starting with "conn" indicate connection tests
connection_test(line);
get_string(line, sizeof(line));
continue;
} else {
publish_test(line);
}
}
esp_console_dev_uart_config_t hw_config = ESP_CONSOLE_DEV_UART_CONFIG_DEFAULT();
ESP_ERROR_CHECK(esp_console_new_repl_uart(&hw_config, &repl_config, &repl));
ESP_ERROR_CHECK(esp_console_start_repl(repl));
}

View File

@ -0,0 +1,55 @@
/*
* SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD
*
* SPDX-License-Identifier: Unlicense OR CC0-1.0
*/
#pragma once
#include "mqtt_client.h"
typedef enum {NONE, TCP, SSL, WS, WSS} transport_t;
typedef struct {
esp_mqtt_client_handle_t mqtt_client;
void * data;
} command_context_t;
typedef struct {
transport_t selected_transport;
char *pattern;
int pattern_repetitions;
int qos;
char *expected;
size_t expected_size;
size_t nr_of_msg_received;
size_t nr_of_msg_expected;
char * received_data;
} publish_context_t ;
typedef struct {
struct arg_str *uri;
struct arg_int *test_case;
struct arg_end *end;
} connection_args_t;
typedef struct {
struct arg_int *expected_to_publish;
struct arg_int *qos;
struct arg_int *enqueue;
struct arg_end *end;
} publish_args_t;
typedef struct {
struct arg_str *transport;
struct arg_str *pattern;
struct arg_int *pattern_repetitions;
struct arg_end *end;
} publish_setup_args_t;
void publish_init_flags(void);
void publish_setup(command_context_t * ctx, char const * transport);
void publish_teardown(command_context_t * ctx);
void publish_test(command_context_t * ctx, int expect_to_publish, int qos, bool enqueue);
void connection_test(command_context_t * ctx, const char *uri, int test_case);
void connect_setup(command_context_t * ctx);
void connect_teardown(command_context_t * ctx);

View File

@ -6,33 +6,25 @@
software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied.
*/
#include <stddef.h>
#include <stdio.h>
#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include "esp_system.h"
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "freertos/event_groups.h"
#include <freertos/event_groups.h>
#include "esp_system.h"
#include "esp_log.h"
#include "mqtt_client.h"
#include "sdkconfig.h"
#include "publish_connect_test.h"
static const char *TAG = "publish_test";
static EventGroupHandle_t mqtt_event_group;
const static int CONNECTED_BIT = BIT0;
static esp_mqtt_client_handle_t mqtt_client = NULL;
static char *expected_data = NULL;
static char *actual_data = NULL;
static size_t expected_size = 0;
static size_t expected_published = 0;
static size_t actual_published = 0;
static int qos_test = 0;
#if CONFIG_EXAMPLE_BROKER_CERTIFICATE_OVERRIDDEN == 1
static const uint8_t mqtt_eclipseprojects_io_pem_start[] = "-----BEGIN CERTIFICATE-----\n" CONFIG_EXAMPLE_BROKER_CERTIFICATE_OVERRIDE "\n-----END CERTIFICATE-----";
#else
@ -42,6 +34,7 @@ extern const uint8_t mqtt_eclipseprojects_io_pem_end[] asm("_binary_mqtt_eclip
static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data)
{
publish_context_t * test_data = handler_args;
esp_mqtt_event_handle_t event = event_data;
esp_mqtt_client_handle_t client = event->client;
static int msg_id = 0;
@ -52,7 +45,7 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
case MQTT_EVENT_CONNECTED:
ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED");
xEventGroupSetBits(mqtt_event_group, CONNECTED_BIT);
msg_id = esp_mqtt_client_subscribe(client, CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, qos_test);
msg_id = esp_mqtt_client_subscribe(client, CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, test_data->qos);
ESP_LOGI(TAG, "sent subscribe successful %s , msg_id=%d", CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, msg_id);
break;
@ -67,13 +60,12 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
ESP_LOGI(TAG, "MQTT_EVENT_UNSUBSCRIBED, msg_id=%d", event->msg_id);
break;
case MQTT_EVENT_PUBLISHED:
ESP_LOGI(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id);
ESP_LOGD(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id);
break;
case MQTT_EVENT_DATA:
ESP_LOGI(TAG, "MQTT_EVENT_DATA");
printf("TOPIC=%.*s\r\n", event->topic_len, event->topic);
printf("DATA=%.*s\r\n", event->data_len, event->data);
printf("ID=%d, total_len=%d, data_len=%d, current_data_offset=%d\n", event->msg_id, event->total_data_len, event->data_len, event->current_data_offset);
ESP_LOGI(TAG, "TOPIC=%.*s", event->topic_len, event->topic);
ESP_LOGI(TAG, "ID=%d, total_len=%d, data_len=%d, current_data_offset=%d", event->msg_id, event->total_data_len, event->data_len, event->current_data_offset);
if (event->topic) {
actual_len = event->data_len;
msg_id = event->msg_id;
@ -85,24 +77,23 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
abort();
}
}
memcpy(actual_data + event->current_data_offset, event->data, event->data_len);
memcpy(test_data->received_data + event->current_data_offset, event->data, event->data_len);
if (actual_len == event->total_data_len) {
if (0 == memcmp(actual_data, expected_data, expected_size)) {
printf("OK!");
memset(actual_data, 0, expected_size);
actual_published ++;
if (actual_published == expected_published) {
printf("Correct pattern received exactly x times\n");
if (0 == memcmp(test_data->received_data, test_data->expected, test_data->expected_size)) {
memset(test_data->received_data, 0, test_data->expected_size);
test_data->nr_of_msg_received ++;
if (test_data->nr_of_msg_received == test_data->nr_of_msg_expected) {
ESP_LOGI(TAG, "Correct pattern received exactly x times");
ESP_LOGI(TAG, "Test finished correctly!");
}
} else {
printf("FAILED!");
ESP_LOGE(TAG, "FAILED!");
abort();
}
}
break;
case MQTT_EVENT_ERROR:
ESP_LOGI(TAG, "MQTT_EVENT_ERROR");
ESP_LOGE(TAG, "MQTT_EVENT_ERROR");
break;
default:
ESP_LOGI(TAG, "Other event id:%d", event->event_id);
@ -110,37 +101,31 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
}
}
typedef enum {NONE, TCP, SSL, WS, WSS} transport_t;
static transport_t current_transport;
void test_init(void)
{
mqtt_event_group = xEventGroupCreate();
esp_mqtt_client_config_t config = {0};
mqtt_client = esp_mqtt_client_init(&config);
current_transport = NONE;
esp_mqtt_client_register_event(mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, NULL);
ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size());
}
void pattern_setup(char *pattern, int repeat)
void pattern_setup(publish_context_t * test_data)
{
int pattern_size = strlen(pattern);
free(expected_data);
free(actual_data);
actual_published = 0;
expected_size = pattern_size * repeat;
expected_data = malloc(expected_size);
actual_data = malloc(expected_size);
for (int i = 0; i < repeat; i++) {
memcpy(expected_data + i * pattern_size, pattern, pattern_size);
int pattern_size = strlen(test_data->pattern);
free(test_data->expected);
free(test_data->received_data);
test_data->nr_of_msg_received = 0;
test_data->expected_size = (size_t)(pattern_size) * test_data->pattern_repetitions;
test_data->expected = malloc(test_data->expected_size);
test_data->received_data = malloc(test_data->expected_size);
for (int i = 0; i < test_data->pattern_repetitions; i++) {
memcpy(test_data->expected + (ptrdiff_t)(i * pattern_size), test_data->pattern, pattern_size);
}
printf("EXPECTED STRING %.*s, SIZE:%d\n", expected_size, expected_data, expected_size);
ESP_LOGI(TAG, "EXPECTED STRING %.*s, SIZE:%d", test_data->expected_size, test_data->expected, test_data->expected_size);
}
static void configure_client(char *transport)
static void configure_client(command_context_t * ctx, const char *transport)
{
publish_context_t * test_data = ctx->data;
ESP_LOGI(TAG, "Configuration");
transport_t selected_transport;
if (0 == strcmp(transport, "tcp")) {
@ -157,7 +142,8 @@ static void configure_client(char *transport)
}
if (selected_transport != current_transport) {
if (selected_transport != test_data->selected_transport) {
test_data->selected_transport = selected_transport;
esp_mqtt_client_config_t config = {0};
switch (selected_transport) {
case NONE:
@ -183,43 +169,45 @@ static void configure_client(char *transport)
ESP_LOGI(TAG, "Set certificate");
config.broker.verification.certificate = (const char *)mqtt_eclipseprojects_io_pem_start;
}
esp_mqtt_set_config(mqtt_client, &config);
esp_mqtt_set_config(ctx->mqtt_client, &config);
}
}
void publish_test(const char *line)
{
char pattern[32];
char transport[32];
int repeat = 0;
int enqueue = 0;
static bool is_test_init = false;
if (!is_test_init) {
test_init();
is_test_init = true;
} else {
esp_mqtt_client_stop(mqtt_client);
}
void publish_init_flags(void) {
mqtt_event_group = xEventGroupCreate();
}
sscanf(line, "%s %s %d %d %d %d", transport, pattern, &repeat, &expected_published, &qos_test, &enqueue);
ESP_LOGI(TAG, "PATTERN:%s REPEATED:%d PUBLISHED:%d", pattern, repeat, expected_published);
pattern_setup(pattern, repeat);
void publish_setup(command_context_t * ctx, char const * const transport) {
xEventGroupClearBits(mqtt_event_group, CONNECTED_BIT);
configure_client(transport);
esp_mqtt_client_start(mqtt_client);
publish_context_t * data = (publish_context_t*)ctx->data;
pattern_setup(data);
configure_client(ctx, transport);
esp_mqtt_client_register_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, data);
}
ESP_LOGI(TAG, "Note free memory: %d bytes", esp_get_free_heap_size());
void publish_teardown(command_context_t * ctx)
{
esp_mqtt_client_unregister_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler);
}
void publish_test(command_context_t * ctx, int expect_to_publish, int qos, bool enqueue)
{
publish_context_t * data = (publish_context_t*)ctx->data;
data->nr_of_msg_expected = expect_to_publish;
ESP_LOGI(TAG, "PATTERN:%s REPEATED:%d PUBLISHED:%d", data->pattern, data->pattern_repetitions, data->nr_of_msg_expected);
xEventGroupWaitBits(mqtt_event_group, CONNECTED_BIT, false, true, portMAX_DELAY);
for (int i = 0; i < expected_published; i++) {
for (int i = 0; i < data->nr_of_msg_expected; i++) {
int msg_id;
if (enqueue) {
msg_id = esp_mqtt_client_enqueue(mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, expected_data, expected_size, qos_test, 0, true);
msg_id = esp_mqtt_client_enqueue(ctx->mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, data->expected, data->expected_size, qos, 0, true);
} else {
msg_id = esp_mqtt_client_publish(mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, expected_data, expected_size, qos_test, 0);
msg_id = esp_mqtt_client_publish(ctx->mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, data->expected, data->expected_size, qos, 0);
if(msg_id < 0) {
ESP_LOGE(TAG, "Failed to publish");
break;
}
}
ESP_LOGI(TAG, "[%d] Publishing...", msg_id);
ESP_LOGD(TAG, "Publishing msg_id=%d", msg_id);
}
}

View File

@ -1,31 +1,21 @@
# SPDX-FileCopyrightText: 2022-2023 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
from __future__ import print_function, unicode_literals
import difflib
import contextlib
import logging
import os
import random
import re
import select
import socket
import socketserver
import ssl
import string
import subprocess
import sys
import time
import typing
from itertools import count
from threading import Event, Lock, Thread
from typing import Any
from threading import Thread
from typing import Any, Callable, Dict, Optional
import paho.mqtt.client as mqtt
import pytest
from common_test_methods import get_host_ip4_by_dest_ip
from pytest_embedded import Dut
from pytest_embedded_qemu.dut import QemuDut
DEFAULT_MSG_SIZE = 16
SERVER_PORT = 2222
def _path(f): # type: (str) -> str
@ -42,313 +32,98 @@ def set_server_cert_cn(ip): # type: (str) -> None
raise RuntimeError('openssl command {} failed'.format(args))
# Publisher class creating a python client to send/receive published data from esp-mqtt client
class MqttPublisher:
event_client_connected = Event()
event_client_got_all = Event()
expected_data = ''
published = 0
sample = ''
class MQTTHandler(socketserver.StreamRequestHandler):
def __init__(self, dut, transport,
qos, repeat, published, queue, publish_cfg, log_details=False): # type: (MqttPublisher, Dut, str, int, int, int, int, dict, bool) -> None
# instance variables used as parameters of the publish test
self.event_stop_client = Event()
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
self.client = None
self.dut = dut
self.log_details = log_details
self.repeat = repeat
self.publish_cfg = publish_cfg
self.publish_cfg['qos'] = qos
self.publish_cfg['queue'] = queue
self.publish_cfg['transport'] = transport
self.lock = Lock()
# static variables used to pass options to and from static callbacks of paho-mqtt client
MqttPublisher.event_client_connected = Event()
MqttPublisher.event_client_got_all = Event()
MqttPublisher.published = published
MqttPublisher.event_client_connected.clear()
MqttPublisher.event_client_got_all.clear()
MqttPublisher.expected_data = f'{self.sample_string * self.repeat}'
MqttPublisher.sample = self.sample_string
def print_details(self, text): # type: (str) -> None
if self.log_details:
logging.info(text)
def mqtt_client_task(self, client, lock): # type: (MqttPublisher, mqtt.Client, Lock) -> None
while not self.event_stop_client.is_set():
with lock:
client.loop()
time.sleep(0.001) # yield to other threads
# The callback for when the client receives a CONNACK response from the server (needs to be static)
@staticmethod
def on_connect(_client, _userdata, _flags, _rc): # type: (mqtt.Client, tuple, bool, str) -> None
MqttPublisher.event_client_connected.set()
# The callback for when a PUBLISH message is received from the server (needs to be static)
@staticmethod
def on_message(client, userdata, msg): # type: (mqtt.Client, int, mqtt.client.MQTTMessage) -> None
payload = msg.payload.decode('utf-8')
if payload == MqttPublisher.expected_data:
userdata += 1
client.user_data_set(userdata)
if userdata == MqttPublisher.published:
MqttPublisher.event_client_got_all.set()
else:
differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, MqttPublisher.expected_data))))
logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:'
f'{len(MqttPublisher.expected_data)}')
logging.info(f'Repetitions: {payload.count(MqttPublisher.sample)}')
logging.info(f'Pattern: {MqttPublisher.sample}')
logging.info(f'First : {payload[:DEFAULT_MSG_SIZE]}')
logging.info(f'Last : {payload[-DEFAULT_MSG_SIZE:]}')
matcher = difflib.SequenceMatcher(a=payload, b=MqttPublisher.expected_data)
for match in matcher.get_matching_blocks():
logging.info(f'Match: {match}')
def __enter__(self): # type: (MqttPublisher) -> None
qos = self.publish_cfg['qos']
queue = self.publish_cfg['queue']
transport = self.publish_cfg['transport']
broker_host = self.publish_cfg['broker_host_' + transport]
broker_port = self.publish_cfg['broker_port_' + transport]
# Start the test
self.print_details(f'PUBLISH TEST: transport:{transport}, qos:{qos}, sequence:{MqttPublisher.published},'
f"enqueue:{queue}, sample msg:'{MqttPublisher.expected_data}'")
try:
if transport in ['ws', 'wss']:
self.client = mqtt.Client(transport='websockets')
def handle(self) -> None:
logging.info(' - connection from: {}'.format(self.client_address))
data = bytearray(self.request.recv(1024))
message = ''.join(format(x, '02x') for x in data)
if message[0:16] == '101800044d515454':
if self.server.refuse_connection is False: # type: ignore
logging.info(' - received mqtt connect, sending ACK')
self.request.send(bytearray.fromhex('20020000'))
else:
self.client = mqtt.Client()
assert self.client is not None
self.client.on_connect = MqttPublisher.on_connect
self.client.on_message = MqttPublisher.on_message
self.client.user_data_set(0)
if transport in ['ssl', 'wss']:
self.client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
self.client.tls_insecure_set(True)
self.print_details('Connecting...')
self.client.connect(broker_host, broker_port, 60)
except Exception:
self.print_details(f'ENV_TEST_FAILURE: Unexpected error while connecting to broker {broker_host}')
raise
# Starting a py-client in a separate thread
thread1 = Thread(target=self.mqtt_client_task, args=(self.client, self.lock))
thread1.start()
self.print_details('Connecting py-client to broker {}:{}...'.format(broker_host, broker_port))
if not MqttPublisher.event_client_connected.wait(timeout=30):
raise ValueError(f'ENV_TEST_FAILURE: Test script cannot connect to broker: {broker_host}')
with self.lock:
self.client.subscribe(self.publish_cfg['subscribe_topic'], qos)
self.dut.write(f'{transport} {self.sample_string} {self.repeat} {MqttPublisher.published} {qos} {queue}')
try:
# waiting till subscribed to defined topic
self.dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60)
for _ in range(MqttPublisher.published):
with self.lock:
self.client.publish(self.publish_cfg['publish_topic'], self.sample_string * self.repeat, qos)
self.print_details('Publishing...')
self.print_details('Checking esp-client received msg published from py-client...')
self.dut.expect(re.compile(rb'Correct pattern received exactly x times'), timeout=60)
if not MqttPublisher.event_client_got_all.wait(timeout=60):
raise ValueError('Not all data received from ESP32: {}'.format(transport))
logging.info(' - all data received from ESP32')
finally:
self.event_stop_client.set()
thread1.join()
def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None
assert self.client is not None
self.client.disconnect()
self.event_stop_client.clear()
# injecting connection not authorized error
logging.info(' - received mqtt connect, sending NAK')
self.request.send(bytearray.fromhex('20020005'))
else:
raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message))
# Simple server for mqtt over TLS connection
class TlsServer:
class TlsServer(socketserver.TCPServer):
timeout = 30.0
allow_reuse_address = True
allow_reuse_port = True
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): # type: (TlsServer, int, bool, bool, bool) -> None
self.port = port
self.socket = socket.socket()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.settimeout(10.0)
self.shutdown = Event()
self.client_cert = client_cert
def __init__(self,
port:int = SERVER_PORT,
ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler,
client_cert:bool=False,
refuse_connection:bool=False,
use_alpn:bool=False):
self.refuse_connection = refuse_connection
self.use_alpn = use_alpn
self.conn = socket.socket()
self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_error = ''
self.alpn_protocol: Optional[str] = None
if client_cert:
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.load_verify_locations(cafile=_path('ca.crt'))
self.context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key'))
if use_alpn:
self.context.set_alpn_protocols(['mymqtt', 'http/1.1'])
self.server_thread = Thread(target=self.serve_forever)
super().__init__(('',port), ServerHandler)
def __enter__(self): # type: (TlsServer) -> TlsServer
try:
self.socket.bind(('', self.port))
except socket.error as e:
print('Bind failed:{}'.format(e))
raise
def server_activate(self) -> None:
self.socket = self.context.wrap_socket(self.socket, server_side=True)
super().server_activate()
self.socket.listen(1)
self.server_thread = Thread(target=self.run_server)
def __enter__(self): # type: ignore
self.server_thread.start()
return self
def __exit__(self, exc_type, exc_value, traceback): # type: (TlsServer, str, str, str) -> None
self.shutdown.set()
self.server_thread.join()
self.socket.close()
if (self.conn is not None):
self.conn.close()
def server_close(self) -> None:
try:
self.shutdown()
self.server_thread.join()
super().server_close()
except RuntimeError as e:
logging.exception(e)
def get_last_ssl_error(self): # type: (TlsServer) -> str
# We need to override it here to capture ssl.SSLError
# The implementation is a slightly modified version from cpython original code.
def _handle_request_noblock(self) -> None:
try:
request, client_address = self.get_request()
self.alpn_protocol = request.selected_alpn_protocol() # type: ignore
except ssl.SSLError as e:
self.ssl_error = e.reason
return
except OSError:
return
if self.verify_request(request, client_address):
try:
self.process_request(request, client_address)
except Exception:
self.handle_error(request, client_address)
self.shutdown_request(request)
except: # noqa: E722
self.shutdown_request(request)
raise
else:
self.shutdown_request(request)
def last_ssl_error(self): # type: (TlsServer) -> str
return self.ssl_error
@typing.no_type_check
def get_negotiated_protocol(self):
return self.negotiated_protocol
def run_server(self) -> None:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if self.client_cert:
context.verify_mode = ssl.CERT_REQUIRED
context.load_verify_locations(cafile=_path('ca.crt'))
context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key'))
if self.use_alpn:
context.set_alpn_protocols(['mymqtt', 'http/1.1'])
self.socket = context.wrap_socket(self.socket, server_side=True)
try:
self.conn, address = self.socket.accept() # accept new connection
self.socket.settimeout(10.0)
print(' - connection from: {}'.format(address))
if self.use_alpn:
self.negotiated_protocol = self.conn.selected_alpn_protocol()
print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
self.handle_conn()
except ssl.SSLError as e:
self.ssl_error = str(e)
print(' - SSLError: {}'.format(str(e)))
def handle_conn(self) -> None:
while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1)
try:
if self.conn in r:
self.process_mqtt_connect()
except socket.error as err:
print(' - error: {}'.format(err))
raise
def process_mqtt_connect(self) -> None:
try:
data = bytearray(self.conn.recv(1024))
message = ''.join(format(x, '02x') for x in data)
if message[0:16] == '101800044d515454':
if self.refuse_connection is False:
print(' - received mqtt connect, sending ACK')
self.conn.send(bytearray.fromhex('20020000'))
else:
# injecting connection not authorized error
print(' - received mqtt connect, sending NAK')
self.conn.send(bytearray.fromhex('20020005'))
else:
raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message))
finally:
# stop the server after the connect message in happy flow, or if any exception occur
self.shutdown.set()
def get_negotiated_protocol(self) -> Optional[str]:
return self.alpn_protocol
def connection_tests(dut, cases, dut_ip): # type: (Dut, dict, str) -> None
ip = get_host_ip4_by_dest_ip(dut_ip)
set_server_cert_cn(ip)
server_port = 2222
def teardown_connection_suite() -> None:
dut.write('conn teardown 0 0\n')
def start_connection_case(case, desc): # type: (str, str) -> Any
print('Starting {}: {}'.format(case, desc))
case_id = cases[case]
dut.write('conn {} {} {}\n'.format(ip, server_port, case_id))
dut.expect('Test case:{} started'.format(case_id))
return case_id
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
# All these cases connect to the server with no server verification or with server only verification
with TlsServer(server_port):
test_nr = start_connection_case(case, 'default server - expect to connect normally')
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
with TlsServer(server_port, refuse_connection=True):
test_nr = start_connection_case(case, 'ssl shall connect, but mqtt sends connect refusal')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
with TlsServer(server_port, client_cert=True) as s:
test_nr = start_connection_case(case, 'server with client verification - handshake error since client presents no client certificate')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
# These cases connect to server with both server and client verification (client key might be password protected)
with TlsServer(server_port, client_cert=True):
test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
with TlsServer(server_port) as s:
test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
if 'alert unknown ca' not in s.get_last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
with TlsServer(server_port, client_cert=True) as s:
test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
with TlsServer(server_port, use_alpn=True) as s:
test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
if case == 'EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
print(' - client with alpn off, no negotiated protocol: OK')
elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
print(' - client with alpn on, negotiated protocol resolved: OK')
else:
raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))
teardown_connection_suite()
@pytest.mark.esp32
@pytest.mark.ethernet
def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None:
"""
steps:
1. join AP
2. connect to uri specified in the config
3. send and receive data
"""
# check and log bin size
binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin')
bin_size = os.path.getsize(binary_file)
logging.info('[Performance][mqtt_publish_connect_test_bin_size]: %s KB', bin_size // 1024)
# Look for test case symbolic names and publish configs
def get_test_cases(dut: Dut) -> Any:
cases = {}
publish_cfg = {}
try:
# Get connection test cases configuration: symbolic names for test cases
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
@ -360,63 +135,107 @@ def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None:
'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
cases[case] = dut.app.sdkconfig.get(case)
except Exception:
print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
logging.error('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
raise
return cases
esp_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode()
print('Got IP={}'.format(esp_ip))
connection_tests(dut,cases,esp_ip)
def get_dut_ip(dut: Dut) -> Any:
dut_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode()
logging.info('Got IP={}'.format(dut_ip))
return get_host_ip4_by_dest_ip(dut_ip)
# Get publish test configuration
@contextlib.contextmanager
def connect_dut(dut: Dut, uri:str, case_id:int) -> Any:
dut.write('connection_setup')
dut.write(f'connect {uri} {case_id}')
dut.expect(f'Test case:{case_id} started')
dut.write('reconnect')
yield
dut.write('connection_teardown')
dut.write('disconnect')
def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
try:
@typing.no_type_check
def get_host_port_from_dut(dut, config_option):
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option))
if value is None:
return None, None
return value.group(1), int(value.group(2))
dut.write('init')
dut.write(f'start')
dut.write(f'disconnect')
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
# All these cases connect to the server with no server verification or with server only verification
with TlsServer(), connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: default server - expect to connect normally')
dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30)
with TlsServer(refuse_connection=True), connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: ssl shall connect, but mqtt sends connect refusal')
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
with TlsServer(client_cert=True) as server, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: server with client verification - handshake error since client presents no client certificate')
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
assert 'PEER_DID_NOT_RETURN_A_CERTIFICATE' in server.last_ssl_error()
publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','')
publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','')
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
# These cases connect to server with both server and client verification (client key might be password protected)
with TlsServer(client_cert=True), connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: server with client verification - expect to connect normally')
dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30)
except Exception:
logging.error('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
raise
case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
with TlsServer() as s, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: invalid server certificate on default server - expect ssl handshake error')
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
if re.match('.*alert.*unknown.*ca',s.last_ssl_error(), flags=re.I) is None:
raise Exception(f'Unexpected ssl error from the server: {s.last_ssl_error()}')
# Initialize message sizes and repeat counts (if defined in the environment)
messages = []
for i in count(0):
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
continue
break
if not messages: # No message sizes present in the env - set defaults
messages = [{'len':0, 'repeat':5}, # zero-sized messages
{'len':2, 'repeat':10}, # short messages
{'len':200, 'repeat':3}, # long messages
{'len':20, 'repeat':50} # many medium sized
]
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
with TlsServer(client_cert=True) as s, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error')
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
if 'CERTIFICATE_VERIFY_FAILED' not in s.last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.last_ssl_error()))
# Iterate over all publish message properties
for transport in ['tcp', 'ssl', 'ws', 'wss']:
if publish_cfg['broker_host_' + transport] is None:
print('Skipping transport: {}...'.format(transport))
continue
for enqueue in [0, 1]:
for qos in [0, 1, 2]:
for msg in messages:
logging.info(f'Starting Publish test: transport:{transport}, qos:{qos}, nr_of_msgs:{msg["repeat"]},'
f'msg_size:{msg["len"] * DEFAULT_MSG_SIZE}, enqueue:{enqueue}')
with MqttPublisher(dut, transport, qos, msg['len'], msg['repeat'], enqueue, publish_cfg):
pass
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
with TlsServer(use_alpn=True) as s, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: server with alpn - expect connect, check resolved protocol')
dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30)
if case == 'EXAMPLE_CONNECT_CASE_NO_CERT':
assert s.get_negotiated_protocol() is None
elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN':
assert s.get_negotiated_protocol() == 'mymqtt'
else:
assert False, f'Unexpected negotiated protocol {s.get_negotiated_protocol()}'
finally:
dut.write('stop')
dut.write('destroy')
if __name__ == '__main__':
test_app_protocol_mqtt_publish_connect(dut=QemuDut if sys.argv[1:] == ['qemu'] else Dut)
@pytest.mark.esp32
@pytest.mark.ethernet
def test_mqtt_connect(
dut: Dut,
log_performance: Callable[[str, object], None],
) -> None:
"""
steps:
1. join AP
2. connect to uri specified in the config
3. send and receive data
"""
# check and log bin size
binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin')
bin_size = os.path.getsize(binary_file)
log_performance('mqtt_publish_connect_test_bin_size', f'{bin_size // 1024} KB')
ip = get_dut_ip(dut)
set_server_cert_cn(ip)
uri = f'mqtts://{ip}:{SERVER_PORT}'
# Look for test case symbolic names and publish configs
cases = get_test_cases(dut)
dut.expect_exact('mqtt>', timeout=30)
run_cases(dut, uri, cases)

View File

@ -0,0 +1,229 @@
# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import contextlib
import difflib
import logging
import os
import random
import re
import ssl
import string
from itertools import count, product
from threading import Event, Lock
from typing import Any, Dict, List, Tuple, no_type_check
import paho.mqtt.client as mqtt
import pexpect
import pytest
from pytest_embedded import Dut
DEFAULT_MSG_SIZE = 16
# Publisher class creating a python client to send/receive published data from esp-mqtt client
class MqttPublisher(mqtt.Client):
def __init__(self, repeat, published, publish_cfg, log_details=False): # type: (MqttPublisher, int, int, dict, bool) -> None
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
self.log_details = log_details
self.repeat = repeat
self.publish_cfg = publish_cfg
self.expected_data = f'{self.sample_string * self.repeat}'
self.published = published
self.received = 0
self.lock = Lock()
self.event_client_connected = Event()
self.event_client_got_all = Event()
transport = 'websockets' if self.publish_cfg['transport'] in ['ws', 'wss'] else 'tcp'
super().__init__('MqttTestRunner', userdata=0, transport=transport)
def print_details(self, text): # type: (str) -> None
if self.log_details:
logging.info(text)
def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc:int) -> None:
self.event_client_connected.set()
def on_connect_fail(self, mqttc: Any, obj: Any) -> None:
logging.error('Connect failed')
def on_message(self, mqttc: Any, userdata: Any, msg: mqtt.MQTTMessage) -> None:
payload = msg.payload.decode('utf-8')
if payload == self.expected_data:
userdata += 1
self.user_data_set(userdata)
self.received = userdata
if userdata == self.published:
self.event_client_got_all.set()
else:
differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, self.expected_data))))
logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:'
f'{len(self.expected_data)}')
logging.info(f'Repetitions: {payload.count(self.sample_string)}')
logging.info(f'Pattern: {self.sample_string}')
logging.info(f'First : {payload[:DEFAULT_MSG_SIZE]}')
logging.info(f'Last : {payload[-DEFAULT_MSG_SIZE:]}')
matcher = difflib.SequenceMatcher(a=payload, b=self.expected_data)
for match in matcher.get_matching_blocks():
logging.info(f'Match: {match}')
def __enter__(self) -> Any:
qos = self.publish_cfg['qos']
broker_host = self.publish_cfg['broker_host_' + self.publish_cfg['transport']]
broker_port = self.publish_cfg['broker_port_' + self.publish_cfg['transport']]
try:
self.print_details('Connecting...')
if self.publish_cfg['transport'] in ['ssl', 'wss']:
self.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
self.tls_insecure_set(True)
self.event_client_connected.clear()
self.loop_start()
self.connect(broker_host, broker_port, 60)
except Exception:
self.print_details(f'ENV_TEST_FAILURE: Unexpected error while connecting to broker {broker_host}')
raise
self.print_details(f'Connecting py-client to broker {broker_host}:{broker_port}...')
if not self.event_client_connected.wait(timeout=30):
raise ValueError(f'ENV_TEST_FAILURE: Test script cannot connect to broker: {broker_host}')
self.event_client_got_all.clear()
self.subscribe(self.publish_cfg['subscribe_topic'], qos)
return self
def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None
self.disconnect()
self.loop_stop()
def get_configurations(dut: Dut) -> Dict[str,Any]:
publish_cfg = {}
try:
@no_type_check
def get_broker_from_dut(dut, config_option):
# logging.info('Option:', config_option, dut.app.sdkconfig.get(config_option))
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option))
if value is None:
return None, None
return value.group(1), int(value.group(2))
# Get publish test configuration
publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','')
publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','')
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
except Exception:
logging.info('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
raise
logging.info(f'configuration: {publish_cfg}')
return publish_cfg
@contextlib.contextmanager
def connected_and_subscribed(dut:Dut, transport:str, pattern:str, pattern_repetitions:int) -> Any:
dut.write(f'publish_setup {transport} {pattern} {pattern_repetitions}')
dut.write(f'start')
dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60)
yield
dut.write(f'stop')
def get_scenarios() -> List[Dict[str, int]]:
scenarios = []
# Initialize message sizes and repeat counts (if defined in the environment)
for i in count(0):
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
scenarios.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
continue
break
if not scenarios: # No message sizes present in the env - set defaults
scenarios = [{'len':0, 'repeat':5}, # zero-sized messages
{'len':2, 'repeat':5}, # short messages
{'len':200, 'repeat':3}, # long messages
]
return scenarios
def get_timeout(test_case: Any) -> int:
transport, qos, enqueue, scenario = test_case
if transport in ['ws', 'wss'] or qos == 2:
return 90
return 60
def run_publish_test_case(dut: Dut, test_case: Any, publish_cfg: Any) -> None:
transport, qos, enqueue, scenario = test_case
if publish_cfg['broker_host_' + transport] is None:
pytest.skip(f'Skipping transport: {transport}...')
repeat = scenario['len']
published = scenario['repeat']
publish_cfg['qos'] = qos
publish_cfg['queue'] = enqueue
publish_cfg['transport'] = transport
test_timeout = get_timeout(test_case)
logging.info(f'Starting Publish test: transport:{transport}, qos:{qos}, nr_of_msgs:{published},'
f' msg_size:{repeat*DEFAULT_MSG_SIZE}, enqueue:{enqueue}')
with MqttPublisher(repeat, published, publish_cfg) as publisher, connected_and_subscribed(dut, transport, publisher.sample_string, scenario['len']):
msgs_published: List[mqtt.MQTTMessageInfo] = []
dut.write(f'publish {publisher.published} {qos} {enqueue}')
assert publisher.event_client_got_all.wait(timeout=test_timeout), (f'Not all data received from ESP32: {transport} '
f'qos={qos} received: {publisher.received} '
f'expected: {publisher.published}')
logging.info(' - all data received from ESP32')
payload = publisher.sample_string * publisher.repeat
for _ in range(publisher.published):
with publisher.lock:
msg = publisher.publish(topic=publisher.publish_cfg['publish_topic'], payload=payload, qos=qos)
if qos > 0:
msgs_published.append(msg)
logging.info(f'Published: {len(msgs_published)}')
while msgs_published:
msgs_published = [msg for msg in msgs_published if msg.is_published()]
try:
dut.expect(re.compile(rb'Correct pattern received exactly x times'), timeout=test_timeout)
except pexpect.exceptions.ExceptionPexpect:
dut.write(f'publish_report')
dut.expect(re.compile(rb'Test Report'), timeout=30)
raise
logging.info('ESP32 received all data from runner')
stress_scenarios = [{'len':20, 'repeat':50}] # many medium sized
transport_cases = ['tcp', 'ws', 'wss', 'ssl']
qos_cases = [0, 1, 2]
enqueue_cases = [0, 1]
def make_cases(scenarios: List[Dict[str, int]]) -> List[Tuple[str, int, int, Dict[str, int]]]:
return [test_case for test_case in product(transport_cases, qos_cases, enqueue_cases, scenarios)]
test_cases = make_cases(get_scenarios())
stress_test_cases = make_cases(stress_scenarios)
@pytest.mark.esp32
@pytest.mark.ethernet
@pytest.mark.parametrize('test_case', test_cases)
def test_mqtt_publish(dut: Dut, test_case: Any) -> None:
publish_cfg = get_configurations(dut)
dut.expect(re.compile(rb'mqtt>'), timeout=30)
dut.confirm_write('init', expect_pattern='init', timeout=30)
run_publish_test_case(dut, test_case, publish_cfg)
@pytest.mark.esp32
@pytest.mark.ethernet
@pytest.mark.nightly_run
@pytest.mark.parametrize('test_case', stress_test_cases)
def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None:
publish_cfg = get_configurations(dut)
dut.expect(re.compile(rb'mqtt>'), timeout=30)
dut.write('init')
run_publish_test_case(dut, test_case, publish_cfg)