mirror of
https://github.com/espressif/esp-idf.git
synced 2024-10-05 20:47:46 -04:00
Examples: migration mqtt examples to pytest
This commit is contained in:
parent
7f4179744b
commit
b68203bfb5
@ -1,3 +1,6 @@
|
||||
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
|
||||
# SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
@ -5,8 +8,9 @@ import sys
|
||||
from threading import Event, Thread
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import ttfw_idf
|
||||
from tiny_test_fw import DUT
|
||||
import pexpect
|
||||
import pytest
|
||||
from pytest_embedded import Dut
|
||||
|
||||
event_client_connected = Event()
|
||||
event_stop_client = Event()
|
||||
@ -16,19 +20,20 @@ message_log = ''
|
||||
|
||||
|
||||
# The callback for when the client receives a CONNACK response from the server.
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, str, bool, str) -> None
|
||||
_ = (userdata, flags)
|
||||
print('Connected with result code ' + str(rc))
|
||||
event_client_connected.set()
|
||||
client.subscribe('/topic/qos0')
|
||||
|
||||
|
||||
def mqtt_client_task(client):
|
||||
def mqtt_client_task(client): # type: (mqtt.Client) -> None
|
||||
while not event_stop_client.is_set():
|
||||
client.loop()
|
||||
|
||||
|
||||
# The callback for when a PUBLISH message is received from the server.
|
||||
def on_message(client, userdata, msg):
|
||||
def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
|
||||
global message_log
|
||||
global event_client_received_correct
|
||||
global event_client_received_binary
|
||||
@ -55,8 +60,9 @@ def on_message(client, userdata, msg):
|
||||
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
|
||||
|
||||
|
||||
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
|
||||
def test_examples_protocol_mqtt_ssl(env, extra_data):
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None
|
||||
broker_url = ''
|
||||
broker_port = 0
|
||||
"""
|
||||
@ -67,18 +73,17 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
|
||||
4. Test ESP32 client received correct qos0 message
|
||||
5. Test python client receives binary data from running partition and compares it with the binary
|
||||
"""
|
||||
dut1 = env.get_dut('mqtt_ssl', 'examples/protocols/mqtt/ssl', dut_class=ttfw_idf.ESP32DUT)
|
||||
# check and log bin size
|
||||
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_ssl.bin')
|
||||
binary_file = os.path.join(dut.app.binary_path, 'mqtt_ssl.bin')
|
||||
bin_size = os.path.getsize(binary_file)
|
||||
ttfw_idf.log_performance('mqtt_ssl_bin_size', '{}KB'
|
||||
.format(bin_size // 1024))
|
||||
logging.info('[Performance][mqtt_ssl_bin_size]: %s KB', bin_size // 1024)
|
||||
|
||||
# Look for host:port in sdkconfig
|
||||
try:
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
|
||||
assert value is not None
|
||||
broker_url = value.group(1)
|
||||
broker_port = int(value.group(2))
|
||||
bin_size = min(int(dut1.app.get_sdkconfig()['CONFIG_BROKER_BIN_SIZE_TO_SEND']), bin_size)
|
||||
bin_size = min(int(dut.app.sdkconfig.get('BROKER_BIN_SIZE_TO_SEND')), bin_size)
|
||||
except Exception:
|
||||
print('ENV_TEST_FAILURE: Cannot find broker url in sdkconfig')
|
||||
raise
|
||||
@ -105,25 +110,20 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
|
||||
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
|
||||
if not event_client_connected.wait(timeout=30):
|
||||
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
|
||||
dut1.start_app()
|
||||
try:
|
||||
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
|
||||
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
|
||||
print('Connected to AP with IP: {}'.format(ip_address))
|
||||
except DUT.ExpectTimeout:
|
||||
except pexpect.TIMEOUT:
|
||||
print('ENV_TEST_FAILURE: Cannot connect to AP')
|
||||
raise
|
||||
print('Checking py-client received msg published from esp...')
|
||||
if not event_client_received_correct.wait(timeout=30):
|
||||
raise ValueError('Wrong data received, msg log: {}'.format(message_log))
|
||||
print('Checking esp-client received msg published from py-client...')
|
||||
dut1.expect(re.compile(r'DATA=send binary please'), timeout=30)
|
||||
dut.expect(r'DATA=send binary please', timeout=30)
|
||||
print('Receiving binary data from running partition...')
|
||||
if not event_client_received_binary.wait(timeout=30):
|
||||
raise ValueError('Binary not received within timeout')
|
||||
finally:
|
||||
event_stop_client.set()
|
||||
thread1.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_examples_protocol_mqtt_ssl()
|
@ -1,19 +1,22 @@
|
||||
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
|
||||
# SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import ttfw_idf
|
||||
import pexpect
|
||||
import pytest
|
||||
from common_test_methods import get_host_ip4_by_dest_ip
|
||||
from tiny_test_fw import DUT
|
||||
from pytest_embedded import Dut
|
||||
|
||||
msgid = -1
|
||||
|
||||
|
||||
def mqqt_server_sketch(my_ip, port):
|
||||
def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None
|
||||
global msgid
|
||||
print('Starting the server on {}'.format(my_ip))
|
||||
s = None
|
||||
@ -32,13 +35,13 @@ def mqqt_server_sketch(my_ip, port):
|
||||
raise
|
||||
data = q.recv(1024)
|
||||
# check if received initial empty message
|
||||
print('received from client {}'.format(data))
|
||||
print('received from client {!r}'.format(data))
|
||||
data = bytearray([0x20, 0x02, 0x00, 0x00])
|
||||
q.send(data)
|
||||
# try to receive qos1
|
||||
data = q.recv(1024)
|
||||
msgid = struct.unpack('>H', data[15:17])[0]
|
||||
print('received from client {}, msgid: {}'.format(data, msgid))
|
||||
print('received from client {!r}, msgid: {}'.format(data, msgid))
|
||||
data = bytearray([0x40, 0x02, data[15], data[16]])
|
||||
q.send(data)
|
||||
time.sleep(5)
|
||||
@ -46,8 +49,9 @@ def mqqt_server_sketch(my_ip, port):
|
||||
print('server closed')
|
||||
|
||||
|
||||
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
|
||||
def test_examples_protocol_mqtt_qos1(env, extra_data):
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
def test_examples_protocol_mqtt_qos1(dut: Dut) -> None:
|
||||
global msgid
|
||||
"""
|
||||
steps: (QoS1: Happy flow)
|
||||
@ -56,18 +60,15 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
|
||||
3. Test evaluates that qos1 message is queued and removed from queued after ACK received
|
||||
4. Test the broker received the same message id evaluated in step 3
|
||||
"""
|
||||
dut1 = env.get_dut('mqtt_tcp', 'examples/protocols/mqtt/tcp', dut_class=ttfw_idf.ESP32DUT)
|
||||
# check and log bin size
|
||||
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_tcp.bin')
|
||||
binary_file = os.path.join(dut.app.binary_path, 'mqtt_tcp.bin')
|
||||
bin_size = os.path.getsize(binary_file)
|
||||
ttfw_idf.log_performance('mqtt_tcp_bin_size', '{}KB'.format(bin_size // 1024))
|
||||
# 1. start the dut test and wait till client gets IP address
|
||||
dut1.start_app()
|
||||
logging.info('[Performance][mqtt_tcp_bin_size]: %s KB', bin_size // 1024)
|
||||
# waiting for getting the IP address
|
||||
try:
|
||||
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
|
||||
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)', timeout=30).group(1).decode()
|
||||
print('Connected to AP/Ethernet with IP: {}'.format(ip_address))
|
||||
except DUT.ExpectTimeout:
|
||||
except pexpect.TIMEOUT:
|
||||
raise ValueError('ENV_TEST_FAILURE: Cannot connect to AP/Ethernet')
|
||||
|
||||
# 2. start mqtt broker sketch
|
||||
@ -75,20 +76,17 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
|
||||
thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883))
|
||||
thread1.start()
|
||||
|
||||
print('writing to device: {}'.format('mqtt://' + host_ip + '\n'))
|
||||
dut1.write('mqtt://' + host_ip + '\n')
|
||||
data_write = 'mqtt://' + host_ip
|
||||
print('writing to device: {}'.format(data_write))
|
||||
dut.write(data_write)
|
||||
thread1.join()
|
||||
print('Message id received from server: {}'.format(msgid))
|
||||
# 3. check the message id was enqueued and then deleted
|
||||
msgid_enqueued = dut1.expect(re.compile(r'outbox: ENQUEUE msgid=([0-9]+)'), timeout=30)
|
||||
msgid_deleted = dut1.expect(re.compile(r'outbox: DELETED msgid=([0-9]+)'), timeout=30)
|
||||
msgid_enqueued = dut.expect(b'outbox: ENQUEUE msgid=([0-9]+)', timeout=30).group(1).decode()
|
||||
msgid_deleted = dut.expect(b'outbox: DELETED msgid=([0-9]+)', timeout=30).group(1).decode()
|
||||
# 4. check the msgid of received data are the same as that of enqueued and deleted from outbox
|
||||
if (msgid_enqueued[0] == str(msgid) and msgid_deleted[0] == str(msgid)):
|
||||
if (msgid_enqueued == str(msgid) and msgid_deleted == str(msgid)):
|
||||
print('PASS: Received correct msg id')
|
||||
else:
|
||||
print('Failure!')
|
||||
raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_examples_protocol_mqtt_qos1()
|
@ -1,11 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
|
||||
# SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from threading import Event, Thread
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import ttfw_idf
|
||||
from tiny_test_fw import DUT
|
||||
import pytest
|
||||
from pytest_embedded import Dut
|
||||
|
||||
event_client_connected = Event()
|
||||
event_stop_client = Event()
|
||||
@ -14,19 +19,21 @@ message_log = ''
|
||||
|
||||
|
||||
# The callback for when the client receives a CONNACK response from the server.
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None
|
||||
_ = (userdata, flags)
|
||||
print('Connected with result code ' + str(rc))
|
||||
event_client_connected.set()
|
||||
client.subscribe('/topic/qos0')
|
||||
|
||||
|
||||
def mqtt_client_task(client):
|
||||
def mqtt_client_task(client): # type: (mqtt.Client) -> None
|
||||
while not event_stop_client.is_set():
|
||||
client.loop()
|
||||
|
||||
|
||||
# The callback for when a PUBLISH message is received from the server.
|
||||
def on_message(client, userdata, msg):
|
||||
def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
|
||||
_ = userdata
|
||||
global message_log
|
||||
payload = msg.payload.decode()
|
||||
if not event_client_received_correct.is_set() and payload == 'data':
|
||||
@ -36,8 +43,9 @@ def on_message(client, userdata, msg):
|
||||
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
|
||||
|
||||
|
||||
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
|
||||
def test_examples_protocol_mqtt_ws(env, extra_data):
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None
|
||||
broker_url = ''
|
||||
broker_port = 0
|
||||
"""
|
||||
@ -47,14 +55,14 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
|
||||
3. Test evaluates it received correct qos0 message
|
||||
4. Test ESP32 client received correct qos0 message
|
||||
"""
|
||||
dut1 = env.get_dut('mqtt_websocket', 'examples/protocols/mqtt/ws', dut_class=ttfw_idf.ESP32DUT)
|
||||
# check and log bin size
|
||||
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket.bin')
|
||||
binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket.bin')
|
||||
bin_size = os.path.getsize(binary_file)
|
||||
ttfw_idf.log_performance('mqtt_websocket_bin_size', '{}KB'.format(bin_size // 1024))
|
||||
logging.info('[Performance][mqtt_websocket_bin_size]: %s KB', bin_size // 1024)
|
||||
# Look for host:port in sdkconfig
|
||||
try:
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
|
||||
assert value is not None
|
||||
broker_url = value.group(1)
|
||||
broker_port = int(value.group(2))
|
||||
except Exception:
|
||||
@ -78,22 +86,17 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
|
||||
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
|
||||
if not event_client_connected.wait(timeout=30):
|
||||
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
|
||||
dut1.start_app()
|
||||
try:
|
||||
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
|
||||
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
|
||||
print('Connected to AP with IP: {}'.format(ip_address))
|
||||
except DUT.ExpectTimeout:
|
||||
except Dut.ExpectTimeout:
|
||||
print('ENV_TEST_FAILURE: Cannot connect to AP')
|
||||
raise
|
||||
print('Checking py-client received msg published from esp...')
|
||||
if not event_client_received_correct.wait(timeout=30):
|
||||
raise ValueError('Wrong data received, msg log: {}'.format(message_log))
|
||||
print('Checking esp-client received msg published from py-client...')
|
||||
dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30)
|
||||
dut.expect(r'DATA=data_to_esp32', timeout=30)
|
||||
finally:
|
||||
event_stop_client.set()
|
||||
thread1.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_examples_protocol_mqtt_ws()
|
@ -1,3 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
|
||||
# SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
@ -5,8 +10,9 @@ import sys
|
||||
from threading import Event, Thread
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import ttfw_idf
|
||||
from tiny_test_fw import DUT
|
||||
import pexpect
|
||||
import pytest
|
||||
from pytest_embedded import Dut
|
||||
|
||||
event_client_connected = Event()
|
||||
event_stop_client = Event()
|
||||
@ -15,19 +21,21 @@ message_log = ''
|
||||
|
||||
|
||||
# The callback for when the client receives a CONNACK response from the server.
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None
|
||||
_ = (userdata, flags)
|
||||
print('Connected with result code ' + str(rc))
|
||||
event_client_connected.set()
|
||||
client.subscribe('/topic/qos0')
|
||||
|
||||
|
||||
def mqtt_client_task(client):
|
||||
def mqtt_client_task(client): # type: (mqtt.Client) -> None
|
||||
while not event_stop_client.is_set():
|
||||
client.loop()
|
||||
|
||||
|
||||
# The callback for when a PUBLISH message is received from the server.
|
||||
def on_message(client, userdata, msg):
|
||||
def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None
|
||||
_ = userdata
|
||||
global message_log
|
||||
payload = msg.payload.decode()
|
||||
if not event_client_received_correct.is_set() and payload == 'data':
|
||||
@ -37,8 +45,9 @@ def on_message(client, userdata, msg):
|
||||
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
|
||||
|
||||
|
||||
@ttfw_idf.idf_example_test(env_tag='ethernet_router')
|
||||
def test_examples_protocol_mqtt_wss(env, extra_data):
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None
|
||||
broker_url = ''
|
||||
broker_port = 0
|
||||
"""
|
||||
@ -48,14 +57,14 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
|
||||
3. Test evaluates it received correct qos0 message
|
||||
4. Test ESP32 client received correct qos0 message
|
||||
"""
|
||||
dut1 = env.get_dut('mqtt_websocket_secure', 'examples/protocols/mqtt/wss', dut_class=ttfw_idf.ESP32DUT)
|
||||
# check and log bin size
|
||||
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket_secure.bin')
|
||||
binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket_secure.bin')
|
||||
bin_size = os.path.getsize(binary_file)
|
||||
ttfw_idf.log_performance('mqtt_websocket_secure_bin_size', '{}KB'.format(bin_size // 1024))
|
||||
logging.info('[Performance][mqtt_websocket_secure_bin_size]: %s KB', bin_size // 1024)
|
||||
# Look for host:port in sdkconfig
|
||||
try:
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI'))
|
||||
assert value is not None
|
||||
broker_url = value.group(1)
|
||||
broker_port = int(value.group(2))
|
||||
except Exception:
|
||||
@ -82,22 +91,17 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
|
||||
print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
|
||||
if not event_client_connected.wait(timeout=30):
|
||||
raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
|
||||
dut1.start_app()
|
||||
try:
|
||||
ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
|
||||
ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0]
|
||||
print('Connected to AP with IP: {}'.format(ip_address))
|
||||
except DUT.ExpectTimeout:
|
||||
except pexpect.TIMEOUT:
|
||||
print('ENV_TEST_FAILURE: Cannot connect to AP')
|
||||
raise
|
||||
print('Checking py-client received msg published from esp...')
|
||||
if not event_client_received_correct.wait(timeout=30):
|
||||
raise ValueError('Wrong data received, msg log: {}'.format(message_log))
|
||||
print('Checking esp-client received msg published from py-client...')
|
||||
dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30)
|
||||
dut.expect(r'DATA=data_to_esp32', timeout=30)
|
||||
finally:
|
||||
event_stop_client.set()
|
||||
thread1.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_examples_protocol_mqtt_wss()
|
@ -18,6 +18,7 @@ netifaces
|
||||
rangehttpserver
|
||||
dbus-python; sys_platform == 'linux'
|
||||
protobuf
|
||||
paho-mqtt
|
||||
|
||||
# for twai tests, communicate with socket can device (e.g. Canable)
|
||||
python-can
|
||||
|
@ -33,7 +33,3 @@ SimpleWebSocketServer
|
||||
|
||||
# py_debug_backend
|
||||
debug_backend
|
||||
|
||||
# examples/protocols/mqtt
|
||||
# tools/test_apps/protocols/mqtt
|
||||
paho-mqtt
|
||||
|
@ -1,5 +1,8 @@
|
||||
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
|
||||
# SPDX-License-Identifier: Unlicense OR CC0-1.0
|
||||
from __future__ import print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@ -10,21 +13,25 @@ import string
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
from itertools import count
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
import ttfw_idf
|
||||
import pytest
|
||||
from common_test_methods import get_host_ip4_by_dest_ip
|
||||
from pytest_embedded import Dut
|
||||
from pytest_embedded_qemu.dut import QemuDut
|
||||
|
||||
DEFAULT_MSG_SIZE = 16
|
||||
|
||||
|
||||
def _path(f):
|
||||
def _path(f): # type: (str) -> str
|
||||
return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
|
||||
|
||||
|
||||
def set_server_cert_cn(ip):
|
||||
def set_server_cert_cn(ip): # type: (str) -> None
|
||||
arg_list = [
|
||||
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'],
|
||||
['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
|
||||
@ -36,8 +43,13 @@ def set_server_cert_cn(ip):
|
||||
|
||||
# Publisher class creating a python client to send/receive published data from esp-mqtt client
|
||||
class MqttPublisher:
|
||||
event_client_connected = Event()
|
||||
event_client_got_all = Event()
|
||||
expected_data = ''
|
||||
published = 0
|
||||
|
||||
def __init__(self, dut, transport, qos, repeat, published, queue, publish_cfg, log_details=False):
|
||||
def __init__(self, dut, transport,
|
||||
qos, repeat, published, queue, publish_cfg, log_details=False): # type: (MqttPublisher, Dut, str, int, int, int, int, dict, bool) -> None
|
||||
# instance variables used as parameters of the publish test
|
||||
self.event_stop_client = Event()
|
||||
self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
|
||||
@ -58,11 +70,11 @@ class MqttPublisher:
|
||||
MqttPublisher.event_client_got_all.clear()
|
||||
MqttPublisher.expected_data = self.sample_string * self.repeat
|
||||
|
||||
def print_details(self, text):
|
||||
def print_details(self, text): # type: (str) -> None
|
||||
if self.log_details:
|
||||
print(text)
|
||||
|
||||
def mqtt_client_task(self, client, lock):
|
||||
def mqtt_client_task(self, client, lock): # type: (MqttPublisher, mqtt.Client, Lock) -> None
|
||||
while not self.event_stop_client.is_set():
|
||||
with lock:
|
||||
client.loop()
|
||||
@ -70,12 +82,12 @@ class MqttPublisher:
|
||||
|
||||
# The callback for when the client receives a CONNACK response from the server (needs to be static)
|
||||
@staticmethod
|
||||
def on_connect(_client, _userdata, _flags, _rc):
|
||||
def on_connect(_client, _userdata, _flags, _rc): # type: (mqtt.Client, tuple, bool, str) -> None
|
||||
MqttPublisher.event_client_connected.set()
|
||||
|
||||
# The callback for when a PUBLISH message is received from the server (needs to be static)
|
||||
@staticmethod
|
||||
def on_message(client, userdata, msg):
|
||||
def on_message(client, userdata, msg): # type: (mqtt.Client, int, mqtt.client.MQTTMessage) -> None
|
||||
payload = msg.payload.decode()
|
||||
if payload == MqttPublisher.expected_data:
|
||||
userdata += 1
|
||||
@ -83,7 +95,7 @@ class MqttPublisher:
|
||||
if userdata == MqttPublisher.published:
|
||||
MqttPublisher.event_client_got_all.set()
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self): # type: (MqttPublisher) -> None
|
||||
|
||||
qos = self.publish_cfg['qos']
|
||||
queue = self.publish_cfg['queue']
|
||||
@ -100,6 +112,7 @@ class MqttPublisher:
|
||||
self.client = mqtt.Client(transport='websockets')
|
||||
else:
|
||||
self.client = mqtt.Client()
|
||||
assert self.client is not None
|
||||
self.client.on_connect = MqttPublisher.on_connect
|
||||
self.client.on_message = MqttPublisher.on_message
|
||||
self.client.user_data_set(0)
|
||||
@ -137,7 +150,8 @@ class MqttPublisher:
|
||||
self.event_stop_client.set()
|
||||
thread1.join()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None
|
||||
assert self.client is not None
|
||||
self.client.disconnect()
|
||||
self.event_stop_client.clear()
|
||||
|
||||
@ -145,7 +159,7 @@ class MqttPublisher:
|
||||
# Simple server for mqtt over TLS connection
|
||||
class TlsServer:
|
||||
|
||||
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False):
|
||||
def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): # type: (TlsServer, int, bool, bool, bool) -> None
|
||||
self.port = port
|
||||
self.socket = socket.socket()
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
@ -153,11 +167,9 @@ class TlsServer:
|
||||
self.shutdown = Event()
|
||||
self.client_cert = client_cert
|
||||
self.refuse_connection = refuse_connection
|
||||
self.ssl_error = None
|
||||
self.use_alpn = use_alpn
|
||||
self.negotiated_protocol = None
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self): # type: (TlsServer) -> TlsServer
|
||||
try:
|
||||
self.socket.bind(('', self.port))
|
||||
except socket.error as e:
|
||||
@ -170,20 +182,21 @@ class TlsServer:
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: (TlsServer, str, str, str) -> None
|
||||
self.shutdown.set()
|
||||
self.server_thread.join()
|
||||
self.socket.close()
|
||||
if (self.conn is not None):
|
||||
self.conn.close()
|
||||
|
||||
def get_last_ssl_error(self):
|
||||
def get_last_ssl_error(self): # type: (TlsServer) -> str
|
||||
return self.ssl_error
|
||||
|
||||
@typing.no_type_check
|
||||
def get_negotiated_protocol(self):
|
||||
return self.negotiated_protocol
|
||||
|
||||
def run_server(self):
|
||||
def run_server(self) -> None:
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
if self.client_cert:
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
@ -201,11 +214,10 @@ class TlsServer:
|
||||
print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
|
||||
self.handle_conn()
|
||||
except ssl.SSLError as e:
|
||||
self.conn = None
|
||||
self.ssl_error = str(e)
|
||||
print(' - SSLError: {}'.format(str(e)))
|
||||
|
||||
def handle_conn(self):
|
||||
def handle_conn(self) -> None:
|
||||
while not self.shutdown.is_set():
|
||||
r,w,e = select.select([self.conn], [], [], 1)
|
||||
try:
|
||||
@ -216,7 +228,7 @@ class TlsServer:
|
||||
print(' - error: {}'.format(err))
|
||||
raise
|
||||
|
||||
def process_mqtt_connect(self):
|
||||
def process_mqtt_connect(self) -> None:
|
||||
try:
|
||||
data = bytearray(self.conn.recv(1024))
|
||||
message = ''.join(format(x, '02x') for x in data)
|
||||
@ -235,22 +247,22 @@ class TlsServer:
|
||||
self.shutdown.set()
|
||||
|
||||
|
||||
def connection_tests(dut, cases, dut_ip):
|
||||
def connection_tests(dut, cases, dut_ip): # type: (Dut, dict, str) -> None
|
||||
ip = get_host_ip4_by_dest_ip(dut_ip)
|
||||
set_server_cert_cn(ip)
|
||||
server_port = 2222
|
||||
|
||||
def teardown_connection_suite():
|
||||
def teardown_connection_suite() -> None:
|
||||
dut.write('conn teardown 0 0')
|
||||
|
||||
def start_connection_case(case, desc):
|
||||
def start_connection_case(case, desc): # type: (str, str) -> Any
|
||||
print('Starting {}: {}'.format(case, desc))
|
||||
case_id = cases[case]
|
||||
dut.write('conn {} {} {}'.format(ip, server_port, case_id))
|
||||
dut.expect('Test case:{} started'.format(case_id))
|
||||
return case_id
|
||||
|
||||
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
|
||||
# All these cases connect to the server with no server verification or with server only verification
|
||||
with TlsServer(server_port):
|
||||
test_nr = start_connection_case(case, 'default server - expect to connect normally')
|
||||
@ -266,13 +278,13 @@ def connection_tests(dut, cases, dut_ip):
|
||||
if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
|
||||
raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
||||
|
||||
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
|
||||
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
|
||||
# These cases connect to server with both server and client verification (client key might be password protected)
|
||||
with TlsServer(server_port, client_cert=True):
|
||||
test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
|
||||
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
||||
|
||||
case = 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
|
||||
case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
|
||||
with TlsServer(server_port) as s:
|
||||
test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
|
||||
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
||||
@ -280,7 +292,7 @@ def connection_tests(dut, cases, dut_ip):
|
||||
if 'alert unknown ca' not in s.get_last_ssl_error():
|
||||
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
||||
|
||||
case = 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
|
||||
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
|
||||
with TlsServer(server_port, client_cert=True) as s:
|
||||
test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
|
||||
dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
|
||||
@ -288,13 +300,13 @@ def connection_tests(dut, cases, dut_ip):
|
||||
if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
|
||||
raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
|
||||
|
||||
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
||||
with TlsServer(server_port, use_alpn=True) as s:
|
||||
test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
|
||||
dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
|
||||
if case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
|
||||
if case == 'EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
|
||||
print(' - client with alpn off, no negotiated protocol: OK')
|
||||
elif case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
|
||||
elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
|
||||
print(' - client with alpn on, negotiated protocol resolved: OK')
|
||||
else:
|
||||
raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))
|
||||
@ -302,19 +314,19 @@ def connection_tests(dut, cases, dut_ip):
|
||||
teardown_connection_suite()
|
||||
|
||||
|
||||
@ttfw_idf.idf_custom_test(env_tag='ethernet_router', group='test-apps')
|
||||
def test_app_protocol_mqtt_publish_connect(env, extra_data):
|
||||
@pytest.mark.esp32
|
||||
@pytest.mark.ethernet
|
||||
def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None:
|
||||
"""
|
||||
steps:
|
||||
1. join AP
|
||||
2. connect to uri specified in the config
|
||||
3. send and receive data
|
||||
"""
|
||||
dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test')
|
||||
# check and log bin size
|
||||
binary_file = os.path.join(dut1.app.binary_path, 'mqtt_publish_connect_test.bin')
|
||||
binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin')
|
||||
bin_size = os.path.getsize(binary_file)
|
||||
ttfw_idf.log_performance('mqtt_publish_connect_test_bin_size', '{}KB'.format(bin_size // 1024))
|
||||
logging.info('[Performance][mqtt_publish_connect_test_bin_size]: %s KB', bin_size // 1024)
|
||||
|
||||
# Look for test case symbolic names and publish configs
|
||||
cases = {}
|
||||
@ -322,25 +334,24 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
|
||||
try:
|
||||
|
||||
# Get connection test cases configuration: symbolic names for test cases
|
||||
for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT',
|
||||
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT',
|
||||
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
|
||||
'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
|
||||
'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
|
||||
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
|
||||
'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
|
||||
'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
||||
cases[case] = dut1.app.get_sdkconfig()[case]
|
||||
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT',
|
||||
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
|
||||
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
|
||||
'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
|
||||
'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
|
||||
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
|
||||
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
|
||||
'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
|
||||
cases[case] = dut.app.sdkconfig.get(case)
|
||||
except Exception:
|
||||
print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
|
||||
raise
|
||||
|
||||
dut1.start_app()
|
||||
esp_ip = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
|
||||
esp_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode()
|
||||
print('Got IP={}'.format(esp_ip))
|
||||
|
||||
if not os.getenv('MQTT_SKIP_CONNECT_TEST'):
|
||||
connection_tests(dut1,cases,esp_ip)
|
||||
connection_tests(dut,cases,esp_ip)
|
||||
|
||||
#
|
||||
# start publish tests only if enabled in the environment (for weekend tests only)
|
||||
@ -349,27 +360,28 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
|
||||
|
||||
# Get publish test configuration
|
||||
try:
|
||||
def get_host_port_from_dut(dut1, config_option):
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()[config_option])
|
||||
@typing.no_type_check
|
||||
def get_host_port_from_dut(dut, config_option):
|
||||
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option))
|
||||
if value is None:
|
||||
return None, None
|
||||
return value.group(1), int(value.group(2))
|
||||
|
||||
publish_cfg['publish_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCRIBE_TOPIC'].replace('"','')
|
||||
publish_cfg['subscribe_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','')
|
||||
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI')
|
||||
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI')
|
||||
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI')
|
||||
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI')
|
||||
publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','')
|
||||
publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','')
|
||||
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI')
|
||||
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI')
|
||||
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
|
||||
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI')
|
||||
|
||||
except Exception:
|
||||
print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
|
||||
raise
|
||||
|
||||
def start_publish_case(transport, qos, repeat, published, queue):
|
||||
def start_publish_case(transport, qos, repeat, published, queue): # type: (str, int, int, int, int) -> None
|
||||
print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}'
|
||||
.format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue))
|
||||
with MqttPublisher(dut1, transport, qos, repeat, published, queue, publish_cfg):
|
||||
with MqttPublisher(dut, transport, qos, repeat, published, queue, publish_cfg):
|
||||
pass
|
||||
|
||||
# Initialize message sizes and repeat counts (if defined in the environment)
|
||||
@ -378,7 +390,7 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
|
||||
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
|
||||
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
|
||||
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
|
||||
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']})
|
||||
messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
|
||||
continue
|
||||
break
|
||||
if not messages: # No message sizes present in the env - set defaults
|
||||
@ -400,4 +412,4 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_app_protocol_mqtt_publish_connect(dut=ttfw_idf.ESP32QEMUDUT if sys.argv[1:] == ['qemu'] else ttfw_idf.ESP32DUT)
|
||||
test_app_protocol_mqtt_publish_connect(dut=QemuDut if sys.argv[1:] == ['qemu'] else Dut)
|
Loading…
Reference in New Issue
Block a user