esp_prov: Compatibility changes and refactoring

- Removed python 2 compatibility
- Removed dependencies on redundant external modules
- Interactive provisioning input for security scheme 2
- Style changes:
  Updated print statements to format strings
  Colored verbose logging
  Raised exceptions on errors instead of clean exits
This commit is contained in:
Laukik Hase 2022-06-22 15:14:19 +05:30
parent 2c4e5c2963
commit 9aefcb12f5
No known key found for this signature in database
GPG Key ID: 11C571361F51A199
20 changed files with 212 additions and 341 deletions

View File

@ -2192,14 +2192,6 @@ tools/ci/python_packages/ttfw_idf/unity_test_parser.py
tools/ci/python_packages/wifi_tools.py
tools/ci/test_autocomplete.py
tools/esp_app_trace/test/sysview/blink.c
tools/esp_prov/__init__.py
tools/esp_prov/prov/__init__.py
tools/esp_prov/prov/wifi_prov.py
tools/esp_prov/security/security.py
tools/esp_prov/security/security0.py
tools/esp_prov/security/security1.py
tools/esp_prov/transport/__init__.py
tools/esp_prov/utils/__init__.py
tools/find_apps.py
tools/find_build_apps/__init__.py
tools/find_build_apps/cmake.py

View File

@ -206,7 +206,6 @@ tools/esp_prov/transport/transport.py
tools/esp_prov/transport/transport_ble.py
tools/esp_prov/transport/transport_console.py
tools/esp_prov/transport/transport_http.py
tools/esp_prov/utils/convenience.py
tools/find_apps.py
tools/find_build_apps/common.py
tools/gen_esp_err_to_name.py

View File

@ -1 +1,5 @@
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
from .esp_prov import * # noqa: export esp_prov module to users

View File

@ -4,8 +4,6 @@
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import print_function
import argparse
import asyncio
import json
@ -13,7 +11,6 @@ import os
import sys
import textwrap
import time
from builtins import input as binput
from getpass import getpass
try:
@ -289,7 +286,7 @@ async def wait_wifi_connected(tp, sec):
retry -= 1
print('Waiting to poll status again (status %s, %d tries left)...' % (ret, retry))
else:
print('---- Provisioning failed ----')
print('---- Provisioning failed! ----')
return False
@ -381,68 +378,69 @@ async def main():
if args.secver == 2 and args.sec2_gen_cred:
if not args.sec2_usr or not args.sec2_pwd:
print('---- Username/password cannot be empty for security scheme 2 (SRP6a) ----')
exit(1)
raise ValueError('Username/password cannot be empty for security scheme 2 (SRP6a)')
print('==== Salt-verifier for security scheme 2 (SRP6a) ====')
security.sec2_gen_salt_verifier(args.sec2_usr, args.sec2_pwd, args.sec2_salt_len)
exit(0)
sys.exit()
obj_transport = await get_transport(args.mode.lower(), args.name)
if obj_transport is None:
print('---- Failed to establish connection ----')
exit(1)
raise RuntimeError('Failed to establish connection')
try:
# If security version not specified check in capabilities
if args.secver is None:
# First check if capabilities are supported or not
if not await has_capability(obj_transport):
print('Security capabilities could not be determined. Please specify "--sec_ver" explicitly')
print('---- Invalid Security Version ----')
exit(2)
print('Security capabilities could not be determined, please specify "--sec_ver" explicitly')
raise ValueError('Invalid Security Version')
# When no_sec is present, use security 0, else security 1
args.secver = int(not await has_capability(obj_transport, 'no_sec'))
print('Security scheme determined to be :', args.secver)
print(f'==== Security Scheme: {args.secver} ====')
if (args.secver != 0) and not await has_capability(obj_transport, 'no_pop'):
if (args.secver == 1):
if not await has_capability(obj_transport, 'no_pop'):
if len(args.sec1_pop) == 0:
args.sec1_pop = binput('Proof of Possession required : ')
prompt_str = 'Proof of Possession required: '
args.sec1_pop = getpass(prompt_str)
elif len(args.sec1_pop) != 0:
print('---- Proof of Possession will be ignored ----')
print('Proof of Possession will be ignored')
args.sec1_pop = ''
if (args.secver == 2):
if len(args.sec2_usr) == 0:
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
if len(args.sec2_pwd) == 0:
prompt_str = 'Security Scheme 2 - SRP6a Password required: '
args.sec2_pwd = getpass(prompt_str)
obj_security = get_security(args.secver, args.sec2_usr, args.sec2_pwd, args.sec1_pop, args.verbose)
if obj_security is None:
print('---- Invalid Security Version ----')
exit(2)
raise ValueError('Invalid Security Version')
if args.version != '':
print('\n==== Verifying protocol version ====')
if not await version_match(obj_transport, args.version, args.verbose):
print('---- Error in protocol version matching ----')
exit(3)
raise RuntimeError('Error in protocol version matching')
print('==== Verified protocol version successfully ====')
print('\n==== Starting Session ====')
if not await establish_session(obj_transport, obj_security):
print('Failed to establish session. Ensure that security scheme and proof of possession are correct')
print('---- Error in establishing session ----')
exit(4)
raise RuntimeError('Error in establishing session')
print('==== Session Established ====')
if args.custom_data != '':
print('\n==== Sending Custom data to esp32 ====')
print('\n==== Sending Custom data to Target ====')
if not await custom_data(obj_transport, obj_security, args.custom_data):
print('---- Error in custom data ----')
exit(5)
raise RuntimeError('Error in custom data')
print('==== Custom data sent successfully ====')
if args.ssid == '':
if not await has_capability(obj_transport, 'wifi_scan'):
print('---- Wi-Fi Scan List is not supported by provisioning service ----')
print('---- Rerun esp_prov with SSID and Passphrase as argument ----')
exit(3)
raise RuntimeError('Wi-Fi Scan List is not supported by provisioning service')
while True:
print('\n==== Scanning Wi-Fi APs ====')
@ -451,12 +449,11 @@ async def main():
end_time = time.time()
print('\n++++ Scan finished in ' + str(end_time - start_time) + ' sec')
if APs is None:
print('---- Error in scanning Wi-Fi APs ----')
exit(8)
raise RuntimeError('Error in scanning Wi-Fi APs')
if len(APs) == 0:
print('No APs found!')
exit(9)
sys.exit()
print('==== Wi-Fi Scan results ====')
print('{0: >4} {1: <33} {2: <12} {3: >4} {4: <4} {5: <16}'.format(
@ -467,7 +464,7 @@ async def main():
while True:
try:
select = int(binput('Select AP by number (0 to rescan) : '))
select = int(input('Select AP by number (0 to rescan) : '))
if select < 0 or select > len(APs):
raise ValueError
break
@ -483,16 +480,14 @@ async def main():
prompt_str = 'Enter passphrase for {0} : '.format(args.ssid)
args.passphrase = getpass(prompt_str)
print('\n==== Sending Wi-Fi credential to esp32 ====')
print('\n==== Sending Wi-Fi Credentials to Target ====')
if not await send_wifi_config(obj_transport, obj_security, args.ssid, args.passphrase):
print('---- Error in send Wi-Fi config ----')
exit(6)
raise RuntimeError('Error in send Wi-Fi config')
print('==== Wi-Fi Credentials sent successfully ====')
print('\n==== Applying config to esp32 ====')
print('\n==== Applying Wi-Fi Config to Target ====')
if not await apply_wifi_config(obj_transport, obj_security):
print('---- Error in apply Wi-Fi config ----')
exit(7)
raise RuntimeError('Error in apply Wi-Fi config')
print('==== Apply config sent successfully ====')
await wait_wifi_connected(obj_transport, obj_security)

View File

@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
from .custom_prov import * # noqa F403

View File

@ -4,26 +4,23 @@
# APIs for interpreting and creating protobuf packets for `custom-config` protocomm endpoint
from __future__ import print_function
import utils
from future.utils import tobytes
from utils import str_to_bytes
def print_verbose(security_ctx, data):
if (security_ctx.verbose):
print('++++ ' + data + ' ++++')
print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def custom_data_request(security_ctx, data):
# Encrypt the custom data
enc_cmd = security_ctx.encrypt_data(tobytes(data))
print_verbose(security_ctx, 'Client -> Device (CustomData cmd) ' + utils.str_to_hexstr(enc_cmd))
enc_cmd = security_ctx.encrypt_data(str_to_bytes(data))
print_verbose(security_ctx, f'Client -> Device (CustomData cmd): 0x{enc_cmd.hex()}')
return enc_cmd.decode('latin-1')
def custom_data_response(security_ctx, response_data):
# Decrypt response packet
decrypt = security_ctx.decrypt_data(tobytes(response_data))
print('CustomData response: ' + str(decrypt))
decrypt = security_ctx.decrypt_data(str_to_bytes(response_data))
print(f'++++ CustomData response: {str(decrypt)}++++')
return 0

View File

@ -1,30 +1,16 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
# APIs for interpreting and creating protobuf packets for Wi-Fi provisioning
from __future__ import print_function
import proto
import utils
from future.utils import tobytes
from utils import str_to_bytes
def print_verbose(security_ctx, data):
if (security_ctx.verbose):
print('++++ ' + data + ' ++++')
print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def config_get_status_request(security_ctx):
@ -33,34 +19,34 @@ def config_get_status_request(security_ctx):
cfg1.msg = proto.wifi_config_pb2.TypeCmdGetStatus
cmd_get_status = proto.wifi_config_pb2.CmdGetStatus()
cfg1.cmd_get_status.MergeFrom(cmd_get_status)
encrypted_cfg = security_ctx.encrypt_data(cfg1.SerializeToString()).decode('latin-1')
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdGetStatus) ' + utils.str_to_hexstr(encrypted_cfg))
return encrypted_cfg
encrypted_cfg = security_ctx.encrypt_data(cfg1.SerializeToString())
print_verbose(security_ctx, f'Client -> Device (Encrypted CmdGetStatus): 0x{encrypted_cfg.hex()}')
return encrypted_cfg.decode('latin-1')
def config_get_status_response(security_ctx, response_data):
# Interpret protobuf response packet from GetStatus command
decrypted_message = security_ctx.decrypt_data(tobytes(response_data))
decrypted_message = security_ctx.decrypt_data(str_to_bytes(response_data))
cmd_resp1 = proto.wifi_config_pb2.WiFiConfigPayload()
cmd_resp1.ParseFromString(decrypted_message)
print_verbose(security_ctx, 'Response type ' + str(cmd_resp1.msg))
print_verbose(security_ctx, 'Response status ' + str(cmd_resp1.resp_get_status.status))
print_verbose(security_ctx, f'CmdGetStatus type: {str(cmd_resp1.msg)}')
print_verbose(security_ctx, f'CmdGetStatus status: {str(cmd_resp1.resp_get_status.status)}')
if cmd_resp1.resp_get_status.sta_state == 0:
print('++++ WiFi state: ' + 'connected ++++')
print('==== WiFi state: Connected ====')
return 'connected'
elif cmd_resp1.resp_get_status.sta_state == 1:
print('++++ WiFi state: ' + 'connecting... ++++')
print('++++ WiFi state: Connecting... ++++')
return 'connecting'
elif cmd_resp1.resp_get_status.sta_state == 2:
print('++++ WiFi state: ' + 'disconnected ++++')
print('---- WiFi state: Disconnected ----')
return 'disconnected'
elif cmd_resp1.resp_get_status.sta_state == 3:
print('++++ WiFi state: ' + 'connection failed ++++')
print('---- WiFi state: Connection Failed ----')
if cmd_resp1.resp_get_status.fail_reason == 0:
print('++++ Failure reason: ' + 'Incorrect Password ++++')
print('---- Failure reason: Incorrect Password ----')
elif cmd_resp1.resp_get_status.fail_reason == 1:
print('++++ Failure reason: ' + 'Incorrect SSID ++++')
print('---- Failure reason: Incorrect SSID ----')
return 'failed'
return 'unknown'
@ -69,19 +55,19 @@ def config_set_config_request(security_ctx, ssid, passphrase):
# Form protobuf request packet for SetConfig command
cmd = proto.wifi_config_pb2.WiFiConfigPayload()
cmd.msg = proto.wifi_config_pb2.TypeCmdSetConfig
cmd.cmd_set_config.ssid = tobytes(ssid)
cmd.cmd_set_config.passphrase = tobytes(passphrase)
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1')
print_verbose(security_ctx, 'Client -> Device (SetConfig cmd) ' + utils.str_to_hexstr(enc_cmd))
return enc_cmd
cmd.cmd_set_config.ssid = str_to_bytes(ssid)
cmd.cmd_set_config.passphrase = str_to_bytes(passphrase)
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, f'Client -> Device (SetConfig cmd): 0x{enc_cmd.hex()}')
return enc_cmd.decode('latin-1')
def config_set_config_response(security_ctx, response_data):
# Interpret protobuf response packet from SetConfig command
decrypt = security_ctx.decrypt_data(tobytes(response_data))
decrypt = security_ctx.decrypt_data(str_to_bytes(response_data))
cmd_resp4 = proto.wifi_config_pb2.WiFiConfigPayload()
cmd_resp4.ParseFromString(decrypt)
print_verbose(security_ctx, 'SetConfig status ' + str(cmd_resp4.resp_set_config.status))
print_verbose(security_ctx, f'SetConfig status: 0x{str(cmd_resp4.resp_set_config.status)}')
return cmd_resp4.resp_set_config.status
@ -89,15 +75,15 @@ def config_apply_config_request(security_ctx):
# Form protobuf request packet for ApplyConfig command
cmd = proto.wifi_config_pb2.WiFiConfigPayload()
cmd.msg = proto.wifi_config_pb2.TypeCmdApplyConfig
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1')
print_verbose(security_ctx, 'Client -> Device (ApplyConfig cmd) ' + utils.str_to_hexstr(enc_cmd))
return enc_cmd
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, f'Client -> Device (ApplyConfig cmd): 0x{enc_cmd.hex()}')
return enc_cmd.decode('latin-1')
def config_apply_config_response(security_ctx, response_data):
# Interpret protobuf response packet from ApplyConfig command
decrypt = security_ctx.decrypt_data(tobytes(response_data))
decrypt = security_ctx.decrypt_data(str_to_bytes(response_data))
cmd_resp5 = proto.wifi_config_pb2.WiFiConfigPayload()
cmd_resp5.ParseFromString(decrypt)
print_verbose(security_ctx, 'ApplyConfig status ' + str(cmd_resp5.resp_apply_config.status))
print_verbose(security_ctx, f'ApplyConfig status: 0x{str(cmd_resp5.resp_apply_config.status)}')
return cmd_resp5.resp_apply_config.status

View File

@ -3,17 +3,13 @@
#
# APIs for interpreting and creating protobuf packets for Wi-Fi Scanning
from __future__ import print_function
import proto
import utils
from future.utils import tobytes
from utils import str_to_bytes
def print_verbose(security_ctx, data):
if (security_ctx.verbose):
print('++++ ' + data + ' ++++')
print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def scan_start_request(security_ctx, blocking=True, passive=False, group_channels=5, period_ms=120):
@ -24,17 +20,17 @@ def scan_start_request(security_ctx, blocking=True, passive=False, group_channel
cmd.cmd_scan_start.passive = passive
cmd.cmd_scan_start.group_channels = group_channels
cmd.cmd_scan_start.period_ms = period_ms
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1')
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdScanStart) ' + utils.str_to_hexstr(enc_cmd))
return enc_cmd
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, f'Client -> Device (Encrypted CmdScanStart): 0x{enc_cmd.hex()}')
return enc_cmd.decode('latin-1')
def scan_start_response(security_ctx, response_data):
# Interpret protobuf response packet from ScanStart command
dec_resp = security_ctx.decrypt_data(tobytes(response_data))
dec_resp = security_ctx.decrypt_data(str_to_bytes(response_data))
resp = proto.wifi_scan_pb2.WiFiScanPayload()
resp.ParseFromString(dec_resp)
print_verbose(security_ctx, 'ScanStart status ' + str(resp.status))
print_verbose(security_ctx, f'ScanStart status: 0x{str(resp.status)}')
if resp.status != 0:
raise RuntimeError
@ -43,17 +39,17 @@ def scan_status_request(security_ctx):
# Form protobuf request packet for ScanStatus command
cmd = proto.wifi_scan_pb2.WiFiScanPayload()
cmd.msg = proto.wifi_scan_pb2.TypeCmdScanStatus
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1')
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdScanStatus) ' + utils.str_to_hexstr(enc_cmd))
return enc_cmd
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, f'Client -> Device (Encrypted CmdScanStatus): 0x{enc_cmd.hex()}')
return enc_cmd.decode('latin-1')
def scan_status_response(security_ctx, response_data):
# Interpret protobuf response packet from ScanStatus command
dec_resp = security_ctx.decrypt_data(tobytes(response_data))
dec_resp = security_ctx.decrypt_data(str_to_bytes(response_data))
resp = proto.wifi_scan_pb2.WiFiScanPayload()
resp.ParseFromString(dec_resp)
print_verbose(security_ctx, 'ScanStatus status ' + str(resp.status))
print_verbose(security_ctx, f'ScanStatus status: 0x{str(resp.status)}')
if resp.status != 0:
raise RuntimeError
return {'finished': resp.resp_scan_status.scan_finished, 'count': resp.resp_scan_status.result_count}
@ -65,17 +61,17 @@ def scan_result_request(security_ctx, index, count):
cmd.msg = proto.wifi_scan_pb2.TypeCmdScanResult
cmd.cmd_scan_result.start_index = index
cmd.cmd_scan_result.count = count
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1')
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdScanResult) ' + utils.str_to_hexstr(enc_cmd))
return enc_cmd
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, f'Client -> Device (Encrypted CmdScanResult): 0x{enc_cmd.hex()}')
return enc_cmd.decode('latin-1')
def scan_result_response(security_ctx, response_data):
# Interpret protobuf response packet from ScanResult command
dec_resp = security_ctx.decrypt_data(tobytes(response_data))
dec_resp = security_ctx.decrypt_data(str_to_bytes(response_data))
resp = proto.wifi_scan_pb2.WiFiScanPayload()
resp.ParseFromString(dec_resp)
print_verbose(security_ctx, 'ScanResult status ' + str(resp.status))
print_verbose(security_ctx, f'ScanResult status: 0x{str(resp.status)}')
if resp.status != 0:
raise RuntimeError
authmode_str = ['Open', 'WEP', 'WPA_PSK', 'WPA2_PSK', 'WPA_WPA2_PSK',
@ -83,13 +79,13 @@ def scan_result_response(security_ctx, response_data):
results = []
for entry in resp.resp_scan_result.entries:
results += [{'ssid': entry.ssid.decode('latin-1').rstrip('\x00'),
'bssid': utils.str_to_hexstr(entry.bssid.decode('latin-1')),
'bssid': entry.bssid.hex(),
'channel': entry.channel,
'rssi': entry.rssi,
'auth': authmode_str[entry.auth]}]
print_verbose(security_ctx, 'ScanResult SSID : ' + str(results[-1]['ssid']))
print_verbose(security_ctx, 'ScanResult BSSID : ' + str(results[-1]['bssid']))
print_verbose(security_ctx, 'ScanResult Channel : ' + str(results[-1]['channel']))
print_verbose(security_ctx, 'ScanResult RSSI : ' + str(results[-1]['rssi']))
print_verbose(security_ctx, 'ScanResult AUTH : ' + str(results[-1]['auth']))
print_verbose(security_ctx, f"ScanResult SSID : {str(results[-1]['ssid'])}")
print_verbose(security_ctx, f"ScanResult BSSID : {str(results[-1]['bssid'])}")
print_verbose(security_ctx, f"ScanResult Channel : {str(results[-1]['channel'])}")
print_verbose(security_ctx, f"ScanResult RSSI : {str(results[-1]['rssi'])}")
print_verbose(security_ctx, f"ScanResult AUTH : {str(results[-1]['auth'])}")
return results

View File

@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
# Base class for protocomm security

View File

@ -1,25 +1,12 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
# APIs for interpreting and creating protobuf packets for
# protocomm endpoint with security type protocomm_security0
from __future__ import print_function
import proto
from future.utils import tobytes
from utils import str_to_bytes
from .security import Security
@ -52,10 +39,10 @@ class Security0(Security):
def setup0_response(self, response_data):
# Interpret protocomm security0 response packet
setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data))
setup_resp.ParseFromString(str_to_bytes(response_data))
# Check if security scheme matches
if setup_resp.sec_ver != proto.session_pb2.SecScheme0:
print('Incorrect sec scheme')
raise RuntimeError('Incorrect security scheme')
def encrypt_data(self, data):
# Passive. No encryption when security0 used

View File

@ -1,35 +1,24 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
# APIs for interpreting and creating protobuf packets for
# protocomm endpoint with security type protocomm_security1
from __future__ import print_function
import proto
import session_pb2
import utils
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from future.utils import tobytes
from utils import long_to_bytes, str_to_bytes
from .security import Security
def a_xor_b(a: bytes, b: bytes) -> bytes:
return b''.join(long_to_bytes(a[i] ^ b[i]) for i in range(0, len(b)))
# Enum for state of protocomm_security1 FSM
class security_state:
REQUEST1 = 0
@ -38,25 +27,11 @@ class security_state:
FINISHED = 3
def xor(a, b):
# XOR two inputs of type `bytes`
ret = bytearray()
# Decode the input bytes to strings
a = a.decode('latin-1')
b = b.decode('latin-1')
for i in range(max(len(a), len(b))):
# Convert the characters to corresponding 8-bit ASCII codes
# then XOR them and store in bytearray
ret.append(([0, ord(a[i])][i < len(a)]) ^ ([0, ord(b[i])][i < len(b)]))
# Convert bytearray to bytes
return bytes(ret)
class Security1(Security):
def __init__(self, pop, verbose):
# Initialize state of the security1 FSM
self.session_state = security_state.REQUEST1
self.pop = tobytes(pop)
self.pop = str_to_bytes(pop)
self.verbose = verbose
Security.__init__(self, self.security1_session)
@ -66,59 +41,55 @@ class Security1(Security):
if (self.session_state == security_state.REQUEST1):
self.session_state = security_state.RESPONSE1_REQUEST2
return self.setup0_request()
if (self.session_state == security_state.RESPONSE1_REQUEST2):
elif (self.session_state == security_state.RESPONSE1_REQUEST2):
self.session_state = security_state.RESPONSE2
self.setup0_response(response_data)
return self.setup1_request()
if (self.session_state == security_state.RESPONSE2):
elif (self.session_state == security_state.RESPONSE2):
self.session_state = security_state.FINISHED
self.setup1_response(response_data)
return None
else:
print('Unexpected state')
return None
print('Unexpected state')
return None
def __generate_key(self):
# Generate private and public key pair for client
self.client_private_key = X25519PrivateKey.generate()
try:
self.client_public_key = self.client_private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
except TypeError:
# backward compatible call for older cryptography library
self.client_public_key = self.client_private_key.public_key().public_bytes()
self.client_public_key = self.client_private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
def _print_verbose(self, data):
if (self.verbose):
print('++++ ' + data + ' ++++')
print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def setup0_request(self):
# Form SessionCmd0 request packet using client public key
setup_req = session_pb2.SessionData()
setup_req.sec_ver = session_pb2.SecScheme1
setup_req = proto.session_pb2.SessionData()
setup_req.sec_ver = proto.session_pb2.SecScheme1
self.__generate_key()
setup_req.sec1.sc0.client_pubkey = self.client_public_key
self._print_verbose('Client Public Key:\t' + utils.str_to_hexstr(self.client_public_key.decode('latin-1')))
self._print_verbose(f'Client Public Key:\t0x{self.client_public_key.hex()}')
return setup_req.SerializeToString().decode('latin-1')
def setup0_response(self, response_data):
# Interpret SessionResp0 response packet
setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data))
setup_resp.ParseFromString(str_to_bytes(response_data))
self._print_verbose('Security version:\t' + str(setup_resp.sec_ver))
if setup_resp.sec_ver != session_pb2.SecScheme1:
print('Incorrect sec scheme')
exit(1)
if setup_resp.sec_ver != proto.session_pb2.SecScheme1:
raise RuntimeError('Incorrect security scheme')
self.device_public_key = setup_resp.sec1.sr0.device_pubkey
# Device random is the initialization vector
device_random = setup_resp.sec1.sr0.device_random
self._print_verbose('Device Public Key:\t' + utils.str_to_hexstr(self.device_public_key.decode('latin-1')))
self._print_verbose('Device Random:\t' + utils.str_to_hexstr(device_random.decode('latin-1')))
self._print_verbose(f'Device Public Key:\t0x{self.device_public_key.hex()}')
self._print_verbose(f'Device Random:\t0x{device_random.hex()}')
# Calculate Curve25519 shared key using Client private key and Device public key
sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key))
self._print_verbose('Shared Key:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
self._print_verbose(f'Shared Key:\t0x{sharedK.hex()}')
# If PoP is provided, XOR SHA256 of PoP with the previously
# calculated Shared Key to form the actual Shared Key
@ -128,8 +99,8 @@ class Security1(Security):
h.update(self.pop)
digest = h.finalize()
# XOR with and update Shared Key
sharedK = xor(sharedK, digest)
self._print_verbose('New Shared Key XORed with PoP:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
sharedK = a_xor_b(sharedK, digest)
self._print_verbose(f'Updated Shared Key (Shared key XORed with PoP):\t0x{sharedK.hex()}')
# Initialize the encryption engine with Shared Key and initialization vector
cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend())
self.cipher = cipher.encryptor()
@ -137,36 +108,33 @@ class Security1(Security):
def setup1_request(self):
# Form SessionCmd1 request packet using encrypted device public key
setup_req = proto.session_pb2.SessionData()
setup_req.sec_ver = session_pb2.SecScheme1
setup_req.sec_ver = proto.session_pb2.SecScheme1
setup_req.sec1.msg = proto.sec1_pb2.Session_Command1
# Encrypt device public key and attach to the request packet
client_verify = self.cipher.update(self.device_public_key)
self._print_verbose('Client Verify:\t' + utils.str_to_hexstr(client_verify.decode('latin-1')))
self._print_verbose(f'Client Proof:\t0x{client_verify.hex()}')
setup_req.sec1.sc1.client_verify_data = client_verify
return setup_req.SerializeToString().decode('latin-1')
def setup1_response(self, response_data):
# Interpret SessionResp1 response packet
setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data))
setup_resp.ParseFromString(str_to_bytes(response_data))
# Ensure security scheme matches
if setup_resp.sec_ver == session_pb2.SecScheme1:
if setup_resp.sec_ver == proto.session_pb2.SecScheme1:
# Read encrypyed device verify string
device_verify = setup_resp.sec1.sr1.device_verify_data
self._print_verbose('Device verify:\t' + utils.str_to_hexstr(device_verify.decode('latin-1')))
self._print_verbose(f'Device Proof:\t0x{device_verify.hex()}')
# Decrypt the device verify string
enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data)
self._print_verbose('Enc client pubkey:\t ' + utils.str_to_hexstr(enc_client_pubkey.decode('latin-1')))
# Match decryped string with client public key
if enc_client_pubkey != self.client_public_key:
print('Mismatch in device verify')
return -2
raise RuntimeError('Failed to verify device!')
else:
print('Unsupported security protocol')
return -1
raise RuntimeError('Unsupported security protocol')
def encrypt_data(self, data):
return self.cipher.update(tobytes(data))
return self.cipher.update(data)
def decrypt_data(self, data):
return self.cipher.update(tobytes(data))
return self.cipher.update(data)

View File

@ -9,10 +9,10 @@ from typing import Any, Type
import proto
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from future.utils import tobytes
from utils import long_to_bytes, str_to_bytes
from .security import Security
from .srp6a import Srp6a, bytes_to_long, generate_salt_and_verifier, long_to_bytes
from .srp6a import Srp6a, generate_salt_and_verifier
AES_KEY_LEN = 256 // 8
@ -70,7 +70,7 @@ class Security2(Security):
self.setup1_response(response_data)
return None
print('Unexpected state')
print('---- Unexpected state! ----')
return None
def _print_verbose(self, data: str) -> None:
@ -83,34 +83,30 @@ class Security2(Security):
setup_req.sec_ver = proto.session_pb2.SecScheme2
setup_req.sec2.msg = proto.sec2_pb2.S2Session_Command0
setup_req.sec2.sc0.client_username = tobytes(self.username)
setup_req.sec2.sc0.client_username = str_to_bytes(self.username)
self.srp6a_ctx = Srp6a(self.username, self.password)
if self.srp6a_ctx is None:
print('Failed to initialize SRP6a instance!')
exit(1)
raise RuntimeError('Failed to initialize SRP6a instance!')
client_pubkey = long_to_bytes(self.srp6a_ctx.A)
setup_req.sec2.sc0.client_pubkey = client_pubkey
self._print_verbose('Client Public Key:\t' + hex(bytes_to_long(client_pubkey)))
self._print_verbose(f'Client Public Key:\t0x{client_pubkey.hex()}')
return setup_req.SerializeToString().decode('latin-1')
def setup0_response(self, response_data: bytes) -> None:
# Interpret SessionResp0 response packet
setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data))
self._print_verbose('Security version:\t' + str(setup_resp.sec_ver))
setup_resp.ParseFromString(str_to_bytes(response_data))
self._print_verbose(f'Security version:\t{str(setup_resp.sec_ver)}')
if setup_resp.sec_ver != proto.session_pb2.SecScheme2:
print('Incorrect sec scheme')
exit(1)
raise RuntimeError('Incorrect security scheme')
# Device public key, random salt and password verifier
device_pubkey = setup_resp.sec2.sr0.device_pubkey
device_salt = setup_resp.sec2.sr0.device_salt
self._print_verbose('Device Public Key:\t' + hex(bytes_to_long(device_pubkey)))
self._print_verbose('Device Salt:\t' + hex(bytes_to_long(device_salt)))
self._print_verbose(f'Device Public Key:\t0x{device_pubkey.hex()}')
self.client_pop_key = self.srp6a_ctx.process_challenge(device_salt, device_pubkey)
def setup1_request(self) -> Any:
@ -120,7 +116,10 @@ class Security2(Security):
setup_req.sec2.msg = proto.sec2_pb2.S2Session_Command1
# Encrypt device public key and attach to the request packet
self._print_verbose('Client Proof:\t' + hex(bytes_to_long(self.client_pop_key)))
if self.client_pop_key is None:
raise RuntimeError('Failed to generate client proof!')
self._print_verbose(f'Client Proof:\t0x{self.client_pop_key.hex()}')
setup_req.sec2.sc1.client_proof = self.client_pop_key
return setup_req.SerializeToString().decode('latin-1')
@ -128,37 +127,36 @@ class Security2(Security):
def setup1_response(self, response_data: bytes) -> Any:
# Interpret SessionResp1 response packet
setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data))
setup_resp.ParseFromString(str_to_bytes(response_data))
# Ensure security scheme matches
if setup_resp.sec_ver == proto.session_pb2.SecScheme2:
# Read encrypyed device proof string
device_proof = setup_resp.sec2.sr1.device_proof
self._print_verbose('Device Proof:\t' + hex(bytes_to_long(device_proof)))
self._print_verbose(f'Device Proof:\t0x{device_proof.hex()}')
self.srp6a_ctx.verify_session(device_proof)
if not self.srp6a_ctx.authenticated():
print('Failed to verify device proof')
exit(1)
raise RuntimeError('Failed to verify device proof')
else:
print('Unsupported security protocol')
exit(1)
raise RuntimeError('Unsupported security protocol')
# Getting the shared secret
shared_secret = self.srp6a_ctx.get_session_key()
self._print_verbose('Shared Secret:\t' + hex(bytes_to_long(shared_secret)))
self._print_verbose(f'Shared Secret:\t0x{shared_secret.hex()}')
# Using the first 256 bits of a 512 bit key
session_key = shared_secret[:AES_KEY_LEN]
self._print_verbose('Session Key:\t' + hex(bytes_to_long(session_key)))
self._print_verbose(f'Session Key:\t0x{session_key.hex()}')
# 96-bit nonce
self.nonce = setup_resp.sec2.sr1.device_nonce
self._print_verbose('Nonce:\t' + hex(bytes_to_long(self.nonce)))
if self.nonce is None:
raise RuntimeError('Received invalid nonce from device!')
self._print_verbose(f'Nonce:\t0x{self.nonce.hex()}')
# Initialize the encryption engine with Shared Key and initialization vector
self.cipher = AESGCM(session_key)
if self.cipher is None:
print('Failed to initialize AES-GCM cryptographic engine!')
exit(1)
raise RuntimeError('Failed to initialize AES-GCM cryptographic engine!')
def encrypt_data(self, data: bytes) -> Any:
return self.cipher.encrypt(self.nonce, data, None)

View File

@ -1,5 +1,6 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
# N A large safe prime (N = 2q+1, where q is prime) [All arithmetic is done modulo N]
# g A generator modulo N
@ -19,6 +20,8 @@ import hashlib
import os
from typing import Any, Callable, Optional, Tuple
from utils import bytes_to_long, long_to_bytes
SHA1 = 0
SHA224 = 1
SHA256 = 2
@ -143,21 +146,11 @@ def get_ng(ng_type: int) -> Tuple[int, int]:
return int(n_hex, 16), int(g_hex, 16)
def bytes_to_long(s: bytes) -> int:
return int.from_bytes(s, 'big')
def long_to_bytes(n: int) -> bytes:
if n == 0:
return b'\x00'
return n.to_bytes((n.bit_length() + 7) // 8, 'big')
def get_random(nbytes: int) -> int:
def get_random(nbytes: int) -> Any:
return bytes_to_long(os.urandom(nbytes))
def get_random_of_length(nbytes: int) -> int:
def get_random_of_length(nbytes: int) -> Any:
offset = (nbytes * 8) - 1
return get_random(nbytes) | (1 << offset)
@ -255,7 +248,7 @@ class Srp6a (object):
def get_username(self) -> str:
return self.Iu
def get_ephemeral_secret(self) -> bytes:
def get_ephemeral_secret(self) -> Any:
return long_to_bytes(self.a)
def get_session_key(self) -> Any:

View File

@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
from .transport_ble import * # noqa: F403, F401

View File

@ -2,12 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import print_function
import platform
from builtins import input
import utils
from utils import hex_str_to_bytes, str_to_bytes
fallback = True
@ -29,18 +26,23 @@ def device_sort(device):
class BLE_Bleak_Client:
def __init__(self):
self.adapter = None
self.adapter_props = None
self.characteristics = dict()
self.chrc_names = None
self.device = None
self.devname = None
self.iface = None
self.nu_lookup = None
self.services = None
self.srv_uuid_adv = None
self.srv_uuid_fallback = None
async def connect(self, devname, iface, chrc_names, fallback_srv_uuid):
self.devname = devname
self.srv_uuid_fallback = fallback_srv_uuid
self.chrc_names = [name.lower() for name in chrc_names]
self.device = None
self.adapter = None
self.services = None
self.nu_lookup = None
self.characteristics = dict()
self.srv_uuid_adv = None
self.iface = iface
print('Discovering...')
try:
@ -62,7 +64,7 @@ class BLE_Bleak_Client:
print('==== BLE Discovery results ====')
print('{0: >4} {1: <33} {2: <12}'.format(
'S.N.', 'Name', 'Address'))
for i in range(len(devices)):
for i, _ in enumerate(devices):
print('[{0: >2}] {1: <33} {2: <12}'.format(i + 1, devices[i].name or 'Unknown', devices[i].address))
while True:
@ -193,10 +195,10 @@ class BLE_Console_Client:
async def send_data(self, characteristic_uuid, data):
print("BLECLI >> Write following data to characteristic with UUID '" + characteristic_uuid + "' :")
print('\t>> ' + utils.str_to_hexstr(data))
print('\t>> ' + str_to_bytes(data).hex())
print('BLECLI >> Enter data read from characteristic (in hex) :')
resp = input('\t<< ')
return utils.hexstr_to_str(resp)
return hex_str_to_bytes(resp)
# --------------------------------------------------------------------

View File

@ -12,6 +12,7 @@ class Transport_BLE(Transport):
def __init__(self, service_uuid, nu_lookup):
self.nu_lookup = nu_lookup
self.service_uuid = service_uuid
self.name_uuid_lookup = None
# Expect service UUID like '0000ffff-0000-1000-8000-00805f9b34fb'
for name in nu_lookup.keys():
# Calculate characteristic UUID for each endpoint
@ -39,7 +40,7 @@ class Transport_BLE(Transport):
# Check if expected characteristics are provided by the service
for name in self.name_uuid_lookup.keys():
if not self.cli.has_characteristic(self.name_uuid_lookup[name]):
raise RuntimeError("'" + name + "' endpoint not found")
raise RuntimeError(f"'{name}' endpoint not found")
async def disconnect(self):
await self.cli.disconnect()
@ -47,5 +48,5 @@ class Transport_BLE(Transport):
async def send_data(self, ep_name, data):
# Write (and read) data to characteristic corresponding to the endpoint
if ep_name not in self.name_uuid_lookup.keys():
raise RuntimeError('Invalid endpoint : ' + ep_name)
raise RuntimeError(f'Invalid endpoint: {ep_name}')
return await self.cli.send_data(self.name_uuid_lookup[ep_name], data)

View File

@ -2,11 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import print_function
from builtins import input
import utils
from utils import hex_str_to_bytes, str_to_bytes
from .transport import Transport
@ -14,10 +10,10 @@ from .transport import Transport
class Transport_Console(Transport):
async def send_data(self, path, data, session_id=0):
print('Client->Device msg :', path, session_id, utils.str_to_hexstr(data))
print('Client->Device msg :', path, session_id, str_to_bytes(data).hex())
try:
resp = input('Enter device->client msg : ')
except Exception as err:
print('error:', err)
return None
return utils.hexstr_to_str(resp)
return hex_str_to_bytes(resp)

View File

@ -22,14 +22,14 @@ class Transport_HTTP(Transport):
try:
socket.gethostbyname(hostname.split(':')[0])
except socket.gaierror:
raise RuntimeError('Unable to resolve hostname :' + hostname)
raise RuntimeError(f'Unable to resolve hostname: {hostname}')
if ssl_context is None:
self.conn = HTTPConnection(hostname, timeout=60)
else:
self.conn = HTTPSConnection(hostname, context=ssl_context, timeout=60)
try:
print('Connecting to ' + hostname)
print(f'++++ Connecting to {hostname}++++')
self.conn.connect()
except Exception as err:
raise RuntimeError('Connection Failure : ' + str(err))

View File

@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
from .convenience import * # noqa: F403, F401

View File

@ -3,21 +3,22 @@
#
# Convenience functions for commonly used data type conversions
import binascii
from future.utils import tobytes
def bytes_to_long(s: bytes) -> int:
return int.from_bytes(s, 'big')
def str_to_hexstr(string):
# Form hexstr by appending ASCII codes (in hex) corresponding to
# each character in the input string
return binascii.hexlify(tobytes(string)).decode('latin-1')
def long_to_bytes(n: int) -> bytes:
if n == 0:
return b'\x00'
return n.to_bytes((n.bit_length() + 7) // 8, 'big')
def hexstr_to_str(hexstr):
# Prepend 0 (if needed) to make the hexstr length an even number
if len(hexstr) % 2 == 1:
hexstr = '0' + hexstr
# Interpret consecutive pairs of hex characters as 8 bit ASCII codes
# and append characters corresponding to each code to form the string
return binascii.unhexlify(tobytes(hexstr)).decode('latin-1')
# 'deadbeef' -> b'deadbeef'
def str_to_bytes(s: str) -> bytes:
return bytes(s, encoding='latin-1')
# 'deadbeef' -> b'\xde\xad\xbe\xef'
def hex_str_to_bytes(s: str) -> bytes:
return bytes.fromhex(s)