Merge branch 'examples/mqtt_tests_migration_pytest' into 'master'

Examples: migration mqtt examples to pytest

See merge request espressif/esp-idf!20957
This commit is contained in:
Suren Gabrielyan 2022-12-21 21:23:14 +08:00
commit e5926d1b1b
7 changed files with 161 additions and 147 deletions

View File

@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os
import re
import ssl
@ -5,8 +8,9 @@ import sys
from threading import Event, Thread
import paho.mqtt.client as mqtt
import ttfw_idf
from tiny_test_fw import DUT
import pexpect
import pytest
from pytest_embedded import Dut
event_client_connected = Event()
event_stop_client = Event()
@ -16,19 +20,20 @@ message_log = ''
# The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc):
def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, str, bool, str) -> None
_ = (userdata, flags)
print('Connected with result code ' + str(rc))
event_client_connected.set()
client.subscribe('/topic/qos0')
def mqtt_client_task(client):
def mqtt_client_task(client): # type: (mqtt.Client) -> None
while not event_stop_client.is_set():
client.loop()
# The callback for when a PUBLISH message is received from the server.
def on_message(client, userdata, msg):
def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
global message_log
global event_client_received_correct
global event_client_received_binary
@ -55,8 +60,9 @@ def on_message(client, userdata, msg):
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
def test_examples_protocol_mqtt_ssl(env, extra_data):
@pytest.mark.esp32
@pytest.mark.ethernet
def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None
broker_url = ''
broker_port = 0
"""
@ -67,18 +73,17 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
4. Test ESP32 client received correct qos0 message
5. Test python client receives binary data from running partition and compares it with the binary
"""
dut1 = env.get_dut('mqtt_ssl', 'examples/protocols/mqtt/ssl', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_ssl.bin')
binary_file = os.path.join(dut.app.binary_path, 'mqtt_ssl.bin')
bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_ssl_bin_size', '{}KB'
.format(bin_size // 1024))
logging.info('[Performance][mqtt_ssl_bin_size]: %s KB', bin_size // 1024)
# Look for host:port in sdkconfig
try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
assert value is not None
broker_url = value.group(1)
broker_port = int(value.group(2))
bin_size = min(int(dut1.app.get_sdkconfig()['CONFIG_BROKER_BIN_SIZE_TO_SEND']), bin_size)
bin_size = min(int(dut.app.sdkconfig.get('BROKER_BIN_SIZE_TO_SEND')), bin_size)
except Exception:
print('ENV_TEST_FAILURE: Cannot find broker url in sdkconfig')
raise
@ -105,25 +110,20 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30):
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app()
try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout:
except pexpect.TIMEOUT:
print('ENV_TEST_FAILURE: Cannot connect to AP')
raise
print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r'DATA=send binary please'), timeout=30)
dut.expect(r'DATA=send binary please', timeout=30)
print('Receiving binary data from running partition...')
if not event_client_received_binary.wait(timeout=30):
raise ValueError('Binary not received within timeout')
finally:
event_stop_client.set()
thread1.join()
if __name__ == '__main__':
test_examples_protocol_mqtt_ssl()

View File

@ -1,19 +1,22 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os
import re
import socket
import struct
import sys
import time
from threading import Thread
import ttfw_idf
import pexpect
import pytest
from common_test_methods import get_host_ip4_by_dest_ip
from tiny_test_fw import DUT
from pytest_embedded import Dut
msgid = -1
def mqqt_server_sketch(my_ip, port):
def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None
global msgid
print('Starting the server on {}'.format(my_ip))
s = None
@ -32,13 +35,13 @@ def mqqt_server_sketch(my_ip, port):
raise
data = q.recv(1024)
# check if received initial empty message
print('received from client {}'.format(data))
print('received from client {!r}'.format(data))
data = bytearray([0x20, 0x02, 0x00, 0x00])
q.send(data)
# try to receive qos1
data = q.recv(1024)
msgid = struct.unpack('>H', data[15:17])[0]
print('received from client {}, msgid: {}'.format(data, msgid))
print('received from client {!r}, msgid: {}'.format(data, msgid))
data = bytearray([0x40, 0x02, data[15], data[16]])
q.send(data)
time.sleep(5)
@ -46,8 +49,9 @@ def mqqt_server_sketch(my_ip, port):
print('server closed')
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
def test_examples_protocol_mqtt_qos1(env, extra_data):
@pytest.mark.esp32
@pytest.mark.ethernet
def test_examples_protocol_mqtt_qos1(dut: Dut) -> None:
global msgid
"""
steps: (QoS1: Happy flow)
@ -56,18 +60,15 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
3. Test evaluates that qos1 message is queued and removed from queued after ACK received
4. Test the broker received the same message id evaluated in step 3
"""
dut1 = env.get_dut('mqtt_tcp', 'examples/protocols/mqtt/tcp', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_tcp.bin')
binary_file = os.path.join(dut.app.binary_path, 'mqtt_tcp.bin')
bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_tcp_bin_size', '{}KB'.format(bin_size // 1024))
# 1. start the dut test and wait till client gets IP address
dut1.start_app()
logging.info('[Performance][mqtt_tcp_bin_size]: %s KB', bin_size // 1024)
# waiting for getting the IP address
try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)', timeout=30).group(1).decode()
print('Connected to AP/Ethernet with IP: {}'.format(ip_address))
except DUT.ExpectTimeout:
except pexpect.TIMEOUT:
raise ValueError('ENV_TEST_FAILURE: Cannot connect to AP/Ethernet')
# 2. start mqtt broker sketch
@ -75,20 +76,17 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883))
thread1.start()
print('writing to device: {}'.format('mqtt://' + host_ip + '\n'))
dut1.write('mqtt://' + host_ip + '\n')
data_write = 'mqtt://' + host_ip
print('writing to device: {}'.format(data_write))
dut.write(data_write)
thread1.join()
print('Message id received from server: {}'.format(msgid))
# 3. check the message id was enqueued and then deleted
msgid_enqueued = dut1.expect(re.compile(r'outbox: ENQUEUE msgid=([0-9]+)'), timeout=30)
msgid_deleted = dut1.expect(re.compile(r'outbox: DELETED msgid=([0-9]+)'), timeout=30)
msgid_enqueued = dut.expect(b'outbox: ENQUEUE msgid=([0-9]+)', timeout=30).group(1).decode()
msgid_deleted = dut.expect(b'outbox: DELETED msgid=([0-9]+)', timeout=30).group(1).decode()
# 4. check the msgid of received data are the same as that of enqueued and deleted from outbox
if (msgid_enqueued[0] == str(msgid) and msgid_deleted[0] == str(msgid)):
if (msgid_enqueued == str(msgid) and msgid_deleted == str(msgid)):
print('PASS: Received correct msg id')
else:
print('Failure!')
raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted))
if __name__ == '__main__':
test_examples_protocol_mqtt_qos1()

View File

@ -1,11 +1,16 @@
#!/usr/bin/env python
#
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os
import re
import sys
from threading import Event, Thread
import paho.mqtt.client as mqtt
import ttfw_idf
from tiny_test_fw import DUT
import pytest
from pytest_embedded import Dut
event_client_connected = Event()
event_stop_client = Event()
@ -14,19 +19,21 @@ message_log = ''
# The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc):
def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None
_ = (userdata, flags)
print('Connected with result code ' + str(rc))
event_client_connected.set()
client.subscribe('/topic/qos0')
def mqtt_client_task(client):
def mqtt_client_task(client): # type: (mqtt.Client) -> None
while not event_stop_client.is_set():
client.loop()
# The callback for when a PUBLISH message is received from the server.
def on_message(client, userdata, msg):
def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
_ = userdata
global message_log
payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == 'data':
@ -36,8 +43,9 @@ def on_message(client, userdata, msg):
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
def test_examples_protocol_mqtt_ws(env, extra_data):
@pytest.mark.esp32
@pytest.mark.ethernet
def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None
broker_url = ''
broker_port = 0
"""
@ -47,14 +55,14 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
3. Test evaluates it received correct qos0 message
4. Test ESP32 client received correct qos0 message
"""
dut1 = env.get_dut('mqtt_websocket', 'examples/protocols/mqtt/ws', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket.bin')
binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket.bin')
bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_websocket_bin_size', '{}KB'.format(bin_size // 1024))
logging.info('[Performance][mqtt_websocket_bin_size]: %s KB', bin_size // 1024)
# Look for host:port in sdkconfig
try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
assert value is not None
broker_url = value.group(1)
broker_port = int(value.group(2))
except Exception:
@ -78,22 +86,17 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30):
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app()
try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout:
except Dut.ExpectTimeout:
print('ENV_TEST_FAILURE: Cannot connect to AP')
raise
print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30)
dut.expect(r'DATA=data_to_esp32', timeout=30)
finally:
event_stop_client.set()
thread1.join()
if __name__ == '__main__':
test_examples_protocol_mqtt_ws()

View File

@ -1,3 +1,8 @@
#!/usr/bin/env python
#
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os
import re
import ssl
@ -5,8 +10,9 @@ import sys
from threading import Event, Thread
import paho.mqtt.client as mqtt
import ttfw_idf
from tiny_test_fw import DUT
import pexpect
import pytest
from pytest_embedded import Dut
event_client_connected = Event()
event_stop_client = Event()
@ -15,19 +21,21 @@ message_log = ''
# The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc):
def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None
_ = (userdata, flags)
print('Connected with result code ' + str(rc))
event_client_connected.set()
client.subscribe('/topic/qos0')
def mqtt_client_task(client):
def mqtt_client_task(client): # type: (mqtt.Client) -> None
while not event_stop_client.is_set():
client.loop()
# The callback for when a PUBLISH message is received from the server.
def on_message(client, userdata, msg):
def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
_ = userdata
global message_log
payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == 'data':
@ -37,8 +45,9 @@ def on_message(client, userdata, msg):
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
def test_examples_protocol_mqtt_wss(env, extra_data):
@pytest.mark.esp32
@pytest.mark.ethernet
def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None
broker_url = ''
broker_port = 0
"""
@ -48,14 +57,14 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
3. Test evaluates it received correct qos0 message
4. Test ESP32 client received correct qos0 message
"""
dut1 = env.get_dut('mqtt_websocket_secure', 'examples/protocols/mqtt/wss', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket_secure.bin')
binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket_secure.bin')
bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_websocket_secure_bin_size', '{}KB'.format(bin_size // 1024))
logging.info('[Performance][mqtt_websocket_secure_bin_size]: %s KB', bin_size // 1024)
# Look for host:port in sdkconfig
try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
assert value is not None
broker_url = value.group(1)
broker_port = int(value.group(2))
except Exception:
@ -82,22 +91,17 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30):
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app()
try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout:
except pexpect.TIMEOUT:
print('ENV_TEST_FAILURE: Cannot connect to AP')
raise
print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30)
dut.expect(r'DATA=data_to_esp32', timeout=30)
finally:
event_stop_client.set()
thread1.join()
if __name__ == '__main__':
test_examples_protocol_mqtt_wss()

View File

@ -18,6 +18,7 @@ netifaces
rangehttpserver
dbus-python; sys_platform == 'linux'
protobuf
paho-mqtt
# for twai tests, communicate with socket can device (e.g. Canable)
python-can

View File

@ -33,7 +33,3 @@ SimpleWebSocketServer
# py_debug_backend
debug_backend
# examples/protocols/mqtt
# tools/test_apps/protocols/mqtt
paho-mqtt

View File

@ -1,5 +1,8 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
from __future__ import print_function, unicode_literals
import logging
import os
import random
import re
@ -10,21 +13,25 @@ import string
import subprocess
import sys
import time
import typing
from itertools import count
from threading import Event, Lock, Thread
from typing import Any
import paho.mqtt.client as mqtt
import ttfw_idf
import pytest
from common_test_methods import get_host_ip4_by_dest_ip
from pytest_embedded import Dut
from pytest_embedded_qemu.dut import QemuDut
DEFAULT_MSG_SIZE = 16
def _path(f):
def _path(f): # type: (str) -> str
return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
def set_server_cert_cn(ip):
def set_server_cert_cn(ip): # type: (str) -> None
arg_list = [
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'],
['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
@ -36,8 +43,13 @@ def set_server_cert_cn(ip):
# Publisher class creating a python client to send/receive published data from esp-mqtt client
class MqttPublisher:
event_client_connected = Event()
event_client_got_all = Event()
expected_data = ''
published = 0
def __init__(self, dut, transport, qos, repeat, published, queue, publish_cfg, log_details=False):
def __init__(self, dut, transport,
qos, repeat, published, queue, publish_cfg, log_details=False): # type: (MqttPublisher, Dut, str, int, int, int, int, dict, bool) -> None
# instance variables used as parameters of the publish test
self.event_stop_client = Event()
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
@ -58,11 +70,11 @@ class MqttPublisher:
MqttPublisher.event_client_got_all.clear()
MqttPublisher.expected_data = self.sample_string * self.repeat
def print_details(self, text):
def print_details(self, text): # type: (str) -> None
if self.log_details:
print(text)
def mqtt_client_task(self, client, lock):
def mqtt_client_task(self, client, lock): # type: (MqttPublisher, mqtt.Client, Lock) -> None
while not self.event_stop_client.is_set():
with lock:
client.loop()
@ -70,12 +82,12 @@ class MqttPublisher:
# The callback for when the client receives a CONNACK response from the server (needs to be static)
@staticmethod
def on_connect(_client, _userdata, _flags, _rc):
def on_connect(_client, _userdata, _flags, _rc): # type: (mqtt.Client, tuple, bool, str) -> None
MqttPublisher.event_client_connected.set()
# The callback for when a PUBLISH message is received from the server (needs to be static)
@staticmethod
def on_message(client, userdata, msg):
def on_message(client, userdata, msg): # type: (mqtt.Client, int, mqtt.client.MQTTMessage) -> None
payload = msg.payload.decode()
if payload == MqttPublisher.expected_data:
userdata += 1
@ -83,7 +95,7 @@ class MqttPublisher:
if userdata == MqttPublisher.published:
MqttPublisher.event_client_got_all.set()
def __enter__(self):
def __enter__(self): # type: (MqttPublisher) -> None
qos = self.publish_cfg['qos']
queue = self.publish_cfg['queue']
@ -100,6 +112,7 @@ class MqttPublisher:
self.client = mqtt.Client(transport='websockets')
else:
self.client = mqtt.Client()
assert self.client is not None
self.client.on_connect = MqttPublisher.on_connect
self.client.on_message = MqttPublisher.on_message
self.client.user_data_set(0)
@ -137,7 +150,8 @@ class MqttPublisher:
self.event_stop_client.set()
thread1.join()
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None
assert self.client is not None
self.client.disconnect()
self.event_stop_client.clear()
@ -145,7 +159,7 @@ class MqttPublisher:
# Simple server for mqtt over TLS connection
class TlsServer:
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False):
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): # type: (TlsServer, int, bool, bool, bool) -> None
self.port = port
self.socket = socket.socket()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -153,11 +167,9 @@ class TlsServer:
self.shutdown = Event()
self.client_cert = client_cert
self.refuse_connection = refuse_connection
self.ssl_error = None
self.use_alpn = use_alpn
self.negotiated_protocol = None
def __enter__(self):
def __enter__(self): # type: (TlsServer) -> TlsServer
try:
self.socket.bind(('', self.port))
except socket.error as e:
@ -170,20 +182,21 @@ class TlsServer:
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback): # type: (TlsServer, str, str, str) -> None
self.shutdown.set()
self.server_thread.join()
self.socket.close()
if (self.conn is not None):
self.conn.close()
def get_last_ssl_error(self):
def get_last_ssl_error(self): # type: (TlsServer) -> str
return self.ssl_error
@typing.no_type_check
def get_negotiated_protocol(self):
return self.negotiated_protocol
def run_server(self):
def run_server(self) -> None:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if self.client_cert:
context.verify_mode = ssl.CERT_REQUIRED
@ -201,11 +214,10 @@ class TlsServer:
print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
self.handle_conn()
except ssl.SSLError as e:
self.conn = None
self.ssl_error = str(e)
print(' - SSLError: {}'.format(str(e)))
def handle_conn(self):
def handle_conn(self) -> None:
while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1)
try:
@ -216,7 +228,7 @@ class TlsServer:
print(' - error: {}'.format(err))
raise
def process_mqtt_connect(self):
def process_mqtt_connect(self) -> None:
try:
data = bytearray(self.conn.recv(1024))
message = ''.join(format(x, '02x') for x in data)
@ -235,22 +247,22 @@ class TlsServer:
self.shutdown.set()
def connection_tests(dut, cases, dut_ip):
def connection_tests(dut, cases, dut_ip): # type: (Dut, dict, str) -> None
ip = get_host_ip4_by_dest_ip(dut_ip)
set_server_cert_cn(ip)
server_port = 2222
def teardown_connection_suite():
def teardown_connection_suite() -> None:
dut.write('conn teardown 0 0')
def start_connection_case(case, desc):
def start_connection_case(case, desc): # type: (str, str) -> Any
print('Starting {}: {}'.format(case, desc))
case_id = cases[case]
dut.write('conn {} {} {}'.format(ip, server_port, case_id))
dut.expect('Test case:{} started'.format(case_id))
return case_id
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
# All these cases connect to the server with no server verification or with server only verification
with TlsServer(server_port):
test_nr = start_connection_case(case, 'default server - expect to connect normally')
@ -266,13 +278,13 @@ def connection_tests(dut, cases, dut_ip):
if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
# These cases connect to server with both server and client verification (client key might be password protected)
with TlsServer(server_port, client_cert=True):
test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
case = 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
with TlsServer(server_port) as s:
test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
@ -280,7 +292,7 @@ def connection_tests(dut, cases, dut_ip):
if 'alert unknown ca' not in s.get_last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
case = 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
with TlsServer(server_port, client_cert=True) as s:
test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
@ -288,13 +300,13 @@ def connection_tests(dut, cases, dut_ip):
if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
with TlsServer(server_port, use_alpn=True) as s:
test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
if case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
if case == 'EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
print(' - client with alpn off, no negotiated protocol: OK')
elif case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
print(' - client with alpn on, negotiated protocol resolved: OK')
else:
raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))
@ -302,19 +314,19 @@ def connection_tests(dut, cases, dut_ip):
teardown_connection_suite()
@ttfw_idf.idf_custom_test(env_tag='ethernet_router', group='test-apps')
def test_app_protocol_mqtt_publish_connect(env, extra_data):
@pytest.mark.esp32
@pytest.mark.ethernet
def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None:
"""
steps:
1. join AP
2. connect to uri specified in the config
3. send and receive data
"""
dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test')
# check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_publish_connect_test.bin')
binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin')
bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_publish_connect_test_bin_size', '{}KB'.format(bin_size // 1024))
logging.info('[Performance][mqtt_publish_connect_test_bin_size]: %s KB', bin_size // 1024)
# Look for test case symbolic names and publish configs
cases = {}
@ -322,25 +334,24 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
try:
# Get connection test cases configuration: symbolic names for test cases
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
cases[case] = dut1.app.get_sdkconfig()[case]
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
cases[case] = dut.app.sdkconfig.get(case)
except Exception:
print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
raise
dut1.start_app()
esp_ip = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
esp_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode()
print('Got IP={}'.format(esp_ip))
if not os.getenv('MQTT_SKIP_CONNECT_TEST'):
connection_tests(dut1,cases,esp_ip)
connection_tests(dut,cases,esp_ip)
#
# start publish tests only if enabled in the environment (for weekend tests only)
@ -349,27 +360,28 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
# Get publish test configuration
try:
def get_host_port_from_dut(dut1, config_option):
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()[config_option])
@typing.no_type_check
def get_host_port_from_dut(dut, config_option):
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option))
if value is None:
return None, None
return value.group(1), int(value.group(2))
publish_cfg['publish_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCRIBE_TOPIC'].replace('"','')
publish_cfg['subscribe_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','')
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI')
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI')
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI')
publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','')
publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','')
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
except Exception:
print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
raise
def start_publish_case(transport, qos, repeat, published, queue):
def start_publish_case(transport, qos, repeat, published, queue): # type: (str, int, int, int, int) -> None
print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}'
.format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue))
with MqttPublisher(dut1, transport, qos, repeat, published, queue, publish_cfg):
with MqttPublisher(dut, transport, qos, repeat, published, queue, publish_cfg):
pass
# Initialize message sizes and repeat counts (if defined in the environment)
@ -378,7 +390,7 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']})
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
continue
break
if not messages: # No message sizes present in the env - set defaults
@ -400,4 +412,4 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
if __name__ == '__main__':
test_app_protocol_mqtt_publish_connect(dut=ttfw_idf.ESP32QEMUDUT if sys.argv[1:] == ['qemu'] else ttfw_idf.ESP32DUT)
test_app_protocol_mqtt_publish_connect(dut=QemuDut if sys.argv[1:] == ['qemu'] else Dut)