diff --git a/components/protocomm/include/security/protocomm_security.h b/components/protocomm/include/security/protocomm_security.h index 28f43a6574..04ca8b3355 100644 --- a/components/protocomm/include/security/protocomm_security.h +++ b/components/protocomm/include/security/protocomm_security.h @@ -35,6 +35,8 @@ typedef struct protocomm_security_pop { uint16_t len; } protocomm_security_pop_t; +typedef void * protocomm_security_handle_t; + /** * @brief Protocomm security object structure. * @@ -54,28 +56,31 @@ typedef struct protocomm_security { * Function for initializing/allocating security * infrastructure */ - esp_err_t (*init)(); + esp_err_t (*init)(protocomm_security_handle_t *handle); /** * Function for deallocating security infrastructure */ - esp_err_t (*cleanup)(); + esp_err_t (*cleanup)(protocomm_security_handle_t handle); /** * Starts new secure transport session with specified ID */ - esp_err_t (*new_transport_session)(uint32_t session_id); + esp_err_t (*new_transport_session)(protocomm_security_handle_t handle, + uint32_t session_id); /** * Closes a secure transport session with specified ID */ - esp_err_t (*close_transport_session)(uint32_t session_id); + esp_err_t (*close_transport_session)(protocomm_security_handle_t handle, + uint32_t session_id); /** * Handler function for authenticating connection * request and establishing secure session */ - esp_err_t (*security_req_handler)(const protocomm_security_pop_t *pop, + esp_err_t (*security_req_handler)(protocomm_security_handle_t handle, + const protocomm_security_pop_t *pop, uint32_t session_id, const uint8_t *inbuf, ssize_t inlen, uint8_t **outbuf, ssize_t *outlen, @@ -84,14 +89,16 @@ typedef struct protocomm_security { /** * Function which implements the encryption algorithm */ - esp_err_t (*encrypt)(uint32_t session_id, + esp_err_t (*encrypt)(protocomm_security_handle_t handle, + uint32_t session_id, const uint8_t *inbuf, ssize_t inlen, uint8_t *outbuf, ssize_t *outlen); /** * Function which implements the decryption algorithm */ - esp_err_t (*decrypt)(uint32_t session_id, + esp_err_t (*decrypt)(protocomm_security_handle_t handle, + uint32_t session_id, const uint8_t *inbuf, ssize_t inlen, uint8_t *outbuf, ssize_t *outlen); } protocomm_security_t; diff --git a/components/protocomm/src/common/protocomm.c b/components/protocomm/src/common/protocomm.c index fe0102856b..12b5b59ee3 100644 --- a/components/protocomm/src/common/protocomm.c +++ b/components/protocomm/src/common/protocomm.c @@ -59,7 +59,7 @@ void protocomm_delete(protocomm_t *pc) /* Free memory allocated to security */ if (pc->sec && pc->sec->cleanup) { - pc->sec->cleanup(); + pc->sec->cleanup(pc->sec_inst); } if (pc->pop) { free(pc->pop); @@ -182,7 +182,7 @@ esp_err_t protocomm_req_handle(protocomm_t *pc, const char *ep_name, uint32_t se } ssize_t dec_inbuf_len = inlen; - ret = pc->sec->decrypt(session_id, inbuf, inlen, dec_inbuf, &dec_inbuf_len); + ret = pc->sec->decrypt(pc->sec_inst, session_id, inbuf, inlen, dec_inbuf, &dec_inbuf_len); if (ret != ESP_OK) { ESP_LOGE(TAG, "Decryption of response failed for endpoint %s", ep_name); free(dec_inbuf); @@ -214,7 +214,7 @@ esp_err_t protocomm_req_handle(protocomm_t *pc, const char *ep_name, uint32_t se } ssize_t enc_resp_len = plaintext_resp_len; - ret = pc->sec->encrypt(session_id, plaintext_resp, plaintext_resp_len, + ret = pc->sec->encrypt(pc->sec_inst, session_id, plaintext_resp, plaintext_resp_len, enc_resp, &enc_resp_len); if (ret != ESP_OK) { @@ -253,7 +253,8 @@ static int protocomm_common_security_handler(uint32_t session_id, protocomm_t *pc = (protocomm_t *) priv_data; if (pc->sec && pc->sec->security_req_handler) { - return pc->sec->security_req_handler(pc->pop, session_id, + return pc->sec->security_req_handler(pc->sec_inst, + pc->pop, session_id, inbuf, inlen, outbuf, outlen, priv_data); @@ -283,7 +284,7 @@ esp_err_t protocomm_set_security(protocomm_t *pc, const char *ep_name, } if (sec->init) { - ret = sec->init(); + ret = sec->init(&pc->sec_inst); if (ret != ESP_OK) { ESP_LOGE(TAG, "Error initializing security"); protocomm_remove_endpoint(pc, ep_name); @@ -297,7 +298,8 @@ esp_err_t protocomm_set_security(protocomm_t *pc, const char *ep_name, if (pc->pop == NULL) { ESP_LOGE(TAG, "Error allocating Proof of Possession"); if (pc->sec && pc->sec->cleanup) { - pc->sec->cleanup(); + pc->sec->cleanup(pc->sec_inst); + pc->sec_inst = NULL; pc->sec = NULL; } @@ -316,7 +318,8 @@ esp_err_t protocomm_unset_security(protocomm_t *pc, const char *ep_name) } if (pc->sec && pc->sec->cleanup) { - pc->sec->cleanup(); + pc->sec->cleanup(pc->sec_inst); + pc->sec_inst = NULL; pc->sec = NULL; } diff --git a/components/protocomm/src/common/protocomm_priv.h b/components/protocomm/src/common/protocomm_priv.h index 16e478337d..0757562049 100644 --- a/components/protocomm/src/common/protocomm_priv.h +++ b/components/protocomm/src/common/protocomm_priv.h @@ -61,10 +61,13 @@ struct protocomm { * internally when protocomm_remove_endpoint() is invoked. */ int (*remove_endpoint)(const char *ep_name); - /* Pointer to security layer instance to be used internally for + /* Pointer to security layer to be used internally for * establishing secure sessions */ const protocomm_security_t *sec; + /* Handle to the security layer instance */ + protocomm_security_handle_t sec_inst; + /* Pointer to proof of possession object */ protocomm_security_pop_t *pop; diff --git a/components/protocomm/src/security/security0.c b/components/protocomm/src/security/security0.c index a127136a3c..11de0e824a 100644 --- a/components/protocomm/src/security/security0.c +++ b/components/protocomm/src/security/security0.c @@ -65,7 +65,9 @@ static void sec0_session_setup_cleanup(uint32_t session_id, SessionData *resp) return; } -static esp_err_t sec0_req_handler(const protocomm_security_pop_t *pop, uint32_t session_id, +static esp_err_t sec0_req_handler(protocomm_security_handle_t handle, + const protocomm_security_pop_t *pop, + uint32_t session_id, const uint8_t *inbuf, ssize_t inlen, uint8_t **outbuf, ssize_t *outlen, void *priv_data) diff --git a/components/protocomm/src/security/security1.c b/components/protocomm/src/security/security1.c index 36d99f0a29..20eb4ca191 100644 --- a/components/protocomm/src/security/security1.c +++ b/components/protocomm/src/security/security1.c @@ -57,8 +57,6 @@ typedef struct session { size_t nc_off; } session_t; -static session_t *cur_session; - static void flip_endian(uint8_t *data, size_t len) { uint8_t swp_buf; @@ -75,7 +73,8 @@ static void hexdump(const char *msg, uint8_t *buf, int len) ESP_LOG_BUFFER_HEX_LEVEL(TAG, buf, len, ESP_LOG_DEBUG); } -static esp_err_t handle_session_command1(uint32_t session_id, +static esp_err_t handle_session_command1(session_t *cur_session, + uint32_t session_id, SessionData *req, SessionData *resp) { ESP_LOGD(TAG, "Request to handle setup1_command"); @@ -176,7 +175,8 @@ static esp_err_t handle_session_command1(uint32_t session_id, return ESP_OK; } -static esp_err_t handle_session_command0(uint32_t session_id, +static esp_err_t handle_session_command0(session_t *cur_session, + uint32_t session_id, SessionData *req, SessionData *resp, const protocomm_security_pop_t *pop) { @@ -355,23 +355,14 @@ exit_cmd0: return ret; } -static esp_err_t sec1_session_setup(uint32_t session_id, +static esp_err_t sec1_session_setup(session_t *cur_session, + uint32_t session_id, SessionData *req, SessionData *resp, const protocomm_security_pop_t *pop) { Sec1Payload *in = (Sec1Payload *) req->sec1; esp_err_t ret; - if (!cur_session) { - ESP_LOGE(TAG, "Invalid session context data"); - return ESP_ERR_INVALID_ARG; - } - - if (session_id != cur_session->id) { - ESP_LOGE(TAG, "Invalid session ID : %d (expected %d)", session_id, cur_session->id); - return ESP_ERR_INVALID_STATE; - } - if (!in) { ESP_LOGE(TAG, "Empty session data"); return ESP_ERR_INVALID_ARG; @@ -379,10 +370,10 @@ static esp_err_t sec1_session_setup(uint32_t session_id, switch (in->msg) { case SEC1_MSG_TYPE__Session_Command0: - ret = handle_session_command0(session_id, req, resp, pop); + ret = handle_session_command0(cur_session, session_id, req, resp, pop); break; case SEC1_MSG_TYPE__Session_Command1: - ret = handle_session_command1(session_id, req, resp); + ret = handle_session_command1(cur_session, session_id, req, resp); break; default: ESP_LOGE(TAG, "Invalid security message type"); @@ -393,7 +384,7 @@ static esp_err_t sec1_session_setup(uint32_t session_id, } -static void sec1_session_setup_cleanup(uint32_t session_id, SessionData *resp) +static void sec1_session_setup_cleanup(session_t *cur_session, uint32_t session_id, SessionData *resp) { Sec1Payload *out = resp->sec1; @@ -427,11 +418,16 @@ static void sec1_session_setup_cleanup(uint32_t session_id, SessionData *resp) return; } -static esp_err_t sec1_close_session(uint32_t session_id) +static esp_err_t sec1_close_session(protocomm_security_handle_t handle, uint32_t session_id) { + session_t *cur_session = (session_t *) handle; + if (!cur_session) { + return ESP_ERR_INVALID_ARG; + } + if (!cur_session || cur_session->id != session_id) { ESP_LOGE(TAG, "Attempt to close invalid session"); - return ESP_ERR_INVALID_ARG; + return ESP_ERR_INVALID_STATE; } if (cur_session->state == SESSION_STATE_DONE) { @@ -439,48 +435,63 @@ static esp_err_t sec1_close_session(uint32_t session_id) mbedtls_aes_free(&cur_session->ctx_aes); } - bzero(cur_session, sizeof(session_t)); - free(cur_session); - cur_session = NULL; + memset(cur_session, 0, sizeof(session_t)); + cur_session->id = -1; return ESP_OK; } -static esp_err_t sec1_new_session(uint32_t session_id) +static esp_err_t sec1_new_session(protocomm_security_handle_t handle, uint32_t session_id) { - if (cur_session) { - /* Only one session is allowed at a time */ - ESP_LOGE(TAG, "Closing old session with id %u", cur_session->id); - sec1_close_session(cur_session->id); + session_t *cur_session = (session_t *) handle; + if (!cur_session) { + return ESP_ERR_INVALID_ARG; } - cur_session = (session_t *) calloc(1, sizeof(session_t)); - if (!cur_session) { - ESP_LOGE(TAG, "Error allocating session structure"); - return ESP_ERR_NO_MEM; + if (cur_session->id != -1) { + /* Only one session is allowed at a time */ + ESP_LOGE(TAG, "Closing old session with id %u", cur_session->id); + sec1_close_session(cur_session, session_id); } cur_session->id = session_id; return ESP_OK; } -static esp_err_t sec1_init() +static esp_err_t sec1_init(protocomm_security_handle_t *handle) { - return ESP_OK; -} - -static esp_err_t sec1_cleanup() -{ - if (cur_session) { - ESP_LOGD(TAG, "Closing current session with id %u", cur_session->id); - sec1_close_session(cur_session->id); + if (!handle) { + return ESP_ERR_INVALID_ARG; } + session_t *cur_session = (session_t *) calloc(1, sizeof(session_t)); + if (!cur_session) { + ESP_LOGE(TAG, "Error allocating new session"); + return ESP_ERR_NO_MEM; + } + cur_session->id = -1; + *handle = (protocomm_security_handle_t) cur_session; return ESP_OK; } -static esp_err_t sec1_decrypt(uint32_t session_id, +static esp_err_t sec1_cleanup(protocomm_security_handle_t handle) +{ + session_t *cur_session = (session_t *) handle; + if (cur_session) { + sec1_close_session(handle, cur_session->id); + } + free(handle); + return ESP_OK; +} + +static esp_err_t sec1_decrypt(protocomm_security_handle_t handle, + uint32_t session_id, const uint8_t *inbuf, ssize_t inlen, uint8_t *outbuf, ssize_t *outlen) { + session_t *cur_session = (session_t *) handle; + if (!cur_session) { + return ESP_ERR_INVALID_ARG; + } + if (*outlen < inlen) { return ESP_ERR_INVALID_ARG; } @@ -505,11 +516,24 @@ static esp_err_t sec1_decrypt(uint32_t session_id, return ESP_OK; } -static esp_err_t sec1_req_handler(const protocomm_security_pop_t *pop, uint32_t session_id, +static esp_err_t sec1_req_handler(protocomm_security_handle_t handle, + const protocomm_security_pop_t *pop, + uint32_t session_id, const uint8_t *inbuf, ssize_t inlen, uint8_t **outbuf, ssize_t *outlen, void *priv_data) { + session_t *cur_session = (session_t *) handle; + if (!cur_session) { + ESP_LOGE(TAG, "Invalid session context data"); + return ESP_ERR_INVALID_ARG; + } + + if (session_id != cur_session->id) { + ESP_LOGE(TAG, "Invalid session ID : %d (expected %d)", session_id, cur_session->id); + return ESP_ERR_INVALID_STATE; + } + SessionData *req; SessionData resp; esp_err_t ret; @@ -526,7 +550,7 @@ static esp_err_t sec1_req_handler(const protocomm_security_pop_t *pop, uint32_t } session_data__init(&resp); - ret = sec1_session_setup(session_id, req, &resp, pop); + ret = sec1_session_setup(cur_session, session_id, req, &resp, pop); if (ret != ESP_OK) { ESP_LOGE(TAG, "Session setup error %d", ret); session_data__free_unpacked(req, NULL); @@ -544,7 +568,7 @@ static esp_err_t sec1_req_handler(const protocomm_security_pop_t *pop, uint32_t } session_data__pack(&resp, *outbuf); - sec1_session_setup_cleanup(session_id, &resp); + sec1_session_setup_cleanup(cur_session, session_id, &resp); return ESP_OK; } diff --git a/components/protocomm/src/transports/protocomm_ble.c b/components/protocomm/src/transports/protocomm_ble.c index 5e65aa95d4..5feb83de0d 100644 --- a/components/protocomm/src/transports/protocomm_ble.c +++ b/components/protocomm/src/transports/protocomm_ble.c @@ -273,7 +273,8 @@ static void transport_simple_ble_disconnect(esp_gatts_cb_event_t event, esp_gatt ESP_LOGD(TAG, "Inside disconnect w/ session - %d", param->disconnect.conn_id); if (protoble_internal->pc_ble->sec && protoble_internal->pc_ble->sec->close_transport_session) { - ret = protoble_internal->pc_ble->sec->close_transport_session(param->disconnect.conn_id); + ret = protoble_internal->pc_ble->sec->close_transport_session(protoble_internal->pc_ble->sec_inst, + param->disconnect.conn_id); if (ret != ESP_OK) { ESP_LOGE(TAG, "error closing the session after disconnect"); } @@ -287,7 +288,8 @@ static void transport_simple_ble_connect(esp_gatts_cb_event_t event, esp_gatt_if ESP_LOGD(TAG, "Inside BLE connect w/ conn_id - %d", param->connect.conn_id); if (protoble_internal->pc_ble->sec && protoble_internal->pc_ble->sec->new_transport_session) { - ret = protoble_internal->pc_ble->sec->new_transport_session(param->connect.conn_id); + ret = protoble_internal->pc_ble->sec->new_transport_session(protoble_internal->pc_ble->sec_inst, + param->connect.conn_id); if (ret != ESP_OK) { ESP_LOGE(TAG, "error creating the session"); } diff --git a/components/protocomm/src/transports/protocomm_console.c b/components/protocomm/src/transports/protocomm_console.c index e1e10a7720..9e0243be44 100644 --- a/components/protocomm/src/transports/protocomm_console.c +++ b/components/protocomm/src/transports/protocomm_console.c @@ -145,7 +145,7 @@ static int common_cmd_handler(int argc, char** argv) if (cur_session_id != session_id) { if (pc_console->sec && pc_console->sec->new_transport_session) { - ret = pc_console->sec->new_transport_session(cur_session_id); + ret = pc_console->sec->new_transport_session(pc_console->sec_inst, cur_session_id); if (ret == ESP_OK) { session_id = cur_session_id; } diff --git a/components/protocomm/src/transports/protocomm_httpd.c b/components/protocomm/src/transports/protocomm_httpd.c index b4653b9a58..b59ffb1a70 100644 --- a/components/protocomm/src/transports/protocomm_httpd.c +++ b/components/protocomm/src/transports/protocomm_httpd.c @@ -48,7 +48,7 @@ static esp_err_t common_post_handler(httpd_req_t *req) /* Presently HTTP server doesn't support callback on socket closure so * previous session can only be closed when new session is requested */ if (pc_httpd->sec && pc_httpd->sec->close_transport_session) { - ret = pc_httpd->sec->close_transport_session(session_id); + ret = pc_httpd->sec->close_transport_session(pc_httpd->sec_inst, session_id); if (ret != ESP_OK) { ESP_LOGE(TAG, "Failed to close session with ID: %d", session_id); ret = ESP_FAIL; @@ -58,7 +58,7 @@ static esp_err_t common_post_handler(httpd_req_t *req) session_id = PROTOCOMM_NO_SESSION_ID; } if (pc_httpd->sec && pc_httpd->sec->new_transport_session) { - ret = pc_httpd->sec->new_transport_session(cur_session_id); + ret = pc_httpd->sec->new_transport_session(pc_httpd->sec_inst, cur_session_id); if (ret != ESP_OK) { ESP_LOGE(TAG, "Failed to launch new session with ID: %d", cur_session_id); ret = ESP_FAIL;