mirror of
https://github.com/espressif/esp-idf.git
synced 2024-10-05 20:47:46 -04:00
Merge branch 'ci/publish_connect_refactor' into 'master'
Publish connect test refactor See merge request espressif/esp-idf!25311
This commit is contained in:
commit
ade6384954
@ -8,11 +8,11 @@
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include "esp_netif.h"
|
||||
#include "esp_console.h"
|
||||
|
||||
#include "esp_log.h"
|
||||
#include "mqtt_client.h"
|
||||
#include "esp_tls.h"
|
||||
#include "publish_connect_test.h"
|
||||
|
||||
#if (!defined(CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT)) || \
|
||||
(!defined(CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT)) || \
|
||||
@ -34,17 +34,23 @@ extern const uint8_t client_inv_crt[] asm("_binary_client_inv_crt_start");
|
||||
extern const uint8_t client_no_pwd_key[] asm("_binary_client_no_pwd_key_start");
|
||||
|
||||
static const char *TAG = "connect_test";
|
||||
static esp_mqtt_client_handle_t mqtt_client = NULL;
|
||||
static int running_test_case = 0;
|
||||
|
||||
static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data)
|
||||
{
|
||||
(void)handler_args;
|
||||
(void)base;
|
||||
(void)event_id;
|
||||
esp_mqtt_event_handle_t event = event_data;
|
||||
ESP_LOGD(TAG, "Event: %d, Test case: %d", event->event_id, running_test_case);
|
||||
switch (event->event_id) {
|
||||
case MQTT_EVENT_BEFORE_CONNECT:
|
||||
break;
|
||||
case MQTT_EVENT_CONNECTED:
|
||||
ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED: Test=%d", running_test_case);
|
||||
break;
|
||||
case MQTT_EVENT_DISCONNECTED:
|
||||
break;
|
||||
case MQTT_EVENT_ERROR:
|
||||
ESP_LOGI(TAG, "MQTT_EVENT_ERROR: Test=%d", running_test_case);
|
||||
if (event->error_handle->error_type == MQTT_ERROR_TYPE_ESP_TLS) {
|
||||
@ -61,44 +67,17 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
|
||||
}
|
||||
}
|
||||
|
||||
static void create_client(void)
|
||||
static void connect_no_certs(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = "mqtts://127.0.0.1:1234"
|
||||
};
|
||||
esp_mqtt_client_handle_t client = esp_mqtt_client_init(&mqtt_cfg);
|
||||
esp_mqtt_client_register_event(client, ESP_EVENT_ANY_ID, mqtt_event_handler, client);
|
||||
mqtt_client = client;
|
||||
esp_mqtt_client_start(client);
|
||||
ESP_LOGI(TAG, "mqtt client created for connection tests");
|
||||
}
|
||||
|
||||
static void destroy_client(void)
|
||||
{
|
||||
if (mqtt_client) {
|
||||
esp_mqtt_client_stop(mqtt_client);
|
||||
esp_mqtt_client_destroy(mqtt_client);
|
||||
mqtt_client = NULL;
|
||||
ESP_LOGI(TAG, "mqtt client for connection tests destroyed");
|
||||
}
|
||||
}
|
||||
|
||||
static void connect_no_certs(const char *host, const int port)
|
||||
{
|
||||
char uri[64];
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
ESP_LOGI(TAG, "Runnning :CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT");
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
static void connect_with_client_key_password(const char *host, const int port)
|
||||
static void connect_with_client_key_password(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
char uri[64];
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri,
|
||||
.broker.verification.certificate = (const char *)ca_local_crt,
|
||||
@ -107,15 +86,11 @@ static void connect_with_client_key_password(const char *host, const int port)
|
||||
.credentials.authentication.key_password = "esp32",
|
||||
.credentials.authentication.key_password_len = 5
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
static void connect_with_server_der_cert(const char *host, const int port)
|
||||
static void connect_with_server_der_cert(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
char uri[64];
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri,
|
||||
.broker.verification.certificate = (const char *)ca_der_start,
|
||||
@ -123,123 +98,96 @@ static void connect_with_server_der_cert(const char *host, const int port)
|
||||
.credentials.authentication.certificate = "NULL",
|
||||
.credentials.authentication.key = "NULL"
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
static void connect_with_wrong_server_cert(const char *host, const int port)
|
||||
static void connect_with_wrong_server_cert(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
char uri[64];
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri,
|
||||
.broker.verification.certificate = (const char *)client_pwd_crt,
|
||||
.credentials.authentication.certificate = "NULL",
|
||||
.credentials.authentication.key = "NULL"
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
static void connect_with_server_cert(const char *host, const int port)
|
||||
static void connect_with_server_cert(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
char uri[64];
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri,
|
||||
.broker.verification.certificate = (const char *)ca_local_crt,
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
static void connect_with_server_client_certs(const char *host, const int port)
|
||||
static void connect_with_server_client_certs(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
char uri[64];
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri,
|
||||
.broker.verification.certificate = (const char *)ca_local_crt,
|
||||
.credentials.authentication.certificate = (const char *)client_pwd_crt,
|
||||
.credentials.authentication.key = (const char *)client_no_pwd_key
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
static void connect_with_invalid_client_certs(const char *host, const int port)
|
||||
static void connect_with_invalid_client_certs(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
char uri[64];
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri,
|
||||
.broker.verification.certificate = (const char *)ca_local_crt,
|
||||
.credentials.authentication.certificate = (const char *)client_inv_crt,
|
||||
.credentials.authentication.key = (const char *)client_no_pwd_key
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
static void connect_with_alpn(const char *host, const int port)
|
||||
static void connect_with_alpn(esp_mqtt_client_handle_t client, const char *uri)
|
||||
{
|
||||
char uri[64];
|
||||
const char *alpn_protos[] = { "mymqtt", NULL };
|
||||
sprintf(uri, "mqtts://%s:%d", host, port);
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = uri,
|
||||
.broker.verification.alpn_protos = alpn_protos
|
||||
};
|
||||
esp_mqtt_set_config(mqtt_client, &mqtt_cfg);
|
||||
esp_mqtt_client_disconnect(mqtt_client);
|
||||
esp_mqtt_client_reconnect(mqtt_client);
|
||||
esp_mqtt_set_config(client, &mqtt_cfg);
|
||||
}
|
||||
|
||||
void connection_test(const char *line)
|
||||
{
|
||||
char test_type[32];
|
||||
char host[32];
|
||||
int port;
|
||||
int test_case;
|
||||
void connect_setup(command_context_t * ctx) {
|
||||
esp_mqtt_client_register_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, ctx->data);
|
||||
}
|
||||
|
||||
sscanf(line, "%s %s %d %d", test_type, host, &port, &test_case);
|
||||
if (mqtt_client == NULL) {
|
||||
create_client();
|
||||
}
|
||||
if (strcmp(host, "teardown") == 0) {
|
||||
destroy_client();;
|
||||
}
|
||||
ESP_LOGI(TAG, "CASE:%d, connecting to mqtts://%s:%d ", test_case, host, port);
|
||||
void connect_teardown(command_context_t * ctx) {
|
||||
esp_mqtt_client_unregister_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler);
|
||||
}
|
||||
void connection_test(command_context_t * ctx, const char *uri, int test_case)
|
||||
{
|
||||
ESP_LOGI(TAG, "CASE:%d, connecting to %s", test_case, uri);
|
||||
running_test_case = test_case;
|
||||
switch (test_case) {
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT:
|
||||
connect_no_certs(host, port);
|
||||
connect_no_certs(ctx->mqtt_client, uri);
|
||||
break;
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT:
|
||||
connect_with_server_cert(host, port);
|
||||
connect_with_server_cert(ctx->mqtt_client, uri);
|
||||
break;
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH:
|
||||
connect_with_server_client_certs(host, port);
|
||||
connect_with_server_client_certs(ctx->mqtt_client, uri);
|
||||
break;
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT:
|
||||
connect_with_wrong_server_cert(host, port);
|
||||
connect_with_wrong_server_cert(ctx->mqtt_client, uri);
|
||||
break;
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT:
|
||||
connect_with_server_der_cert(host, port);
|
||||
connect_with_server_der_cert(ctx->mqtt_client, uri);
|
||||
break;
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD:
|
||||
connect_with_client_key_password(host, port);
|
||||
connect_with_client_key_password(ctx->mqtt_client, uri);
|
||||
break;
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT:
|
||||
connect_with_invalid_client_certs(host, port);
|
||||
connect_with_invalid_client_certs(ctx->mqtt_client, uri);
|
||||
break;
|
||||
case CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN:
|
||||
connect_with_alpn(host, port);
|
||||
connect_with_alpn(ctx->mqtt_client, uri);
|
||||
break;
|
||||
default:
|
||||
ESP_LOGE(TAG, "Unknown test case %d ", test_case);
|
||||
|
@ -8,68 +8,310 @@
|
||||
*/
|
||||
#include <stdio.h>
|
||||
#include <stddef.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "esp_system.h"
|
||||
#include "mqtt_client.h"
|
||||
#include "nvs_flash.h"
|
||||
#include "esp_event.h"
|
||||
#include "esp_netif.h"
|
||||
#include "protocol_examples_common.h"
|
||||
#include "esp_console.h"
|
||||
#include "argtable3/argtable3.h"
|
||||
#include "esp_log.h"
|
||||
#include "publish_connect_test.h"
|
||||
|
||||
static const char *TAG = "publish_connect_test";
|
||||
|
||||
void connection_test(const char *line);
|
||||
void publish_test(const char *line);
|
||||
command_context_t command_context;
|
||||
connection_args_t connection_args;
|
||||
publish_setup_args_t publish_setup_args;
|
||||
publish_args_t publish_args;
|
||||
|
||||
static void get_string(char *line, size_t size)
|
||||
{
|
||||
int count = 0;
|
||||
while (count < size) {
|
||||
int c = fgetc(stdin);
|
||||
if (c == '\n') {
|
||||
line[count] = '\0';
|
||||
break;
|
||||
} else if (c > 0 && c < 127) {
|
||||
line[count] = c;
|
||||
++count;
|
||||
}
|
||||
vTaskDelay(10 / portTICK_PERIOD_MS);
|
||||
#define RETURN_ON_PARSE_ERROR(args) do { \
|
||||
int nerrors = arg_parse(argc, argv, (void **) &(args)); \
|
||||
if (nerrors != 0) { \
|
||||
arg_print_errors(stderr, (args).end, argv[0]); \
|
||||
return 1; \
|
||||
}} while(0)
|
||||
|
||||
|
||||
static int do_free_heap(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
ESP_LOGI(TAG, "Note free memory: %d bytes", esp_get_free_heap_size());
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_init(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
const esp_mqtt_client_config_t mqtt_cfg = {
|
||||
.broker.address.uri = "mqtts://127.0.0.1:1234",
|
||||
.network.disable_auto_reconnect = true
|
||||
};
|
||||
command_context.mqtt_client = esp_mqtt_client_init(&mqtt_cfg);
|
||||
if(!command_context.mqtt_client) {
|
||||
ESP_LOGE(TAG, "Failed to initialize client");
|
||||
return 1;
|
||||
}
|
||||
publish_init_flags();
|
||||
ESP_LOGI(TAG, "Mqtt client initialized");
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_start(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
if(esp_mqtt_client_start(command_context.mqtt_client) != ESP_OK) {
|
||||
ESP_LOGE(TAG, "Failed to start mqtt client task");
|
||||
return 1;
|
||||
}
|
||||
ESP_LOGI(TAG, "Mqtt client started");
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_stop(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
if(esp_mqtt_client_stop(command_context.mqtt_client) != ESP_OK) {
|
||||
ESP_LOGE(TAG, "Failed to stop mqtt client task");
|
||||
return 1;
|
||||
}
|
||||
ESP_LOGI(TAG, "Mqtt client stoped");
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_disconnect(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
if(esp_mqtt_client_disconnect(command_context.mqtt_client) != ESP_OK) {
|
||||
ESP_LOGE(TAG, "Failed to request disconnection");
|
||||
return 1;
|
||||
}
|
||||
ESP_LOGI(TAG, "Mqtt client disconnected");
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_connect_setup(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
connect_setup(&command_context);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_connect_teardown(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
connect_teardown(&command_context);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_reconnect(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
if(esp_mqtt_client_reconnect(command_context.mqtt_client) != ESP_OK) {
|
||||
ESP_LOGE(TAG, "Failed to request reconnection");
|
||||
return 1;
|
||||
}
|
||||
ESP_LOGI(TAG, "Mqtt client will reconnect");
|
||||
return 0;
|
||||
;
|
||||
}
|
||||
|
||||
static int do_destroy(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
esp_mqtt_client_destroy(command_context.mqtt_client);
|
||||
command_context.mqtt_client = NULL;
|
||||
ESP_LOGI(TAG, "mqtt client for tests destroyed");
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_connect(int argc, char **argv)
|
||||
{
|
||||
int nerrors = arg_parse(argc, argv, (void **) &connection_args);
|
||||
if (nerrors != 0) {
|
||||
arg_print_errors(stderr, connection_args.end, argv[0]);
|
||||
return 1;
|
||||
}
|
||||
if(!command_context.mqtt_client) {
|
||||
ESP_LOGE(TAG, "MQTT client not initialized, call init first");
|
||||
return 1;
|
||||
}
|
||||
connection_test(&command_context, *connection_args.uri->sval, *connection_args.test_case->ival);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_publish_setup(int argc, char **argv) {
|
||||
RETURN_ON_PARSE_ERROR(publish_setup_args);
|
||||
if(command_context.data) {
|
||||
free(command_context.data);
|
||||
}
|
||||
command_context.data = calloc(1, sizeof(publish_context_t));
|
||||
((publish_context_t*)command_context.data)->pattern = strdup(*publish_setup_args.pattern->sval);
|
||||
((publish_context_t*)command_context.data)->pattern_repetitions = *publish_setup_args.pattern_repetitions->ival;
|
||||
publish_setup(&command_context, *publish_setup_args.transport->sval);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_publish(int argc, char **argv) {
|
||||
RETURN_ON_PARSE_ERROR(publish_args);
|
||||
publish_test(&command_context, publish_args.expected_to_publish->ival[0], publish_args.qos->ival[0], publish_args.enqueue->ival[0]);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int do_publish_report(int argc, char **argv) {
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
publish_context_t * ctx = command_context.data;
|
||||
ESP_LOGI(TAG,"Test Report : Messages received %d, %d expected", ctx->nr_of_msg_received, ctx->nr_of_msg_expected);
|
||||
|
||||
return 0;
|
||||
}
|
||||
void register_common_commands(void) {
|
||||
const esp_console_cmd_t init = {
|
||||
.command = "init",
|
||||
.help = "Run inition test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_init,
|
||||
};
|
||||
|
||||
const esp_console_cmd_t start = {
|
||||
.command = "start",
|
||||
.help = "Run startion test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_start,
|
||||
};
|
||||
const esp_console_cmd_t stop = {
|
||||
.command = "stop",
|
||||
.help = "Run stopion test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_stop,
|
||||
};
|
||||
const esp_console_cmd_t destroy = {
|
||||
.command = "destroy",
|
||||
.help = "Run destroyion test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_destroy,
|
||||
};
|
||||
const esp_console_cmd_t free_heap = {
|
||||
.command = "free_heap",
|
||||
.help = "Run destroyion test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_free_heap,
|
||||
};
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&init));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&start));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&stop));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&destroy));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&free_heap));
|
||||
}
|
||||
void register_publish_commands(void) {
|
||||
publish_setup_args.transport = arg_str1(NULL,NULL,"<transport>", "Selected transport to test");
|
||||
publish_setup_args.pattern = arg_str1(NULL,NULL,"<pattern>", "Message pattern repeated to build big messages");
|
||||
publish_setup_args.pattern_repetitions = arg_int1(NULL,NULL,"<pattern repetitions>", "How many times the pattern is repeated");
|
||||
publish_setup_args.end = arg_end(1);
|
||||
|
||||
publish_args.expected_to_publish = arg_int1(NULL,NULL,"<number of messages>", "How many times the pattern is repeated");
|
||||
publish_args.qos = arg_int1(NULL,NULL,"<qos>", "How many times the pattern is repeated");
|
||||
publish_args.enqueue = arg_int1(NULL,NULL,"<enqueue>", "How many times the pattern is repeated");
|
||||
publish_args.end = arg_end(1);
|
||||
const esp_console_cmd_t publish_setup = {
|
||||
.command = "publish_setup",
|
||||
.help = "Run publish test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_publish_setup,
|
||||
.argtable = &publish_setup_args
|
||||
};
|
||||
|
||||
const esp_console_cmd_t publish = {
|
||||
.command = "publish",
|
||||
.help = "Run publish test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_publish,
|
||||
.argtable = &publish_args
|
||||
};
|
||||
const esp_console_cmd_t publish_report = {
|
||||
.command = "publish_report",
|
||||
.help = "Run destroyion test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_publish_report,
|
||||
};
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&publish_setup));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&publish));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&publish_report));
|
||||
}
|
||||
void register_connect_commands(void){
|
||||
connection_args.uri = arg_str1(NULL,NULL,"<broker uri>", "Broker address");
|
||||
connection_args.test_case = arg_int1(NULL, NULL, "<test case>","Selected test case");
|
||||
connection_args.end = arg_end(1);
|
||||
|
||||
const esp_console_cmd_t connect = {
|
||||
.command = "connect",
|
||||
.help = "Run connection test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_connect,
|
||||
.argtable = &connection_args
|
||||
};
|
||||
|
||||
const esp_console_cmd_t reconnect = {
|
||||
.command = "reconnect",
|
||||
.help = "Run reconnection test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_reconnect,
|
||||
};
|
||||
const esp_console_cmd_t connection_setup = {
|
||||
.command = "connection_setup",
|
||||
.help = "Run reconnection test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_connect_setup,
|
||||
};
|
||||
const esp_console_cmd_t connection_teardown = {
|
||||
.command = "connection_teardown",
|
||||
.help = "Run reconnection test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_connect_teardown,
|
||||
};
|
||||
const esp_console_cmd_t disconnect = {
|
||||
.command = "disconnect",
|
||||
.help = "Run disconnection test\n",
|
||||
.hint = NULL,
|
||||
.func = &do_disconnect,
|
||||
};
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&connect));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&disconnect));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&reconnect));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&connection_setup));
|
||||
ESP_ERROR_CHECK(esp_console_cmd_register(&connection_teardown));
|
||||
}
|
||||
|
||||
void app_main(void)
|
||||
{
|
||||
char line[256];
|
||||
static const size_t max_line = 256;
|
||||
|
||||
ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size());
|
||||
ESP_LOGI(TAG, "[APP] IDF version: %s", esp_get_idf_version());
|
||||
|
||||
esp_log_level_set("*", ESP_LOG_INFO);
|
||||
esp_log_level_set("wifi", ESP_LOG_ERROR);
|
||||
esp_log_level_set("mqtt_client", ESP_LOG_VERBOSE);
|
||||
esp_log_level_set("transport_base", ESP_LOG_VERBOSE);
|
||||
esp_log_level_set("transport", ESP_LOG_VERBOSE);
|
||||
esp_log_level_set("outbox", ESP_LOG_VERBOSE);
|
||||
|
||||
ESP_ERROR_CHECK(nvs_flash_init());
|
||||
ESP_ERROR_CHECK(esp_netif_init());
|
||||
ESP_ERROR_CHECK(esp_event_loop_create_default());
|
||||
|
||||
/* This helper function configures Wi-Fi or Ethernet, as selected in menuconfig.
|
||||
* Read "Establishing Wi-Fi or Ethernet Connection" section in
|
||||
* examples/protocols/README.md for more information about this function.
|
||||
*/
|
||||
ESP_ERROR_CHECK(example_connect());
|
||||
esp_console_repl_t *repl = NULL;
|
||||
esp_console_repl_config_t repl_config = ESP_CONSOLE_REPL_CONFIG_DEFAULT();
|
||||
repl_config.prompt = "mqtt>";
|
||||
repl_config.max_cmdline_length = max_line;
|
||||
esp_console_register_help_command();
|
||||
register_common_commands();
|
||||
register_connect_commands();
|
||||
register_publish_commands();
|
||||
|
||||
while (1) {
|
||||
get_string(line, sizeof(line));
|
||||
if (memcmp(line, "conn", 4) == 0) {
|
||||
// line starting with "conn" indicate connection tests
|
||||
connection_test(line);
|
||||
get_string(line, sizeof(line));
|
||||
continue;
|
||||
} else {
|
||||
publish_test(line);
|
||||
}
|
||||
}
|
||||
|
||||
esp_console_dev_uart_config_t hw_config = ESP_CONSOLE_DEV_UART_CONFIG_DEFAULT();
|
||||
ESP_ERROR_CHECK(esp_console_new_repl_uart(&hw_config, &repl_config, &repl));
|
||||
ESP_ERROR_CHECK(esp_console_start_repl(repl));
|
||||
}
|
||||
|
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD
|
||||
*
|
||||
* SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "mqtt_client.h"
|
||||
|
||||
typedef enum {NONE, TCP, SSL, WS, WSS} transport_t;
|
||||
|
||||
typedef struct {
|
||||
esp_mqtt_client_handle_t mqtt_client;
|
||||
void * data;
|
||||
} command_context_t;
|
||||
|
||||
typedef struct {
|
||||
transport_t selected_transport;
|
||||
char *pattern;
|
||||
int pattern_repetitions;
|
||||
int qos;
|
||||
char *expected;
|
||||
size_t expected_size;
|
||||
size_t nr_of_msg_received;
|
||||
size_t nr_of_msg_expected;
|
||||
char * received_data;
|
||||
} publish_context_t ;
|
||||
|
||||
typedef struct {
|
||||
struct arg_str *uri;
|
||||
struct arg_int *test_case;
|
||||
struct arg_end *end;
|
||||
} connection_args_t;
|
||||
|
||||
typedef struct {
|
||||
struct arg_int *expected_to_publish;
|
||||
struct arg_int *qos;
|
||||
struct arg_int *enqueue;
|
||||
struct arg_end *end;
|
||||
} publish_args_t;
|
||||
|
||||
typedef struct {
|
||||
struct arg_str *transport;
|
||||
struct arg_str *pattern;
|
||||
struct arg_int *pattern_repetitions;
|
||||
struct arg_end *end;
|
||||
} publish_setup_args_t;
|
||||
|
||||
void publish_init_flags(void);
|
||||
void publish_setup(command_context_t * ctx, char const * transport);
|
||||
void publish_teardown(command_context_t * ctx);
|
||||
void publish_test(command_context_t * ctx, int expect_to_publish, int qos, bool enqueue);
|
||||
void connection_test(command_context_t * ctx, const char *uri, int test_case);
|
||||
void connect_setup(command_context_t * ctx);
|
||||
void connect_teardown(command_context_t * ctx);
|
@ -6,33 +6,25 @@
|
||||
software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
CONDITIONS OF ANY KIND, either express or implied.
|
||||
*/
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
#include "esp_system.h"
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include "freertos/task.h"
|
||||
#include "freertos/event_groups.h"
|
||||
#include <freertos/event_groups.h>
|
||||
#include "esp_system.h"
|
||||
|
||||
#include "esp_log.h"
|
||||
#include "mqtt_client.h"
|
||||
#include "sdkconfig.h"
|
||||
#include "publish_connect_test.h"
|
||||
|
||||
static const char *TAG = "publish_test";
|
||||
|
||||
static EventGroupHandle_t mqtt_event_group;
|
||||
const static int CONNECTED_BIT = BIT0;
|
||||
|
||||
static esp_mqtt_client_handle_t mqtt_client = NULL;
|
||||
|
||||
static char *expected_data = NULL;
|
||||
static char *actual_data = NULL;
|
||||
static size_t expected_size = 0;
|
||||
static size_t expected_published = 0;
|
||||
static size_t actual_published = 0;
|
||||
static int qos_test = 0;
|
||||
|
||||
#if CONFIG_EXAMPLE_BROKER_CERTIFICATE_OVERRIDDEN == 1
|
||||
static const uint8_t mqtt_eclipseprojects_io_pem_start[] = "-----BEGIN CERTIFICATE-----\n" CONFIG_EXAMPLE_BROKER_CERTIFICATE_OVERRIDE "\n-----END CERTIFICATE-----";
|
||||
#else
|
||||
@ -42,6 +34,7 @@ extern const uint8_t mqtt_eclipseprojects_io_pem_end[] asm("_binary_mqtt_eclip
|
||||
|
||||
static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data)
|
||||
{
|
||||
publish_context_t * test_data = handler_args;
|
||||
esp_mqtt_event_handle_t event = event_data;
|
||||
esp_mqtt_client_handle_t client = event->client;
|
||||
static int msg_id = 0;
|
||||
@ -52,7 +45,7 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
|
||||
case MQTT_EVENT_CONNECTED:
|
||||
ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED");
|
||||
xEventGroupSetBits(mqtt_event_group, CONNECTED_BIT);
|
||||
msg_id = esp_mqtt_client_subscribe(client, CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, qos_test);
|
||||
msg_id = esp_mqtt_client_subscribe(client, CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, test_data->qos);
|
||||
ESP_LOGI(TAG, "sent subscribe successful %s , msg_id=%d", CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, msg_id);
|
||||
|
||||
break;
|
||||
@ -67,13 +60,12 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
|
||||
ESP_LOGI(TAG, "MQTT_EVENT_UNSUBSCRIBED, msg_id=%d", event->msg_id);
|
||||
break;
|
||||
case MQTT_EVENT_PUBLISHED:
|
||||
ESP_LOGI(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id);
|
||||
ESP_LOGD(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id);
|
||||
break;
|
||||
case MQTT_EVENT_DATA:
|
||||
ESP_LOGI(TAG, "MQTT_EVENT_DATA");
|
||||
printf("TOPIC=%.*s\r\n", event->topic_len, event->topic);
|
||||
printf("DATA=%.*s\r\n", event->data_len, event->data);
|
||||
printf("ID=%d, total_len=%d, data_len=%d, current_data_offset=%d\n", event->msg_id, event->total_data_len, event->data_len, event->current_data_offset);
|
||||
ESP_LOGI(TAG, "TOPIC=%.*s", event->topic_len, event->topic);
|
||||
ESP_LOGI(TAG, "ID=%d, total_len=%d, data_len=%d, current_data_offset=%d", event->msg_id, event->total_data_len, event->data_len, event->current_data_offset);
|
||||
if (event->topic) {
|
||||
actual_len = event->data_len;
|
||||
msg_id = event->msg_id;
|
||||
@ -85,24 +77,23 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
|
||||
abort();
|
||||
}
|
||||
}
|
||||
memcpy(actual_data + event->current_data_offset, event->data, event->data_len);
|
||||
memcpy(test_data->received_data + event->current_data_offset, event->data, event->data_len);
|
||||
if (actual_len == event->total_data_len) {
|
||||
if (0 == memcmp(actual_data, expected_data, expected_size)) {
|
||||
printf("OK!");
|
||||
memset(actual_data, 0, expected_size);
|
||||
actual_published ++;
|
||||
if (actual_published == expected_published) {
|
||||
printf("Correct pattern received exactly x times\n");
|
||||
if (0 == memcmp(test_data->received_data, test_data->expected, test_data->expected_size)) {
|
||||
memset(test_data->received_data, 0, test_data->expected_size);
|
||||
test_data->nr_of_msg_received ++;
|
||||
if (test_data->nr_of_msg_received == test_data->nr_of_msg_expected) {
|
||||
ESP_LOGI(TAG, "Correct pattern received exactly x times");
|
||||
ESP_LOGI(TAG, "Test finished correctly!");
|
||||
}
|
||||
} else {
|
||||
printf("FAILED!");
|
||||
ESP_LOGE(TAG, "FAILED!");
|
||||
abort();
|
||||
}
|
||||
}
|
||||
break;
|
||||
case MQTT_EVENT_ERROR:
|
||||
ESP_LOGI(TAG, "MQTT_EVENT_ERROR");
|
||||
ESP_LOGE(TAG, "MQTT_EVENT_ERROR");
|
||||
break;
|
||||
default:
|
||||
ESP_LOGI(TAG, "Other event id:%d", event->event_id);
|
||||
@ -110,37 +101,31 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_
|
||||
}
|
||||
}
|
||||
|
||||
typedef enum {NONE, TCP, SSL, WS, WSS} transport_t;
|
||||
static transport_t current_transport;
|
||||
|
||||
|
||||
void test_init(void)
|
||||
{
|
||||
mqtt_event_group = xEventGroupCreate();
|
||||
esp_mqtt_client_config_t config = {0};
|
||||
mqtt_client = esp_mqtt_client_init(&config);
|
||||
current_transport = NONE;
|
||||
esp_mqtt_client_register_event(mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, NULL);
|
||||
ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size());
|
||||
}
|
||||
|
||||
void pattern_setup(char *pattern, int repeat)
|
||||
void pattern_setup(publish_context_t * test_data)
|
||||
{
|
||||
int pattern_size = strlen(pattern);
|
||||
free(expected_data);
|
||||
free(actual_data);
|
||||
actual_published = 0;
|
||||
expected_size = pattern_size * repeat;
|
||||
expected_data = malloc(expected_size);
|
||||
actual_data = malloc(expected_size);
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
memcpy(expected_data + i * pattern_size, pattern, pattern_size);
|
||||
int pattern_size = strlen(test_data->pattern);
|
||||
free(test_data->expected);
|
||||
free(test_data->received_data);
|
||||
test_data->nr_of_msg_received = 0;
|
||||
test_data->expected_size = (size_t)(pattern_size) * test_data->pattern_repetitions;
|
||||
test_data->expected = malloc(test_data->expected_size);
|
||||
test_data->received_data = malloc(test_data->expected_size);
|
||||
for (int i = 0; i < test_data->pattern_repetitions; i++) {
|
||||
memcpy(test_data->expected + (ptrdiff_t)(i * pattern_size), test_data->pattern, pattern_size);
|
||||
}
|
||||
printf("EXPECTED STRING %.*s, SIZE:%d\n", expected_size, expected_data, expected_size);
|
||||
ESP_LOGI(TAG, "EXPECTED STRING %.*s, SIZE:%d", test_data->expected_size, test_data->expected, test_data->expected_size);
|
||||
}
|
||||
|
||||
static void configure_client(char *transport)
|
||||
static void configure_client(command_context_t * ctx, const char *transport)
|
||||
{
|
||||
publish_context_t * test_data = ctx->data;
|
||||
ESP_LOGI(TAG, "Configuration");
|
||||
transport_t selected_transport;
|
||||
if (0 == strcmp(transport, "tcp")) {
|
||||
@ -157,7 +142,8 @@ static void configure_client(char *transport)
|
||||
}
|
||||
|
||||
|
||||
if (selected_transport != current_transport) {
|
||||
if (selected_transport != test_data->selected_transport) {
|
||||
test_data->selected_transport = selected_transport;
|
||||
esp_mqtt_client_config_t config = {0};
|
||||
switch (selected_transport) {
|
||||
case NONE:
|
||||
@ -183,43 +169,45 @@ static void configure_client(char *transport)
|
||||
ESP_LOGI(TAG, "Set certificate");
|
||||
config.broker.verification.certificate = (const char *)mqtt_eclipseprojects_io_pem_start;
|
||||
}
|
||||
esp_mqtt_set_config(mqtt_client, &config);
|
||||
|
||||
esp_mqtt_set_config(ctx->mqtt_client, &config);
|
||||
}
|
||||
|
||||
}
|
||||
void publish_test(const char *line)
|
||||
{
|
||||
char pattern[32];
|
||||
char transport[32];
|
||||
int repeat = 0;
|
||||
int enqueue = 0;
|
||||
|
||||
static bool is_test_init = false;
|
||||
if (!is_test_init) {
|
||||
test_init();
|
||||
is_test_init = true;
|
||||
} else {
|
||||
esp_mqtt_client_stop(mqtt_client);
|
||||
}
|
||||
void publish_init_flags(void) {
|
||||
mqtt_event_group = xEventGroupCreate();
|
||||
}
|
||||
|
||||
sscanf(line, "%s %s %d %d %d %d", transport, pattern, &repeat, &expected_published, &qos_test, &enqueue);
|
||||
ESP_LOGI(TAG, "PATTERN:%s REPEATED:%d PUBLISHED:%d", pattern, repeat, expected_published);
|
||||
pattern_setup(pattern, repeat);
|
||||
void publish_setup(command_context_t * ctx, char const * const transport) {
|
||||
xEventGroupClearBits(mqtt_event_group, CONNECTED_BIT);
|
||||
configure_client(transport);
|
||||
esp_mqtt_client_start(mqtt_client);
|
||||
publish_context_t * data = (publish_context_t*)ctx->data;
|
||||
pattern_setup(data);
|
||||
configure_client(ctx, transport);
|
||||
esp_mqtt_client_register_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, data);
|
||||
}
|
||||
|
||||
ESP_LOGI(TAG, "Note free memory: %d bytes", esp_get_free_heap_size());
|
||||
void publish_teardown(command_context_t * ctx)
|
||||
{
|
||||
esp_mqtt_client_unregister_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler);
|
||||
}
|
||||
|
||||
void publish_test(command_context_t * ctx, int expect_to_publish, int qos, bool enqueue)
|
||||
{
|
||||
publish_context_t * data = (publish_context_t*)ctx->data;
|
||||
data->nr_of_msg_expected = expect_to_publish;
|
||||
ESP_LOGI(TAG, "PATTERN:%s REPEATED:%d PUBLISHED:%d", data->pattern, data->pattern_repetitions, data->nr_of_msg_expected);
|
||||
|
||||
xEventGroupWaitBits(mqtt_event_group, CONNECTED_BIT, false, true, portMAX_DELAY);
|
||||
for (int i = 0; i < expected_published; i++) {
|
||||
for (int i = 0; i < data->nr_of_msg_expected; i++) {
|
||||
int msg_id;
|
||||
if (enqueue) {
|
||||
msg_id = esp_mqtt_client_enqueue(mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, expected_data, expected_size, qos_test, 0, true);
|
||||
msg_id = esp_mqtt_client_enqueue(ctx->mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, data->expected, data->expected_size, qos, 0, true);
|
||||
} else {
|
||||
msg_id = esp_mqtt_client_publish(mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, expected_data, expected_size, qos_test, 0);
|
||||
msg_id = esp_mqtt_client_publish(ctx->mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, data->expected, data->expected_size, qos, 0);
|
||||
if(msg_id < 0) {
|
||||
ESP_LOGE(TAG, "Failed to publish");
|
||||
break;
|
||||
}
|
||||
}
|
||||
ESP_LOGI(TAG, "[%d] Publishing...", msg_id);
|
||||
ESP_LOGD(TAG, "Publishing msg_id=%d", msg_id);
|
||||
}
|
||||
}
|
||||
|
@ -1,31 +1,21 @@
|
||||
# SPDX-FileCopyrightText: 2022-2023 Espressif Systems (Shanghai) CO LTD
|
||||
# SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
from __future__ import print_function, unicode_literals
|
||||
|
||||
import difflib
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import select
|
||||
import socket
|
||||
import socketserver
|
||||
import ssl
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
from itertools import count
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import pytest
|
||||
from common_test_methods import get_host_ip4_by_dest_ip
|
||||
from pytest_embedded import Dut
|
||||
from pytest_embedded_qemu.dut import QemuDut
|
||||
|
||||
DEFAULT_MSG_SIZE = 16
|
||||
SERVER_PORT = 2222
|
||||
|
||||
|
||||
def _path(f): # type: (str) -> str
|
||||
@ -42,313 +32,98 @@ def set_server_cert_cn(ip): # type: (str) -> None
|
||||
raise RuntimeError('openssl command {} failed'.format(args))
|
||||
|
||||
|
||||
# Publisher class creating a python client to send/receive published data from esp-mqtt client
|
||||
class MqttPublisher:
|
||||
event_client_connected = Event()
|
||||
event_client_got_all = Event()
|
||||
expected_data = ''
|
||||
published = 0
|
||||
sample = ''
|
||||
class MQTTHandler(socketserver.StreamRequestHandler):
|
||||
|
||||
def __init__(self, dut, transport,
|
||||
qos, repeat, published, queue, publish_cfg, log_details=False): # type: (MqttPublisher, Dut, str, int, int, int, int, dict, bool) -> None
|
||||
# instance variables used as parameters of the publish test
|
||||
self.event_stop_client = Event()
|
||||
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
|
||||
self.client = None
|
||||
self.dut = dut
|
||||
self.log_details = log_details
|
||||
self.repeat = repeat
|
||||
self.publish_cfg = publish_cfg
|
||||
self.publish_cfg['qos'] = qos
|
||||
self.publish_cfg['queue'] = queue
|
||||
self.publish_cfg['transport'] = transport
|
||||
self.lock = Lock()
|
||||
# static variables used to pass options to and from static callbacks of paho-mqtt client
|
||||
MqttPublisher.event_client_connected = Event()
|
||||
MqttPublisher.event_client_got_all = Event()
|
||||
MqttPublisher.published = published
|
||||
MqttPublisher.event_client_connected.clear()
|
||||
MqttPublisher.event_client_got_all.clear()
|
||||
MqttPublisher.expected_data = f'{self.sample_string * self.repeat}'
|
||||
MqttPublisher.sample = self.sample_string
|
||||
|
||||
def print_details(self, text): # type: (str) -> None
|
||||
if self.log_details:
|
||||
logging.info(text)
|
||||
|
||||
def mqtt_client_task(self, client, lock): # type: (MqttPublisher, mqtt.Client, Lock) -> None
|
||||
while not self.event_stop_client.is_set():
|
||||
with lock:
|
||||
client.loop()
|
||||
time.sleep(0.001) # yield to other threads
|
||||
|
||||
# The callback for when the client receives a CONNACK response from the server (needs to be static)
|
||||
@staticmethod
|
||||
def on_connect(_client, _userdata, _flags, _rc): # type: (mqtt.Client, tuple, bool, str) -> None
|
||||
MqttPublisher.event_client_connected.set()
|
||||
|
||||
# The callback for when a PUBLISH message is received from the server (needs to be static)
|
||||
@staticmethod
|
||||
def on_message(client, userdata, msg): # type: (mqtt.Client, int, mqtt.client.MQTTMessage) -> None
|
||||
payload = msg.payload.decode('utf-8')
|
||||
if payload == MqttPublisher.expected_data:
|
||||
userdata += 1
|
||||
client.user_data_set(userdata)
|
||||
if userdata == MqttPublisher.published:
|
||||
MqttPublisher.event_client_got_all.set()
|
||||
else:
|
||||
differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, MqttPublisher.expected_data))))
|
||||
logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:'
|
||||
f'{len(MqttPublisher.expected_data)}')
|
||||
logging.info(f'Repetitions: {payload.count(MqttPublisher.sample)}')
|
||||
logging.info(f'Pattern: {MqttPublisher.sample}')
|
||||
logging.info(f'First : {payload[:DEFAULT_MSG_SIZE]}')
|
||||
logging.info(f'Last : {payload[-DEFAULT_MSG_SIZE:]}')
|
||||
matcher = difflib.SequenceMatcher(a=payload, b=MqttPublisher.expected_data)
|
||||
for match in matcher.get_matching_blocks():
|
||||
logging.info(f'Match: {match}')
|
||||
|
||||
def __enter__(self): # type: (MqttPublisher) -> None
|
||||
|
||||
qos = self.publish_cfg['qos']
|
||||
queue = self.publish_cfg['queue']
|
||||
transport = self.publish_cfg['transport']
|
||||
broker_host = self.publish_cfg['broker_host_' + transport]
|
||||
broker_port = self.publish_cfg['broker_port_' + transport]
|
||||
|
||||
# Start the test
|
||||
self.print_details(f'PUBLISH TEST: transport:{transport}, qos:{qos}, sequence:{MqttPublisher.published},'
|
||||
f"enqueue:{queue}, sample msg:'{MqttPublisher.expected_data}'")
|
||||
|
||||
try:
|
||||
if transport in ['ws', 'wss']:
|
||||
self.client = mqtt.Client(transport='websockets')
|
||||
def handle(self) -> None:
|
||||
logging.info(' - connection from: {}'.format(self.client_address))
|
||||
data = bytearray(self.request.recv(1024))
|
||||
message = ''.join(format(x, '02x') for x in data)
|
||||
if message[0:16] == '101800044d515454':
|
||||
if self.server.refuse_connection is False: # type: ignore
|
||||
logging.info(' - received mqtt connect, sending ACK')
|
||||
self.request.send(bytearray.fromhex('20020000'))
|
||||
else:
|
||||
self.client = mqtt.Client()
|
||||
assert self.client is not None
|
||||
self.client.on_connect = MqttPublisher.on_connect
|
||||
self.client.on_message = MqttPublisher.on_message
|
||||
self.client.user_data_set(0)
|
||||
|
||||
if transport in ['ssl', 'wss']:
|
||||
self.client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
|
||||
self.client.tls_insecure_set(True)
|
||||
self.print_details('Connecting...')
|
||||
self.client.connect(broker_host, broker_port, 60)
|
||||
except Exception:
|
||||
self.print_details(f'ENV_TEST_FAILURE: Unexpected error while connecting to broker {broker_host}')
|
||||
raise
|
||||
# Starting a py-client in a separate thread
|
||||
thread1 = Thread(target=self.mqtt_client_task, args=(self.client, self.lock))
|
||||
thread1.start()
|
||||
self.print_details('Connecting py-client to broker {}:{}...'.format(broker_host, broker_port))
|
||||
if not MqttPublisher.event_client_connected.wait(timeout=30):
|
||||
raise ValueError(f'ENV_TEST_FAILURE: Test script cannot connect to broker: {broker_host}')
|
||||
with self.lock:
|
||||
self.client.subscribe(self.publish_cfg['subscribe_topic'], qos)
|
||||
self.dut.write(f'{transport} {self.sample_string} {self.repeat} {MqttPublisher.published} {qos} {queue}')
|
||||
try:
|
||||
# waiting till subscribed to defined topic
|
||||
self.dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60)
|
||||
for _ in range(MqttPublisher.published):
|
||||
with self.lock:
|
||||
self.client.publish(self.publish_cfg['publish_topic'], self.sample_string * self.repeat, qos)
|
||||
self.print_details('Publishing...')
|
||||
self.print_details('Checking esp-client received msg published from py-client...')
|
||||
self.dut.expect(re.compile(rb'Correct pattern received exactly x times'), timeout=60)
|
||||
if not MqttPublisher.event_client_got_all.wait(timeout=60):
|
||||
raise ValueError('Not all data received from ESP32: {}'.format(transport))
|
||||
logging.info(' - all data received from ESP32')
|
||||
finally:
|
||||
self.event_stop_client.set()
|
||||
thread1.join()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None
|
||||
assert self.client is not None
|
||||
self.client.disconnect()
|
||||
self.event_stop_client.clear()
|
||||
# injecting connection not authorized error
|
||||
logging.info(' - received mqtt connect, sending NAK')
|
||||
self.request.send(bytearray.fromhex('20020005'))
|
||||
else:
|
||||
raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message))
|
||||
|
||||
|
||||
# Simple server for mqtt over TLS connection
|
||||
class TlsServer:
|
||||
class TlsServer(socketserver.TCPServer):
|
||||
timeout = 30.0
|
||||
allow_reuse_address = True
|
||||
allow_reuse_port = True
|
||||
|
||||
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): # type: (TlsServer, int, bool, bool, bool) -> None
|
||||
self.port = port
|
||||
self.socket = socket.socket()
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.settimeout(10.0)
|
||||
self.shutdown = Event()
|
||||
self.client_cert = client_cert
|
||||
def __init__(self,
|
||||
port:int = SERVER_PORT,
|
||||
ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler,
|
||||
client_cert:bool=False,
|
||||
refuse_connection:bool=False,
|
||||
use_alpn:bool=False):
|
||||
self.refuse_connection = refuse_connection
|
||||
self.use_alpn = use_alpn
|
||||
self.conn = socket.socket()
|
||||
self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
self.ssl_error = ''
|
||||
self.alpn_protocol: Optional[str] = None
|
||||
if client_cert:
|
||||
self.context.verify_mode = ssl.CERT_REQUIRED
|
||||
self.context.load_verify_locations(cafile=_path('ca.crt'))
|
||||
self.context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key'))
|
||||
if use_alpn:
|
||||
self.context.set_alpn_protocols(['mymqtt', 'http/1.1'])
|
||||
self.server_thread = Thread(target=self.serve_forever)
|
||||
super().__init__(('',port), ServerHandler)
|
||||
|
||||
def __enter__(self): # type: (TlsServer) -> TlsServer
|
||||
try:
|
||||
self.socket.bind(('', self.port))
|
||||
except socket.error as e:
|
||||
print('Bind failed:{}'.format(e))
|
||||
raise
|
||||
def server_activate(self) -> None:
|
||||
self.socket = self.context.wrap_socket(self.socket, server_side=True)
|
||||
super().server_activate()
|
||||
|
||||
self.socket.listen(1)
|
||||
self.server_thread = Thread(target=self.run_server)
|
||||
def __enter__(self): # type: ignore
|
||||
self.server_thread.start()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: (TlsServer, str, str, str) -> None
|
||||
self.shutdown.set()
|
||||
self.server_thread.join()
|
||||
self.socket.close()
|
||||
if (self.conn is not None):
|
||||
self.conn.close()
|
||||
def server_close(self) -> None:
|
||||
try:
|
||||
self.shutdown()
|
||||
self.server_thread.join()
|
||||
super().server_close()
|
||||
except RuntimeError as e:
|
||||
logging.exception(e)
|
||||
|
||||
def get_last_ssl_error(self): # type: (TlsServer) -> str
|
||||
# We need to override it here to capture ssl.SSLError
|
||||
# The implementation is a slightly modified version from cpython original code.
|
||||
def _handle_request_noblock(self) -> None:
|
||||
try:
|
||||
request, client_address = self.get_request()
|
||||
self.alpn_protocol = request.selected_alpn_protocol() # type: ignore
|
||||
except ssl.SSLError as e:
|
||||
self.ssl_error = e.reason
|
||||
return
|
||||
except OSError:
|
||||
return
|
||||
if self.verify_request(request, client_address):
|
||||
try:
|
||||
self.process_request(request, client_address)
|
||||
except Exception:
|
||||
self.handle_error(request, client_address)
|
||||
self.shutdown_request(request)
|
||||
except: # noqa: E722
|
||||
self.shutdown_request(request)
|
||||
raise
|
||||
else:
|
||||
self.shutdown_request(request)
|
||||
|
||||
def last_ssl_error(self): # type: (TlsServer) -> str
|
||||
return self.ssl_error
|
||||
|
||||
@typing.no_type_check
|
||||
def get_negotiated_protocol(self):
|
||||
return self.negotiated_protocol
|
||||
|
||||
def run_server(self) -> None:
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
if self.client_cert:
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
context.load_verify_locations(cafile=_path('ca.crt'))
|
||||
context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key'))
|
||||
if self.use_alpn:
|
||||
context.set_alpn_protocols(['mymqtt', 'http/1.1'])
|
||||
self.socket = context.wrap_socket(self.socket, server_side=True)
|
||||
try:
|
||||
self.conn, address = self.socket.accept() # accept new connection
|
||||
self.socket.settimeout(10.0)
|
||||
print(' - connection from: {}'.format(address))
|
||||
if self.use_alpn:
|
||||
self.negotiated_protocol = self.conn.selected_alpn_protocol()
|
||||
print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
|
||||
self.handle_conn()
|
||||
except ssl.SSLError as e:
|
||||
self.ssl_error = str(e)
|
||||
print(' - SSLError: {}'.format(str(e)))
|
||||
|
||||
def handle_conn(self) -> None:
|
||||
while not self.shutdown.is_set():
|
||||
r,w,e = select.select([self.conn], [], [], 1)
|
||||
try:
|
||||
if self.conn in r:
|
||||
self.process_mqtt_connect()
|
||||
|
||||
except socket.error as err:
|
||||
print(' - error: {}'.format(err))
|
||||
raise
|
||||
|
||||
def process_mqtt_connect(self) -> None:
|
||||
try:
|
||||
data = bytearray(self.conn.recv(1024))
|
||||
message = ''.join(format(x, '02x') for x in data)
|
||||
if message[0:16] == '101800044d515454':
|
||||
if self.refuse_connection is False:
|
||||
print(' - received mqtt connect, sending ACK')
|
||||
self.conn.send(bytearray.fromhex('20020000'))
|
||||
else:
|
||||
# injecting connection not authorized error
|
||||
print(' - received mqtt connect, sending NAK')
|
||||
self.conn.send(bytearray.fromhex('20020005'))
|
||||
else:
|
||||
raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message))
|
||||
finally:
|
||||
# stop the server after the connect message in happy flow, or if any exception occur
|
||||
self.shutdown.set()
|
||||
def get_negotiated_protocol(self) -> Optional[str]:
|
||||
return self.alpn_protocol
|
||||
|
||||
|
||||
def connection_tests(dut, cases, dut_ip): # type: (Dut, dict, str) -> None
|
||||
ip = get_host_ip4_by_dest_ip(dut_ip)
|
||||
set_server_cert_cn(ip)
|
||||
server_port = 2222
|
||||
|
||||
def teardown_connection_suite() -> None:
|
||||
dut.write('conn teardown 0 0\n')
|
||||
|
||||
def start_connection_case(case, desc): # type: (str, str) -> Any
|
||||
print('Starting {}: {}'.format(case, desc))
|
||||
case_id = cases[case]
|
||||
dut.write('conn {} {} {}\n'.format(ip, server_port, case_id))
|
||||
dut.expect('Test case:{} started'.format(case_id))
|
||||
return case_id
|
||||
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
|
||||
# All these cases connect to the server with no server verification or with server only verification
|
||||
with TlsServer(server_port):
|
||||
test_nr = start_connection_case(case, 'default server - expect to connect normally')
|
||||
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
||||
with TlsServer(server_port, refuse_connection=True):
|
||||
test_nr = start_connection_case(case, 'ssl shall connect, but mqtt sends connect refusal')
|
||||
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
||||
dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
|
||||
with TlsServer(server_port, client_cert=True) as s:
|
||||
test_nr = start_connection_case(case, 'server with client verification - handshake error since client presents no client certificate')
|
||||
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
||||
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
|
||||
if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
|
||||
raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
||||
|
||||
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
|
||||
# These cases connect to server with both server and client verification (client key might be password protected)
|
||||
with TlsServer(server_port, client_cert=True):
|
||||
test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
|
||||
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
||||
|
||||
case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
|
||||
with TlsServer(server_port) as s:
|
||||
test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
|
||||
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
||||
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
|
||||
if 'alert unknown ca' not in s.get_last_ssl_error():
|
||||
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
||||
|
||||
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
|
||||
with TlsServer(server_port, client_cert=True) as s:
|
||||
test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
|
||||
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
||||
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
|
||||
if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
|
||||
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
||||
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
||||
with TlsServer(server_port, use_alpn=True) as s:
|
||||
test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
|
||||
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
||||
if case == 'EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
|
||||
print(' - client with alpn off, no negotiated protocol: OK')
|
||||
elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
|
||||
print(' - client with alpn on, negotiated protocol resolved: OK')
|
||||
else:
|
||||
raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))
|
||||
|
||||
teardown_connection_suite()
|
||||
|
||||
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None:
|
||||
"""
|
||||
steps:
|
||||
1. join AP
|
||||
2. connect to uri specified in the config
|
||||
3. send and receive data
|
||||
"""
|
||||
# check and log bin size
|
||||
binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin')
|
||||
bin_size = os.path.getsize(binary_file)
|
||||
logging.info('[Performance][mqtt_publish_connect_test_bin_size]: %s KB', bin_size // 1024)
|
||||
|
||||
# Look for test case symbolic names and publish configs
|
||||
def get_test_cases(dut: Dut) -> Any:
|
||||
cases = {}
|
||||
publish_cfg = {}
|
||||
try:
|
||||
|
||||
# Get connection test cases configuration: symbolic names for test cases
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT',
|
||||
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
|
||||
@ -360,63 +135,107 @@ def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None:
|
||||
'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
||||
cases[case] = dut.app.sdkconfig.get(case)
|
||||
except Exception:
|
||||
print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
|
||||
logging.error('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
|
||||
raise
|
||||
return cases
|
||||
|
||||
esp_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode()
|
||||
print('Got IP={}'.format(esp_ip))
|
||||
|
||||
connection_tests(dut,cases,esp_ip)
|
||||
def get_dut_ip(dut: Dut) -> Any:
|
||||
dut_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode()
|
||||
logging.info('Got IP={}'.format(dut_ip))
|
||||
return get_host_ip4_by_dest_ip(dut_ip)
|
||||
|
||||
# Get publish test configuration
|
||||
|
||||
@contextlib.contextmanager
|
||||
def connect_dut(dut: Dut, uri:str, case_id:int) -> Any:
|
||||
dut.write('connection_setup')
|
||||
dut.write(f'connect {uri} {case_id}')
|
||||
dut.expect(f'Test case:{case_id} started')
|
||||
dut.write('reconnect')
|
||||
yield
|
||||
dut.write('connection_teardown')
|
||||
dut.write('disconnect')
|
||||
|
||||
|
||||
def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
|
||||
try:
|
||||
@typing.no_type_check
|
||||
def get_host_port_from_dut(dut, config_option):
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option))
|
||||
if value is None:
|
||||
return None, None
|
||||
return value.group(1), int(value.group(2))
|
||||
dut.write('init')
|
||||
dut.write(f'start')
|
||||
dut.write(f'disconnect')
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
|
||||
# All these cases connect to the server with no server verification or with server only verification
|
||||
with TlsServer(), connect_dut(dut, uri, cases[case]):
|
||||
logging.info(f'Running {case}: default server - expect to connect normally')
|
||||
dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30)
|
||||
with TlsServer(refuse_connection=True), connect_dut(dut, uri, cases[case]):
|
||||
logging.info(f'Running {case}: ssl shall connect, but mqtt sends connect refusal')
|
||||
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
|
||||
dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
|
||||
with TlsServer(client_cert=True) as server, connect_dut(dut, uri, cases[case]):
|
||||
logging.info(f'Running {case}: server with client verification - handshake error since client presents no client certificate')
|
||||
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
|
||||
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
|
||||
assert 'PEER_DID_NOT_RETURN_A_CERTIFICATE' in server.last_ssl_error()
|
||||
|
||||
publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','')
|
||||
publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','')
|
||||
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
|
||||
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
|
||||
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
|
||||
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
|
||||
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
|
||||
# These cases connect to server with both server and client verification (client key might be password protected)
|
||||
with TlsServer(client_cert=True), connect_dut(dut, uri, cases[case]):
|
||||
logging.info(f'Running {case}: server with client verification - expect to connect normally')
|
||||
dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30)
|
||||
|
||||
except Exception:
|
||||
logging.error('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
|
||||
raise
|
||||
case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
|
||||
with TlsServer() as s, connect_dut(dut, uri, cases[case]):
|
||||
logging.info(f'Running {case}: invalid server certificate on default server - expect ssl handshake error')
|
||||
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
|
||||
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
|
||||
if re.match('.*alert.*unknown.*ca',s.last_ssl_error(), flags=re.I) is None:
|
||||
raise Exception(f'Unexpected ssl error from the server: {s.last_ssl_error()}')
|
||||
|
||||
# Initialize message sizes and repeat counts (if defined in the environment)
|
||||
messages = []
|
||||
for i in count(0):
|
||||
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
|
||||
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
|
||||
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
|
||||
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
|
||||
continue
|
||||
break
|
||||
if not messages: # No message sizes present in the env - set defaults
|
||||
messages = [{'len':0, 'repeat':5}, # zero-sized messages
|
||||
{'len':2, 'repeat':10}, # short messages
|
||||
{'len':200, 'repeat':3}, # long messages
|
||||
{'len':20, 'repeat':50} # many medium sized
|
||||
]
|
||||
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
|
||||
with TlsServer(client_cert=True) as s, connect_dut(dut, uri, cases[case]):
|
||||
logging.info(f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error')
|
||||
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
|
||||
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
|
||||
if 'CERTIFICATE_VERIFY_FAILED' not in s.last_ssl_error():
|
||||
raise Exception('Unexpected ssl error from the server {}'.format(s.last_ssl_error()))
|
||||
|
||||
# Iterate over all publish message properties
|
||||
for transport in ['tcp', 'ssl', 'ws', 'wss']:
|
||||
if publish_cfg['broker_host_' + transport] is None:
|
||||
print('Skipping transport: {}...'.format(transport))
|
||||
continue
|
||||
for enqueue in [0, 1]:
|
||||
for qos in [0, 1, 2]:
|
||||
for msg in messages:
|
||||
logging.info(f'Starting Publish test: transport:{transport}, qos:{qos}, nr_of_msgs:{msg["repeat"]},'
|
||||
f'msg_size:{msg["len"] * DEFAULT_MSG_SIZE}, enqueue:{enqueue}')
|
||||
with MqttPublisher(dut, transport, qos, msg['len'], msg['repeat'], enqueue, publish_cfg):
|
||||
pass
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
||||
with TlsServer(use_alpn=True) as s, connect_dut(dut, uri, cases[case]):
|
||||
logging.info(f'Running {case}: server with alpn - expect connect, check resolved protocol')
|
||||
dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30)
|
||||
if case == 'EXAMPLE_CONNECT_CASE_NO_CERT':
|
||||
assert s.get_negotiated_protocol() is None
|
||||
elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN':
|
||||
assert s.get_negotiated_protocol() == 'mymqtt'
|
||||
else:
|
||||
assert False, f'Unexpected negotiated protocol {s.get_negotiated_protocol()}'
|
||||
finally:
|
||||
dut.write('stop')
|
||||
dut.write('destroy')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_app_protocol_mqtt_publish_connect(dut=QemuDut if sys.argv[1:] == ['qemu'] else Dut)
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
def test_mqtt_connect(
|
||||
dut: Dut,
|
||||
log_performance: Callable[[str, object], None],
|
||||
) -> None:
|
||||
"""
|
||||
steps:
|
||||
1. join AP
|
||||
2. connect to uri specified in the config
|
||||
3. send and receive data
|
||||
"""
|
||||
# check and log bin size
|
||||
binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin')
|
||||
bin_size = os.path.getsize(binary_file)
|
||||
log_performance('mqtt_publish_connect_test_bin_size', f'{bin_size // 1024} KB')
|
||||
|
||||
ip = get_dut_ip(dut)
|
||||
set_server_cert_cn(ip)
|
||||
uri = f'mqtts://{ip}:{SERVER_PORT}'
|
||||
|
||||
# Look for test case symbolic names and publish configs
|
||||
cases = get_test_cases(dut)
|
||||
dut.expect_exact('mqtt>', timeout=30)
|
||||
run_cases(dut, uri, cases)
|
||||
|
@ -0,0 +1,229 @@
|
||||
# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD
|
||||
# SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
|
||||
import contextlib
|
||||
import difflib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import ssl
|
||||
import string
|
||||
from itertools import count, product
|
||||
from threading import Event, Lock
|
||||
from typing import Any, Dict, List, Tuple, no_type_check
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import pexpect
|
||||
import pytest
|
||||
from pytest_embedded import Dut
|
||||
|
||||
DEFAULT_MSG_SIZE = 16
|
||||
|
||||
|
||||
# Publisher class creating a python client to send/receive published data from esp-mqtt client
|
||||
class MqttPublisher(mqtt.Client):
|
||||
|
||||
def __init__(self, repeat, published, publish_cfg, log_details=False): # type: (MqttPublisher, int, int, dict, bool) -> None
|
||||
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
|
||||
self.log_details = log_details
|
||||
self.repeat = repeat
|
||||
self.publish_cfg = publish_cfg
|
||||
self.expected_data = f'{self.sample_string * self.repeat}'
|
||||
self.published = published
|
||||
self.received = 0
|
||||
self.lock = Lock()
|
||||
self.event_client_connected = Event()
|
||||
self.event_client_got_all = Event()
|
||||
transport = 'websockets' if self.publish_cfg['transport'] in ['ws', 'wss'] else 'tcp'
|
||||
super().__init__('MqttTestRunner', userdata=0, transport=transport)
|
||||
|
||||
def print_details(self, text): # type: (str) -> None
|
||||
if self.log_details:
|
||||
logging.info(text)
|
||||
|
||||
def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc:int) -> None:
|
||||
self.event_client_connected.set()
|
||||
|
||||
def on_connect_fail(self, mqttc: Any, obj: Any) -> None:
|
||||
logging.error('Connect failed')
|
||||
|
||||
def on_message(self, mqttc: Any, userdata: Any, msg: mqtt.MQTTMessage) -> None:
|
||||
payload = msg.payload.decode('utf-8')
|
||||
if payload == self.expected_data:
|
||||
userdata += 1
|
||||
self.user_data_set(userdata)
|
||||
self.received = userdata
|
||||
if userdata == self.published:
|
||||
self.event_client_got_all.set()
|
||||
else:
|
||||
differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, self.expected_data))))
|
||||
logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:'
|
||||
f'{len(self.expected_data)}')
|
||||
logging.info(f'Repetitions: {payload.count(self.sample_string)}')
|
||||
logging.info(f'Pattern: {self.sample_string}')
|
||||
logging.info(f'First : {payload[:DEFAULT_MSG_SIZE]}')
|
||||
logging.info(f'Last : {payload[-DEFAULT_MSG_SIZE:]}')
|
||||
matcher = difflib.SequenceMatcher(a=payload, b=self.expected_data)
|
||||
for match in matcher.get_matching_blocks():
|
||||
logging.info(f'Match: {match}')
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
qos = self.publish_cfg['qos']
|
||||
broker_host = self.publish_cfg['broker_host_' + self.publish_cfg['transport']]
|
||||
broker_port = self.publish_cfg['broker_port_' + self.publish_cfg['transport']]
|
||||
|
||||
try:
|
||||
self.print_details('Connecting...')
|
||||
if self.publish_cfg['transport'] in ['ssl', 'wss']:
|
||||
self.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
|
||||
self.tls_insecure_set(True)
|
||||
self.event_client_connected.clear()
|
||||
self.loop_start()
|
||||
self.connect(broker_host, broker_port, 60)
|
||||
except Exception:
|
||||
self.print_details(f'ENV_TEST_FAILURE: Unexpected error while connecting to broker {broker_host}')
|
||||
raise
|
||||
self.print_details(f'Connecting py-client to broker {broker_host}:{broker_port}...')
|
||||
|
||||
if not self.event_client_connected.wait(timeout=30):
|
||||
raise ValueError(f'ENV_TEST_FAILURE: Test script cannot connect to broker: {broker_host}')
|
||||
self.event_client_got_all.clear()
|
||||
self.subscribe(self.publish_cfg['subscribe_topic'], qos)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None
|
||||
self.disconnect()
|
||||
self.loop_stop()
|
||||
|
||||
|
||||
def get_configurations(dut: Dut) -> Dict[str,Any]:
|
||||
publish_cfg = {}
|
||||
try:
|
||||
@no_type_check
|
||||
def get_broker_from_dut(dut, config_option):
|
||||
# logging.info('Option:', config_option, dut.app.sdkconfig.get(config_option))
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option))
|
||||
if value is None:
|
||||
return None, None
|
||||
return value.group(1), int(value.group(2))
|
||||
# Get publish test configuration
|
||||
publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','')
|
||||
publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','')
|
||||
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
|
||||
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
|
||||
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
|
||||
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
|
||||
|
||||
except Exception:
|
||||
logging.info('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
|
||||
raise
|
||||
logging.info(f'configuration: {publish_cfg}')
|
||||
return publish_cfg
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def connected_and_subscribed(dut:Dut, transport:str, pattern:str, pattern_repetitions:int) -> Any:
|
||||
dut.write(f'publish_setup {transport} {pattern} {pattern_repetitions}')
|
||||
dut.write(f'start')
|
||||
dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60)
|
||||
yield
|
||||
dut.write(f'stop')
|
||||
|
||||
|
||||
def get_scenarios() -> List[Dict[str, int]]:
|
||||
scenarios = []
|
||||
# Initialize message sizes and repeat counts (if defined in the environment)
|
||||
for i in count(0):
|
||||
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
|
||||
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
|
||||
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
|
||||
scenarios.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
|
||||
continue
|
||||
break
|
||||
if not scenarios: # No message sizes present in the env - set defaults
|
||||
scenarios = [{'len':0, 'repeat':5}, # zero-sized messages
|
||||
{'len':2, 'repeat':5}, # short messages
|
||||
{'len':200, 'repeat':3}, # long messages
|
||||
]
|
||||
return scenarios
|
||||
|
||||
|
||||
def get_timeout(test_case: Any) -> int:
|
||||
transport, qos, enqueue, scenario = test_case
|
||||
if transport in ['ws', 'wss'] or qos == 2:
|
||||
return 90
|
||||
return 60
|
||||
|
||||
|
||||
def run_publish_test_case(dut: Dut, test_case: Any, publish_cfg: Any) -> None:
|
||||
transport, qos, enqueue, scenario = test_case
|
||||
if publish_cfg['broker_host_' + transport] is None:
|
||||
pytest.skip(f'Skipping transport: {transport}...')
|
||||
repeat = scenario['len']
|
||||
published = scenario['repeat']
|
||||
publish_cfg['qos'] = qos
|
||||
publish_cfg['queue'] = enqueue
|
||||
publish_cfg['transport'] = transport
|
||||
test_timeout = get_timeout(test_case)
|
||||
logging.info(f'Starting Publish test: transport:{transport}, qos:{qos}, nr_of_msgs:{published},'
|
||||
f' msg_size:{repeat*DEFAULT_MSG_SIZE}, enqueue:{enqueue}')
|
||||
with MqttPublisher(repeat, published, publish_cfg) as publisher, connected_and_subscribed(dut, transport, publisher.sample_string, scenario['len']):
|
||||
msgs_published: List[mqtt.MQTTMessageInfo] = []
|
||||
dut.write(f'publish {publisher.published} {qos} {enqueue}')
|
||||
assert publisher.event_client_got_all.wait(timeout=test_timeout), (f'Not all data received from ESP32: {transport} '
|
||||
f'qos={qos} received: {publisher.received} '
|
||||
f'expected: {publisher.published}')
|
||||
logging.info(' - all data received from ESP32')
|
||||
payload = publisher.sample_string * publisher.repeat
|
||||
for _ in range(publisher.published):
|
||||
with publisher.lock:
|
||||
msg = publisher.publish(topic=publisher.publish_cfg['publish_topic'], payload=payload, qos=qos)
|
||||
if qos > 0:
|
||||
msgs_published.append(msg)
|
||||
logging.info(f'Published: {len(msgs_published)}')
|
||||
while msgs_published:
|
||||
msgs_published = [msg for msg in msgs_published if msg.is_published()]
|
||||
|
||||
try:
|
||||
dut.expect(re.compile(rb'Correct pattern received exactly x times'), timeout=test_timeout)
|
||||
except pexpect.exceptions.ExceptionPexpect:
|
||||
dut.write(f'publish_report')
|
||||
dut.expect(re.compile(rb'Test Report'), timeout=30)
|
||||
raise
|
||||
logging.info('ESP32 received all data from runner')
|
||||
|
||||
|
||||
stress_scenarios = [{'len':20, 'repeat':50}] # many medium sized
|
||||
transport_cases = ['tcp', 'ws', 'wss', 'ssl']
|
||||
qos_cases = [0, 1, 2]
|
||||
enqueue_cases = [0, 1]
|
||||
|
||||
|
||||
def make_cases(scenarios: List[Dict[str, int]]) -> List[Tuple[str, int, int, Dict[str, int]]]:
|
||||
return [test_case for test_case in product(transport_cases, qos_cases, enqueue_cases, scenarios)]
|
||||
|
||||
|
||||
test_cases = make_cases(get_scenarios())
|
||||
stress_test_cases = make_cases(stress_scenarios)
|
||||
|
||||
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
@pytest.mark.parametrize('test_case', test_cases)
|
||||
def test_mqtt_publish(dut: Dut, test_case: Any) -> None:
|
||||
publish_cfg = get_configurations(dut)
|
||||
dut.expect(re.compile(rb'mqtt>'), timeout=30)
|
||||
dut.confirm_write('init', expect_pattern='init', timeout=30)
|
||||
run_publish_test_case(dut, test_case, publish_cfg)
|
||||
|
||||
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
@pytest.mark.nightly_run
|
||||
@pytest.mark.parametrize('test_case', stress_test_cases)
|
||||
def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None:
|
||||
publish_cfg = get_configurations(dut)
|
||||
dut.expect(re.compile(rb'mqtt>'), timeout=30)
|
||||
dut.write('init')
|
||||
run_publish_test_case(dut, test_case, publish_cfg)
|
Loading…
x
Reference in New Issue
Block a user