From acc7bd2ca45c21033cbd02220a27c3c1ecdd5ad0 Mon Sep 17 00:00:00 2001 From: David Cermak Date: Mon, 11 Jan 2021 18:15:13 +0100 Subject: [PATCH] ws_transport: Add option to propagate control packets to the app Client could choose if they want to receive control packets and handle them. * If disabled (default) the transport itself tries to handle PING and CLOSE frames automatically during read operation. If handled correctly, read outputs 0 indicating no (actual app) data received. * if enabled, all control frames are passed to the application to be processed there. Closes https://github.com/espressif/esp-idf/issues/6307 --- .../esp_websocket_client.c | 44 ++-- .../tcp_transport/include/esp_transport_ws.h | 29 +++ .../private_include/esp_transport_utils.h | 12 + components/tcp_transport/transport_ws.c | 230 +++++++++++++++++- .../websocket/main/websocket_example.c | 1 + 5 files changed, 296 insertions(+), 20 deletions(-) diff --git a/components/esp_websocket_client/esp_websocket_client.c b/components/esp_websocket_client/esp_websocket_client.c index e396c3fb0f..bd3fa0e42d 100644 --- a/components/esp_websocket_client/esp_websocket_client.c +++ b/components/esp_websocket_client/esp_websocket_client.c @@ -46,6 +46,14 @@ static const char *TAG = "WEBSOCKET_CLIENT"; action; \ } +#define ESP_WS_CLIENT_ERR_OK_CHECK(TAG, err, action) { \ + esp_err_t _esp_ws_err_to_check = err; \ + if (_esp_ws_err_to_check != ESP_OK) { \ + ESP_LOGE(TAG,"%s(%d): Expected ESP_OK; reported: %d", __FUNCTION__, __LINE__, _esp_ws_err_to_check); \ + action; \ + } \ + } + #define ESP_WS_CLIENT_STATE_CHECK(TAG, a, action) if ((a->state) < WEBSOCKET_STATE_INIT) { \ ESP_LOGE(TAG,"%s:%d (%s): %s", __FILE__, __LINE__, __FUNCTION__, "Websocket already stop"); \ action; \ @@ -258,20 +266,20 @@ static esp_err_t esp_websocket_client_destroy_config(esp_websocket_client_handle return ESP_OK; } -static void set_websocket_transport_optional_settings(esp_websocket_client_handle_t client, esp_transport_handle_t trans) +static esp_err_t set_websocket_transport_optional_settings(esp_websocket_client_handle_t client, const char *scheme) { - if (trans && client->config->path) { - esp_transport_ws_set_path(trans, client->config->path); - } - if (trans && client->config->subprotocol) { - esp_transport_ws_set_subprotocol(trans, client->config->subprotocol); - } - if (trans && client->config->user_agent) { - esp_transport_ws_set_user_agent(trans, client->config->user_agent); - } - if (trans && client->config->headers) { - esp_transport_ws_set_headers(trans, client->config->headers); + esp_transport_handle_t trans = esp_transport_list_get_transport(client->transport_list, scheme); + if (trans) { + const esp_transport_ws_config_t config = { + .ws_path = client->config->path, + .sub_protocol = client->config->subprotocol, + .user_agent = client->config->user_agent, + .headers = client->config->headers, + .propagate_control_frames = true + }; + return esp_transport_ws_set_config(trans, &config); } + return ESP_ERR_INVALID_ARG; } esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_client_config_t *config) @@ -376,8 +384,8 @@ esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_clie ESP_WS_CLIENT_MEM_CHECK(TAG, client->config->scheme, goto _websocket_init_fail); } - set_websocket_transport_optional_settings(client, esp_transport_list_get_transport(client->transport_list, "ws")); - set_websocket_transport_optional_settings(client, esp_transport_list_get_transport(client->transport_list, "wss")); + ESP_WS_CLIENT_ERR_OK_CHECK(TAG, set_websocket_transport_optional_settings(client, "ws"), goto _websocket_init_fail;) + ESP_WS_CLIENT_ERR_OK_CHECK(TAG, set_websocket_transport_optional_settings(client, "wss"), goto _websocket_init_fail;) client->keepalive_tick_ms = _tick_get_ms(); client->reconnect_tick_ms = _tick_get_ms(); @@ -510,6 +518,11 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client) client->payload_len = esp_transport_ws_get_read_payload_len(client->transport); client->last_opcode = esp_transport_ws_get_read_opcode(client->transport); + if (rlen == 0 && client->last_opcode == WS_TRANSPORT_OPCODES_NONE ) { + ESP_LOGV(TAG, "esp_transport_read timeouts"); + return ESP_OK; + } + esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen); client->payload_offset += rlen; @@ -518,6 +531,7 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client) // if a PING message received -> send out the PONG, this will not work for PING messages with payload longer than buffer len if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) { const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer; + ESP_LOGD(TAG, "Sending PONG with payload len=%d", client->payload_len); esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG | WS_TRANSPORT_OPCODES_FIN, data, client->payload_len, client->config->network_timeout_ms); } else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) { @@ -557,7 +571,7 @@ static void esp_websocket_client_task(void *pv) int read_select = 0; while (client->run) { if (xSemaphoreTakeRecursive(client->lock, lock_timeout) != pdPASS) { - ESP_LOGE(TAG, "Failed to lock ws-client tasks, exitting the task..."); + ESP_LOGE(TAG, "Failed to lock ws-client tasks, exiting the task..."); break; } switch ((int)client->state) { diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index febe1d0bb9..338b3fcd01 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -8,6 +8,7 @@ #define _ESP_TRANSPORT_WS_H_ #include "esp_transport.h" +#include #ifdef __cplusplus extern "C" { @@ -21,8 +22,24 @@ typedef enum ws_transport_opcodes { WS_TRANSPORT_OPCODES_PING = 0x09, WS_TRANSPORT_OPCODES_PONG = 0x0a, WS_TRANSPORT_OPCODES_FIN = 0x80, + WS_TRANSPORT_OPCODES_NONE = 0x100, /*!< not a valid opcode to indicate no message previously received + * from the API esp_transport_ws_get_read_opcode() */ } ws_transport_opcodes_t; +/** + * WS transport configuration structure + */ +typedef struct { + const char *ws_path; /*!< HTTP path to update protocol to websocket */ + const char *sub_protocol; /*!< WS subprotocol */ + const char *user_agent; /*!< WS user agent */ + const char *headers; /*!< WS additional headers */ + bool propagate_control_frames; /*!< If true, control frames are passed to the reader + * If false, only user frames are propagated, control frames are handled + * automatically during read operations + */ +} esp_transport_ws_config_t; + /** * @brief Create web socket transport * @@ -76,6 +93,18 @@ esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char * */ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers); +/** + * @brief Set websocket transport parameters + * + * @param t websocket transport handle + * @param config pointer to websocket config structure + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transport_ws_config_t *config); + /** * @brief Sends websocket raw message with custom opcode and payload * diff --git a/components/tcp_transport/private_include/esp_transport_utils.h b/components/tcp_transport/private_include/esp_transport_utils.h index 2aa95119d8..2c86dbe454 100644 --- a/components/tcp_transport/private_include/esp_transport_utils.h +++ b/components/tcp_transport/private_include/esp_transport_utils.h @@ -29,6 +29,18 @@ extern "C" { action; \ } +/** + * @brief Utility macro for checking the error code of esp_err_t + */ +#define ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, action) \ + { \ + esp_err_t _esp_transport_err_to_check = err; \ + if (_esp_transport_err_to_check != ESP_OK) { \ + ESP_LOGE(TAG,"%s(%d): Expected ESP_OK; reported: %d", __FUNCTION__, __LINE__, _esp_transport_err_to_check); \ + action; \ + } \ + } + /** * @brief Convert milliseconds to timeval struct for valid timeouts, otherwise * (if "wait forever" requested by timeout_ms=-1) timeval structure is not updated and NULL returned diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 344491d6b8..b5043392cf 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -22,6 +22,7 @@ static const char *TAG = "TRANSPORT_WS"; #define WS_OPCODE_CLOSE 0x08 #define WS_OPCODE_PING 0x09 #define WS_OPCODE_PONG 0x0a +#define WS_OPCODE_CONTROL_FRAME 0x08 // Second byte #define WS_MASK 0x80 @@ -29,6 +30,7 @@ static const char *TAG = "TRANSPORT_WS"; #define WS_SIZE64 127 #define MAX_WEBSOCKET_HEADER_SIZE 16 #define WS_RESPONSE_OK 101 +#define WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN 125 typedef struct { @@ -36,6 +38,7 @@ typedef struct { char mask_key[4]; /*!< Mask key for this payload */ int payload_len; /*!< Total length of the payload */ int bytes_remaining; /*!< Bytes left to read of the payload */ + bool header_received; /*!< Flag to indicate that a new message header was received */ } ws_transport_frame_state_t; typedef struct { @@ -44,10 +47,33 @@ typedef struct { char *sub_protocol; char *user_agent; char *headers; + bool propagate_control_frames; ws_transport_frame_state_t frame_state; esp_transport_handle_t parent; } transport_ws_t; +/** + * @brief Handles control frames + * + * This API is used internally to handle control frames at the transport layer. + * The API could be possibly promoted to a public API if needed by some clients + * + * @param t Websocket transport handle + * @param buffer Buffer with the actual payload of the control packet to be processed + * @param len Length of the buffer (typically the same as the payload buffer) + * @param timeout_ms The timeout milliseconds + * @param client_closed To indicate that the connection has been closed by the client +* (to prevent echoing the CLOSE packet if true, as this is the actual echo from the server) + * + * @return + * 0 - no activity, or successfully responded to PING + * -1 - Failure: Error on read or the actual payload longer then buffer + * 1 - Close handshake success + * 2 - Got PONG message + */ + +static int esp_transport_ws_handle_control_frames(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, bool client_closed); + static inline uint8_t ws_get_bin_opcode(ws_transport_opcodes_t opcode) { return (uint8_t)opcode; @@ -333,6 +359,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t char *data_ptr = ws_header, mask; int rlen; int poll_read; + ws->frame_state.header_received = false; if ((poll_read = esp_transport_poll_read(ws->parent, timeout_ms)) <= 0) { return poll_read; } @@ -344,6 +371,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t ESP_LOGE(TAG, "Error read data"); return rlen; } + ws->frame_state.header_received = true; ws->frame_state.opcode = (*data_ptr & 0x0F); data_ptr ++; mask = ((*data_ptr >> 7) & 0x01); @@ -390,6 +418,56 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t return payload_len; } +static int ws_handle_control_frame_internal(esp_transport_handle_t t, int timeout_ms) +{ + transport_ws_t *ws = esp_transport_get_context_data(t); + char *control_frame_buffer = NULL; + int control_frame_buffer_len = 0; + int payload_len = ws->frame_state.payload_len; + int ret = 0; + + // If no new header reception in progress, or not a control frame + // just pass 0 -> no need to handle control frames + if (ws->frame_state.header_received == false || + !(ws->frame_state.opcode & WS_OPCODE_CONTROL_FRAME)) { + return 0; + } + + if (payload_len > WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN) { + ESP_LOGE(TAG, "Not enough room for reading control frames (need=%d, max_allowed=%d)", + ws->frame_state.payload_len, WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN); + return -1; + } + + // Now we can handle the control frame correctly (either zero payload, or a short one for which we allocate mem) + control_frame_buffer_len = payload_len; + if (control_frame_buffer_len > 0) { + control_frame_buffer = malloc(control_frame_buffer_len); + if (control_frame_buffer == NULL) { + ESP_LOGE(TAG, "Cannot allocate buffer for control frames, need-%d", control_frame_buffer_len); + return -1; + } + } else { + control_frame_buffer_len = 0; + } + + // read the payload of the control frame + int actual_len = ws_read_payload(t, control_frame_buffer, control_frame_buffer_len, timeout_ms); + if (actual_len != payload_len) { + ESP_LOGE(TAG, "Control frame (opcode=%d) payload read failed (payload_len=%d, read_len=%d)", + ws->frame_state.opcode, payload_len, actual_len); + ret = -1; + goto free_payload_buffer; + } + + ret = esp_transport_ws_handle_control_frames(t, control_frame_buffer, control_frame_buffer_len, timeout_ms, false); + +free_payload_buffer: + free(control_frame_buffer); + return ret > 0 ? 0 : ret; // We don't propagate control frames, pass 0 to upper layers + +} + static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { int rlen = 0; @@ -397,12 +475,28 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ // If message exceeds buffer len then subsequent reads will skip reading header and read whatever is left of the payload if (ws->frame_state.bytes_remaining <= 0) { - if ( (rlen = ws_read_header(t, buffer, len, timeout_ms)) <= 0) { + + if ( (rlen = ws_read_header(t, buffer, len, timeout_ms)) < 0) { // If something when wrong then we prepare for reading a new header ws->frame_state.bytes_remaining = 0; return rlen; } + + // If the new opcode is a control frame and we don't pass it to the app + // - try to handle it internally using the application buffer + if (ws->frame_state.header_received && (ws->frame_state.opcode & WS_OPCODE_CONTROL_FRAME) && + ws->propagate_control_frames == false) { + // automatically handle only 0 payload frames and make the transport read to return 0 on success + // which might be interpreted as timeouts + return ws_handle_control_frame_internal(t, timeout_ms); + } + + if (rlen == 0) { + ws->frame_state.bytes_remaining = 0; + return 0; // timeout + } } + if (ws->frame_state.payload_len) { if ( (rlen = ws_read_payload(t, buffer, len, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error reading payload data"); @@ -444,11 +538,32 @@ static esp_err_t ws_destroy(esp_transport_handle_t t) free(ws); return 0; } +static esp_err_t internal_esp_transport_ws_set_path(esp_transport_handle_t t, const char *path) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + if (ws->path) { + free(ws->path); + } + if (path == NULL) { + ws->path = NULL; + return ESP_OK; + } + ws->path = strdup(path); + if (ws->path == NULL) { + return ESP_ERR_NO_MEM; + } + return ESP_OK; +} + void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path) { - transport_ws_t *ws = esp_transport_get_context_data(t); - ws->path = realloc(ws->path, strlen(path) + 1); - strcpy(ws->path, path); + esp_err_t err = internal_esp_transport_ws_set_path(t, path); + if (err != ESP_OK) { + ESP_LOGE(TAG, "esp_transport_ws_set_path has internally failed with err=%d", err); + } } static int ws_get_socket(esp_transport_handle_t t) @@ -550,10 +665,42 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea return ESP_OK; } +esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transport_ws_config_t *config) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + esp_err_t err = ESP_OK; + transport_ws_t *ws = esp_transport_get_context_data(t); + if (config->ws_path) { + err = internal_esp_transport_ws_set_path(t, config->ws_path); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } + if (config->sub_protocol) { + err = esp_transport_ws_set_subprotocol(t, config->sub_protocol); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } + if (config->user_agent) { + err = esp_transport_ws_set_user_agent(t, config->user_agent); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } + if (config->headers) { + err = esp_transport_ws_set_headers(t, config->headers); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } + ws->propagate_control_frames = config->propagate_control_frames; + + return err; +} + ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t); - return ws->frame_state.opcode; + if (ws->frame_state.header_received) { + // convert the header byte to enum if correctly received + return (ws_transport_opcodes_t)ws->frame_state.opcode; + } + return WS_TRANSPORT_OPCODES_NONE; } int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t) @@ -562,6 +709,79 @@ int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t) return ws->frame_state.payload_len; } +static int esp_transport_ws_handle_control_frames(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, bool client_closed) +{ + transport_ws_t *ws = esp_transport_get_context_data(t); + + // If no new header reception in progress, or not a control frame + // just pass 0 -> no need to handle control frames + if (ws->frame_state.header_received == false || + !(ws->frame_state.opcode & WS_OPCODE_CONTROL_FRAME)) { + return 0; + } + int actual_len; + int payload_len = ws->frame_state.payload_len; + + ESP_LOGD(TAG, "Handling control frame with %d bytes payload", payload_len); + if (payload_len > len) { + ESP_LOGE(TAG, "Not enough room for processing the payload (need=%d, available=%d)", payload_len, len); + ws->frame_state.bytes_remaining = payload_len - len; + return -1; + } + + if (ws->frame_state.opcode == WS_OPCODE_PING) { + // handle PING frames internally: just send a PONG with the same payload + actual_len = _ws_write(t, WS_OPCODE_PONG | WS_FIN, WS_MASK, buffer, + payload_len, timeout_ms); + if (actual_len != payload_len) { + ESP_LOGE(TAG, "PONG send failed (payload_len=%d, written_len=%d)", payload_len, actual_len); + return -1; + } + ESP_LOGD(TAG, "PONG sent correctly (payload_len=%d)", payload_len); + + // control frame handled correctly, reset the flag indicating new header received + ws->frame_state.header_received = false; + return 0; + + } else if (ws->frame_state.opcode == WS_OPCODE_CLOSE) { + // handle CLOSE by the server: send a zero payload frame + if (buffer && payload_len > 0) { // if some payload, print out the status code + uint16_t *code_network_order = (uint16_t *) buffer; + ESP_LOGI(TAG, "Got CLOSE frame with status code=%u", ntohs(*code_network_order)); + } + + if (client_closed == false) { + // Only echo the closing frame if not initiated by the client + if (_ws_write(t, WS_OPCODE_CLOSE | WS_FIN, WS_MASK, NULL,0, timeout_ms) < 0) { + ESP_LOGE(TAG, "Sending CLOSE frame with 0 payload failed"); + return -1; + } + ESP_LOGD(TAG, "CLOSE frame with no payload sent correctly"); + } + + // control frame handled correctly, reset the flag indicating new header received + ws->frame_state.header_received = false; + int ret = esp_transport_ws_poll_connection_closed(t, timeout_ms); + if (ret == 0) { + ESP_LOGW(TAG, "Connection cannot be terminated gracefully within timeout=%d", timeout_ms); + return -1; + } + if (ret < 0) { + ESP_LOGW(TAG, "Connection terminated while waiting for clean TCP close"); + return -1; + } + ESP_LOGI(TAG, "Connection terminated gracefully"); + return 1; + } else if (ws->frame_state.opcode == WS_OPCODE_PONG) { + // handle PONG: just indicate return code + ESP_LOGD(TAG, "Received PONG frame with payload=%d", payload_len); + // control frame handled correctly, reset the flag indicating new header received + ws->frame_state.header_received = false; + return 2; + } + return 0; +} + int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms) { struct timeval timeout; diff --git a/examples/protocols/websocket/main/websocket_example.c b/examples/protocols/websocket/main/websocket_example.c index eb3db80ce7..a1cb08c566 100644 --- a/examples/protocols/websocket/main/websocket_example.c +++ b/examples/protocols/websocket/main/websocket_example.c @@ -137,6 +137,7 @@ void app_main(void) ESP_LOGI(TAG, "[APP] IDF version: %s", esp_get_idf_version()); esp_log_level_set("*", ESP_LOG_INFO); esp_log_level_set("WEBSOCKET_CLIENT", ESP_LOG_DEBUG); + esp_log_level_set("TRANSPORT_WS", ESP_LOG_DEBUG); esp_log_level_set("TRANS_TCP", ESP_LOG_DEBUG); ESP_ERROR_CHECK(nvs_flash_init());