diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index 87e89f07cb..fcf2230732 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -34,6 +34,7 @@ typedef struct { const char *sub_protocol; /*!< WS subprotocol */ const char *user_agent; /*!< WS user agent */ const char *headers; /*!< WS additional headers */ + const char *auth; /*!< HTTP authorization header */ bool propagate_control_frames; /*!< If true, control frames are passed to the reader * If false, only user frames are propagated, control frames are handled * automatically during read operations @@ -93,6 +94,18 @@ esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char * */ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers); +/** + * @brief Set websocket authorization headers + * + * @param t websocket transport handle + * @param sub_protocol The HTTP authorization header string, set NULL to clear the old value + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth); + /** * @brief Set websocket transport parameters * diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index ddeaf8e112..07fe4932e1 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -54,6 +54,7 @@ typedef struct { char *user_agent; char *headers; int http_status_code; + char *auth; bool propagate_control_frames; ws_transport_frame_state_t frame_state; esp_transport_handle_t parent; @@ -209,6 +210,16 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int return -1; } } + if (ws->auth) { + ESP_LOGD(TAG, "Authorization: %s", ws->auth); + int r = snprintf(ws->buffer + len, WS_BUFFER_SIZE - len, "Authorization: %s\r\n", ws->auth); + len += r; + if (r <= 0 || len >= WS_BUFFER_SIZE) { + ESP_LOGE(TAG, "Error in request generation" + "(snprintf of authorization returned %d, desired request len: %d, buffer size: %d", r, len, WS_BUFFER_SIZE); + return -1; + } + } if (ws->headers) { ESP_LOGD(TAG, "headers: %s", ws->headers); int r = snprintf(ws->buffer + len, WS_BUFFER_SIZE - len, "%s", ws->headers); @@ -586,6 +597,7 @@ static esp_err_t ws_destroy(esp_transport_handle_t t) free(ws->sub_protocol); free(ws->user_agent); free(ws->headers); + free(ws->auth); free(ws); return 0; } @@ -730,6 +742,26 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea return ESP_OK; } +esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + if (ws->auth) { + free(ws->auth); + } + if (auth == NULL) { + ws->auth = NULL; + return ESP_OK; + } + ws->auth = strdup(auth); + if (ws->auth == NULL) { + return ESP_ERR_NO_MEM; + } + return ESP_OK; +} + esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transport_ws_config_t *config) { if (t == NULL) { @@ -753,6 +785,10 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp err = esp_transport_ws_set_headers(t, config->headers); ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) } + if (config->auth) { + err = esp_transport_ws_set_auth(t, config->auth); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } ws->propagate_control_frames = config->propagate_control_frames; return err;