Examples: migration mqtt examples to pytest

This commit is contained in:
Suren Gabrielyan 2022-11-07 10:46:46 +04:00
parent 7f4179744b
commit b68203bfb5
7 changed files with 161 additions and 147 deletions

View File

@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os import os
import re import re
import ssl import ssl
@ -5,8 +8,9 @@ import sys
from threading import Event, Thread from threading import Event, Thread
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import ttfw_idf import pexpect
from tiny_test_fw import DUT import pytest
from pytest_embedded import Dut
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
@ -16,19 +20,20 @@ message_log = ''
# The callback for when the client receives a CONNACK response from the server. # The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc): def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, str, bool, str) -> None
_ = (userdata, flags)
print('Connected with result code ' + str(rc)) print('Connected with result code ' + str(rc))
event_client_connected.set() event_client_connected.set()
client.subscribe('/topic/qos0') client.subscribe('/topic/qos0')
def mqtt_client_task(client): def mqtt_client_task(client): # type: (mqtt.Client) -> None
while not event_stop_client.is_set(): while not event_stop_client.is_set():
client.loop() client.loop()
# The callback for when a PUBLISH message is received from the server. # The callback for when a PUBLISH message is received from the server.
def on_message(client, userdata, msg): def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
global message_log global message_log
global event_client_received_correct global event_client_received_correct
global event_client_received_binary global event_client_received_binary
@ -55,8 +60,9 @@ def on_message(client, userdata, msg):
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag='ethernet_router') @pytest.mark.esp32
def test_examples_protocol_mqtt_ssl(env, extra_data): @pytest.mark.ethernet
def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None
broker_url = '' broker_url = ''
broker_port = 0 broker_port = 0
""" """
@ -67,18 +73,17 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
4. Test ESP32 client received correct qos0 message 4. Test ESP32 client received correct qos0 message
5. Test python client receives binary data from running partition and compares it with the binary 5. Test python client receives binary data from running partition and compares it with the binary
""" """
dut1 = env.get_dut('mqtt_ssl', 'examples/protocols/mqtt/ssl', dut_class=ttfw_idf.ESP32DUT) binary_file = os.path.join(dut.app.binary_path, 'mqtt_ssl.bin')
# check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_ssl.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_ssl_bin_size', '{}KB' logging.info('[Performance][mqtt_ssl_bin_size]: %s KB', bin_size // 1024)
.format(bin_size // 1024))
# Look for host:port in sdkconfig # Look for host:port in sdkconfig
try: try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI']) value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
assert value is not None
broker_url = value.group(1) broker_url = value.group(1)
broker_port = int(value.group(2)) broker_port = int(value.group(2))
bin_size = min(int(dut1.app.get_sdkconfig()['CONFIG_BROKER_BIN_SIZE_TO_SEND']), bin_size) bin_size = min(int(dut.app.sdkconfig.get('BROKER_BIN_SIZE_TO_SEND')), bin_size)
except Exception: except Exception:
print('ENV_TEST_FAILURE: Cannot find broker url in sdkconfig') print('ENV_TEST_FAILURE: Cannot find broker url in sdkconfig')
raise raise
@ -105,25 +110,20 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port)) print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30): if not event_client_connected.wait(timeout=30):
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url)) raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app()
try: try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
print('Connected to AP with IP: {}'.format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except pexpect.TIMEOUT:
print('ENV_TEST_FAILURE: Cannot connect to AP') print('ENV_TEST_FAILURE: Cannot connect to AP')
raise raise
print('Checking py-client received msg published from esp...') print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30): if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log)) raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print('Checking esp-client received msg published from py-client...') print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r'DATA=send binary please'), timeout=30) dut.expect(r'DATA=send binary please', timeout=30)
print('Receiving binary data from running partition...') print('Receiving binary data from running partition...')
if not event_client_received_binary.wait(timeout=30): if not event_client_received_binary.wait(timeout=30):
raise ValueError('Binary not received within timeout') raise ValueError('Binary not received within timeout')
finally: finally:
event_stop_client.set() event_stop_client.set()
thread1.join() thread1.join()
if __name__ == '__main__':
test_examples_protocol_mqtt_ssl()

View File

@ -1,19 +1,22 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os import os
import re
import socket import socket
import struct import struct
import sys import sys
import time import time
from threading import Thread from threading import Thread
import ttfw_idf import pexpect
import pytest
from common_test_methods import get_host_ip4_by_dest_ip from common_test_methods import get_host_ip4_by_dest_ip
from tiny_test_fw import DUT from pytest_embedded import Dut
msgid = -1 msgid = -1
def mqqt_server_sketch(my_ip, port): def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None
global msgid global msgid
print('Starting the server on {}'.format(my_ip)) print('Starting the server on {}'.format(my_ip))
s = None s = None
@ -32,13 +35,13 @@ def mqqt_server_sketch(my_ip, port):
raise raise
data = q.recv(1024) data = q.recv(1024)
# check if received initial empty message # check if received initial empty message
print('received from client {}'.format(data)) print('received from client {!r}'.format(data))
data = bytearray([0x20, 0x02, 0x00, 0x00]) data = bytearray([0x20, 0x02, 0x00, 0x00])
q.send(data) q.send(data)
# try to receive qos1 # try to receive qos1
data = q.recv(1024) data = q.recv(1024)
msgid = struct.unpack('>H', data[15:17])[0] msgid = struct.unpack('>H', data[15:17])[0]
print('received from client {}, msgid: {}'.format(data, msgid)) print('received from client {!r}, msgid: {}'.format(data, msgid))
data = bytearray([0x40, 0x02, data[15], data[16]]) data = bytearray([0x40, 0x02, data[15], data[16]])
q.send(data) q.send(data)
time.sleep(5) time.sleep(5)
@ -46,8 +49,9 @@ def mqqt_server_sketch(my_ip, port):
print('server closed') print('server closed')
@ttfw_idf.idf_example_test(env_tag='ethernet_router') @pytest.mark.esp32
def test_examples_protocol_mqtt_qos1(env, extra_data): @pytest.mark.ethernet
def test_examples_protocol_mqtt_qos1(dut: Dut) -> None:
global msgid global msgid
""" """
steps: (QoS1: Happy flow) steps: (QoS1: Happy flow)
@ -56,18 +60,15 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
3. Test evaluates that qos1 message is queued and removed from queued after ACK received 3. Test evaluates that qos1 message is queued and removed from queued after ACK received
4. Test the broker received the same message id evaluated in step 3 4. Test the broker received the same message id evaluated in step 3
""" """
dut1 = env.get_dut('mqtt_tcp', 'examples/protocols/mqtt/tcp', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_tcp.bin') binary_file = os.path.join(dut.app.binary_path, 'mqtt_tcp.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_tcp_bin_size', '{}KB'.format(bin_size // 1024)) logging.info('[Performance][mqtt_tcp_bin_size]: %s KB', bin_size // 1024)
# 1. start the dut test and wait till client gets IP address
dut1.start_app()
# waiting for getting the IP address # waiting for getting the IP address
try: try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)', timeout=30).group(1).decode()
print('Connected to AP/Ethernet with IP: {}'.format(ip_address)) print('Connected to AP/Ethernet with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except pexpect.TIMEOUT:
raise ValueError('ENV_TEST_FAILURE: Cannot connect to AP/Ethernet') raise ValueError('ENV_TEST_FAILURE: Cannot connect to AP/Ethernet')
# 2. start mqtt broker sketch # 2. start mqtt broker sketch
@ -75,20 +76,17 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883)) thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883))
thread1.start() thread1.start()
print('writing to device: {}'.format('mqtt://' + host_ip + '\n')) data_write = 'mqtt://' + host_ip
dut1.write('mqtt://' + host_ip + '\n') print('writing to device: {}'.format(data_write))
dut.write(data_write)
thread1.join() thread1.join()
print('Message id received from server: {}'.format(msgid)) print('Message id received from server: {}'.format(msgid))
# 3. check the message id was enqueued and then deleted # 3. check the message id was enqueued and then deleted
msgid_enqueued = dut1.expect(re.compile(r'outbox: ENQUEUE msgid=([0-9]+)'), timeout=30) msgid_enqueued = dut.expect(b'outbox: ENQUEUE msgid=([0-9]+)', timeout=30).group(1).decode()
msgid_deleted = dut1.expect(re.compile(r'outbox: DELETED msgid=([0-9]+)'), timeout=30) msgid_deleted = dut.expect(b'outbox: DELETED msgid=([0-9]+)', timeout=30).group(1).decode()
# 4. check the msgid of received data are the same as that of enqueued and deleted from outbox # 4. check the msgid of received data are the same as that of enqueued and deleted from outbox
if (msgid_enqueued[0] == str(msgid) and msgid_deleted[0] == str(msgid)): if (msgid_enqueued == str(msgid) and msgid_deleted == str(msgid)):
print('PASS: Received correct msg id') print('PASS: Received correct msg id')
else: else:
print('Failure!') print('Failure!')
raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted)) raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted))
if __name__ == '__main__':
test_examples_protocol_mqtt_qos1()

View File

@ -1,11 +1,16 @@
#!/usr/bin/env python
#
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os import os
import re import re
import sys import sys
from threading import Event, Thread from threading import Event, Thread
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import ttfw_idf import pytest
from tiny_test_fw import DUT from pytest_embedded import Dut
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
@ -14,19 +19,21 @@ message_log = ''
# The callback for when the client receives a CONNACK response from the server. # The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc): def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None
_ = (userdata, flags)
print('Connected with result code ' + str(rc)) print('Connected with result code ' + str(rc))
event_client_connected.set() event_client_connected.set()
client.subscribe('/topic/qos0') client.subscribe('/topic/qos0')
def mqtt_client_task(client): def mqtt_client_task(client): # type: (mqtt.Client) -> None
while not event_stop_client.is_set(): while not event_stop_client.is_set():
client.loop() client.loop()
# The callback for when a PUBLISH message is received from the server. # The callback for when a PUBLISH message is received from the server.
def on_message(client, userdata, msg): def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
_ = userdata
global message_log global message_log
payload = msg.payload.decode() payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == 'data': if not event_client_received_correct.is_set() and payload == 'data':
@ -36,8 +43,9 @@ def on_message(client, userdata, msg):
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag='ethernet_router') @pytest.mark.esp32
def test_examples_protocol_mqtt_ws(env, extra_data): @pytest.mark.ethernet
def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None
broker_url = '' broker_url = ''
broker_port = 0 broker_port = 0
""" """
@ -47,14 +55,14 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
3. Test evaluates it received correct qos0 message 3. Test evaluates it received correct qos0 message
4. Test ESP32 client received correct qos0 message 4. Test ESP32 client received correct qos0 message
""" """
dut1 = env.get_dut('mqtt_websocket', 'examples/protocols/mqtt/ws', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket.bin') binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_websocket_bin_size', '{}KB'.format(bin_size // 1024)) logging.info('[Performance][mqtt_websocket_bin_size]: %s KB', bin_size // 1024)
# Look for host:port in sdkconfig # Look for host:port in sdkconfig
try: try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI']) value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
assert value is not None
broker_url = value.group(1) broker_url = value.group(1)
broker_port = int(value.group(2)) broker_port = int(value.group(2))
except Exception: except Exception:
@ -78,22 +86,17 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port)) print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30): if not event_client_connected.wait(timeout=30):
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url)) raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app()
try: try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
print('Connected to AP with IP: {}'.format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except Dut.ExpectTimeout:
print('ENV_TEST_FAILURE: Cannot connect to AP') print('ENV_TEST_FAILURE: Cannot connect to AP')
raise raise
print('Checking py-client received msg published from esp...') print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30): if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log)) raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print('Checking esp-client received msg published from py-client...') print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30) dut.expect(r'DATA=data_to_esp32', timeout=30)
finally: finally:
event_stop_client.set() event_stop_client.set()
thread1.join() thread1.join()
if __name__ == '__main__':
test_examples_protocol_mqtt_ws()

View File

@ -1,3 +1,8 @@
#!/usr/bin/env python
#
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging
import os import os
import re import re
import ssl import ssl
@ -5,8 +10,9 @@ import sys
from threading import Event, Thread from threading import Event, Thread
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import ttfw_idf import pexpect
from tiny_test_fw import DUT import pytest
from pytest_embedded import Dut
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
@ -15,19 +21,21 @@ message_log = ''
# The callback for when the client receives a CONNACK response from the server. # The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc): def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None
_ = (userdata, flags)
print('Connected with result code ' + str(rc)) print('Connected with result code ' + str(rc))
event_client_connected.set() event_client_connected.set()
client.subscribe('/topic/qos0') client.subscribe('/topic/qos0')
def mqtt_client_task(client): def mqtt_client_task(client): # type: (mqtt.Client) -> None
while not event_stop_client.is_set(): while not event_stop_client.is_set():
client.loop() client.loop()
# The callback for when a PUBLISH message is received from the server. # The callback for when a PUBLISH message is received from the server.
def on_message(client, userdata, msg): def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
_ = userdata
global message_log global message_log
payload = msg.payload.decode() payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == 'data': if not event_client_received_correct.is_set() and payload == 'data':
@ -37,8 +45,9 @@ def on_message(client, userdata, msg):
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag='ethernet_router') @pytest.mark.esp32
def test_examples_protocol_mqtt_wss(env, extra_data): @pytest.mark.ethernet
def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None
broker_url = '' broker_url = ''
broker_port = 0 broker_port = 0
""" """
@ -48,14 +57,14 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
3. Test evaluates it received correct qos0 message 3. Test evaluates it received correct qos0 message
4. Test ESP32 client received correct qos0 message 4. Test ESP32 client received correct qos0 message
""" """
dut1 = env.get_dut('mqtt_websocket_secure', 'examples/protocols/mqtt/wss', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket_secure.bin') binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket_secure.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_websocket_secure_bin_size', '{}KB'.format(bin_size // 1024)) logging.info('[Performance][mqtt_websocket_secure_bin_size]: %s KB', bin_size // 1024)
# Look for host:port in sdkconfig # Look for host:port in sdkconfig
try: try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI']) value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
assert value is not None
broker_url = value.group(1) broker_url = value.group(1)
broker_port = int(value.group(2)) broker_port = int(value.group(2))
except Exception: except Exception:
@ -82,22 +91,17 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port)) print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30): if not event_client_connected.wait(timeout=30):
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url)) raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app()
try: try:
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
print('Connected to AP with IP: {}'.format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except pexpect.TIMEOUT:
print('ENV_TEST_FAILURE: Cannot connect to AP') print('ENV_TEST_FAILURE: Cannot connect to AP')
raise raise
print('Checking py-client received msg published from esp...') print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30): if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log)) raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print('Checking esp-client received msg published from py-client...') print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30) dut.expect(r'DATA=data_to_esp32', timeout=30)
finally: finally:
event_stop_client.set() event_stop_client.set()
thread1.join() thread1.join()
if __name__ == '__main__':
test_examples_protocol_mqtt_wss()

View File

@ -18,6 +18,7 @@ netifaces
rangehttpserver rangehttpserver
dbus-python; sys_platform == 'linux' dbus-python; sys_platform == 'linux'
protobuf protobuf
paho-mqtt
# for twai tests, communicate with socket can device (e.g. Canable) # for twai tests, communicate with socket can device (e.g. Canable)
python-can python-can

View File

@ -33,7 +33,3 @@ SimpleWebSocketServer
# py_debug_backend # py_debug_backend
debug_backend debug_backend
# examples/protocols/mqtt
# tools/test_apps/protocols/mqtt
paho-mqtt

View File

@ -1,5 +1,8 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import logging
import os import os
import random import random
import re import re
@ -10,21 +13,25 @@ import string
import subprocess import subprocess
import sys import sys
import time import time
import typing
from itertools import count from itertools import count
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
from typing import Any
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import ttfw_idf import pytest
from common_test_methods import get_host_ip4_by_dest_ip from common_test_methods import get_host_ip4_by_dest_ip
from pytest_embedded import Dut
from pytest_embedded_qemu.dut import QemuDut
DEFAULT_MSG_SIZE = 16 DEFAULT_MSG_SIZE = 16
def _path(f): def _path(f): # type: (str) -> str
return os.path.join(os.path.dirname(os.path.realpath(__file__)),f) return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
def set_server_cert_cn(ip): def set_server_cert_cn(ip): # type: (str) -> None
arg_list = [ arg_list = [
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'], ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'],
['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'), ['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
@ -36,8 +43,13 @@ def set_server_cert_cn(ip):
# Publisher class creating a python client to send/receive published data from esp-mqtt client # Publisher class creating a python client to send/receive published data from esp-mqtt client
class MqttPublisher: class MqttPublisher:
event_client_connected = Event()
event_client_got_all = Event()
expected_data = ''
published = 0
def __init__(self, dut, transport, qos, repeat, published, queue, publish_cfg, log_details=False): def __init__(self, dut, transport,
qos, repeat, published, queue, publish_cfg, log_details=False): # type: (MqttPublisher, Dut, str, int, int, int, int, dict, bool) -> None
# instance variables used as parameters of the publish test # instance variables used as parameters of the publish test
self.event_stop_client = Event() self.event_stop_client = Event()
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)) self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
@ -58,11 +70,11 @@ class MqttPublisher:
MqttPublisher.event_client_got_all.clear() MqttPublisher.event_client_got_all.clear()
MqttPublisher.expected_data = self.sample_string * self.repeat MqttPublisher.expected_data = self.sample_string * self.repeat
def print_details(self, text): def print_details(self, text): # type: (str) -> None
if self.log_details: if self.log_details:
print(text) print(text)
def mqtt_client_task(self, client, lock): def mqtt_client_task(self, client, lock): # type: (MqttPublisher, mqtt.Client, Lock) -> None
while not self.event_stop_client.is_set(): while not self.event_stop_client.is_set():
with lock: with lock:
client.loop() client.loop()
@ -70,12 +82,12 @@ class MqttPublisher:
# The callback for when the client receives a CONNACK response from the server (needs to be static) # The callback for when the client receives a CONNACK response from the server (needs to be static)
@staticmethod @staticmethod
def on_connect(_client, _userdata, _flags, _rc): def on_connect(_client, _userdata, _flags, _rc): # type: (mqtt.Client, tuple, bool, str) -> None
MqttPublisher.event_client_connected.set() MqttPublisher.event_client_connected.set()
# The callback for when a PUBLISH message is received from the server (needs to be static) # The callback for when a PUBLISH message is received from the server (needs to be static)
@staticmethod @staticmethod
def on_message(client, userdata, msg): def on_message(client, userdata, msg): # type: (mqtt.Client, int, mqtt.client.MQTTMessage) -> None
payload = msg.payload.decode() payload = msg.payload.decode()
if payload == MqttPublisher.expected_data: if payload == MqttPublisher.expected_data:
userdata += 1 userdata += 1
@ -83,7 +95,7 @@ class MqttPublisher:
if userdata == MqttPublisher.published: if userdata == MqttPublisher.published:
MqttPublisher.event_client_got_all.set() MqttPublisher.event_client_got_all.set()
def __enter__(self): def __enter__(self): # type: (MqttPublisher) -> None
qos = self.publish_cfg['qos'] qos = self.publish_cfg['qos']
queue = self.publish_cfg['queue'] queue = self.publish_cfg['queue']
@ -100,6 +112,7 @@ class MqttPublisher:
self.client = mqtt.Client(transport='websockets') self.client = mqtt.Client(transport='websockets')
else: else:
self.client = mqtt.Client() self.client = mqtt.Client()
assert self.client is not None
self.client.on_connect = MqttPublisher.on_connect self.client.on_connect = MqttPublisher.on_connect
self.client.on_message = MqttPublisher.on_message self.client.on_message = MqttPublisher.on_message
self.client.user_data_set(0) self.client.user_data_set(0)
@ -137,7 +150,8 @@ class MqttPublisher:
self.event_stop_client.set() self.event_stop_client.set()
thread1.join() thread1.join()
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None
assert self.client is not None
self.client.disconnect() self.client.disconnect()
self.event_stop_client.clear() self.event_stop_client.clear()
@ -145,7 +159,7 @@ class MqttPublisher:
# Simple server for mqtt over TLS connection # Simple server for mqtt over TLS connection
class TlsServer: class TlsServer:
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): # type: (TlsServer, int, bool, bool, bool) -> None
self.port = port self.port = port
self.socket = socket.socket() self.socket = socket.socket()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -153,11 +167,9 @@ class TlsServer:
self.shutdown = Event() self.shutdown = Event()
self.client_cert = client_cert self.client_cert = client_cert
self.refuse_connection = refuse_connection self.refuse_connection = refuse_connection
self.ssl_error = None
self.use_alpn = use_alpn self.use_alpn = use_alpn
self.negotiated_protocol = None
def __enter__(self): def __enter__(self): # type: (TlsServer) -> TlsServer
try: try:
self.socket.bind(('', self.port)) self.socket.bind(('', self.port))
except socket.error as e: except socket.error as e:
@ -170,20 +182,21 @@ class TlsServer:
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback): # type: (TlsServer, str, str, str) -> None
self.shutdown.set() self.shutdown.set()
self.server_thread.join() self.server_thread.join()
self.socket.close() self.socket.close()
if (self.conn is not None): if (self.conn is not None):
self.conn.close() self.conn.close()
def get_last_ssl_error(self): def get_last_ssl_error(self): # type: (TlsServer) -> str
return self.ssl_error return self.ssl_error
@typing.no_type_check
def get_negotiated_protocol(self): def get_negotiated_protocol(self):
return self.negotiated_protocol return self.negotiated_protocol
def run_server(self): def run_server(self) -> None:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if self.client_cert: if self.client_cert:
context.verify_mode = ssl.CERT_REQUIRED context.verify_mode = ssl.CERT_REQUIRED
@ -201,11 +214,10 @@ class TlsServer:
print(' - negotiated_protocol: {}'.format(self.negotiated_protocol)) print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
self.handle_conn() self.handle_conn()
except ssl.SSLError as e: except ssl.SSLError as e:
self.conn = None
self.ssl_error = str(e) self.ssl_error = str(e)
print(' - SSLError: {}'.format(str(e))) print(' - SSLError: {}'.format(str(e)))
def handle_conn(self): def handle_conn(self) -> None:
while not self.shutdown.is_set(): while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1) r,w,e = select.select([self.conn], [], [], 1)
try: try:
@ -216,7 +228,7 @@ class TlsServer:
print(' - error: {}'.format(err)) print(' - error: {}'.format(err))
raise raise
def process_mqtt_connect(self): def process_mqtt_connect(self) -> None:
try: try:
data = bytearray(self.conn.recv(1024)) data = bytearray(self.conn.recv(1024))
message = ''.join(format(x, '02x') for x in data) message = ''.join(format(x, '02x') for x in data)
@ -235,22 +247,22 @@ class TlsServer:
self.shutdown.set() self.shutdown.set()
def connection_tests(dut, cases, dut_ip): def connection_tests(dut, cases, dut_ip): # type: (Dut, dict, str) -> None
ip = get_host_ip4_by_dest_ip(dut_ip) ip = get_host_ip4_by_dest_ip(dut_ip)
set_server_cert_cn(ip) set_server_cert_cn(ip)
server_port = 2222 server_port = 2222
def teardown_connection_suite(): def teardown_connection_suite() -> None:
dut.write('conn teardown 0 0') dut.write('conn teardown 0 0')
def start_connection_case(case, desc): def start_connection_case(case, desc): # type: (str, str) -> Any
print('Starting {}: {}'.format(case, desc)) print('Starting {}: {}'.format(case, desc))
case_id = cases[case] case_id = cases[case]
dut.write('conn {} {} {}'.format(ip, server_port, case_id)) dut.write('conn {} {} {}'.format(ip, server_port, case_id))
dut.expect('Test case:{} started'.format(case_id)) dut.expect('Test case:{} started'.format(case_id))
return case_id return case_id
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']: for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
# All these cases connect to the server with no server verification or with server only verification # All these cases connect to the server with no server verification or with server only verification
with TlsServer(server_port): with TlsServer(server_port):
test_nr = start_connection_case(case, 'default server - expect to connect normally') test_nr = start_connection_case(case, 'default server - expect to connect normally')
@ -266,13 +278,13 @@ def connection_tests(dut, cases, dut_ip):
if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error(): if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']: for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
# These cases connect to server with both server and client verification (client key might be password protected) # These cases connect to server with both server and client verification (client key might be password protected)
with TlsServer(server_port, client_cert=True): with TlsServer(server_port, client_cert=True):
test_nr = start_connection_case(case, 'server with client verification - expect to connect normally') test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30) dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
case = 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT' case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
with TlsServer(server_port) as s: with TlsServer(server_port) as s:
test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error') test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
@ -280,7 +292,7 @@ def connection_tests(dut, cases, dut_ip):
if 'alert unknown ca' not in s.get_last_ssl_error(): if 'alert unknown ca' not in s.get_last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
case = 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT' case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
with TlsServer(server_port, client_cert=True) as s: with TlsServer(server_port, client_cert=True) as s:
test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error') test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
@ -288,13 +300,13 @@ def connection_tests(dut, cases, dut_ip):
if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error(): if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
with TlsServer(server_port, use_alpn=True) as s: with TlsServer(server_port, use_alpn=True) as s:
test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol') test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30) dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
if case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None: if case == 'EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
print(' - client with alpn off, no negotiated protocol: OK') print(' - client with alpn off, no negotiated protocol: OK')
elif case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt': elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
print(' - client with alpn on, negotiated protocol resolved: OK') print(' - client with alpn on, negotiated protocol resolved: OK')
else: else:
raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol())) raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))
@ -302,19 +314,19 @@ def connection_tests(dut, cases, dut_ip):
teardown_connection_suite() teardown_connection_suite()
@ttfw_idf.idf_custom_test(env_tag='ethernet_router', group='test-apps') @pytest.mark.esp32
def test_app_protocol_mqtt_publish_connect(env, extra_data): @pytest.mark.ethernet
def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None:
""" """
steps: steps:
1. join AP 1. join AP
2. connect to uri specified in the config 2. connect to uri specified in the config
3. send and receive data 3. send and receive data
""" """
dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test')
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_publish_connect_test.bin') binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance('mqtt_publish_connect_test_bin_size', '{}KB'.format(bin_size // 1024)) logging.info('[Performance][mqtt_publish_connect_test_bin_size]: %s KB', bin_size // 1024)
# Look for test case symbolic names and publish configs # Look for test case symbolic names and publish configs
cases = {} cases = {}
@ -322,25 +334,24 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
try: try:
# Get connection test cases configuration: symbolic names for test cases # Get connection test cases configuration: symbolic names for test cases
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', for case in ['EXAMPLE_CONNECT_CASE_NO_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
cases[case] = dut1.app.get_sdkconfig()[case] cases[case] = dut.app.sdkconfig.get(case)
except Exception: except Exception:
print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig') print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
raise raise
dut1.start_app() esp_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode()
esp_ip = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
print('Got IP={}'.format(esp_ip)) print('Got IP={}'.format(esp_ip))
if not os.getenv('MQTT_SKIP_CONNECT_TEST'): if not os.getenv('MQTT_SKIP_CONNECT_TEST'):
connection_tests(dut1,cases,esp_ip) connection_tests(dut,cases,esp_ip)
# #
# start publish tests only if enabled in the environment (for weekend tests only) # start publish tests only if enabled in the environment (for weekend tests only)
@ -349,27 +360,28 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
# Get publish test configuration # Get publish test configuration
try: try:
def get_host_port_from_dut(dut1, config_option): @typing.no_type_check
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()[config_option]) def get_host_port_from_dut(dut, config_option):
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option))
if value is None: if value is None:
return None, None return None, None
return value.group(1), int(value.group(2)) return value.group(1), int(value.group(2))
publish_cfg['publish_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCRIBE_TOPIC'].replace('"','') publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','')
publish_cfg['subscribe_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','') publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','')
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI') publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI') publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI') publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI') publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
except Exception: except Exception:
print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig') print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
raise raise
def start_publish_case(transport, qos, repeat, published, queue): def start_publish_case(transport, qos, repeat, published, queue): # type: (str, int, int, int, int) -> None
print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}' print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}'
.format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue)) .format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue))
with MqttPublisher(dut1, transport, qos, repeat, published, queue, publish_cfg): with MqttPublisher(dut, transport, qos, repeat, published, queue, publish_cfg):
pass pass
# Initialize message sizes and repeat counts (if defined in the environment) # Initialize message sizes and repeat counts (if defined in the environment)
@ -378,7 +390,7 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x} # Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']} env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']): if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
continue continue
break break
if not messages: # No message sizes present in the env - set defaults if not messages: # No message sizes present in the env - set defaults
@ -400,4 +412,4 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
if __name__ == '__main__': if __name__ == '__main__':
test_app_protocol_mqtt_publish_connect(dut=ttfw_idf.ESP32QEMUDUT if sys.argv[1:] == ['qemu'] else ttfw_idf.ESP32DUT) test_app_protocol_mqtt_publish_connect(dut=QemuDut if sys.argv[1:] == ['qemu'] else Dut)