fix(ws_transport): fix first fragment losting during websocket connection

This commit is contained in:
Suren Gabrielyan 2024-01-23 15:54:40 +04:00
parent dc4bf7d3e3
commit 2267d4b6b5

View File

@ -52,11 +52,12 @@ typedef struct {
typedef struct { typedef struct {
char *path; char *path;
char *buffer;
char *sub_protocol; char *sub_protocol;
char *user_agent; char *user_agent;
char *headers; char *headers;
char *auth; char *auth;
char *buffer; /*!< Initial HTTP connection buffer, which may include data beyond the handshake headers, such as the next WebSocket packet*/
size_t buffer_len; /*!< The buffer length */
int http_status_code; int http_status_code;
bool propagate_control_frames; bool propagate_control_frames;
ws_transport_frame_state_t frame_state; ws_transport_frame_state_t frame_state;
@ -101,6 +102,35 @@ static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_hand
return ws->parent; return ws->parent;
} }
static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len, int timeout_ms)
{
// No buffered data to read from, directly attempt to read from the transport.
if (ws->buffer_len == 0) {
return esp_transport_read(ws->parent, buffer, len, timeout_ms);
}
// At this point, buffer_len is guaranteed to be > 0.
int to_read = (ws->buffer_len >= len) ? len : ws->buffer_len;
// Copy the available or requested data to the buffer.
memcpy(buffer, ws->buffer, to_read);
if (to_read < ws->buffer_len) {
// Shift remaining data if not all was read.
memmove(ws->buffer, ws->buffer + to_read, ws->buffer_len - to_read);
ws->buffer_len -= to_read;
} else {
// All buffer data was consumed.
#ifdef CONFIG_WS_DYNAMIC_BUFFER
free(ws->buffer);
ws->buffer = NULL;
#endif
ws->buffer_len = 0;
}
return to_read;
}
static char *trimwhitespace(const char *str) static char *trimwhitespace(const char *str)
{ {
char *end; char *end;
@ -164,6 +194,8 @@ static char *get_http_header(const char *buffer, const char *key)
static int ws_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) static int ws_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
{ {
transport_ws_t *ws = esp_transport_get_context_data(t); transport_ws_t *ws = esp_transport_get_context_data(t);
const char delimiter[] = "\r\n\r\n";
if (esp_transport_connect(ws->parent, host, port, timeout_ms) < 0) { if (esp_transport_connect(ws->parent, host, port, timeout_ms) < 0) {
ESP_LOGE(TAG, "Error connecting to host %s:%d", host, port); ESP_LOGE(TAG, "Error connecting to host %s:%d", host, port);
return -1; return -1;
@ -256,9 +288,12 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
return -1; return -1;
} }
header_len += len; header_len += len;
ws->buffer[header_len] = '\0'; ws->buffer_len = header_len;
ws->buffer[header_len] = '\0'; // We will mark the end of the header to ensure that strstr operations for parsing the headers don't fail.
ESP_LOGD(TAG, "Read header chunk %d, current header size: %d", len, header_len); ESP_LOGD(TAG, "Read header chunk %d, current header size: %d", len, header_len);
} while (NULL == strstr(ws->buffer, "\r\n\r\n") && header_len < WS_BUFFER_SIZE); } while (NULL == strstr(ws->buffer, delimiter) && header_len < WS_BUFFER_SIZE);
char* delim_ptr = strstr(ws->buffer, delimiter);
ws->http_status_code = get_http_status_code(ws->buffer); ws->http_status_code = get_http_status_code(ws->buffer);
if (ws->http_status_code == -1) { if (ws->http_status_code == -1) {
@ -272,6 +307,20 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
return -1; return -1;
} }
if (delim_ptr != NULL) {
size_t delim_pos = delim_ptr - ws->buffer + sizeof(delimiter) - 1;
size_t remaining_len = ws->buffer_len - delim_pos;
if (remaining_len > 0) {
memmove(ws->buffer, ws->buffer + delim_pos, remaining_len);
ws->buffer_len = remaining_len;
} else {
#ifdef CONFIG_WS_DYNAMIC_BUFFER
free(ws->buffer);
ws->buffer = NULL;
#endif
ws->buffer_len = 0;
}
}
// See esp_crypto_sha1() arg size // See esp_crypto_sha1() arg size
unsigned char expected_server_sha1[20]; unsigned char expected_server_sha1[20];
// Size of base64 coded string see above // Size of base64 coded string see above
@ -291,10 +340,6 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
ESP_LOGE(TAG, "Invalid websocket key"); ESP_LOGE(TAG, "Invalid websocket key");
return -1; return -1;
} }
#ifdef CONFIG_WS_DYNAMIC_BUFFER
free(ws->buffer);
ws->buffer = NULL;
#endif
return 0; return 0;
} }
@ -406,7 +451,7 @@ static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int
} }
// Receive and process payload // Receive and process payload
if (bytes_to_read != 0 && (rlen = esp_transport_read(ws->parent, buffer, bytes_to_read, timeout_ms)) <= 0) { if (bytes_to_read != 0 && (rlen = esp_transport_read_internal(ws, buffer, bytes_to_read, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data"); ESP_LOGE(TAG, "Error read data");
return rlen; return rlen;
} }
@ -437,7 +482,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
// Receive and process header first (based on header size) // Receive and process header first (based on header size)
int header = 2; int header = 2;
int mask_len = 4; int mask_len = 4;
if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data"); ESP_LOGE(TAG, "Error read data");
return rlen; return rlen;
} }
@ -451,7 +496,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d", ws->frame_state.opcode, mask, payload_len); ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d", ws->frame_state.opcode, mask, payload_len);
if (payload_len == 126) { if (payload_len == 126) {
// headerLen += 2; // headerLen += 2;
if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data"); ESP_LOGE(TAG, "Error read data");
return rlen; return rlen;
} }
@ -459,7 +504,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
} else if (payload_len == 127) { } else if (payload_len == 127) {
// headerLen += 8; // headerLen += 8;
header = 8; header = 8;
if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) { if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data"); ESP_LOGE(TAG, "Error read data");
return rlen; return rlen;
} }
@ -474,7 +519,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
if (mask) { if (mask) {
// Read and store mask // Read and store mask
if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, mask_len, timeout_ms)) <= 0) { if (payload_len != 0 && (rlen = esp_transport_read_internal(ws, buffer, mask_len, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data"); ESP_LOGE(TAG, "Error read data");
return rlen; return rlen;
} }