style: format python files with isort and double-quote-string-fixer

This commit is contained in:
Fu Hanxi 2021-01-26 10:49:01 +08:00
parent dc8402ea61
commit 0146f258d7
276 changed files with 8241 additions and 8162 deletions

View File

@ -16,23 +16,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function, division from __future__ import division, print_function
import argparse import argparse
import os
import sys
import binascii import binascii
import tempfile
import collections import collections
import os
import struct import struct
import sys
import tempfile
try: try:
from parttool import PartitionName, PartitionType, ParttoolTarget, PARTITION_TABLE_OFFSET from parttool import PARTITION_TABLE_OFFSET, PartitionName, PartitionType, ParttoolTarget
except ImportError: except ImportError:
COMPONENTS_PATH = os.path.expandvars(os.path.join("$IDF_PATH", "components")) COMPONENTS_PATH = os.path.expandvars(os.path.join('$IDF_PATH', 'components'))
PARTTOOL_DIR = os.path.join(COMPONENTS_PATH, "partition_table") PARTTOOL_DIR = os.path.join(COMPONENTS_PATH, 'partition_table')
sys.path.append(PARTTOOL_DIR) sys.path.append(PARTTOOL_DIR)
from parttool import PartitionName, PartitionType, ParttoolTarget, PARTITION_TABLE_OFFSET from parttool import PARTITION_TABLE_OFFSET, PartitionName, PartitionType, ParttoolTarget
__version__ = '2.0' __version__ = '2.0'
@ -48,7 +49,7 @@ def status(msg):
class OtatoolTarget(): class OtatoolTarget():
OTADATA_PARTITION = PartitionType("data", "ota") OTADATA_PARTITION = PartitionType('data', 'ota')
def __init__(self, port=None, baud=None, partition_table_offset=PARTITION_TABLE_OFFSET, partition_table_file=None, def __init__(self, port=None, baud=None, partition_table_offset=PARTITION_TABLE_OFFSET, partition_table_file=None,
spi_flash_sec_size=SPI_FLASH_SEC_SIZE, esptool_args=[], esptool_write_args=[], spi_flash_sec_size=SPI_FLASH_SEC_SIZE, esptool_args=[], esptool_write_args=[],
@ -61,14 +62,14 @@ class OtatoolTarget():
temp_file.close() temp_file.close()
try: try:
self.target.read_partition(OtatoolTarget.OTADATA_PARTITION, temp_file.name) self.target.read_partition(OtatoolTarget.OTADATA_PARTITION, temp_file.name)
with open(temp_file.name, "rb") as f: with open(temp_file.name, 'rb') as f:
self.otadata = f.read() self.otadata = f.read()
finally: finally:
os.unlink(temp_file.name) os.unlink(temp_file.name)
def _check_otadata_partition(self): def _check_otadata_partition(self):
if not self.otadata: if not self.otadata:
raise Exception("No otadata partition found") raise Exception('No otadata partition found')
def erase_otadata(self): def erase_otadata(self):
self._check_otadata_partition() self._check_otadata_partition()
@ -77,7 +78,7 @@ class OtatoolTarget():
def _get_otadata_info(self): def _get_otadata_info(self):
info = [] info = []
otadata_info = collections.namedtuple("otadata_info", "seq crc") otadata_info = collections.namedtuple('otadata_info', 'seq crc')
for i in range(2): for i in range(2):
start = i * (self.spi_flash_sec_size >> 1) start = i * (self.spi_flash_sec_size >> 1)
@ -94,7 +95,7 @@ class OtatoolTarget():
def _get_partition_id_from_ota_id(self, ota_id): def _get_partition_id_from_ota_id(self, ota_id):
if isinstance(ota_id, int): if isinstance(ota_id, int):
return PartitionType("app", "ota_" + str(ota_id)) return PartitionType('app', 'ota_' + str(ota_id))
else: else:
return PartitionName(ota_id) return PartitionName(ota_id)
@ -106,7 +107,7 @@ class OtatoolTarget():
def is_otadata_info_valid(status): def is_otadata_info_valid(status):
seq = status.seq % (1 << 32) seq = status.seq % (1 << 32)
crc = hex(binascii.crc32(struct.pack("I", seq), 0xFFFFFFFF) % (1 << 32)) crc = hex(binascii.crc32(struct.pack('I', seq), 0xFFFFFFFF) % (1 << 32))
return seq < (int('0xFFFFFFFF', 16) % (1 << 32)) and status.crc == crc return seq < (int('0xFFFFFFFF', 16) % (1 << 32)) and status.crc == crc
partition_table = self.target.partition_table partition_table = self.target.partition_table
@ -124,7 +125,7 @@ class OtatoolTarget():
ota_partitions = sorted(ota_partitions, key=lambda p: p.subtype) ota_partitions = sorted(ota_partitions, key=lambda p: p.subtype)
if not ota_partitions: if not ota_partitions:
raise Exception("No ota app partitions found") raise Exception('No ota app partitions found')
# Look for the app partition to switch to # Look for the app partition to switch to
ota_partition_next = None ota_partition_next = None
@ -137,7 +138,7 @@ class OtatoolTarget():
ota_partition_next = list(ota_partition_next)[0] ota_partition_next = list(ota_partition_next)[0]
except IndexError: except IndexError:
raise Exception("Partition to switch to not found") raise Exception('Partition to switch to not found')
otadata_info = self._get_otadata_info() otadata_info = self._get_otadata_info()
@ -177,15 +178,15 @@ class OtatoolTarget():
ota_seq_next = target_seq ota_seq_next = target_seq
# Create binary data from computed values # Create binary data from computed values
ota_seq_next = struct.pack("I", ota_seq_next) ota_seq_next = struct.pack('I', ota_seq_next)
ota_seq_crc_next = binascii.crc32(ota_seq_next, 0xFFFFFFFF) % (1 << 32) ota_seq_crc_next = binascii.crc32(ota_seq_next, 0xFFFFFFFF) % (1 << 32)
ota_seq_crc_next = struct.pack("I", ota_seq_crc_next) ota_seq_crc_next = struct.pack('I', ota_seq_crc_next)
temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file = tempfile.NamedTemporaryFile(delete=False)
temp_file.close() temp_file.close()
try: try:
with open(temp_file.name, "wb") as otadata_next_file: with open(temp_file.name, 'wb') as otadata_next_file:
start = (1 if otadata_compute_base == 0 else 0) * (self.spi_flash_sec_size >> 1) start = (1 if otadata_compute_base == 0 else 0) * (self.spi_flash_sec_size >> 1)
otadata_next_file.write(self.otadata) otadata_next_file.write(self.otadata)
@ -217,14 +218,14 @@ def _read_otadata(target):
otadata_info = target._get_otadata_info() otadata_info = target._get_otadata_info()
print(" {:8s} \t {:8s} | \t {:8s} \t {:8s}".format("OTA_SEQ", "CRC", "OTA_SEQ", "CRC")) print(' {:8s} \t {:8s} | \t {:8s} \t {:8s}'.format('OTA_SEQ', 'CRC', 'OTA_SEQ', 'CRC'))
print("Firmware: 0x{:8x} \t0x{:8x} | \t0x{:8x} \t 0x{:8x}".format(otadata_info[0].seq, otadata_info[0].crc, print('Firmware: 0x{:8x} \t0x{:8x} | \t0x{:8x} \t 0x{:8x}'.format(otadata_info[0].seq, otadata_info[0].crc,
otadata_info[1].seq, otadata_info[1].crc)) otadata_info[1].seq, otadata_info[1].crc))
def _erase_otadata(target): def _erase_otadata(target):
target.erase_otadata() target.erase_otadata()
status("Erased ota_data partition contents") status('Erased ota_data partition contents')
def _switch_ota_partition(target, ota_id): def _switch_ota_partition(target, ota_id):
@ -233,68 +234,68 @@ def _switch_ota_partition(target, ota_id):
def _read_ota_partition(target, ota_id, output): def _read_ota_partition(target, ota_id, output):
target.read_ota_partition(ota_id, output) target.read_ota_partition(ota_id, output)
status("Read ota partition contents to file {}".format(output)) status('Read ota partition contents to file {}'.format(output))
def _write_ota_partition(target, ota_id, input): def _write_ota_partition(target, ota_id, input):
target.write_ota_partition(ota_id, input) target.write_ota_partition(ota_id, input)
status("Written contents of file {} to ota partition".format(input)) status('Written contents of file {} to ota partition'.format(input))
def _erase_ota_partition(target, ota_id): def _erase_ota_partition(target, ota_id):
target.erase_ota_partition(ota_id) target.erase_ota_partition(ota_id)
status("Erased contents of ota partition") status('Erased contents of ota partition')
def main(): def main():
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
print("WARNING: Support for Python 2 is deprecated and will be removed in future versions.", file=sys.stderr) print('WARNING: Support for Python 2 is deprecated and will be removed in future versions.', file=sys.stderr)
elif sys.version_info[0] == 3 and sys.version_info[1] < 6: elif sys.version_info[0] == 3 and sys.version_info[1] < 6:
print("WARNING: Python 3 versions older than 3.6 are not supported.", file=sys.stderr) print('WARNING: Python 3 versions older than 3.6 are not supported.', file=sys.stderr)
global quiet global quiet
parser = argparse.ArgumentParser("ESP-IDF OTA Partitions Tool") parser = argparse.ArgumentParser('ESP-IDF OTA Partitions Tool')
parser.add_argument("--quiet", "-q", help="suppress stderr messages", action="store_true") parser.add_argument('--quiet', '-q', help='suppress stderr messages', action='store_true')
parser.add_argument("--esptool-args", help="additional main arguments for esptool", nargs="+") parser.add_argument('--esptool-args', help='additional main arguments for esptool', nargs='+')
parser.add_argument("--esptool-write-args", help="additional subcommand arguments for esptool write_flash", nargs="+") parser.add_argument('--esptool-write-args', help='additional subcommand arguments for esptool write_flash', nargs='+')
parser.add_argument("--esptool-read-args", help="additional subcommand arguments for esptool read_flash", nargs="+") parser.add_argument('--esptool-read-args', help='additional subcommand arguments for esptool read_flash', nargs='+')
parser.add_argument("--esptool-erase-args", help="additional subcommand arguments for esptool erase_region", nargs="+") parser.add_argument('--esptool-erase-args', help='additional subcommand arguments for esptool erase_region', nargs='+')
# There are two possible sources for the partition table: a device attached to the host # There are two possible sources for the partition table: a device attached to the host
# or a partition table CSV/binary file. These sources are mutually exclusive. # or a partition table CSV/binary file. These sources are mutually exclusive.
parser.add_argument("--port", "-p", help="port where the device to read the partition table from is attached") parser.add_argument('--port', '-p', help='port where the device to read the partition table from is attached')
parser.add_argument("--baud", "-b", help="baudrate to use", type=int) parser.add_argument('--baud', '-b', help='baudrate to use', type=int)
parser.add_argument("--partition-table-offset", "-o", help="offset to read the partition table from", type=str) parser.add_argument('--partition-table-offset', '-o', help='offset to read the partition table from', type=str)
parser.add_argument("--partition-table-file", "-f", help="file (CSV/binary) to read the partition table from; \ parser.add_argument('--partition-table-file', '-f', help='file (CSV/binary) to read the partition table from; \
overrides device attached to specified port as the partition table source when defined") overrides device attached to specified port as the partition table source when defined')
subparsers = parser.add_subparsers(dest="operation", help="run otatool -h for additional help") subparsers = parser.add_subparsers(dest='operation', help='run otatool -h for additional help')
spi_flash_sec_size = argparse.ArgumentParser(add_help=False) spi_flash_sec_size = argparse.ArgumentParser(add_help=False)
spi_flash_sec_size.add_argument("--spi-flash-sec-size", help="value of SPI_FLASH_SEC_SIZE macro", type=str) spi_flash_sec_size.add_argument('--spi-flash-sec-size', help='value of SPI_FLASH_SEC_SIZE macro', type=str)
# Specify the supported operations # Specify the supported operations
subparsers.add_parser("read_otadata", help="read otadata partition", parents=[spi_flash_sec_size]) subparsers.add_parser('read_otadata', help='read otadata partition', parents=[spi_flash_sec_size])
subparsers.add_parser("erase_otadata", help="erase otadata partition") subparsers.add_parser('erase_otadata', help='erase otadata partition')
slot_or_name_parser = argparse.ArgumentParser(add_help=False) slot_or_name_parser = argparse.ArgumentParser(add_help=False)
slot_or_name_parser_args = slot_or_name_parser.add_mutually_exclusive_group() slot_or_name_parser_args = slot_or_name_parser.add_mutually_exclusive_group()
slot_or_name_parser_args.add_argument("--slot", help="slot number of the ota partition", type=int) slot_or_name_parser_args.add_argument('--slot', help='slot number of the ota partition', type=int)
slot_or_name_parser_args.add_argument("--name", help="name of the ota partition") slot_or_name_parser_args.add_argument('--name', help='name of the ota partition')
subparsers.add_parser("switch_ota_partition", help="switch otadata partition", parents=[slot_or_name_parser, spi_flash_sec_size]) subparsers.add_parser('switch_ota_partition', help='switch otadata partition', parents=[slot_or_name_parser, spi_flash_sec_size])
read_ota_partition_subparser = subparsers.add_parser("read_ota_partition", help="read contents of an ota partition", parents=[slot_or_name_parser]) read_ota_partition_subparser = subparsers.add_parser('read_ota_partition', help='read contents of an ota partition', parents=[slot_or_name_parser])
read_ota_partition_subparser.add_argument("--output", help="file to write the contents of the ota partition to") read_ota_partition_subparser.add_argument('--output', help='file to write the contents of the ota partition to')
write_ota_partition_subparser = subparsers.add_parser("write_ota_partition", help="write contents to an ota partition", parents=[slot_or_name_parser]) write_ota_partition_subparser = subparsers.add_parser('write_ota_partition', help='write contents to an ota partition', parents=[slot_or_name_parser])
write_ota_partition_subparser.add_argument("--input", help="file whose contents to write to the ota partition") write_ota_partition_subparser.add_argument('--input', help='file whose contents to write to the ota partition')
subparsers.add_parser("erase_ota_partition", help="erase contents of an ota partition", parents=[slot_or_name_parser]) subparsers.add_parser('erase_ota_partition', help='erase contents of an ota partition', parents=[slot_or_name_parser])
args = parser.parse_args() args = parser.parse_args()
@ -309,34 +310,34 @@ def main():
target_args = {} target_args = {}
if args.port: if args.port:
target_args["port"] = args.port target_args['port'] = args.port
if args.partition_table_file: if args.partition_table_file:
target_args["partition_table_file"] = args.partition_table_file target_args['partition_table_file'] = args.partition_table_file
if args.partition_table_offset: if args.partition_table_offset:
target_args["partition_table_offset"] = int(args.partition_table_offset, 0) target_args['partition_table_offset'] = int(args.partition_table_offset, 0)
try: try:
if args.spi_flash_sec_size: if args.spi_flash_sec_size:
target_args["spi_flash_sec_size"] = int(args.spi_flash_sec_size, 0) target_args['spi_flash_sec_size'] = int(args.spi_flash_sec_size, 0)
except AttributeError: except AttributeError:
pass pass
if args.esptool_args: if args.esptool_args:
target_args["esptool_args"] = args.esptool_args target_args['esptool_args'] = args.esptool_args
if args.esptool_write_args: if args.esptool_write_args:
target_args["esptool_write_args"] = args.esptool_write_args target_args['esptool_write_args'] = args.esptool_write_args
if args.esptool_read_args: if args.esptool_read_args:
target_args["esptool_read_args"] = args.esptool_read_args target_args['esptool_read_args'] = args.esptool_read_args
if args.esptool_erase_args: if args.esptool_erase_args:
target_args["esptool_erase_args"] = args.esptool_erase_args target_args['esptool_erase_args'] = args.esptool_erase_args
if args.baud: if args.baud:
target_args["baud"] = args.baud target_args['baud'] = args.baud
target = OtatoolTarget(**target_args) target = OtatoolTarget(**target_args)
@ -347,10 +348,10 @@ def main():
try: try:
if args.name is not None: if args.name is not None:
ota_id = ["name"] ota_id = ['name']
else: else:
if args.slot is not None: if args.slot is not None:
ota_id = ["slot"] ota_id = ['slot']
except AttributeError: except AttributeError:
pass pass
@ -358,8 +359,8 @@ def main():
'read_otadata':(_read_otadata, []), 'read_otadata':(_read_otadata, []),
'erase_otadata':(_erase_otadata, []), 'erase_otadata':(_erase_otadata, []),
'switch_ota_partition':(_switch_ota_partition, ota_id), 'switch_ota_partition':(_switch_ota_partition, ota_id),
'read_ota_partition':(_read_ota_partition, ["output"] + ota_id), 'read_ota_partition':(_read_ota_partition, ['output'] + ota_id),
'write_ota_partition':(_write_ota_partition, ["input"] + ota_id), 'write_ota_partition':(_write_ota_partition, ['input'] + ota_id),
'erase_ota_partition':(_erase_ota_partition, ota_id) 'erase_ota_partition':(_erase_ota_partition, ota_id)
} }

View File

@ -17,18 +17,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function, division from __future__ import division, print_function
import argparse import argparse
import hashlib
import os import os
import re import re
import sys import sys
import hashlib
__version__ = '1.0' __version__ = '1.0'
quiet = False quiet = False
max_blk_len = 256 max_blk_len = 256
idf_target = "esp32" idf_target = 'esp32'
copyright = '''// Copyright 2017-2020 Espressif Systems (Shanghai) PTE LTD copyright = '''// Copyright 2017-2020 Espressif Systems (Shanghai) PTE LTD
// //
@ -61,7 +62,7 @@ def critical(msg):
class FuseTable(list): class FuseTable(list):
def __init__(self): def __init__(self):
super(FuseTable, self).__init__(self) super(FuseTable, self).__init__(self)
self.md5_digest_table = "" self.md5_digest_table = ''
@classmethod @classmethod
def from_csv(cls, csv_contents): def from_csv(cls, csv_contents):
@ -77,14 +78,14 @@ class FuseTable(list):
for line_no in range(len(lines)): for line_no in range(len(lines)):
line = expand_vars(lines[line_no]).strip() line = expand_vars(lines[line_no]).strip()
if line.startswith("#") or len(line) == 0: if line.startswith('#') or len(line) == 0:
continue continue
try: try:
res.append(FuseDefinition.from_csv(line)) res.append(FuseDefinition.from_csv(line))
except InputError as e: except InputError as e:
raise InputError("Error at line %d: %s" % (line_no + 1, e)) raise InputError('Error at line %d: %s' % (line_no + 1, e))
except Exception: except Exception:
critical("Unexpected error parsing line %d: %s" % (line_no + 1, line)) critical('Unexpected error parsing line %d: %s' % (line_no + 1, line))
raise raise
# fix up missing bit_start # fix up missing bit_start
@ -102,9 +103,9 @@ class FuseTable(list):
# fix up missing field_name # fix up missing field_name
last_field = None last_field = None
for e in res: for e in res:
if e.field_name == "" and last_field is None: if e.field_name == '' and last_field is None:
raise InputError("Error at line %d: %s missing field name" % (line_no + 1, e)) raise InputError('Error at line %d: %s missing field name' % (line_no + 1, e))
elif e.field_name == "" and last_field is not None: elif e.field_name == '' and last_field is not None:
e.field_name = last_field.field_name e.field_name = last_field.field_name
last_field = e last_field = e
@ -136,12 +137,12 @@ class FuseTable(list):
fl_error = False fl_error = False
for p in self: for p in self:
field_name = p.field_name + p.group field_name = p.field_name + p.group
if field_name != "" and len(duplicates.intersection([field_name])) != 0: if field_name != '' and len(duplicates.intersection([field_name])) != 0:
fl_error = True fl_error = True
print("Field at %s, %s, %s, %s have dublicate field_name" % print('Field at %s, %s, %s, %s have dublicate field_name' %
(p.field_name, p.efuse_block, p.bit_start, p.bit_count)) (p.field_name, p.efuse_block, p.bit_start, p.bit_count))
if fl_error is True: if fl_error is True:
raise InputError("Field names must be unique") raise InputError('Field names must be unique')
def verify(self, type_table=None): def verify(self, type_table=None):
for p in self: for p in self:
@ -153,7 +154,7 @@ class FuseTable(list):
last = None last = None
for p in sorted(self, key=lambda x:(x.efuse_block, x.bit_start)): for p in sorted(self, key=lambda x:(x.efuse_block, x.bit_start)):
if last is not None and last.efuse_block == p.efuse_block and p.bit_start < last.bit_start + last.bit_count: if last is not None and last.efuse_block == p.efuse_block and p.bit_start < last.bit_start + last.bit_count:
raise InputError("Field at %s, %s, %s, %s overlaps %s, %s, %s, %s" % raise InputError('Field at %s, %s, %s, %s overlaps %s, %s, %s, %s' %
(p.field_name, p.efuse_block, p.bit_start, p.bit_count, (p.field_name, p.efuse_block, p.bit_start, p.bit_count,
last.field_name, last.efuse_block, last.bit_start, last.bit_count)) last.field_name, last.efuse_block, last.bit_start, last.bit_count))
last = p last = p
@ -161,7 +162,7 @@ class FuseTable(list):
def calc_md5(self): def calc_md5(self):
txt_table = '' txt_table = ''
for p in self: for p in self:
txt_table += "%s %s %d %s %s" % (p.field_name, p.efuse_block, p.bit_start, str(p.get_bit_count()), p.comment) + "\n" txt_table += '%s %s %d %s %s' % (p.field_name, p.efuse_block, p.bit_start, str(p.get_bit_count()), p.comment) + '\n'
self.md5_digest_table = hashlib.md5(txt_table.encode('utf-8')).hexdigest() self.md5_digest_table = hashlib.md5(txt_table.encode('utf-8')).hexdigest()
def show_range_used_bits(self): def show_range_used_bits(self):
@ -169,9 +170,9 @@ class FuseTable(list):
rows = '' rows = ''
rows += 'Sorted efuse table:\n' rows += 'Sorted efuse table:\n'
num = 1 num = 1
rows += "{0} \t{1:<30} \t{2} \t{3} \t{4}".format("#", "field_name", "efuse_block", "bit_start", "bit_count") + "\n" rows += '{0} \t{1:<30} \t{2} \t{3} \t{4}'.format('#', 'field_name', 'efuse_block', 'bit_start', 'bit_count') + '\n'
for p in sorted(self, key=lambda x:(x.efuse_block, x.bit_start)): for p in sorted(self, key=lambda x:(x.efuse_block, x.bit_start)):
rows += "{0} \t{1:<30} \t{2} \t{3:^8} \t{4:^8}".format(num, p.field_name, p.efuse_block, p.bit_start, p.bit_count) + "\n" rows += '{0} \t{1:<30} \t{2} \t{3:^8} \t{4:^8}'.format(num, p.field_name, p.efuse_block, p.bit_start, p.bit_count) + '\n'
num += 1 num += 1
rows += '\nUsed bits in efuse table:\n' rows += '\nUsed bits in efuse table:\n'
@ -204,30 +205,30 @@ class FuseTable(list):
def to_header(self, file_name): def to_header(self, file_name):
rows = [copyright] rows = [copyright]
rows += ["#ifdef __cplusplus", rows += ['#ifdef __cplusplus',
'extern "C" {', 'extern "C" {',
"#endif", '#endif',
"", '',
"", '',
"// md5_digest_table " + self.md5_digest_table, '// md5_digest_table ' + self.md5_digest_table,
"// This file was generated from the file " + file_name + ".csv. DO NOT CHANGE THIS FILE MANUALLY.", '// This file was generated from the file ' + file_name + '.csv. DO NOT CHANGE THIS FILE MANUALLY.',
"// If you want to change some fields, you need to change " + file_name + ".csv file", '// If you want to change some fields, you need to change ' + file_name + '.csv file',
"// then run `efuse_common_table` or `efuse_custom_table` command it will generate this file.", '// then run `efuse_common_table` or `efuse_custom_table` command it will generate this file.',
"// To show efuse_table run the command 'show_efuse_table'.", "// To show efuse_table run the command 'show_efuse_table'.",
"", '',
""] '']
last_field_name = '' last_field_name = ''
for p in self: for p in self:
if (p.field_name != last_field_name): if (p.field_name != last_field_name):
rows += ["extern const esp_efuse_desc_t* " + "ESP_EFUSE_" + p.field_name + "[];"] rows += ['extern const esp_efuse_desc_t* ' + 'ESP_EFUSE_' + p.field_name + '[];']
last_field_name = p.field_name last_field_name = p.field_name
rows += ["", rows += ['',
"#ifdef __cplusplus", '#ifdef __cplusplus',
"}", '}',
"#endif", '#endif',
""] '']
return '\n'.join(rows) return '\n'.join(rows)
def to_c_file(self, file_name, debug): def to_c_file(self, file_name, debug):
@ -236,33 +237,33 @@ class FuseTable(list):
'#include "esp_efuse.h"', '#include "esp_efuse.h"',
'#include <assert.h>', '#include <assert.h>',
'#include "' + file_name + '.h"', '#include "' + file_name + '.h"',
"", '',
"// md5_digest_table " + self.md5_digest_table, '// md5_digest_table ' + self.md5_digest_table,
"// This file was generated from the file " + file_name + ".csv. DO NOT CHANGE THIS FILE MANUALLY.", '// This file was generated from the file ' + file_name + '.csv. DO NOT CHANGE THIS FILE MANUALLY.',
"// If you want to change some fields, you need to change " + file_name + ".csv file", '// If you want to change some fields, you need to change ' + file_name + '.csv file',
"// then run `efuse_common_table` or `efuse_custom_table` command it will generate this file.", '// then run `efuse_common_table` or `efuse_custom_table` command it will generate this file.',
"// To show efuse_table run the command 'show_efuse_table'."] "// To show efuse_table run the command 'show_efuse_table'."]
rows += [""] rows += ['']
if idf_target == "esp32": if idf_target == 'esp32':
rows += ["#define MAX_BLK_LEN CONFIG_EFUSE_MAX_BLK_LEN"] rows += ['#define MAX_BLK_LEN CONFIG_EFUSE_MAX_BLK_LEN']
rows += [""] rows += ['']
last_free_bit_blk1 = self.get_str_position_last_free_bit_in_blk("EFUSE_BLK1") last_free_bit_blk1 = self.get_str_position_last_free_bit_in_blk('EFUSE_BLK1')
last_free_bit_blk2 = self.get_str_position_last_free_bit_in_blk("EFUSE_BLK2") last_free_bit_blk2 = self.get_str_position_last_free_bit_in_blk('EFUSE_BLK2')
last_free_bit_blk3 = self.get_str_position_last_free_bit_in_blk("EFUSE_BLK3") last_free_bit_blk3 = self.get_str_position_last_free_bit_in_blk('EFUSE_BLK3')
rows += ["// The last free bit in the block is counted over the entire file."] rows += ['// The last free bit in the block is counted over the entire file.']
if last_free_bit_blk1 is not None: if last_free_bit_blk1 is not None:
rows += ["#define LAST_FREE_BIT_BLK1 " + last_free_bit_blk1] rows += ['#define LAST_FREE_BIT_BLK1 ' + last_free_bit_blk1]
if last_free_bit_blk2 is not None: if last_free_bit_blk2 is not None:
rows += ["#define LAST_FREE_BIT_BLK2 " + last_free_bit_blk2] rows += ['#define LAST_FREE_BIT_BLK2 ' + last_free_bit_blk2]
if last_free_bit_blk3 is not None: if last_free_bit_blk3 is not None:
rows += ["#define LAST_FREE_BIT_BLK3 " + last_free_bit_blk3] rows += ['#define LAST_FREE_BIT_BLK3 ' + last_free_bit_blk3]
rows += [""] rows += ['']
if last_free_bit_blk1 is not None: if last_free_bit_blk1 is not None:
rows += ['_Static_assert(LAST_FREE_BIT_BLK1 <= MAX_BLK_LEN, "The eFuse table does not match the coding scheme. ' rows += ['_Static_assert(LAST_FREE_BIT_BLK1 <= MAX_BLK_LEN, "The eFuse table does not match the coding scheme. '
@ -274,50 +275,50 @@ class FuseTable(list):
rows += ['_Static_assert(LAST_FREE_BIT_BLK3 <= MAX_BLK_LEN, "The eFuse table does not match the coding scheme. ' rows += ['_Static_assert(LAST_FREE_BIT_BLK3 <= MAX_BLK_LEN, "The eFuse table does not match the coding scheme. '
'Edit the table and restart the efuse_common_table or efuse_custom_table command to regenerate the new files.");'] 'Edit the table and restart the efuse_common_table or efuse_custom_table command to regenerate the new files.");']
rows += [""] rows += ['']
last_name = '' last_name = ''
for p in self: for p in self:
if (p.field_name != last_name): if (p.field_name != last_name):
if last_name != '': if last_name != '':
rows += ["};\n"] rows += ['};\n']
rows += ["static const esp_efuse_desc_t " + p.field_name + "[] = {"] rows += ['static const esp_efuse_desc_t ' + p.field_name + '[] = {']
last_name = p.field_name last_name = p.field_name
rows += [p.to_struct(debug) + ","] rows += [p.to_struct(debug) + ',']
rows += ["};\n"] rows += ['};\n']
rows += ["\n\n\n"] rows += ['\n\n\n']
last_name = '' last_name = ''
for p in self: for p in self:
if (p.field_name != last_name): if (p.field_name != last_name):
if last_name != '': if last_name != '':
rows += [" NULL", rows += [' NULL',
"};\n"] '};\n']
rows += ["const esp_efuse_desc_t* " + "ESP_EFUSE_" + p.field_name + "[] = {"] rows += ['const esp_efuse_desc_t* ' + 'ESP_EFUSE_' + p.field_name + '[] = {']
last_name = p.field_name last_name = p.field_name
index = str(0) if str(p.group) == "" else str(p.group) index = str(0) if str(p.group) == '' else str(p.group)
rows += [" &" + p.field_name + "[" + index + "], \t\t// " + p.comment] rows += [' &' + p.field_name + '[' + index + '], \t\t// ' + p.comment]
rows += [" NULL", rows += [' NULL',
"};\n"] '};\n']
return '\n'.join(rows) return '\n'.join(rows)
class FuseDefinition(object): class FuseDefinition(object):
def __init__(self): def __init__(self):
self.field_name = "" self.field_name = ''
self.group = "" self.group = ''
self.efuse_block = "" self.efuse_block = ''
self.bit_start = None self.bit_start = None
self.bit_count = None self.bit_count = None
self.define = None self.define = None
self.comment = "" self.comment = ''
@classmethod @classmethod
def from_csv(cls, line): def from_csv(cls, line):
""" Parse a line from the CSV """ """ Parse a line from the CSV """
line_w_defaults = line + ",,,," # lazy way to support default fields line_w_defaults = line + ',,,,' # lazy way to support default fields
fields = [f.strip() for f in line_w_defaults.split(",")] fields = [f.strip() for f in line_w_defaults.split(',')]
res = FuseDefinition() res = FuseDefinition()
res.field_name = fields[0] res.field_name = fields[0]
@ -330,12 +331,12 @@ class FuseDefinition(object):
return res return res
def parse_num(self, strval): def parse_num(self, strval):
if strval == "": if strval == '':
return None # Field will fill in default return None # Field will fill in default
return self.parse_int(strval) return self.parse_int(strval)
def parse_bit_count(self, strval): def parse_bit_count(self, strval):
if strval == "MAX_BLK_LEN": if strval == 'MAX_BLK_LEN':
self.define = strval self.define = strval
return self.get_max_bits_of_block() return self.get_max_bits_of_block()
else: else:
@ -345,18 +346,18 @@ class FuseDefinition(object):
try: try:
return int(v, 0) return int(v, 0)
except ValueError: except ValueError:
raise InputError("Invalid field value %s" % v) raise InputError('Invalid field value %s' % v)
def parse_block(self, strval): def parse_block(self, strval):
if strval == "": if strval == '':
raise InputError("Field 'efuse_block' can't be left empty.") raise InputError("Field 'efuse_block' can't be left empty.")
if idf_target == "esp32": if idf_target == 'esp32':
if strval not in ["EFUSE_BLK0", "EFUSE_BLK1", "EFUSE_BLK2", "EFUSE_BLK3"]: if strval not in ['EFUSE_BLK0', 'EFUSE_BLK1', 'EFUSE_BLK2', 'EFUSE_BLK3']:
raise InputError("Field 'efuse_block' should be one of EFUSE_BLK0..EFUSE_BLK3") raise InputError("Field 'efuse_block' should be one of EFUSE_BLK0..EFUSE_BLK3")
else: else:
if strval not in ["EFUSE_BLK0", "EFUSE_BLK1", "EFUSE_BLK2", "EFUSE_BLK3", "EFUSE_BLK4", if strval not in ['EFUSE_BLK0', 'EFUSE_BLK1', 'EFUSE_BLK2', 'EFUSE_BLK3', 'EFUSE_BLK4',
"EFUSE_BLK5", "EFUSE_BLK6", "EFUSE_BLK7", "EFUSE_BLK8", "EFUSE_BLK9", 'EFUSE_BLK5', 'EFUSE_BLK6', 'EFUSE_BLK7', 'EFUSE_BLK8', 'EFUSE_BLK9',
"EFUSE_BLK10"]: 'EFUSE_BLK10']:
raise InputError("Field 'efuse_block' should be one of EFUSE_BLK0..EFUSE_BLK10") raise InputError("Field 'efuse_block' should be one of EFUSE_BLK0..EFUSE_BLK10")
return strval return strval
@ -365,32 +366,32 @@ class FuseDefinition(object):
'''common_table: EFUSE_BLK0, EFUSE_BLK1, EFUSE_BLK2, EFUSE_BLK3 '''common_table: EFUSE_BLK0, EFUSE_BLK1, EFUSE_BLK2, EFUSE_BLK3
custom_table: ----------, ----------, ----------, EFUSE_BLK3(some reserved in common_table) custom_table: ----------, ----------, ----------, EFUSE_BLK3(some reserved in common_table)
''' '''
if self.efuse_block == "EFUSE_BLK0": if self.efuse_block == 'EFUSE_BLK0':
return 256 return 256
else: else:
return max_blk_len return max_blk_len
def verify(self, type_table): def verify(self, type_table):
if self.efuse_block is None: if self.efuse_block is None:
raise ValidationError(self, "efuse_block field is not set") raise ValidationError(self, 'efuse_block field is not set')
if self.bit_count is None: if self.bit_count is None:
raise ValidationError(self, "bit_count field is not set") raise ValidationError(self, 'bit_count field is not set')
if type_table is not None: if type_table is not None:
if type_table == "custom_table": if type_table == 'custom_table':
if self.efuse_block != "EFUSE_BLK3": if self.efuse_block != 'EFUSE_BLK3':
raise ValidationError(self, "custom_table should use only EFUSE_BLK3") raise ValidationError(self, 'custom_table should use only EFUSE_BLK3')
max_bits = self.get_max_bits_of_block() max_bits = self.get_max_bits_of_block()
if self.bit_start + self.bit_count > max_bits: if self.bit_start + self.bit_count > max_bits:
raise ValidationError(self, "The field is outside the boundaries(max_bits = %d) of the %s block" % (max_bits, self.efuse_block)) raise ValidationError(self, 'The field is outside the boundaries(max_bits = %d) of the %s block' % (max_bits, self.efuse_block))
def get_full_name(self): def get_full_name(self):
def get_postfix(group): def get_postfix(group):
postfix = "" postfix = ''
if group != "": if group != '':
postfix = "_PART_" + group postfix = '_PART_' + group
return postfix return postfix
return self.field_name + get_postfix(self.group) return self.field_name + get_postfix(self.group)
@ -402,19 +403,19 @@ class FuseDefinition(object):
return self.bit_count return self.bit_count
def to_struct(self, debug): def to_struct(self, debug):
start = " {" start = ' {'
if debug is True: if debug is True:
start = " {" + '"' + self.field_name + '" ,' start = ' {' + '"' + self.field_name + '" ,'
return ", ".join([start + self.efuse_block, return ', '.join([start + self.efuse_block,
str(self.bit_start), str(self.bit_start),
str(self.get_bit_count()) + "}, \t // " + self.comment]) str(self.get_bit_count()) + '}, \t // ' + self.comment])
def process_input_file(file, type_table): def process_input_file(file, type_table):
status("Parsing efuse CSV input file " + file.name + " ...") status('Parsing efuse CSV input file ' + file.name + ' ...')
input = file.read() input = file.read()
table = FuseTable.from_csv(input) table = FuseTable.from_csv(input)
status("Verifying efuse table...") status('Verifying efuse table...')
table.verify(type_table) table.verify(type_table)
return table return table
@ -432,35 +433,35 @@ def create_output_files(name, output_table, debug):
file_name = os.path.splitext(os.path.basename(name))[0] file_name = os.path.splitext(os.path.basename(name))[0]
gen_dir = os.path.dirname(name) gen_dir = os.path.dirname(name)
dir_for_file_h = gen_dir + "/include" dir_for_file_h = gen_dir + '/include'
try: try:
os.stat(dir_for_file_h) os.stat(dir_for_file_h)
except Exception: except Exception:
os.mkdir(dir_for_file_h) os.mkdir(dir_for_file_h)
file_h_path = os.path.join(dir_for_file_h, file_name + ".h") file_h_path = os.path.join(dir_for_file_h, file_name + '.h')
file_c_path = os.path.join(gen_dir, file_name + ".c") file_c_path = os.path.join(gen_dir, file_name + '.c')
# src files are the same # src files are the same
if ckeck_md5_in_file(output_table.md5_digest_table, file_c_path) is False: if ckeck_md5_in_file(output_table.md5_digest_table, file_c_path) is False:
status("Creating efuse *.h file " + file_h_path + " ...") status('Creating efuse *.h file ' + file_h_path + ' ...')
output = output_table.to_header(file_name) output = output_table.to_header(file_name)
with open(file_h_path, 'w') as f: with open(file_h_path, 'w') as f:
f.write(output) f.write(output)
status("Creating efuse *.c file " + file_c_path + " ...") status('Creating efuse *.c file ' + file_c_path + ' ...')
output = output_table.to_c_file(file_name, debug) output = output_table.to_c_file(file_name, debug)
with open(file_c_path, 'w') as f: with open(file_c_path, 'w') as f:
f.write(output) f.write(output)
else: else:
print("Source files do not require updating correspond to csv file.") print('Source files do not require updating correspond to csv file.')
def main(): def main():
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
print("WARNING: Support for Python 2 is deprecated and will be removed in future versions.", file=sys.stderr) print('WARNING: Support for Python 2 is deprecated and will be removed in future versions.', file=sys.stderr)
elif sys.version_info[0] == 3 and sys.version_info[1] < 6: elif sys.version_info[0] == 3 and sys.version_info[1] < 6:
print("WARNING: Python 3 versions older than 3.6 are not supported.", file=sys.stderr) print('WARNING: Python 3 versions older than 3.6 are not supported.', file=sys.stderr)
global quiet global quiet
global max_blk_len global max_blk_len
global idf_target global idf_target
@ -468,8 +469,8 @@ def main():
parser = argparse.ArgumentParser(description='ESP32 eFuse Manager') parser = argparse.ArgumentParser(description='ESP32 eFuse Manager')
parser.add_argument('--idf_target', '-t', help='Target chip type', choices=['esp32', 'esp32s2', 'esp32s3', 'esp32c3'], default='esp32') parser.add_argument('--idf_target', '-t', help='Target chip type', choices=['esp32', 'esp32s2', 'esp32s3', 'esp32c3'], default='esp32')
parser.add_argument('--quiet', '-q', help="Don't print non-critical status messages to stderr", action='store_true') parser.add_argument('--quiet', '-q', help="Don't print non-critical status messages to stderr", action='store_true')
parser.add_argument('--debug', help='Create header file with debug info', default=False, action="store_false") parser.add_argument('--debug', help='Create header file with debug info', default=False, action='store_false')
parser.add_argument('--info', help='Print info about range of used bits', default=False, action="store_true") parser.add_argument('--info', help='Print info about range of used bits', default=False, action='store_true')
parser.add_argument('--max_blk_len', help='Max number of bits in BLOCKs', type=int, default=256) parser.add_argument('--max_blk_len', help='Max number of bits in BLOCKs', type=int, default=256)
parser.add_argument('common_input', help='Path to common CSV file to parse.', type=argparse.FileType('r')) parser.add_argument('common_input', help='Path to common CSV file to parse.', type=argparse.FileType('r'))
parser.add_argument('custom_input', help='Path to custom CSV file to parse.', type=argparse.FileType('r'), nargs='?', default=None) parser.add_argument('custom_input', help='Path to custom CSV file to parse.', type=argparse.FileType('r'), nargs='?', default=None)
@ -479,18 +480,18 @@ def main():
idf_target = args.idf_target idf_target = args.idf_target
max_blk_len = args.max_blk_len max_blk_len = args.max_blk_len
print("Max number of bits in BLK %d" % (max_blk_len)) print('Max number of bits in BLK %d' % (max_blk_len))
if max_blk_len not in [256, 192, 128]: if max_blk_len not in [256, 192, 128]:
raise InputError("Unsupported block length = %d" % (max_blk_len)) raise InputError('Unsupported block length = %d' % (max_blk_len))
quiet = args.quiet quiet = args.quiet
debug = args.debug debug = args.debug
info = args.info info = args.info
common_table = process_input_file(args.common_input, "common_table") common_table = process_input_file(args.common_input, 'common_table')
two_table = common_table two_table = common_table
if args.custom_input is not None: if args.custom_input is not None:
custom_table = process_input_file(args.custom_input, "custom_table") custom_table = process_input_file(args.custom_input, 'custom_table')
two_table += custom_table two_table += custom_table
two_table.verify() two_table.verify()
@ -512,7 +513,7 @@ class InputError(RuntimeError):
class ValidationError(InputError): class ValidationError(InputError):
def __init__(self, p, message): def __init__(self, p, message):
super(ValidationError, self).__init__("Entry %s invalid: %s" % (p.field_name, message)) super(ValidationError, self).__init__('Entry %s invalid: %s' % (p.field_name, message))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,12 +1,13 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function, division from __future__ import division, print_function
import unittest
import sys import sys
import unittest
try: try:
import efuse_table_gen import efuse_table_gen
except ImportError: except ImportError:
sys.path.append("..") sys.path.append('..')
import efuse_table_gen import efuse_table_gen
@ -117,7 +118,7 @@ name2, EFUSE_BLK2, ,
, EFUSE_BLK2, , 4, , EFUSE_BLK2, , 4,
name1, EFUSE_BLK3, , 5, name1, EFUSE_BLK3, , 5,
""" """
with self.assertRaisesRegex(efuse_table_gen.InputError, "Field names must be unique"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'Field names must be unique'):
efuse_table_gen.FuseTable.from_csv(csv) efuse_table_gen.FuseTable.from_csv(csv)
def test_seq_bit_start5_fill(self): def test_seq_bit_start5_fill(self):
@ -154,7 +155,7 @@ name1, EFUSE_BLK3, 1,
name2, EFUSE_BLK3, 5, 4, Use for test name 2 name2, EFUSE_BLK3, 5, 4, Use for test name 2
""" """
t = efuse_table_gen.FuseTable.from_csv(csv) t = efuse_table_gen.FuseTable.from_csv(csv)
with self.assertRaisesRegex(efuse_table_gen.InputError, "overlap"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'overlap'):
t.verify() t.verify()
def test_empty_field_name_fail(self): def test_empty_field_name_fail(self):
@ -163,7 +164,7 @@ name2, EFUSE_BLK3, 5,
, EFUSE_BLK3, , 5, , EFUSE_BLK3, , 5,
name2, EFUSE_BLK2, , 4, name2, EFUSE_BLK2, , 4,
""" """
with self.assertRaisesRegex(efuse_table_gen.InputError, "missing field name"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'missing field name'):
efuse_table_gen.FuseTable.from_csv(csv) efuse_table_gen.FuseTable.from_csv(csv)
def test_unique_field_name_fail(self): def test_unique_field_name_fail(self):
@ -172,7 +173,7 @@ name2, EFUSE_BLK2, ,
name1, EFUSE_BLK3, 0, 5, Use for test name 1 name1, EFUSE_BLK3, 0, 5, Use for test name 1
name1, EFUSE_BLK3, 5, 4, Use for test name 2 name1, EFUSE_BLK3, 5, 4, Use for test name 2
""" """
with self.assertRaisesRegex(efuse_table_gen.InputError, "Field names must be unique"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'Field names must be unique'):
efuse_table_gen.FuseTable.from_csv(csv) efuse_table_gen.FuseTable.from_csv(csv)
def test_bit_count_empty_fail(self): def test_bit_count_empty_fail(self):
@ -181,7 +182,7 @@ name1, EFUSE_BLK3, 5,
name1, EFUSE_BLK3, 0, , Use for test name 1 name1, EFUSE_BLK3, 0, , Use for test name 1
name2, EFUSE_BLK3, 5, 4, Use for test name 2 name2, EFUSE_BLK3, 5, 4, Use for test name 2
""" """
with self.assertRaisesRegex(efuse_table_gen.InputError, "empty"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'empty'):
efuse_table_gen.FuseTable.from_csv(csv) efuse_table_gen.FuseTable.from_csv(csv)
def test_bit_start_num_fail(self): def test_bit_start_num_fail(self):
@ -190,7 +191,7 @@ name2, EFUSE_BLK3, 5,
name1, EFUSE_BLK3, k, 5, Use for test name 1 name1, EFUSE_BLK3, k, 5, Use for test name 1
name2, EFUSE_BLK3, 5, 4, Use for test name 2 name2, EFUSE_BLK3, 5, 4, Use for test name 2
""" """
with self.assertRaisesRegex(efuse_table_gen.InputError, "Invalid field value"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'Invalid field value'):
efuse_table_gen.FuseTable.from_csv(csv) efuse_table_gen.FuseTable.from_csv(csv)
def test_join_entry(self): def test_join_entry(self):
@ -257,7 +258,7 @@ name2, EFUSE_BLK3, 191,
""" """
efuse_table_gen.max_blk_len = 192 efuse_table_gen.max_blk_len = 192
t = efuse_table_gen.FuseTable.from_csv(csv) t = efuse_table_gen.FuseTable.from_csv(csv)
with self.assertRaisesRegex(efuse_table_gen.InputError, "The field is outside the boundaries"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'The field is outside the boundaries'):
t.verify() t.verify()
def test_field_blk1_size_is_more(self): def test_field_blk1_size_is_more(self):
@ -267,7 +268,7 @@ name1, EFUSE_BLK0, 0,
name2, EFUSE_BLK1, 1, 256, Use for test name 2 name2, EFUSE_BLK1, 1, 256, Use for test name 2
""" """
t = efuse_table_gen.FuseTable.from_csv(csv) t = efuse_table_gen.FuseTable.from_csv(csv)
with self.assertRaisesRegex(efuse_table_gen.InputError, "The field is outside the boundaries"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'The field is outside the boundaries'):
t.verify() t.verify()
@ -311,8 +312,8 @@ name1, EFUSE_BLK3, 0,
name2, EFUSE_BLK2, 5, 4, Use for test name 2 name2, EFUSE_BLK2, 5, 4, Use for test name 2
""" """
t = efuse_table_gen.FuseTable.from_csv(csv) t = efuse_table_gen.FuseTable.from_csv(csv)
with self.assertRaisesRegex(efuse_table_gen.ValidationError, "custom_table should use only EFUSE_BLK3"): with self.assertRaisesRegex(efuse_table_gen.ValidationError, 'custom_table should use only EFUSE_BLK3'):
t.verify("custom_table") t.verify('custom_table')
def test_common_and_custom_table_use_the_same_bits(self): def test_common_and_custom_table_use_the_same_bits(self):
csv_common = """ csv_common = """
@ -321,7 +322,7 @@ name1, EFUSE_BLK3, 0,
name2, EFUSE_BLK2, 5, 4, Use for test name 2 name2, EFUSE_BLK2, 5, 4, Use for test name 2
""" """
common_table = efuse_table_gen.FuseTable.from_csv(csv_common) common_table = efuse_table_gen.FuseTable.from_csv(csv_common)
common_table.verify("common_table") common_table.verify('common_table')
two_tables = common_table two_tables = common_table
csv_custom = """ csv_custom = """
@ -330,12 +331,12 @@ name3, EFUSE_BLK3, 20,
name4, EFUSE_BLK3, 4, 1, Use for test name 2 name4, EFUSE_BLK3, 4, 1, Use for test name 2
""" """
custom_table = efuse_table_gen.FuseTable.from_csv(csv_custom) custom_table = efuse_table_gen.FuseTable.from_csv(csv_custom)
custom_table.verify("custom_table") custom_table.verify('custom_table')
two_tables += custom_table two_tables += custom_table
with self.assertRaisesRegex(efuse_table_gen.InputError, "overlaps"): with self.assertRaisesRegex(efuse_table_gen.InputError, 'overlaps'):
two_tables.verify() two_tables.verify()
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -2,13 +2,13 @@
import hashlib import hashlib
import hmac import hmac
import struct
import os import os
import random import random
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes import struct
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.utils import int_to_bytes from cryptography.utils import int_to_bytes
@ -19,9 +19,9 @@ def number_as_bignum_words(number):
""" """
result = [] result = []
while number != 0: while number != 0:
result.append("0x%08x" % (number & 0xFFFFFFFF)) result.append('0x%08x' % (number & 0xFFFFFFFF))
number >>= 32 number >>= 32
return "{ " + ", ".join(result) + " }" return '{ ' + ', '.join(result) + ' }'
def number_as_bytes(number, pad_bits=None): def number_as_bytes(number, pad_bits=None):
@ -38,7 +38,7 @@ def bytes_as_char_array(b):
""" """
Given a sequence of bytes, format as a char array Given a sequence of bytes, format as a char array
""" """
return "{ " + ", ".join("0x%02x" % x for x in b) + " }" return '{ ' + ', '.join('0x%02x' % x for x in b) + ' }'
NUM_HMAC_KEYS = 3 NUM_HMAC_KEYS = 3
@ -50,36 +50,36 @@ hmac_keys = [os.urandom(32) for x in range(NUM_HMAC_KEYS)]
messages = [random.randrange(0, 1 << 4096) for x in range(NUM_MESSAGES)] messages = [random.randrange(0, 1 << 4096) for x in range(NUM_MESSAGES)]
with open("digital_signature_test_cases.h", "w") as f: with open('digital_signature_test_cases.h', 'w') as f:
f.write("/* File generated by gen_digital_signature_tests.py */\n\n") f.write('/* File generated by gen_digital_signature_tests.py */\n\n')
# Write out HMAC keys # Write out HMAC keys
f.write("#define NUM_HMAC_KEYS %d\n\n" % NUM_HMAC_KEYS) f.write('#define NUM_HMAC_KEYS %d\n\n' % NUM_HMAC_KEYS)
f.write("static const uint8_t test_hmac_keys[NUM_HMAC_KEYS][32] = {\n") f.write('static const uint8_t test_hmac_keys[NUM_HMAC_KEYS][32] = {\n')
for h in hmac_keys: for h in hmac_keys:
f.write(" %s,\n" % bytes_as_char_array(h)) f.write(' %s,\n' % bytes_as_char_array(h))
f.write("};\n\n") f.write('};\n\n')
# Write out messages # Write out messages
f.write("#define NUM_MESSAGES %d\n\n" % NUM_MESSAGES) f.write('#define NUM_MESSAGES %d\n\n' % NUM_MESSAGES)
f.write("static const uint32_t test_messages[NUM_MESSAGES][4096/32] = {\n") f.write('static const uint32_t test_messages[NUM_MESSAGES][4096/32] = {\n')
for m in messages: for m in messages:
f.write(" // Message %d\n" % messages.index(m)) f.write(' // Message %d\n' % messages.index(m))
f.write(" %s,\n" % number_as_bignum_words(m)) f.write(' %s,\n' % number_as_bignum_words(m))
f.write(" };\n") f.write(' };\n')
f.write("\n\n\n") f.write('\n\n\n')
f.write("#define NUM_CASES %d\n\n" % NUM_CASES) f.write('#define NUM_CASES %d\n\n' % NUM_CASES)
f.write("static const encrypt_testcase_t test_cases[NUM_CASES] = {\n") f.write('static const encrypt_testcase_t test_cases[NUM_CASES] = {\n')
for case in range(NUM_CASES): for case in range(NUM_CASES):
f.write(" { /* Case %d */\n" % case) f.write(' { /* Case %d */\n' % case)
iv = os.urandom(16) iv = os.urandom(16)
f.write(" .iv = %s,\n" % (bytes_as_char_array(iv))) f.write(' .iv = %s,\n' % (bytes_as_char_array(iv)))
hmac_key_idx = random.randrange(0, NUM_HMAC_KEYS) hmac_key_idx = random.randrange(0, NUM_HMAC_KEYS)
aes_key = hmac.HMAC(hmac_keys[hmac_key_idx], b"\xFF" * 32, hashlib.sha256).digest() aes_key = hmac.HMAC(hmac_keys[hmac_key_idx], b'\xFF' * 32, hashlib.sha256).digest()
sizes = [4096, 3072, 2048, 1024, 512] sizes = [4096, 3072, 2048, 1024, 512]
key_size = sizes[case % len(sizes)] key_size = sizes[case % len(sizes)]
@ -100,13 +100,13 @@ with open("digital_signature_test_cases.h", "w") as f:
mprime &= 0xFFFFFFFF mprime &= 0xFFFFFFFF
length = key_size // 32 - 1 length = key_size // 32 - 1
f.write(" .p_data = {\n") f.write(' .p_data = {\n')
f.write(" .Y = %s,\n" % number_as_bignum_words(Y)) f.write(' .Y = %s,\n' % number_as_bignum_words(Y))
f.write(" .M = %s,\n" % number_as_bignum_words(M)) f.write(' .M = %s,\n' % number_as_bignum_words(M))
f.write(" .Rb = %s,\n" % number_as_bignum_words(rinv)) f.write(' .Rb = %s,\n' % number_as_bignum_words(rinv))
f.write(" .M_prime = 0x%08x,\n" % mprime) f.write(' .M_prime = 0x%08x,\n' % mprime)
f.write(" .length = %d, // %d bit\n" % (length, key_size)) f.write(' .length = %d, // %d bit\n' % (length, key_size))
f.write(" },\n") f.write(' },\n')
# calculate MD from preceding values and IV # calculate MD from preceding values and IV
@ -114,7 +114,7 @@ with open("digital_signature_test_cases.h", "w") as f:
md_in = number_as_bytes(Y, 4096) + \ md_in = number_as_bytes(Y, 4096) + \
number_as_bytes(M, 4096) + \ number_as_bytes(M, 4096) + \
number_as_bytes(rinv, 4096) + \ number_as_bytes(rinv, 4096) + \
struct.pack("<II", mprime, length) + \ struct.pack('<II', mprime, length) + \
iv iv
assert len(md_in) == 12480 / 8 assert len(md_in) == 12480 / 8
md = hashlib.sha256(md_in).digest() md = hashlib.sha256(md_in).digest()
@ -126,7 +126,7 @@ with open("digital_signature_test_cases.h", "w") as f:
number_as_bytes(M, 4096) + \ number_as_bytes(M, 4096) + \
number_as_bytes(rinv, 4096) + \ number_as_bytes(rinv, 4096) + \
md + \ md + \
struct.pack("<II", mprime, length) + \ struct.pack('<II', mprime, length) + \
b'\x08' * 8 b'\x08' * 8
assert len(p) == 12672 / 8 assert len(p) == 12672 / 8
@ -135,16 +135,16 @@ with open("digital_signature_test_cases.h", "w") as f:
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
c = encryptor.update(p) + encryptor.finalize() c = encryptor.update(p) + encryptor.finalize()
f.write(" .expected_c = %s,\n" % bytes_as_char_array(c)) f.write(' .expected_c = %s,\n' % bytes_as_char_array(c))
f.write(" .hmac_key_idx = %d,\n" % (hmac_key_idx)) f.write(' .hmac_key_idx = %d,\n' % (hmac_key_idx))
f.write(" // results of message array encrypted with these keys\n") f.write(' // results of message array encrypted with these keys\n')
f.write(" .expected_results = {\n") f.write(' .expected_results = {\n')
mask = (1 << key_size) - 1 # truncate messages if needed mask = (1 << key_size) - 1 # truncate messages if needed
for m in messages: for m in messages:
f.write(" // Message %d\n" % messages.index(m)) f.write(' // Message %d\n' % messages.index(m))
f.write(" %s," % (number_as_bignum_words(pow(m & mask, Y, M)))) f.write(' %s,' % (number_as_bignum_words(pow(m & mask, Y, M))))
f.write(" },\n") f.write(' },\n')
f.write(" },\n") f.write(' },\n')
f.write("};\n") f.write('};\n')

View File

@ -3,12 +3,14 @@
# source: esp_local_ctrl.proto # source: esp_local_ctrl.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -16,7 +18,6 @@ _sym_db = _symbol_database.Default()
import constants_pb2 as constants__pb2 import constants_pb2 as constants__pb2
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='esp_local_ctrl.proto', name='esp_local_ctrl.proto',
package='', package='',
@ -153,7 +154,7 @@ _PROPERTYINFO = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='PropertyInfo.name', index=1, name='name', full_name='PropertyInfo.name', index=1,
number=2, type=9, cpp_type=9, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b('').decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@ -174,7 +175,7 @@ _PROPERTYINFO = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='value', full_name='PropertyInfo.value', index=4, name='value', full_name='PropertyInfo.value', index=4,
number=5, type=12, cpp_type=9, label=1, number=5, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@ -281,7 +282,7 @@ _PROPERTYVALUE = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='value', full_name='PropertyValue.value', index=1, name='value', full_name='PropertyValue.value', index=1,
number=2, type=12, cpp_type=9, label=1, number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),

View File

@ -1,13 +1,13 @@
import re
import os import os
import re
import socket import socket
from threading import Thread, Event
import subprocess import subprocess
import time import time
from shutil import copyfile from shutil import copyfile
from threading import Event, Thread
from tiny_test_fw import Utility, DUT
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT, Utility
stop_sock_listener = Event() stop_sock_listener = Event()
stop_io_listener = Event() stop_io_listener = Event()
@ -22,16 +22,16 @@ def io_listener(dut1):
data = b'' data = b''
while not stop_io_listener.is_set(): while not stop_io_listener.is_set():
try: try:
data = dut1.expect(re.compile(r"PacketOut:\[([a-fA-F0-9]+)\]"), timeout=5) data = dut1.expect(re.compile(r'PacketOut:\[([a-fA-F0-9]+)\]'), timeout=5)
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
continue continue
if data != () and data[0] != b'': if data != () and data[0] != b'':
packet_data = data[0] packet_data = data[0]
print("Packet_data>{}<".format(packet_data)) print('Packet_data>{}<'.format(packet_data))
response = bytearray.fromhex(packet_data.decode()) response = bytearray.fromhex(packet_data.decode())
print("Sending to socket:") print('Sending to socket:')
packet = ' '.join(format(x, '02x') for x in bytearray(response)) packet = ' '.join(format(x, '02x') for x in bytearray(response))
print("Packet>{}<".format(packet)) print('Packet>{}<'.format(packet))
if client_address is not None: if client_address is not None:
sock.sendto(response, ('127.0.0.1', 7777)) sock.sendto(response, ('127.0.0.1', 7777))
@ -50,7 +50,7 @@ def sock_listener(dut1):
try: try:
payload, client_address = sock.recvfrom(1024) payload, client_address = sock.recvfrom(1024)
packet = ' '.join(format(x, '02x') for x in bytearray(payload)) packet = ' '.join(format(x, '02x') for x in bytearray(payload))
print("Received from address {}, data {}".format(client_address, packet)) print('Received from address {}, data {}'.format(client_address, packet))
dut1.write(str.encode(packet)) dut1.write(str.encode(packet))
except socket.timeout: except socket.timeout:
pass pass
@ -59,7 +59,7 @@ def sock_listener(dut1):
sock = None sock = None
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def lwip_test_suite(env, extra_data): def lwip_test_suite(env, extra_data):
global stop_io_listener global stop_io_listener
global stop_sock_listener global stop_sock_listener
@ -70,12 +70,12 @@ def lwip_test_suite(env, extra_data):
3. Execute ttcn3 test suite 3. Execute ttcn3 test suite
4. Collect result from ttcn3 4. Collect result from ttcn3
""" """
dut1 = env.get_dut("net_suite", "examples/system/network_tests", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('net_suite', 'examples/system/network_tests', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "net_suite.bin") binary_file = os.path.join(dut1.app.binary_path, 'net_suite.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("net_suite", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('net_suite', '{}KB'.format(bin_size // 1024))
ttfw_idf.check_performance("net_suite", bin_size // 1024, dut1.TARGET) ttfw_idf.check_performance('net_suite', bin_size // 1024, dut1.TARGET)
dut1.start_app() dut1.start_app()
thread1 = Thread(target=sock_listener, args=(dut1, )) thread1 = Thread(target=sock_listener, args=(dut1, ))
thread2 = Thread(target=io_listener, args=(dut1, )) thread2 = Thread(target=io_listener, args=(dut1, ))
@ -84,48 +84,48 @@ def lwip_test_suite(env, extra_data):
TTCN_SRC = 'esp32_netsuite.ttcn' TTCN_SRC = 'esp32_netsuite.ttcn'
TTCN_CFG = 'esp32_netsuite.cfg' TTCN_CFG = 'esp32_netsuite.cfg'
# System Paths # System Paths
netsuite_path = os.getenv("NETSUITE_PATH") netsuite_path = os.getenv('NETSUITE_PATH')
netsuite_src_path = os.path.join(netsuite_path, "src") netsuite_src_path = os.path.join(netsuite_path, 'src')
test_dir = os.path.dirname(os.path.realpath(__file__)) test_dir = os.path.dirname(os.path.realpath(__file__))
# Building the suite # Building the suite
print("Rebuilding the test suite") print('Rebuilding the test suite')
print("-------------------------") print('-------------------------')
# copy esp32 specific files to ttcn net-suite dir # copy esp32 specific files to ttcn net-suite dir
copyfile(os.path.join(test_dir, TTCN_SRC), os.path.join(netsuite_src_path, TTCN_SRC)) copyfile(os.path.join(test_dir, TTCN_SRC), os.path.join(netsuite_src_path, TTCN_SRC))
copyfile(os.path.join(test_dir, TTCN_CFG), os.path.join(netsuite_src_path, TTCN_CFG)) copyfile(os.path.join(test_dir, TTCN_CFG), os.path.join(netsuite_src_path, TTCN_CFG))
proc = subprocess.Popen(['bash', '-c', 'cd ' + netsuite_src_path + ' && source make.sh'], proc = subprocess.Popen(['bash', '-c', 'cd ' + netsuite_src_path + ' && source make.sh'],
cwd=netsuite_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cwd=netsuite_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output = proc.stdout.read() output = proc.stdout.read()
print("Note: First build step we expect failure (titan/net_suite build system not suitable for multijob make)") print('Note: First build step we expect failure (titan/net_suite build system not suitable for multijob make)')
print(output) print(output)
proc = subprocess.Popen(['bash', '-c', 'cd ' + netsuite_src_path + ' && make'], proc = subprocess.Popen(['bash', '-c', 'cd ' + netsuite_src_path + ' && make'],
cwd=netsuite_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cwd=netsuite_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("Note: This time all dependencies shall be generated -- multijob make shall pass") print('Note: This time all dependencies shall be generated -- multijob make shall pass')
output = proc.stdout.read() output = proc.stdout.read()
print(output) print(output)
# Executing the test suite # Executing the test suite
thread1.start() thread1.start()
thread2.start() thread2.start()
time.sleep(2) time.sleep(2)
print("Executing the test suite") print('Executing the test suite')
print("------------------------") print('------------------------')
proc = subprocess.Popen(['ttcn3_start', os.path.join(netsuite_src_path,'test_suite'), os.path.join(netsuite_src_path, TTCN_CFG)], proc = subprocess.Popen(['ttcn3_start', os.path.join(netsuite_src_path,'test_suite'), os.path.join(netsuite_src_path, TTCN_CFG)],
stdout=subprocess.PIPE) stdout=subprocess.PIPE)
output = proc.stdout.read() output = proc.stdout.read()
print(output) print(output)
print("Collecting results") print('Collecting results')
print("------------------") print('------------------')
verdict_stats = re.search('(Verdict statistics:.*)', output) verdict_stats = re.search('(Verdict statistics:.*)', output)
if verdict_stats: if verdict_stats:
verdict_stats = verdict_stats.group(1) verdict_stats = verdict_stats.group(1)
else: else:
verdict_stats = b"" verdict_stats = b''
verdict = re.search('Overall verdict: pass', output) verdict = re.search('Overall verdict: pass', output)
if verdict: if verdict:
print("Test passed!") print('Test passed!')
Utility.console_log(verdict_stats, "green") Utility.console_log(verdict_stats, 'green')
else: else:
Utility.console_log(verdict_stats, "red") Utility.console_log(verdict_stats, 'red')
raise ValueError('Test failed with: {}'.format(verdict_stats)) raise ValueError('Test failed with: {}'.format(verdict_stats))
else: else:
try: try:
@ -137,8 +137,8 @@ def lwip_test_suite(env, extra_data):
time.sleep(0.5) time.sleep(0.5)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
print("Executing done, waiting for tests to finish") print('Executing done, waiting for tests to finish')
print("-------------------------------------------") print('-------------------------------------------')
stop_io_listener.set() stop_io_listener.set()
stop_sock_listener.set() stop_sock_listener.set()
thread1.join() thread1.join()
@ -146,6 +146,6 @@ def lwip_test_suite(env, extra_data):
if __name__ == '__main__': if __name__ == '__main__':
print("Manual execution, please build and start ttcn in a separate console") print('Manual execution, please build and start ttcn in a separate console')
manual_test = True manual_test = True
lwip_test_suite() lwip_test_suite()

View File

@ -24,12 +24,12 @@
from __future__ import with_statement from __future__ import with_statement
import os
import sys
import struct
import argparse import argparse
import csv import csv
import os
import re import re
import struct
import sys
from io import open from io import open
try: try:
@ -80,22 +80,22 @@ class CertificateBundle:
def add_from_file(self, file_path): def add_from_file(self, file_path):
try: try:
if file_path.endswith('.pem'): if file_path.endswith('.pem'):
status("Parsing certificates from %s" % file_path) status('Parsing certificates from %s' % file_path)
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
crt_str = f.read() crt_str = f.read()
self.add_from_pem(crt_str) self.add_from_pem(crt_str)
return True return True
elif file_path.endswith('.der'): elif file_path.endswith('.der'):
status("Parsing certificates from %s" % file_path) status('Parsing certificates from %s' % file_path)
with open(file_path, 'rb') as f: with open(file_path, 'rb') as f:
crt_str = f.read() crt_str = f.read()
self.add_from_der(crt_str) self.add_from_der(crt_str)
return True return True
except ValueError: except ValueError:
critical("Invalid certificate in %s" % file_path) critical('Invalid certificate in %s' % file_path)
raise InputError("Invalid certificate") raise InputError('Invalid certificate')
return False return False
@ -119,13 +119,13 @@ class CertificateBundle:
crt += strg crt += strg
if(count == 0): if(count == 0):
raise InputError("No certificate found") raise InputError('No certificate found')
status("Successfully added %d certificates" % count) status('Successfully added %d certificates' % count)
def add_from_der(self, crt_str): def add_from_der(self, crt_str):
self.certificates.append(x509.load_der_x509_certificate(crt_str, default_backend())) self.certificates.append(x509.load_der_x509_certificate(crt_str, default_backend()))
status("Successfully added 1 certificate") status('Successfully added 1 certificate')
def create_bundle(self): def create_bundle(self):
# Sort certificates in order to do binary search when looking up certificates # Sort certificates in order to do binary search when looking up certificates
@ -162,7 +162,7 @@ class CertificateBundle:
for row in csv_reader: for row in csv_reader:
filter_set.add(row[1]) filter_set.add(row[1])
status("Parsing certificates from %s" % crts_path) status('Parsing certificates from %s' % crts_path)
crt_str = [] crt_str = []
with open(crts_path, 'r', encoding='utf-8') as f: with open(crts_path, 'r', encoding='utf-8') as f:
crt_str = f.read() crt_str = f.read()
@ -202,14 +202,14 @@ def main():
for path in args.input: for path in args.input:
if os.path.isfile(path): if os.path.isfile(path):
if os.path.basename(path) == "cacrt_all.pem" and args.filter: if os.path.basename(path) == 'cacrt_all.pem' and args.filter:
bundle.add_with_filter(path, args.filter) bundle.add_with_filter(path, args.filter)
else: else:
bundle.add_from_file(path) bundle.add_from_file(path)
elif os.path.isdir(path): elif os.path.isdir(path):
bundle.add_from_path(path) bundle.add_from_path(path)
else: else:
raise InputError("Invalid --input=%s, is neither file nor folder" % args.input) raise InputError('Invalid --input=%s, is neither file nor folder' % args.input)
status('Successfully added %d certificates in total' % len(bundle.certificates)) status('Successfully added %d certificates in total' % len(bundle.certificates))

View File

@ -1,13 +1,13 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest
import sys
import os import os
import sys
import unittest
try: try:
import gen_crt_bundle import gen_crt_bundle
except ImportError: except ImportError:
sys.path.append("..") sys.path.append('..')
import gen_crt_bundle import gen_crt_bundle
@ -67,11 +67,11 @@ class GenCrtBundleTests(Py23TestCase):
def test_invalid_crt_input(self): def test_invalid_crt_input(self):
bundle = gen_crt_bundle.CertificateBundle() bundle = gen_crt_bundle.CertificateBundle()
with self.assertRaisesRegex(gen_crt_bundle.InputError, "Invalid certificate"): with self.assertRaisesRegex(gen_crt_bundle.InputError, 'Invalid certificate'):
bundle.add_from_file(test_crts_path + invalid_test_file) bundle.add_from_file(test_crts_path + invalid_test_file)
with self.assertRaisesRegex(gen_crt_bundle.InputError, "No certificate found"): with self.assertRaisesRegex(gen_crt_bundle.InputError, 'No certificate found'):
bundle.add_from_pem("") bundle.add_from_pem('')
def test_non_ascii_crt_input(self): def test_non_ascii_crt_input(self):
bundle = gen_crt_bundle.CertificateBundle() bundle = gen_crt_bundle.CertificateBundle()
@ -80,5 +80,5 @@ class GenCrtBundleTests(Py23TestCase):
self.assertTrue(len(bundle.certificates)) self.assertTrue(len(bundle.certificates))
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,36 +1,35 @@
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
from builtins import str
import re
import sys
import ssl
import paho.mqtt.client as mqtt
from threading import Thread, Event
import time
import string
import random import random
import re
import ssl
import string
import sys
import time
from builtins import str
from threading import Event, Thread
from tiny_test_fw import DUT import paho.mqtt.client as mqtt
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
event_client_received_correct = Event() event_client_received_correct = Event()
message_log = "" message_log = ''
broker_host = {} broker_host = {}
broker_port = {} broker_port = {}
expected_data = "" expected_data = ''
subscribe_topic = "" subscribe_topic = ''
publish_topic = "" publish_topic = ''
expected_count = 0 expected_count = 0
# The callback for when the client receives a CONNACK response from the server. # The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc): def on_connect(client, userdata, flags, rc):
print("Connected with result code " + str(rc)) print('Connected with result code ' + str(rc))
event_client_connected.set() event_client_connected.set()
client.subscribe("/topic/qos0") client.subscribe('/topic/qos0')
def mqtt_client_task(client): def mqtt_client_task(client):
@ -52,8 +51,8 @@ def on_message(client, userdata, msg):
payload = msg.payload.decode() payload = msg.payload.decode()
if payload == expected_data: if payload == expected_data:
expected_count += 1 expected_count += 1
print("[{}] Received...".format(msg.mid)) print('[{}] Received...'.format(msg.mid))
message_log += "Received data:" + msg.topic + " " + payload + "\n" message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
def test_single_config(dut, transport, qos, repeat, published, queue=0): def test_single_config(dut, transport, qos, repeat, published, queue=0):
@ -63,49 +62,49 @@ def test_single_config(dut, transport, qos, repeat, published, queue=0):
sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(16)) sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(16))
event_client_connected.clear() event_client_connected.clear()
expected_count = 0 expected_count = 0
message_log = "" message_log = ''
expected_data = sample_string * repeat expected_data = sample_string * repeat
print("PUBLISH TEST: transport:{}, qos:{}, sequence:{}, enqueue:{}, sample msg:'{}'".format(transport, qos, published, queue, expected_data)) print("PUBLISH TEST: transport:{}, qos:{}, sequence:{}, enqueue:{}, sample msg:'{}'".format(transport, qos, published, queue, expected_data))
client = None client = None
try: try:
if transport in ["ws", "wss"]: if transport in ['ws', 'wss']:
client = mqtt.Client(transport="websockets") client = mqtt.Client(transport='websockets')
else: else:
client = mqtt.Client() client = mqtt.Client()
client.on_connect = on_connect client.on_connect = on_connect
client.on_message = on_message client.on_message = on_message
if transport in ["ssl", "wss"]: if transport in ['ssl', 'wss']:
client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
client.tls_insecure_set(True) client.tls_insecure_set(True)
print("Connecting...") print('Connecting...')
client.connect(broker_host[transport], broker_port[transport], 60) client.connect(broker_host[transport], broker_port[transport], 60)
except Exception: except Exception:
print("ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:".format(broker_host[transport], sys.exc_info()[0])) print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_host[transport], sys.exc_info()[0]))
raise raise
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
thread1 = Thread(target=mqtt_client_task, args=(client,)) thread1 = Thread(target=mqtt_client_task, args=(client,))
thread1.start() thread1.start()
print("Connecting py-client to broker {}:{}...".format(broker_host[transport], broker_port[transport])) print('Connecting py-client to broker {}:{}...'.format(broker_host[transport], broker_port[transport]))
if not event_client_connected.wait(timeout=30): if not event_client_connected.wait(timeout=30):
raise ValueError("ENV_TEST_FAILURE: Test script cannot connect to broker: {}".format(broker_host[transport])) raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_host[transport]))
client.subscribe(subscribe_topic, qos) client.subscribe(subscribe_topic, qos)
dut.write(' '.join(str(x) for x in (transport, sample_string, repeat, published, qos, queue)), eol="\n") dut.write(' '.join(str(x) for x in (transport, sample_string, repeat, published, qos, queue)), eol='\n')
try: try:
# waiting till subscribed to defined topic # waiting till subscribed to defined topic
dut.expect(re.compile(r"MQTT_EVENT_SUBSCRIBED"), timeout=30) dut.expect(re.compile(r'MQTT_EVENT_SUBSCRIBED'), timeout=30)
for i in range(published): for i in range(published):
client.publish(publish_topic, sample_string * repeat, qos) client.publish(publish_topic, sample_string * repeat, qos)
print("Publishing...") print('Publishing...')
print("Checking esp-client received msg published from py-client...") print('Checking esp-client received msg published from py-client...')
dut.expect(re.compile(r"Correct pattern received exactly x times"), timeout=60) dut.expect(re.compile(r'Correct pattern received exactly x times'), timeout=60)
start = time.time() start = time.time()
while expected_count < published and time.time() - start <= 60: while expected_count < published and time.time() - start <= 60:
time.sleep(1) time.sleep(1)
# Note: tolerate that messages qos=1 to be received more than once # Note: tolerate that messages qos=1 to be received more than once
if expected_count == published or (expected_count > published and qos == 1): if expected_count == published or (expected_count > published and qos == 1):
print("All data received from ESP32...") print('All data received from ESP32...')
else: else:
raise ValueError("Not all data received from ESP32: Expected:{}x{}, Received:{}x{}".format(expected_count, published, expected_data, message_log)) raise ValueError('Not all data received from ESP32: Expected:{}x{}, Received:{}x{}'.format(expected_count, published, expected_data, message_log))
finally: finally:
event_stop_client.set() event_stop_client.set()
thread1.join() thread1.join()
@ -113,7 +112,7 @@ def test_single_config(dut, transport, qos, repeat, published, queue=0):
event_stop_client.clear() event_stop_client.clear()
@ttfw_idf.idf_custom_test(env_tag="Example_WIFI") @ttfw_idf.idf_custom_test(env_tag='Example_WIFI')
def test_weekend_mqtt_publish(env, extra_data): def test_weekend_mqtt_publish(env, extra_data):
# Using broker url dictionary for different transport # Using broker url dictionary for different transport
global broker_host global broker_host
@ -127,28 +126,28 @@ def test_weekend_mqtt_publish(env, extra_data):
3. Test evaluates python client received correct qos0 message 3. Test evaluates python client received correct qos0 message
4. Test ESP32 client received correct qos0 message 4. Test ESP32 client received correct qos0 message
""" """
dut1 = env.get_dut("mqtt_publish_connect_test", "tools/test_apps/protocols/mqtt/publish_connect_test") dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test')
# Look for host:port in sdkconfig # Look for host:port in sdkconfig
try: try:
# python client subscribes to the topic to which esp client publishes and vice versa # python client subscribes to the topic to which esp client publishes and vice versa
publish_topic = dut1.app.get_sdkconfig()["CONFIG_EXAMPLE_SUBSCIBE_TOPIC"].replace('"','') publish_topic = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCIBE_TOPIC'].replace('"','')
subscribe_topic = dut1.app.get_sdkconfig()["CONFIG_EXAMPLE_PUBLISH_TOPIC"].replace('"','') subscribe_topic = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','')
broker_host["ssl"], broker_port["ssl"] = get_host_port_from_dut(dut1, "CONFIG_EXAMPLE_BROKER_SSL_URI") broker_host['ssl'], broker_port['ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI')
broker_host["tcp"], broker_port["tcp"] = get_host_port_from_dut(dut1, "CONFIG_EXAMPLE_BROKER_TCP_URI") broker_host['tcp'], broker_port['tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI')
broker_host["ws"], broker_port["ws"] = get_host_port_from_dut(dut1, "CONFIG_EXAMPLE_BROKER_WS_URI") broker_host['ws'], broker_port['ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI')
broker_host["wss"], broker_port["wss"] = get_host_port_from_dut(dut1, "CONFIG_EXAMPLE_BROKER_WSS_URI") broker_host['wss'], broker_port['wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI')
except Exception: except Exception:
print('ENV_TEST_FAILURE: Cannot find broker url in sdkconfig') print('ENV_TEST_FAILURE: Cannot find broker url in sdkconfig')
raise raise
dut1.start_app() dut1.start_app()
try: try:
ip_address = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30) ip_address = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)
print("Connected to AP with IP: {}".format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
print('ENV_TEST_FAILURE: Cannot connect to AP') print('ENV_TEST_FAILURE: Cannot connect to AP')
raise raise
for qos in [0, 1, 2]: for qos in [0, 1, 2]:
for transport in ["tcp", "ssl", "ws", "wss"]: for transport in ['tcp', 'ssl', 'ws', 'wss']:
for q in [0, 1]: for q in [0, 1]:
if broker_host[transport] is None: if broker_host[transport] is None:
print('Skipping transport: {}...'.format(transport)) print('Skipping transport: {}...'.format(transport))
@ -156,14 +155,14 @@ def test_weekend_mqtt_publish(env, extra_data):
# simple test with empty message # simple test with empty message
test_single_config(dut1, transport, qos, 0, 5, q) test_single_config(dut1, transport, qos, 0, 5, q)
# decide on broker what level of test will pass (local broker works the best) # decide on broker what level of test will pass (local broker works the best)
if broker_host[transport].startswith("192.168") and qos > 0 and q == 0: if broker_host[transport].startswith('192.168') and qos > 0 and q == 0:
# medium size, medium repeated # medium size, medium repeated
test_single_config(dut1, transport, qos, 5, 50, q) test_single_config(dut1, transport, qos, 5, 50, q)
# long data # long data
test_single_config(dut1, transport, qos, 1000, 10, q) test_single_config(dut1, transport, qos, 1000, 10, q)
# short data, many repeats # short data, many repeats
test_single_config(dut1, transport, qos, 2, 200, q) test_single_config(dut1, transport, qos, 2, 200, q)
elif transport in ["ws", "wss"]: elif transport in ['ws', 'wss']:
# more relaxed criteria for websockets! # more relaxed criteria for websockets!
test_single_config(dut1, transport, qos, 2, 5, q) test_single_config(dut1, transport, qos, 2, 5, q)
test_single_config(dut1, transport, qos, 50, 1, q) test_single_config(dut1, transport, qos, 50, 1, q)

View File

@ -19,40 +19,42 @@
# #
from __future__ import division, print_function from __future__ import division, print_function
from future.moves.itertools import zip_longest
from builtins import int, range, bytes
from io import open
import sys
import argparse import argparse
import binascii
import random
import struct
import os
import array import array
import zlib import binascii
import codecs import codecs
import datetime import datetime
import distutils.dir_util import distutils.dir_util
import os
import random
import struct
import sys
import zlib
from builtins import bytes, int, range
from io import open
from future.moves.itertools import zip_longest
try: try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
except ImportError: except ImportError:
print('The cryptography package is not installed.' print('The cryptography package is not installed.'
'Please refer to the Get Started section of the ESP-IDF Programming Guide for ' 'Please refer to the Get Started section of the ESP-IDF Programming Guide for '
'setting up the required packages.') 'setting up the required packages.')
raise raise
VERSION1_PRINT = "V1 - Multipage Blob Support Disabled" VERSION1_PRINT = 'V1 - Multipage Blob Support Disabled'
VERSION2_PRINT = "V2 - Multipage Blob Support Enabled" VERSION2_PRINT = 'V2 - Multipage Blob Support Enabled'
def reverse_hexbytes(addr_tmp): def reverse_hexbytes(addr_tmp):
addr = [] addr = []
reversed_bytes = "" reversed_bytes = ''
for i in range(0, len(addr_tmp), 2): for i in range(0, len(addr_tmp), 2):
addr.append(addr_tmp[i:i + 2]) addr.append(addr_tmp[i:i + 2])
reversed_bytes = "".join(reversed(addr)) reversed_bytes = ''.join(reversed(addr))
return reversed_bytes return reversed_bytes
@ -62,10 +64,10 @@ def reverse_hexbytes(addr_tmp):
class Page(object): class Page(object):
PAGE_PARAMS = { PAGE_PARAMS = {
"max_size": 4096, 'max_size': 4096,
"max_old_blob_size": 1984, 'max_old_blob_size': 1984,
"max_new_blob_size": 4000, 'max_new_blob_size': 4000,
"max_entries": 126 'max_entries': 126
} }
# Item type codes # Item type codes
@ -98,7 +100,7 @@ class Page(object):
self.entry_num = 0 self.entry_num = 0
self.bitmap_array = array.array('B') self.bitmap_array = array.array('B')
self.version = version self.version = version
self.page_buf = bytearray(b'\xff') * Page.PAGE_PARAMS["max_size"] self.page_buf = bytearray(b'\xff') * Page.PAGE_PARAMS['max_size']
if not is_rsrv_page: if not is_rsrv_page:
self.bitmap_array = self.create_bitmap_array() self.bitmap_array = self.create_bitmap_array()
self.set_header(page_num, version) self.set_header(page_num, version)
@ -167,7 +169,7 @@ class Page(object):
else: else:
encr_key_input = codecs.decode(nvs_obj.encr_key, 'hex') encr_key_input = codecs.decode(nvs_obj.encr_key, 'hex')
rel_addr = nvs_obj.page_num * Page.PAGE_PARAMS["max_size"] + Page.FIRST_ENTRY_OFFSET rel_addr = nvs_obj.page_num * Page.PAGE_PARAMS['max_size'] + Page.FIRST_ENTRY_OFFSET
if not isinstance(data_input, bytearray): if not isinstance(data_input, bytearray):
byte_arr = bytearray(b'\xff') * 32 byte_arr = bytearray(b'\xff') * 32
@ -249,8 +251,8 @@ class Page(object):
chunk_size = 0 chunk_size = 0
# Get the size available in current page # Get the size available in current page
tailroom = (Page.PAGE_PARAMS["max_entries"] - self.entry_num - 1) * Page.SINGLE_ENTRY_SIZE tailroom = (Page.PAGE_PARAMS['max_entries'] - self.entry_num - 1) * Page.SINGLE_ENTRY_SIZE
assert tailroom >= 0, "Page overflow!!" assert tailroom >= 0, 'Page overflow!!'
# Split the binary data into two and store a chunk of available size onto curr page # Split the binary data into two and store a chunk of available size onto curr page
if tailroom < remaining_size: if tailroom < remaining_size:
@ -358,14 +360,14 @@ class Page(object):
# Set size of data # Set size of data
datalen = len(data) datalen = len(data)
if datalen > Page.PAGE_PARAMS["max_old_blob_size"]: if datalen > Page.PAGE_PARAMS['max_old_blob_size']:
if self.version == Page.VERSION1: if self.version == Page.VERSION1:
raise InputError(" Input File: Size (%d) exceeds max allowed length `%s` bytes for key `%s`." raise InputError(' Input File: Size (%d) exceeds max allowed length `%s` bytes for key `%s`.'
% (datalen, Page.PAGE_PARAMS["max_old_blob_size"], key)) % (datalen, Page.PAGE_PARAMS['max_old_blob_size'], key))
else: else:
if encoding == "string": if encoding == 'string':
raise InputError(" Input File: Size (%d) exceeds max allowed length `%s` bytes for key `%s`." raise InputError(' Input File: Size (%d) exceeds max allowed length `%s` bytes for key `%s`.'
% (datalen, Page.PAGE_PARAMS["max_old_blob_size"], key)) % (datalen, Page.PAGE_PARAMS['max_old_blob_size'], key))
# Calculate no. of entries data will require # Calculate no. of entries data will require
rounded_size = (datalen + 31) & ~31 rounded_size = (datalen + 31) & ~31
@ -373,10 +375,10 @@ class Page(object):
total_entry_count = data_entry_count + 1 # +1 for the entry header total_entry_count = data_entry_count + 1 # +1 for the entry header
# Check if page is already full and new page is needed to be created right away # Check if page is already full and new page is needed to be created right away
if self.entry_num >= Page.PAGE_PARAMS["max_entries"]: if self.entry_num >= Page.PAGE_PARAMS['max_entries']:
raise PageFullError() raise PageFullError()
elif (self.entry_num + total_entry_count) >= Page.PAGE_PARAMS["max_entries"]: elif (self.entry_num + total_entry_count) >= Page.PAGE_PARAMS['max_entries']:
if not (self.version == Page.VERSION2 and encoding in ["hex2bin", "binary", "base64"]): if not (self.version == Page.VERSION2 and encoding in ['hex2bin', 'binary', 'base64']):
raise PageFullError() raise PageFullError()
# Entry header # Entry header
@ -385,7 +387,7 @@ class Page(object):
entry_struct[0] = ns_index entry_struct[0] = ns_index
# Set Span # Set Span
if self.version == Page.VERSION2: if self.version == Page.VERSION2:
if encoding == "string": if encoding == 'string':
entry_struct[2] = data_entry_count + 1 entry_struct[2] = data_entry_count + 1
# Set Chunk Index # Set Chunk Index
chunk_index = Page.CHUNK_ANY chunk_index = Page.CHUNK_ANY
@ -399,12 +401,12 @@ class Page(object):
entry_struct[8:8 + len(key)] = key.encode() entry_struct[8:8 + len(key)] = key.encode()
# set Type # set Type
if encoding == "string": if encoding == 'string':
entry_struct[1] = Page.SZ entry_struct[1] = Page.SZ
elif encoding in ["hex2bin", "binary", "base64"]: elif encoding in ['hex2bin', 'binary', 'base64']:
entry_struct[1] = Page.BLOB entry_struct[1] = Page.BLOB
if self.version == Page.VERSION2 and (encoding in ["hex2bin", "binary", "base64"]): if self.version == Page.VERSION2 and (encoding in ['hex2bin', 'binary', 'base64']):
entry_struct = self.write_varlen_binary_data(entry_struct,ns_index,key,data, entry_struct = self.write_varlen_binary_data(entry_struct,ns_index,key,data,
datalen,total_entry_count, encoding, nvs_obj) datalen,total_entry_count, encoding, nvs_obj)
else: else:
@ -413,7 +415,7 @@ class Page(object):
""" Low-level function to write data of primitive type into page buffer. """ """ Low-level function to write data of primitive type into page buffer. """
def write_primitive_data(self, key, data, encoding, ns_index,nvs_obj): def write_primitive_data(self, key, data, encoding, ns_index,nvs_obj):
# Check if entry exceeds max number of entries allowed per page # Check if entry exceeds max number of entries allowed per page
if self.entry_num >= Page.PAGE_PARAMS["max_entries"]: if self.entry_num >= Page.PAGE_PARAMS['max_entries']:
raise PageFullError() raise PageFullError()
entry_struct = bytearray(b'\xff') * 32 entry_struct = bytearray(b'\xff') * 32
@ -427,28 +429,28 @@ class Page(object):
entry_struct[8:24] = key_array entry_struct[8:24] = key_array
entry_struct[8:8 + len(key)] = key.encode() entry_struct[8:8 + len(key)] = key.encode()
if encoding == "u8": if encoding == 'u8':
entry_struct[1] = Page.U8 entry_struct[1] = Page.U8
struct.pack_into('<B', entry_struct, 24, data) struct.pack_into('<B', entry_struct, 24, data)
elif encoding == "i8": elif encoding == 'i8':
entry_struct[1] = Page.I8 entry_struct[1] = Page.I8
struct.pack_into('<b', entry_struct, 24, data) struct.pack_into('<b', entry_struct, 24, data)
elif encoding == "u16": elif encoding == 'u16':
entry_struct[1] = Page.U16 entry_struct[1] = Page.U16
struct.pack_into('<H', entry_struct, 24, data) struct.pack_into('<H', entry_struct, 24, data)
elif encoding == "i16": elif encoding == 'i16':
entry_struct[1] = Page.I16 entry_struct[1] = Page.I16
struct.pack_into('<h', entry_struct, 24, data) struct.pack_into('<h', entry_struct, 24, data)
elif encoding == "u32": elif encoding == 'u32':
entry_struct[1] = Page.U32 entry_struct[1] = Page.U32
struct.pack_into('<I', entry_struct, 24, data) struct.pack_into('<I', entry_struct, 24, data)
elif encoding == "i32": elif encoding == 'i32':
entry_struct[1] = Page.I32 entry_struct[1] = Page.I32
struct.pack_into('<i', entry_struct, 24, data) struct.pack_into('<i', entry_struct, 24, data)
elif encoding == "u64": elif encoding == 'u64':
entry_struct[1] = Page.U64 entry_struct[1] = Page.U64
struct.pack_into('<Q', entry_struct, 24, data) struct.pack_into('<Q', entry_struct, 24, data)
elif encoding == "i64": elif encoding == 'i64':
entry_struct[1] = Page.I64 entry_struct[1] = Page.I64
struct.pack_into('<q', entry_struct, 24, data) struct.pack_into('<q', entry_struct, 24, data)
@ -516,9 +518,9 @@ class NVS(object):
version = self.version version = self.version
# Update available size as each page is created # Update available size as each page is created
if self.size == 0: if self.size == 0:
raise InsufficientSizeError("Error: Size parameter is less than the size of data in csv.Please increase size.") raise InsufficientSizeError('Error: Size parameter is less than the size of data in csv.Please increase size.')
if not is_rsrv_page: if not is_rsrv_page:
self.size = self.size - Page.PAGE_PARAMS["max_size"] self.size = self.size - Page.PAGE_PARAMS['max_size']
self.page_num += 1 self.page_num += 1
# Set version for each page and page header # Set version for each page and page header
new_page = Page(self.page_num, version, is_rsrv_page) new_page = Page(self.page_num, version, is_rsrv_page)
@ -533,10 +535,10 @@ class NVS(object):
def write_namespace(self, key): def write_namespace(self, key):
self.namespace_idx += 1 self.namespace_idx += 1
try: try:
self.cur_page.write_primitive_data(key, self.namespace_idx, "u8", 0,self) self.cur_page.write_primitive_data(key, self.namespace_idx, 'u8', 0,self)
except PageFullError: except PageFullError:
new_page = self.create_new_page() new_page = self.create_new_page()
new_page.write_primitive_data(key, self.namespace_idx, "u8", 0,self) new_page.write_primitive_data(key, self.namespace_idx, 'u8', 0,self)
""" """
Write key-value pair. Function accepts value in the form of ascii character and converts Write key-value pair. Function accepts value in the form of ascii character and converts
@ -545,23 +547,23 @@ class NVS(object):
We don't have to guard re-invocation with try-except since no entry can span multiple pages. We don't have to guard re-invocation with try-except since no entry can span multiple pages.
""" """
def write_entry(self, key, value, encoding): def write_entry(self, key, value, encoding):
if encoding == "hex2bin": if encoding == 'hex2bin':
value = value.strip() value = value.strip()
if len(value) % 2 != 0: if len(value) % 2 != 0:
raise InputError("%s: Invalid data length. Should be multiple of 2." % key) raise InputError('%s: Invalid data length. Should be multiple of 2.' % key)
value = binascii.a2b_hex(value) value = binascii.a2b_hex(value)
if encoding == "base64": if encoding == 'base64':
value = binascii.a2b_base64(value) value = binascii.a2b_base64(value)
if encoding == "string": if encoding == 'string':
if type(value) == bytes: if type(value) == bytes:
value = value.decode() value = value.decode()
value += '\0' value += '\0'
encoding = encoding.lower() encoding = encoding.lower()
varlen_encodings = ["string", "binary", "hex2bin", "base64"] varlen_encodings = ['string', 'binary', 'hex2bin', 'base64']
primitive_encodings = ["u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64"] primitive_encodings = ['u8', 'i8', 'u16', 'i16', 'u32', 'i32', 'u64', 'i64']
if encoding in varlen_encodings: if encoding in varlen_encodings:
try: try:
@ -576,7 +578,7 @@ class NVS(object):
new_page = self.create_new_page() new_page = self.create_new_page()
new_page.write_primitive_data(key, int(value), encoding, self.namespace_idx,self) new_page.write_primitive_data(key, int(value), encoding, self.namespace_idx,self)
else: else:
raise InputError("%s: Unsupported encoding" % encoding) raise InputError('%s: Unsupported encoding' % encoding)
""" Return accumulated data of all pages """ """ Return accumulated data of all pages """
def get_binary_data(self): def get_binary_data(self):
@ -600,7 +602,7 @@ class InputError(RuntimeError):
Represents error on the input Represents error on the input
""" """
def __init__(self, e): def __init__(self, e):
print("\nError:") print('\nError:')
super(InputError, self).__init__(e) super(InputError, self).__init__(e)
@ -634,7 +636,7 @@ def write_entry(nvs_instance, key, datatype, encoding, value):
:return: None :return: None
""" """
if datatype == "file": if datatype == 'file':
abs_file_path = value abs_file_path = value
if os.path.isabs(value) is False: if os.path.isabs(value) is False:
script_dir = os.getcwd() script_dir = os.getcwd()
@ -643,7 +645,7 @@ def write_entry(nvs_instance, key, datatype, encoding, value):
with open(abs_file_path, 'rb') as f: with open(abs_file_path, 'rb') as f:
value = f.read() value = f.read()
if datatype == "namespace": if datatype == 'namespace':
nvs_instance.write_namespace(key) nvs_instance.write_namespace(key)
else: else:
nvs_instance.write_entry(key, value, encoding) nvs_instance.write_entry(key, value, encoding)
@ -667,13 +669,13 @@ def check_size(size):
# Set size # Set size
input_size = int(size, 0) input_size = int(size, 0)
if input_size % 4096 != 0: if input_size % 4096 != 0:
sys.exit("Size of partition must be multiple of 4096") sys.exit('Size of partition must be multiple of 4096')
# Update size as a page needs to be reserved of size 4KB # Update size as a page needs to be reserved of size 4KB
input_size = input_size - Page.PAGE_PARAMS["max_size"] input_size = input_size - Page.PAGE_PARAMS['max_size']
if input_size < (2 * Page.PAGE_PARAMS["max_size"]): if input_size < (2 * Page.PAGE_PARAMS['max_size']):
sys.exit("Minimum NVS partition size needed is 0x3000 bytes.") sys.exit('Minimum NVS partition size needed is 0x3000 bytes.')
return input_size return input_size
except Exception as e: except Exception as e:
print(e) print(e)
@ -708,7 +710,7 @@ def set_target_filepath(outdir, filepath):
if os.path.isabs(filepath): if os.path.isabs(filepath):
if not outdir == os.getcwd(): if not outdir == os.getcwd():
print("\nWarning: `%s` \n\t==> absolute path given so outdir is ignored for this file." % filepath) print('\nWarning: `%s` \n\t==> absolute path given so outdir is ignored for this file.' % filepath)
# Set to empty as outdir is ignored here # Set to empty as outdir is ignored here
outdir = '' outdir = ''
@ -728,11 +730,11 @@ def encrypt(args):
check_size(args.size) check_size(args.size)
if (args.keygen is False) and (not args.inputkey): if (args.keygen is False) and (not args.inputkey):
sys.exit("Error. --keygen or --inputkey argument needed.") sys.exit('Error. --keygen or --inputkey argument needed.')
elif args.keygen and args.inputkey: elif args.keygen and args.inputkey:
sys.exit("Error. --keygen and --inputkey both are not allowed.") sys.exit('Error. --keygen and --inputkey both are not allowed.')
elif not args.keygen and args.keyfile: elif not args.keygen and args.keyfile:
print("\nWarning:","--inputkey argument is given. --keyfile argument will be ignored...") print('\nWarning:','--inputkey argument is given. --keyfile argument will be ignored...')
if args.inputkey: if args.inputkey:
# Check if key file has .bin extension # Check if key file has .bin extension
@ -835,7 +837,7 @@ def decrypt(args):
start_entry_offset += nvs_read_bytes start_entry_offset += nvs_read_bytes
output_file.write(output_buf) output_file.write(output_buf)
print("\nCreated NVS decrypted binary: ===>", args.output) print('\nCreated NVS decrypted binary: ===>', args.output)
def generate_key(args): def generate_key(args):
@ -850,7 +852,7 @@ def generate_key(args):
if not args.keyfile: if not args.keyfile:
timestamp = datetime.datetime.now().strftime('%m-%d_%H-%M') timestamp = datetime.datetime.now().strftime('%m-%d_%H-%M')
args.keyfile = "keys-" + timestamp + bin_ext args.keyfile = 'keys-' + timestamp + bin_ext
keys_outdir = os.path.join(args.outdir,keys_dir, '') keys_outdir = os.path.join(args.outdir,keys_dir, '')
# Create keys/ dir in <outdir> if does not exist # Create keys/ dir in <outdir> if does not exist
@ -872,7 +874,7 @@ def generate_key(args):
with open(output_keyfile, 'wb') as output_keys_file: with open(output_keyfile, 'wb') as output_keys_file:
output_keys_file.write(keys_buf) output_keys_file.write(keys_buf)
print("\nCreated encryption keys: ===> ", output_keyfile) print('\nCreated encryption keys: ===> ', output_keyfile)
return key return key
@ -914,7 +916,7 @@ def generate(args, is_encr_enabled=False, encr_key=None):
else: else:
version_set = VERSION2_PRINT version_set = VERSION2_PRINT
print("\nCreating NVS binary with version:", version_set) print('\nCreating NVS binary with version:', version_set)
line = input_file.readline().strip() line = input_file.readline().strip()
@ -939,25 +941,25 @@ def generate(args, is_encr_enabled=False, encr_key=None):
try: try:
# Check key length # Check key length
if len(data["key"]) > 15: if len(data['key']) > 15:
raise InputError("Length of key `{}` should be <= 15 characters.".format(data["key"])) raise InputError('Length of key `{}` should be <= 15 characters.'.format(data['key']))
write_entry(nvs_obj, data["key"], data["type"], data["encoding"], data["value"]) write_entry(nvs_obj, data['key'], data['type'], data['encoding'], data['value'])
except InputError as e: except InputError as e:
print(e) print(e)
filedir, filename = os.path.split(args.output) filedir, filename = os.path.split(args.output)
if filename: if filename:
print("\nWarning: NVS binary not created...") print('\nWarning: NVS binary not created...')
os.remove(args.output) os.remove(args.output)
if is_dir_new and not filedir == os.getcwd(): if is_dir_new and not filedir == os.getcwd():
print("\nWarning: Output dir not created...") print('\nWarning: Output dir not created...')
os.rmdir(filedir) os.rmdir(filedir)
sys.exit(-2) sys.exit(-2)
print("\nCreated NVS binary: ===>", args.output) print('\nCreated NVS binary: ===>', args.output)
def main(): def main():
parser = argparse.ArgumentParser(description="\nESP NVS partition generation utility", formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(description='\nESP NVS partition generation utility', formatter_class=argparse.RawTextHelpFormatter)
subparser = parser.add_subparsers(title='Commands', subparser = parser.add_subparsers(title='Commands',
dest='command', dest='command',
help='\nRun nvs_partition_gen.py {command} -h for additional help\n\n') help='\nRun nvs_partition_gen.py {command} -h for additional help\n\n')
@ -1022,7 +1024,7 @@ def main():
\nVersion 2 - Multipage blob support enabled.\ \nVersion 2 - Multipage blob support enabled.\
\nDefault: Version 2''') \nDefault: Version 2''')
parser_encr.add_argument('--keygen', parser_encr.add_argument('--keygen',
action="store_true", action='store_true',
default=False, default=False,
help='Generates key for encrypting NVS partition') help='Generates key for encrypting NVS partition')
parser_encr.add_argument('--keyfile', parser_encr.add_argument('--keyfile',
@ -1057,5 +1059,5 @@ def main():
args.func(args) args.func(args)
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View File

@ -17,8 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function, division from __future__ import division, print_function, unicode_literals
from __future__ import unicode_literals
import argparse import argparse
import sys import sys
@ -28,7 +28,7 @@ quiet = False
def generate_blanked_file(size, output_path): def generate_blanked_file(size, output_path):
output = b"\xFF" * size output = b'\xFF' * size
try: try:
stdout_binary = sys.stdout.buffer # Python 3 stdout_binary = sys.stdout.buffer # Python 3
except AttributeError: except AttributeError:

View File

@ -20,19 +20,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function, division from __future__ import division, print_function, unicode_literals
from __future__ import unicode_literals
import argparse import argparse
import binascii
import errno
import hashlib
import os import os
import re import re
import struct import struct
import sys import sys
import hashlib
import binascii
import errno
MAX_PARTITION_LENGTH = 0xC00 # 3K for partition data (96 entries) leaves 1K in a 4K sector for signature MAX_PARTITION_LENGTH = 0xC00 # 3K for partition data (96 entries) leaves 1K in a 4K sector for signature
MD5_PARTITION_BEGIN = b"\xEB\xEB" + b"\xFF" * 14 # The first 2 bytes are like magic numbers for MD5 sum MD5_PARTITION_BEGIN = b'\xEB\xEB' + b'\xFF' * 14 # The first 2 bytes are like magic numbers for MD5 sum
PARTITION_TABLE_SIZE = 0x1000 # Size of partition table PARTITION_TABLE_SIZE = 0x1000 # Size of partition table
MIN_PARTITION_SUBTYPE_APP_OTA = 0x10 MIN_PARTITION_SUBTYPE_APP_OTA = 0x10
@ -44,26 +44,26 @@ APP_TYPE = 0x00
DATA_TYPE = 0x01 DATA_TYPE = 0x01
TYPES = { TYPES = {
"app": APP_TYPE, 'app': APP_TYPE,
"data": DATA_TYPE, 'data': DATA_TYPE,
} }
# Keep this map in sync with esp_partition_subtype_t enum in esp_partition.h # Keep this map in sync with esp_partition_subtype_t enum in esp_partition.h
SUBTYPES = { SUBTYPES = {
APP_TYPE: { APP_TYPE: {
"factory": 0x00, 'factory': 0x00,
"test": 0x20, 'test': 0x20,
}, },
DATA_TYPE: { DATA_TYPE: {
"ota": 0x00, 'ota': 0x00,
"phy": 0x01, 'phy': 0x01,
"nvs": 0x02, 'nvs': 0x02,
"coredump": 0x03, 'coredump': 0x03,
"nvs_keys": 0x04, 'nvs_keys': 0x04,
"efuse": 0x05, 'efuse': 0x05,
"esphttpd": 0x80, 'esphttpd': 0x80,
"fat": 0x81, 'fat': 0x81,
"spiffs": 0x82, 'spiffs': 0x82,
}, },
} }
@ -103,14 +103,14 @@ class PartitionTable(list):
for line_no in range(len(lines)): for line_no in range(len(lines)):
line = expand_vars(lines[line_no]).strip() line = expand_vars(lines[line_no]).strip()
if line.startswith("#") or len(line) == 0: if line.startswith('#') or len(line) == 0:
continue continue
try: try:
res.append(PartitionDefinition.from_csv(line, line_no + 1)) res.append(PartitionDefinition.from_csv(line, line_no + 1))
except InputError as e: except InputError as e:
raise InputError("Error at line %d: %s" % (line_no + 1, e)) raise InputError('Error at line %d: %s' % (line_no + 1, e))
except Exception: except Exception:
critical("Unexpected error parsing CSV line %d: %s" % (line_no + 1, line)) critical('Unexpected error parsing CSV line %d: %s' % (line_no + 1, line))
raise raise
# fix up missing offsets & negative sizes # fix up missing offsets & negative sizes
@ -118,10 +118,10 @@ class PartitionTable(list):
for e in res: for e in res:
if e.offset is not None and e.offset < last_end: if e.offset is not None and e.offset < last_end:
if e == res[0]: if e == res[0]:
raise InputError("CSV Error: First partition offset 0x%x overlaps end of partition table 0x%x" raise InputError('CSV Error: First partition offset 0x%x overlaps end of partition table 0x%x'
% (e.offset, last_end)) % (e.offset, last_end))
else: else:
raise InputError("CSV Error: Partitions overlap. Partition at line %d sets offset 0x%x. Previous partition ends 0x%x" raise InputError('CSV Error: Partitions overlap. Partition at line %d sets offset 0x%x. Previous partition ends 0x%x'
% (e.line_no, e.offset, last_end)) % (e.line_no, e.offset, last_end))
if e.offset is None: if e.offset is None:
pad_to = 0x10000 if e.type == APP_TYPE else 4 pad_to = 0x10000 if e.type == APP_TYPE else 4
@ -186,19 +186,19 @@ class PartitionTable(list):
# print sorted duplicate partitions by name # print sorted duplicate partitions by name
if len(duplicates) != 0: if len(duplicates) != 0:
print("A list of partitions that have the same name:") print('A list of partitions that have the same name:')
for p in sorted(self, key=lambda x:x.name): for p in sorted(self, key=lambda x:x.name):
if len(duplicates.intersection([p.name])) != 0: if len(duplicates.intersection([p.name])) != 0:
print("%s" % (p.to_csv())) print('%s' % (p.to_csv()))
raise InputError("Partition names must be unique") raise InputError('Partition names must be unique')
# check for overlaps # check for overlaps
last = None last = None
for p in sorted(self, key=lambda x:x.offset): for p in sorted(self, key=lambda x:x.offset):
if p.offset < offset_part_table + PARTITION_TABLE_SIZE: if p.offset < offset_part_table + PARTITION_TABLE_SIZE:
raise InputError("Partition offset 0x%x is below 0x%x" % (p.offset, offset_part_table + PARTITION_TABLE_SIZE)) raise InputError('Partition offset 0x%x is below 0x%x' % (p.offset, offset_part_table + PARTITION_TABLE_SIZE))
if last is not None and p.offset < last.offset + last.size: if last is not None and p.offset < last.offset + last.size:
raise InputError("Partition at 0x%x overlaps 0x%x-0x%x" % (p.offset, last.offset, last.offset + last.size - 1)) raise InputError('Partition at 0x%x overlaps 0x%x-0x%x' % (p.offset, last.offset, last.offset + last.size - 1))
last = p last = p
def flash_size(self): def flash_size(self):
@ -218,7 +218,7 @@ class PartitionTable(list):
for o in range(0,len(b),32): for o in range(0,len(b),32):
data = b[o:o + 32] data = b[o:o + 32]
if len(data) != 32: if len(data) != 32:
raise InputError("Partition table length must be a multiple of 32 bytes") raise InputError('Partition table length must be a multiple of 32 bytes')
if data == b'\xFF' * 32: if data == b'\xFF' * 32:
return result # got end marker return result # got end marker
if md5sum and data[:2] == MD5_PARTITION_BEGIN[:2]: # check only the magic number part if md5sum and data[:2] == MD5_PARTITION_BEGIN[:2]: # check only the magic number part
@ -229,26 +229,26 @@ class PartitionTable(list):
else: else:
md5.update(data) md5.update(data)
result.append(PartitionDefinition.from_binary(data)) result.append(PartitionDefinition.from_binary(data))
raise InputError("Partition table is missing an end-of-table marker") raise InputError('Partition table is missing an end-of-table marker')
def to_binary(self): def to_binary(self):
result = b"".join(e.to_binary() for e in self) result = b''.join(e.to_binary() for e in self)
if md5sum: if md5sum:
result += MD5_PARTITION_BEGIN + hashlib.md5(result).digest() result += MD5_PARTITION_BEGIN + hashlib.md5(result).digest()
if len(result) >= MAX_PARTITION_LENGTH: if len(result) >= MAX_PARTITION_LENGTH:
raise InputError("Binary partition table length (%d) longer than max" % len(result)) raise InputError('Binary partition table length (%d) longer than max' % len(result))
result += b"\xFF" * (MAX_PARTITION_LENGTH - len(result)) # pad the sector, for signing result += b'\xFF' * (MAX_PARTITION_LENGTH - len(result)) # pad the sector, for signing
return result return result
def to_csv(self, simple_formatting=False): def to_csv(self, simple_formatting=False):
rows = ["# ESP-IDF Partition Table", rows = ['# ESP-IDF Partition Table',
"# Name, Type, SubType, Offset, Size, Flags"] '# Name, Type, SubType, Offset, Size, Flags']
rows += [x.to_csv(simple_formatting) for x in self] rows += [x.to_csv(simple_formatting) for x in self]
return "\n".join(rows) + "\n" return '\n'.join(rows) + '\n'
class PartitionDefinition(object): class PartitionDefinition(object):
MAGIC_BYTES = b"\xAA\x50" MAGIC_BYTES = b'\xAA\x50'
ALIGNMENT = { ALIGNMENT = {
APP_TYPE: 0x10000, APP_TYPE: 0x10000,
@ -258,15 +258,15 @@ class PartitionDefinition(object):
# dictionary maps flag name (as used in CSV flags list, property name) # dictionary maps flag name (as used in CSV flags list, property name)
# to bit set in flags words in binary format # to bit set in flags words in binary format
FLAGS = { FLAGS = {
"encrypted": 0 'encrypted': 0
} }
# add subtypes for the 16 OTA slot values ("ota_XX, etc.") # add subtypes for the 16 OTA slot values ("ota_XX, etc.")
for ota_slot in range(NUM_PARTITION_SUBTYPE_APP_OTA): for ota_slot in range(NUM_PARTITION_SUBTYPE_APP_OTA):
SUBTYPES[TYPES["app"]]["ota_%d" % ota_slot] = MIN_PARTITION_SUBTYPE_APP_OTA + ota_slot SUBTYPES[TYPES['app']]['ota_%d' % ota_slot] = MIN_PARTITION_SUBTYPE_APP_OTA + ota_slot
def __init__(self): def __init__(self):
self.name = "" self.name = ''
self.type = None self.type = None
self.subtype = None self.subtype = None
self.offset = None self.offset = None
@ -276,8 +276,8 @@ class PartitionDefinition(object):
@classmethod @classmethod
def from_csv(cls, line, line_no): def from_csv(cls, line, line_no):
""" Parse a line from the CSV """ """ Parse a line from the CSV """
line_w_defaults = line + ",,,," # lazy way to support default fields line_w_defaults = line + ',,,,' # lazy way to support default fields
fields = [f.strip() for f in line_w_defaults.split(",")] fields = [f.strip() for f in line_w_defaults.split(',')]
res = PartitionDefinition() res = PartitionDefinition()
res.line_no = line_no res.line_no = line_no
@ -289,7 +289,7 @@ class PartitionDefinition(object):
if res.size is None: if res.size is None:
raise InputError("Size field can't be empty") raise InputError("Size field can't be empty")
flags = fields[5].split(":") flags = fields[5].split(':')
for flag in flags: for flag in flags:
if flag in cls.FLAGS: if flag in cls.FLAGS:
setattr(res, flag, True) setattr(res, flag, True)
@ -305,7 +305,7 @@ class PartitionDefinition(object):
def __repr__(self): def __repr__(self):
def maybe_hex(x): def maybe_hex(x):
return "0x%x" % x if x is not None else "None" return '0x%x' % x if x is not None else 'None'
return "PartitionDefinition('%s', 0x%x, 0x%x, %s, %s)" % (self.name, self.type, self.subtype or 0, return "PartitionDefinition('%s', 0x%x, 0x%x, %s, %s)" % (self.name, self.type, self.subtype or 0,
maybe_hex(self.offset), maybe_hex(self.size)) maybe_hex(self.offset), maybe_hex(self.size))
@ -328,65 +328,65 @@ class PartitionDefinition(object):
return self.offset >= other.offset return self.offset >= other.offset
def parse_type(self, strval): def parse_type(self, strval):
if strval == "": if strval == '':
raise InputError("Field 'type' can't be left empty.") raise InputError("Field 'type' can't be left empty.")
return parse_int(strval, TYPES) return parse_int(strval, TYPES)
def parse_subtype(self, strval): def parse_subtype(self, strval):
if strval == "": if strval == '':
return 0 # default return 0 # default
return parse_int(strval, SUBTYPES.get(self.type, {})) return parse_int(strval, SUBTYPES.get(self.type, {}))
def parse_address(self, strval): def parse_address(self, strval):
if strval == "": if strval == '':
return None # PartitionTable will fill in default return None # PartitionTable will fill in default
return parse_int(strval) return parse_int(strval)
def verify(self): def verify(self):
if self.type is None: if self.type is None:
raise ValidationError(self, "Type field is not set") raise ValidationError(self, 'Type field is not set')
if self.subtype is None: if self.subtype is None:
raise ValidationError(self, "Subtype field is not set") raise ValidationError(self, 'Subtype field is not set')
if self.offset is None: if self.offset is None:
raise ValidationError(self, "Offset field is not set") raise ValidationError(self, 'Offset field is not set')
align = self.ALIGNMENT.get(self.type, 4) align = self.ALIGNMENT.get(self.type, 4)
if self.offset % align: if self.offset % align:
raise ValidationError(self, "Offset 0x%x is not aligned to 0x%x" % (self.offset, align)) raise ValidationError(self, 'Offset 0x%x is not aligned to 0x%x' % (self.offset, align))
if self.size % align and secure: if self.size % align and secure:
raise ValidationError(self, "Size 0x%x is not aligned to 0x%x" % (self.size, align)) raise ValidationError(self, 'Size 0x%x is not aligned to 0x%x' % (self.size, align))
if self.size is None: if self.size is None:
raise ValidationError(self, "Size field is not set") raise ValidationError(self, 'Size field is not set')
if self.name in TYPES and TYPES.get(self.name, "") != self.type: if self.name in TYPES and TYPES.get(self.name, '') != self.type:
critical("WARNING: Partition has name '%s' which is a partition type, but does not match this partition's " critical("WARNING: Partition has name '%s' which is a partition type, but does not match this partition's "
"type (0x%x). Mistake in partition table?" % (self.name, self.type)) 'type (0x%x). Mistake in partition table?' % (self.name, self.type))
all_subtype_names = [] all_subtype_names = []
for names in (t.keys() for t in SUBTYPES.values()): for names in (t.keys() for t in SUBTYPES.values()):
all_subtype_names += names all_subtype_names += names
if self.name in all_subtype_names and SUBTYPES.get(self.type, {}).get(self.name, "") != self.subtype: if self.name in all_subtype_names and SUBTYPES.get(self.type, {}).get(self.name, '') != self.subtype:
critical("WARNING: Partition has name '%s' which is a partition subtype, but this partition has " critical("WARNING: Partition has name '%s' which is a partition subtype, but this partition has "
"non-matching type 0x%x and subtype 0x%x. Mistake in partition table?" % (self.name, self.type, self.subtype)) 'non-matching type 0x%x and subtype 0x%x. Mistake in partition table?' % (self.name, self.type, self.subtype))
STRUCT_FORMAT = b"<2sBBLL16sL" STRUCT_FORMAT = b'<2sBBLL16sL'
@classmethod @classmethod
def from_binary(cls, b): def from_binary(cls, b):
if len(b) != 32: if len(b) != 32:
raise InputError("Partition definition length must be exactly 32 bytes. Got %d bytes." % len(b)) raise InputError('Partition definition length must be exactly 32 bytes. Got %d bytes.' % len(b))
res = cls() res = cls()
(magic, res.type, res.subtype, res.offset, (magic, res.type, res.subtype, res.offset,
res.size, res.name, flags) = struct.unpack(cls.STRUCT_FORMAT, b) res.size, res.name, flags) = struct.unpack(cls.STRUCT_FORMAT, b)
if b"\x00" in res.name: # strip null byte padding from name string if b'\x00' in res.name: # strip null byte padding from name string
res.name = res.name[:res.name.index(b"\x00")] res.name = res.name[:res.name.index(b'\x00')]
res.name = res.name.decode() res.name = res.name.decode()
if magic != cls.MAGIC_BYTES: if magic != cls.MAGIC_BYTES:
raise InputError("Invalid magic bytes (%r) for partition definition" % magic) raise InputError('Invalid magic bytes (%r) for partition definition' % magic)
for flag,bit in cls.FLAGS.items(): for flag,bit in cls.FLAGS.items():
if flags & (1 << bit): if flags & (1 << bit):
setattr(res, flag, True) setattr(res, flag, True)
flags &= ~(1 << bit) flags &= ~(1 << bit)
if flags != 0: if flags != 0:
critical("WARNING: Partition definition had unknown flag(s) 0x%08x. Newer binary format?" % flags) critical('WARNING: Partition definition had unknown flag(s) 0x%08x. Newer binary format?' % flags)
return res return res
def get_flags_list(self): def get_flags_list(self):
@ -404,22 +404,22 @@ class PartitionDefinition(object):
def to_csv(self, simple_formatting=False): def to_csv(self, simple_formatting=False):
def addr_format(a, include_sizes): def addr_format(a, include_sizes):
if not simple_formatting and include_sizes: if not simple_formatting and include_sizes:
for (val, suffix) in [(0x100000, "M"), (0x400, "K")]: for (val, suffix) in [(0x100000, 'M'), (0x400, 'K')]:
if a % val == 0: if a % val == 0:
return "%d%s" % (a // val, suffix) return '%d%s' % (a // val, suffix)
return "0x%x" % a return '0x%x' % a
def lookup_keyword(t, keywords): def lookup_keyword(t, keywords):
for k,v in keywords.items(): for k,v in keywords.items():
if simple_formatting is False and t == v: if simple_formatting is False and t == v:
return k return k
return "%d" % t return '%d' % t
def generate_text_flags(): def generate_text_flags():
""" colon-delimited list of flags """ """ colon-delimited list of flags """
return ":".join(self.get_flags_list()) return ':'.join(self.get_flags_list())
return ",".join([self.name, return ','.join([self.name,
lookup_keyword(self.type, TYPES), lookup_keyword(self.type, TYPES),
lookup_keyword(self.subtype, SUBTYPES.get(self.type, {})), lookup_keyword(self.subtype, SUBTYPES.get(self.type, {})),
addr_format(self.offset, False), addr_format(self.offset, False),
@ -432,17 +432,17 @@ def parse_int(v, keywords={}):
k/m/K/M suffixes and 'keyword' value lookup. k/m/K/M suffixes and 'keyword' value lookup.
""" """
try: try:
for letter, multiplier in [("k", 1024), ("m", 1024 * 1024)]: for letter, multiplier in [('k', 1024), ('m', 1024 * 1024)]:
if v.lower().endswith(letter): if v.lower().endswith(letter):
return parse_int(v[:-1], keywords) * multiplier return parse_int(v[:-1], keywords) * multiplier
return int(v, 0) return int(v, 0)
except ValueError: except ValueError:
if len(keywords) == 0: if len(keywords) == 0:
raise InputError("Invalid field value %s" % v) raise InputError('Invalid field value %s' % v)
try: try:
return keywords[v.lower()] return keywords[v.lower()]
except KeyError: except KeyError:
raise InputError("Value '%s' is not valid. Known keywords: %s" % (v, ", ".join(keywords))) raise InputError("Value '%s' is not valid. Known keywords: %s" % (v, ', '.join(keywords)))
def main(): def main():
@ -456,11 +456,11 @@ def main():
nargs='?', choices=['1MB', '2MB', '4MB', '8MB', '16MB']) nargs='?', choices=['1MB', '2MB', '4MB', '8MB', '16MB'])
parser.add_argument('--disable-md5sum', help='Disable md5 checksum for the partition table', default=False, action='store_true') parser.add_argument('--disable-md5sum', help='Disable md5 checksum for the partition table', default=False, action='store_true')
parser.add_argument('--no-verify', help="Don't verify partition table fields", action='store_true') parser.add_argument('--no-verify', help="Don't verify partition table fields", action='store_true')
parser.add_argument('--verify', '-v', help="Verify partition table fields (deprecated, this behaviour is " parser.add_argument('--verify', '-v', help='Verify partition table fields (deprecated, this behaviour is '
"enabled by default and this flag does nothing.", action='store_true') 'enabled by default and this flag does nothing.', action='store_true')
parser.add_argument('--quiet', '-q', help="Don't print non-critical status messages to stderr", action='store_true') parser.add_argument('--quiet', '-q', help="Don't print non-critical status messages to stderr", action='store_true')
parser.add_argument('--offset', '-o', help='Set offset partition table', default='0x8000') parser.add_argument('--offset', '-o', help='Set offset partition table', default='0x8000')
parser.add_argument('--secure', help="Require app partitions to be suitable for secure boot", action='store_true') parser.add_argument('--secure', help='Require app partitions to be suitable for secure boot', action='store_true')
parser.add_argument('input', help='Path to CSV or binary file to parse.', type=argparse.FileType('rb')) parser.add_argument('input', help='Path to CSV or binary file to parse.', type=argparse.FileType('rb'))
parser.add_argument('output', help='Path to output converted binary or CSV file. Will use stdout if omitted.', parser.add_argument('output', help='Path to output converted binary or CSV file. Will use stdout if omitted.',
nargs='?', default='-') nargs='?', default='-')
@ -474,19 +474,19 @@ def main():
input = args.input.read() input = args.input.read()
input_is_binary = input[0:2] == PartitionDefinition.MAGIC_BYTES input_is_binary = input[0:2] == PartitionDefinition.MAGIC_BYTES
if input_is_binary: if input_is_binary:
status("Parsing binary partition input...") status('Parsing binary partition input...')
table = PartitionTable.from_binary(input) table = PartitionTable.from_binary(input)
else: else:
input = input.decode() input = input.decode()
status("Parsing CSV input...") status('Parsing CSV input...')
table = PartitionTable.from_csv(input) table = PartitionTable.from_csv(input)
if not args.no_verify: if not args.no_verify:
status("Verifying table...") status('Verifying table...')
table.verify() table.verify()
if args.flash_size: if args.flash_size:
size_mb = int(args.flash_size.replace("MB", "")) size_mb = int(args.flash_size.replace('MB', ''))
size = size_mb * 1024 * 1024 # flash memory uses honest megabytes! size = size_mb * 1024 * 1024 # flash memory uses honest megabytes!
table_size = table.flash_size() table_size = table.flash_size()
if size < table_size: if size < table_size:
@ -526,7 +526,7 @@ class InputError(RuntimeError):
class ValidationError(InputError): class ValidationError(InputError):
def __init__(self, partition, message): def __init__(self, partition, message):
super(ValidationError, self).__init__( super(ValidationError, self).__init__(
"Partition %s invalid: %s" % (partition.name, message)) 'Partition %s invalid: %s' % (partition.name, message))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -16,20 +16,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function, division from __future__ import division, print_function
import argparse import argparse
import os import os
import sys
import subprocess
import tempfile
import re import re
import gen_esp32part as gen import subprocess
import sys
import tempfile
import gen_esp32part as gen
__version__ = '2.0' __version__ = '2.0'
COMPONENTS_PATH = os.path.expandvars(os.path.join("$IDF_PATH", "components")) COMPONENTS_PATH = os.path.expandvars(os.path.join('$IDF_PATH', 'components'))
ESPTOOL_PY = os.path.join(COMPONENTS_PATH, "esptool_py", "esptool", "esptool.py") ESPTOOL_PY = os.path.join(COMPONENTS_PATH, 'esptool_py', 'esptool', 'esptool.py')
PARTITION_TABLE_OFFSET = 0x8000 PARTITION_TABLE_OFFSET = 0x8000
@ -78,14 +79,14 @@ class ParttoolTarget():
def parse_esptool_args(esptool_args): def parse_esptool_args(esptool_args):
results = list() results = list()
for arg in esptool_args: for arg in esptool_args:
pattern = re.compile(r"(.+)=(.+)") pattern = re.compile(r'(.+)=(.+)')
result = pattern.match(arg) result = pattern.match(arg)
try: try:
key = result.group(1) key = result.group(1)
value = result.group(2) value = result.group(2)
results.extend(["--" + key, value]) results.extend(['--' + key, value])
except AttributeError: except AttributeError:
results.extend(["--" + arg]) results.extend(['--' + arg])
return results return results
self.esptool_args = parse_esptool_args(esptool_args) self.esptool_args = parse_esptool_args(esptool_args)
@ -95,14 +96,14 @@ class ParttoolTarget():
if partition_table_file: if partition_table_file:
partition_table = None partition_table = None
with open(partition_table_file, "rb") as f: with open(partition_table_file, 'rb') as f:
input_is_binary = (f.read(2) == gen.PartitionDefinition.MAGIC_BYTES) input_is_binary = (f.read(2) == gen.PartitionDefinition.MAGIC_BYTES)
f.seek(0) f.seek(0)
if input_is_binary: if input_is_binary:
partition_table = gen.PartitionTable.from_binary(f.read()) partition_table = gen.PartitionTable.from_binary(f.read())
if partition_table is None: if partition_table is None:
with open(partition_table_file, "r") as f: with open(partition_table_file, 'r') as f:
f.seek(0) f.seek(0)
partition_table = gen.PartitionTable.from_csv(f.read()) partition_table = gen.PartitionTable.from_csv(f.read())
else: else:
@ -110,8 +111,8 @@ class ParttoolTarget():
temp_file.close() temp_file.close()
try: try:
self._call_esptool(["read_flash", str(partition_table_offset), str(gen.MAX_PARTITION_LENGTH), temp_file.name]) self._call_esptool(['read_flash', str(partition_table_offset), str(gen.MAX_PARTITION_LENGTH), temp_file.name])
with open(temp_file.name, "rb") as f: with open(temp_file.name, 'rb') as f:
partition_table = gen.PartitionTable.from_binary(f.read()) partition_table = gen.PartitionTable.from_binary(f.read())
finally: finally:
os.unlink(temp_file.name) os.unlink(temp_file.name)
@ -125,18 +126,18 @@ class ParttoolTarget():
esptool_args = [sys.executable, ESPTOOL_PY] + self.esptool_args esptool_args = [sys.executable, ESPTOOL_PY] + self.esptool_args
if self.port: if self.port:
esptool_args += ["--port", self.port] esptool_args += ['--port', self.port]
if self.baud: if self.baud:
esptool_args += ["--baud", str(self.baud)] esptool_args += ['--baud', str(self.baud)]
esptool_args += args esptool_args += args
print("Running %s..." % (" ".join(esptool_args))) print('Running %s...' % (' '.join(esptool_args)))
try: try:
subprocess.check_call(esptool_args, stdout=out, stderr=subprocess.STDOUT) subprocess.check_call(esptool_args, stdout=out, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print("An exception: **", str(e), "** occurred in _call_esptool.", file=out) print('An exception: **', str(e), '** occurred in _call_esptool.', file=out)
raise e raise e
def get_partition_info(self, partition_id): def get_partition_info(self, partition_id):
@ -149,37 +150,37 @@ class ParttoolTarget():
if not partition_id.part_list: if not partition_id.part_list:
partition = partition[0] partition = partition[0]
else: # default boot partition else: # default boot partition
search = ["factory"] + ["ota_{}".format(d) for d in range(16)] search = ['factory'] + ['ota_{}'.format(d) for d in range(16)]
for subtype in search: for subtype in search:
partition = next(self.partition_table.find_by_type("app", subtype), None) partition = next(self.partition_table.find_by_type('app', subtype), None)
if partition: if partition:
break break
if not partition: if not partition:
raise Exception("Partition does not exist") raise Exception('Partition does not exist')
return partition return partition
def erase_partition(self, partition_id): def erase_partition(self, partition_id):
partition = self.get_partition_info(partition_id) partition = self.get_partition_info(partition_id)
self._call_esptool(["erase_region", str(partition.offset), str(partition.size)] + self.esptool_erase_args) self._call_esptool(['erase_region', str(partition.offset), str(partition.size)] + self.esptool_erase_args)
def read_partition(self, partition_id, output): def read_partition(self, partition_id, output):
partition = self.get_partition_info(partition_id) partition = self.get_partition_info(partition_id)
self._call_esptool(["read_flash", str(partition.offset), str(partition.size), output] + self.esptool_read_args) self._call_esptool(['read_flash', str(partition.offset), str(partition.size), output] + self.esptool_read_args)
def write_partition(self, partition_id, input): def write_partition(self, partition_id, input):
self.erase_partition(partition_id) self.erase_partition(partition_id)
partition = self.get_partition_info(partition_id) partition = self.get_partition_info(partition_id)
with open(input, "rb") as input_file: with open(input, 'rb') as input_file:
content_len = len(input_file.read()) content_len = len(input_file.read())
if content_len > partition.size: if content_len > partition.size:
raise Exception("Input file size exceeds partition size") raise Exception('Input file size exceeds partition size')
self._call_esptool(["write_flash", str(partition.offset), input] + self.esptool_write_args) self._call_esptool(['write_flash', str(partition.offset), input] + self.esptool_write_args)
def _write_partition(target, partition_id, input): def _write_partition(target, partition_id, input):
@ -214,41 +215,41 @@ def _get_partition_info(target, partition_id, info):
try: try:
for p in partitions: for p in partitions:
info_dict = { info_dict = {
"name": '{}'.format(p.name), 'name': '{}'.format(p.name),
"type": '{}'.format(p.type), 'type': '{}'.format(p.type),
"subtype": '{}'.format(p.subtype), 'subtype': '{}'.format(p.subtype),
"offset": '0x{:x}'.format(p.offset), 'offset': '0x{:x}'.format(p.offset),
"size": '0x{:x}'.format(p.size), 'size': '0x{:x}'.format(p.size),
"encrypted": '{}'.format(p.encrypted) 'encrypted': '{}'.format(p.encrypted)
} }
for i in info: for i in info:
infos += [info_dict[i]] infos += [info_dict[i]]
except KeyError: except KeyError:
raise RuntimeError("Request for unknown partition info {}".format(i)) raise RuntimeError('Request for unknown partition info {}'.format(i))
print(" ".join(infos)) print(' '.join(infos))
def main(): def main():
global quiet global quiet
parser = argparse.ArgumentParser("ESP-IDF Partitions Tool") parser = argparse.ArgumentParser('ESP-IDF Partitions Tool')
parser.add_argument("--quiet", "-q", help="suppress stderr messages", action="store_true") parser.add_argument('--quiet', '-q', help='suppress stderr messages', action='store_true')
parser.add_argument("--esptool-args", help="additional main arguments for esptool", nargs="+") parser.add_argument('--esptool-args', help='additional main arguments for esptool', nargs='+')
parser.add_argument("--esptool-write-args", help="additional subcommand arguments when writing to flash", nargs="+") parser.add_argument('--esptool-write-args', help='additional subcommand arguments when writing to flash', nargs='+')
parser.add_argument("--esptool-read-args", help="additional subcommand arguments when reading flash", nargs="+") parser.add_argument('--esptool-read-args', help='additional subcommand arguments when reading flash', nargs='+')
parser.add_argument("--esptool-erase-args", help="additional subcommand arguments when erasing regions of flash", nargs="+") parser.add_argument('--esptool-erase-args', help='additional subcommand arguments when erasing regions of flash', nargs='+')
# By default the device attached to the specified port is queried for the partition table. If a partition table file # By default the device attached to the specified port is queried for the partition table. If a partition table file
# is specified, that is used instead. # is specified, that is used instead.
parser.add_argument("--port", "-p", help="port where the target device of the command is connected to; the partition table is sourced from this device \ parser.add_argument('--port', '-p', help='port where the target device of the command is connected to; the partition table is sourced from this device \
when the partition table file is not defined") when the partition table file is not defined')
parser.add_argument("--baud", "-b", help="baudrate to use", type=int) parser.add_argument('--baud', '-b', help='baudrate to use', type=int)
parser.add_argument("--partition-table-offset", "-o", help="offset to read the partition table from", type=str) parser.add_argument('--partition-table-offset', '-o', help='offset to read the partition table from', type=str)
parser.add_argument("--partition-table-file", "-f", help="file (CSV/binary) to read the partition table from; \ parser.add_argument('--partition-table-file', '-f', help='file (CSV/binary) to read the partition table from; \
overrides device attached to specified port as the partition table source when defined") overrides device attached to specified port as the partition table source when defined')
partition_selection_parser = argparse.ArgumentParser(add_help=False) partition_selection_parser = argparse.ArgumentParser(add_help=False)
@ -256,30 +257,30 @@ def main():
# partition name or the first partition that matches the specified type/subtype # partition name or the first partition that matches the specified type/subtype
partition_selection_args = partition_selection_parser.add_mutually_exclusive_group() partition_selection_args = partition_selection_parser.add_mutually_exclusive_group()
partition_selection_args.add_argument("--partition-name", "-n", help="name of the partition") partition_selection_args.add_argument('--partition-name', '-n', help='name of the partition')
partition_selection_args.add_argument("--partition-type", "-t", help="type of the partition") partition_selection_args.add_argument('--partition-type', '-t', help='type of the partition')
partition_selection_args.add_argument('--partition-boot-default', "-d", help='select the default boot partition \ partition_selection_args.add_argument('--partition-boot-default', '-d', help='select the default boot partition \
using the same fallback logic as the IDF bootloader', action="store_true") using the same fallback logic as the IDF bootloader', action='store_true')
partition_selection_parser.add_argument("--partition-subtype", "-s", help="subtype of the partition") partition_selection_parser.add_argument('--partition-subtype', '-s', help='subtype of the partition')
subparsers = parser.add_subparsers(dest="operation", help="run parttool -h for additional help") subparsers = parser.add_subparsers(dest='operation', help='run parttool -h for additional help')
# Specify the supported operations # Specify the supported operations
read_part_subparser = subparsers.add_parser("read_partition", help="read partition from device and dump contents into a file", read_part_subparser = subparsers.add_parser('read_partition', help='read partition from device and dump contents into a file',
parents=[partition_selection_parser]) parents=[partition_selection_parser])
read_part_subparser.add_argument("--output", help="file to dump the read partition contents to") read_part_subparser.add_argument('--output', help='file to dump the read partition contents to')
write_part_subparser = subparsers.add_parser("write_partition", help="write contents of a binary file to partition on device", write_part_subparser = subparsers.add_parser('write_partition', help='write contents of a binary file to partition on device',
parents=[partition_selection_parser]) parents=[partition_selection_parser])
write_part_subparser.add_argument("--input", help="file whose contents are to be written to the partition offset") write_part_subparser.add_argument('--input', help='file whose contents are to be written to the partition offset')
subparsers.add_parser("erase_partition", help="erase the contents of a partition on the device", parents=[partition_selection_parser]) subparsers.add_parser('erase_partition', help='erase the contents of a partition on the device', parents=[partition_selection_parser])
print_partition_info_subparser = subparsers.add_parser("get_partition_info", help="get partition information", parents=[partition_selection_parser]) print_partition_info_subparser = subparsers.add_parser('get_partition_info', help='get partition information', parents=[partition_selection_parser])
print_partition_info_subparser.add_argument("--info", help="type of partition information to get", print_partition_info_subparser.add_argument('--info', help='type of partition information to get',
choices=["name", "type", "subtype", "offset", "size", "encrypted"], default=["offset", "size"], nargs="+") choices=['name', 'type', 'subtype', 'offset', 'size', 'encrypted'], default=['offset', 'size'], nargs='+')
print_partition_info_subparser.add_argument('--part_list', help="Get a list of partitions suitable for a given type", action='store_true') print_partition_info_subparser.add_argument('--part_list', help='Get a list of partitions suitable for a given type', action='store_true')
args = parser.parse_args() args = parser.parse_args()
quiet = args.quiet quiet = args.quiet
@ -295,40 +296,40 @@ def main():
partition_id = PartitionName(args.partition_name) partition_id = PartitionName(args.partition_name)
elif args.partition_type: elif args.partition_type:
if not args.partition_subtype: if not args.partition_subtype:
raise RuntimeError("--partition-subtype should be defined when --partition-type is defined") raise RuntimeError('--partition-subtype should be defined when --partition-type is defined')
partition_id = PartitionType(args.partition_type, args.partition_subtype, getattr(args, 'part_list', None)) partition_id = PartitionType(args.partition_type, args.partition_subtype, getattr(args, 'part_list', None))
elif args.partition_boot_default: elif args.partition_boot_default:
partition_id = PARTITION_BOOT_DEFAULT partition_id = PARTITION_BOOT_DEFAULT
else: else:
raise RuntimeError("Partition to operate on should be defined using --partition-name OR \ raise RuntimeError('Partition to operate on should be defined using --partition-name OR \
partition-type,--partition-subtype OR partition-boot-default") partition-type,--partition-subtype OR partition-boot-default')
# Prepare the device to perform operation on # Prepare the device to perform operation on
target_args = {} target_args = {}
if args.port: if args.port:
target_args["port"] = args.port target_args['port'] = args.port
if args.baud: if args.baud:
target_args["baud"] = args.baud target_args['baud'] = args.baud
if args.partition_table_file: if args.partition_table_file:
target_args["partition_table_file"] = args.partition_table_file target_args['partition_table_file'] = args.partition_table_file
if args.partition_table_offset: if args.partition_table_offset:
target_args["partition_table_offset"] = int(args.partition_table_offset, 0) target_args['partition_table_offset'] = int(args.partition_table_offset, 0)
if args.esptool_args: if args.esptool_args:
target_args["esptool_args"] = args.esptool_args target_args['esptool_args'] = args.esptool_args
if args.esptool_write_args: if args.esptool_write_args:
target_args["esptool_write_args"] = args.esptool_write_args target_args['esptool_write_args'] = args.esptool_write_args
if args.esptool_read_args: if args.esptool_read_args:
target_args["esptool_read_args"] = args.esptool_read_args target_args['esptool_read_args'] = args.esptool_read_args
if args.esptool_erase_args: if args.esptool_erase_args:
target_args["esptool_erase_args"] = args.esptool_erase_args target_args['esptool_erase_args'] = args.esptool_erase_args
target = ParttoolTarget(**target_args) target = ParttoolTarget(**target_args)
@ -336,9 +337,9 @@ def main():
common_args = {'target':target, 'partition_id':partition_id} common_args = {'target':target, 'partition_id':partition_id}
parttool_ops = { parttool_ops = {
'erase_partition':(_erase_partition, []), 'erase_partition':(_erase_partition, []),
'read_partition':(_read_partition, ["output"]), 'read_partition':(_read_partition, ['output']),
'write_partition':(_write_partition, ["input"]), 'write_partition':(_write_partition, ['input']),
'get_partition_info':(_get_partition_info, ["info"]) 'get_partition_info':(_get_partition_info, ['info'])
} }
(op, op_args) = parttool_ops[args.operation] (op, op_args) = parttool_ops[args.operation]

View File

@ -1,18 +1,19 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import print_function, division from __future__ import division, print_function
import unittest
import struct
import csv import csv
import sys
import subprocess
import tempfile
import os
import io import io
import os
import struct
import subprocess
import sys
import tempfile
import unittest
try: try:
import gen_esp32part import gen_esp32part
except ImportError: except ImportError:
sys.path.append("..") sys.path.append('..')
import gen_esp32part import gen_esp32part
SIMPLE_CSV = """ SIMPLE_CSV = """
@ -20,40 +21,40 @@ SIMPLE_CSV = """
factory,0,2,65536,1048576, factory,0,2,65536,1048576,
""" """
LONGER_BINARY_TABLE = b"" LONGER_BINARY_TABLE = b''
# type 0x00, subtype 0x00, # type 0x00, subtype 0x00,
# offset 64KB, size 1MB # offset 64KB, size 1MB
LONGER_BINARY_TABLE += b"\xAA\x50\x00\x00" + \ LONGER_BINARY_TABLE += b'\xAA\x50\x00\x00' + \
b"\x00\x00\x01\x00" + \ b'\x00\x00\x01\x00' + \
b"\x00\x00\x10\x00" + \ b'\x00\x00\x10\x00' + \
b"factory\0" + (b"\0" * 8) + \ b'factory\0' + (b'\0' * 8) + \
b"\x00\x00\x00\x00" b'\x00\x00\x00\x00'
# type 0x01, subtype 0x20, # type 0x01, subtype 0x20,
# offset 0x110000, size 128KB # offset 0x110000, size 128KB
LONGER_BINARY_TABLE += b"\xAA\x50\x01\x20" + \ LONGER_BINARY_TABLE += b'\xAA\x50\x01\x20' + \
b"\x00\x00\x11\x00" + \ b'\x00\x00\x11\x00' + \
b"\x00\x02\x00\x00" + \ b'\x00\x02\x00\x00' + \
b"data" + (b"\0" * 12) + \ b'data' + (b'\0' * 12) + \
b"\x00\x00\x00\x00" b'\x00\x00\x00\x00'
# type 0x10, subtype 0x00, # type 0x10, subtype 0x00,
# offset 0x150000, size 1MB # offset 0x150000, size 1MB
LONGER_BINARY_TABLE += b"\xAA\x50\x10\x00" + \ LONGER_BINARY_TABLE += b'\xAA\x50\x10\x00' + \
b"\x00\x00\x15\x00" + \ b'\x00\x00\x15\x00' + \
b"\x00\x10\x00\x00" + \ b'\x00\x10\x00\x00' + \
b"second" + (b"\0" * 10) + \ b'second' + (b'\0' * 10) + \
b"\x00\x00\x00\x00" b'\x00\x00\x00\x00'
# MD5 checksum # MD5 checksum
LONGER_BINARY_TABLE += b"\xEB\xEB" + b"\xFF" * 14 LONGER_BINARY_TABLE += b'\xEB\xEB' + b'\xFF' * 14
LONGER_BINARY_TABLE += b'\xf9\xbd\x06\x1b\x45\x68\x6f\x86\x57\x1a\x2c\xd5\x2a\x1d\xa6\x5b' LONGER_BINARY_TABLE += b'\xf9\xbd\x06\x1b\x45\x68\x6f\x86\x57\x1a\x2c\xd5\x2a\x1d\xa6\x5b'
# empty partition # empty partition
LONGER_BINARY_TABLE += b"\xFF" * 32 LONGER_BINARY_TABLE += b'\xFF' * 32
def _strip_trailing_ffs(binary_table): def _strip_trailing_ffs(binary_table):
""" """
Strip all FFs down to the last 32 bytes (terminating entry) Strip all FFs down to the last 32 bytes (terminating entry)
""" """
while binary_table.endswith(b"\xFF" * 64): while binary_table.endswith(b'\xFF' * 64):
binary_table = binary_table[0:len(binary_table) - 32] binary_table = binary_table[0:len(binary_table) - 32]
return binary_table return binary_table
@ -75,7 +76,7 @@ class CSVParserTests(Py23TestCase):
def test_simple_partition(self): def test_simple_partition(self):
table = gen_esp32part.PartitionTable.from_csv(SIMPLE_CSV) table = gen_esp32part.PartitionTable.from_csv(SIMPLE_CSV)
self.assertEqual(len(table), 1) self.assertEqual(len(table), 1)
self.assertEqual(table[0].name, "factory") self.assertEqual(table[0].name, 'factory')
self.assertEqual(table[0].type, 0) self.assertEqual(table[0].type, 0)
self.assertEqual(table[0].subtype, 2) self.assertEqual(table[0].subtype, 2)
self.assertEqual(table[0].offset, 65536) self.assertEqual(table[0].offset, 65536)
@ -86,7 +87,7 @@ class CSVParserTests(Py23TestCase):
# Name,Type, SubType,Offset,Size # Name,Type, SubType,Offset,Size
ihavenotype, ihavenotype,
""" """
with self.assertRaisesRegex(gen_esp32part.InputError, "type"): with self.assertRaisesRegex(gen_esp32part.InputError, 'type'):
gen_esp32part.PartitionTable.from_csv(csv) gen_esp32part.PartitionTable.from_csv(csv)
def test_type_subtype_names(self): def test_type_subtype_names(self):
@ -115,15 +116,15 @@ myota_status, data, ota,, 0x100000
nomagic = gen_esp32part.PartitionTable.from_csv(csv_nomagicnumbers) nomagic = gen_esp32part.PartitionTable.from_csv(csv_nomagicnumbers)
nomagic.verify() nomagic.verify()
self.assertEqual(nomagic["myapp"].type, 0) self.assertEqual(nomagic['myapp'].type, 0)
self.assertEqual(nomagic["myapp"].subtype, 0) self.assertEqual(nomagic['myapp'].subtype, 0)
self.assertEqual(nomagic["myapp"], magic["myapp"]) self.assertEqual(nomagic['myapp'], magic['myapp'])
self.assertEqual(nomagic["myota_0"].type, 0) self.assertEqual(nomagic['myota_0'].type, 0)
self.assertEqual(nomagic["myota_0"].subtype, 0x10) self.assertEqual(nomagic['myota_0'].subtype, 0x10)
self.assertEqual(nomagic["myota_0"], magic["myota_0"]) self.assertEqual(nomagic['myota_0'], magic['myota_0'])
self.assertEqual(nomagic["myota_15"], magic["myota_15"]) self.assertEqual(nomagic['myota_15'], magic['myota_15'])
self.assertEqual(nomagic["mytest"], magic["mytest"]) self.assertEqual(nomagic['mytest'], magic['mytest'])
self.assertEqual(nomagic["myota_status"], magic["myota_status"]) self.assertEqual(nomagic['myota_status'], magic['myota_status'])
# self.assertEqual(nomagic.to_binary(), magic.to_binary()) # self.assertEqual(nomagic.to_binary(), magic.to_binary())
@ -176,7 +177,7 @@ second, data, 0x15, , 1M
first, app, factory, 0x100000, 2M first, app, factory, 0x100000, 2M
second, app, ota_0, 0x200000, 1M second, app, ota_0, 0x200000, 1M
""" """
with self.assertRaisesRegex(gen_esp32part.InputError, "overlap"): with self.assertRaisesRegex(gen_esp32part.InputError, 'overlap'):
t = gen_esp32part.PartitionTable.from_csv(csv) t = gen_esp32part.PartitionTable.from_csv(csv)
t.verify() t.verify()
@ -185,7 +186,7 @@ second, app, ota_0, 0x200000, 1M
first, app, factory, 0x100000, 1M first, app, factory, 0x100000, 1M
first, app, ota_0, 0x200000, 1M first, app, ota_0, 0x200000, 1M
""" """
with self.assertRaisesRegex(gen_esp32part.InputError, "Partition names must be unique"): with self.assertRaisesRegex(gen_esp32part.InputError, 'Partition names must be unique'):
t = gen_esp32part.PartitionTable.from_csv(csv) t = gen_esp32part.PartitionTable.from_csv(csv)
t.verify() t.verify()
@ -200,10 +201,10 @@ first, 0x30, 0xEE, 0x100400, 0x300000
self.assertEqual(len(tb), 64 + 32) self.assertEqual(len(tb), 64 + 32)
self.assertEqual(b'\xAA\x50', tb[0:2]) # magic self.assertEqual(b'\xAA\x50', tb[0:2]) # magic
self.assertEqual(b'\x30\xee', tb[2:4]) # type, subtype self.assertEqual(b'\x30\xee', tb[2:4]) # type, subtype
eo, es = struct.unpack("<LL", tb[4:12]) eo, es = struct.unpack('<LL', tb[4:12])
self.assertEqual(eo, 0x100400) # offset self.assertEqual(eo, 0x100400) # offset
self.assertEqual(es, 0x300000) # size self.assertEqual(es, 0x300000) # size
self.assertEqual(b"\xEB\xEB" + b"\xFF" * 14, tb[32:48]) self.assertEqual(b'\xEB\xEB' + b'\xFF' * 14, tb[32:48])
self.assertEqual(b'\x43\x03\x3f\x33\x40\x87\x57\x51\x69\x83\x9b\x40\x61\xb1\x27\x26', tb[48:64]) self.assertEqual(b'\x43\x03\x3f\x33\x40\x87\x57\x51\x69\x83\x9b\x40\x61\xb1\x27\x26', tb[48:64])
def test_multiple_entries(self): def test_multiple_entries(self):
@ -233,12 +234,12 @@ class BinaryParserTests(Py23TestCase):
def test_parse_one_entry(self): def test_parse_one_entry(self):
# type 0x30, subtype 0xee, # type 0x30, subtype 0xee,
# offset 1MB, size 2MB # offset 1MB, size 2MB
entry = b"\xAA\x50\x30\xee" + \ entry = b'\xAA\x50\x30\xee' + \
b"\x00\x00\x10\x00" + \ b'\x00\x00\x10\x00' + \
b"\x00\x00\x20\x00" + \ b'\x00\x00\x20\x00' + \
b"0123456789abc\0\0\0" + \ b'0123456789abc\0\0\0' + \
b"\x00\x00\x00\x00" + \ b'\x00\x00\x00\x00' + \
b"\xFF" * 32 b'\xFF' * 32
# verify that parsing 32 bytes as a table # verify that parsing 32 bytes as a table
# or as a single Definition are the same thing # or as a single Definition are the same thing
t = gen_esp32part.PartitionTable.from_binary(entry) t = gen_esp32part.PartitionTable.from_binary(entry)
@ -253,7 +254,7 @@ class BinaryParserTests(Py23TestCase):
self.assertEqual(e.subtype, 0xEE) self.assertEqual(e.subtype, 0xEE)
self.assertEqual(e.offset, 0x100000) self.assertEqual(e.offset, 0x100000)
self.assertEqual(e.size, 0x200000) self.assertEqual(e.size, 0x200000)
self.assertEqual(e.name, "0123456789abc") self.assertEqual(e.name, '0123456789abc')
def test_multiple_entries(self): def test_multiple_entries(self):
t = gen_esp32part.PartitionTable.from_binary(LONGER_BINARY_TABLE) t = gen_esp32part.PartitionTable.from_binary(LONGER_BINARY_TABLE)
@ -261,53 +262,53 @@ class BinaryParserTests(Py23TestCase):
self.assertEqual(3, len(t)) self.assertEqual(3, len(t))
self.assertEqual(t[0].type, gen_esp32part.APP_TYPE) self.assertEqual(t[0].type, gen_esp32part.APP_TYPE)
self.assertEqual(t[0].name, "factory") self.assertEqual(t[0].name, 'factory')
self.assertEqual(t[1].type, gen_esp32part.DATA_TYPE) self.assertEqual(t[1].type, gen_esp32part.DATA_TYPE)
self.assertEqual(t[1].name, "data") self.assertEqual(t[1].name, 'data')
self.assertEqual(t[2].type, 0x10) self.assertEqual(t[2].type, 0x10)
self.assertEqual(t[2].name, "second") self.assertEqual(t[2].name, 'second')
round_trip = _strip_trailing_ffs(t.to_binary()) round_trip = _strip_trailing_ffs(t.to_binary())
self.assertEqual(round_trip, LONGER_BINARY_TABLE) self.assertEqual(round_trip, LONGER_BINARY_TABLE)
def test_bad_magic(self): def test_bad_magic(self):
bad_magic = b"OHAI" + \ bad_magic = b'OHAI' + \
b"\x00\x00\x10\x00" + \ b'\x00\x00\x10\x00' + \
b"\x00\x00\x20\x00" + \ b'\x00\x00\x20\x00' + \
b"0123456789abc\0\0\0" + \ b'0123456789abc\0\0\0' + \
b"\x00\x00\x00\x00" b'\x00\x00\x00\x00'
with self.assertRaisesRegex(gen_esp32part.InputError, "Invalid magic bytes"): with self.assertRaisesRegex(gen_esp32part.InputError, 'Invalid magic bytes'):
gen_esp32part.PartitionTable.from_binary(bad_magic) gen_esp32part.PartitionTable.from_binary(bad_magic)
def test_bad_length(self): def test_bad_length(self):
bad_length = b"OHAI" + \ bad_length = b'OHAI' + \
b"\x00\x00\x10\x00" + \ b'\x00\x00\x10\x00' + \
b"\x00\x00\x20\x00" + \ b'\x00\x00\x20\x00' + \
b"0123456789" b'0123456789'
with self.assertRaisesRegex(gen_esp32part.InputError, "32 bytes"): with self.assertRaisesRegex(gen_esp32part.InputError, '32 bytes'):
gen_esp32part.PartitionTable.from_binary(bad_length) gen_esp32part.PartitionTable.from_binary(bad_length)
class CSVOutputTests(Py23TestCase): class CSVOutputTests(Py23TestCase):
def _readcsv(self, source_str): def _readcsv(self, source_str):
return list(csv.reader(source_str.split("\n"))) return list(csv.reader(source_str.split('\n')))
def test_output_simple_formatting(self): def test_output_simple_formatting(self):
table = gen_esp32part.PartitionTable.from_csv(SIMPLE_CSV) table = gen_esp32part.PartitionTable.from_csv(SIMPLE_CSV)
as_csv = table.to_csv(True) as_csv = table.to_csv(True)
c = self._readcsv(as_csv) c = self._readcsv(as_csv)
# first two lines should start with comments # first two lines should start with comments
self.assertEqual(c[0][0][0], "#") self.assertEqual(c[0][0][0], '#')
self.assertEqual(c[1][0][0], "#") self.assertEqual(c[1][0][0], '#')
row = c[2] row = c[2]
self.assertEqual(row[0], "factory") self.assertEqual(row[0], 'factory')
self.assertEqual(row[1], "0") self.assertEqual(row[1], '0')
self.assertEqual(row[2], "2") self.assertEqual(row[2], '2')
self.assertEqual(row[3], "0x10000") # reformatted as hex self.assertEqual(row[3], '0x10000') # reformatted as hex
self.assertEqual(row[4], "0x100000") # also hex self.assertEqual(row[4], '0x100000') # also hex
# round trip back to a PartitionTable and check is identical # round trip back to a PartitionTable and check is identical
roundtrip = gen_esp32part.PartitionTable.from_csv(as_csv) roundtrip = gen_esp32part.PartitionTable.from_csv(as_csv)
@ -318,14 +319,14 @@ class CSVOutputTests(Py23TestCase):
as_csv = table.to_csv(False) as_csv = table.to_csv(False)
c = self._readcsv(as_csv) c = self._readcsv(as_csv)
# first two lines should start with comments # first two lines should start with comments
self.assertEqual(c[0][0][0], "#") self.assertEqual(c[0][0][0], '#')
self.assertEqual(c[1][0][0], "#") self.assertEqual(c[1][0][0], '#')
row = c[2] row = c[2]
self.assertEqual(row[0], "factory") self.assertEqual(row[0], 'factory')
self.assertEqual(row[1], "app") self.assertEqual(row[1], 'app')
self.assertEqual(row[2], "2") self.assertEqual(row[2], '2')
self.assertEqual(row[3], "0x10000") self.assertEqual(row[3], '0x10000')
self.assertEqual(row[4], "1M") self.assertEqual(row[4], '1M')
# round trip back to a PartitionTable and check is identical # round trip back to a PartitionTable and check is identical
roundtrip = gen_esp32part.PartitionTable.from_csv(as_csv) roundtrip = gen_esp32part.PartitionTable.from_csv(as_csv)
@ -344,18 +345,18 @@ class CommandLineTests(Py23TestCase):
f.write(LONGER_BINARY_TABLE) f.write(LONGER_BINARY_TABLE)
# run gen_esp32part.py to convert binary file to CSV # run gen_esp32part.py to convert binary file to CSV
output = subprocess.check_output([sys.executable, "../gen_esp32part.py", output = subprocess.check_output([sys.executable, '../gen_esp32part.py',
binpath, csvpath], stderr=subprocess.STDOUT) binpath, csvpath], stderr=subprocess.STDOUT)
# reopen the CSV and check the generated binary is identical # reopen the CSV and check the generated binary is identical
self.assertNotIn(b"WARNING", output) self.assertNotIn(b'WARNING', output)
with open(csvpath, 'r') as f: with open(csvpath, 'r') as f:
from_csv = gen_esp32part.PartitionTable.from_csv(f.read()) from_csv = gen_esp32part.PartitionTable.from_csv(f.read())
self.assertEqual(_strip_trailing_ffs(from_csv.to_binary()), LONGER_BINARY_TABLE) self.assertEqual(_strip_trailing_ffs(from_csv.to_binary()), LONGER_BINARY_TABLE)
# run gen_esp32part.py to conver the CSV to binary again # run gen_esp32part.py to conver the CSV to binary again
output = subprocess.check_output([sys.executable, "../gen_esp32part.py", output = subprocess.check_output([sys.executable, '../gen_esp32part.py',
csvpath, binpath], stderr=subprocess.STDOUT) csvpath, binpath], stderr=subprocess.STDOUT)
self.assertNotIn(b"WARNING", output) self.assertNotIn(b'WARNING', output)
# assert that file reads back as identical # assert that file reads back as identical
with open(binpath, 'rb') as f: with open(binpath, 'rb') as f:
binary_readback = f.read() binary_readback = f.read()
@ -377,7 +378,7 @@ class VerificationTests(Py23TestCase):
# Name,Type, SubType,Offset,Size # Name,Type, SubType,Offset,Size
app,app, factory, 32K, 1M app,app, factory, 32K, 1M
""" """
with self.assertRaisesRegex(gen_esp32part.ValidationError, r"Offset.+not aligned"): with self.assertRaisesRegex(gen_esp32part.ValidationError, r'Offset.+not aligned'):
t = gen_esp32part.PartitionTable.from_csv(csv) t = gen_esp32part.PartitionTable.from_csv(csv)
t.verify() t.verify()
@ -385,16 +386,16 @@ app,app, factory, 32K, 1M
try: try:
sys.stderr = io.StringIO() # capture stderr sys.stderr = io.StringIO() # capture stderr
csv_1 = "app, 1, 2, 32K, 1M\n" csv_1 = 'app, 1, 2, 32K, 1M\n'
gen_esp32part.PartitionTable.from_csv(csv_1).verify() gen_esp32part.PartitionTable.from_csv(csv_1).verify()
self.assertIn("WARNING", sys.stderr.getvalue()) self.assertIn('WARNING', sys.stderr.getvalue())
self.assertIn("partition type", sys.stderr.getvalue()) self.assertIn('partition type', sys.stderr.getvalue())
sys.stderr = io.StringIO() sys.stderr = io.StringIO()
csv_2 = "ota_0, app, ota_1, , 1M\n" csv_2 = 'ota_0, app, ota_1, , 1M\n'
gen_esp32part.PartitionTable.from_csv(csv_2).verify() gen_esp32part.PartitionTable.from_csv(csv_2).verify()
self.assertIn("WARNING", sys.stderr.getvalue()) self.assertIn('WARNING', sys.stderr.getvalue())
self.assertIn("partition subtype", sys.stderr.getvalue()) self.assertIn('partition subtype', sys.stderr.getvalue())
finally: finally:
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
@ -404,13 +405,13 @@ class PartToolTests(Py23TestCase):
def _run_parttool(self, csvcontents, args): def _run_parttool(self, csvcontents, args):
csvpath = tempfile.mktemp() csvpath = tempfile.mktemp()
with open(csvpath, "w") as f: with open(csvpath, 'w') as f:
f.write(csvcontents) f.write(csvcontents)
try: try:
output = subprocess.check_output([sys.executable, "../parttool.py", "-q", "--partition-table-file", output = subprocess.check_output([sys.executable, '../parttool.py', '-q', '--partition-table-file',
csvpath, "get_partition_info"] + args, csvpath, 'get_partition_info'] + args,
stderr=subprocess.STDOUT) stderr=subprocess.STDOUT)
self.assertNotIn(b"WARNING", output) self.assertNotIn(b'WARNING', output)
return output.strip() return output.strip()
finally: finally:
os.remove(csvpath) os.remove(csvpath)
@ -431,41 +432,41 @@ nvs_key2, data, nvs_keys, 0x119000, 0x1000, encrypted
return self._run_parttool(csv, args) return self._run_parttool(csv, args)
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs", "--info", "offset"]), b"0x9000") rpt(['--partition-type', 'data', '--partition-subtype', 'nvs', '--info', 'offset']), b'0x9000')
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs", "--info", "size"]), b"0x4000") rpt(['--partition-type', 'data', '--partition-subtype', 'nvs', '--info', 'size']), b'0x4000')
self.assertEqual( self.assertEqual(
rpt(["--partition-name", "otadata", "--info", "offset"]), b"0xd000") rpt(['--partition-name', 'otadata', '--info', 'offset']), b'0xd000')
self.assertEqual( self.assertEqual(
rpt(["--partition-boot-default", "--info", "offset"]), b"0x10000") rpt(['--partition-boot-default', '--info', 'offset']), b'0x10000')
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs", "--info", "name", "offset", "size", "encrypted"]), rpt(['--partition-type', 'data', '--partition-subtype', 'nvs', '--info', 'name', 'offset', 'size', 'encrypted']),
b"nvs 0x9000 0x4000 False") b'nvs 0x9000 0x4000 False')
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs", "--info", "name", "offset", "size", "encrypted", "--part_list"]), rpt(['--partition-type', 'data', '--partition-subtype', 'nvs', '--info', 'name', 'offset', 'size', 'encrypted', '--part_list']),
b"nvs 0x9000 0x4000 False nvs1_user 0x110000 0x4000 False nvs2_user 0x114000 0x4000 False") b'nvs 0x9000 0x4000 False nvs1_user 0x110000 0x4000 False nvs2_user 0x114000 0x4000 False')
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs", "--info", "name", "--part_list"]), rpt(['--partition-type', 'data', '--partition-subtype', 'nvs', '--info', 'name', '--part_list']),
b"nvs nvs1_user nvs2_user") b'nvs nvs1_user nvs2_user')
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs_keys", "--info", "name", "--part_list"]), rpt(['--partition-type', 'data', '--partition-subtype', 'nvs_keys', '--info', 'name', '--part_list']),
b"nvs_key1 nvs_key2") b'nvs_key1 nvs_key2')
self.assertEqual( self.assertEqual(
rpt(["--partition-name", "nvs", "--info", "encrypted"]), b"False") rpt(['--partition-name', 'nvs', '--info', 'encrypted']), b'False')
self.assertEqual( self.assertEqual(
rpt(["--partition-name", "nvs1_user", "--info", "encrypted"]), b"False") rpt(['--partition-name', 'nvs1_user', '--info', 'encrypted']), b'False')
self.assertEqual( self.assertEqual(
rpt(["--partition-name", "nvs2_user", "--info", "encrypted"]), b"False") rpt(['--partition-name', 'nvs2_user', '--info', 'encrypted']), b'False')
self.assertEqual( self.assertEqual(
rpt(["--partition-name", "nvs_key1", "--info", "encrypted"]), b"True") rpt(['--partition-name', 'nvs_key1', '--info', 'encrypted']), b'True')
self.assertEqual( self.assertEqual(
rpt(["--partition-name", "nvs_key2", "--info", "encrypted"]), b"True") rpt(['--partition-name', 'nvs_key2', '--info', 'encrypted']), b'True')
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs_keys", "--info", "name", "encrypted", "--part_list"]), rpt(['--partition-type', 'data', '--partition-subtype', 'nvs_keys', '--info', 'name', 'encrypted', '--part_list']),
b"nvs_key1 True nvs_key2 True") b'nvs_key1 True nvs_key2 True')
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "data", "--partition-subtype", "nvs", "--info", "name", "encrypted", "--part_list"]), rpt(['--partition-type', 'data', '--partition-subtype', 'nvs', '--info', 'name', 'encrypted', '--part_list']),
b"nvs False nvs1_user False nvs2_user False") b'nvs False nvs1_user False nvs2_user False')
def test_fallback(self): def test_fallback(self):
csv = """ csv = """
@ -480,14 +481,14 @@ ota_1, app, ota_1, , 1M
return self._run_parttool(csv, args) return self._run_parttool(csv, args)
self.assertEqual( self.assertEqual(
rpt(["--partition-type", "app", "--partition-subtype", "ota_1", "--info", "offset"]), b"0x130000") rpt(['--partition-type', 'app', '--partition-subtype', 'ota_1', '--info', 'offset']), b'0x130000')
self.assertEqual( self.assertEqual(
rpt(["--partition-boot-default", "--info", "offset"]), b"0x30000") # ota_0 rpt(['--partition-boot-default', '--info', 'offset']), b'0x30000') # ota_0
csv_mod = csv.replace("ota_0", "ota_2") csv_mod = csv.replace('ota_0', 'ota_2')
self.assertEqual( self.assertEqual(
self._run_parttool(csv_mod, ["--partition-boot-default", "--info", "offset"]), self._run_parttool(csv_mod, ['--partition-boot-default', '--info', 'offset']),
b"0x130000") # now default is ota_1 b'0x130000') # now default is ota_1
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -2,13 +2,15 @@
# source: constants.proto # source: constants.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()

View File

@ -2,13 +2,15 @@
# source: sec0.proto # source: sec0.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -16,7 +18,6 @@ _sym_db = _symbol_database.Default()
import constants_pb2 as constants__pb2 import constants_pb2 as constants__pb2
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='sec0.proto', name='sec0.proto',
package='', package='',

View File

@ -2,13 +2,15 @@
# source: sec1.proto # source: sec1.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -16,7 +18,6 @@ _sym_db = _symbol_database.Default()
import constants_pb2 as constants__pb2 import constants_pb2 as constants__pb2
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='sec1.proto', name='sec1.proto',
package='', package='',
@ -73,7 +74,7 @@ _SESSIONCMD1 = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='client_verify_data', full_name='SessionCmd1.client_verify_data', index=0, name='client_verify_data', full_name='SessionCmd1.client_verify_data', index=0,
number=2, type=12, cpp_type=9, label=1, number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
@ -111,7 +112,7 @@ _SESSIONRESP1 = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_verify_data', full_name='SessionResp1.device_verify_data', index=1, name='device_verify_data', full_name='SessionResp1.device_verify_data', index=1,
number=3, type=12, cpp_type=9, label=1, number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
@ -142,7 +143,7 @@ _SESSIONCMD0 = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='client_pubkey', full_name='SessionCmd0.client_pubkey', index=0, name='client_pubkey', full_name='SessionCmd0.client_pubkey', index=0,
number=1, type=12, cpp_type=9, label=1, number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
@ -180,14 +181,14 @@ _SESSIONRESP0 = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_pubkey', full_name='SessionResp0.device_pubkey', index=1, name='device_pubkey', full_name='SessionResp0.device_pubkey', index=1,
number=2, type=12, cpp_type=9, label=1, number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_random', full_name='SessionResp0.device_random', index=2, name='device_random', full_name='SessionResp0.device_random', index=2,
number=3, type=12, cpp_type=9, label=1, number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),

View File

@ -2,13 +2,15 @@
# source: session.proto # source: session.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -17,7 +19,6 @@ _sym_db = _symbol_database.Default()
import sec0_pb2 as sec0__pb2 import sec0_pb2 as sec0__pb2
import sec1_pb2 as sec1__pb2 import sec1_pb2 as sec1__pb2
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='session.proto', name='session.proto',
package='', package='',

View File

@ -17,13 +17,14 @@
# limitations under the License. # limitations under the License.
from __future__ import division, print_function from __future__ import division, print_function
import os
import sys
import io
import math
import struct
import argparse import argparse
import ctypes import ctypes
import io
import math
import os
import struct
import sys
SPIFFS_PH_FLAG_USED_FINAL_INDEX = 0xF8 SPIFFS_PH_FLAG_USED_FINAL_INDEX = 0xF8
SPIFFS_PH_FLAG_USED_FINAL = 0xFC SPIFFS_PH_FLAG_USED_FINAL = 0xFC
@ -45,7 +46,7 @@ class SpiffsBuildConfig():
block_ix_len, meta_len, obj_name_len, obj_id_len, block_ix_len, meta_len, obj_name_len, obj_id_len,
span_ix_len, packed, aligned, endianness, use_magic, use_magic_len): span_ix_len, packed, aligned, endianness, use_magic, use_magic_len):
if block_size % page_size != 0: if block_size % page_size != 0:
raise RuntimeError("block size should be a multiple of page size") raise RuntimeError('block size should be a multiple of page size')
self.page_size = page_size self.page_size = page_size
self.block_size = block_size self.block_size = block_size
@ -88,15 +89,15 @@ class SpiffsFullError(RuntimeError):
class SpiffsPage(): class SpiffsPage():
_endianness_dict = { _endianness_dict = {
"little": "<", 'little': '<',
"big": ">" 'big': '>'
} }
_len_dict = { _len_dict = {
1: "B", 1: 'B',
2: "H", 2: 'H',
4: "I", 4: 'I',
8: "Q" 8: 'Q'
} }
_type_dict = { _type_dict = {
@ -137,7 +138,7 @@ class SpiffsObjLuPage(SpiffsPage):
def to_binary(self): def to_binary(self):
global test global test
img = b"" img = b''
for (obj_id, page_type) in self.obj_ids: for (obj_id, page_type) in self.obj_ids:
if page_type == SpiffsObjIndexPage: if page_type == SpiffsObjIndexPage:
@ -147,7 +148,7 @@ class SpiffsObjLuPage(SpiffsPage):
assert(len(img) <= self.build_config.page_size) assert(len(img) <= self.build_config.page_size)
img += b"\xFF" * (self.build_config.page_size - len(img)) img += b'\xFF' * (self.build_config.page_size - len(img))
return img return img
@ -205,7 +206,7 @@ class SpiffsObjIndexPage(SpiffsPage):
SPIFFS_PH_FLAG_USED_FINAL_INDEX) SPIFFS_PH_FLAG_USED_FINAL_INDEX)
# Add padding before the object index page specific information # Add padding before the object index page specific information
img += b"\xFF" * self.build_config.OBJ_DATA_PAGE_HEADER_LEN_ALIGNED_PAD img += b'\xFF' * self.build_config.OBJ_DATA_PAGE_HEADER_LEN_ALIGNED_PAD
# If this is the first object index page for the object, add filname, type # If this is the first object index page for the object, add filname, type
# and size information # and size information
@ -216,7 +217,7 @@ class SpiffsObjIndexPage(SpiffsPage):
self.size, self.size,
SPIFFS_TYPE_FILE) SPIFFS_TYPE_FILE)
img += self.name.encode() + (b"\x00" * ((self.build_config.obj_name_len - len(self.name)) + self.build_config.meta_len)) img += self.name.encode() + (b'\x00' * ((self.build_config.obj_name_len - len(self.name)) + self.build_config.meta_len))
# Finally, add the page index of daa pages # Finally, add the page index of daa pages
for page in self.pages: for page in self.pages:
@ -226,7 +227,7 @@ class SpiffsObjIndexPage(SpiffsPage):
assert(len(img) <= self.build_config.page_size) assert(len(img) <= self.build_config.page_size)
img += b"\xFF" * (self.build_config.page_size - len(img)) img += b'\xFF' * (self.build_config.page_size - len(img))
return img return img
@ -252,7 +253,7 @@ class SpiffsObjDataPage(SpiffsPage):
assert(len(img) <= self.build_config.page_size) assert(len(img) <= self.build_config.page_size)
img += b"\xFF" * (self.build_config.page_size - len(img)) img += b'\xFF' * (self.build_config.page_size - len(img))
return img return img
@ -296,7 +297,7 @@ class SpiffsBlock():
except AttributeError: # no next lookup page except AttributeError: # no next lookup page
# Since the amount of lookup pages is pre-computed at every block instance, # Since the amount of lookup pages is pre-computed at every block instance,
# this should never occur # this should never occur
raise RuntimeError("invalid attempt to add page to a block when there is no more space in lookup") raise RuntimeError('invalid attempt to add page to a block when there is no more space in lookup')
self.pages.append(page) self.pages.append(page)
@ -335,7 +336,7 @@ class SpiffsBlock():
return self.remaining_pages <= 0 return self.remaining_pages <= 0
def to_binary(self, blocks_lim): def to_binary(self, blocks_lim):
img = b"" img = b''
if self.build_config.use_magic: if self.build_config.use_magic:
for (idx, page) in enumerate(self.pages): for (idx, page) in enumerate(self.pages):
@ -348,14 +349,14 @@ class SpiffsBlock():
assert(len(img) <= self.build_config.block_size) assert(len(img) <= self.build_config.block_size)
img += b"\xFF" * (self.build_config.block_size - len(img)) img += b'\xFF' * (self.build_config.block_size - len(img))
return img return img
class SpiffsFS(): class SpiffsFS():
def __init__(self, img_size, build_config): def __init__(self, img_size, build_config):
if img_size % build_config.block_size != 0: if img_size % build_config.block_size != 0:
raise RuntimeError("image size should be a multiple of block size") raise RuntimeError('image size should be a multiple of block size')
self.img_size = img_size self.img_size = img_size
self.build_config = build_config self.build_config = build_config
@ -367,7 +368,7 @@ class SpiffsFS():
def _create_block(self): def _create_block(self):
if self.is_full(): if self.is_full():
raise SpiffsFullError("the image size has been exceeded") raise SpiffsFullError('the image size has been exceeded')
block = SpiffsBlock(len(self.blocks), self.blocks_lim, self.build_config) block = SpiffsBlock(len(self.blocks), self.blocks_lim, self.build_config)
self.blocks.append(block) self.blocks.append(block)
@ -385,7 +386,7 @@ class SpiffsFS():
name = img_path name = img_path
with open(file_path, "rb") as obj: with open(file_path, 'rb') as obj:
contents = obj.read() contents = obj.read()
stream = io.BytesIO(contents) stream = io.BytesIO(contents)
@ -434,7 +435,7 @@ class SpiffsFS():
self.cur_obj_id += 1 self.cur_obj_id += 1
def to_binary(self): def to_binary(self):
img = b"" img = b''
for block in self.blocks: for block in self.blocks:
img += block.to_binary(self.blocks_lim) img += block.to_binary(self.blocks_lim)
bix = len(self.blocks) bix = len(self.blocks)
@ -447,78 +448,78 @@ class SpiffsFS():
bix += 1 bix += 1
else: else:
# Just fill remaining spaces FF's # Just fill remaining spaces FF's
img += "\xFF" * (self.img_size - len(img)) img += '\xFF' * (self.img_size - len(img))
return img return img
def main(): def main():
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
print("WARNING: Support for Python 2 is deprecated and will be removed in future versions.", file=sys.stderr) print('WARNING: Support for Python 2 is deprecated and will be removed in future versions.', file=sys.stderr)
elif sys.version_info[0] == 3 and sys.version_info[1] < 6: elif sys.version_info[0] == 3 and sys.version_info[1] < 6:
print("WARNING: Python 3 versions older than 3.6 are not supported.", file=sys.stderr) print('WARNING: Python 3 versions older than 3.6 are not supported.', file=sys.stderr)
parser = argparse.ArgumentParser(description="SPIFFS Image Generator", parser = argparse.ArgumentParser(description='SPIFFS Image Generator',
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("image_size", parser.add_argument('image_size',
help="Size of the created image") help='Size of the created image')
parser.add_argument("base_dir", parser.add_argument('base_dir',
help="Path to directory from which the image will be created") help='Path to directory from which the image will be created')
parser.add_argument("output_file", parser.add_argument('output_file',
help="Created image output file path") help='Created image output file path')
parser.add_argument("--page-size", parser.add_argument('--page-size',
help="Logical page size. Set to value same as CONFIG_SPIFFS_PAGE_SIZE.", help='Logical page size. Set to value same as CONFIG_SPIFFS_PAGE_SIZE.',
type=int, type=int,
default=256) default=256)
parser.add_argument("--block-size", parser.add_argument('--block-size',
help="Logical block size. Set to the same value as the flash chip's sector size (g_rom_flashchip.sector_size).", help="Logical block size. Set to the same value as the flash chip's sector size (g_rom_flashchip.sector_size).",
type=int, type=int,
default=4096) default=4096)
parser.add_argument("--obj-name-len", parser.add_argument('--obj-name-len',
help="File full path maximum length. Set to value same as CONFIG_SPIFFS_OBJ_NAME_LEN.", help='File full path maximum length. Set to value same as CONFIG_SPIFFS_OBJ_NAME_LEN.',
type=int, type=int,
default=32) default=32)
parser.add_argument("--meta-len", parser.add_argument('--meta-len',
help="File metadata length. Set to value same as CONFIG_SPIFFS_META_LENGTH.", help='File metadata length. Set to value same as CONFIG_SPIFFS_META_LENGTH.',
type=int, type=int,
default=4) default=4)
parser.add_argument("--use-magic", parser.add_argument('--use-magic',
help="Use magic number to create an identifiable SPIFFS image. Specify if CONFIG_SPIFFS_USE_MAGIC.", help='Use magic number to create an identifiable SPIFFS image. Specify if CONFIG_SPIFFS_USE_MAGIC.',
action="store_true", action='store_true',
default=True) default=True)
parser.add_argument("--follow-symlinks", parser.add_argument('--follow-symlinks',
help="Take into account symbolic links during partition image creation.", help='Take into account symbolic links during partition image creation.',
action="store_true", action='store_true',
default=False) default=False)
parser.add_argument("--use-magic-len", parser.add_argument('--use-magic-len',
help="Use position in memory to create different magic numbers for each block. Specify if CONFIG_SPIFFS_USE_MAGIC_LENGTH.", help='Use position in memory to create different magic numbers for each block. Specify if CONFIG_SPIFFS_USE_MAGIC_LENGTH.',
action="store_true", action='store_true',
default=True) default=True)
parser.add_argument("--big-endian", parser.add_argument('--big-endian',
help="Specify if the target architecture is big-endian. If not specified, little-endian is assumed.", help='Specify if the target architecture is big-endian. If not specified, little-endian is assumed.',
action="store_true", action='store_true',
default=False) default=False)
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.base_dir): if not os.path.exists(args.base_dir):
raise RuntimeError("given base directory %s does not exist" % args.base_dir) raise RuntimeError('given base directory %s does not exist' % args.base_dir)
with open(args.output_file, "wb") as image_file: with open(args.output_file, 'wb') as image_file:
image_size = int(args.image_size, 0) image_size = int(args.image_size, 0)
spiffs_build_default = SpiffsBuildConfig(args.page_size, SPIFFS_PAGE_IX_LEN, spiffs_build_default = SpiffsBuildConfig(args.page_size, SPIFFS_PAGE_IX_LEN,
args.block_size, SPIFFS_BLOCK_IX_LEN, args.meta_len, args.block_size, SPIFFS_BLOCK_IX_LEN, args.meta_len,
args.obj_name_len, SPIFFS_OBJ_ID_LEN, SPIFFS_SPAN_IX_LEN, args.obj_name_len, SPIFFS_OBJ_ID_LEN, SPIFFS_SPAN_IX_LEN,
True, True, "big" if args.big_endian else "little", True, True, 'big' if args.big_endian else 'little',
args.use_magic, args.use_magic_len) args.use_magic, args.use_magic_len)
spiffs = SpiffsFS(image_size, spiffs_build_default) spiffs = SpiffsFS(image_size, spiffs_build_default)
@ -526,12 +527,12 @@ def main():
for root, dirs, files in os.walk(args.base_dir, followlinks=args.follow_symlinks): for root, dirs, files in os.walk(args.base_dir, followlinks=args.follow_symlinks):
for f in files: for f in files:
full_path = os.path.join(root, f) full_path = os.path.join(root, f)
spiffs.create_file("/" + os.path.relpath(full_path, args.base_dir).replace("\\", "/"), full_path) spiffs.create_file('/' + os.path.relpath(full_path, args.base_dir).replace('\\', '/'), full_path)
image = spiffs.to_binary() image = spiffs.to_binary()
image_file.write(image) image_file.write(image)
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View File

@ -6,58 +6,59 @@
# Distributed under the terms of Apache License v2.0 found in the top-level LICENSE file. # Distributed under the terms of Apache License v2.0 found in the top-level LICENSE file.
from __future__ import print_function from __future__ import print_function
from optparse import OptionParser
import sys import sys
from optparse import OptionParser
BASE_ADDR = 0x50000000 BASE_ADDR = 0x50000000
def gen_ld_h_from_sym(f_sym, f_ld, f_h): def gen_ld_h_from_sym(f_sym, f_ld, f_h):
f_ld.write("/* Variable definitions for ESP32ULP linker\n") f_ld.write('/* Variable definitions for ESP32ULP linker\n')
f_ld.write(" * This file is generated automatically by esp32ulp_mapgen.py utility.\n") f_ld.write(' * This file is generated automatically by esp32ulp_mapgen.py utility.\n')
f_ld.write(" */\n\n") f_ld.write(' */\n\n')
f_h.write("// Variable definitions for ESP32ULP\n") f_h.write('// Variable definitions for ESP32ULP\n')
f_h.write("// This file is generated automatically by esp32ulp_mapgen.py utility\n\n") f_h.write('// This file is generated automatically by esp32ulp_mapgen.py utility\n\n')
f_h.write("#pragma once\n\n") f_h.write('#pragma once\n\n')
for line in f_sym: for line in f_sym:
name, _, addr_str = line.split(" ", 2) name, _, addr_str = line.split(' ', 2)
addr = int(addr_str, 16) + BASE_ADDR addr = int(addr_str, 16) + BASE_ADDR
f_h.write("extern uint32_t ulp_{0};\n".format(name)) f_h.write('extern uint32_t ulp_{0};\n'.format(name))
f_ld.write("PROVIDE ( ulp_{0} = 0x{1:08x} );\n".format(name, addr)) f_ld.write('PROVIDE ( ulp_{0} = 0x{1:08x} );\n'.format(name, addr))
def gen_ld_h_from_sym_riscv(f_sym, f_ld, f_h): def gen_ld_h_from_sym_riscv(f_sym, f_ld, f_h):
f_ld.write("/* Variable definitions for ESP32ULP linker\n") f_ld.write('/* Variable definitions for ESP32ULP linker\n')
f_ld.write(" * This file is generated automatically by esp32ulp_mapgen.py utility.\n") f_ld.write(' * This file is generated automatically by esp32ulp_mapgen.py utility.\n')
f_ld.write(" */\n\n") f_ld.write(' */\n\n')
f_h.write("// Variable definitions for ESP32ULP\n") f_h.write('// Variable definitions for ESP32ULP\n')
f_h.write("// This file is generated automatically by esp32ulp_mapgen.py utility\n\n") f_h.write('// This file is generated automatically by esp32ulp_mapgen.py utility\n\n')
f_h.write("#pragma once\n\n") f_h.write('#pragma once\n\n')
for line in f_sym: for line in f_sym:
addr_str, _, name = line.split() addr_str, _, name = line.split()
addr = int(addr_str, 16) + BASE_ADDR addr = int(addr_str, 16) + BASE_ADDR
f_h.write("extern uint32_t ulp_{0};\n".format(name)) f_h.write('extern uint32_t ulp_{0};\n'.format(name))
f_ld.write("PROVIDE ( ulp_{0} = 0x{1:08x} );\n".format(name, addr)) f_ld.write('PROVIDE ( ulp_{0} = 0x{1:08x} );\n'.format(name, addr))
def main(): def main():
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
print("WARNING: Support for Python 2 is deprecated and will be removed in future versions.", file=sys.stderr) print('WARNING: Support for Python 2 is deprecated and will be removed in future versions.', file=sys.stderr)
elif sys.version_info[0] == 3 and sys.version_info[1] < 6: elif sys.version_info[0] == 3 and sys.version_info[1] < 6:
print("WARNING: Python 3 versions older than 3.6 are not supported.", file=sys.stderr) print('WARNING: Python 3 versions older than 3.6 are not supported.', file=sys.stderr)
description = ("This application generates .h and .ld files for symbols defined in input file. " description = ('This application generates .h and .ld files for symbols defined in input file. '
"The input symbols file can be generated using nm utility like this: " 'The input symbols file can be generated using nm utility like this: '
"esp32-ulp-nm -g -f posix <elf_file> > <symbols_file>") 'esp32-ulp-nm -g -f posix <elf_file> > <symbols_file>')
parser = OptionParser(description=description) parser = OptionParser(description=description)
parser.add_option("-s", "--symfile", dest="symfile", parser.add_option('-s', '--symfile', dest='symfile',
help="symbols file name", metavar="SYMFILE") help='symbols file name', metavar='SYMFILE')
parser.add_option("-o", "--outputfile", dest="outputfile", parser.add_option('-o', '--outputfile', dest='outputfile',
help="destination .h and .ld files name prefix", metavar="OUTFILE") help='destination .h and .ld files name prefix', metavar='OUTFILE')
parser.add_option("--riscv", action="store_true", help="use format for ulp riscv .sym file") parser.add_option('--riscv', action='store_true', help='use format for ulp riscv .sym file')
(options, args) = parser.parse_args() (options, args) = parser.parse_args()
if options.symfile is None: if options.symfile is None:
@ -69,14 +70,14 @@ def main():
return 1 return 1
if options.riscv: if options.riscv:
with open(options.outputfile + ".h", 'w') as f_h, open(options.outputfile + ".ld", 'w') as f_ld, open(options.symfile) as f_sym: with open(options.outputfile + '.h', 'w') as f_h, open(options.outputfile + '.ld', 'w') as f_ld, open(options.symfile) as f_sym:
gen_ld_h_from_sym_riscv(f_sym, f_ld, f_h) gen_ld_h_from_sym_riscv(f_sym, f_ld, f_h)
return 0 return 0
with open(options.outputfile + ".h", 'w') as f_h, open(options.outputfile + ".ld", 'w') as f_ld, open(options.symfile) as f_sym: with open(options.outputfile + '.h', 'w') as f_h, open(options.outputfile + '.ld', 'w') as f_ld, open(options.symfile) as f_sym:
gen_ld_h_from_sym(f_sym, f_ld, f_h) gen_ld_h_from_sym(f_sym, f_ld, f_h)
return 0 return 0
if __name__ == "__main__": if __name__ == '__main__':
exit(main()) exit(main())

View File

@ -2,13 +2,15 @@
# source: wifi_config.proto # source: wifi_config.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -17,7 +19,6 @@ _sym_db = _symbol_database.Default()
import constants_pb2 as constants__pb2 import constants_pb2 as constants__pb2
import wifi_constants_pb2 as wifi__constants__pb2 import wifi_constants_pb2 as wifi__constants__pb2
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='wifi_config.proto', name='wifi_config.proto',
package='', package='',
@ -163,21 +164,21 @@ _CMDSETCONFIG = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='ssid', full_name='CmdSetConfig.ssid', index=0, name='ssid', full_name='CmdSetConfig.ssid', index=0,
number=1, type=12, cpp_type=9, label=1, number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='passphrase', full_name='CmdSetConfig.passphrase', index=1, name='passphrase', full_name='CmdSetConfig.passphrase', index=1,
number=2, type=12, cpp_type=9, label=1, number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='bssid', full_name='CmdSetConfig.bssid', index=2, name='bssid', full_name='CmdSetConfig.bssid', index=2,
number=3, type=12, cpp_type=9, label=1, number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),

View File

@ -2,13 +2,15 @@
# source: wifi_constants.proto # source: wifi_constants.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -141,7 +143,7 @@ _WIFICONNECTEDSTATE = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='ip4_addr', full_name='WifiConnectedState.ip4_addr', index=0, name='ip4_addr', full_name='WifiConnectedState.ip4_addr', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b('').decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
@ -155,14 +157,14 @@ _WIFICONNECTEDSTATE = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='ssid', full_name='WifiConnectedState.ssid', index=2, name='ssid', full_name='WifiConnectedState.ssid', index=2,
number=3, type=12, cpp_type=9, label=1, number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='bssid', full_name='WifiConnectedState.bssid', index=3, name='bssid', full_name='WifiConnectedState.bssid', index=3,
number=4, type=12, cpp_type=9, label=1, number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR), options=None, file=DESCRIPTOR),

View File

@ -3,12 +3,14 @@
# source: wifi_scan.proto # source: wifi_scan.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -17,7 +19,6 @@ _sym_db = _symbol_database.Default()
import constants_pb2 as constants__pb2 import constants_pb2 as constants__pb2
import wifi_constants_pb2 as wifi__constants__pb2 import wifi_constants_pb2 as wifi__constants__pb2
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='wifi_scan.proto', name='wifi_scan.proto',
package='', package='',
@ -261,7 +262,7 @@ _WIFISCANRESULT = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='ssid', full_name='WiFiScanResult.ssid', index=0, name='ssid', full_name='WiFiScanResult.ssid', index=0,
number=1, type=12, cpp_type=9, label=1, number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@ -282,7 +283,7 @@ _WIFISCANRESULT = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='bssid', full_name='WiFiScanResult.bssid', index=3, name='bssid', full_name='WiFiScanResult.bssid', index=3,
number=4, type=12, cpp_type=9, label=1, number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""), has_default_value=False, default_value=_b(''),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),

View File

@ -50,11 +50,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import sys import sys
# Check if loaded into GDB # Check if loaded into GDB
try: try:
assert gdb.__name__ == "gdb" assert gdb.__name__ == 'gdb'
WITH_GDB = True WITH_GDB = True
except NameError: except NameError:
WITH_GDB = False WITH_GDB = False
@ -114,7 +115,7 @@ class TraxPacket(object):
return result return result
def __str__(self): def __str__(self):
return "%d byte packet%s" % (self.size_bytes, " (truncated)" if self.truncated else "") return '%d byte packet%s' % (self.size_bytes, ' (truncated)' if self.truncated else '')
class TraxMessage(object): class TraxMessage(object):
@ -175,7 +176,7 @@ class TraxMessage(object):
self.icnt = self.packets[0].get_bits(12, -1) self.icnt = self.packets[0].get_bits(12, -1)
self.is_correlation = True self.is_correlation = True
else: else:
raise NotImplementedError("Unknown message type (%d)" % self.msg_type) raise NotImplementedError('Unknown message type (%d)' % self.msg_type)
def process_forward(self, cur_pc): def process_forward(self, cur_pc):
""" """
@ -229,23 +230,23 @@ class TraxMessage(object):
return prev_pc return prev_pc
def __str__(self): def __str__(self):
desc = "Unknown (%d)" % self.msg_type desc = 'Unknown (%d)' % self.msg_type
extra = "" extra = ''
if self.truncated: if self.truncated:
desc = "Truncated" desc = 'Truncated'
if self.msg_type == TVAL_INDBR: if self.msg_type == TVAL_INDBR:
desc = "Indirect branch" desc = 'Indirect branch'
extra = ", icnt=%d, uaddr=0x%x, exc=%d" % (self.icnt, self.uaddr, self.is_exception) extra = ', icnt=%d, uaddr=0x%x, exc=%d' % (self.icnt, self.uaddr, self.is_exception)
if self.msg_type == TVAL_INDBRSYNC: if self.msg_type == TVAL_INDBRSYNC:
desc = "Indirect branch w/sync" desc = 'Indirect branch w/sync'
extra = ", icnt=%d, dcont=%d, exc=%d" % (self.icnt, self.dcont, self.is_exception) extra = ', icnt=%d, dcont=%d, exc=%d' % (self.icnt, self.dcont, self.is_exception)
if self.msg_type == TVAL_SYNC: if self.msg_type == TVAL_SYNC:
desc = "Synchronization" desc = 'Synchronization'
extra = ", icnt=%d, dcont=%d" % (self.icnt, self.dcont) extra = ', icnt=%d, dcont=%d' % (self.icnt, self.dcont)
if self.msg_type == TVAL_CORR: if self.msg_type == TVAL_CORR:
desc = "Correlation" desc = 'Correlation'
extra = ", icnt=%d" % self.icnt extra = ', icnt=%d' % self.icnt
return "%s message, %d packets, PC range 0x%08x - 0x%08x, target PC 0x%08x" % ( return '%s message, %d packets, PC range 0x%08x - 0x%08x, target PC 0x%08x' % (
desc, len(self.packets), self.pc_start, self.pc_end, self.pc_target) + extra desc, len(self.packets), self.pc_start, self.pc_end, self.pc_target) + extra
@ -264,7 +265,7 @@ def load_messages(data):
# Iterate over the input data, splitting bytes into packets and messages # Iterate over the input data, splitting bytes into packets and messages
for i, b in enumerate(data): for i, b in enumerate(data):
if (b & MSEO_MSGEND) and not (b & MSEO_PKTEND): if (b & MSEO_MSGEND) and not (b & MSEO_PKTEND):
raise AssertionError("Invalid MSEO bits in b=0x%x. Not a TRAX dump?" % b) raise AssertionError('Invalid MSEO bits in b=0x%x. Not a TRAX dump?' % b)
if b & MSEO_PKTEND: if b & MSEO_PKTEND:
pkt_cnt += 1 pkt_cnt += 1
@ -276,7 +277,7 @@ def load_messages(data):
try: try:
messages.append(TraxMessage(packets, len(messages) == 0)) messages.append(TraxMessage(packets, len(messages) == 0))
except NotImplementedError as e: except NotImplementedError as e:
sys.stderr.write("Failed to parse message #%03d (at %d bytes): %s\n" % (msg_cnt, i, str(e))) sys.stderr.write('Failed to parse message #%03d (at %d bytes): %s\n' % (msg_cnt, i, str(e)))
packets = [] packets = []
# Resolve PC ranges of messages. # Resolve PC ranges of messages.
@ -312,32 +313,32 @@ def parse_and_dump(filename, disassemble=WITH_GDB):
data = f.read() data = f.read()
messages = load_messages(data) messages = load_messages(data)
sys.stderr.write("Loaded %d messages in %d bytes\n" % (len(messages), len(data))) sys.stderr.write('Loaded %d messages in %d bytes\n' % (len(messages), len(data)))
for i, m in enumerate(messages): for i, m in enumerate(messages):
if m.truncated: if m.truncated:
continue continue
print("%04d: %s" % (i, str(m))) print('%04d: %s' % (i, str(m)))
if m.is_exception: if m.is_exception:
print("*** Exception occurred ***") print('*** Exception occurred ***')
if disassemble and WITH_GDB: if disassemble and WITH_GDB:
try: try:
gdb.execute("disassemble 0x%08x, 0x%08x" % (m.pc_start, m.pc_end)) # noqa: F821 gdb.execute('disassemble 0x%08x, 0x%08x' % (m.pc_start, m.pc_end)) # noqa: F821
except gdb.MemoryError: # noqa: F821 except gdb.MemoryError: # noqa: F821
print("Failed to disassemble from 0x%08x to 0x%08x" % (m.pc_start, m.pc_end)) print('Failed to disassemble from 0x%08x to 0x%08x' % (m.pc_start, m.pc_end))
def main(): def main():
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
print("WARNING: Support for Python 2 is deprecated and will be removed in future versions.", file=sys.stderr) print('WARNING: Support for Python 2 is deprecated and will be removed in future versions.', file=sys.stderr)
elif sys.version_info[0] == 3 and sys.version_info[1] < 6: elif sys.version_info[0] == 3 and sys.version_info[1] < 6:
print("WARNING: Python 3 versions older than 3.6 are not supported.", file=sys.stderr) print('WARNING: Python 3 versions older than 3.6 are not supported.', file=sys.stderr)
if len(sys.argv) < 2: if len(sys.argv) < 2:
sys.stderr.write("Usage: %s <dump_file>\n") sys.stderr.write('Usage: %s <dump_file>\n')
raise SystemExit(1) raise SystemExit(1)
parse_and_dump(sys.argv[1]) parse_and_dump(sys.argv[1])
if __name__ == "__main__" and not WITH_GDB: if __name__ == '__main__' and not WITH_GDB:
main() main()

View File

@ -24,31 +24,33 @@
# limitations under the License. # limitations under the License.
# #
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import locale import locale
import math import math
import multiprocessing import multiprocessing
import os import os
import os.path import os.path
import re
import subprocess import subprocess
import sys import sys
import re
from packaging import version
from collections import namedtuple from collections import namedtuple
LANGUAGES = ["en", "zh_CN"] from packaging import version
TARGETS = ["esp32", "esp32s2"]
SPHINX_WARN_LOG = "sphinx-warning-log.txt" LANGUAGES = ['en', 'zh_CN']
SPHINX_SANITIZED_LOG = "sphinx-warning-log-sanitized.txt" TARGETS = ['esp32', 'esp32s2']
SPHINX_KNOWN_WARNINGS = os.path.join(os.environ["IDF_PATH"], "docs", "sphinx-known-warnings.txt")
DXG_WARN_LOG = "doxygen-warning-log.txt" SPHINX_WARN_LOG = 'sphinx-warning-log.txt'
DXG_SANITIZED_LOG = "doxygen-warning-log-sanitized.txt" SPHINX_SANITIZED_LOG = 'sphinx-warning-log-sanitized.txt'
DXG_KNOWN_WARNINGS = os.path.join(os.environ["IDF_PATH"], "docs", "doxygen-known-warnings.txt") SPHINX_KNOWN_WARNINGS = os.path.join(os.environ['IDF_PATH'], 'docs', 'sphinx-known-warnings.txt')
DXG_WARN_LOG = 'doxygen-warning-log.txt'
DXG_SANITIZED_LOG = 'doxygen-warning-log-sanitized.txt'
DXG_KNOWN_WARNINGS = os.path.join(os.environ['IDF_PATH'], 'docs', 'doxygen-known-warnings.txt')
DXG_CI_VERSION = version.parse('1.8.11') DXG_CI_VERSION = version.parse('1.8.11')
LogMessage = namedtuple("LogMessage", "original_text sanitized_text") LogMessage = namedtuple('LogMessage', 'original_text sanitized_text')
languages = LANGUAGES languages = LANGUAGES
targets = TARGETS targets = TARGETS
@ -58,11 +60,11 @@ def main():
# check Python dependencies for docs # check Python dependencies for docs
try: try:
subprocess.check_call([sys.executable, subprocess.check_call([sys.executable,
os.path.join(os.environ["IDF_PATH"], os.path.join(os.environ['IDF_PATH'],
"tools", 'tools',
"check_python_dependencies.py"), 'check_python_dependencies.py'),
"-r", '-r',
"{}/docs/requirements.txt".format(os.environ["IDF_PATH"]) '{}/docs/requirements.txt'.format(os.environ['IDF_PATH'])
]) ])
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
raise SystemExit(2) # stdout will already have these errors raise SystemExit(2) # stdout will already have these errors
@ -73,31 +75,31 @@ def main():
# type not the str type. # type not the str type.
if ('UTF-8' not in locale.getlocale()) and ('utf8' not in locale.getlocale()): if ('UTF-8' not in locale.getlocale()) and ('utf8' not in locale.getlocale()):
raise RuntimeError("build_docs.py requires the default locale's encoding to be UTF-8.\n" + raise RuntimeError("build_docs.py requires the default locale's encoding to be UTF-8.\n" +
" - Linux. Setting environment variable LC_ALL=C.UTF-8 when running build_docs.py may be " + ' - Linux. Setting environment variable LC_ALL=C.UTF-8 when running build_docs.py may be ' +
"enough to fix this.\n" 'enough to fix this.\n'
" - Windows. Possible solution for the Windows 10 starting version 1803. Go to " + ' - Windows. Possible solution for the Windows 10 starting version 1803. Go to ' +
"Control Panel->Clock and Region->Region->Administrative->Change system locale...; " + 'Control Panel->Clock and Region->Region->Administrative->Change system locale...; ' +
"Check `Beta: Use Unicode UTF-8 for worldwide language support` and reboot") 'Check `Beta: Use Unicode UTF-8 for worldwide language support` and reboot')
parser = argparse.ArgumentParser(description='build_docs.py: Build IDF docs', prog='build_docs.py') parser = argparse.ArgumentParser(description='build_docs.py: Build IDF docs', prog='build_docs.py')
parser.add_argument("--language", "-l", choices=LANGUAGES, required=False) parser.add_argument('--language', '-l', choices=LANGUAGES, required=False)
parser.add_argument("--target", "-t", choices=TARGETS, required=False) parser.add_argument('--target', '-t', choices=TARGETS, required=False)
parser.add_argument("--build-dir", "-b", type=str, default="_build") parser.add_argument('--build-dir', '-b', type=str, default='_build')
parser.add_argument("--source-dir", "-s", type=str, default="") parser.add_argument('--source-dir', '-s', type=str, default='')
parser.add_argument("--builders", "-bs", nargs='+', type=str, default=["html"], parser.add_argument('--builders', '-bs', nargs='+', type=str, default=['html'],
help="List of builders for Sphinx, e.g. html or latex, for latex a PDF is also generated") help='List of builders for Sphinx, e.g. html or latex, for latex a PDF is also generated')
parser.add_argument("--sphinx-parallel-builds", "-p", choices=["auto"] + [str(x) for x in range(8)], parser.add_argument('--sphinx-parallel-builds', '-p', choices=['auto'] + [str(x) for x in range(8)],
help="Parallel Sphinx builds - number of independent Sphinx builds to run", default="auto") help='Parallel Sphinx builds - number of independent Sphinx builds to run', default='auto')
parser.add_argument("--sphinx-parallel-jobs", "-j", choices=["auto"] + [str(x) for x in range(8)], parser.add_argument('--sphinx-parallel-jobs', '-j', choices=['auto'] + [str(x) for x in range(8)],
help="Sphinx parallel jobs argument - number of threads for each Sphinx build to use", default="1") help='Sphinx parallel jobs argument - number of threads for each Sphinx build to use', default='1')
parser.add_argument("--input-docs", "-i", nargs='+', default=[""], parser.add_argument('--input-docs', '-i', nargs='+', default=[''],
help="List of documents to build relative to the doc base folder, i.e. the language folder. Defaults to all documents") help='List of documents to build relative to the doc base folder, i.e. the language folder. Defaults to all documents')
action_parsers = parser.add_subparsers(dest='action') action_parsers = parser.add_subparsers(dest='action')
build_parser = action_parsers.add_parser('build', help='Build documentation') build_parser = action_parsers.add_parser('build', help='Build documentation')
build_parser.add_argument("--check-warnings-only", "-w", action='store_true') build_parser.add_argument('--check-warnings-only', '-w', action='store_true')
action_parsers.add_parser('linkcheck', help='Check links (a current IDF revision should be uploaded to GitHub)') action_parsers.add_parser('linkcheck', help='Check links (a current IDF revision should be uploaded to GitHub)')
@ -107,27 +109,27 @@ def main():
global languages global languages
if args.language is None: if args.language is None:
print("Building all languages") print('Building all languages')
languages = LANGUAGES languages = LANGUAGES
else: else:
languages = [args.language] languages = [args.language]
global targets global targets
if args.target is None: if args.target is None:
print("Building all targets") print('Building all targets')
targets = TARGETS targets = TARGETS
else: else:
targets = [args.target] targets = [args.target]
if args.action == "build" or args.action is None: if args.action == 'build' or args.action is None:
if args.action is None: if args.action is None:
args.check_warnings_only = False args.check_warnings_only = False
sys.exit(action_build(args)) sys.exit(action_build(args))
if args.action == "linkcheck": if args.action == 'linkcheck':
sys.exit(action_linkcheck(args)) sys.exit(action_linkcheck(args))
if args.action == "gh-linkcheck": if args.action == 'gh-linkcheck':
sys.exit(action_gh_linkcheck(args)) sys.exit(action_gh_linkcheck(args))
@ -135,7 +137,7 @@ def parallel_call(args, callback):
num_sphinx_builds = len(languages) * len(targets) num_sphinx_builds = len(languages) * len(targets)
num_cpus = multiprocessing.cpu_count() num_cpus = multiprocessing.cpu_count()
if args.sphinx_parallel_builds == "auto": if args.sphinx_parallel_builds == 'auto':
# at most one sphinx build per CPU, up to the number of CPUs # at most one sphinx build per CPU, up to the number of CPUs
args.sphinx_parallel_builds = min(num_sphinx_builds, num_cpus) args.sphinx_parallel_builds = min(num_sphinx_builds, num_cpus)
else: else:
@ -143,17 +145,17 @@ def parallel_call(args, callback):
# Force -j1 because sphinx works incorrectly # Force -j1 because sphinx works incorrectly
args.sphinx_parallel_jobs = 1 args.sphinx_parallel_jobs = 1
if args.sphinx_parallel_jobs == "auto": if args.sphinx_parallel_jobs == 'auto':
# N CPUs per build job, rounded up - (maybe smarter to round down to avoid contention, idk) # N CPUs per build job, rounded up - (maybe smarter to round down to avoid contention, idk)
args.sphinx_parallel_jobs = int(math.ceil(num_cpus / args.sphinx_parallel_builds)) args.sphinx_parallel_jobs = int(math.ceil(num_cpus / args.sphinx_parallel_builds))
else: else:
args.sphinx_parallel_jobs = int(args.sphinx_parallel_jobs) args.sphinx_parallel_jobs = int(args.sphinx_parallel_jobs)
print("Will use %d parallel builds and %d jobs per build" % (args.sphinx_parallel_builds, args.sphinx_parallel_jobs)) print('Will use %d parallel builds and %d jobs per build' % (args.sphinx_parallel_builds, args.sphinx_parallel_jobs))
pool = multiprocessing.Pool(args.sphinx_parallel_builds) pool = multiprocessing.Pool(args.sphinx_parallel_builds)
if args.sphinx_parallel_jobs > 1: if args.sphinx_parallel_jobs > 1:
print("WARNING: Sphinx parallel jobs currently produce incorrect docs output with Sphinx 1.8.5") print('WARNING: Sphinx parallel jobs currently produce incorrect docs output with Sphinx 1.8.5')
# make a list of all combinations of build_docs() args as tuples # make a list of all combinations of build_docs() args as tuples
# #
@ -173,13 +175,13 @@ def parallel_call(args, callback):
is_error = False is_error = False
for ret in errcodes: for ret in errcodes:
if ret != 0: if ret != 0:
print("\nThe following language/target combinations failed to build:") print('\nThe following language/target combinations failed to build:')
is_error = True is_error = True
break break
if is_error: if is_error:
for ret, entry in zip(errcodes, entries): for ret, entry in zip(errcodes, entries):
if ret != 0: if ret != 0:
print("language: %s, target: %s, errcode: %d" % (entry[0], entry[1], ret)) print('language: %s, target: %s, errcode: %d' % (entry[0], entry[1], ret))
# Don't re-throw real error code from each parallel process # Don't re-throw real error code from each parallel process
return 1 return 1
else: else:
@ -193,9 +195,9 @@ def sphinx_call(language, target, build_dir, src_dir, sphinx_parallel_jobs, buil
# wrap stdout & stderr in a way that lets us see which build_docs instance they come from # wrap stdout & stderr in a way that lets us see which build_docs instance they come from
# #
# this doesn't apply to subprocesses, they write to OS stdout & stderr so no prefix appears # this doesn't apply to subprocesses, they write to OS stdout & stderr so no prefix appears
prefix = "%s/%s: " % (language, target) prefix = '%s/%s: ' % (language, target)
print("Building in build_dir: %s" % (build_dir)) print('Building in build_dir: %s' % (build_dir))
try: try:
os.makedirs(build_dir) os.makedirs(build_dir)
except OSError: except OSError:
@ -205,21 +207,21 @@ def sphinx_call(language, target, build_dir, src_dir, sphinx_parallel_jobs, buil
environ.update(os.environ) environ.update(os.environ)
environ['BUILDDIR'] = build_dir environ['BUILDDIR'] = build_dir
args = [sys.executable, "-u", "-m", "sphinx.cmd.build", args = [sys.executable, '-u', '-m', 'sphinx.cmd.build',
"-j", str(sphinx_parallel_jobs), '-j', str(sphinx_parallel_jobs),
"-b", buildername, '-b', buildername,
"-d", os.path.join(build_dir, "doctrees"), '-d', os.path.join(build_dir, 'doctrees'),
"-w", SPHINX_WARN_LOG, '-w', SPHINX_WARN_LOG,
"-t", target, '-t', target,
"-D", "idf_target={}".format(target), '-D', 'idf_target={}'.format(target),
"-D", "docs_to_build={}".format(",". join(input_docs)), '-D', 'docs_to_build={}'.format(','. join(input_docs)),
src_dir, src_dir,
os.path.join(build_dir, buildername) # build directory os.path.join(build_dir, buildername) # build directory
] ]
saved_cwd = os.getcwd() saved_cwd = os.getcwd()
os.chdir(build_dir) # also run sphinx in the build directory os.chdir(build_dir) # also run sphinx in the build directory
print("Running '%s'" % (" ".join(args))) print("Running '%s'" % (' '.join(args)))
ret = 1 ret = 1
try: try:
@ -282,7 +284,7 @@ def call_build_docs(entry):
# Build PDF from tex # Build PDF from tex
if 'latex' in builders: if 'latex' in builders:
latex_dir = os.path.join(build_dir, "latex") latex_dir = os.path.join(build_dir, 'latex')
ret = build_pdf(language, target, latex_dir) ret = build_pdf(language, target, latex_dir)
return ret return ret
@ -294,9 +296,9 @@ def build_pdf(language, target, latex_dir):
# wrap stdout & stderr in a way that lets us see which build_docs instance they come from # wrap stdout & stderr in a way that lets us see which build_docs instance they come from
# #
# this doesn't apply to subprocesses, they write to OS stdout & stderr so no prefix appears # this doesn't apply to subprocesses, they write to OS stdout & stderr so no prefix appears
prefix = "%s/%s: " % (language, target) prefix = '%s/%s: ' % (language, target)
print("Building PDF in latex_dir: %s" % (latex_dir)) print('Building PDF in latex_dir: %s' % (latex_dir))
saved_cwd = os.getcwd() saved_cwd = os.getcwd()
os.chdir(latex_dir) os.chdir(latex_dir)
@ -337,8 +339,8 @@ def build_pdf(language, target, latex_dir):
return ret return ret
SANITIZE_FILENAME_REGEX = re.compile("[^:]*/([^/:]*)(:.*)") SANITIZE_FILENAME_REGEX = re.compile('[^:]*/([^/:]*)(:.*)')
SANITIZE_LINENUM_REGEX = re.compile("([^:]*)(:[0-9]+:)(.*)") SANITIZE_LINENUM_REGEX = re.compile('([^:]*)(:[0-9]+:)(.*)')
def sanitize_line(line): def sanitize_line(line):
@ -376,12 +378,12 @@ def check_docs(language, target, log_file, known_warnings_file, out_sanitized_lo
for known_line in k: for known_line in k:
known_messages.append(known_line) known_messages.append(known_line)
if "doxygen" in known_warnings_file: if 'doxygen' in known_warnings_file:
# Clean a known Doxygen limitation: it's expected to always document anonymous # Clean a known Doxygen limitation: it's expected to always document anonymous
# structs/unions but we don't do this in our docs, so filter these all out with a regex # structs/unions but we don't do this in our docs, so filter these all out with a regex
# (this won't match any named field, only anonymous members - # (this won't match any named field, only anonymous members -
# ie the last part of the field is is just <something>::@NUM not <something>::name) # ie the last part of the field is is just <something>::@NUM not <something>::name)
RE_ANONYMOUS_FIELD = re.compile(r".+:line: warning: parameters of member [^:\s]+(::[^:\s]+)*(::@\d+)+ are not \(all\) documented") RE_ANONYMOUS_FIELD = re.compile(r'.+:line: warning: parameters of member [^:\s]+(::[^:\s]+)*(::@\d+)+ are not \(all\) documented')
all_messages = [msg for msg in all_messages if not re.match(RE_ANONYMOUS_FIELD, msg.sanitized_text)] all_messages = [msg for msg in all_messages if not re.match(RE_ANONYMOUS_FIELD, msg.sanitized_text)]
# Collect all new messages that are not match with the known messages. # Collect all new messages that are not match with the known messages.
@ -395,17 +397,17 @@ def check_docs(language, target, log_file, known_warnings_file, out_sanitized_lo
new_messages.append(msg) new_messages.append(msg)
if new_messages: if new_messages:
print("\n%s/%s: Build failed due to new/different warnings (%s):\n" % (language, target, log_file)) print('\n%s/%s: Build failed due to new/different warnings (%s):\n' % (language, target, log_file))
for msg in new_messages: for msg in new_messages:
print("%s/%s: %s" % (language, target, msg.original_text), end='') print('%s/%s: %s' % (language, target, msg.original_text), end='')
print("\n%s/%s: (Check files %s and %s for full details.)" % (language, target, known_warnings_file, log_file)) print('\n%s/%s: (Check files %s and %s for full details.)' % (language, target, known_warnings_file, log_file))
return 1 return 1
return 0 return 0
def action_linkcheck(args): def action_linkcheck(args):
args.builders = "linkcheck" args.builders = 'linkcheck'
return parallel_call(args, call_linkcheck) return parallel_call(args, call_linkcheck)
@ -416,49 +418,49 @@ def call_linkcheck(entry):
# https://github.com/espressif/esp-idf/tree/ # https://github.com/espressif/esp-idf/tree/
# https://github.com/espressif/esp-idf/blob/ # https://github.com/espressif/esp-idf/blob/
# https://github.com/espressif/esp-idf/raw/ # https://github.com/espressif/esp-idf/raw/
GH_LINK_RE = r"https://github.com/espressif/esp-idf/(?:tree|blob|raw)/[^\s]+" GH_LINK_RE = r'https://github.com/espressif/esp-idf/(?:tree|blob|raw)/[^\s]+'
# we allow this one doc, because we always want users to see the latest support policy # we allow this one doc, because we always want users to see the latest support policy
GH_LINK_ALLOWED = ["https://github.com/espressif/esp-idf/blob/master/SUPPORT_POLICY.md", GH_LINK_ALLOWED = ['https://github.com/espressif/esp-idf/blob/master/SUPPORT_POLICY.md',
"https://github.com/espressif/esp-idf/blob/master/SUPPORT_POLICY_CN.md"] 'https://github.com/espressif/esp-idf/blob/master/SUPPORT_POLICY_CN.md']
def action_gh_linkcheck(args): def action_gh_linkcheck(args):
print("Checking for hardcoded GitHub links\n") print('Checking for hardcoded GitHub links\n')
github_links = [] github_links = []
docs_dir = os.path.relpath(os.path.dirname(__file__)) docs_dir = os.path.relpath(os.path.dirname(__file__))
for root, _, files in os.walk(docs_dir): for root, _, files in os.walk(docs_dir):
if "_build" in root: if '_build' in root:
continue continue
files = [os.path.join(root, f) for f in files if f.endswith(".rst")] files = [os.path.join(root, f) for f in files if f.endswith('.rst')]
for path in files: for path in files:
with open(path, "r") as f: with open(path, 'r') as f:
for link in re.findall(GH_LINK_RE, f.read()): for link in re.findall(GH_LINK_RE, f.read()):
if link not in GH_LINK_ALLOWED: if link not in GH_LINK_ALLOWED:
github_links.append((path, link)) github_links.append((path, link))
if github_links: if github_links:
for path, link in github_links: for path, link in github_links:
print("%s: %s" % (path, link)) print('%s: %s' % (path, link))
print("WARNING: Some .rst files contain hardcoded Github links.") print('WARNING: Some .rst files contain hardcoded Github links.')
print("Please check above output and replace links with one of the following:") print('Please check above output and replace links with one of the following:')
print("- :idf:`dir` - points to directory inside ESP-IDF") print('- :idf:`dir` - points to directory inside ESP-IDF')
print("- :idf_file:`file` - points to file inside ESP-IDF") print('- :idf_file:`file` - points to file inside ESP-IDF')
print("- :idf_raw:`file` - points to raw view of the file inside ESP-IDF") print('- :idf_raw:`file` - points to raw view of the file inside ESP-IDF')
print("- :component:`dir` - points to directory inside ESP-IDF components dir") print('- :component:`dir` - points to directory inside ESP-IDF components dir')
print("- :component_file:`file` - points to file inside ESP-IDF components dir") print('- :component_file:`file` - points to file inside ESP-IDF components dir')
print("- :component_raw:`file` - points to raw view of the file inside ESP-IDF components dir") print('- :component_raw:`file` - points to raw view of the file inside ESP-IDF components dir')
print("- :example:`dir` - points to directory inside ESP-IDF examples dir") print('- :example:`dir` - points to directory inside ESP-IDF examples dir')
print("- :example_file:`file` - points to file inside ESP-IDF examples dir") print('- :example_file:`file` - points to file inside ESP-IDF examples dir')
print("- :example_raw:`file` - points to raw view of the file inside ESP-IDF examples dir") print('- :example_raw:`file` - points to raw view of the file inside ESP-IDF examples dir')
print("These link types will point to the correct GitHub version automatically") print('These link types will point to the correct GitHub version automatically')
return 1 return 1
else: else:
print("No hardcoded links found") print('No hardcoded links found')
return 0 return 0
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View File

@ -14,17 +14,17 @@
# All configuration values have a default; values that are commented out # All configuration values have a default; values that are commented out
# serve to show the default. # serve to show the default.
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
import sys
import os import os
import os.path import os.path
import re import re
import subprocess import subprocess
from sanitize_version import sanitize_version import sys
from idf_extensions.util import download_file_if_missing
from get_github_rev import get_github_rev
from get_github_rev import get_github_rev
from idf_extensions.util import download_file_if_missing
from sanitize_version import sanitize_version
# build_docs on the CI server sometimes fails under Python3. This is a workaround: # build_docs on the CI server sometimes fails under Python3. This is a workaround:
sys.setrecursionlimit(3500) sys.setrecursionlimit(3500)
@ -242,7 +242,7 @@ versions_url = 'https://dl.espressif.com/dl/esp-idf/idf_versions.js'
idf_targets = ['esp32', 'esp32s2'] idf_targets = ['esp32', 'esp32s2']
languages = ['en', 'zh_CN'] languages = ['en', 'zh_CN']
project_homepage = "https://github.com/espressif/esp-idf" project_homepage = 'https://github.com/espressif/esp-idf'
# -- Options for HTML output ---------------------------------------------- # -- Options for HTML output ----------------------------------------------
@ -250,11 +250,11 @@ project_homepage = "https://github.com/espressif/esp-idf"
# #
# Redirects should be listed in page_redirects.xt # Redirects should be listed in page_redirects.xt
# #
with open("../page_redirects.txt") as f: with open('../page_redirects.txt') as f:
lines = [re.sub(" +", " ", line.strip()) for line in f.readlines() if line.strip() != "" and not line.startswith("#")] lines = [re.sub(' +', ' ', line.strip()) for line in f.readlines() if line.strip() != '' and not line.startswith('#')]
for line in lines: # check for well-formed entries for line in lines: # check for well-formed entries
if len(line.split(' ')) != 2: if len(line.split(' ')) != 2:
raise RuntimeError("Invalid line in page_redirects.txt: %s" % line) raise RuntimeError('Invalid line in page_redirects.txt: %s' % line)
html_redirect_pages = [tuple(line.split(' ')) for line in lines] html_redirect_pages = [tuple(line.split(' ')) for line in lines]
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
@ -264,10 +264,10 @@ html_theme = 'sphinx_idf_theme'
# context used by sphinx_idf_theme # context used by sphinx_idf_theme
html_context = { html_context = {
"display_github": True, # Add 'Edit on Github' link instead of 'View page source' 'display_github': True, # Add 'Edit on Github' link instead of 'View page source'
"github_user": "espressif", 'github_user': 'espressif',
"github_repo": "esp-idf", 'github_repo': 'esp-idf',
"github_version": get_github_rev(), 'github_version': get_github_rev(),
} }
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
@ -287,7 +287,7 @@ html_context = {
# The name of an image file (relative to this directory) to place at the top # The name of an image file (relative to this directory) to place at the top
# of the sidebar. # of the sidebar.
html_logo = "../_static/espressif-logo.svg" html_logo = '../_static/espressif-logo.svg'
# The name of an image file (within the static path) to use as favicon of the # The name of an image file (within the static path) to use as favicon of the
@ -380,7 +380,7 @@ latex_elements = {
# The name of an image file (relative to this directory) to place at the bottom of # The name of an image file (relative to this directory) to place at the bottom of
# the title page. # the title page.
latex_logo = "../_static/espressif2.pdf" latex_logo = '../_static/espressif2.pdf'
latex_engine = 'xelatex' latex_engine = 'xelatex'
latex_use_xindy = False latex_use_xindy = False
@ -427,7 +427,7 @@ def setup(app):
app.add_stylesheet('theme_overrides.css') app.add_stylesheet('theme_overrides.css')
# these two must be pushed in by build_docs.py # these two must be pushed in by build_docs.py
if "idf_target" not in app.config: if 'idf_target' not in app.config:
app.add_config_value('idf_target', None, 'env') app.add_config_value('idf_target', None, 'env')
app.add_config_value('idf_targets', None, 'env') app.add_config_value('idf_targets', None, 'env')
@ -436,8 +436,8 @@ def setup(app):
# Breathe extension variables (depend on build_dir) # Breathe extension variables (depend on build_dir)
# note: we generate into xml_in and then copy_if_modified to xml dir # note: we generate into xml_in and then copy_if_modified to xml dir
app.config.breathe_projects = {"esp32-idf": os.path.join(app.config.build_dir, "xml_in/")} app.config.breathe_projects = {'esp32-idf': os.path.join(app.config.build_dir, 'xml_in/')}
app.config.breathe_default_project = "esp32-idf" app.config.breathe_default_project = 'esp32-idf'
setup_diag_font(app) setup_diag_font(app)
@ -455,13 +455,13 @@ def setup_config_values(app, config):
app.add_config_value('idf_target_title_dict', idf_target_title_dict, 'env') app.add_config_value('idf_target_title_dict', idf_target_title_dict, 'env')
pdf_name = "esp-idf-{}-{}-{}".format(app.config.language, app.config.version, app.config.idf_target) pdf_name = 'esp-idf-{}-{}-{}'.format(app.config.language, app.config.version, app.config.idf_target)
app.add_config_value('pdf_file', pdf_name, 'env') app.add_config_value('pdf_file', pdf_name, 'env')
def setup_html_context(app, config): def setup_html_context(app, config):
# Setup path for 'edit on github'-link # Setup path for 'edit on github'-link
config.html_context['conf_py_path'] = "/docs/{}/".format(app.config.language) config.html_context['conf_py_path'] = '/docs/{}/'.format(app.config.language)
def setup_diag_font(app): def setup_diag_font(app):
@ -476,7 +476,7 @@ def setup_diag_font(app):
font_dir = os.path.join(config_dir, '_static') font_dir = os.path.join(config_dir, '_static')
assert os.path.exists(font_dir) assert os.path.exists(font_dir)
print("Downloading font file %s for %s" % (font_name, app.config.language)) print('Downloading font file %s for %s' % (font_name, app.config.language))
download_file_if_missing('https://dl.espressif.com/dl/esp-idf/docs/_static/{}'.format(font_name), font_dir) download_file_if_missing('https://dl.espressif.com/dl/esp-idf/docs/_static/{}'.format(font_name), font_dir)
font_path = os.path.abspath(os.path.join(font_dir, font_name)) font_path = os.path.abspath(os.path.join(font_dir, font_name))

View File

@ -9,8 +9,8 @@
try: try:
from conf_common import * # noqa: F403,F401 from conf_common import * # noqa: F403,F401
except ImportError: except ImportError:
import sys
import os import os
import sys
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath('..'))
from conf_common import * # noqa: F403,F401 from conf_common import * # noqa: F403,F401

View File

@ -53,22 +53,22 @@ def create_redirect_pages(app):
return # only relevant for standalone HTML output return # only relevant for standalone HTML output
for (old_url, new_url) in app.config.html_redirect_pages: for (old_url, new_url) in app.config.html_redirect_pages:
print("Creating redirect %s to %s..." % (old_url, new_url)) print('Creating redirect %s to %s...' % (old_url, new_url))
if old_url.startswith('/'): if old_url.startswith('/'):
print("Stripping leading / from URL in config file...") print('Stripping leading / from URL in config file...')
old_url = old_url[1:] old_url = old_url[1:]
new_url = app.builder.get_relative_uri(old_url, new_url) new_url = app.builder.get_relative_uri(old_url, new_url)
out_file = app.builder.get_outfilename(old_url) out_file = app.builder.get_outfilename(old_url)
print("HTML file %s redirects to relative URL %s" % (out_file, new_url)) print('HTML file %s redirects to relative URL %s' % (out_file, new_url))
out_dir = os.path.dirname(out_file) out_dir = os.path.dirname(out_file)
if not os.path.exists(out_dir): if not os.path.exists(out_dir):
os.makedirs(out_dir) os.makedirs(out_dir)
content = REDIRECT_TEMPLATE.replace("$NEWURL", new_url) content = REDIRECT_TEMPLATE.replace('$NEWURL', new_url)
with open(out_file, "w") as rp: with open(out_file, 'w') as rp:
rp.write(content) rp.write(content)
return [] return []

View File

@ -1,4 +1,5 @@
import re import re
from docutils import nodes from docutils import nodes
from docutils.parsers.rst import Directive from docutils.parsers.rst import Directive

View File

@ -1,5 +1,6 @@
# Based on https://stackoverflow.com/a/46600038 with some modifications # Based on https://stackoverflow.com/a/46600038 with some modifications
import re import re
from sphinx.directives.other import TocTree from sphinx.directives.other import TocTree

View File

@ -18,14 +18,14 @@ import argparse
import datetime as dt import datetime as dt
import json import json
import numpy as np
import requests
import matplotlib.dates import matplotlib.dates
import matplotlib.patches as mpatches import matplotlib.patches as mpatches
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.dates import MONTHLY, DateFormatter, RRuleLocator, rrulewrapper import numpy as np
import requests
from dateutil import parser from dateutil import parser
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from matplotlib.dates import MONTHLY, DateFormatter, RRuleLocator, rrulewrapper
class Version(object): class Version(object):
@ -68,18 +68,18 @@ class ChartVersions(object):
def get_releases_as_json(self): def get_releases_as_json(self):
return { return {
x.version_name: { x.version_name: {
"start_date": x.get_start_date().strftime("%Y-%m-%d"), 'start_date': x.get_start_date().strftime('%Y-%m-%d'),
"end_service": x.get_end_service_date().strftime("%Y-%m-%d"), 'end_service': x.get_end_service_date().strftime('%Y-%m-%d'),
"end_date": x.get_end_of_life_date().strftime("%Y-%m-%d") 'end_date': x.get_end_of_life_date().strftime('%Y-%m-%d')
} for x in self.sorted_releases_supported } for x in self.sorted_releases_supported
} }
@staticmethod @staticmethod
def parse_chart_releases_from_js(js_as_string): def parse_chart_releases_from_js(js_as_string):
return json.loads(js_as_string[js_as_string.find("RELEASES: ") + len("RELEASES: "):js_as_string.rfind("};")]) return json.loads(js_as_string[js_as_string.find('RELEASES: ') + len('RELEASES: '):js_as_string.rfind('};')])
def _get_all_version_from_url(self, url=None, filename=None): def _get_all_version_from_url(self, url=None, filename=None):
releases_file = requests.get(url).text if url is not None else "".join(open(filename).readlines()) releases_file = requests.get(url).text if url is not None else ''.join(open(filename).readlines())
return self.parse_chart_releases_from_js(releases_file) return self.parse_chart_releases_from_js(releases_file)
def _get_releases_from_url(self, url=None, filename=None): def _get_releases_from_url(self, url=None, filename=None):
@ -178,7 +178,7 @@ class ChartVersions(object):
rule = rrulewrapper(MONTHLY, interval=x_ax_interval) rule = rrulewrapper(MONTHLY, interval=x_ax_interval)
loc = RRuleLocator(rule) loc = RRuleLocator(rule)
formatter = DateFormatter("%b %Y") formatter = DateFormatter('%b %Y')
ax.xaxis.set_major_locator(loc) ax.xaxis.set_major_locator(loc)
ax.xaxis.set_major_formatter(formatter) ax.xaxis.set_major_formatter(formatter)
@ -198,19 +198,19 @@ class ChartVersions(object):
bbox_to_anchor=(1.01, 1.165), loc='upper right') bbox_to_anchor=(1.01, 1.165), loc='upper right')
fig.set_size_inches(11, 5, forward=True) fig.set_size_inches(11, 5, forward=True)
plt.savefig(output_chart_name + output_chart_extension, bbox_inches='tight') plt.savefig(output_chart_name + output_chart_extension, bbox_inches='tight')
print("Saved into " + output_chart_name + output_chart_extension) print('Saved into ' + output_chart_name + output_chart_extension)
if __name__ == '__main__': if __name__ == '__main__':
arg_parser = argparse.ArgumentParser( arg_parser = argparse.ArgumentParser(
description="Create chart of version support. Set the url or filename with versions." description='Create chart of version support. Set the url or filename with versions.'
"If you set both filename and url the script will prefer filename.") 'If you set both filename and url the script will prefer filename.')
arg_parser.add_argument("--url", metavar="URL", default="https://dl.espressif.com/dl/esp-idf/idf_versions.js") arg_parser.add_argument('--url', metavar='URL', default='https://dl.espressif.com/dl/esp-idf/idf_versions.js')
arg_parser.add_argument("--filename", arg_parser.add_argument('--filename',
help="Set the name of the source file, if is set, the script ignores the url.") help='Set the name of the source file, if is set, the script ignores the url.')
arg_parser.add_argument("--output-format", help="Set the output format of the image.", default="svg") arg_parser.add_argument('--output-format', help='Set the output format of the image.', default='svg')
arg_parser.add_argument("--output-file", help="Set the name of the output file.", default="docs/chart") arg_parser.add_argument('--output-file', help='Set the name of the output file.', default='docs/chart')
args = arg_parser.parse_args() args = arg_parser.parse_args()
ChartVersions(url=args.url if args.filename is None else None, filename=args.filename).create_chart( ChartVersions(url=args.url if args.filename is None else None, filename=args.filename).create_chart(
output_chart_extension="." + args.output_format.lower()[-3:], output_chart_name=args.output_file) output_chart_extension='.' + args.output_format.lower()[-3:], output_chart_name=args.output_file)

View File

@ -6,11 +6,11 @@
# #
# Then emits the new 'idf-info' event which has information read from IDF # Then emits the new 'idf-info' event which has information read from IDF
# build system, that other extensions can use to generate relevant data. # build system, that other extensions can use to generate relevant data.
import json
import os.path import os.path
import shutil import shutil
import sys
import subprocess import subprocess
import json import sys
# this directory also contains the dummy IDF project # this directory also contains the dummy IDF project
project_path = os.path.abspath(os.path.dirname(__file__)) project_path = os.path.abspath(os.path.dirname(__file__))
@ -23,7 +23,7 @@ def setup(app):
# Setup some common paths # Setup some common paths
try: try:
build_dir = os.environ["BUILDDIR"] # TODO see if we can remove this build_dir = os.environ['BUILDDIR'] # TODO see if we can remove this
except KeyError: except KeyError:
build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep)) build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep))
@ -43,7 +43,7 @@ def setup(app):
except KeyError: except KeyError:
idf_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '..')) idf_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))
app.add_config_value('docs_root', os.path.join(idf_path, "docs"), 'env') app.add_config_value('docs_root', os.path.join(idf_path, 'docs'), 'env')
app.add_config_value('idf_path', idf_path, 'env') app.add_config_value('idf_path', idf_path, 'env')
app.add_config_value('build_dir', build_dir, 'env') # not actually an IDF thing app.add_config_value('build_dir', build_dir, 'env') # not actually an IDF thing
app.add_event('idf-info') app.add_event('idf-info')
@ -55,43 +55,43 @@ def setup(app):
def generate_idf_info(app, config): def generate_idf_info(app, config):
print("Running CMake on dummy project to get build info...") print('Running CMake on dummy project to get build info...')
build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep)) build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep))
cmake_build_dir = os.path.join(build_dir, "build_dummy_project") cmake_build_dir = os.path.join(build_dir, 'build_dummy_project')
idf_py_path = os.path.join(app.config.idf_path, "tools", "idf.py") idf_py_path = os.path.join(app.config.idf_path, 'tools', 'idf.py')
print("Running idf.py...") print('Running idf.py...')
idf_py = [sys.executable, idf_py = [sys.executable,
idf_py_path, idf_py_path,
"-B", '-B',
cmake_build_dir, cmake_build_dir,
"-C", '-C',
project_path, project_path,
"-D", '-D',
"SDKCONFIG={}".format(os.path.join(build_dir, "dummy_project_sdkconfig")) 'SDKCONFIG={}'.format(os.path.join(build_dir, 'dummy_project_sdkconfig'))
] ]
# force a clean idf.py build w/ new sdkconfig each time # force a clean idf.py build w/ new sdkconfig each time
# (not much slower than 'reconfigure', avoids any potential config & build versioning problems # (not much slower than 'reconfigure', avoids any potential config & build versioning problems
shutil.rmtree(cmake_build_dir, ignore_errors=True) shutil.rmtree(cmake_build_dir, ignore_errors=True)
print("Starting new dummy IDF project... ") print('Starting new dummy IDF project... ')
if (app.config.idf_target in PREVIEW_TARGETS): if (app.config.idf_target in PREVIEW_TARGETS):
subprocess.check_call(idf_py + ["--preview", "set-target", app.config.idf_target]) subprocess.check_call(idf_py + ['--preview', 'set-target', app.config.idf_target])
else: else:
subprocess.check_call(idf_py + ["set-target", app.config.idf_target]) subprocess.check_call(idf_py + ['set-target', app.config.idf_target])
print("Running CMake on dummy project...") print('Running CMake on dummy project...')
subprocess.check_call(idf_py + ["reconfigure"]) subprocess.check_call(idf_py + ['reconfigure'])
with open(os.path.join(cmake_build_dir, "project_description.json")) as f: with open(os.path.join(cmake_build_dir, 'project_description.json')) as f:
project_description = json.load(f) project_description = json.load(f)
if project_description["target"] != app.config.idf_target: if project_description['target'] != app.config.idf_target:
# this shouldn't really happen unless someone has been moving around directories inside _build, as # this shouldn't really happen unless someone has been moving around directories inside _build, as
# the cmake_build_dir path should be target-specific # the cmake_build_dir path should be target-specific
raise RuntimeError(("Error configuring the dummy IDF project for {}. " + raise RuntimeError(('Error configuring the dummy IDF project for {}. ' +
"Target in project description is {}. " + 'Target in project description is {}. ' +
"Is build directory contents corrupt?") 'Is build directory contents corrupt?')
.format(app.config.idf_target, project_description["target"])) .format(app.config.idf_target, project_description['target']))
app.emit('idf-info', project_description) app.emit('idf-info', project_description)
return [] return []

View File

@ -1,5 +1,5 @@
# Extension to generate esp_err definition as .rst # Extension to generate esp_err definition as .rst
from .util import copy_if_modified, call_with_python from .util import call_with_python, copy_if_modified
def setup(app): def setup(app):

View File

@ -35,8 +35,8 @@ def build_subset(app, config):
# Get all docs that will be built # Get all docs that will be built
docs = [filename for filename in get_matching_files(app.srcdir, compile_matchers(exclude_docs))] docs = [filename for filename in get_matching_files(app.srcdir, compile_matchers(exclude_docs))]
if not docs: if not docs:
raise ValueError("No documents to build") raise ValueError('No documents to build')
print("Building a subset of the documents: {}".format(docs)) print('Building a subset of the documents: {}'.format(docs))
# Sphinx requires a master document, if there is a document name 'index' then we pick that # Sphinx requires a master document, if there is a document name 'index' then we pick that
index_docs = [doc for doc in docs if 'index' in doc] index_docs = [doc for doc in docs if 'index' in doc]

View File

@ -1,9 +1,10 @@
import re
import os import os
import os.path import os.path
import re
from docutils import io, nodes, statemachine, utils from docutils import io, nodes, statemachine, utils
from docutils.utils.error_reporting import SafeString, ErrorString
from docutils.parsers.rst import directives from docutils.parsers.rst import directives
from docutils.utils.error_reporting import ErrorString, SafeString
from sphinx.directives.other import Include as BaseInclude from sphinx.directives.other import Include as BaseInclude
from sphinx.util import logging from sphinx.util import logging
@ -73,26 +74,26 @@ class StringSubstituter:
def init_sub_strings(self, config): def init_sub_strings(self, config):
self.target_name = config.idf_target self.target_name = config.idf_target
self.add_pair("{IDF_TARGET_NAME}", self.TARGET_NAMES[config.idf_target]) self.add_pair('{IDF_TARGET_NAME}', self.TARGET_NAMES[config.idf_target])
self.add_pair("{IDF_TARGET_PATH_NAME}", config.idf_target) self.add_pair('{IDF_TARGET_PATH_NAME}', config.idf_target)
self.add_pair("{IDF_TARGET_TOOLCHAIN_NAME}", self.TOOLCHAIN_NAMES[config.idf_target]) self.add_pair('{IDF_TARGET_TOOLCHAIN_NAME}', self.TOOLCHAIN_NAMES[config.idf_target])
self.add_pair("{IDF_TARGET_CFG_PREFIX}", self.CONFIG_PREFIX[config.idf_target]) self.add_pair('{IDF_TARGET_CFG_PREFIX}', self.CONFIG_PREFIX[config.idf_target])
self.add_pair("{IDF_TARGET_TRM_EN_URL}", self.TRM_EN_URL[config.idf_target]) self.add_pair('{IDF_TARGET_TRM_EN_URL}', self.TRM_EN_URL[config.idf_target])
self.add_pair("{IDF_TARGET_TRM_CN_URL}", self.TRM_CN_URL[config.idf_target]) self.add_pair('{IDF_TARGET_TRM_CN_URL}', self.TRM_CN_URL[config.idf_target])
def add_local_subs(self, matches): def add_local_subs(self, matches):
for sub_def in matches: for sub_def in matches:
if len(sub_def) != 2: if len(sub_def) != 2:
raise ValueError("IDF_TARGET_X substitution define invalid, val={}".format(sub_def)) raise ValueError('IDF_TARGET_X substitution define invalid, val={}'.format(sub_def))
tag = "{" + "IDF_TARGET_{}".format(sub_def[0]) + "}" tag = '{' + 'IDF_TARGET_{}'.format(sub_def[0]) + '}'
match_default = re.match(r'^\s*default(\s*)=(\s*)\"(.*?)\"', sub_def[1]) match_default = re.match(r'^\s*default(\s*)=(\s*)\"(.*?)\"', sub_def[1])
if match_default is None: if match_default is None:
# There should always be a default value # There should always be a default value
raise ValueError("No default value in IDF_TARGET_X substitution define, val={}".format(sub_def)) raise ValueError('No default value in IDF_TARGET_X substitution define, val={}'.format(sub_def))
match_target = re.match(r'^.*{}(\s*)=(\s*)\"(.*?)\"'.format(self.target_name), sub_def[1]) match_target = re.match(r'^.*{}(\s*)=(\s*)\"(.*?)\"'.format(self.target_name), sub_def[1])

View File

@ -8,35 +8,35 @@
import glob import glob
import os import os
import pprint import pprint
import subprocess
import re import re
import subprocess
def generate_defines(app, project_description): def generate_defines(app, project_description):
sdk_config_path = os.path.join(project_description["build_dir"], "config") sdk_config_path = os.path.join(project_description['build_dir'], 'config')
# Parse kconfig macros to pass into doxygen # Parse kconfig macros to pass into doxygen
# #
# TODO: this should use the set of "config which can't be changed" eventually, # TODO: this should use the set of "config which can't be changed" eventually,
# not the header # not the header
defines = get_defines(os.path.join(project_description["build_dir"], defines = get_defines(os.path.join(project_description['build_dir'],
"config", "sdkconfig.h"), sdk_config_path) 'config', 'sdkconfig.h'), sdk_config_path)
# Add all SOC _caps.h headers and kconfig macros to the defines # Add all SOC _caps.h headers and kconfig macros to the defines
# #
# kind of a hack, be nicer to add a component info dict in project_description.json # kind of a hack, be nicer to add a component info dict in project_description.json
soc_path = [p for p in project_description["build_component_paths"] if p.endswith("/soc")][0] soc_path = [p for p in project_description['build_component_paths'] if p.endswith('/soc')][0]
soc_headers = glob.glob(os.path.join(soc_path, project_description["target"], soc_headers = glob.glob(os.path.join(soc_path, project_description['target'],
"include", "soc", "*_caps.h")) 'include', 'soc', '*_caps.h'))
assert len(soc_headers) > 0 assert len(soc_headers) > 0
for soc_header in soc_headers: for soc_header in soc_headers:
defines.update(get_defines(soc_header, sdk_config_path)) defines.update(get_defines(soc_header, sdk_config_path))
# write a list of definitions to make debugging easier # write a list of definitions to make debugging easier
with open(os.path.join(app.config.build_dir, "macro-definitions.txt"), "w") as f: with open(os.path.join(app.config.build_dir, 'macro-definitions.txt'), 'w') as f:
pprint.pprint(defines, f) pprint.pprint(defines, f)
print("Saved macro list to %s" % f.name) print('Saved macro list to %s' % f.name)
add_tags(app, defines) add_tags(app, defines)
@ -48,19 +48,19 @@ def get_defines(header_path, sdk_config_path):
# Note: we run C preprocessor here without any -I arguments (except "sdkconfig.h"), so assumption is # Note: we run C preprocessor here without any -I arguments (except "sdkconfig.h"), so assumption is
# that these headers are all self-contained and don't include any other headers # that these headers are all self-contained and don't include any other headers
# not in the same directory # not in the same directory
print("Reading macros from %s..." % (header_path)) print('Reading macros from %s...' % (header_path))
processed_output = subprocess.check_output(["xtensa-esp32-elf-gcc", "-I", sdk_config_path, processed_output = subprocess.check_output(['xtensa-esp32-elf-gcc', '-I', sdk_config_path,
"-dM", "-E", header_path]).decode() '-dM', '-E', header_path]).decode()
for line in processed_output.split("\n"): for line in processed_output.split('\n'):
line = line.strip() line = line.strip()
m = re.search("#define ([^ ]+) ?(.*)", line) m = re.search('#define ([^ ]+) ?(.*)', line)
if m: if m:
name = m.group(1) name = m.group(1)
value = m.group(2) value = m.group(2)
if name.startswith("_"): if name.startswith('_'):
continue # toolchain macro continue # toolchain macro
if (" " in value) or ("=" in value): if (' ' in value) or ('=' in value):
value = "" # macros that expand to multiple tokens (ie function macros) cause doxygen errors, so just mark as 'defined' value = '' # macros that expand to multiple tokens (ie function macros) cause doxygen errors, so just mark as 'defined'
defines[name] = value defines[name] = value
return defines return defines
@ -70,7 +70,7 @@ def add_tags(app, defines):
# try to parse define values as ints and add to tags # try to parse define values as ints and add to tags
for name, value in defines.items(): for name, value in defines.items():
try: try:
define_value = int(value.strip("()")) define_value = int(value.strip('()'))
if define_value > 0: if define_value > 0:
app.tags.add(name) app.tags.add(name)
except ValueError: except ValueError:

View File

@ -1,7 +1,9 @@
# Generate toolchain download links from toolchain info makefile # Generate toolchain download links from toolchain info makefile
from __future__ import print_function from __future__ import print_function
import os.path import os.path
from .util import copy_if_modified, call_with_python
from .util import call_with_python, copy_if_modified
def setup(app): def setup(app):
@ -12,9 +14,9 @@ def setup(app):
def generate_idf_tools_links(app, project_description): def generate_idf_tools_links(app, project_description):
print("Generating IDF Tools list") print('Generating IDF Tools list')
os.environ["IDF_MAINTAINER"] = "1" os.environ['IDF_MAINTAINER'] = '1'
tools_rst = os.path.join(app.config.build_dir, 'inc', 'idf-tools-inc.rst') tools_rst = os.path.join(app.config.build_dir, 'inc', 'idf-tools-inc.rst')
tools_rst_tmp = os.path.join(app.config.build_dir, 'idf-tools-inc.rst') tools_rst_tmp = os.path.join(app.config.build_dir, 'idf-tools-inc.rst')
call_with_python("{}/tools/idf_tools.py gen-doc --output {}".format(app.config.idf_path, tools_rst_tmp)) call_with_python('{}/tools/idf_tools.py gen-doc --output {}'.format(app.config.idf_path, tools_rst_tmp))
copy_if_modified(tools_rst_tmp, tools_rst) copy_if_modified(tools_rst_tmp, tools_rst)

View File

@ -1,17 +1,19 @@
# Generate toolchain download links from toolchain info makefile # Generate toolchain download links from toolchain info makefile
from __future__ import print_function from __future__ import print_function
import os.path import os.path
from collections import namedtuple from collections import namedtuple
from .util import copy_if_modified from .util import copy_if_modified
BASE_URL = 'https://dl.espressif.com/dl/' BASE_URL = 'https://dl.espressif.com/dl/'
PlatformInfo = namedtuple("PlatformInfo", [ PlatformInfo = namedtuple('PlatformInfo', [
"platform_name", 'platform_name',
"platform_archive_suffix", 'platform_archive_suffix',
"extension", 'extension',
"unpack_cmd", 'unpack_cmd',
"unpack_code" 'unpack_code'
]) ])
@ -23,9 +25,9 @@ def setup(app):
def generate_toolchain_download_links(app, project_description): def generate_toolchain_download_links(app, project_description):
print("Generating toolchain download links") print('Generating toolchain download links')
toolchain_tmpdir = '{}/toolchain_inc'.format(app.config.build_dir) toolchain_tmpdir = '{}/toolchain_inc'.format(app.config.build_dir)
toolchain_versions = os.path.join(app.config.idf_path, "tools/toolchain_versions.mk") toolchain_versions = os.path.join(app.config.idf_path, 'tools/toolchain_versions.mk')
gen_toolchain_links(toolchain_versions, toolchain_tmpdir) gen_toolchain_links(toolchain_versions, toolchain_tmpdir)
copy_if_modified(toolchain_tmpdir, '{}/inc'.format(app.config.build_dir)) copy_if_modified(toolchain_tmpdir, '{}/inc'.format(app.config.build_dir))
@ -34,11 +36,11 @@ def gen_toolchain_links(versions_file, out_dir):
version_vars = {} version_vars = {}
with open(versions_file) as f: with open(versions_file) as f:
for line in f: for line in f:
name, var = line.partition("=")[::2] name, var = line.partition('=')[::2]
version_vars[name.strip()] = var.strip() version_vars[name.strip()] = var.strip()
gcc_version = version_vars["CURRENT_TOOLCHAIN_GCC_VERSION"] gcc_version = version_vars['CURRENT_TOOLCHAIN_GCC_VERSION']
toolchain_desc = version_vars["CURRENT_TOOLCHAIN_COMMIT_DESC_SHORT"] toolchain_desc = version_vars['CURRENT_TOOLCHAIN_COMMIT_DESC_SHORT']
unpack_code_linux_macos = """ unpack_code_linux_macos = """
:: ::
@ -59,10 +61,10 @@ def gen_toolchain_links(versions_file, out_dir):
""" """
platform_info = [ platform_info = [
PlatformInfo("linux64", "linux-amd64", "tar.gz", "z", unpack_code_linux_macos), PlatformInfo('linux64', 'linux-amd64', 'tar.gz', 'z', unpack_code_linux_macos),
PlatformInfo("linux32", "linux-i686","tar.gz", "z", unpack_code_linux_macos), PlatformInfo('linux32', 'linux-i686','tar.gz', 'z', unpack_code_linux_macos),
PlatformInfo("osx", "macos", "tar.gz", "z", unpack_code_linux_macos), PlatformInfo('osx', 'macos', 'tar.gz', 'z', unpack_code_linux_macos),
PlatformInfo("win32", "win32", "zip", None, None) PlatformInfo('win32', 'win32', 'zip', None, None)
] ]
try: try:
@ -70,7 +72,7 @@ def gen_toolchain_links(versions_file, out_dir):
except OSError: except OSError:
pass pass
with open(os.path.join(out_dir, 'download-links.inc'), "w") as links_file: with open(os.path.join(out_dir, 'download-links.inc'), 'w') as links_file:
for p in platform_info: for p in platform_info:
archive_name = 'xtensa-esp32-elf-gcc{}-{}-{}.{}'.format( archive_name = 'xtensa-esp32-elf-gcc{}-{}-{}.{}'.format(
gcc_version.replace('.', '_'), toolchain_desc, p.platform_archive_suffix, p.extension) gcc_version.replace('.', '_'), toolchain_desc, p.platform_archive_suffix, p.extension)
@ -79,8 +81,8 @@ def gen_toolchain_links(versions_file, out_dir):
p.platform_name, BASE_URL, archive_name), file=links_file) p.platform_name, BASE_URL, archive_name), file=links_file)
if p.unpack_code is not None: if p.unpack_code is not None:
with open(os.path.join(out_dir, 'unpack-code-%s.inc' % p.platform_name), "w") as f: with open(os.path.join(out_dir, 'unpack-code-%s.inc' % p.platform_name), 'w') as f:
print(p.unpack_code.format(p.unpack_cmd, archive_name), file=f) print(p.unpack_code.format(p.unpack_cmd, archive_name), file=f)
with open(os.path.join(out_dir, 'scratch-build-code.inc'), "w") as code_file: with open(os.path.join(out_dir, 'scratch-build-code.inc'), 'w') as code_file:
print(scratch_build_code_linux_macos.format(toolchain_desc), file=code_file) print(scratch_build_code_linux_macos.format(toolchain_desc), file=code_file)

View File

@ -4,17 +4,18 @@
# Sphinx extension to generate ReSTructured Text .inc snippets # Sphinx extension to generate ReSTructured Text .inc snippets
# with version-based content for this IDF version # with version-based content for this IDF version
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
from io import open
from .util import copy_if_modified
import subprocess
import os import os
import re import re
import subprocess
from io import open
from .util import copy_if_modified
TEMPLATES = { TEMPLATES = {
"en": { 'en': {
"git-clone-bash": """ 'git-clone-bash': """
.. code-block:: bash .. code-block:: bash
mkdir -p ~/esp mkdir -p ~/esp
@ -22,7 +23,7 @@ TEMPLATES = {
git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git
""", """,
"git-clone-windows": """ 'git-clone-windows': """
.. code-block:: batch .. code-block:: batch
mkdir %%userprofile%%\\esp mkdir %%userprofile%%\\esp
@ -30,8 +31,8 @@ TEMPLATES = {
git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git
""", """,
"git-clone-notes": { 'git-clone-notes': {
"template": """ 'template': """
.. note:: .. note::
%(extra_note)s %(extra_note)s
@ -40,35 +41,35 @@ TEMPLATES = {
%(zipfile_note)s %(zipfile_note)s
""", """,
"master": 'This command will clone the master branch, which has the latest development ("bleeding edge") ' 'master': 'This command will clone the master branch, which has the latest development ("bleeding edge") '
'version of ESP-IDF. It is fully functional and updated on weekly basis with the most recent features and bugfixes.', 'version of ESP-IDF. It is fully functional and updated on weekly basis with the most recent features and bugfixes.',
"branch": 'The ``git clone`` option ``-b %(clone_arg)s`` tells git to clone the %(ver_type)s in the ESP-IDF repository ``git clone`` ' 'branch': 'The ``git clone`` option ``-b %(clone_arg)s`` tells git to clone the %(ver_type)s in the ESP-IDF repository ``git clone`` '
'corresponding to this version of the documentation.', 'corresponding to this version of the documentation.',
"zipfile": { 'zipfile': {
"stable": 'As a fallback, it is also possible to download a zip file of this stable release from the `Releases page`_. ' 'stable': 'As a fallback, it is also possible to download a zip file of this stable release from the `Releases page`_. '
'Do not download the "Source code" zip file(s) generated automatically by GitHub, they do not work with ESP-IDF.', 'Do not download the "Source code" zip file(s) generated automatically by GitHub, they do not work with ESP-IDF.',
"unstable": 'GitHub\'s "Download zip file" feature does not work with ESP-IDF, a ``git clone`` is required. As a fallback, ' 'unstable': 'GitHub\'s "Download zip file" feature does not work with ESP-IDF, a ``git clone`` is required. As a fallback, '
'`Stable version`_ can be installed without Git.' '`Stable version`_ can be installed without Git.'
}, # zipfile }, # zipfile
}, # git-clone-notes }, # git-clone-notes
"version-note": { 'version-note': {
"master": """ 'master': """
.. note:: .. note::
This is documentation for the master branch (latest version) of ESP-IDF. This version is under continual development. This is documentation for the master branch (latest version) of ESP-IDF. This version is under continual development.
`Stable version`_ documentation is available, as well as other :doc:`/versions`. `Stable version`_ documentation is available, as well as other :doc:`/versions`.
""", """,
"stable": """ 'stable': """
.. note:: .. note::
This is documentation for stable version %s of ESP-IDF. Other :doc:`/versions` are also available. This is documentation for stable version %s of ESP-IDF. Other :doc:`/versions` are also available.
""", """,
"branch": """ 'branch': """
.. note:: .. note::
This is documentation for %s ``%s`` of ESP-IDF. Other :doc:`/versions` are also available. This is documentation for %s ``%s`` of ESP-IDF. Other :doc:`/versions` are also available.
""" """
}, # version-note }, # version-note
}, # en }, # en
"zh_CN": { 'zh_CN': {
"git-clone-bash": """ 'git-clone-bash': """
.. code-block:: bash .. code-block:: bash
mkdir -p ~/esp mkdir -p ~/esp
@ -76,7 +77,7 @@ TEMPLATES = {
git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git
""", """,
"git-clone-windows": """ 'git-clone-windows': """
.. code-block:: batch .. code-block:: batch
mkdir %%userprofile%%\\esp mkdir %%userprofile%%\\esp
@ -84,8 +85,8 @@ TEMPLATES = {
git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git git clone %(clone_args)s--recursive https://github.com/espressif/esp-idf.git
""", """,
"git-clone-notes": { 'git-clone-notes': {
"template": """ 'template': """
.. note:: .. note::
%(extra_note)s %(extra_note)s
@ -94,24 +95,24 @@ TEMPLATES = {
%(zipfile_note)s %(zipfile_note)s
""", """,
"master": '此命令将克隆 master 分支,该分支保存着 ESP-IDF 的最新版本,它功能齐全,每周都会更新一些新功能并修正一些错误。', 'master': '此命令将克隆 master 分支,该分支保存着 ESP-IDF 的最新版本,它功能齐全,每周都会更新一些新功能并修正一些错误。',
"branch": '``git clone`` 命令的 ``-b %(clone_arg)s`` 选项告诉 git 从 ESP-IDF 仓库中克隆与此版本的文档对应的分支。', 'branch': '``git clone`` 命令的 ``-b %(clone_arg)s`` 选项告诉 git 从 ESP-IDF 仓库中克隆与此版本的文档对应的分支。',
"zipfile": { 'zipfile': {
"stable": '作为备份,还可以从 `Releases page`_ 下载此稳定版本的 zip 文件。不要下载由 GitHub 自动生成的"源代码"的 zip 文件,它们不适用于 ESP-IDF。', 'stable': '作为备份,还可以从 `Releases page`_ 下载此稳定版本的 zip 文件。不要下载由 GitHub 自动生成的"源代码"的 zip 文件,它们不适用于 ESP-IDF。',
"unstable": 'GitHub 中"下载 zip 文档"的功能不适用于 ESP-IDF所以需要使用 ``git clone`` 命令。作为备份,可以在没有安装 Git 的环境中下载 ' 'unstable': 'GitHub 中"下载 zip 文档"的功能不适用于 ESP-IDF所以需要使用 ``git clone`` 命令。作为备份,可以在没有安装 Git 的环境中下载 '
'`Stable version`_ 的 zip 归档文件。' '`Stable version`_ 的 zip 归档文件。'
}, # zipfile }, # zipfile
}, # git-clone }, # git-clone
"version-note": { 'version-note': {
"master": """ 'master': """
.. note:: .. note::
这是ESP-IDF master 分支最新版本的文档该版本在持续开发中还有 `Stable version`_ 的文档以及其他版本的文档 :doc:`/versions` 供参考 这是ESP-IDF master 分支最新版本的文档该版本在持续开发中还有 `Stable version`_ 的文档以及其他版本的文档 :doc:`/versions` 供参考
""", """,
"stable": """ 'stable': """
.. note:: .. note::
这是ESP-IDF 稳定版本 %s 的文档还有其他版本的文档 :doc:`/versions` 供参考 这是ESP-IDF 稳定版本 %s 的文档还有其他版本的文档 :doc:`/versions` 供参考
""", """,
"branch": """ 'branch': """
.. note:: .. note::
这是ESP-IDF %s ``%s`` 版本的文档还有其他版本的文档 :doc:`/versions` 供参考 这是ESP-IDF %s ``%s`` 版本的文档还有其他版本的文档 :doc:`/versions` 供参考
""" """
@ -128,9 +129,9 @@ def setup(app):
def generate_version_specific_includes(app, project_description): def generate_version_specific_includes(app, project_description):
language = app.config.language language = app.config.language
tmp_out_dir = os.path.join(app.config.build_dir, "version_inc") tmp_out_dir = os.path.join(app.config.build_dir, 'version_inc')
if not os.path.exists(tmp_out_dir): if not os.path.exists(tmp_out_dir):
print("Creating directory %s" % tmp_out_dir) print('Creating directory %s' % tmp_out_dir)
os.mkdir(tmp_out_dir) os.mkdir(tmp_out_dir)
template = TEMPLATES[language] template = TEMPLATES[language]
@ -138,56 +139,56 @@ def generate_version_specific_includes(app, project_description):
version, ver_type, is_stable = get_version() version, ver_type, is_stable = get_version()
write_git_clone_inc_files(template, tmp_out_dir, version, ver_type, is_stable) write_git_clone_inc_files(template, tmp_out_dir, version, ver_type, is_stable)
write_version_note(template["version-note"], tmp_out_dir, version, ver_type, is_stable) write_version_note(template['version-note'], tmp_out_dir, version, ver_type, is_stable)
copy_if_modified(tmp_out_dir, os.path.join(app.config.build_dir, "inc")) copy_if_modified(tmp_out_dir, os.path.join(app.config.build_dir, 'inc'))
print("Done") print('Done')
def write_git_clone_inc_files(templates, out_dir, version, ver_type, is_stable): def write_git_clone_inc_files(templates, out_dir, version, ver_type, is_stable):
def out_file(basename): def out_file(basename):
p = os.path.join(out_dir, "%s.inc" % basename) p = os.path.join(out_dir, '%s.inc' % basename)
print("Writing %s..." % p) print('Writing %s...' % p)
return p return p
if version == "master": if version == 'master':
clone_args = "" clone_args = ''
else: else:
clone_args = "-b %s " % version clone_args = '-b %s ' % version
with open(out_file("git-clone-bash"), "w", encoding="utf-8") as f: with open(out_file('git-clone-bash'), 'w', encoding='utf-8') as f:
f.write(templates["git-clone-bash"] % locals()) f.write(templates['git-clone-bash'] % locals())
with open(out_file("git-clone-windows"), "w", encoding="utf-8") as f: with open(out_file('git-clone-windows'), 'w', encoding='utf-8') as f:
f.write(templates["git-clone-windows"] % locals()) f.write(templates['git-clone-windows'] % locals())
with open(out_file("git-clone-notes"), "w", encoding="utf-8") as f: with open(out_file('git-clone-notes'), 'w', encoding='utf-8') as f:
template = templates["git-clone-notes"] template = templates['git-clone-notes']
zipfile = template["zipfile"] zipfile = template['zipfile']
if version == "master": if version == 'master':
extra_note = template["master"] extra_note = template['master']
zipfile_note = zipfile["unstable"] zipfile_note = zipfile['unstable']
else: else:
extra_note = template["branch"] % {"clone_arg": version, "ver_type": ver_type} extra_note = template['branch'] % {'clone_arg': version, 'ver_type': ver_type}
zipfile_note = zipfile["stable"] if is_stable else zipfile["unstable"] zipfile_note = zipfile['stable'] if is_stable else zipfile['unstable']
f.write(template["template"] % locals()) f.write(template['template'] % locals())
print("Wrote git-clone-xxx.inc files") print('Wrote git-clone-xxx.inc files')
def write_version_note(template, out_dir, version, ver_type, is_stable): def write_version_note(template, out_dir, version, ver_type, is_stable):
if version == "master": if version == 'master':
content = template["master"] content = template['master']
elif ver_type == "tag" and is_stable: elif ver_type == 'tag' and is_stable:
content = template["stable"] % version content = template['stable'] % version
else: else:
content = template["branch"] % (ver_type, version) content = template['branch'] % (ver_type, version)
out_file = os.path.join(out_dir, "version-note.inc") out_file = os.path.join(out_dir, 'version-note.inc')
with open(out_file, "w", encoding='utf-8') as f: with open(out_file, 'w', encoding='utf-8') as f:
f.write(content) f.write(content)
print("%s written" % out_file) print('%s written' % out_file)
def get_version(): def get_version():
@ -196,22 +197,22 @@ def get_version():
""" """
# Use git to look for a tag # Use git to look for a tag
try: try:
tag = subprocess.check_output(["git", "describe", "--exact-match"]).strip().decode('utf-8') tag = subprocess.check_output(['git', 'describe', '--exact-match']).strip().decode('utf-8')
is_stable = re.match(r"v[0-9\.]+$", tag) is not None is_stable = re.match(r'v[0-9\.]+$', tag) is not None
return (tag, "tag", is_stable) return (tag, 'tag', is_stable)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
pass pass
# No tag, look at branch name from CI, this will give the correct branch name even if the ref for the branch we # No tag, look at branch name from CI, this will give the correct branch name even if the ref for the branch we
# merge into has moved forward before the pipeline runs # merge into has moved forward before the pipeline runs
branch = os.environ.get("CI_COMMIT_REF_NAME", None) branch = os.environ.get('CI_COMMIT_REF_NAME', None)
if branch is not None: if branch is not None:
return (branch, "branch", False) return (branch, 'branch', False)
# Try to find the branch name even if docs are built locally # Try to find the branch name even if docs are built locally
branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).strip().decode('utf-8') branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8')
if branch != "HEAD": if branch != 'HEAD':
return (branch, "branch", False) return (branch, 'branch', False)
# As a last resort we return commit SHA-1, should never happen in CI/docs that should be published # As a last resort we return commit SHA-1, should never happen in CI/docs that should be published
return (subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).strip().decode('utf-8'), "commit", False) return (subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip().decode('utf-8'), 'commit', False)

View File

@ -1,4 +1,5 @@
import os.path import os.path
from docutils.parsers.rst import directives from docutils.parsers.rst import directives
from docutils.parsers.rst.directives.misc import Include as BaseInclude from docutils.parsers.rst.directives.misc import Include as BaseInclude
from sphinx.util.docutils import SphinxDirective from sphinx.util.docutils import SphinxDirective

View File

@ -1,7 +1,7 @@
# Extension to generate the KConfig reference list # Extension to generate the KConfig reference list
import os.path import os.path
import sys
import subprocess import subprocess
import sys
from .util import copy_if_modified from .util import copy_if_modified
@ -18,18 +18,18 @@ def generate_reference(app, project_description):
build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep)) build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep))
# Generate 'kconfig.inc' file from components' Kconfig files # Generate 'kconfig.inc' file from components' Kconfig files
print("Generating kconfig.inc from kconfig contents") print('Generating kconfig.inc from kconfig contents')
kconfig_inc_path = '{}/inc/kconfig.inc'.format(build_dir) kconfig_inc_path = '{}/inc/kconfig.inc'.format(build_dir)
temp_sdkconfig_path = '{}/sdkconfig.tmp'.format(build_dir) temp_sdkconfig_path = '{}/sdkconfig.tmp'.format(build_dir)
kconfigs = project_description["config_environment"]["COMPONENT_KCONFIGS"].split(";") kconfigs = project_description['config_environment']['COMPONENT_KCONFIGS'].split(';')
kconfig_projbuilds = project_description["config_environment"]["COMPONENT_KCONFIGS_PROJBUILD"].split(";") kconfig_projbuilds = project_description['config_environment']['COMPONENT_KCONFIGS_PROJBUILD'].split(';')
sdkconfig_renames = set() sdkconfig_renames = set()
# TODO: this should be generated in project description as well, if possible # TODO: this should be generated in project description as well, if possible
for k in kconfigs + kconfig_projbuilds: for k in kconfigs + kconfig_projbuilds:
component_dir = os.path.dirname(k) component_dir = os.path.dirname(k)
sdkconfig_rename = os.path.join(component_dir, "sdkconfig.rename") sdkconfig_rename = os.path.join(component_dir, 'sdkconfig.rename')
if os.path.exists(sdkconfig_rename): if os.path.exists(sdkconfig_rename):
sdkconfig_renames.add(sdkconfig_rename) sdkconfig_renames.add(sdkconfig_rename)
@ -37,27 +37,27 @@ def generate_reference(app, project_description):
kconfig_projbuilds_source_path = '{}/inc/kconfig_projbuilds_source.in'.format(build_dir) kconfig_projbuilds_source_path = '{}/inc/kconfig_projbuilds_source.in'.format(build_dir)
prepare_kconfig_files_args = [sys.executable, prepare_kconfig_files_args = [sys.executable,
"{}/tools/kconfig_new/prepare_kconfig_files.py".format(app.config.idf_path), '{}/tools/kconfig_new/prepare_kconfig_files.py'.format(app.config.idf_path),
"--env", "COMPONENT_KCONFIGS={}".format(" ".join(kconfigs)), '--env', 'COMPONENT_KCONFIGS={}'.format(' '.join(kconfigs)),
"--env", "COMPONENT_KCONFIGS_PROJBUILD={}".format(" ".join(kconfig_projbuilds)), '--env', 'COMPONENT_KCONFIGS_PROJBUILD={}'.format(' '.join(kconfig_projbuilds)),
"--env", "COMPONENT_KCONFIGS_SOURCE_FILE={}".format(kconfigs_source_path), '--env', 'COMPONENT_KCONFIGS_SOURCE_FILE={}'.format(kconfigs_source_path),
"--env", "COMPONENT_KCONFIGS_PROJBUILD_SOURCE_FILE={}".format(kconfig_projbuilds_source_path), '--env', 'COMPONENT_KCONFIGS_PROJBUILD_SOURCE_FILE={}'.format(kconfig_projbuilds_source_path),
] ]
subprocess.check_call(prepare_kconfig_files_args) subprocess.check_call(prepare_kconfig_files_args)
confgen_args = [sys.executable, confgen_args = [sys.executable,
"{}/tools/kconfig_new/confgen.py".format(app.config.idf_path), '{}/tools/kconfig_new/confgen.py'.format(app.config.idf_path),
"--kconfig", "./Kconfig", '--kconfig', './Kconfig',
"--sdkconfig-rename", "./sdkconfig.rename", '--sdkconfig-rename', './sdkconfig.rename',
"--config", temp_sdkconfig_path, '--config', temp_sdkconfig_path,
"--env", "COMPONENT_KCONFIGS={}".format(" ".join(kconfigs)), '--env', 'COMPONENT_KCONFIGS={}'.format(' '.join(kconfigs)),
"--env", "COMPONENT_KCONFIGS_PROJBUILD={}".format(" ".join(kconfig_projbuilds)), '--env', 'COMPONENT_KCONFIGS_PROJBUILD={}'.format(' '.join(kconfig_projbuilds)),
"--env", "COMPONENT_SDKCONFIG_RENAMES={}".format(" ".join(sdkconfig_renames)), '--env', 'COMPONENT_SDKCONFIG_RENAMES={}'.format(' '.join(sdkconfig_renames)),
"--env", "COMPONENT_KCONFIGS_SOURCE_FILE={}".format(kconfigs_source_path), '--env', 'COMPONENT_KCONFIGS_SOURCE_FILE={}'.format(kconfigs_source_path),
"--env", "COMPONENT_KCONFIGS_PROJBUILD_SOURCE_FILE={}".format(kconfig_projbuilds_source_path), '--env', 'COMPONENT_KCONFIGS_PROJBUILD_SOURCE_FILE={}'.format(kconfig_projbuilds_source_path),
"--env", "IDF_PATH={}".format(app.config.idf_path), '--env', 'IDF_PATH={}'.format(app.config.idf_path),
"--env", "IDF_TARGET={}".format(app.config.idf_target), '--env', 'IDF_TARGET={}'.format(app.config.idf_target),
"--output", "docs", kconfig_inc_path + '.in' '--output', 'docs', kconfig_inc_path + '.in'
] ]
subprocess.check_call(confgen_args, cwd=app.config.idf_path) subprocess.check_call(confgen_args, cwd=app.config.idf_path)
copy_if_modified(kconfig_inc_path + '.in', kconfig_inc_path) copy_if_modified(kconfig_inc_path + '.in', kconfig_inc_path)

View File

@ -1,6 +1,7 @@
from sphinx.builders.latex import LaTeXBuilder
import os import os
from sphinx.builders.latex import LaTeXBuilder
# Overrides the default Sphinx latex build # Overrides the default Sphinx latex build
class IdfLatexBuilder(LaTeXBuilder): class IdfLatexBuilder(LaTeXBuilder):
@ -26,7 +27,7 @@ class IdfLatexBuilder(LaTeXBuilder):
def prepare_latex_macros(self, package_path, config): def prepare_latex_macros(self, package_path, config):
PACKAGE_NAME = "espidf.sty" PACKAGE_NAME = 'espidf.sty'
latex_package = '' latex_package = ''
with open(package_path, 'r') as template: with open(package_path, 'r') as template:
@ -36,7 +37,7 @@ class IdfLatexBuilder(LaTeXBuilder):
latex_package = latex_package.replace('<idf_target_title>', idf_target_title) latex_package = latex_package.replace('<idf_target_title>', idf_target_title)
# Release name for the PDF front page, remove '_' as this is used for subscript in Latex # Release name for the PDF front page, remove '_' as this is used for subscript in Latex
idf_release_name = "Release {}".format(config.version.replace('_', '-')) idf_release_name = 'Release {}'.format(config.version.replace('_', '-'))
latex_package = latex_package.replace('<idf_release_name>', idf_release_name) latex_package = latex_package.replace('<idf_release_name>', idf_release_name)
with open(os.path.join(self.outdir, PACKAGE_NAME), 'w') as package_file: with open(os.path.join(self.outdir, PACKAGE_NAME), 'w') as package_file:
@ -45,7 +46,7 @@ class IdfLatexBuilder(LaTeXBuilder):
def finish(self): def finish(self):
super().finish() super().finish()
TEMPLATE_PATH = "../latex_templates/espidf.sty" TEMPLATE_PATH = '../latex_templates/espidf.sty'
self.prepare_latex_macros(os.path.join(self.confdir,TEMPLATE_PATH), self.config) self.prepare_latex_macros(os.path.join(self.confdir,TEMPLATE_PATH), self.config)

View File

@ -1,14 +1,15 @@
# based on http://protips.readthedocs.io/link-roles.html # based on http://protips.readthedocs.io/link-roles.html
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
import re
import os import os
import re
import subprocess import subprocess
from docutils import nodes
from collections import namedtuple from collections import namedtuple
from sphinx.transforms.post_transforms import SphinxPostTransform
from docutils import nodes
from get_github_rev import get_github_rev from get_github_rev import get_github_rev
from sphinx.transforms.post_transforms import SphinxPostTransform
# Creates a dict of all submodules with the format {submodule_path : (url relative to git root), commit)} # Creates a dict of all submodules with the format {submodule_path : (url relative to git root), commit)}
@ -27,7 +28,7 @@ def get_submodules():
rev = sub_info[0].lstrip('-')[0:7] rev = sub_info[0].lstrip('-')[0:7]
path = sub_info[1].lstrip('./') path = sub_info[1].lstrip('./')
config_key_arg = "submodule.{}.url".format(path) config_key_arg = 'submodule.{}.url'.format(path)
rel_url = subprocess.check_output(['git', 'config', '--file', gitmodules_file, '--get', config_key_arg]).decode('utf-8').lstrip('./').rstrip('\n') rel_url = subprocess.check_output(['git', 'config', '--file', gitmodules_file, '--get', config_key_arg]).decode('utf-8').lstrip('./').rstrip('\n')
submodule_dict[path] = Submodule(rel_url, rev) submodule_dict[path] = Submodule(rel_url, rev)
@ -38,8 +39,8 @@ def get_submodules():
def url_join(*url_parts): def url_join(*url_parts):
""" Make a URL out of multiple components, assume first part is the https:// part and """ Make a URL out of multiple components, assume first part is the https:// part and
anything else is a path component """ anything else is a path component """
result = "/".join(url_parts) result = '/'.join(url_parts)
result = re.sub(r"([^:])//+", r"\1/", result) # remove any // that isn't in the https:// part result = re.sub(r'([^:])//+', r'\1/', result) # remove any // that isn't in the https:// part
return result return result
@ -47,7 +48,7 @@ def github_link(link_type, idf_rev, submods, root_path, app_config):
def role(name, rawtext, text, lineno, inliner, options={}, content=[]): def role(name, rawtext, text, lineno, inliner, options={}, content=[]):
msgs = [] msgs = []
BASE_URL = 'https://github.com/' BASE_URL = 'https://github.com/'
IDF_REPO = "espressif/esp-idf" IDF_REPO = 'espressif/esp-idf'
def warning(msg): def warning(msg):
system_msg = inliner.reporter.warning(msg) system_msg = inliner.reporter.warning(msg)
@ -90,31 +91,31 @@ def github_link(link_type, idf_rev, submods, root_path, app_config):
line_no = tuple(int(ln_group) for ln_group in line_no.groups() if ln_group) # tuple of (nnn,) or (nnn, NNN) for ranges line_no = tuple(int(ln_group) for ln_group in line_no.groups() if ln_group) # tuple of (nnn,) or (nnn, NNN) for ranges
elif '#' in abs_path: # drop any other anchor from the line elif '#' in abs_path: # drop any other anchor from the line
abs_path = abs_path.split('#')[0] abs_path = abs_path.split('#')[0]
warning("URL %s seems to contain an unusable anchor after the #, only line numbers are supported" % link) warning('URL %s seems to contain an unusable anchor after the #, only line numbers are supported' % link)
is_dir = (link_type == 'tree') is_dir = (link_type == 'tree')
if not os.path.exists(abs_path): if not os.path.exists(abs_path):
warning("IDF path %s does not appear to exist (absolute path %s)" % (rel_path, abs_path)) warning('IDF path %s does not appear to exist (absolute path %s)' % (rel_path, abs_path))
elif is_dir and not os.path.isdir(abs_path): elif is_dir and not os.path.isdir(abs_path):
# note these "wrong type" warnings are not strictly needed as GitHub will apply a redirect, # note these "wrong type" warnings are not strictly needed as GitHub will apply a redirect,
# but the may become important in the future (plus make for cleaner links) # but the may become important in the future (plus make for cleaner links)
warning("IDF path %s is not a directory but role :%s: is for linking to a directory, try :%s_file:" % (rel_path, name, name)) warning('IDF path %s is not a directory but role :%s: is for linking to a directory, try :%s_file:' % (rel_path, name, name))
elif not is_dir and os.path.isdir(abs_path): elif not is_dir and os.path.isdir(abs_path):
warning("IDF path %s is a directory but role :%s: is for linking to a file" % (rel_path, name)) warning('IDF path %s is a directory but role :%s: is for linking to a file' % (rel_path, name))
# check the line number is valid # check the line number is valid
if line_no: if line_no:
if is_dir: if is_dir:
warning("URL %s contains a line number anchor but role :%s: is for linking to a directory" % (rel_path, name, name)) warning('URL %s contains a line number anchor but role :%s: is for linking to a directory' % (rel_path, name, name))
elif os.path.exists(abs_path) and not os.path.isdir(abs_path): elif os.path.exists(abs_path) and not os.path.isdir(abs_path):
with open(abs_path, "r") as f: with open(abs_path, 'r') as f:
lines = len(f.readlines()) lines = len(f.readlines())
if any(True for ln in line_no if ln > lines): if any(True for ln in line_no if ln > lines):
warning("URL %s specifies a range larger than file (file has %d lines)" % (rel_path, lines)) warning('URL %s specifies a range larger than file (file has %d lines)' % (rel_path, lines))
if tuple(sorted(line_no)) != line_no: # second line number comes before first one! if tuple(sorted(line_no)) != line_no: # second line number comes before first one!
warning("URL %s specifies a backwards line number range" % rel_path) warning('URL %s specifies a backwards line number range' % rel_path)
node = nodes.reference(rawtext, link_text, refuri=url, **options) node = nodes.reference(rawtext, link_text, refuri=url, **options)
return [node], msgs return [node], msgs
@ -148,7 +149,7 @@ class TranslationLinkNodeTransform(SphinxPostTransform):
doc_path = env.doc2path(docname, None, None) doc_path = env.doc2path(docname, None, None)
return_path = '../' * doc_path.count('/') # path back to the root from 'docname' return_path = '../' * doc_path.count('/') # path back to the root from 'docname'
# then take off 3 more paths for language/release/targetname and build the new URL # then take off 3 more paths for language/release/targetname and build the new URL
url = "{}.html".format(os.path.join(return_path, '../../..', language, env.config.release, url = '{}.html'.format(os.path.join(return_path, '../../..', language, env.config.release,
env.config.idf_target, docname)) env.config.idf_target, docname))
node.replace_self(nodes.reference(rawtext, link_text, refuri=url, **options)) node.replace_self(nodes.reference(rawtext, link_text, refuri=url, **options))
else: else:

View File

@ -1,20 +1,21 @@
# Extension to generate Doxygen XML include files, with IDF config & soc macros included # Extension to generate Doxygen XML include files, with IDF config & soc macros included
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
from io import open
import os import os
import os.path import os.path
import re import re
import subprocess import subprocess
from io import open
from .util import copy_if_modified from .util import copy_if_modified
ALL_KINDS = [ ALL_KINDS = [
("function", "Functions"), ('function', 'Functions'),
("union", "Unions"), ('union', 'Unions'),
("struct", "Structures"), ('struct', 'Structures'),
("define", "Macros"), ('define', 'Macros'),
("typedef", "Type Definitions"), ('typedef', 'Type Definitions'),
("enum", "Enumerations") ('enum', 'Enumerations')
] ]
"""list of items that will be generated for a single API file """list of items that will be generated for a single API file
""" """
@ -30,27 +31,27 @@ def generate_doxygen(app, defines):
build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep)) build_dir = os.path.dirname(app.doctreedir.rstrip(os.sep))
# Call Doxygen to get XML files from the header files # Call Doxygen to get XML files from the header files
print("Calling Doxygen to generate latest XML files") print('Calling Doxygen to generate latest XML files')
doxy_env = os.environ doxy_env = os.environ
doxy_env.update({ doxy_env.update({
"ENV_DOXYGEN_DEFINES": " ".join('{}={}'.format(key, value) for key, value in defines.items()), 'ENV_DOXYGEN_DEFINES': ' '.join('{}={}'.format(key, value) for key, value in defines.items()),
"IDF_PATH": app.config.idf_path, 'IDF_PATH': app.config.idf_path,
"IDF_TARGET": app.config.idf_target, 'IDF_TARGET': app.config.idf_target,
}) })
doxyfile_dir = os.path.join(app.config.docs_root, "doxygen") doxyfile_dir = os.path.join(app.config.docs_root, 'doxygen')
doxyfile_main = os.path.join(doxyfile_dir, "Doxyfile_common") doxyfile_main = os.path.join(doxyfile_dir, 'Doxyfile_common')
doxyfile_target = os.path.join(doxyfile_dir, "Doxyfile_" + app.config.idf_target) doxyfile_target = os.path.join(doxyfile_dir, 'Doxyfile_' + app.config.idf_target)
print("Running doxygen with doxyfiles {} and {}".format(doxyfile_main, doxyfile_target)) print('Running doxygen with doxyfiles {} and {}'.format(doxyfile_main, doxyfile_target))
# It's possible to have doxygen log warnings to a file using WARN_LOGFILE directive, # It's possible to have doxygen log warnings to a file using WARN_LOGFILE directive,
# but in some cases it will still log an error to stderr and return success! # but in some cases it will still log an error to stderr and return success!
# #
# So take all of stderr and redirect it to a logfile (will contain warnings and errors) # So take all of stderr and redirect it to a logfile (will contain warnings and errors)
logfile = os.path.join(build_dir, "doxygen-warning-log.txt") logfile = os.path.join(build_dir, 'doxygen-warning-log.txt')
with open(logfile, "w") as f: with open(logfile, 'w') as f:
# note: run Doxygen in the build directory, so the xml & xml_in files end up in there # note: run Doxygen in the build directory, so the xml & xml_in files end up in there
subprocess.check_call(["doxygen", doxyfile_main], env=doxy_env, cwd=build_dir, stderr=f) subprocess.check_call(['doxygen', doxyfile_main], env=doxy_env, cwd=build_dir, stderr=f)
# Doxygen has generated XML files in 'xml' directory. # Doxygen has generated XML files in 'xml' directory.
# Copy them to 'xml_in', only touching the files which have changed. # Copy them to 'xml_in', only touching the files which have changed.
@ -69,11 +70,11 @@ def convert_api_xml_to_inc(app, doxyfiles):
""" """
build_dir = app.config.build_dir build_dir = app.config.build_dir
xml_directory_path = "{}/xml".format(build_dir) xml_directory_path = '{}/xml'.format(build_dir)
inc_directory_path = "{}/inc".format(build_dir) inc_directory_path = '{}/inc'.format(build_dir)
if not os.path.isdir(xml_directory_path): if not os.path.isdir(xml_directory_path):
raise RuntimeError("Directory {} does not exist!".format(xml_directory_path)) raise RuntimeError('Directory {} does not exist!'.format(xml_directory_path))
if not os.path.exists(inc_directory_path): if not os.path.exists(inc_directory_path):
os.makedirs(inc_directory_path) os.makedirs(inc_directory_path)
@ -83,16 +84,16 @@ def convert_api_xml_to_inc(app, doxyfiles):
print("Generating 'api_name.inc' files with Doxygen directives") print("Generating 'api_name.inc' files with Doxygen directives")
for header_file_path in header_paths: for header_file_path in header_paths:
api_name = get_api_name(header_file_path) api_name = get_api_name(header_file_path)
inc_file_path = inc_directory_path + "/" + api_name + ".inc" inc_file_path = inc_directory_path + '/' + api_name + '.inc'
rst_output = generate_directives(header_file_path, xml_directory_path) rst_output = generate_directives(header_file_path, xml_directory_path)
previous_rst_output = '' previous_rst_output = ''
if os.path.isfile(inc_file_path): if os.path.isfile(inc_file_path):
with open(inc_file_path, "r", encoding='utf-8') as inc_file_old: with open(inc_file_path, 'r', encoding='utf-8') as inc_file_old:
previous_rst_output = inc_file_old.read() previous_rst_output = inc_file_old.read()
if previous_rst_output != rst_output: if previous_rst_output != rst_output:
with open(inc_file_path, "w", encoding='utf-8') as inc_file: with open(inc_file_path, 'w', encoding='utf-8') as inc_file:
inc_file.write(rst_output) inc_file.write(rst_output)
@ -108,11 +109,11 @@ def get_doxyfile_input_paths(app, doxyfile_path):
print("Getting Doxyfile's INPUT") print("Getting Doxyfile's INPUT")
with open(doxyfile_path, "r", encoding='utf-8') as input_file: with open(doxyfile_path, 'r', encoding='utf-8') as input_file:
line = input_file.readline() line = input_file.readline()
# read contents of Doxyfile until 'INPUT' statement # read contents of Doxyfile until 'INPUT' statement
while line: while line:
if line.find("INPUT") == 0: if line.find('INPUT') == 0:
break break
line = input_file.readline() line = input_file.readline()
@ -124,13 +125,13 @@ def get_doxyfile_input_paths(app, doxyfile_path):
# we have reached the end of 'INPUT' statement # we have reached the end of 'INPUT' statement
break break
# process only lines that are not comments # process only lines that are not comments
if line.find("#") == -1: if line.find('#') == -1:
# extract header file path inside components folder # extract header file path inside components folder
m = re.search("components/(.*\.h)", line) # noqa: W605 - regular expression m = re.search('components/(.*\.h)', line) # noqa: W605 - regular expression
header_file_path = m.group(1) header_file_path = m.group(1)
# Replace env variable used for multi target header # Replace env variable used for multi target header
header_file_path = header_file_path.replace("$(IDF_TARGET)", app.config.idf_target) header_file_path = header_file_path.replace('$(IDF_TARGET)', app.config.idf_target)
doxyfile_INPUT.append(header_file_path) doxyfile_INPUT.append(header_file_path)
@ -150,8 +151,8 @@ def get_api_name(header_file_path):
The name of API. The name of API.
""" """
api_name = "" api_name = ''
regex = r".*/(.*)\.h" regex = r'.*/(.*)\.h'
m = re.search(regex, header_file_path) m = re.search(regex, header_file_path)
if m: if m:
api_name = m.group(1) api_name = m.group(1)
@ -173,15 +174,15 @@ def generate_directives(header_file_path, xml_directory_path):
api_name = get_api_name(header_file_path) api_name = get_api_name(header_file_path)
# in XLT file name each "_" in the api name is expanded by Doxygen to "__" # in XLT file name each "_" in the api name is expanded by Doxygen to "__"
xlt_api_name = api_name.replace("_", "__") xlt_api_name = api_name.replace('_', '__')
xml_file_path = "%s/%s_8h.xml" % (xml_directory_path, xlt_api_name) xml_file_path = '%s/%s_8h.xml' % (xml_directory_path, xlt_api_name)
rst_output = "" rst_output = ''
rst_output = ".. File automatically generated by 'gen-dxd.py'\n" rst_output = ".. File automatically generated by 'gen-dxd.py'\n"
rst_output += "\n" rst_output += '\n'
rst_output += get_rst_header("Header File") rst_output += get_rst_header('Header File')
rst_output += "* :component_file:`" + header_file_path + "`\n" rst_output += '* :component_file:`' + header_file_path + '`\n'
rst_output += "\n" rst_output += '\n'
try: try:
import xml.etree.cElementTree as ET import xml.etree.cElementTree as ET
@ -206,10 +207,10 @@ def get_rst_header(header_name):
""" """
rst_output = "" rst_output = ''
rst_output += header_name + "\n" rst_output += header_name + '\n'
rst_output += "^" * len(header_name) + "\n" rst_output += '^' * len(header_name) + '\n'
rst_output += "\n" rst_output += '\n'
return rst_output return rst_output
@ -226,14 +227,14 @@ def select_unions(innerclass_list):
""" """
rst_output = "" rst_output = ''
for line in innerclass_list.splitlines(): for line in innerclass_list.splitlines():
# union is denoted by "union" at the beginning of line # union is denoted by "union" at the beginning of line
if line.find("union") == 0: if line.find('union') == 0:
union_id, union_name = re.split(r"\t+", line) union_id, union_name = re.split(r'\t+', line)
rst_output += ".. doxygenunion:: " rst_output += '.. doxygenunion:: '
rst_output += union_name rst_output += union_name
rst_output += "\n" rst_output += '\n'
return rst_output return rst_output
@ -251,20 +252,20 @@ def select_structs(innerclass_list):
""" """
rst_output = "" rst_output = ''
for line in innerclass_list.splitlines(): for line in innerclass_list.splitlines():
# structure is denoted by "struct" at the beginning of line # structure is denoted by "struct" at the beginning of line
if line.find("struct") == 0: if line.find('struct') == 0:
# skip structures that are part of union # skip structures that are part of union
# they are documented by 'doxygenunion' directive # they are documented by 'doxygenunion' directive
if line.find("::") > 0: if line.find('::') > 0:
continue continue
struct_id, struct_name = re.split(r"\t+", line) struct_id, struct_name = re.split(r'\t+', line)
rst_output += ".. doxygenstruct:: " rst_output += '.. doxygenstruct:: '
rst_output += struct_name rst_output += struct_name
rst_output += "\n" rst_output += '\n'
rst_output += " :members:\n" rst_output += ' :members:\n'
rst_output += "\n" rst_output += '\n'
return rst_output return rst_output
@ -282,12 +283,12 @@ def get_directives(tree, kind):
""" """
rst_output = "" rst_output = ''
if kind in ["union", "struct"]: if kind in ['union', 'struct']:
innerclass_list = "" innerclass_list = ''
for elem in tree.iterfind('compounddef/innerclass'): for elem in tree.iterfind('compounddef/innerclass'):
innerclass_list += elem.attrib["refid"] + "\t" + elem.text + "\n" innerclass_list += elem.attrib['refid'] + '\t' + elem.text + '\n'
if kind == "union": if kind == 'union':
rst_output += select_unions(innerclass_list) rst_output += select_unions(innerclass_list)
else: else:
rst_output += select_structs(innerclass_list) rst_output += select_structs(innerclass_list)
@ -295,10 +296,10 @@ def get_directives(tree, kind):
for elem in tree.iterfind( for elem in tree.iterfind(
'compounddef/sectiondef/memberdef[@kind="%s"]' % kind): 'compounddef/sectiondef/memberdef[@kind="%s"]' % kind):
name = elem.find('name') name = elem.find('name')
rst_output += ".. doxygen%s:: " % kind rst_output += '.. doxygen%s:: ' % kind
rst_output += name.text + "\n" rst_output += name.text + '\n'
if rst_output: if rst_output:
all_kinds_dict = dict(ALL_KINDS) all_kinds_dict = dict(ALL_KINDS)
rst_output = get_rst_header(all_kinds_dict[kind]) + rst_output + "\n" rst_output = get_rst_header(all_kinds_dict[kind]) + rst_output + '\n'
return rst_output return rst_output

View File

@ -15,10 +15,11 @@
# limitations under the License. # limitations under the License.
from __future__ import unicode_literals from __future__ import unicode_literals
from io import open
import os import os
import shutil import shutil
import sys import sys
from io import open
try: try:
import urllib.request import urllib.request
@ -33,10 +34,10 @@ def files_equal(path_1, path_2):
if not os.path.exists(path_1) or not os.path.exists(path_2): if not os.path.exists(path_1) or not os.path.exists(path_2):
return False return False
file_1_contents = '' file_1_contents = ''
with open(path_1, "r", encoding='utf-8') as f_1: with open(path_1, 'r', encoding='utf-8') as f_1:
file_1_contents = f_1.read() file_1_contents = f_1.read()
file_2_contents = '' file_2_contents = ''
with open(path_2, "r", encoding='utf-8') as f_2: with open(path_2, 'r', encoding='utf-8') as f_2:
file_2_contents = f_2.read() file_2_contents = f_2.read()
return file_1_contents == file_2_contents return file_1_contents == file_2_contents
@ -63,7 +64,7 @@ def copy_if_modified(src_path, dst_path):
def download_file_if_missing(from_url, to_path): def download_file_if_missing(from_url, to_path):
filename_with_path = to_path + "/" + os.path.basename(from_url) filename_with_path = to_path + '/' + os.path.basename(from_url)
exists = os.path.isfile(filename_with_path) exists = os.path.isfile(filename_with_path)
if exists: if exists:
print("The file '%s' already exists" % (filename_with_path)) print("The file '%s' already exists" % (filename_with_path))

View File

@ -35,8 +35,8 @@ def sanitize_version(original_version):
except KeyError: except KeyError:
version = original_version version = original_version
if version == "master": if version == 'master':
return "latest" return 'latest'
version = version.replace('/', '-') version = version.replace('/', '-')

View File

@ -8,8 +8,8 @@
try: try:
from conf_common import * # noqa: F403,F401 from conf_common import * # noqa: F403,F401
except ImportError: except ImportError:
import sys
import os import os
import sys
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath('../..'))
from conf_common import * # noqa: F403,F401 from conf_common import * # noqa: F403,F401
@ -27,7 +27,7 @@ html_logo = None
latex_logo = None latex_logo = None
html_static_path = [] html_static_path = []
conditional_include_dict = {'esp32':["esp32_page.rst"], conditional_include_dict = {'esp32':['esp32_page.rst'],
'esp32s2':["esp32s2_page.rst"], 'esp32s2':['esp32s2_page.rst'],
'SOC_BT_SUPPORTED':["bt_page.rst"], 'SOC_BT_SUPPORTED':['bt_page.rst'],
} }

View File

@ -1,16 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import unittest import os
import subprocess import subprocess
import sys import sys
import os import unittest
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
ESP32_DOC = "esp32_page" ESP32_DOC = 'esp32_page'
ESP32_S2_DOC = "esp32s2_page" ESP32_S2_DOC = 'esp32s2_page'
BT_DOC = "bt_page" BT_DOC = 'bt_page'
LINK_ROLES_DOC = "link_roles" LINK_ROLES_DOC = 'link_roles'
IDF_FORMAT_DOC = "idf_target_format" IDF_FORMAT_DOC = 'idf_target_format'
class DocBuilder(): class DocBuilder():
@ -24,7 +24,7 @@ class DocBuilder():
self.html_out_dir = os.path.join(CURRENT_DIR, build_dir, language, target, 'html') self.html_out_dir = os.path.join(CURRENT_DIR, build_dir, language, target, 'html')
def build(self, opt_args=[]): def build(self, opt_args=[]):
args = [sys.executable, self.build_docs_py_path, "-b", self.build_dir, "-s", self.src_dir, "-t", self.target, "-l", self.language] args = [sys.executable, self.build_docs_py_path, '-b', self.build_dir, '-s', self.src_dir, '-t', self.target, '-l', self.language]
args.extend(opt_args) args.extend(opt_args)
return subprocess.call(args) return subprocess.call(args)
@ -33,65 +33,65 @@ class TestDocs(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.builder = DocBuilder("test", "_build/test_docs", "esp32s2", "en") cls.builder = DocBuilder('test', '_build/test_docs', 'esp32s2', 'en')
cls.build_ret_flag = cls.builder.build() cls.build_ret_flag = cls.builder.build()
def setUp(self): def setUp(self):
if self.build_ret_flag: if self.build_ret_flag:
self.fail("Build docs failed with return: {}".format(self.build_ret_flag)) self.fail('Build docs failed with return: {}'.format(self.build_ret_flag))
def assert_str_not_in_doc(self, doc_name, str_to_find): def assert_str_not_in_doc(self, doc_name, str_to_find):
with open(os.path.join(self.builder.html_out_dir, doc_name)) as f: with open(os.path.join(self.builder.html_out_dir, doc_name)) as f:
content = f.read() content = f.read()
self.assertFalse(str_to_find in content, "Found {} in {}".format(str_to_find, doc_name)) self.assertFalse(str_to_find in content, 'Found {} in {}'.format(str_to_find, doc_name))
def assert_str_in_doc(self, doc_name, str_to_find): def assert_str_in_doc(self, doc_name, str_to_find):
with open(os.path.join(self.builder.html_out_dir, doc_name)) as f: with open(os.path.join(self.builder.html_out_dir, doc_name)) as f:
content = f.read() content = f.read()
self.assertTrue(str_to_find in content, "Did not find {} in {}".format(str_to_find, doc_name)) self.assertTrue(str_to_find in content, 'Did not find {} in {}'.format(str_to_find, doc_name))
def test_only_dir(self): def test_only_dir(self):
# Test that ESP32 content was excluded # Test that ESP32 content was excluded
self.assert_str_not_in_doc(ESP32_S2_DOC + ".html", "!ESP32_CONTENT!") self.assert_str_not_in_doc(ESP32_S2_DOC + '.html', '!ESP32_CONTENT!')
# Test that ESP32 S2 content was included # Test that ESP32 S2 content was included
self.assert_str_in_doc(ESP32_S2_DOC + ".html", "!ESP32_S2_CONTENT!") self.assert_str_in_doc(ESP32_S2_DOC + '.html', '!ESP32_S2_CONTENT!')
# Test that BT content was excluded # Test that BT content was excluded
self.assert_str_not_in_doc(ESP32_S2_DOC + ".html", "!BT_CONTENT!") self.assert_str_not_in_doc(ESP32_S2_DOC + '.html', '!BT_CONTENT!')
def test_toctree_filter(self): def test_toctree_filter(self):
# ESP32 page should NOT be built # ESP32 page should NOT be built
esp32_doc = os.path.join(self.builder.html_out_dir, ESP32_DOC + ".html") esp32_doc = os.path.join(self.builder.html_out_dir, ESP32_DOC + '.html')
self.assertFalse(os.path.isfile(esp32_doc), "Found {}".format(esp32_doc)) self.assertFalse(os.path.isfile(esp32_doc), 'Found {}'.format(esp32_doc))
self.assert_str_not_in_doc('index.html', "!ESP32_CONTENT!") self.assert_str_not_in_doc('index.html', '!ESP32_CONTENT!')
esp32s2_doc = os.path.join(self.builder.html_out_dir, ESP32_S2_DOC + ".html") esp32s2_doc = os.path.join(self.builder.html_out_dir, ESP32_S2_DOC + '.html')
self.assertTrue(os.path.isfile(esp32s2_doc), "{} not found".format(esp32s2_doc)) self.assertTrue(os.path.isfile(esp32s2_doc), '{} not found'.format(esp32s2_doc))
# Spot check a few other tags # Spot check a few other tags
# No Bluetooth on ESP32 S2 # No Bluetooth on ESP32 S2
bt_doc = os.path.join(self.builder.html_out_dir, BT_DOC + ".html") bt_doc = os.path.join(self.builder.html_out_dir, BT_DOC + '.html')
self.assertFalse(os.path.isfile(bt_doc), "Found {}".format(bt_doc)) self.assertFalse(os.path.isfile(bt_doc), 'Found {}'.format(bt_doc))
self.assert_str_not_in_doc('index.html', "!BT_CONTENT!") self.assert_str_not_in_doc('index.html', '!BT_CONTENT!')
def test_link_roles(self): def test_link_roles(self):
print("test") print('test')
class TestBuildSubset(unittest.TestCase): class TestBuildSubset(unittest.TestCase):
def test_build_subset(self): def test_build_subset(self):
builder = DocBuilder("test", "_build/test_build_subset", "esp32", "en") builder = DocBuilder('test', '_build/test_build_subset', 'esp32', 'en')
docs_to_build = "esp32_page.rst" docs_to_build = 'esp32_page.rst'
self.assertFalse(builder.build(["-i", docs_to_build])) self.assertFalse(builder.build(['-i', docs_to_build]))
# Check that we only built the input docs # Check that we only built the input docs
bt_doc = os.path.join(builder.html_out_dir, BT_DOC + ".html") bt_doc = os.path.join(builder.html_out_dir, BT_DOC + '.html')
esp32_doc = os.path.join(builder.html_out_dir, ESP32_DOC + ".html") esp32_doc = os.path.join(builder.html_out_dir, ESP32_DOC + '.html')
self.assertFalse(os.path.isfile(bt_doc), "Found {}".format(bt_doc)) self.assertFalse(os.path.isfile(bt_doc), 'Found {}'.format(bt_doc))
self.assertTrue(os.path.isfile(esp32_doc), "Found {}".format(esp32_doc)) self.assertTrue(os.path.isfile(esp32_doc), 'Found {}'.format(esp32_doc))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -3,8 +3,8 @@
import os import os
import sys import sys
import unittest import unittest
from unittest.mock import MagicMock
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from unittest.mock import MagicMock
from sphinx.util import tags from sphinx.util import tags
@ -14,9 +14,7 @@ except ImportError:
sys.path.append('..') sys.path.append('..')
from idf_extensions import exclude_docs from idf_extensions import exclude_docs
from idf_extensions import format_idf_target from idf_extensions import format_idf_target, gen_idf_tools_links, link_roles
from idf_extensions import gen_idf_tools_links
from idf_extensions import link_roles
class TestFormatIdfTarget(unittest.TestCase): class TestFormatIdfTarget(unittest.TestCase):
@ -30,14 +28,14 @@ class TestFormatIdfTarget(unittest.TestCase):
def test_add_subs(self): def test_add_subs(self):
self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_NAME}'], "ESP32") self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_NAME}'], 'ESP32')
self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_PATH_NAME}'], "esp32") self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_PATH_NAME}'], 'esp32')
self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_TOOLCHAIN_NAME}'], "esp32") self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_TOOLCHAIN_NAME}'], 'esp32')
self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_CFG_PREFIX}'], "ESP32") self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_CFG_PREFIX}'], 'ESP32')
self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_TRM_EN_URL}'], self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_TRM_EN_URL}'],
"https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_en.pdf") 'https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_en.pdf')
self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_TRM_CN_URL}'], self.assertEqual(self.str_sub.substitute_strings['{IDF_TARGET_TRM_CN_URL}'],
"https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_cn.pdf") 'https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_cn.pdf')
def test_sub(self): def test_sub(self):
content = ('This is a {IDF_TARGET_NAME}, with {IDF_TARGET_PATH_NAME}/soc.c, compiled with ' content = ('This is a {IDF_TARGET_NAME}, with {IDF_TARGET_PATH_NAME}/soc.c, compiled with '
@ -54,14 +52,14 @@ class TestFormatIdfTarget(unittest.TestCase):
content = ('{IDF_TARGET_TX_PIN:default="IO3", esp32="IO4", esp32s2="IO5"}' content = ('{IDF_TARGET_TX_PIN:default="IO3", esp32="IO4", esp32s2="IO5"}'
'The {IDF_TARGET_NAME} UART {IDF_TARGET_TX_PIN} uses for TX') 'The {IDF_TARGET_NAME} UART {IDF_TARGET_TX_PIN} uses for TX')
expected = "The ESP32 UART IO4 uses for TX" expected = 'The ESP32 UART IO4 uses for TX'
self.assertEqual(self.str_sub.substitute(content), expected) self.assertEqual(self.str_sub.substitute(content), expected)
def test_local_sub_default(self): def test_local_sub_default(self):
content = ('{IDF_TARGET_TX_PIN:default="IO3", esp32s2="IO5"}' content = ('{IDF_TARGET_TX_PIN:default="IO3", esp32s2="IO5"}'
'The {IDF_TARGET_NAME} UART {IDF_TARGET_TX_PIN} uses for TX') 'The {IDF_TARGET_NAME} UART {IDF_TARGET_TX_PIN} uses for TX')
expected = "The ESP32 UART IO3 uses for TX" expected = 'The ESP32 UART IO3 uses for TX'
self.assertEqual(self.str_sub.substitute(content), expected) self.assertEqual(self.str_sub.substitute(content), expected)
def test_local_sub_no_default(self): def test_local_sub_no_default(self):
@ -76,12 +74,12 @@ class TestExclude(unittest.TestCase):
def setUp(self): def setUp(self):
self.app = MagicMock() self.app = MagicMock()
self.app.tags = tags.Tags() self.app.tags = tags.Tags()
self.app.config.conditional_include_dict = {"esp32":["esp32.rst", "bt.rst"], "esp32s2":["esp32s2.rst"]} self.app.config.conditional_include_dict = {'esp32':['esp32.rst', 'bt.rst'], 'esp32s2':['esp32s2.rst']}
self.app.config.docs_to_build = None self.app.config.docs_to_build = None
self.app.config.exclude_patterns = [] self.app.config.exclude_patterns = []
def test_update_exclude_pattern(self): def test_update_exclude_pattern(self):
self.app.tags.add("esp32") self.app.tags.add('esp32')
exclude_docs.update_exclude_patterns(self.app, self.app.config) exclude_docs.update_exclude_patterns(self.app, self.app.config)
docs_to_build = set(self.app.config.conditional_include_dict['esp32']) docs_to_build = set(self.app.config.conditional_include_dict['esp32'])
@ -92,7 +90,7 @@ class TestExclude(unittest.TestCase):
class TestGenIDFToolLinks(unittest.TestCase): class TestGenIDFToolLinks(unittest.TestCase):
def setUp(self): def setUp(self):
self.app = MagicMock() self.app = MagicMock()
self.app.config.build_dir = "_build" self.app.config.build_dir = '_build'
self.app.config.idf_path = os.environ['IDF_PATH'] self.app.config.idf_path = os.environ['IDF_PATH']
def test_gen_idf_tool_links(self): def test_gen_idf_tool_links(self):

View File

@ -9,8 +9,8 @@
try: try:
from conf_common import * # noqa: F403,F401 from conf_common import * # noqa: F403,F401
except ImportError: except ImportError:
import sys
import os import os
import sys
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath('..'))
from conf_common import * # noqa: F403,F401 from conf_common import * # noqa: F403,F401

View File

@ -15,21 +15,22 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
import re import re
import uuid
import subprocess import subprocess
import uuid
from tiny_test_fw import Utility
import ttfw_idf import ttfw_idf
from ble import lib_ble_client from ble import lib_ble_client
from tiny_test_fw import Utility
# When running on local machine execute the following before running this script # When running on local machine execute the following before running this script
# > make app bootloader # > make app bootloader
# > make print_flash_cmd | tail -n 1 > build/download.config # > make print_flash_cmd | tail -n 1 > build/download.config
@ttfw_idf.idf_example_test(env_tag="Example_WIFI_BT") @ttfw_idf.idf_example_test(env_tag='Example_WIFI_BT')
def test_example_app_ble_central(env, extra_data): def test_example_app_ble_central(env, extra_data):
""" """
Steps: Steps:
@ -37,7 +38,7 @@ def test_example_app_ble_central(env, extra_data):
""" """
interface = 'hci0' interface = 'hci0'
adv_host_name = "BleCentTestApp" adv_host_name = 'BleCentTestApp'
adv_iface_index = 0 adv_iface_index = 0
adv_type = 'peripheral' adv_type = 'peripheral'
adv_uuid = '1811' adv_uuid = '1811'
@ -45,15 +46,15 @@ def test_example_app_ble_central(env, extra_data):
subprocess.check_output(['rm','-rf','/var/lib/bluetooth/*']) subprocess.check_output(['rm','-rf','/var/lib/bluetooth/*'])
subprocess.check_output(['hciconfig','hci0','reset']) subprocess.check_output(['hciconfig','hci0','reset'])
# Acquire DUT # Acquire DUT
dut = env.get_dut("blecent", "examples/bluetooth/nimble/blecent", dut_class=ttfw_idf.ESP32DUT) dut = env.get_dut('blecent', 'examples/bluetooth/nimble/blecent', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut.app.binary_path, "blecent.bin") binary_file = os.path.join(dut.app.binary_path, 'blecent.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("blecent_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('blecent_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting blecent example test app") Utility.console_log('Starting blecent example test app')
dut.start_app() dut.start_app()
dut.reset() dut.reset()
@ -62,16 +63,16 @@ def test_example_app_ble_central(env, extra_data):
# Get BLE client module # Get BLE client module
ble_client_obj = lib_ble_client.BLE_Bluez_Client(interface) ble_client_obj = lib_ble_client.BLE_Bluez_Client(interface)
if not ble_client_obj: if not ble_client_obj:
raise RuntimeError("Get DBus-Bluez object failed !!") raise RuntimeError('Get DBus-Bluez object failed !!')
# Discover Bluetooth Adapter and power on # Discover Bluetooth Adapter and power on
is_adapter_set = ble_client_obj.set_adapter() is_adapter_set = ble_client_obj.set_adapter()
if not is_adapter_set: if not is_adapter_set:
raise RuntimeError("Adapter Power On failed !!") raise RuntimeError('Adapter Power On failed !!')
# Write device address to dut # Write device address to dut
dut.expect("BLE Host Task Started", timeout=60) dut.expect('BLE Host Task Started', timeout=60)
dut.write(device_addr + "\n") dut.write(device_addr + '\n')
''' '''
Blecent application run: Blecent application run:
@ -87,22 +88,22 @@ def test_example_app_ble_central(env, extra_data):
ble_client_obj.disconnect() ble_client_obj.disconnect()
# Check dut responses # Check dut responses
dut.expect("Connection established", timeout=60) dut.expect('Connection established', timeout=60)
dut.expect("Service discovery complete; status=0", timeout=60) dut.expect('Service discovery complete; status=0', timeout=60)
print("Service discovery passed\n\tService Discovery Status: 0") print('Service discovery passed\n\tService Discovery Status: 0')
dut.expect("GATT procedure initiated: read;", timeout=60) dut.expect('GATT procedure initiated: read;', timeout=60)
dut.expect("Read complete; status=0", timeout=60) dut.expect('Read complete; status=0', timeout=60)
print("Read passed\n\tSupportedNewAlertCategoryCharacteristic\n\tRead Status: 0") print('Read passed\n\tSupportedNewAlertCategoryCharacteristic\n\tRead Status: 0')
dut.expect("GATT procedure initiated: write;", timeout=60) dut.expect('GATT procedure initiated: write;', timeout=60)
dut.expect("Write complete; status=0", timeout=60) dut.expect('Write complete; status=0', timeout=60)
print("Write passed\n\tAlertNotificationControlPointCharacteristic\n\tWrite Status: 0") print('Write passed\n\tAlertNotificationControlPointCharacteristic\n\tWrite Status: 0')
dut.expect("GATT procedure initiated: write;", timeout=60) dut.expect('GATT procedure initiated: write;', timeout=60)
dut.expect("Subscribe complete; status=0", timeout=60) dut.expect('Subscribe complete; status=0', timeout=60)
print("Subscribe passed\n\tClientCharacteristicConfigurationDescriptor\n\tSubscribe Status: 0") print('Subscribe passed\n\tClientCharacteristicConfigurationDescriptor\n\tSubscribe Status: 0')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -15,20 +15,21 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
import re import re
import subprocess
import threading import threading
import traceback import traceback
import subprocess
try: try:
import Queue import Queue
except ImportError: except ImportError:
import queue as Queue import queue as Queue
from tiny_test_fw import Utility
import ttfw_idf import ttfw_idf
from ble import lib_ble_client from ble import lib_ble_client
from tiny_test_fw import Utility
# When running on local machine execute the following before running this script # When running on local machine execute the following before running this script
# > make app bootloader # > make app bootloader
@ -44,28 +45,28 @@ def blehr_client_task(hr_obj, dut_addr):
# Get BLE client module # Get BLE client module
ble_client_obj = lib_ble_client.BLE_Bluez_Client(interface, devname=ble_devname, devaddr=dut_addr) ble_client_obj = lib_ble_client.BLE_Bluez_Client(interface, devname=ble_devname, devaddr=dut_addr)
if not ble_client_obj: if not ble_client_obj:
raise RuntimeError("Failed to get DBus-Bluez object") raise RuntimeError('Failed to get DBus-Bluez object')
# Discover Bluetooth Adapter and power on # Discover Bluetooth Adapter and power on
is_adapter_set = ble_client_obj.set_adapter() is_adapter_set = ble_client_obj.set_adapter()
if not is_adapter_set: if not is_adapter_set:
raise RuntimeError("Adapter Power On failed !!") raise RuntimeError('Adapter Power On failed !!')
# Connect BLE Device # Connect BLE Device
is_connected = ble_client_obj.connect() is_connected = ble_client_obj.connect()
if not is_connected: if not is_connected:
# Call disconnect to perform cleanup operations before exiting application # Call disconnect to perform cleanup operations before exiting application
ble_client_obj.disconnect() ble_client_obj.disconnect()
raise RuntimeError("Connection to device " + str(ble_devname) + " failed !!") raise RuntimeError('Connection to device ' + str(ble_devname) + ' failed !!')
# Read Services # Read Services
services_ret = ble_client_obj.get_services() services_ret = ble_client_obj.get_services()
if services_ret: if services_ret:
Utility.console_log("\nServices\n") Utility.console_log('\nServices\n')
Utility.console_log(str(services_ret)) Utility.console_log(str(services_ret))
else: else:
ble_client_obj.disconnect() ble_client_obj.disconnect()
raise RuntimeError("Failure: Read Services failed") raise RuntimeError('Failure: Read Services failed')
''' '''
Blehr application run: Blehr application run:
@ -75,9 +76,9 @@ def blehr_client_task(hr_obj, dut_addr):
''' '''
blehr_ret = ble_client_obj.hr_update_simulation(hr_srv_uuid, hr_char_uuid) blehr_ret = ble_client_obj.hr_update_simulation(hr_srv_uuid, hr_char_uuid)
if blehr_ret: if blehr_ret:
Utility.console_log("Success: blehr example test passed") Utility.console_log('Success: blehr example test passed')
else: else:
raise RuntimeError("Failure: blehr example test failed") raise RuntimeError('Failure: blehr example test failed')
# Call disconnect to perform cleanup operations before exiting application # Call disconnect to perform cleanup operations before exiting application
ble_client_obj.disconnect() ble_client_obj.disconnect()
@ -96,7 +97,7 @@ class BleHRThread(threading.Thread):
self.exceptions_queue.put(traceback.format_exc(), block=False) self.exceptions_queue.put(traceback.format_exc(), block=False)
@ttfw_idf.idf_example_test(env_tag="Example_WIFI_BT") @ttfw_idf.idf_example_test(env_tag='Example_WIFI_BT')
def test_example_app_ble_hr(env, extra_data): def test_example_app_ble_hr(env, extra_data):
""" """
Steps: Steps:
@ -110,20 +111,20 @@ def test_example_app_ble_hr(env, extra_data):
subprocess.check_output(['hciconfig','hci0','reset']) subprocess.check_output(['hciconfig','hci0','reset'])
# Acquire DUT # Acquire DUT
dut = env.get_dut("blehr", "examples/bluetooth/nimble/blehr", dut_class=ttfw_idf.ESP32DUT) dut = env.get_dut('blehr', 'examples/bluetooth/nimble/blehr', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut.app.binary_path, "blehr.bin") binary_file = os.path.join(dut.app.binary_path, 'blehr.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("blehr_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('blehr_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting blehr simple example test app") Utility.console_log('Starting blehr simple example test app')
dut.start_app() dut.start_app()
dut.reset() dut.reset()
# Get device address from dut # Get device address from dut
dut_addr = dut.expect(re.compile(r"Device Address: ([a-fA-F0-9:]+)"), timeout=30)[0] dut_addr = dut.expect(re.compile(r'Device Address: ([a-fA-F0-9:]+)'), timeout=30)[0]
exceptions_queue = Queue.Queue() exceptions_queue = Queue.Queue()
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
blehr_thread_obj = BleHRThread(dut_addr, exceptions_queue) blehr_thread_obj = BleHRThread(dut_addr, exceptions_queue)
@ -137,15 +138,15 @@ def test_example_app_ble_hr(env, extra_data):
except Queue.Empty: except Queue.Empty:
break break
else: else:
Utility.console_log("\n" + exception_msg) Utility.console_log('\n' + exception_msg)
if exception_msg: if exception_msg:
raise Exception("Thread did not run successfully") raise Exception('Thread did not run successfully')
# Check dut responses # Check dut responses
dut.expect("subscribe event; cur_notify=1", timeout=30) dut.expect('subscribe event; cur_notify=1', timeout=30)
dut.expect("subscribe event; cur_notify=0", timeout=30) dut.expect('subscribe event; cur_notify=0', timeout=30)
dut.expect("disconnect;", timeout=30) dut.expect('disconnect;', timeout=30)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -15,20 +15,21 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
import re import re
import traceback
import threading
import subprocess import subprocess
import threading
import traceback
try: try:
import Queue import Queue
except ImportError: except ImportError:
import queue as Queue import queue as Queue
from tiny_test_fw import Utility
import ttfw_idf import ttfw_idf
from ble import lib_ble_client from ble import lib_ble_client
from tiny_test_fw import Utility
# When running on local machine execute the following before running this script # When running on local machine execute the following before running this script
# > make app bootloader # > make app bootloader
@ -44,45 +45,45 @@ def bleprph_client_task(prph_obj, dut, dut_addr):
# Get BLE client module # Get BLE client module
ble_client_obj = lib_ble_client.BLE_Bluez_Client(interface, devname=ble_devname, devaddr=dut_addr) ble_client_obj = lib_ble_client.BLE_Bluez_Client(interface, devname=ble_devname, devaddr=dut_addr)
if not ble_client_obj: if not ble_client_obj:
raise RuntimeError("Failed to get DBus-Bluez object") raise RuntimeError('Failed to get DBus-Bluez object')
# Discover Bluetooth Adapter and power on # Discover Bluetooth Adapter and power on
is_adapter_set = ble_client_obj.set_adapter() is_adapter_set = ble_client_obj.set_adapter()
if not is_adapter_set: if not is_adapter_set:
raise RuntimeError("Adapter Power On failed !!") raise RuntimeError('Adapter Power On failed !!')
# Connect BLE Device # Connect BLE Device
is_connected = ble_client_obj.connect() is_connected = ble_client_obj.connect()
if not is_connected: if not is_connected:
# Call disconnect to perform cleanup operations before exiting application # Call disconnect to perform cleanup operations before exiting application
ble_client_obj.disconnect() ble_client_obj.disconnect()
raise RuntimeError("Connection to device " + ble_devname + " failed !!") raise RuntimeError('Connection to device ' + ble_devname + ' failed !!')
# Check dut responses # Check dut responses
dut.expect("GAP procedure initiated: advertise;", timeout=30) dut.expect('GAP procedure initiated: advertise;', timeout=30)
# Read Services # Read Services
services_ret = ble_client_obj.get_services(srv_uuid) services_ret = ble_client_obj.get_services(srv_uuid)
if services_ret: if services_ret:
Utility.console_log("\nServices\n") Utility.console_log('\nServices\n')
Utility.console_log(str(services_ret)) Utility.console_log(str(services_ret))
else: else:
ble_client_obj.disconnect() ble_client_obj.disconnect()
raise RuntimeError("Failure: Read Services failed") raise RuntimeError('Failure: Read Services failed')
# Read Characteristics # Read Characteristics
chars_ret = {} chars_ret = {}
chars_ret = ble_client_obj.read_chars() chars_ret = ble_client_obj.read_chars()
if chars_ret: if chars_ret:
Utility.console_log("\nCharacteristics retrieved") Utility.console_log('\nCharacteristics retrieved')
for path, props in chars_ret.items(): for path, props in chars_ret.items():
Utility.console_log("\n\tCharacteristic: " + str(path)) Utility.console_log('\n\tCharacteristic: ' + str(path))
Utility.console_log("\tCharacteristic UUID: " + str(props[2])) Utility.console_log('\tCharacteristic UUID: ' + str(props[2]))
Utility.console_log("\tValue: " + str(props[0])) Utility.console_log('\tValue: ' + str(props[0]))
Utility.console_log("\tProperties: : " + str(props[1])) Utility.console_log('\tProperties: : ' + str(props[1]))
else: else:
ble_client_obj.disconnect() ble_client_obj.disconnect()
raise RuntimeError("Failure: Read Characteristics failed") raise RuntimeError('Failure: Read Characteristics failed')
''' '''
Write Characteristics Write Characteristics
@ -91,15 +92,15 @@ def bleprph_client_task(prph_obj, dut, dut_addr):
chars_ret_on_write = {} chars_ret_on_write = {}
chars_ret_on_write = ble_client_obj.write_chars(b'A') chars_ret_on_write = ble_client_obj.write_chars(b'A')
if chars_ret_on_write: if chars_ret_on_write:
Utility.console_log("\nCharacteristics after write operation") Utility.console_log('\nCharacteristics after write operation')
for path, props in chars_ret_on_write.items(): for path, props in chars_ret_on_write.items():
Utility.console_log("\n\tCharacteristic:" + str(path)) Utility.console_log('\n\tCharacteristic:' + str(path))
Utility.console_log("\tCharacteristic UUID: " + str(props[2])) Utility.console_log('\tCharacteristic UUID: ' + str(props[2]))
Utility.console_log("\tValue:" + str(props[0])) Utility.console_log('\tValue:' + str(props[0]))
Utility.console_log("\tProperties: : " + str(props[1])) Utility.console_log('\tProperties: : ' + str(props[1]))
else: else:
ble_client_obj.disconnect() ble_client_obj.disconnect()
raise RuntimeError("Failure: Write Characteristics failed") raise RuntimeError('Failure: Write Characteristics failed')
# Call disconnect to perform cleanup operations before exiting application # Call disconnect to perform cleanup operations before exiting application
ble_client_obj.disconnect() ble_client_obj.disconnect()
@ -119,7 +120,7 @@ class BlePrphThread(threading.Thread):
self.exceptions_queue.put(traceback.format_exc(), block=False) self.exceptions_queue.put(traceback.format_exc(), block=False)
@ttfw_idf.idf_example_test(env_tag="Example_WIFI_BT") @ttfw_idf.idf_example_test(env_tag='Example_WIFI_BT')
def test_example_app_ble_peripheral(env, extra_data): def test_example_app_ble_peripheral(env, extra_data):
""" """
Steps: Steps:
@ -133,20 +134,20 @@ def test_example_app_ble_peripheral(env, extra_data):
subprocess.check_output(['hciconfig','hci0','reset']) subprocess.check_output(['hciconfig','hci0','reset'])
# Acquire DUT # Acquire DUT
dut = env.get_dut("bleprph", "examples/bluetooth/nimble/bleprph", dut_class=ttfw_idf.ESP32DUT) dut = env.get_dut('bleprph', 'examples/bluetooth/nimble/bleprph', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut.app.binary_path, "bleprph.bin") binary_file = os.path.join(dut.app.binary_path, 'bleprph.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("bleprph_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('bleprph_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting bleprph simple example test app") Utility.console_log('Starting bleprph simple example test app')
dut.start_app() dut.start_app()
dut.reset() dut.reset()
# Get device address from dut # Get device address from dut
dut_addr = dut.expect(re.compile(r"Device Address: ([a-fA-F0-9:]+)"), timeout=30)[0] dut_addr = dut.expect(re.compile(r'Device Address: ([a-fA-F0-9:]+)'), timeout=30)[0]
exceptions_queue = Queue.Queue() exceptions_queue = Queue.Queue()
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
@ -161,14 +162,14 @@ def test_example_app_ble_peripheral(env, extra_data):
except Queue.Empty: except Queue.Empty:
break break
else: else:
Utility.console_log("\n" + exception_msg) Utility.console_log('\n' + exception_msg)
if exception_msg: if exception_msg:
raise Exception("Thread did not run successfully") raise Exception('Thread did not run successfully')
# Check dut responses # Check dut responses
dut.expect("connection established; status=0", timeout=30) dut.expect('connection established; status=0', timeout=30)
dut.expect("disconnect;", timeout=30) dut.expect('disconnect;', timeout=30)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re import re
import ttfw_idf import ttfw_idf

View File

@ -1,20 +1,19 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import division from __future__ import division, print_function, unicode_literals
from __future__ import print_function
from __future__ import unicode_literals import hashlib
import re import os
import os import re
import hashlib
from tiny_test_fw import Utility
import ttfw_idf import ttfw_idf
from tiny_test_fw import Utility
def verify_elf_sha256_embedding(dut): def verify_elf_sha256_embedding(dut):
elf_file = os.path.join(dut.app.binary_path, "blink.elf") elf_file = os.path.join(dut.app.binary_path, 'blink.elf')
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
with open(elf_file, "rb") as f: with open(elf_file, 'rb') as f:
sha256.update(f.read()) sha256.update(f.read())
sha256_expected = sha256.hexdigest() sha256_expected = sha256.hexdigest()
@ -28,12 +27,12 @@ def verify_elf_sha256_embedding(dut):
raise ValueError('ELF file SHA256 mismatch') raise ValueError('ELF file SHA256 mismatch')
@ttfw_idf.idf_example_test(env_tag="Example_GENERIC") @ttfw_idf.idf_example_test(env_tag='Example_GENERIC')
def test_examples_blink(env, extra_data): def test_examples_blink(env, extra_data):
dut = env.get_dut("blink", "examples/get-started/blink", dut_class=ttfw_idf.ESP32DUT) dut = env.get_dut('blink', 'examples/get-started/blink', dut_class=ttfw_idf.ESP32DUT)
binary_file = os.path.join(dut.app.binary_path, "blink.bin") binary_file = os.path.join(dut.app.binary_path, 'blink.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("blink_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('blink_bin_size', '{}KB'.format(bin_size // 1024))
dut.start_app() dut.start_app()

View File

@ -1,16 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import division from __future__ import division, print_function, unicode_literals
from __future__ import print_function
from __future__ import unicode_literals
import ttfw_idf import ttfw_idf
@ttfw_idf.idf_example_test(env_tag="Example_GENERIC", target=['esp32', 'esp32s2'], ci_target=['esp32']) @ttfw_idf.idf_example_test(env_tag='Example_GENERIC', target=['esp32', 'esp32s2'], ci_target=['esp32'])
def test_examples_hello_world(env, extra_data): def test_examples_hello_world(env, extra_data):
app_name = 'hello_world' app_name = 'hello_world'
dut = env.get_dut(app_name, "examples/get-started/hello_world") dut = env.get_dut(app_name, 'examples/get-started/hello_world')
dut.start_app() dut.start_app()
res = dut.expect(ttfw_idf.MINIMUM_FREE_HEAP_SIZE_RE) res = dut.expect(ttfw_idf.MINIMUM_FREE_HEAP_SIZE_RE)
if not res: if not res:

View File

@ -1,16 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import division from __future__ import division, print_function, unicode_literals
from __future__ import print_function
from __future__ import unicode_literals
import ttfw_idf import ttfw_idf
@ttfw_idf.idf_example_test(env_tag="Example_TWAI1", target=['esp32', 'esp32s2'], ci_target=['esp32']) @ttfw_idf.idf_example_test(env_tag='Example_TWAI1', target=['esp32', 'esp32s2'], ci_target=['esp32'])
def test_examples_gpio(env, extra_data): def test_examples_gpio(env, extra_data):
app_name = "gpio" app_name = 'gpio'
dut = env.get_dut(app_name, "examples/peripherals/gpio/generic_gpio") dut = env.get_dut(app_name, 'examples/peripherals/gpio/generic_gpio')
dut.start_app() dut.start_app()
res = dut.expect(ttfw_idf.MINIMUM_FREE_HEAP_SIZE_RE) res = dut.expect(ttfw_idf.MINIMUM_FREE_HEAP_SIZE_RE)
if not res: if not res:

View File

@ -10,27 +10,27 @@ def test_i2ctools_example(env, extra_data):
# Get device under test, flash and start example. "i2ctool" must be defined in EnvConfig # Get device under test, flash and start example. "i2ctool" must be defined in EnvConfig
dut = env.get_dut('i2ctools', 'examples/peripherals/i2c/i2c_tools', dut_class=ttfw_idf.ESP32DUT) dut = env.get_dut('i2ctools', 'examples/peripherals/i2c/i2c_tools', dut_class=ttfw_idf.ESP32DUT)
dut.start_app() dut.start_app()
dut.expect("i2c-tools>", timeout=EXPECT_TIMEOUT) dut.expect('i2c-tools>', timeout=EXPECT_TIMEOUT)
# Get i2c address # Get i2c address
dut.write("i2cdetect") dut.write('i2cdetect')
dut.expect("5b", timeout=EXPECT_TIMEOUT) dut.expect('5b', timeout=EXPECT_TIMEOUT)
# Get chip ID # Get chip ID
dut.write("i2cget -c 0x5b -r 0x20 -l 1") dut.write('i2cget -c 0x5b -r 0x20 -l 1')
dut.expect("0x81", timeout=EXPECT_TIMEOUT) dut.expect('0x81', timeout=EXPECT_TIMEOUT)
# Reset sensor # Reset sensor
dut.write("i2cset -c 0x5b -r 0xFF 0x11 0xE5 0x72 0x8A") dut.write('i2cset -c 0x5b -r 0xFF 0x11 0xE5 0x72 0x8A')
dut.expect("OK", timeout=EXPECT_TIMEOUT) dut.expect('OK', timeout=EXPECT_TIMEOUT)
# Get status # Get status
dut.write("i2cget -c 0x5b -r 0x00 -l 1") dut.write('i2cget -c 0x5b -r 0x00 -l 1')
dut.expect_any("0x10", timeout=EXPECT_TIMEOUT) dut.expect_any('0x10', timeout=EXPECT_TIMEOUT)
# Change work mode # Change work mode
dut.write("i2cset -c 0x5b -r 0xF4") dut.write('i2cset -c 0x5b -r 0xF4')
dut.expect("OK", timeout=EXPECT_TIMEOUT) dut.expect('OK', timeout=EXPECT_TIMEOUT)
dut.write("i2cset -c 0x5b -r 0x01 0x10") dut.write('i2cset -c 0x5b -r 0x01 0x10')
dut.expect("OK", timeout=EXPECT_TIMEOUT) dut.expect('OK', timeout=EXPECT_TIMEOUT)
# Get new status # Get new status
dut.write("i2cget -c 0x5b -r 0x00 -l 1") dut.write('i2cget -c 0x5b -r 0x00 -l 1')
dut.expect_any("0x98", "0x90", timeout=EXPECT_TIMEOUT) dut.expect_any('0x98', '0x90', timeout=EXPECT_TIMEOUT)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,43 +1,44 @@
from __future__ import print_function from __future__ import print_function
from builtins import range
import os import os
import wave
import struct import struct
import wave
from builtins import range
def get_wave_array_str(filename, target_bits): def get_wave_array_str(filename, target_bits):
wave_read = wave.open(filename, "r") wave_read = wave.open(filename, 'r')
array_str = "" array_str = ''
nchannels, sampwidth, framerate, nframes, comptype, compname = wave_read.getparams() nchannels, sampwidth, framerate, nframes, comptype, compname = wave_read.getparams()
sampwidth *= 8 sampwidth *= 8
for i in range(wave_read.getnframes()): for i in range(wave_read.getnframes()):
val, = struct.unpack("<H", wave_read.readframes(1)) val, = struct.unpack('<H', wave_read.readframes(1))
scale_val = (1 << target_bits) - 1 scale_val = (1 << target_bits) - 1
cur_lim = (1 << sampwidth) - 1 cur_lim = (1 << sampwidth) - 1
# scale current data to 8-bit data # scale current data to 8-bit data
val = val * scale_val / cur_lim val = val * scale_val / cur_lim
val = int(val + ((scale_val + 1) // 2)) & scale_val val = int(val + ((scale_val + 1) // 2)) & scale_val
array_str += "0x%x, " % (val) array_str += '0x%x, ' % (val)
if (i + 1) % 16 == 0: if (i + 1) % 16 == 0:
array_str += "\n" array_str += '\n'
return array_str return array_str
def gen_wave_table(wav_file_list, target_file_name, scale_bits=8): def gen_wave_table(wav_file_list, target_file_name, scale_bits=8):
with open(target_file_name, "w") as audio_table: with open(target_file_name, 'w') as audio_table:
print('#include <stdio.h>', file=audio_table) print('#include <stdio.h>', file=audio_table)
print('const unsigned char audio_table[] = {', file=audio_table) print('const unsigned char audio_table[] = {', file=audio_table)
for wav in wav_file_list: for wav in wav_file_list:
print("processing: {}".format(wav)) print('processing: {}'.format(wav))
print(get_wave_array_str(filename=wav, target_bits=scale_bits), file=audio_table) print(get_wave_array_str(filename=wav, target_bits=scale_bits), file=audio_table)
print('};\n', file=audio_table) print('};\n', file=audio_table)
print("Done...") print('Done...')
if __name__ == '__main__': if __name__ == '__main__':
print("Generating audio array...") print('Generating audio array...')
wav_list = [] wav_list = []
for filename in os.listdir("./"): for filename in os.listdir('./'):
if filename.endswith(".wav"): if filename.endswith('.wav'):
wav_list.append(filename) wav_list.append(filename)
gen_wave_table(wav_file_list=wav_list, target_file_name="audio_example_file.h") gen_wave_table(wav_file_list=wav_list, target_file_name='audio_example_file.h')

View File

@ -8,19 +8,19 @@ EXPECT_TIMEOUT = 20
@ttfw_idf.idf_example_test(env_tag='Example_RMT_IR_PROTOCOLS') @ttfw_idf.idf_example_test(env_tag='Example_RMT_IR_PROTOCOLS')
def test_examples_rmt_ir_protocols(env, extra_data): def test_examples_rmt_ir_protocols(env, extra_data):
dut = env.get_dut('ir_protocols_example', 'examples/peripherals/rmt/ir_protocols', app_config_name='nec') dut = env.get_dut('ir_protocols_example', 'examples/peripherals/rmt/ir_protocols', app_config_name='nec')
print("Using binary path: {}".format(dut.app.binary_path)) print('Using binary path: {}'.format(dut.app.binary_path))
dut.start_app() dut.start_app()
dut.expect("example: Send command 0x20 to address 0x10", timeout=EXPECT_TIMEOUT) dut.expect('example: Send command 0x20 to address 0x10', timeout=EXPECT_TIMEOUT)
dut.expect("Scan Code --- addr: 0x0010 cmd: 0x0020", timeout=EXPECT_TIMEOUT) dut.expect('Scan Code --- addr: 0x0010 cmd: 0x0020', timeout=EXPECT_TIMEOUT)
dut.expect("Scan Code (repeat) --- addr: 0x0010 cmd: 0x0020", timeout=EXPECT_TIMEOUT) dut.expect('Scan Code (repeat) --- addr: 0x0010 cmd: 0x0020', timeout=EXPECT_TIMEOUT)
env.close_dut(dut.name) env.close_dut(dut.name)
dut = env.get_dut('ir_protocols_example', 'examples/peripherals/rmt/ir_protocols', app_config_name='rc5') dut = env.get_dut('ir_protocols_example', 'examples/peripherals/rmt/ir_protocols', app_config_name='rc5')
print("Using binary path: {}".format(dut.app.binary_path)) print('Using binary path: {}'.format(dut.app.binary_path))
dut.start_app() dut.start_app()
dut.expect("example: Send command 0x20 to address 0x10", timeout=EXPECT_TIMEOUT) dut.expect('example: Send command 0x20 to address 0x10', timeout=EXPECT_TIMEOUT)
dut.expect("Scan Code --- addr: 0x0010 cmd: 0x0020", timeout=EXPECT_TIMEOUT) dut.expect('Scan Code --- addr: 0x0010 cmd: 0x0020', timeout=EXPECT_TIMEOUT)
dut.expect("Scan Code (repeat) --- addr: 0x0010 cmd: 0x0020", timeout=EXPECT_TIMEOUT) dut.expect('Scan Code (repeat) --- addr: 0x0010 cmd: 0x0020', timeout=EXPECT_TIMEOUT)
env.close_dut(dut.name) env.close_dut(dut.name)

View File

@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tiny_test_fw import TinyFW
import ttfw_idf import ttfw_idf
from tiny_test_fw import TinyFW
@ttfw_idf.idf_example_test(env_tag="Example_SDIO", ignore=True) @ttfw_idf.idf_example_test(env_tag='Example_SDIO', ignore=True)
def test_example_sdio_communication(env, extra_data): def test_example_sdio_communication(env, extra_data):
""" """
Configurations Configurations
@ -36,88 +36,88 @@ def test_example_sdio_communication(env, extra_data):
or use sdio test board, which has two wrover modules connect to a same FT3232 or use sdio test board, which has two wrover modules connect to a same FT3232
Assume that first dut is host and second is slave Assume that first dut is host and second is slave
""" """
dut1 = env.get_dut("sdio_host", "examples/peripherals/sdio/host", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('sdio_host', 'examples/peripherals/sdio/host', dut_class=ttfw_idf.ESP32DUT)
dut2 = env.get_dut("sdio_slave", "examples/peripherals/sdio/slave", dut_class=ttfw_idf.ESP32DUT) dut2 = env.get_dut('sdio_slave', 'examples/peripherals/sdio/slave', dut_class=ttfw_idf.ESP32DUT)
dut1.start_app() dut1.start_app()
# wait until the master is ready to setup the slave # wait until the master is ready to setup the slave
dut1.expect("host ready, start initializing slave...") dut1.expect('host ready, start initializing slave...')
dut2.start_app() dut2.start_app()
dut1.expect("0a 0d 10 13 16 19 1c 1f 22 25 28 2b 2e 31 34 37") dut1.expect('0a 0d 10 13 16 19 1c 1f 22 25 28 2b 2e 31 34 37')
dut1.expect("3a 3d 40 43 46 49 4c 4f 52 55 58 5b 00 00 00 00") dut1.expect('3a 3d 40 43 46 49 4c 4f 52 55 58 5b 00 00 00 00')
dut1.expect("6a 6d 70 73 76 79 7c 7f 82 85 88 8b 8e 91 94 97") dut1.expect('6a 6d 70 73 76 79 7c 7f 82 85 88 8b 8e 91 94 97')
dut1.expect("9a 9d a0 a3 a6 a9 ac af b2 b5 b8 bb be c1 c4 c7") dut1.expect('9a 9d a0 a3 a6 a9 ac af b2 b5 b8 bb be c1 c4 c7')
dut2.expect("================ JOB_WRITE_REG ================") dut2.expect('================ JOB_WRITE_REG ================')
dut2.expect("0a 0d 10 13 16 19 1c 1f 22 25 28 2b 2e 31 34 37") dut2.expect('0a 0d 10 13 16 19 1c 1f 22 25 28 2b 2e 31 34 37')
dut2.expect("3a 3d 40 43 46 49 4c 4f 52 55 58 5b 00 00 00 00") dut2.expect('3a 3d 40 43 46 49 4c 4f 52 55 58 5b 00 00 00 00')
dut2.expect("6a 6d 70 73 76 79 7c 7f 82 85 88 8b 8e 91 94 97") dut2.expect('6a 6d 70 73 76 79 7c 7f 82 85 88 8b 8e 91 94 97')
dut2.expect("9a 9d a0 a3 a6 a9 ac af b2 b5 b8 bb be c1 c4 c7") dut2.expect('9a 9d a0 a3 a6 a9 ac af b2 b5 b8 bb be c1 c4 c7')
dut1.expect("host int: 0") dut1.expect('host int: 0')
dut1.expect("host int: 1") dut1.expect('host int: 1')
dut1.expect("host int: 2") dut1.expect('host int: 2')
dut1.expect("host int: 3") dut1.expect('host int: 3')
dut1.expect("host int: 4") dut1.expect('host int: 4')
dut1.expect("host int: 5") dut1.expect('host int: 5')
dut1.expect("host int: 6") dut1.expect('host int: 6')
dut1.expect("host int: 7") dut1.expect('host int: 7')
dut1.expect("host int: 0") dut1.expect('host int: 0')
dut1.expect("host int: 1") dut1.expect('host int: 1')
dut1.expect("host int: 2") dut1.expect('host int: 2')
dut1.expect("host int: 3") dut1.expect('host int: 3')
dut1.expect("host int: 4") dut1.expect('host int: 4')
dut1.expect("host int: 5") dut1.expect('host int: 5')
dut1.expect("host int: 6") dut1.expect('host int: 6')
dut1.expect("host int: 7") dut1.expect('host int: 7')
dut2.expect("================ JOB_SEND_INT ================") dut2.expect('================ JOB_SEND_INT ================')
dut2.expect("================ JOB_SEND_INT ================") dut2.expect('================ JOB_SEND_INT ================')
dut1.expect("send packet length: 3") dut1.expect('send packet length: 3')
dut1.expect("send packet length: 6") dut1.expect('send packet length: 6')
dut1.expect("send packet length: 12") dut1.expect('send packet length: 12')
dut1.expect("send packet length: 128") dut1.expect('send packet length: 128')
dut1.expect("send packet length: 511") dut1.expect('send packet length: 511')
dut1.expect("send packet length: 512") dut1.expect('send packet length: 512')
dut2.expect("recv len: 3") dut2.expect('recv len: 3')
dut2.expect("recv len: 6") dut2.expect('recv len: 6')
dut2.expect("recv len: 12") dut2.expect('recv len: 12')
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
# 511 # 511
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
dut2.expect("recv len: 127") dut2.expect('recv len: 127')
# 512 # 512
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
dut2.expect("recv len: 128") dut2.expect('recv len: 128')
dut1.expect("receive data, size: 3") dut1.expect('receive data, size: 3')
dut1.expect("receive data, size: 6") dut1.expect('receive data, size: 6')
dut1.expect("receive data, size: 12") dut1.expect('receive data, size: 12')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
dut1.expect("receive data, size: 127") dut1.expect('receive data, size: 127')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
dut1.expect("receive data, size: 128") dut1.expect('receive data, size: 128')
# the last valid line of one round # the last valid line of one round
dut1.expect("ce d3 d8 dd e2 e7 ec f1 f6 fb 00 05 0a 0f 14 19") dut1.expect('ce d3 d8 dd e2 e7 ec f1 f6 fb 00 05 0a 0f 14 19')
# the first 2 lines of the second round # the first 2 lines of the second round
dut1.expect("46 4b 50") dut1.expect('46 4b 50')
dut1.expect("5a 5f 64 69 6e 73") dut1.expect('5a 5f 64 69 6e 73')
if __name__ == '__main__': if __name__ == '__main__':
TinyFW.set_default_config(env_config_file="EnvConfigTemplate.yml", dut=ttfw_idf.IDFDUT) TinyFW.set_default_config(env_config_file='EnvConfigTemplate.yml', dut=ttfw_idf.IDFDUT)
test_example_sdio_communication() test_example_sdio_communication()

View File

@ -4,7 +4,7 @@ from __future__ import print_function
import ttfw_idf import ttfw_idf
# TWAI Self Test Example constants # TWAI Self Test Example constants
STR_EXPECT = ("TWAI Alert and Recovery: Driver installed", "TWAI Alert and Recovery: Driver uninstalled") STR_EXPECT = ('TWAI Alert and Recovery: Driver installed', 'TWAI Alert and Recovery: Driver uninstalled')
EXPECT_TIMEOUT = 20 EXPECT_TIMEOUT = 20

View File

@ -6,9 +6,9 @@ from threading import Thread
import ttfw_idf import ttfw_idf
# Define tuple of strings to expect for each DUT. # Define tuple of strings to expect for each DUT.
master_expect = ("TWAI Master: Driver installed", "TWAI Master: Driver uninstalled") master_expect = ('TWAI Master: Driver installed', 'TWAI Master: Driver uninstalled')
slave_expect = ("TWAI Slave: Driver installed", "TWAI Slave: Driver uninstalled") slave_expect = ('TWAI Slave: Driver installed', 'TWAI Slave: Driver uninstalled')
listen_only_expect = ("TWAI Listen Only: Driver installed", "TWAI Listen Only: Driver uninstalled") listen_only_expect = ('TWAI Listen Only: Driver installed', 'TWAI Listen Only: Driver uninstalled')
def dut_thread_callback(**kwargs): def dut_thread_callback(**kwargs):
@ -31,11 +31,11 @@ def dut_thread_callback(**kwargs):
def test_twai_network_example(env, extra_data): def test_twai_network_example(env, extra_data):
# Get device under test. "dut1", "dut2", and "dut3" must be properly defined in EnvConfig # Get device under test. "dut1", "dut2", and "dut3" must be properly defined in EnvConfig
dut_master = env.get_dut("dut1", "examples/peripherals/twai/twai_network/twai_network_master", dut_master = env.get_dut('dut1', 'examples/peripherals/twai/twai_network/twai_network_master',
dut_class=ttfw_idf.ESP32DUT) dut_class=ttfw_idf.ESP32DUT)
dut_slave = env.get_dut("dut2", "examples/peripherals/twai/twai_network/twai_network_slave", dut_slave = env.get_dut('dut2', 'examples/peripherals/twai/twai_network/twai_network_slave',
dut_class=ttfw_idf.ESP32DUT) dut_class=ttfw_idf.ESP32DUT)
dut_listen_only = env.get_dut("dut3", "examples/peripherals/twai/twai_network/twai_network_listen_only", dut_listen_only = env.get_dut('dut3', 'examples/peripherals/twai/twai_network/twai_network_listen_only',
dut_class=ttfw_idf.ESP32DUT) dut_class=ttfw_idf.ESP32DUT)
# Flash app onto each DUT, each DUT is reset again at the start of each thread # Flash app onto each DUT, each DUT is reset again at the start of each thread
@ -45,14 +45,14 @@ def test_twai_network_example(env, extra_data):
# Create dict of keyword arguments for each dut # Create dict of keyword arguments for each dut
results = [[False], [False], [False]] results = [[False], [False], [False]]
master_kwargs = {"dut": dut_master, "result": results[0], "expected": master_expect} master_kwargs = {'dut': dut_master, 'result': results[0], 'expected': master_expect}
slave_kwargs = {"dut": dut_slave, "result": results[1], "expected": slave_expect} slave_kwargs = {'dut': dut_slave, 'result': results[1], 'expected': slave_expect}
listen_only_kwargs = {"dut": dut_listen_only, "result": results[2], "expected": listen_only_expect} listen_only_kwargs = {'dut': dut_listen_only, 'result': results[2], 'expected': listen_only_expect}
# Create thread for each dut # Create thread for each dut
dut_master_thread = Thread(target=dut_thread_callback, name="Master Thread", kwargs=master_kwargs) dut_master_thread = Thread(target=dut_thread_callback, name='Master Thread', kwargs=master_kwargs)
dut_slave_thread = Thread(target=dut_thread_callback, name="Slave Thread", kwargs=slave_kwargs) dut_slave_thread = Thread(target=dut_thread_callback, name='Slave Thread', kwargs=slave_kwargs)
dut_listen_only_thread = Thread(target=dut_thread_callback, name="Listen Only Thread", kwargs=listen_only_kwargs) dut_listen_only_thread = Thread(target=dut_thread_callback, name='Listen Only Thread', kwargs=listen_only_kwargs)
# Start each thread # Start each thread
dut_listen_only_thread.start() dut_listen_only_thread.start()
@ -67,7 +67,7 @@ def test_twai_network_example(env, extra_data):
# check each thread ran to completion # check each thread ran to completion
for result in results: for result in results:
if result[0] is not True: if result[0] is not True:
raise Exception("One or more threads did not run successfully") raise Exception('One or more threads did not run successfully')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -3,9 +3,8 @@ from __future__ import print_function
import ttfw_idf import ttfw_idf
# TWAI Self Test Example constants # TWAI Self Test Example constants
STR_EXPECT = ("TWAI Self Test: Driver installed", "TWAI Self Test: Driver uninstalled") STR_EXPECT = ('TWAI Self Test: Driver installed', 'TWAI Self Test: Driver uninstalled')
EXPECT_TIMEOUT = 20 EXPECT_TIMEOUT = 20

View File

@ -1,21 +1,21 @@
import re
import os import os
import re
import socket import socket
from threading import Thread
import time import time
from threading import Thread
import ttfw_idf import ttfw_idf
global g_client_response global g_client_response
global g_msg_to_client global g_msg_to_client
g_client_response = b"" g_client_response = b''
g_msg_to_client = b" 3XYZ" g_msg_to_client = b' 3XYZ'
def get_my_ip(): def get_my_ip():
s1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s1.connect(("8.8.8.8", 80)) s1.connect(('8.8.8.8', 80))
my_ip = s1.getsockname()[0] my_ip = s1.getsockname()[0]
s1.close() s1.close()
return my_ip return my_ip
@ -23,14 +23,14 @@ def get_my_ip():
def chat_server_sketch(my_ip): def chat_server_sketch(my_ip):
global g_client_response global g_client_response
print("Starting the server on {}".format(my_ip)) print('Starting the server on {}'.format(my_ip))
port = 2222 port = 2222
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(600) s.settimeout(600)
s.bind((my_ip, port)) s.bind((my_ip, port))
s.listen(1) s.listen(1)
q,addr = s.accept() q,addr = s.accept()
print("connection accepted") print('connection accepted')
q.settimeout(30) q.settimeout(30)
q.send(g_msg_to_client) q.send(g_msg_to_client)
data = q.recv(1024) data = q.recv(1024)
@ -39,12 +39,12 @@ def chat_server_sketch(my_ip):
g_client_response = data g_client_response = data
else: else:
g_client_response = q.recv(1024) g_client_response = q.recv(1024)
print("received from client {}".format(g_client_response)) print('received from client {}'.format(g_client_response))
s.close() s.close()
print("server closed") print('server closed')
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_asio_chat_client(env, extra_data): def test_examples_protocol_asio_chat_client(env, extra_data):
""" """
steps: | steps: |
@ -57,19 +57,19 @@ def test_examples_protocol_asio_chat_client(env, extra_data):
""" """
global g_client_response global g_client_response
global g_msg_to_client global g_msg_to_client
test_msg = "ABC" test_msg = 'ABC'
dut1 = env.get_dut("chat_client", "examples/protocols/asio/chat_client", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('chat_client', 'examples/protocols/asio/chat_client', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "asio_chat_client.bin") binary_file = os.path.join(dut1.app.binary_path, 'asio_chat_client.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("asio_chat_client_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('asio_chat_client_size', '{}KB'.format(bin_size // 1024))
# 1. start a tcp server on the host # 1. start a tcp server on the host
host_ip = get_my_ip() host_ip = get_my_ip()
thread1 = Thread(target=chat_server_sketch, args=(host_ip,)) thread1 = Thread(target=chat_server_sketch, args=(host_ip,))
thread1.start() thread1.start()
# 2. start the dut test and wait till client gets IP address # 2. start the dut test and wait till client gets IP address
dut1.start_app() dut1.start_app()
dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30) dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)
# 3. send host's IP to the client i.e. the `dut1` # 3. send host's IP to the client i.e. the `dut1`
dut1.write(host_ip) dut1.write(host_ip)
# 4. client `dut1` should receive a message # 4. client `dut1` should receive a message
@ -82,10 +82,10 @@ def test_examples_protocol_asio_chat_client(env, extra_data):
print(g_client_response) print(g_client_response)
# 6. evaluate host_server received this message # 6. evaluate host_server received this message
if (g_client_response[4:7] == test_msg): if (g_client_response[4:7] == test_msg):
print("PASS: Received correct message") print('PASS: Received correct message')
pass pass
else: else:
print("Failure!") print('Failure!')
raise ValueError('Wrong data received from asi tcp server: {} (expected:{})'.format(g_client_response[4:7], test_msg)) raise ValueError('Wrong data received from asi tcp server: {} (expected:{})'.format(g_client_response[4:7], test_msg))
thread1.join() thread1.join()

View File

@ -1,11 +1,11 @@
import re
import os import os
import re
import socket import socket
import ttfw_idf import ttfw_idf
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_asio_chat_server(env, extra_data): def test_examples_protocol_asio_chat_server(env, extra_data):
""" """
steps: | steps: |
@ -14,16 +14,16 @@ def test_examples_protocol_asio_chat_server(env, extra_data):
3. Test connects to server and sends a test message 3. Test connects to server and sends a test message
4. Test evaluates received test message from server 4. Test evaluates received test message from server
""" """
test_msg = b" 4ABC\n" test_msg = b' 4ABC\n'
dut1 = env.get_dut("chat_server", "examples/protocols/asio/chat_server", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('chat_server', 'examples/protocols/asio/chat_server', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "asio_chat_server.bin") binary_file = os.path.join(dut1.app.binary_path, 'asio_chat_server.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("asio_chat_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('asio_chat_server_bin_size', '{}KB'.format(bin_size // 1024))
# 1. start test # 1. start test
dut1.start_app() dut1.start_app()
# 2. get the server IP address # 2. get the server IP address
data = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30) data = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)
# 3. create tcp client and connect to server # 3. create tcp client and connect to server
cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
cli.settimeout(30) cli.settimeout(30)
@ -32,10 +32,10 @@ def test_examples_protocol_asio_chat_server(env, extra_data):
data = cli.recv(1024) data = cli.recv(1024)
# 4. check the message received back from the server # 4. check the message received back from the server
if (data == test_msg): if (data == test_msg):
print("PASS: Received correct message {}".format(data)) print('PASS: Received correct message {}'.format(data))
pass pass
else: else:
print("Failure!") print('Failure!')
raise ValueError('Wrong data received from asi tcp server: {} (expoected:{})'.format(data, test_msg)) raise ValueError('Wrong data received from asi tcp server: {} (expoected:{})'.format(data, test_msg))

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import ttfw_idf import ttfw_idf

View File

@ -1,11 +1,11 @@
import re
import os import os
import re
import socket import socket
import ttfw_idf import ttfw_idf
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_asio_tcp_server(env, extra_data): def test_examples_protocol_asio_tcp_server(env, extra_data):
""" """
steps: | steps: |
@ -15,16 +15,16 @@ def test_examples_protocol_asio_tcp_server(env, extra_data):
4. Test evaluates received test message from server 4. Test evaluates received test message from server
5. Test evaluates received test message on server stdout 5. Test evaluates received test message on server stdout
""" """
test_msg = b"echo message from client to server" test_msg = b'echo message from client to server'
dut1 = env.get_dut("tcp_echo_server", "examples/protocols/asio/tcp_echo_server", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('tcp_echo_server', 'examples/protocols/asio/tcp_echo_server', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "asio_tcp_echo_server.bin") binary_file = os.path.join(dut1.app.binary_path, 'asio_tcp_echo_server.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("asio_tcp_echo_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('asio_tcp_echo_server_bin_size', '{}KB'.format(bin_size // 1024))
# 1. start test # 1. start test
dut1.start_app() dut1.start_app()
# 2. get the server IP address # 2. get the server IP address
data = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30) data = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)
# 3. create tcp client and connect to server # 3. create tcp client and connect to server
cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
cli.settimeout(30) cli.settimeout(30)
@ -33,10 +33,10 @@ def test_examples_protocol_asio_tcp_server(env, extra_data):
data = cli.recv(1024) data = cli.recv(1024)
# 4. check the message received back from the server # 4. check the message received back from the server
if (data == test_msg): if (data == test_msg):
print("PASS: Received correct message") print('PASS: Received correct message')
pass pass
else: else:
print("Failure!") print('Failure!')
raise ValueError('Wrong data received from asi tcp server: {} (expected:{})'.format(data, test_msg)) raise ValueError('Wrong data received from asi tcp server: {} (expected:{})'.format(data, test_msg))
# 5. check the client message appears also on server terminal # 5. check the client message appears also on server terminal
dut1.expect(test_msg.decode()) dut1.expect(test_msg.decode())

View File

@ -1,11 +1,11 @@
import re
import os import os
import re
import socket import socket
import ttfw_idf import ttfw_idf
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_asio_udp_server(env, extra_data): def test_examples_protocol_asio_udp_server(env, extra_data):
""" """
steps: | steps: |
@ -15,16 +15,16 @@ def test_examples_protocol_asio_udp_server(env, extra_data):
4. Test evaluates received test message from server 4. Test evaluates received test message from server
5. Test evaluates received test message on server stdout 5. Test evaluates received test message on server stdout
""" """
test_msg = b"echo message from client to server" test_msg = b'echo message from client to server'
dut1 = env.get_dut("udp_echo_server", "examples/protocols/asio/udp_echo_server", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('udp_echo_server', 'examples/protocols/asio/udp_echo_server', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "asio_udp_echo_server.bin") binary_file = os.path.join(dut1.app.binary_path, 'asio_udp_echo_server.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("asio_udp_echo_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('asio_udp_echo_server_bin_size', '{}KB'.format(bin_size // 1024))
# 1. start test # 1. start test
dut1.start_app() dut1.start_app()
# 2. get the server IP address # 2. get the server IP address
data = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30) data = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)
# 3. create tcp client and connect to server # 3. create tcp client and connect to server
cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
cli.settimeout(30) cli.settimeout(30)
@ -33,10 +33,10 @@ def test_examples_protocol_asio_udp_server(env, extra_data):
data = cli.recv(1024) data = cli.recv(1024)
# 4. check the message received back from the server # 4. check the message received back from the server
if (data == test_msg): if (data == test_msg):
print("PASS: Received correct message") print('PASS: Received correct message')
pass pass
else: else:
print("Failure!") print('Failure!')
raise ValueError('Wrong data received from asio udp server: {} (expected:{})'.format(data, test_msg)) raise ValueError('Wrong data received from asio udp server: {} (expected:{})'.format(data, test_msg))
# 5. check the client message appears also on server terminal # 5. check the client message appears also on server terminal
dut1.expect(test_msg.decode()) dut1.expect(test_msg.decode())

View File

@ -1,6 +1,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re import re
import textwrap import textwrap
import ttfw_idf import ttfw_idf

View File

@ -1,64 +1,64 @@
import re
import os import os
import re
import ttfw_idf import ttfw_idf
@ttfw_idf.idf_example_test(env_tag="Example_EthKitV1") @ttfw_idf.idf_example_test(env_tag='Example_EthKitV1')
def test_examples_protocol_esp_http_client(env, extra_data): def test_examples_protocol_esp_http_client(env, extra_data):
""" """
steps: | steps: |
1. join AP 1. join AP
2. Send HTTP request to httpbin.org 2. Send HTTP request to httpbin.org
""" """
dut1 = env.get_dut("esp_http_client", "examples/protocols/esp_http_client", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('esp_http_client', 'examples/protocols/esp_http_client', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "esp-http-client-example.bin") binary_file = os.path.join(dut1.app.binary_path, 'esp-http-client-example.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("esp_http_client_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('esp_http_client_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
dut1.expect("Connected to AP, begin http example", timeout=30) dut1.expect('Connected to AP, begin http example', timeout=30)
dut1.expect(re.compile(r"HTTP GET Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP GET Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP POST Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP POST Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP PUT Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP PUT Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP PATCH Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP PATCH Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP DELETE Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP DELETE Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP HEAD Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP HEAD Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP Basic Auth Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Basic Auth Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP Basic Auth redirect Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Basic Auth redirect Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP Digest Auth Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Digest Auth Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTPS Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTPS Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP chunk encoding Status = 200, content_length = (-?\d)")) dut1.expect(re.compile(r'HTTP chunk encoding Status = 200, content_length = (-?\d)'))
# content-len for chunked encoding is typically -1, could be a positive length in some cases # content-len for chunked encoding is typically -1, could be a positive length in some cases
dut1.expect(re.compile(r"HTTP Stream reader Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Stream reader Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"Last esp error code: 0x8001")) dut1.expect(re.compile(r'Last esp error code: 0x8001'))
dut1.expect("Finish http example") dut1.expect('Finish http example')
# test mbedtls dynamic resource # test mbedtls dynamic resource
dut1 = env.get_dut("esp_http_client", "examples/protocols/esp_http_client", dut_class=ttfw_idf.ESP32DUT, app_config_name='ssldyn') dut1 = env.get_dut('esp_http_client', 'examples/protocols/esp_http_client', dut_class=ttfw_idf.ESP32DUT, app_config_name='ssldyn')
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "esp-http-client-example.bin") binary_file = os.path.join(dut1.app.binary_path, 'esp-http-client-example.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("esp_http_client_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('esp_http_client_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
dut1.expect("Connected to AP, begin http example", timeout=30) dut1.expect('Connected to AP, begin http example', timeout=30)
dut1.expect(re.compile(r"HTTP GET Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP GET Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP POST Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP POST Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP PUT Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP PUT Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP PATCH Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP PATCH Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP DELETE Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP DELETE Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP HEAD Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP HEAD Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP Basic Auth Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Basic Auth Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP Basic Auth redirect Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Basic Auth redirect Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP Digest Auth Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Digest Auth Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTPS Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTPS Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"HTTP chunk encoding Status = 200, content_length = (-?\d)")) dut1.expect(re.compile(r'HTTP chunk encoding Status = 200, content_length = (-?\d)'))
# content-len for chunked encoding is typically -1, could be a positive length in some cases # content-len for chunked encoding is typically -1, could be a positive length in some cases
dut1.expect(re.compile(r"HTTP Stream reader Status = 200, content_length = (\d)")) dut1.expect(re.compile(r'HTTP Stream reader Status = 200, content_length = (\d)'))
dut1.expect(re.compile(r"Last esp error code: 0x8001")) dut1.expect(re.compile(r'Last esp error code: 0x8001'))
dut1.expect("Finish http example") dut1.expect('Finish http example')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os import os
import re import re
import sys import sys
import ttfw_idf import ttfw_idf

View File

@ -16,22 +16,22 @@
# #
from __future__ import print_function from __future__ import print_function
from future.utils import tobytes
from builtins import input
import os
import sys
import struct
import argparse import argparse
import os
import ssl import ssl
import struct
import sys
from builtins import input
import proto import proto
from future.utils import tobytes
# The tools directory is already in the PATH in environment prepared by install.sh which would allow to import # The tools directory is already in the PATH in environment prepared by install.sh which would allow to import
# esp_prov as file but not as complete module. # esp_prov as file but not as complete module.
sys.path.insert(0, os.path.join(os.environ['IDF_PATH'], 'tools/esp_prov')) sys.path.insert(0, os.path.join(os.environ['IDF_PATH'], 'tools/esp_prov'))
import esp_prov # noqa: E402 import esp_prov # noqa: E402
# Set this to true to allow exceptions to be thrown # Set this to true to allow exceptions to be thrown
config_throw_except = False config_throw_except = False
@ -48,26 +48,26 @@ PROP_FLAG_READONLY = (1 << 0)
def prop_typestr(prop): def prop_typestr(prop):
if prop["type"] == PROP_TYPE_TIMESTAMP: if prop['type'] == PROP_TYPE_TIMESTAMP:
return "TIME(us)" return 'TIME(us)'
elif prop["type"] == PROP_TYPE_INT32: elif prop['type'] == PROP_TYPE_INT32:
return "INT32" return 'INT32'
elif prop["type"] == PROP_TYPE_BOOLEAN: elif prop['type'] == PROP_TYPE_BOOLEAN:
return "BOOLEAN" return 'BOOLEAN'
elif prop["type"] == PROP_TYPE_STRING: elif prop['type'] == PROP_TYPE_STRING:
return "STRING" return 'STRING'
return "UNKNOWN" return 'UNKNOWN'
def encode_prop_value(prop, value): def encode_prop_value(prop, value):
try: try:
if prop["type"] == PROP_TYPE_TIMESTAMP: if prop['type'] == PROP_TYPE_TIMESTAMP:
return struct.pack('q', value) return struct.pack('q', value)
elif prop["type"] == PROP_TYPE_INT32: elif prop['type'] == PROP_TYPE_INT32:
return struct.pack('i', value) return struct.pack('i', value)
elif prop["type"] == PROP_TYPE_BOOLEAN: elif prop['type'] == PROP_TYPE_BOOLEAN:
return struct.pack('?', value) return struct.pack('?', value)
elif prop["type"] == PROP_TYPE_STRING: elif prop['type'] == PROP_TYPE_STRING:
return tobytes(value) return tobytes(value)
return value return value
except struct.error as e: except struct.error as e:
@ -77,13 +77,13 @@ def encode_prop_value(prop, value):
def decode_prop_value(prop, value): def decode_prop_value(prop, value):
try: try:
if prop["type"] == PROP_TYPE_TIMESTAMP: if prop['type'] == PROP_TYPE_TIMESTAMP:
return struct.unpack('q', value)[0] return struct.unpack('q', value)[0]
elif prop["type"] == PROP_TYPE_INT32: elif prop['type'] == PROP_TYPE_INT32:
return struct.unpack('i', value)[0] return struct.unpack('i', value)[0]
elif prop["type"] == PROP_TYPE_BOOLEAN: elif prop['type'] == PROP_TYPE_BOOLEAN:
return struct.unpack('?', value)[0] return struct.unpack('?', value)[0]
elif prop["type"] == PROP_TYPE_STRING: elif prop['type'] == PROP_TYPE_STRING:
return value.decode('latin-1') return value.decode('latin-1')
return value return value
except struct.error as e: except struct.error as e:
@ -93,13 +93,13 @@ def decode_prop_value(prop, value):
def str_to_prop_value(prop, strval): def str_to_prop_value(prop, strval):
try: try:
if prop["type"] == PROP_TYPE_TIMESTAMP: if prop['type'] == PROP_TYPE_TIMESTAMP:
return int(strval) return int(strval)
elif prop["type"] == PROP_TYPE_INT32: elif prop['type'] == PROP_TYPE_INT32:
return int(strval) return int(strval)
elif prop["type"] == PROP_TYPE_BOOLEAN: elif prop['type'] == PROP_TYPE_BOOLEAN:
return bool(strval) return bool(strval)
elif prop["type"] == PROP_TYPE_STRING: elif prop['type'] == PROP_TYPE_STRING:
return strval return strval
return strval return strval
except ValueError as e: except ValueError as e:
@ -108,7 +108,7 @@ def str_to_prop_value(prop, strval):
def prop_is_readonly(prop): def prop_is_readonly(prop):
return (prop["flags"] & PROP_FLAG_READONLY) != 0 return (prop['flags'] & PROP_FLAG_READONLY) != 0
def on_except(err): def on_except(err):
@ -122,8 +122,8 @@ def get_transport(sel_transport, service_name, check_hostname):
try: try:
tp = None tp = None
if (sel_transport == 'http'): if (sel_transport == 'http'):
example_path = os.environ['IDF_PATH'] + "/examples/protocols/esp_local_ctrl" example_path = os.environ['IDF_PATH'] + '/examples/protocols/esp_local_ctrl'
cert_path = example_path + "/main/certs/rootCA.pem" cert_path = example_path + '/main/certs/rootCA.pem'
ssl_ctx = ssl.create_default_context(cafile=cert_path) ssl_ctx = ssl.create_default_context(cafile=cert_path)
ssl_ctx.check_hostname = check_hostname ssl_ctx.check_hostname = check_hostname
tp = esp_prov.transport.Transport_HTTP(service_name, ssl_ctx) tp = esp_prov.transport.Transport_HTTP(service_name, ssl_ctx)
@ -156,15 +156,15 @@ def get_all_property_values(tp):
response = tp.send_data('esp_local_ctrl/control', message) response = tp.send_data('esp_local_ctrl/control', message)
count = proto.get_prop_count_response(response) count = proto.get_prop_count_response(response)
if count == 0: if count == 0:
raise RuntimeError("No properties found!") raise RuntimeError('No properties found!')
indices = [i for i in range(count)] indices = [i for i in range(count)]
message = proto.get_prop_vals_request(indices) message = proto.get_prop_vals_request(indices)
response = tp.send_data('esp_local_ctrl/control', message) response = tp.send_data('esp_local_ctrl/control', message)
props = proto.get_prop_vals_response(response) props = proto.get_prop_vals_response(response)
if len(props) != count: if len(props) != count:
raise RuntimeError("Incorrect count of properties!") raise RuntimeError('Incorrect count of properties!')
for p in props: for p in props:
p["value"] = decode_prop_value(p, p["value"]) p['value'] = decode_prop_value(p, p['value'])
return props return props
except RuntimeError as e: except RuntimeError as e:
on_except(e) on_except(e)
@ -176,7 +176,7 @@ def set_property_values(tp, props, indices, values, check_readonly=False):
if check_readonly: if check_readonly:
for index in indices: for index in indices:
if prop_is_readonly(props[index]): if prop_is_readonly(props[index]):
raise RuntimeError("Cannot set value of Read-Only property") raise RuntimeError('Cannot set value of Read-Only property')
message = proto.set_prop_vals_request(indices, values) message = proto.set_prop_vals_request(indices, values)
response = tp.send_data('esp_local_ctrl/control', message) response = tp.send_data('esp_local_ctrl/control', message)
return proto.set_prop_vals_response(response) return proto.set_prop_vals_response(response)
@ -188,27 +188,27 @@ def set_property_values(tp, props, indices, values, check_readonly=False):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(add_help=False) parser = argparse.ArgumentParser(add_help=False)
parser = argparse.ArgumentParser(description="Control an ESP32 running esp_local_ctrl service") parser = argparse.ArgumentParser(description='Control an ESP32 running esp_local_ctrl service')
parser.add_argument("--version", dest='version', type=str, parser.add_argument('--version', dest='version', type=str,
help="Protocol version", default='') help='Protocol version', default='')
parser.add_argument("--transport", dest='transport', type=str, parser.add_argument('--transport', dest='transport', type=str,
help="transport i.e http or ble", default='http') help='transport i.e http or ble', default='http')
parser.add_argument("--name", dest='service_name', type=str, parser.add_argument('--name', dest='service_name', type=str,
help="BLE Device Name / HTTP Server hostname or IP", default='') help='BLE Device Name / HTTP Server hostname or IP', default='')
parser.add_argument("--dont-check-hostname", action="store_true", parser.add_argument('--dont-check-hostname', action='store_true',
# If enabled, the certificate won't be rejected for hostname mismatch. # If enabled, the certificate won't be rejected for hostname mismatch.
# This option is hidden because it should be used only for testing purposes. # This option is hidden because it should be used only for testing purposes.
help=argparse.SUPPRESS) help=argparse.SUPPRESS)
parser.add_argument("-v", "--verbose", dest='verbose', help="increase output verbosity", action="store_true") parser.add_argument('-v', '--verbose', dest='verbose', help='increase output verbosity', action='store_true')
args = parser.parse_args() args = parser.parse_args()
if args.version != '': if args.version != '':
print("==== Esp_Ctrl Version: " + args.version + " ====") print('==== Esp_Ctrl Version: ' + args.version + ' ====')
if args.service_name == '': if args.service_name == '':
args.service_name = 'my_esp_ctrl_device' args.service_name = 'my_esp_ctrl_device'
@ -217,45 +217,45 @@ if __name__ == '__main__':
obj_transport = get_transport(args.transport, args.service_name, not args.dont_check_hostname) obj_transport = get_transport(args.transport, args.service_name, not args.dont_check_hostname)
if obj_transport is None: if obj_transport is None:
print("---- Invalid transport ----") print('---- Invalid transport ----')
exit(1) exit(1)
if args.version != '': if args.version != '':
print("\n==== Verifying protocol version ====") print('\n==== Verifying protocol version ====')
if not version_match(obj_transport, args.version, args.verbose): if not version_match(obj_transport, args.version, args.verbose):
print("---- Error in protocol version matching ----") print('---- Error in protocol version matching ----')
exit(2) exit(2)
print("==== Verified protocol version successfully ====") print('==== Verified protocol version successfully ====')
while True: while True:
properties = get_all_property_values(obj_transport) properties = get_all_property_values(obj_transport)
if len(properties) == 0: if len(properties) == 0:
print("---- Error in reading property values ----") print('---- Error in reading property values ----')
exit(4) exit(4)
print("\n==== Available Properties ====") print('\n==== Available Properties ====')
print("{0: >4} {1: <16} {2: <10} {3: <16} {4: <16}".format( print('{0: >4} {1: <16} {2: <10} {3: <16} {4: <16}'.format(
"S.N.", "Name", "Type", "Flags", "Value")) 'S.N.', 'Name', 'Type', 'Flags', 'Value'))
for i in range(len(properties)): for i in range(len(properties)):
print("[{0: >2}] {1: <16} {2: <10} {3: <16} {4: <16}".format( print('[{0: >2}] {1: <16} {2: <10} {3: <16} {4: <16}'.format(
i + 1, properties[i]["name"], prop_typestr(properties[i]), i + 1, properties[i]['name'], prop_typestr(properties[i]),
["","Read-Only"][prop_is_readonly(properties[i])], ['','Read-Only'][prop_is_readonly(properties[i])],
str(properties[i]["value"]))) str(properties[i]['value'])))
select = 0 select = 0
while True: while True:
try: try:
inval = input("\nSelect properties to set (0 to re-read, 'q' to quit) : ") inval = input("\nSelect properties to set (0 to re-read, 'q' to quit) : ")
if inval.lower() == 'q': if inval.lower() == 'q':
print("Quitting...") print('Quitting...')
exit(5) exit(5)
invals = inval.split(',') invals = inval.split(',')
selections = [int(val) for val in invals] selections = [int(val) for val in invals]
if min(selections) < 0 or max(selections) > len(properties): if min(selections) < 0 or max(selections) > len(properties):
raise ValueError("Invalid input") raise ValueError('Invalid input')
break break
except ValueError as e: except ValueError as e:
print(str(e) + "! Retry...") print(str(e) + '! Retry...')
if len(selections) == 1 and selections[0] == 0: if len(selections) == 1 and selections[0] == 0:
continue continue
@ -264,15 +264,15 @@ if __name__ == '__main__':
set_indices = [] set_indices = []
for select in selections: for select in selections:
while True: while True:
inval = input("Enter value to set for property (" + properties[select - 1]["name"] + ") : ") inval = input('Enter value to set for property (' + properties[select - 1]['name'] + ') : ')
value = encode_prop_value(properties[select - 1], value = encode_prop_value(properties[select - 1],
str_to_prop_value(properties[select - 1], inval)) str_to_prop_value(properties[select - 1], inval))
if value is None: if value is None:
print("Invalid input! Retry...") print('Invalid input! Retry...')
continue continue
break break
set_values += [value] set_values += [value]
set_indices += [select - 1] set_indices += [select - 1]
if not set_property_values(obj_transport, properties, set_indices, set_values): if not set_property_values(obj_transport, properties, set_indices, set_values):
print("Failed to set values!") print('Failed to set values!')

View File

@ -15,9 +15,11 @@
from __future__ import print_function from __future__ import print_function
from future.utils import tobytes
import os import os
from future.utils import tobytes
def _load_source(name, path): def _load_source(name, path):
try: try:
@ -30,8 +32,8 @@ def _load_source(name, path):
idf_path = os.environ['IDF_PATH'] idf_path = os.environ['IDF_PATH']
constants_pb2 = _load_source("constants_pb2", idf_path + "/components/protocomm/python/constants_pb2.py") constants_pb2 = _load_source('constants_pb2', idf_path + '/components/protocomm/python/constants_pb2.py')
local_ctrl_pb2 = _load_source("esp_local_ctrl_pb2", idf_path + "/components/esp_local_ctrl/python/esp_local_ctrl_pb2.py") local_ctrl_pb2 = _load_source('esp_local_ctrl_pb2', idf_path + '/components/esp_local_ctrl/python/esp_local_ctrl_pb2.py')
def get_prop_count_request(): def get_prop_count_request():
@ -67,10 +69,10 @@ def get_prop_vals_response(response_data):
if (resp.resp_get_prop_vals.status == 0): if (resp.resp_get_prop_vals.status == 0):
for prop in resp.resp_get_prop_vals.props: for prop in resp.resp_get_prop_vals.props:
results += [{ results += [{
"name": prop.name, 'name': prop.name,
"type": prop.type, 'type': prop.type,
"flags": prop.flags, 'flags': prop.flags,
"value": tobytes(prop.value) 'value': tobytes(prop.value)
}] }]
return results return results

View File

@ -14,15 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import division from __future__ import division, print_function, unicode_literals
from __future__ import print_function
from __future__ import unicode_literals import os
import re import re
import os
from tiny_test_fw import Utility
import ttfw_idf import ttfw_idf
from idf_http_server_test import test as client from idf_http_server_test import test as client
from tiny_test_fw import Utility
# When running on local machine execute the following before running this script # When running on local machine execute the following before running this script
# > make app bootloader # > make app bootloader
@ -36,23 +35,23 @@ from idf_http_server_test import test as client
# features to this component. # features to this component.
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_http_server_advanced(env, extra_data): def test_examples_protocol_http_server_advanced(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("http_server", "examples/protocols/http_server/advanced_tests", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('http_server', 'examples/protocols/http_server/advanced_tests', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "tests.bin") binary_file = os.path.join(dut1.app.binary_path, 'tests.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("http_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('http_server_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting http_server advanced test app") Utility.console_log('Starting http_server advanced test app')
dut1.start_app() dut1.start_app()
# Parse IP address of STA # Parse IP address of STA
Utility.console_log("Waiting to connect with AP") Utility.console_log('Waiting to connect with AP')
got_ip = dut1.expect(re.compile(r"(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)"), timeout=30)[0] got_ip = dut1.expect(re.compile(r'(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)'), timeout=30)[0]
got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Started HTTP server on port: '(\d+)'"), timeout=15)[0] got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Started HTTP server on port: '(\d+)'"), timeout=15)[0]
result = dut1.expect(re.compile(r"(?:[\s\S]*)Max URI handlers: '(\d+)'(?:[\s\S]*)Max Open Sessions: " # noqa: W605 result = dut1.expect(re.compile(r"(?:[\s\S]*)Max URI handlers: '(\d+)'(?:[\s\S]*)Max Open Sessions: " # noqa: W605
@ -64,18 +63,18 @@ def test_examples_protocol_http_server_advanced(env, extra_data):
max_uri_len = int(result[3]) max_uri_len = int(result[3])
max_stack_size = int(result[4]) max_stack_size = int(result[4])
Utility.console_log("Got IP : " + got_ip) Utility.console_log('Got IP : ' + got_ip)
Utility.console_log("Got Port : " + got_port) Utility.console_log('Got Port : ' + got_port)
# Run test script # Run test script
# If failed raise appropriate exception # If failed raise appropriate exception
failed = False failed = False
Utility.console_log("Sessions and Context Tests...") Utility.console_log('Sessions and Context Tests...')
if not client.spillover_session(got_ip, got_port, max_sessions): if not client.spillover_session(got_ip, got_port, max_sessions):
Utility.console_log("Ignoring failure") Utility.console_log('Ignoring failure')
if not client.parallel_sessions_adder(got_ip, got_port, max_sessions): if not client.parallel_sessions_adder(got_ip, got_port, max_sessions):
Utility.console_log("Ignoring failure") Utility.console_log('Ignoring failure')
if not client.leftover_data_test(got_ip, got_port): if not client.leftover_data_test(got_ip, got_port):
failed = True failed = True
if not client.async_response_test(got_ip, got_port): if not client.async_response_test(got_ip, got_port):
@ -90,17 +89,17 @@ def test_examples_protocol_http_server_advanced(env, extra_data):
# if not client.packet_size_limit_test(got_ip, got_port, test_size): # if not client.packet_size_limit_test(got_ip, got_port, test_size):
# Utility.console_log("Ignoring failure") # Utility.console_log("Ignoring failure")
Utility.console_log("Getting initial stack usage...") Utility.console_log('Getting initial stack usage...')
if not client.get_hello(got_ip, got_port): if not client.get_hello(got_ip, got_port):
failed = True failed = True
inital_stack = int(dut1.expect(re.compile(r"(?:[\s\S]*)Free Stack for server task: '(\d+)'"), timeout=15)[0]) inital_stack = int(dut1.expect(re.compile(r"(?:[\s\S]*)Free Stack for server task: '(\d+)'"), timeout=15)[0])
if inital_stack < 0.1 * max_stack_size: if inital_stack < 0.1 * max_stack_size:
Utility.console_log("More than 90% of stack being used on server start") Utility.console_log('More than 90% of stack being used on server start')
failed = True failed = True
Utility.console_log("Basic HTTP Client Tests...") Utility.console_log('Basic HTTP Client Tests...')
if not client.get_hello(got_ip, got_port): if not client.get_hello(got_ip, got_port):
failed = True failed = True
if not client.post_hello(got_ip, got_port): if not client.post_hello(got_ip, got_port):
@ -122,7 +121,7 @@ def test_examples_protocol_http_server_advanced(env, extra_data):
if not client.get_test_headers(got_ip, got_port): if not client.get_test_headers(got_ip, got_port):
failed = True failed = True
Utility.console_log("Error code tests...") Utility.console_log('Error code tests...')
if not client.code_500_server_error_test(got_ip, got_port): if not client.code_500_server_error_test(got_ip, got_port):
failed = True failed = True
if not client.code_501_method_not_impl(got_ip, got_port): if not client.code_501_method_not_impl(got_ip, got_port):
@ -138,20 +137,20 @@ def test_examples_protocol_http_server_advanced(env, extra_data):
if not client.code_408_req_timeout(got_ip, got_port): if not client.code_408_req_timeout(got_ip, got_port):
failed = True failed = True
if not client.code_414_uri_too_long(got_ip, got_port, max_uri_len): if not client.code_414_uri_too_long(got_ip, got_port, max_uri_len):
Utility.console_log("Ignoring failure") Utility.console_log('Ignoring failure')
if not client.code_431_hdr_too_long(got_ip, got_port, max_hdr_len): if not client.code_431_hdr_too_long(got_ip, got_port, max_hdr_len):
Utility.console_log("Ignoring failure") Utility.console_log('Ignoring failure')
if not client.test_upgrade_not_supported(got_ip, got_port): if not client.test_upgrade_not_supported(got_ip, got_port):
failed = True failed = True
Utility.console_log("Getting final stack usage...") Utility.console_log('Getting final stack usage...')
if not client.get_hello(got_ip, got_port): if not client.get_hello(got_ip, got_port):
failed = True failed = True
final_stack = int(dut1.expect(re.compile(r"(?:[\s\S]*)Free Stack for server task: '(\d+)'"), timeout=15)[0]) final_stack = int(dut1.expect(re.compile(r"(?:[\s\S]*)Free Stack for server task: '(\d+)'"), timeout=15)[0])
if final_stack < 0.05 * max_stack_size: if final_stack < 0.05 * max_stack_size:
Utility.console_log("More than 95% of stack got used during tests") Utility.console_log('More than 95% of stack got used during tests')
failed = True failed = True
if failed: if failed:

View File

@ -129,20 +129,17 @@
# - Simple GET on /hello/restart_results (returns the leak results) # - Simple GET on /hello/restart_results (returns the leak results)
from __future__ import division from __future__ import division, print_function
from __future__ import print_function
from builtins import str
from builtins import range
from builtins import object
import threading
import socket
import time
import argparse import argparse
import http.client import http.client
import sys
import string
import random import random
import socket
import string
import sys
import threading
import time
from builtins import object, range, str
try: try:
import Utility import Utility
@ -151,7 +148,7 @@ except ImportError:
# This environment variable is expected on the host machine # This environment variable is expected on the host machine
# > export TEST_FW_PATH=~/esp/esp-idf/tools/tiny-test-fw # > export TEST_FW_PATH=~/esp/esp-idf/tools/tiny-test-fw
test_fw_path = os.getenv("TEST_FW_PATH") test_fw_path = os.getenv('TEST_FW_PATH')
if test_fw_path and test_fw_path not in sys.path: if test_fw_path and test_fw_path not in sys.path:
sys.path.insert(0, test_fw_path) sys.path.insert(0, test_fw_path)
@ -177,32 +174,32 @@ class Session(object):
self.client.sendall(data.encode()) self.client.sendall(data.encode())
except socket.error as err: except socket.error as err:
self.client.close() self.client.close()
Utility.console_log("Socket Error in send :", err) Utility.console_log('Socket Error in send :', err)
rval = False rval = False
return rval return rval
def send_get(self, path, headers=None): def send_get(self, path, headers=None):
request = "GET " + path + " HTTP/1.1\r\nHost: " + self.target request = 'GET ' + path + ' HTTP/1.1\r\nHost: ' + self.target
if headers: if headers:
for field, value in headers.items(): for field, value in headers.items():
request += "\r\n" + field + ": " + value request += '\r\n' + field + ': ' + value
request += "\r\n\r\n" request += '\r\n\r\n'
return self.send_err_check(request) return self.send_err_check(request)
def send_put(self, path, data, headers=None): def send_put(self, path, data, headers=None):
request = "PUT " + path + " HTTP/1.1\r\nHost: " + self.target request = 'PUT ' + path + ' HTTP/1.1\r\nHost: ' + self.target
if headers: if headers:
for field, value in headers.items(): for field, value in headers.items():
request += "\r\n" + field + ": " + value request += '\r\n' + field + ': ' + value
request += "\r\nContent-Length: " + str(len(data)) + "\r\n\r\n" request += '\r\nContent-Length: ' + str(len(data)) + '\r\n\r\n'
return self.send_err_check(request, data) return self.send_err_check(request, data)
def send_post(self, path, data, headers=None): def send_post(self, path, data, headers=None):
request = "POST " + path + " HTTP/1.1\r\nHost: " + self.target request = 'POST ' + path + ' HTTP/1.1\r\nHost: ' + self.target
if headers: if headers:
for field, value in headers.items(): for field, value in headers.items():
request += "\r\n" + field + ": " + value request += '\r\n' + field + ': ' + value
request += "\r\nContent-Length: " + str(len(data)) + "\r\n\r\n" request += '\r\nContent-Length: ' + str(len(data)) + '\r\n\r\n'
return self.send_err_check(request, data) return self.send_err_check(request, data)
def read_resp_hdrs(self): def read_resp_hdrs(self):
@ -246,7 +243,7 @@ class Session(object):
return headers return headers
except socket.error as err: except socket.error as err:
self.client.close() self.client.close()
Utility.console_log("Socket Error in recv :", err) Utility.console_log('Socket Error in recv :', err)
return None return None
def read_resp_data(self): def read_resp_data(self):
@ -275,9 +272,9 @@ class Session(object):
rem_len -= len(new_data) rem_len -= len(new_data)
chunk_data_buf = '' chunk_data_buf = ''
# Fetch remaining CRLF # Fetch remaining CRLF
if self.client.recv(2) != "\r\n": if self.client.recv(2) != '\r\n':
# Error in packet # Error in packet
Utility.console_log("Error in chunked data") Utility.console_log('Error in chunked data')
return None return None
if not chunk_len: if not chunk_len:
# If last chunk # If last chunk
@ -290,7 +287,7 @@ class Session(object):
return read_data return read_data
except socket.error as err: except socket.error as err:
self.client.close() self.client.close()
Utility.console_log("Socket Error in recv :", err) Utility.console_log('Socket Error in recv :', err)
return None return None
def close(self): def close(self):
@ -299,10 +296,10 @@ class Session(object):
def test_val(text, expected, received): def test_val(text, expected, received):
if expected != received: if expected != received:
Utility.console_log(" Fail!") Utility.console_log(' Fail!')
Utility.console_log(" [reason] " + text + ":") Utility.console_log(' [reason] ' + text + ':')
Utility.console_log(" expected: " + str(expected)) Utility.console_log(' expected: ' + str(expected))
Utility.console_log(" received: " + str(received)) Utility.console_log(' received: ' + str(received))
return False return False
return True return True
@ -320,7 +317,7 @@ class adder_thread (threading.Thread):
# Pipeline 3 requests # Pipeline 3 requests
if (_verbose_): if (_verbose_):
Utility.console_log(" Thread: Using adder start " + str(self.id)) Utility.console_log(' Thread: Using adder start ' + str(self.id))
for _ in range(self.depth): for _ in range(self.depth):
self.session.send_post('/adder', str(self.id)) self.session.send_post('/adder', str(self.id))
@ -332,10 +329,10 @@ class adder_thread (threading.Thread):
def adder_result(self): def adder_result(self):
if len(self.response) != self.depth: if len(self.response) != self.depth:
Utility.console_log("Error : missing response packets") Utility.console_log('Error : missing response packets')
return False return False
for i in range(len(self.response)): for i in range(len(self.response)):
if not test_val("Thread" + str(self.id) + " response[" + str(i) + "]", if not test_val('Thread' + str(self.id) + ' response[' + str(i) + ']',
str(self.id * (i + 1)), str(self.response[i])): str(self.id * (i + 1)), str(self.response[i])):
return False return False
return True return True
@ -348,177 +345,177 @@ def get_hello(dut, port):
# GET /hello should return 'Hello World!' # GET /hello should return 'Hello World!'
Utility.console_log("[test] GET /hello returns 'Hello World!' =>", end=' ') Utility.console_log("[test] GET /hello returns 'Hello World!' =>", end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("GET", "/hello") conn.request('GET', '/hello')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 200, resp.status): if not test_val('status_code', 200, resp.status):
conn.close() conn.close()
return False return False
if not test_val("data", "Hello World!", resp.read().decode()): if not test_val('data', 'Hello World!', resp.read().decode()):
conn.close() conn.close()
return False return False
if not test_val("data", "text/html", resp.getheader('Content-Type')): if not test_val('data', 'text/html', resp.getheader('Content-Type')):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def put_hello(dut, port): def put_hello(dut, port):
# PUT /hello returns 405' # PUT /hello returns 405'
Utility.console_log("[test] PUT /hello returns 405 =>", end=' ') Utility.console_log('[test] PUT /hello returns 405 =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("PUT", "/hello", "Hello") conn.request('PUT', '/hello', 'Hello')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 405, resp.status): if not test_val('status_code', 405, resp.status):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def post_hello(dut, port): def post_hello(dut, port):
# POST /hello returns 405' # POST /hello returns 405'
Utility.console_log("[test] POST /hello returns 405 =>", end=' ') Utility.console_log('[test] POST /hello returns 405 =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("POST", "/hello", "Hello") conn.request('POST', '/hello', 'Hello')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 405, resp.status): if not test_val('status_code', 405, resp.status):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def post_echo(dut, port): def post_echo(dut, port):
# POST /echo echoes data' # POST /echo echoes data'
Utility.console_log("[test] POST /echo echoes data =>", end=' ') Utility.console_log('[test] POST /echo echoes data =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("POST", "/echo", "Hello") conn.request('POST', '/echo', 'Hello')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 200, resp.status): if not test_val('status_code', 200, resp.status):
conn.close() conn.close()
return False return False
if not test_val("data", "Hello", resp.read().decode()): if not test_val('data', 'Hello', resp.read().decode()):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def put_echo(dut, port): def put_echo(dut, port):
# PUT /echo echoes data' # PUT /echo echoes data'
Utility.console_log("[test] PUT /echo echoes data =>", end=' ') Utility.console_log('[test] PUT /echo echoes data =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("PUT", "/echo", "Hello") conn.request('PUT', '/echo', 'Hello')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 200, resp.status): if not test_val('status_code', 200, resp.status):
conn.close() conn.close()
return False return False
if not test_val("data", "Hello", resp.read().decode()): if not test_val('data', 'Hello', resp.read().decode()):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def get_echo(dut, port): def get_echo(dut, port):
# GET /echo returns 404' # GET /echo returns 404'
Utility.console_log("[test] GET /echo returns 405 =>", end=' ') Utility.console_log('[test] GET /echo returns 405 =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("GET", "/echo") conn.request('GET', '/echo')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 405, resp.status): if not test_val('status_code', 405, resp.status):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def get_test_headers(dut, port): def get_test_headers(dut, port):
# GET /test_header returns data of Header2' # GET /test_header returns data of Header2'
Utility.console_log("[test] GET /test_header =>", end=' ') Utility.console_log('[test] GET /test_header =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
custom_header = {"Header1": "Value1", "Header3": "Value3"} custom_header = {'Header1': 'Value1', 'Header3': 'Value3'}
header2_values = ["", " ", "Value2", " Value2", "Value2 ", " Value2 "] header2_values = ['', ' ', 'Value2', ' Value2', 'Value2 ', ' Value2 ']
for val in header2_values: for val in header2_values:
custom_header["Header2"] = val custom_header['Header2'] = val
conn.request("GET", "/test_header", headers=custom_header) conn.request('GET', '/test_header', headers=custom_header)
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 200, resp.status): if not test_val('status_code', 200, resp.status):
conn.close() conn.close()
return False return False
hdr_val_start_idx = val.find("Value2") hdr_val_start_idx = val.find('Value2')
if hdr_val_start_idx == -1: if hdr_val_start_idx == -1:
if not test_val("header: Header2", "", resp.read().decode()): if not test_val('header: Header2', '', resp.read().decode()):
conn.close() conn.close()
return False return False
else: else:
if not test_val("header: Header2", val[hdr_val_start_idx:], resp.read().decode()): if not test_val('header: Header2', val[hdr_val_start_idx:], resp.read().decode()):
conn.close() conn.close()
return False return False
resp.read() resp.read()
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def get_hello_type(dut, port): def get_hello_type(dut, port):
# GET /hello/type_html returns text/html as Content-Type' # GET /hello/type_html returns text/html as Content-Type'
Utility.console_log("[test] GET /hello/type_html has Content-Type of text/html =>", end=' ') Utility.console_log('[test] GET /hello/type_html has Content-Type of text/html =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("GET", "/hello/type_html") conn.request('GET', '/hello/type_html')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 200, resp.status): if not test_val('status_code', 200, resp.status):
conn.close() conn.close()
return False return False
if not test_val("data", "Hello World!", resp.read().decode()): if not test_val('data', 'Hello World!', resp.read().decode()):
conn.close() conn.close()
return False return False
if not test_val("data", "text/html", resp.getheader('Content-Type')): if not test_val('data', 'text/html', resp.getheader('Content-Type')):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def get_hello_status(dut, port): def get_hello_status(dut, port):
# GET /hello/status_500 returns status 500' # GET /hello/status_500 returns status 500'
Utility.console_log("[test] GET /hello/status_500 returns status 500 =>", end=' ') Utility.console_log('[test] GET /hello/status_500 returns status 500 =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("GET", "/hello/status_500") conn.request('GET', '/hello/status_500')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 500, resp.status): if not test_val('status_code', 500, resp.status):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def get_false_uri(dut, port): def get_false_uri(dut, port):
# GET /false_uri returns status 404' # GET /false_uri returns status 404'
Utility.console_log("[test] GET /false_uri returns status 404 =>", end=' ') Utility.console_log('[test] GET /false_uri returns status 404 =>', end=' ')
conn = http.client.HTTPConnection(dut, int(port), timeout=15) conn = http.client.HTTPConnection(dut, int(port), timeout=15)
conn.request("GET", "/false_uri") conn.request('GET', '/false_uri')
resp = conn.getresponse() resp = conn.getresponse()
if not test_val("status_code", 404, resp.status): if not test_val('status_code', 404, resp.status):
conn.close() conn.close()
return False return False
Utility.console_log("Success") Utility.console_log('Success')
conn.close() conn.close()
return True return True
def parallel_sessions_adder(dut, port, max_sessions): def parallel_sessions_adder(dut, port, max_sessions):
# POSTs on /adder in parallel sessions # POSTs on /adder in parallel sessions
Utility.console_log("[test] POST {pipelined} on /adder in " + str(max_sessions) + " sessions =>", end=' ') Utility.console_log('[test] POST {pipelined} on /adder in ' + str(max_sessions) + ' sessions =>', end=' ')
t = [] t = []
# Create all sessions # Create all sessions
for i in range(max_sessions): for i in range(max_sessions):
@ -532,90 +529,90 @@ def parallel_sessions_adder(dut, port, max_sessions):
res = True res = True
for i in range(len(t)): for i in range(len(t)):
if not test_val("Thread" + str(i) + " Failed", t[i].adder_result(), True): if not test_val('Thread' + str(i) + ' Failed', t[i].adder_result(), True):
res = False res = False
t[i].close() t[i].close()
if (res): if (res):
Utility.console_log("Success") Utility.console_log('Success')
return res return res
def async_response_test(dut, port): def async_response_test(dut, port):
# Test that an asynchronous work is executed in the HTTPD's context # Test that an asynchronous work is executed in the HTTPD's context
# This is tested by reading two responses over the same session # This is tested by reading two responses over the same session
Utility.console_log("[test] Test HTTPD Work Queue (Async response) =>", end=' ') Utility.console_log('[test] Test HTTPD Work Queue (Async response) =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
s.send_get('/async_data') s.send_get('/async_data')
s.read_resp_hdrs() s.read_resp_hdrs()
if not test_val("First Response", "Hello World!", s.read_resp_data()): if not test_val('First Response', 'Hello World!', s.read_resp_data()):
s.close() s.close()
return False return False
s.read_resp_hdrs() s.read_resp_hdrs()
if not test_val("Second Response", "Hello Double World!", s.read_resp_data()): if not test_val('Second Response', 'Hello Double World!', s.read_resp_data()):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def leftover_data_test(dut, port): def leftover_data_test(dut, port):
# Leftover data in POST is purged (valid and invalid URIs) # Leftover data in POST is purged (valid and invalid URIs)
Utility.console_log("[test] Leftover data in POST is purged (valid and invalid URIs) =>", end=' ') Utility.console_log('[test] Leftover data in POST is purged (valid and invalid URIs) =>', end=' ')
s = http.client.HTTPConnection(dut + ":" + port, timeout=15) s = http.client.HTTPConnection(dut + ':' + port, timeout=15)
s.request("POST", url='/leftover_data', body="abcdefghijklmnopqrstuvwxyz\r\nabcdefghijklmnopqrstuvwxyz") s.request('POST', url='/leftover_data', body='abcdefghijklmnopqrstuvwxyz\r\nabcdefghijklmnopqrstuvwxyz')
resp = s.getresponse() resp = s.getresponse()
if not test_val("Partial data", "abcdefghij", resp.read().decode()): if not test_val('Partial data', 'abcdefghij', resp.read().decode()):
s.close() s.close()
return False return False
s.request("GET", url='/hello') s.request('GET', url='/hello')
resp = s.getresponse() resp = s.getresponse()
if not test_val("Hello World Data", "Hello World!", resp.read().decode()): if not test_val('Hello World Data', 'Hello World!', resp.read().decode()):
s.close() s.close()
return False return False
s.request("POST", url='/false_uri', body="abcdefghijklmnopqrstuvwxyz\r\nabcdefghijklmnopqrstuvwxyz") s.request('POST', url='/false_uri', body='abcdefghijklmnopqrstuvwxyz\r\nabcdefghijklmnopqrstuvwxyz')
resp = s.getresponse() resp = s.getresponse()
if not test_val("False URI Status", str(404), str(resp.status)): if not test_val('False URI Status', str(404), str(resp.status)):
s.close() s.close()
return False return False
# socket would have been closed by server due to error # socket would have been closed by server due to error
s.close() s.close()
s = http.client.HTTPConnection(dut + ":" + port, timeout=15) s = http.client.HTTPConnection(dut + ':' + port, timeout=15)
s.request("GET", url='/hello') s.request('GET', url='/hello')
resp = s.getresponse() resp = s.getresponse()
if not test_val("Hello World Data", "Hello World!", resp.read().decode()): if not test_val('Hello World Data', 'Hello World!', resp.read().decode()):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def spillover_session(dut, port, max_sess): def spillover_session(dut, port, max_sess):
# Session max_sess_sessions + 1 is rejected # Session max_sess_sessions + 1 is rejected
Utility.console_log("[test] Session max_sess_sessions (" + str(max_sess) + ") + 1 is rejected =>", end=' ') Utility.console_log('[test] Session max_sess_sessions (' + str(max_sess) + ') + 1 is rejected =>', end=' ')
s = [] s = []
_verbose_ = True _verbose_ = True
for i in range(max_sess + 1): for i in range(max_sess + 1):
if (_verbose_): if (_verbose_):
Utility.console_log("Executing " + str(i)) Utility.console_log('Executing ' + str(i))
try: try:
a = http.client.HTTPConnection(dut + ":" + port, timeout=15) a = http.client.HTTPConnection(dut + ':' + port, timeout=15)
a.request("GET", url='/hello') a.request('GET', url='/hello')
resp = a.getresponse() resp = a.getresponse()
if not test_val("Connection " + str(i), "Hello World!", resp.read().decode()): if not test_val('Connection ' + str(i), 'Hello World!', resp.read().decode()):
a.close() a.close()
break break
s.append(a) s.append(a)
except Exception: except Exception:
if (_verbose_): if (_verbose_):
Utility.console_log("Connection " + str(i) + " rejected") Utility.console_log('Connection ' + str(i) + ' rejected')
a.close() a.close()
break break
@ -624,134 +621,134 @@ def spillover_session(dut, port, max_sess):
a.close() a.close()
# Check if number of connections is equal to max_sess # Check if number of connections is equal to max_sess
Utility.console_log(["Fail","Success"][len(s) == max_sess]) Utility.console_log(['Fail','Success'][len(s) == max_sess])
return (len(s) == max_sess) return (len(s) == max_sess)
def recv_timeout_test(dut, port): def recv_timeout_test(dut, port):
Utility.console_log("[test] Timeout occurs if partial packet sent =>", end=' ') Utility.console_log('[test] Timeout occurs if partial packet sent =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
s.client.sendall(b"GE") s.client.sendall(b'GE')
s.read_resp_hdrs() s.read_resp_hdrs()
resp = s.read_resp_data() resp = s.read_resp_data()
if not test_val("Request Timeout", "Server closed this connection", resp): if not test_val('Request Timeout', 'Server closed this connection', resp):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def packet_size_limit_test(dut, port, test_size): def packet_size_limit_test(dut, port, test_size):
Utility.console_log("[test] send size limit test =>", end=' ') Utility.console_log('[test] send size limit test =>', end=' ')
retry = 5 retry = 5
while (retry): while (retry):
retry -= 1 retry -= 1
Utility.console_log("data size = ", test_size) Utility.console_log('data size = ', test_size)
s = http.client.HTTPConnection(dut + ":" + port, timeout=15) s = http.client.HTTPConnection(dut + ':' + port, timeout=15)
random_data = ''.join(string.printable[random.randint(0,len(string.printable)) - 1] for _ in list(range(test_size))) random_data = ''.join(string.printable[random.randint(0,len(string.printable)) - 1] for _ in list(range(test_size)))
path = "/echo" path = '/echo'
s.request("POST", url=path, body=random_data) s.request('POST', url=path, body=random_data)
resp = s.getresponse() resp = s.getresponse()
if not test_val("Error", "200", str(resp.status)): if not test_val('Error', '200', str(resp.status)):
if test_val("Error", "500", str(resp.status)): if test_val('Error', '500', str(resp.status)):
Utility.console_log("Data too large to be allocated") Utility.console_log('Data too large to be allocated')
test_size = test_size // 10 test_size = test_size // 10
else: else:
Utility.console_log("Unexpected error") Utility.console_log('Unexpected error')
s.close() s.close()
Utility.console_log("Retry...") Utility.console_log('Retry...')
continue continue
resp = resp.read().decode() resp = resp.read().decode()
result = (resp == random_data) result = (resp == random_data)
if not result: if not result:
test_val("Data size", str(len(random_data)), str(len(resp))) test_val('Data size', str(len(random_data)), str(len(resp)))
s.close() s.close()
Utility.console_log("Retry...") Utility.console_log('Retry...')
continue continue
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
Utility.console_log("Failed") Utility.console_log('Failed')
return False return False
def arbitrary_termination_test(dut, port): def arbitrary_termination_test(dut, port):
Utility.console_log("[test] Arbitrary termination test =>", end=' ') Utility.console_log('[test] Arbitrary termination test =>', end=' ')
cases = [ cases = [
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nCustom: SomeValue\r\n\r\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nCustom: SomeValue\r\n\r\n',
"code": "200", 'code': '200',
"header": "SomeValue" 'header': 'SomeValue'
}, },
{ {
"request": "POST /echo HTTP/1.1\nHost: " + dut + "\r\nCustom: SomeValue\r\n\r\n", 'request': 'POST /echo HTTP/1.1\nHost: ' + dut + '\r\nCustom: SomeValue\r\n\r\n',
"code": "200", 'code': '200',
"header": "SomeValue" 'header': 'SomeValue'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\nCustom: SomeValue\r\n\r\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\nCustom: SomeValue\r\n\r\n',
"code": "200", 'code': '200',
"header": "SomeValue" 'header': 'SomeValue'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nCustom: SomeValue\n\r\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nCustom: SomeValue\n\r\n',
"code": "200", 'code': '200',
"header": "SomeValue" 'header': 'SomeValue'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nCustom: SomeValue\r\n\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nCustom: SomeValue\r\n\n',
"code": "200", 'code': '200',
"header": "SomeValue" 'header': 'SomeValue'
}, },
{ {
"request": "POST /echo HTTP/1.1\nHost: " + dut + "\nCustom: SomeValue\n\n", 'request': 'POST /echo HTTP/1.1\nHost: ' + dut + '\nCustom: SomeValue\n\n',
"code": "200", 'code': '200',
"header": "SomeValue" 'header': 'SomeValue'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nContent-Length: 5\n\r\nABCDE", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nContent-Length: 5\n\r\nABCDE',
"code": "200", 'code': '200',
"body": "ABCDE" 'body': 'ABCDE'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nContent-Length: 5\r\n\nABCDE", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nContent-Length: 5\r\n\nABCDE',
"code": "200", 'code': '200',
"body": "ABCDE" 'body': 'ABCDE'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nContent-Length: 5\n\nABCDE", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nContent-Length: 5\n\nABCDE',
"code": "200", 'code': '200',
"body": "ABCDE" 'body': 'ABCDE'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nContent-Length: 5\n\n\rABCD", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nContent-Length: 5\n\n\rABCD',
"code": "200", 'code': '200',
"body": "\rABCD" 'body': '\rABCD'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\r\nCustom: SomeValue\r\r\n\r\r\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\r\nCustom: SomeValue\r\r\n\r\r\n',
"code": "400" 'code': '400'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\r\nHost: " + dut + "\r\n\r\n", 'request': 'POST /echo HTTP/1.1\r\r\nHost: ' + dut + '\r\n\r\n',
"code": "400" 'code': '400'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\n\rHost: " + dut + "\r\n\r\n", 'request': 'POST /echo HTTP/1.1\r\n\rHost: ' + dut + '\r\n\r\n',
"code": "400" 'code': '400'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\rCustom: SomeValue\r\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\rCustom: SomeValue\r\n',
"code": "400" 'code': '400'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nCustom: Some\rValue\r\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nCustom: Some\rValue\r\n',
"code": "400" 'code': '400'
}, },
{ {
"request": "POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nCustom- SomeValue\r\n\r\n", 'request': 'POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nCustom- SomeValue\r\n\r\n',
"code": "400" 'code': '400'
} }
] ]
for case in cases: for case in cases:
@ -760,159 +757,159 @@ def arbitrary_termination_test(dut, port):
resp_hdrs = s.read_resp_hdrs() resp_hdrs = s.read_resp_hdrs()
resp_body = s.read_resp_data() resp_body = s.read_resp_data()
s.close() s.close()
if not test_val("Response Code", case["code"], s.status): if not test_val('Response Code', case['code'], s.status):
return False return False
if "header" in case.keys(): if 'header' in case.keys():
resp_hdr_val = None resp_hdr_val = None
if "Custom" in resp_hdrs.keys(): if 'Custom' in resp_hdrs.keys():
resp_hdr_val = resp_hdrs["Custom"] resp_hdr_val = resp_hdrs['Custom']
if not test_val("Response Header", case["header"], resp_hdr_val): if not test_val('Response Header', case['header'], resp_hdr_val):
return False return False
if "body" in case.keys(): if 'body' in case.keys():
if not test_val("Response Body", case["body"], resp_body): if not test_val('Response Body', case['body'], resp_body):
return False return False
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_500_server_error_test(dut, port): def code_500_server_error_test(dut, port):
Utility.console_log("[test] 500 Server Error test =>", end=' ') Utility.console_log('[test] 500 Server Error test =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
# Sending a very large content length will cause malloc to fail # Sending a very large content length will cause malloc to fail
content_len = 2**30 content_len = 2**30
s.client.sendall(("POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nContent-Length: " + str(content_len) + "\r\n\r\nABCD").encode()) s.client.sendall(('POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nContent-Length: ' + str(content_len) + '\r\n\r\nABCD').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
if not test_val("Server Error", "500", s.status): if not test_val('Server Error', '500', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_501_method_not_impl(dut, port): def code_501_method_not_impl(dut, port):
Utility.console_log("[test] 501 Method Not Implemented =>", end=' ') Utility.console_log('[test] 501 Method Not Implemented =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
path = "/hello" path = '/hello'
s.client.sendall(("ABC " + path + " HTTP/1.1\r\nHost: " + dut + "\r\n\r\n").encode()) s.client.sendall(('ABC ' + path + ' HTTP/1.1\r\nHost: ' + dut + '\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
# Presently server sends back 400 Bad Request # Presently server sends back 400 Bad Request
# if not test_val("Server Error", "501", s.status): # if not test_val("Server Error", "501", s.status):
# s.close() # s.close()
# return False # return False
if not test_val("Server Error", "400", s.status): if not test_val('Server Error', '400', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_505_version_not_supported(dut, port): def code_505_version_not_supported(dut, port):
Utility.console_log("[test] 505 Version Not Supported =>", end=' ') Utility.console_log('[test] 505 Version Not Supported =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
path = "/hello" path = '/hello'
s.client.sendall(("GET " + path + " HTTP/2.0\r\nHost: " + dut + "\r\n\r\n").encode()) s.client.sendall(('GET ' + path + ' HTTP/2.0\r\nHost: ' + dut + '\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
if not test_val("Server Error", "505", s.status): if not test_val('Server Error', '505', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_400_bad_request(dut, port): def code_400_bad_request(dut, port):
Utility.console_log("[test] 400 Bad Request =>", end=' ') Utility.console_log('[test] 400 Bad Request =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
path = "/hello" path = '/hello'
s.client.sendall(("XYZ " + path + " HTTP/1.1\r\nHost: " + dut + "\r\n\r\n").encode()) s.client.sendall(('XYZ ' + path + ' HTTP/1.1\r\nHost: ' + dut + '\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
if not test_val("Client Error", "400", s.status): if not test_val('Client Error', '400', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_404_not_found(dut, port): def code_404_not_found(dut, port):
Utility.console_log("[test] 404 Not Found =>", end=' ') Utility.console_log('[test] 404 Not Found =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
path = "/dummy" path = '/dummy'
s.client.sendall(("GET " + path + " HTTP/1.1\r\nHost: " + dut + "\r\n\r\n").encode()) s.client.sendall(('GET ' + path + ' HTTP/1.1\r\nHost: ' + dut + '\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
if not test_val("Client Error", "404", s.status): if not test_val('Client Error', '404', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_405_method_not_allowed(dut, port): def code_405_method_not_allowed(dut, port):
Utility.console_log("[test] 405 Method Not Allowed =>", end=' ') Utility.console_log('[test] 405 Method Not Allowed =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
path = "/hello" path = '/hello'
s.client.sendall(("POST " + path + " HTTP/1.1\r\nHost: " + dut + "\r\n\r\n").encode()) s.client.sendall(('POST ' + path + ' HTTP/1.1\r\nHost: ' + dut + '\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
if not test_val("Client Error", "405", s.status): if not test_val('Client Error', '405', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_408_req_timeout(dut, port): def code_408_req_timeout(dut, port):
Utility.console_log("[test] 408 Request Timeout =>", end=' ') Utility.console_log('[test] 408 Request Timeout =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
s.client.sendall(("POST /echo HTTP/1.1\r\nHost: " + dut + "\r\nContent-Length: 10\r\n\r\nABCD").encode()) s.client.sendall(('POST /echo HTTP/1.1\r\nHost: ' + dut + '\r\nContent-Length: 10\r\n\r\nABCD').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
if not test_val("Client Error", "408", s.status): if not test_val('Client Error', '408', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def code_411_length_required(dut, port): def code_411_length_required(dut, port):
Utility.console_log("[test] 411 Length Required =>", end=' ') Utility.console_log('[test] 411 Length Required =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
path = "/echo" path = '/echo'
s.client.sendall(("POST " + path + " HTTP/1.1\r\nHost: " + dut + "\r\nContent-Type: text/plain\r\nTransfer-Encoding: chunked\r\n\r\n").encode()) s.client.sendall(('POST ' + path + ' HTTP/1.1\r\nHost: ' + dut + '\r\nContent-Type: text/plain\r\nTransfer-Encoding: chunked\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
# Presently server sends back 400 Bad Request # Presently server sends back 400 Bad Request
# if not test_val("Client Error", "411", s.status): # if not test_val("Client Error", "411", s.status):
# s.close() # s.close()
# return False # return False
if not test_val("Client Error", "400", s.status): if not test_val('Client Error', '400', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def send_getx_uri_len(dut, port, length): def send_getx_uri_len(dut, port, length):
s = Session(dut, port) s = Session(dut, port)
method = "GET " method = 'GET '
version = " HTTP/1.1\r\n" version = ' HTTP/1.1\r\n'
path = "/" + "x" * (length - len(method) - len(version) - len("/")) path = '/' + 'x' * (length - len(method) - len(version) - len('/'))
s.client.sendall(method.encode()) s.client.sendall(method.encode())
time.sleep(1) time.sleep(1)
s.client.sendall(path.encode()) s.client.sendall(path.encode())
time.sleep(1) time.sleep(1)
s.client.sendall((version + "Host: " + dut + "\r\n\r\n").encode()) s.client.sendall((version + 'Host: ' + dut + '\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
s.close() s.close()
@ -920,59 +917,59 @@ def send_getx_uri_len(dut, port, length):
def code_414_uri_too_long(dut, port, max_uri_len): def code_414_uri_too_long(dut, port, max_uri_len):
Utility.console_log("[test] 414 URI Too Long =>", end=' ') Utility.console_log('[test] 414 URI Too Long =>', end=' ')
status = send_getx_uri_len(dut, port, max_uri_len) status = send_getx_uri_len(dut, port, max_uri_len)
if not test_val("Client Error", "404", status): if not test_val('Client Error', '404', status):
return False return False
status = send_getx_uri_len(dut, port, max_uri_len + 1) status = send_getx_uri_len(dut, port, max_uri_len + 1)
if not test_val("Client Error", "414", status): if not test_val('Client Error', '414', status):
return False return False
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def send_postx_hdr_len(dut, port, length): def send_postx_hdr_len(dut, port, length):
s = Session(dut, port) s = Session(dut, port)
path = "/echo" path = '/echo'
host = "Host: " + dut host = 'Host: ' + dut
custom_hdr_field = "\r\nCustom: " custom_hdr_field = '\r\nCustom: '
custom_hdr_val = "x" * (length - len(host) - len(custom_hdr_field) - len("\r\n\r\n") + len("0")) custom_hdr_val = 'x' * (length - len(host) - len(custom_hdr_field) - len('\r\n\r\n') + len('0'))
request = ("POST " + path + " HTTP/1.1\r\n" + host + custom_hdr_field + custom_hdr_val + "\r\n\r\n").encode() request = ('POST ' + path + ' HTTP/1.1\r\n' + host + custom_hdr_field + custom_hdr_val + '\r\n\r\n').encode()
s.client.sendall(request[:length // 2]) s.client.sendall(request[:length // 2])
time.sleep(1) time.sleep(1)
s.client.sendall(request[length // 2:]) s.client.sendall(request[length // 2:])
hdr = s.read_resp_hdrs() hdr = s.read_resp_hdrs()
resp = s.read_resp_data() resp = s.read_resp_data()
s.close() s.close()
if hdr and ("Custom" in hdr): if hdr and ('Custom' in hdr):
return (hdr["Custom"] == custom_hdr_val), resp return (hdr['Custom'] == custom_hdr_val), resp
return False, s.status return False, s.status
def code_431_hdr_too_long(dut, port, max_hdr_len): def code_431_hdr_too_long(dut, port, max_hdr_len):
Utility.console_log("[test] 431 Header Too Long =>", end=' ') Utility.console_log('[test] 431 Header Too Long =>', end=' ')
res, status = send_postx_hdr_len(dut, port, max_hdr_len) res, status = send_postx_hdr_len(dut, port, max_hdr_len)
if not res: if not res:
return False return False
res, status = send_postx_hdr_len(dut, port, max_hdr_len + 1) res, status = send_postx_hdr_len(dut, port, max_hdr_len + 1)
if not test_val("Client Error", "431", status): if not test_val('Client Error', '431', status):
return False return False
Utility.console_log("Success") Utility.console_log('Success')
return True return True
def test_upgrade_not_supported(dut, port): def test_upgrade_not_supported(dut, port):
Utility.console_log("[test] Upgrade Not Supported =>", end=' ') Utility.console_log('[test] Upgrade Not Supported =>', end=' ')
s = Session(dut, port) s = Session(dut, port)
# path = "/hello" # path = "/hello"
s.client.sendall(("OPTIONS * HTTP/1.1\r\nHost:" + dut + "\r\nUpgrade: TLS/1.0\r\nConnection: Upgrade\r\n\r\n").encode()) s.client.sendall(('OPTIONS * HTTP/1.1\r\nHost:' + dut + '\r\nUpgrade: TLS/1.0\r\nConnection: Upgrade\r\n\r\n').encode())
s.read_resp_hdrs() s.read_resp_hdrs()
s.read_resp_data() s.read_resp_data()
if not test_val("Client Error", "400", s.status): if not test_val('Client Error', '400', s.status):
s.close() s.close()
return False return False
s.close() s.close()
Utility.console_log("Success") Utility.console_log('Success')
return True return True
@ -997,7 +994,7 @@ if __name__ == '__main__':
_verbose_ = True _verbose_ = True
Utility.console_log("### Basic HTTP Client Tests") Utility.console_log('### Basic HTTP Client Tests')
get_hello(dut, port) get_hello(dut, port)
post_hello(dut, port) post_hello(dut, port)
put_hello(dut, port) put_hello(dut, port)
@ -1009,7 +1006,7 @@ if __name__ == '__main__':
get_false_uri(dut, port) get_false_uri(dut, port)
get_test_headers(dut, port) get_test_headers(dut, port)
Utility.console_log("### Error code tests") Utility.console_log('### Error code tests')
code_500_server_error_test(dut, port) code_500_server_error_test(dut, port)
code_501_method_not_impl(dut, port) code_501_method_not_impl(dut, port)
code_505_version_not_supported(dut, port) code_505_version_not_supported(dut, port)
@ -1024,7 +1021,7 @@ if __name__ == '__main__':
# Not supported yet (Error on chunked request) # Not supported yet (Error on chunked request)
# code_411_length_required(dut, port) # code_411_length_required(dut, port)
Utility.console_log("### Sessions and Context Tests") Utility.console_log('### Sessions and Context Tests')
parallel_sessions_adder(dut, port, max_sessions) parallel_sessions_adder(dut, port, max_sessions)
leftover_data_test(dut, port) leftover_data_test(dut, port)
async_response_test(dut, port) async_response_test(dut, port)

View File

@ -14,49 +14,47 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import division from __future__ import division, print_function, unicode_literals
from __future__ import print_function
from __future__ import unicode_literals
from builtins import str
from builtins import range
import re
import os import os
import random import random
import re
from builtins import range, str
from tiny_test_fw import Utility
import ttfw_idf import ttfw_idf
from idf_http_server_test import adder as client from idf_http_server_test import adder as client
from tiny_test_fw import Utility
# When running on local machine execute the following before running this script # When running on local machine execute the following before running this script
# > make app bootloader # > make app bootloader
# > make print_flash_cmd | tail -n 1 > build/download.config # > make print_flash_cmd | tail -n 1 > build/download.config
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_http_server_persistence(env, extra_data): def test_examples_protocol_http_server_persistence(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("http_server", "examples/protocols/http_server/persistent_sockets", dut1 = env.get_dut('http_server', 'examples/protocols/http_server/persistent_sockets',
dut_class=ttfw_idf.ESP32DUT) dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "persistent_sockets.bin") binary_file = os.path.join(dut1.app.binary_path, 'persistent_sockets.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("http_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('http_server_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting http_server persistance test app") Utility.console_log('Starting http_server persistance test app')
dut1.start_app() dut1.start_app()
# Parse IP address of STA # Parse IP address of STA
Utility.console_log("Waiting to connect with AP") Utility.console_log('Waiting to connect with AP')
got_ip = dut1.expect(re.compile(r"(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)"), timeout=30)[0] got_ip = dut1.expect(re.compile(r'(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)'), timeout=30)[0]
got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Starting server on port: '(\d+)'"), timeout=30)[0] got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Starting server on port: '(\d+)'"), timeout=30)[0]
Utility.console_log("Got IP : " + got_ip) Utility.console_log('Got IP : ' + got_ip)
Utility.console_log("Got Port : " + got_port) Utility.console_log('Got Port : ' + got_port)
# Expected Logs # Expected Logs
dut1.expect("Registering URI handlers", timeout=30) dut1.expect('Registering URI handlers', timeout=30)
# Run test script # Run test script
conn = client.start_session(got_ip, got_port) conn = client.start_session(got_ip, got_port)
@ -65,23 +63,23 @@ def test_examples_protocol_http_server_persistence(env, extra_data):
# Test PUT request and initialize session context # Test PUT request and initialize session context
num = random.randint(0,100) num = random.randint(0,100)
client.putreq(conn, "/adder", str(num)) client.putreq(conn, '/adder', str(num))
visitor += 1 visitor += 1
dut1.expect("/adder visitor count = " + str(visitor), timeout=30) dut1.expect('/adder visitor count = ' + str(visitor), timeout=30)
dut1.expect("/adder PUT handler read " + str(num), timeout=30) dut1.expect('/adder PUT handler read ' + str(num), timeout=30)
dut1.expect("PUT allocating new session", timeout=30) dut1.expect('PUT allocating new session', timeout=30)
# Retest PUT request and change session context value # Retest PUT request and change session context value
num = random.randint(0,100) num = random.randint(0,100)
Utility.console_log("Adding: " + str(num)) Utility.console_log('Adding: ' + str(num))
client.putreq(conn, "/adder", str(num)) client.putreq(conn, '/adder', str(num))
visitor += 1 visitor += 1
adder += num adder += num
dut1.expect("/adder visitor count = " + str(visitor), timeout=30) dut1.expect('/adder visitor count = ' + str(visitor), timeout=30)
dut1.expect("/adder PUT handler read " + str(num), timeout=30) dut1.expect('/adder PUT handler read ' + str(num), timeout=30)
try: try:
# Re allocation shouldn't happen # Re allocation shouldn't happen
dut1.expect("PUT allocating new session", timeout=30) dut1.expect('PUT allocating new session', timeout=30)
# Not expected # Not expected
raise RuntimeError raise RuntimeError
except Exception: except Exception:
@ -91,37 +89,37 @@ def test_examples_protocol_http_server_persistence(env, extra_data):
# Test POST request and session persistence # Test POST request and session persistence
random_nums = [random.randint(0,100) for _ in range(100)] random_nums = [random.randint(0,100) for _ in range(100)]
for num in random_nums: for num in random_nums:
Utility.console_log("Adding: " + str(num)) Utility.console_log('Adding: ' + str(num))
client.postreq(conn, "/adder", str(num)) client.postreq(conn, '/adder', str(num))
visitor += 1 visitor += 1
adder += num adder += num
dut1.expect("/adder visitor count = " + str(visitor), timeout=30) dut1.expect('/adder visitor count = ' + str(visitor), timeout=30)
dut1.expect("/adder handler read " + str(num), timeout=30) dut1.expect('/adder handler read ' + str(num), timeout=30)
# Test GET request and session persistence # Test GET request and session persistence
Utility.console_log("Matching final sum: " + str(adder)) Utility.console_log('Matching final sum: ' + str(adder))
if client.getreq(conn, "/adder").decode() != str(adder): if client.getreq(conn, '/adder').decode() != str(adder):
raise RuntimeError raise RuntimeError
visitor += 1 visitor += 1
dut1.expect("/adder visitor count = " + str(visitor), timeout=30) dut1.expect('/adder visitor count = ' + str(visitor), timeout=30)
dut1.expect("/adder GET handler send " + str(adder), timeout=30) dut1.expect('/adder GET handler send ' + str(adder), timeout=30)
Utility.console_log("Ending session") Utility.console_log('Ending session')
# Close connection and check for invocation of context "Free" function # Close connection and check for invocation of context "Free" function
client.end_session(conn) client.end_session(conn)
dut1.expect("/adder Free Context function called", timeout=30) dut1.expect('/adder Free Context function called', timeout=30)
Utility.console_log("Validating user context data") Utility.console_log('Validating user context data')
# Start another session to check user context data # Start another session to check user context data
client.start_session(got_ip, got_port) client.start_session(got_ip, got_port)
num = random.randint(0,100) num = random.randint(0,100)
client.putreq(conn, "/adder", str(num)) client.putreq(conn, '/adder', str(num))
visitor += 1 visitor += 1
dut1.expect("/adder visitor count = " + str(visitor), timeout=30) dut1.expect('/adder visitor count = ' + str(visitor), timeout=30)
dut1.expect("/adder PUT handler read " + str(num), timeout=30) dut1.expect('/adder PUT handler read ' + str(num), timeout=30)
dut1.expect("PUT allocating new session", timeout=30) dut1.expect('PUT allocating new session', timeout=30)
client.end_session(conn) client.end_session(conn)
dut1.expect("/adder Free Context function called", timeout=30) dut1.expect('/adder Free Context function called', timeout=30)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -14,22 +14,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import division from __future__ import division, print_function, unicode_literals
from __future__ import print_function
from __future__ import unicode_literals
from builtins import range
import re
import os
import string
import random
import os
import random
import re
import socket
import string
import threading import threading
import time import time
import socket from builtins import range
from tiny_test_fw import Utility
import ttfw_idf import ttfw_idf
from idf_http_server_test import client from idf_http_server_test import client
from tiny_test_fw import Utility
class http_client_thread(threading.Thread): class http_client_thread(threading.Thread):
@ -64,97 +62,97 @@ class http_client_thread(threading.Thread):
# > make print_flash_cmd | tail -n 1 > build/download.config # > make print_flash_cmd | tail -n 1 > build/download.config
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_http_server_simple(env, extra_data): def test_examples_protocol_http_server_simple(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("http_server", "examples/protocols/http_server/simple", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('http_server', 'examples/protocols/http_server/simple', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "simple.bin") binary_file = os.path.join(dut1.app.binary_path, 'simple.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("http_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('http_server_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting http_server simple test app") Utility.console_log('Starting http_server simple test app')
dut1.start_app() dut1.start_app()
# Parse IP address of STA # Parse IP address of STA
Utility.console_log("Waiting to connect with AP") Utility.console_log('Waiting to connect with AP')
got_ip = dut1.expect(re.compile(r"(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)"), timeout=30)[0] got_ip = dut1.expect(re.compile(r'(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)'), timeout=30)[0]
got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Starting server on port: '(\d+)'"), timeout=30)[0] got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Starting server on port: '(\d+)'"), timeout=30)[0]
Utility.console_log("Got IP : " + got_ip) Utility.console_log('Got IP : ' + got_ip)
Utility.console_log("Got Port : " + got_port) Utility.console_log('Got Port : ' + got_port)
# Expected Logs # Expected Logs
dut1.expect("Registering URI handlers", timeout=30) dut1.expect('Registering URI handlers', timeout=30)
# Run test script # Run test script
# If failed raise appropriate exception # If failed raise appropriate exception
Utility.console_log("Test /hello GET handler") Utility.console_log('Test /hello GET handler')
if not client.test_get_handler(got_ip, got_port): if not client.test_get_handler(got_ip, got_port):
raise RuntimeError raise RuntimeError
# Acquire host IP. Need a way to check it # Acquire host IP. Need a way to check it
dut1.expect(re.compile(r"(?:[\s\S]*)Found header => Host: (\d+.\d+.\d+.\d+)"), timeout=30)[0] dut1.expect(re.compile(r'(?:[\s\S]*)Found header => Host: (\d+.\d+.\d+.\d+)'), timeout=30)[0]
# Match additional headers sent in the request # Match additional headers sent in the request
dut1.expect("Found header => Test-Header-2: Test-Value-2", timeout=30) dut1.expect('Found header => Test-Header-2: Test-Value-2', timeout=30)
dut1.expect("Found header => Test-Header-1: Test-Value-1", timeout=30) dut1.expect('Found header => Test-Header-1: Test-Value-1', timeout=30)
dut1.expect("Found URL query parameter => query1=value1", timeout=30) dut1.expect('Found URL query parameter => query1=value1', timeout=30)
dut1.expect("Found URL query parameter => query3=value3", timeout=30) dut1.expect('Found URL query parameter => query3=value3', timeout=30)
dut1.expect("Found URL query parameter => query2=value2", timeout=30) dut1.expect('Found URL query parameter => query2=value2', timeout=30)
dut1.expect("Request headers lost", timeout=30) dut1.expect('Request headers lost', timeout=30)
Utility.console_log("Test /ctrl PUT handler and realtime handler de/registration") Utility.console_log('Test /ctrl PUT handler and realtime handler de/registration')
if not client.test_put_handler(got_ip, got_port): if not client.test_put_handler(got_ip, got_port):
raise RuntimeError raise RuntimeError
dut1.expect("Unregistering /hello and /echo URIs", timeout=30) dut1.expect('Unregistering /hello and /echo URIs', timeout=30)
dut1.expect("Registering /hello and /echo URIs", timeout=30) dut1.expect('Registering /hello and /echo URIs', timeout=30)
# Generate random data of 10KB # Generate random data of 10KB
random_data = ''.join(string.printable[random.randint(0,len(string.printable)) - 1] for _ in range(10 * 1024)) random_data = ''.join(string.printable[random.randint(0,len(string.printable)) - 1] for _ in range(10 * 1024))
Utility.console_log("Test /echo POST handler with random data") Utility.console_log('Test /echo POST handler with random data')
if not client.test_post_handler(got_ip, got_port, random_data): if not client.test_post_handler(got_ip, got_port, random_data):
raise RuntimeError raise RuntimeError
query = "http://foobar" query = 'http://foobar'
Utility.console_log("Test /hello with custom query : " + query) Utility.console_log('Test /hello with custom query : ' + query)
if not client.test_custom_uri_query(got_ip, got_port, query): if not client.test_custom_uri_query(got_ip, got_port, query):
raise RuntimeError raise RuntimeError
dut1.expect("Found URL query => " + query, timeout=30) dut1.expect('Found URL query => ' + query, timeout=30)
query = "abcd+1234%20xyz" query = 'abcd+1234%20xyz'
Utility.console_log("Test /hello with custom query : " + query) Utility.console_log('Test /hello with custom query : ' + query)
if not client.test_custom_uri_query(got_ip, got_port, query): if not client.test_custom_uri_query(got_ip, got_port, query):
raise RuntimeError raise RuntimeError
dut1.expect("Found URL query => " + query, timeout=30) dut1.expect('Found URL query => ' + query, timeout=30)
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_http_server_lru_purge_enable(env, extra_data): def test_examples_protocol_http_server_lru_purge_enable(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("http_server", "examples/protocols/http_server/simple", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('http_server', 'examples/protocols/http_server/simple', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "simple.bin") binary_file = os.path.join(dut1.app.binary_path, 'simple.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("http_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('http_server_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting http_server simple test app") Utility.console_log('Starting http_server simple test app')
dut1.start_app() dut1.start_app()
# Parse IP address of STA # Parse IP address of STA
Utility.console_log("Waiting to connect with AP") Utility.console_log('Waiting to connect with AP')
got_ip = dut1.expect(re.compile(r"(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)"), timeout=30)[0] got_ip = dut1.expect(re.compile(r'(?:[\s\S]*)IPv4 address: (\d+.\d+.\d+.\d+)'), timeout=30)[0]
got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Starting server on port: '(\d+)'"), timeout=30)[0] got_port = dut1.expect(re.compile(r"(?:[\s\S]*)Starting server on port: '(\d+)'"), timeout=30)[0]
Utility.console_log("Got IP : " + got_ip) Utility.console_log('Got IP : ' + got_ip)
Utility.console_log("Got Port : " + got_port) Utility.console_log('Got Port : ' + got_port)
# Expected Logs # Expected Logs
dut1.expect("Registering URI handlers", timeout=30) dut1.expect('Registering URI handlers', timeout=30)
threads = [] threads = []
# Open 20 sockets, one from each thread # Open 20 sockets, one from each thread
for _ in range(20): for _ in range(20):
@ -163,7 +161,7 @@ def test_examples_protocol_http_server_lru_purge_enable(env, extra_data):
thread.start() thread.start()
threads.append(thread) threads.append(thread)
except OSError as err: except OSError as err:
Utility.console_log("Error: unable to start thread, " + err) Utility.console_log('Error: unable to start thread, ' + err)
for t in threads: for t in threads:
t.join() t.join()

View File

@ -14,15 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import division from __future__ import division, print_function, unicode_literals
from __future__ import print_function
from __future__ import unicode_literals
import re
from tiny_test_fw import Utility
import ttfw_idf
import os
import websocket
import os
import re
import ttfw_idf
import websocket
from tiny_test_fw import Utility
OPCODE_TEXT = 0x1 OPCODE_TEXT = 0x1
OPCODE_BIN = 0x2 OPCODE_BIN = 0x2
@ -37,7 +36,7 @@ class WsClient:
self.ws = websocket.WebSocket() self.ws = websocket.WebSocket()
def __enter__(self): def __enter__(self):
self.ws.connect("ws://{}:{}/ws".format(self.ip, self.port)) self.ws.connect('ws://{}:{}/ws'.format(self.ip, self.port))
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
@ -46,7 +45,7 @@ class WsClient:
def read(self): def read(self):
return self.ws.recv_data(control_frame=True) return self.ws.recv_data(control_frame=True)
def write(self, data="", opcode=OPCODE_TEXT): def write(self, data='', opcode=OPCODE_TEXT):
if opcode == OPCODE_BIN: if opcode == OPCODE_BIN:
return self.ws.send_binary(data.encode()) return self.ws.send_binary(data.encode())
if opcode == OPCODE_PING: if opcode == OPCODE_PING:
@ -54,27 +53,27 @@ class WsClient:
return self.ws.send(data) return self.ws.send(data)
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_http_ws_echo_server(env, extra_data): def test_examples_protocol_http_ws_echo_server(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("http_server", "examples/protocols/http_server/ws_echo_server", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('http_server', 'examples/protocols/http_server/ws_echo_server', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "ws_echo_server.bin") binary_file = os.path.join(dut1.app.binary_path, 'ws_echo_server.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("http_ws_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('http_ws_server_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
Utility.console_log("Starting ws-echo-server test app based on http_server") Utility.console_log('Starting ws-echo-server test app based on http_server')
dut1.start_app() dut1.start_app()
# Parse IP address of STA # Parse IP address of STA
Utility.console_log("Waiting to connect with AP") Utility.console_log('Waiting to connect with AP')
got_ip = dut1.expect(re.compile(r"IPv4 address: (\d+.\d+.\d+.\d+)"), timeout=60)[0] got_ip = dut1.expect(re.compile(r'IPv4 address: (\d+.\d+.\d+.\d+)'), timeout=60)[0]
got_port = dut1.expect(re.compile(r"Starting server on port: '(\d+)'"), timeout=60)[0] got_port = dut1.expect(re.compile(r"Starting server on port: '(\d+)'"), timeout=60)[0]
Utility.console_log("Got IP : " + got_ip) Utility.console_log('Got IP : ' + got_ip)
Utility.console_log("Got Port : " + got_port) Utility.console_log('Got Port : ' + got_port)
# Start ws server test # Start ws server test
with WsClient(got_ip, int(got_port)) as ws: with WsClient(got_ip, int(got_port)) as ws:
@ -82,23 +81,23 @@ def test_examples_protocol_http_ws_echo_server(env, extra_data):
for expected_opcode in [OPCODE_TEXT, OPCODE_BIN, OPCODE_PING]: for expected_opcode in [OPCODE_TEXT, OPCODE_BIN, OPCODE_PING]:
ws.write(data=DATA, opcode=expected_opcode) ws.write(data=DATA, opcode=expected_opcode)
opcode, data = ws.read() opcode, data = ws.read()
Utility.console_log("Testing opcode {}: Received opcode:{}, data:{}".format(expected_opcode, opcode, data)) Utility.console_log('Testing opcode {}: Received opcode:{}, data:{}'.format(expected_opcode, opcode, data))
data = data.decode() data = data.decode()
if expected_opcode == OPCODE_PING: if expected_opcode == OPCODE_PING:
dut1.expect("Got a WS PING frame, Replying PONG") dut1.expect('Got a WS PING frame, Replying PONG')
if opcode != OPCODE_PONG or data != DATA: if opcode != OPCODE_PONG or data != DATA:
raise RuntimeError("Failed to receive correct opcode:{} or data:{}".format(opcode, data)) raise RuntimeError('Failed to receive correct opcode:{} or data:{}'.format(opcode, data))
continue continue
dut_data = dut1.expect(re.compile(r"Got packet with message: ([A-Za-z0-9_]*)"))[0] dut_data = dut1.expect(re.compile(r'Got packet with message: ([A-Za-z0-9_]*)'))[0]
dut_opcode = int(dut1.expect(re.compile(r"Packet type: ([0-9]*)"))[0]) dut_opcode = int(dut1.expect(re.compile(r'Packet type: ([0-9]*)'))[0])
if opcode != expected_opcode or data != DATA or opcode != dut_opcode or data != dut_data: if opcode != expected_opcode or data != DATA or opcode != dut_opcode or data != dut_data:
raise RuntimeError("Failed to receive correct opcode:{} or data:{}".format(opcode, data)) raise RuntimeError('Failed to receive correct opcode:{} or data:{}'.format(opcode, data))
ws.write(data="Trigger async", opcode=OPCODE_TEXT) ws.write(data='Trigger async', opcode=OPCODE_TEXT)
opcode, data = ws.read() opcode, data = ws.read()
Utility.console_log("Testing async send: Received opcode:{}, data:{}".format(opcode, data)) Utility.console_log('Testing async send: Received opcode:{}, data:{}'.format(opcode, data))
data = data.decode() data = data.decode()
if opcode != OPCODE_TEXT or data != "Async data": if opcode != OPCODE_TEXT or data != 'Async data':
raise RuntimeError("Failed to receive correct opcode:{} or data:{}".format(opcode, data)) raise RuntimeError('Failed to receive correct opcode:{} or data:{}'.format(opcode, data))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,10 +1,11 @@
import os import os
import re import re
import ttfw_idf import ttfw_idf
from tiny_test_fw import Utility from tiny_test_fw import Utility
@ttfw_idf.idf_example_test(env_tag="Example_EthKitV1") @ttfw_idf.idf_example_test(env_tag='Example_EthKitV1')
def test_examples_protocol_https_request(env, extra_data): def test_examples_protocol_https_request(env, extra_data):
""" """
steps: | steps: |
@ -13,24 +14,24 @@ def test_examples_protocol_https_request(env, extra_data):
certificate verification options certificate verification options
3. send http request 3. send http request
""" """
dut1 = env.get_dut("https_request", "examples/protocols/https_request", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('https_request', 'examples/protocols/https_request', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "https_request.bin") binary_file = os.path.join(dut1.app.binary_path, 'https_request.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("https_request_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('https_request_bin_size', '{}KB'.format(bin_size // 1024))
# start tes # start tes
Utility.console_log("Starting https_request simple test app") Utility.console_log('Starting https_request simple test app')
dut1.start_app() dut1.start_app()
# Check for connection using crt bundle # Check for connection using crt bundle
Utility.console_log("Testing for \"https_request using crt bundle\"") Utility.console_log("Testing for \"https_request using crt bundle\"")
try: try:
dut1.expect(re.compile("https_request using crt bundle"), timeout=30) dut1.expect(re.compile('https_request using crt bundle'), timeout=30)
dut1.expect_all("Certificate validated", dut1.expect_all('Certificate validated',
"Connection established...", 'Connection established...',
"Reading HTTP response...", 'Reading HTTP response...',
"HTTP/1.1 200 OK", 'HTTP/1.1 200 OK',
re.compile("connection closed")) re.compile('connection closed'))
except Exception: except Exception:
Utility.console_log("Failed the test for \"https_request using crt bundle\"") Utility.console_log("Failed the test for \"https_request using crt bundle\"")
raise raise
@ -39,11 +40,11 @@ def test_examples_protocol_https_request(env, extra_data):
# Check for connection using cacert_buf # Check for connection using cacert_buf
Utility.console_log("Testing for \"https_request using cacert_buf\"") Utility.console_log("Testing for \"https_request using cacert_buf\"")
try: try:
dut1.expect(re.compile("https_request using cacert_buf"), timeout=20) dut1.expect(re.compile('https_request using cacert_buf'), timeout=20)
dut1.expect_all("Connection established...", dut1.expect_all('Connection established...',
"Reading HTTP response...", 'Reading HTTP response...',
"HTTP/1.1 200 OK", 'HTTP/1.1 200 OK',
re.compile("connection closed")) re.compile('connection closed'))
except Exception: except Exception:
Utility.console_log("Passed the test for \"https_request using cacert_buf\"") Utility.console_log("Passed the test for \"https_request using cacert_buf\"")
raise raise
@ -52,32 +53,32 @@ def test_examples_protocol_https_request(env, extra_data):
# Check for connection using global ca_store # Check for connection using global ca_store
Utility.console_log("Testing for \"https_request using global ca_store\"") Utility.console_log("Testing for \"https_request using global ca_store\"")
try: try:
dut1.expect(re.compile("https_request using global ca_store"), timeout=20) dut1.expect(re.compile('https_request using global ca_store'), timeout=20)
dut1.expect_all("Connection established...", dut1.expect_all('Connection established...',
"Reading HTTP response...", 'Reading HTTP response...',
"HTTP/1.1 200 OK", 'HTTP/1.1 200 OK',
re.compile("connection closed")) re.compile('connection closed'))
except Exception: except Exception:
Utility.console_log("Failed the test for \"https_request using global ca_store\"") Utility.console_log("Failed the test for \"https_request using global ca_store\"")
raise raise
Utility.console_log("Passed the test for \"https_request using global ca_store\"") Utility.console_log("Passed the test for \"https_request using global ca_store\"")
# Check for connection using crt bundle with mbedtls dynamic resource enabled # Check for connection using crt bundle with mbedtls dynamic resource enabled
dut1 = env.get_dut("https_request", "examples/protocols/https_request", dut_class=ttfw_idf.ESP32DUT, app_config_name='ssldyn') dut1 = env.get_dut('https_request', 'examples/protocols/https_request', dut_class=ttfw_idf.ESP32DUT, app_config_name='ssldyn')
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "https_request.bin") binary_file = os.path.join(dut1.app.binary_path, 'https_request.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("https_request_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('https_request_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
# only check if one connection is established # only check if one connection is established
Utility.console_log("Testing for \"https_request using crt bundle\" with mbedtls dynamic resource enabled") Utility.console_log("Testing for \"https_request using crt bundle\" with mbedtls dynamic resource enabled")
try: try:
dut1.expect(re.compile("https_request using crt bundle"), timeout=30) dut1.expect(re.compile('https_request using crt bundle'), timeout=30)
dut1.expect_all("Connection established...", dut1.expect_all('Connection established...',
"Reading HTTP response...", 'Reading HTTP response...',
"HTTP/1.1 200 OK", 'HTTP/1.1 200 OK',
re.compile("connection closed")) re.compile('connection closed'))
except Exception: except Exception:
Utility.console_log("Failed the test for \"https_request using crt bundle\" when mbedtls dynamic resource was enabled") Utility.console_log("Failed the test for \"https_request using crt bundle\" when mbedtls dynamic resource was enabled")
raise raise

View File

@ -1,9 +1,10 @@
import os import os
import re import re
import ttfw_idf import ttfw_idf
@ttfw_idf.idf_example_test(env_tag="Example_WIFI", ignore=True) @ttfw_idf.idf_example_test(env_tag='Example_WIFI', ignore=True)
def test_examples_protocol_https_x509_bundle(env, extra_data): def test_examples_protocol_https_x509_bundle(env, extra_data):
""" """
steps: | steps: |
@ -11,28 +12,28 @@ def test_examples_protocol_https_x509_bundle(env, extra_data):
2. connect to multiple URLs 2. connect to multiple URLs
3. send http request 3. send http request
""" """
dut1 = env.get_dut("https_x509_bundle", "examples/protocols/https_x509_bundle") dut1 = env.get_dut('https_x509_bundle', 'examples/protocols/https_x509_bundle')
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "https_x509_bundle.bin") binary_file = os.path.join(dut1.app.binary_path, 'https_x509_bundle.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("https_x509_bundle_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('https_x509_bundle_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
num_URLS = dut1.expect(re.compile(r"Connecting to (\d+) URLs"), timeout=30) num_URLS = dut1.expect(re.compile(r'Connecting to (\d+) URLs'), timeout=30)
dut1.expect(re.compile(r"Connection established to ([\s\S]*)"), timeout=30) dut1.expect(re.compile(r'Connection established to ([\s\S]*)'), timeout=30)
dut1.expect("Completed {} connections".format(num_URLS[0]), timeout=60) dut1.expect('Completed {} connections'.format(num_URLS[0]), timeout=60)
# test mbedtls dynamic resource # test mbedtls dynamic resource
dut1 = env.get_dut("https_x509_bundle", "examples/protocols/https_x509_bundle", app_config_name='ssldyn') dut1 = env.get_dut('https_x509_bundle', 'examples/protocols/https_x509_bundle', app_config_name='ssldyn')
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "https_x509_bundle.bin") binary_file = os.path.join(dut1.app.binary_path, 'https_x509_bundle.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("https_x509_bundle_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('https_x509_bundle_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
num_URLS = dut1.expect(re.compile(r"Connecting to (\d+) URLs"), timeout=30) num_URLS = dut1.expect(re.compile(r'Connecting to (\d+) URLs'), timeout=30)
dut1.expect(re.compile(r"Connection established to ([\s\S]*)"), timeout=30) dut1.expect(re.compile(r'Connection established to ([\s\S]*)'), timeout=30)
dut1.expect("Completed {} connections".format(num_URLS[0]), timeout=60) dut1.expect('Completed {} connections'.format(num_URLS[0]), timeout=60)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re
import ttfw_idf
import os import os
import re
import ttfw_idf
@ttfw_idf.idf_example_test(env_tag='Example_WIFI') @ttfw_idf.idf_example_test(env_tag='Example_WIFI')

View File

@ -1,15 +1,15 @@
import re
import os import os
import re
import socket import socket
import time
import struct import struct
import subprocess
import time
from threading import Event, Thread
import dpkt import dpkt
import dpkt.dns import dpkt.dns
from threading import Thread, Event
import subprocess
from tiny_test_fw import DUT
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT
stop_mdns_server = Event() stop_mdns_server = Event()
esp_answered = Event() esp_answered = Event()
@ -18,7 +18,7 @@ esp_answered = Event()
def get_dns_query_for_esp(esp_host): def get_dns_query_for_esp(esp_host):
dns = dpkt.dns.DNS(b'\x00\x00\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01') dns = dpkt.dns.DNS(b'\x00\x00\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01')
dns.qd[0].name = esp_host + u'.local' dns.qd[0].name = esp_host + u'.local'
print("Created query for esp host: {} ".format(dns.__repr__())) print('Created query for esp host: {} '.format(dns.__repr__()))
return dns.pack() return dns.pack()
@ -32,26 +32,26 @@ def get_dns_answer_to_mdns(tester_host):
arr.name = tester_host arr.name = tester_host
arr.ip = socket.inet_aton('127.0.0.1') arr.ip = socket.inet_aton('127.0.0.1')
dns. an.append(arr) dns. an.append(arr)
print("Created answer to mdns query: {} ".format(dns.__repr__())) print('Created answer to mdns query: {} '.format(dns.__repr__()))
return dns.pack() return dns.pack()
def get_dns_answer_to_mdns_lwip(tester_host, id): def get_dns_answer_to_mdns_lwip(tester_host, id):
dns = dpkt.dns.DNS(b"\x5e\x39\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x0a\x64\x61\x76\x69\x64" dns = dpkt.dns.DNS(b'\x5e\x39\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x0a\x64\x61\x76\x69\x64'
b"\x2d\x63\x6f\x6d\x70\x05\x6c\x6f\x63\x61\x6c\x00\x00\x01\x00\x01\xc0\x0c" b'\x2d\x63\x6f\x6d\x70\x05\x6c\x6f\x63\x61\x6c\x00\x00\x01\x00\x01\xc0\x0c'
b"\x00\x01\x00\x01\x00\x00\x00\x0a\x00\x04\xc0\xa8\x0a\x6c") b'\x00\x01\x00\x01\x00\x00\x00\x0a\x00\x04\xc0\xa8\x0a\x6c')
dns.qd[0].name = tester_host dns.qd[0].name = tester_host
dns.an[0].name = tester_host dns.an[0].name = tester_host
dns.an[0].ip = socket.inet_aton('127.0.0.1') dns.an[0].ip = socket.inet_aton('127.0.0.1')
dns.an[0].rdata = socket.inet_aton('127.0.0.1') dns.an[0].rdata = socket.inet_aton('127.0.0.1')
dns.id = id dns.id = id
print("Created answer to mdns (lwip) query: {} ".format(dns.__repr__())) print('Created answer to mdns (lwip) query: {} '.format(dns.__repr__()))
return dns.pack() return dns.pack()
def mdns_server(esp_host): def mdns_server(esp_host):
global esp_answered global esp_answered
UDP_IP = "0.0.0.0" UDP_IP = '0.0.0.0'
UDP_PORT = 5353 UDP_PORT = 5353
MCAST_GRP = '224.0.0.251' MCAST_GRP = '224.0.0.251'
TESTER_NAME = u'tinytester.local' TESTER_NAME = u'tinytester.local'
@ -60,7 +60,7 @@ def mdns_server(esp_host):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind((UDP_IP,UDP_PORT)) sock.bind((UDP_IP,UDP_PORT))
mreq = struct.pack("4sl", socket.inet_aton(MCAST_GRP), socket.INADDR_ANY) mreq = struct.pack('4sl', socket.inet_aton(MCAST_GRP), socket.INADDR_ANY)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
sock.settimeout(30) sock.settimeout(30)
while not stop_mdns_server.is_set(): while not stop_mdns_server.is_set():
@ -72,14 +72,14 @@ def mdns_server(esp_host):
dns = dpkt.dns.DNS(data) dns = dpkt.dns.DNS(data)
if len(dns.qd) > 0 and dns.qd[0].type == dpkt.dns.DNS_A: if len(dns.qd) > 0 and dns.qd[0].type == dpkt.dns.DNS_A:
if dns.qd[0].name == TESTER_NAME: if dns.qd[0].name == TESTER_NAME:
print("Received query: {} ".format(dns.__repr__())) print('Received query: {} '.format(dns.__repr__()))
sock.sendto(get_dns_answer_to_mdns(TESTER_NAME), (MCAST_GRP,UDP_PORT)) sock.sendto(get_dns_answer_to_mdns(TESTER_NAME), (MCAST_GRP,UDP_PORT))
elif dns.qd[0].name == TESTER_NAME_LWIP: elif dns.qd[0].name == TESTER_NAME_LWIP:
print("Received query: {} ".format(dns.__repr__())) print('Received query: {} '.format(dns.__repr__()))
sock.sendto(get_dns_answer_to_mdns_lwip(TESTER_NAME_LWIP, dns.id), addr) sock.sendto(get_dns_answer_to_mdns_lwip(TESTER_NAME_LWIP, dns.id), addr)
if len(dns.an) > 0 and dns.an[0].type == dpkt.dns.DNS_A: if len(dns.an) > 0 and dns.an[0].type == dpkt.dns.DNS_A:
if dns.an[0].name == esp_host + u'.local': if dns.an[0].name == esp_host + u'.local':
print("Received answer to esp32-mdns query: {}".format(dns.__repr__())) print('Received answer to esp32-mdns query: {}'.format(dns.__repr__()))
esp_answered.set() esp_answered.set()
except socket.timeout: except socket.timeout:
break break
@ -87,7 +87,7 @@ def mdns_server(esp_host):
continue continue
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_mdns(env, extra_data): def test_examples_protocol_mdns(env, extra_data):
global stop_mdns_server global stop_mdns_server
""" """
@ -97,21 +97,21 @@ def test_examples_protocol_mdns(env, extra_data):
3. check the mdns name is accessible 3. check the mdns name is accessible
4. check DUT output if mdns advertized host is resolved 4. check DUT output if mdns advertized host is resolved
""" """
dut1 = env.get_dut("mdns-test", "examples/protocols/mdns", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('mdns-test', 'examples/protocols/mdns', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "mdns-test.bin") binary_file = os.path.join(dut1.app.binary_path, 'mdns-test.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("mdns-test_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('mdns-test_bin_size', '{}KB'.format(bin_size // 1024))
# 1. start mdns application # 1. start mdns application
dut1.start_app() dut1.start_app()
# 2. get the dut host name (and IP address) # 2. get the dut host name (and IP address)
specific_host = dut1.expect(re.compile(r"mdns hostname set to: \[([^\]]+)\]"), timeout=30) specific_host = dut1.expect(re.compile(r'mdns hostname set to: \[([^\]]+)\]'), timeout=30)
specific_host = str(specific_host[0]) specific_host = str(specific_host[0])
thread1 = Thread(target=mdns_server, args=(specific_host,)) thread1 = Thread(target=mdns_server, args=(specific_host,))
thread1.start() thread1.start()
try: try:
ip_address = dut1.expect(re.compile(r" sta ip: ([^,]+),"), timeout=30)[0] ip_address = dut1.expect(re.compile(r' sta ip: ([^,]+),'), timeout=30)[0]
print("Connected to AP with IP: {}".format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
stop_mdns_server.set() stop_mdns_server.set()
thread1.join() thread1.join()
@ -121,15 +121,15 @@ def test_examples_protocol_mdns(env, extra_data):
if not esp_answered.wait(timeout=30): if not esp_answered.wait(timeout=30):
raise ValueError('Test has failed: did not receive mdns answer within timeout') raise ValueError('Test has failed: did not receive mdns answer within timeout')
# 4. check DUT output if mdns advertized host is resolved # 4. check DUT output if mdns advertized host is resolved
dut1.expect(re.compile(r"mdns-test: Query A: tinytester.local resolved to: 127.0.0.1"), timeout=30) dut1.expect(re.compile(r'mdns-test: Query A: tinytester.local resolved to: 127.0.0.1'), timeout=30)
dut1.expect(re.compile(r"mdns-test: gethostbyname: tinytester-lwip.local resolved to: 127.0.0.1"), timeout=30) dut1.expect(re.compile(r'mdns-test: gethostbyname: tinytester-lwip.local resolved to: 127.0.0.1'), timeout=30)
dut1.expect(re.compile(r"mdns-test: getaddrinfo: tinytester-lwip.local resolved to: 127.0.0.1"), timeout=30) dut1.expect(re.compile(r'mdns-test: getaddrinfo: tinytester-lwip.local resolved to: 127.0.0.1'), timeout=30)
# 5. check the DUT answers to `dig` command # 5. check the DUT answers to `dig` command
dig_output = subprocess.check_output(['dig', '+short', '-p', '5353', '@224.0.0.251', dig_output = subprocess.check_output(['dig', '+short', '-p', '5353', '@224.0.0.251',
'{}.local'.format(specific_host)]) '{}.local'.format(specific_host)])
print('Resolving {} using "dig" succeeded with:\n{}'.format(specific_host, dig_output)) print('Resolving {} using "dig" succeeded with:\n{}'.format(specific_host, dig_output))
if not ip_address.encode('utf-8') in dig_output: if not ip_address.encode('utf-8') in dig_output:
raise ValueError("Test has failed: Incorrectly resolved DUT hostname using dig" raise ValueError('Test has failed: Incorrectly resolved DUT hostname using dig'
"Output should've contained DUT's IP address:{}".format(ip_address)) "Output should've contained DUT's IP address:{}".format(ip_address))
finally: finally:
stop_mdns_server.set() stop_mdns_server.set()

View File

@ -1,15 +1,15 @@
# Need Python 3 string formatting functions # Need Python 3 string formatting functions
from __future__ import print_function from __future__ import print_function
import logging
import os import os
import re import re
import logging
from threading import Thread from threading import Thread
import ttfw_idf import ttfw_idf
LOG_LEVEL = logging.DEBUG LOG_LEVEL = logging.DEBUG
LOGGER_NAME = "modbus_test" LOGGER_NAME = 'modbus_test'
# Allowed parameter reads # Allowed parameter reads
TEST_READ_MIN_COUNT = 10 # Minimum number of correct readings TEST_READ_MIN_COUNT = 10 # Minimum number of correct readings
@ -27,38 +27,38 @@ TEST_SLAVE_ASCII = 'slave_ascii'
# Define tuple of strings to expect for each DUT. # Define tuple of strings to expect for each DUT.
# #
master_expect = ("MASTER_TEST: Modbus master stack initialized...", "MASTER_TEST: Start modbus test...", "MASTER_TEST: Destroy master...") master_expect = ('MASTER_TEST: Modbus master stack initialized...', 'MASTER_TEST: Start modbus test...', 'MASTER_TEST: Destroy master...')
slave_expect = ("SLAVE_TEST: Modbus slave stack initialized.", "SLAVE_TEST: Start modbus test...", "SLAVE_TEST: Modbus controller destroyed.") slave_expect = ('SLAVE_TEST: Modbus slave stack initialized.', 'SLAVE_TEST: Start modbus test...', 'SLAVE_TEST: Modbus controller destroyed.')
# The dictionary for expected values in listing # The dictionary for expected values in listing
expect_dict_master_ok = {"START": (), expect_dict_master_ok = {'START': (),
"READ_PAR_OK": (), 'READ_PAR_OK': (),
"ALARM_MSG": (u'7',)} 'ALARM_MSG': (u'7',)}
expect_dict_master_err = {"READ_PAR_ERR": (u'263', u'ESP_ERR_TIMEOUT'), expect_dict_master_err = {'READ_PAR_ERR': (u'263', u'ESP_ERR_TIMEOUT'),
"READ_STK_ERR": (u'107', u'ESP_ERR_TIMEOUT')} 'READ_STK_ERR': (u'107', u'ESP_ERR_TIMEOUT')}
# The dictionary for regular expression patterns to check in listing # The dictionary for regular expression patterns to check in listing
pattern_dict_master_ok = {"START": (r'.*I \([0-9]+\) MASTER_TEST: Start modbus test...'), pattern_dict_master_ok = {'START': (r'.*I \([0-9]+\) MASTER_TEST: Start modbus test...'),
"READ_PAR_OK": (r'.*I\s\([0-9]+\) MASTER_TEST: Characteristic #[0-9]+ [a-zA-Z0-9_]+' 'READ_PAR_OK': (r'.*I\s\([0-9]+\) MASTER_TEST: Characteristic #[0-9]+ [a-zA-Z0-9_]+'
r'\s\([a-zA-Z\%\/]+\) value = [a-zA-Z0-9\.\s]*\(0x[a-zA-Z0-9]+\) read successful.'), r'\s\([a-zA-Z\%\/]+\) value = [a-zA-Z0-9\.\s]*\(0x[a-zA-Z0-9]+\) read successful.'),
"ALARM_MSG": (r'.*I \([0-9]*\) MASTER_TEST: Alarm triggered by cid #([0-9]+).')} 'ALARM_MSG': (r'.*I \([0-9]*\) MASTER_TEST: Alarm triggered by cid #([0-9]+).')}
pattern_dict_master_err = {"READ_PAR_ERR_TOUT": (r'.*E \([0-9]+\) MASTER_TEST: Characteristic #[0-9]+' pattern_dict_master_err = {'READ_PAR_ERR_TOUT': (r'.*E \([0-9]+\) MASTER_TEST: Characteristic #[0-9]+'
r'\s\([a-zA-Z0-9_]+\) read fail, err = [0-9]+ \([_A-Z]+\).'), r'\s\([a-zA-Z0-9_]+\) read fail, err = [0-9]+ \([_A-Z]+\).'),
"READ_STK_ERR_TOUT": (r'.*E \([0-9]+\) MB_CONTROLLER_MASTER: [a-zA-Z0-9_]+\([0-9]+\):\s' 'READ_STK_ERR_TOUT': (r'.*E \([0-9]+\) MB_CONTROLLER_MASTER: [a-zA-Z0-9_]+\([0-9]+\):\s'
r'SERIAL master get parameter failure error=\(0x([a-zA-Z0-9]+)\) \(([_A-Z]+)\).')} r'SERIAL master get parameter failure error=\(0x([a-zA-Z0-9]+)\) \(([_A-Z]+)\).')}
# The dictionary for expected values in listing # The dictionary for expected values in listing
expect_dict_slave_ok = {"START": (), expect_dict_slave_ok = {'START': (),
"READ_PAR_OK": (), 'READ_PAR_OK': (),
"DESTROY": ()} 'DESTROY': ()}
# The dictionary for regular expression patterns to check in listing # The dictionary for regular expression patterns to check in listing
pattern_dict_slave_ok = {"START": (r'.*I \([0-9]+\) SLAVE_TEST: Start modbus test...'), pattern_dict_slave_ok = {'START': (r'.*I \([0-9]+\) SLAVE_TEST: Start modbus test...'),
"READ_PAR_OK": (r'.*I\s\([0-9]+\) SLAVE_TEST: [A-Z]+ READ \([a-zA-Z0-9_]+ us\),\s' 'READ_PAR_OK': (r'.*I\s\([0-9]+\) SLAVE_TEST: [A-Z]+ READ \([a-zA-Z0-9_]+ us\),\s'
r'ADDR:[0-9]+, TYPE:[0-9]+, INST_ADDR:0x[a-zA-Z0-9]+, SIZE:[0-9]+'), r'ADDR:[0-9]+, TYPE:[0-9]+, INST_ADDR:0x[a-zA-Z0-9]+, SIZE:[0-9]+'),
"DESTROY": (r'.*I\s\([0-9]+\) SLAVE_TEST: Modbus controller destroyed.')} 'DESTROY': (r'.*I\s\([0-9]+\) SLAVE_TEST: Modbus controller destroyed.')}
logger = logging.getLogger(LOGGER_NAME) logger = logging.getLogger(LOGGER_NAME)
@ -89,8 +89,8 @@ class DutTestThread(Thread):
# Check DUT exceptions # Check DUT exceptions
dut_exceptions = self.dut.get_exceptions() dut_exceptions = self.dut.get_exceptions()
if "Guru Meditation Error:" in dut_exceptions: if 'Guru Meditation Error:' in dut_exceptions:
raise Exception("%s generated an exception: %s\n" % (str(self.dut), dut_exceptions)) raise Exception('%s generated an exception: %s\n' % (str(self.dut), dut_exceptions))
# Mark thread has run to completion without any exceptions # Mark thread has run to completion without any exceptions
self.data = self.dut.stop_capture_raw_data() self.data = self.dut.stop_capture_raw_data()
@ -102,7 +102,7 @@ def test_filter_output(data=None, start_pattern=None, end_pattern=None):
""" """
start_index = str(data).find(start_pattern) start_index = str(data).find(start_pattern)
end_index = str(data).find(end_pattern) end_index = str(data).find(end_pattern)
logger.debug("Listing start index= %d, end=%d" % (start_index, end_index)) logger.debug('Listing start index= %d, end=%d' % (start_index, end_index))
if start_index == -1 or end_index == -1: if start_index == -1 or end_index == -1:
return data return data
return data[start_index:end_index + len(end_pattern)] return data[start_index:end_index + len(end_pattern)]
@ -145,9 +145,9 @@ def test_check_output(data=None, check_dict=None, expect_dict=None):
for line in data_lines: for line in data_lines:
group, index = test_expect_re(line, pattern) group, index = test_expect_re(line, pattern)
if index is not None: if index is not None:
logger.debug("Found key{%s}=%s, line: \n%s" % (key, group, line)) logger.debug('Found key{%s}=%s, line: \n%s' % (key, group, line))
if expect_dict[key] == group: if expect_dict[key] == group:
logger.debug("The result is correct for the key:%s, expected:%s == returned:%s" % (key, str(expect_dict[key]), str(group))) logger.debug('The result is correct for the key:%s, expected:%s == returned:%s' % (key, str(expect_dict[key]), str(group)))
match_count += 1 match_count += 1
return match_count return match_count
@ -158,7 +158,7 @@ def test_check_mode(dut=None, mode_str=None, value=None):
global logger global logger
try: try:
opt = dut.app.get_sdkconfig()[mode_str] opt = dut.app.get_sdkconfig()[mode_str]
logger.info("%s {%s} = %s.\n" % (str(dut), mode_str, opt)) logger.info('%s {%s} = %s.\n' % (str(dut), mode_str, opt))
return value == opt return value == opt
except Exception: except Exception:
logger.info('ENV_TEST_FAILURE: %s: Cannot find option %s in sdkconfig.' % (str(dut), mode_str)) logger.info('ENV_TEST_FAILURE: %s: Cannot find option %s in sdkconfig.' % (str(dut), mode_str))
@ -170,30 +170,30 @@ def test_modbus_communication(env, comm_mode):
global logger global logger
# Get device under test. "dut1 - master", "dut2 - slave" must be properly connected through RS485 interface driver # Get device under test. "dut1 - master", "dut2 - slave" must be properly connected through RS485 interface driver
dut_master = env.get_dut("modbus_master", "examples/protocols/modbus/serial/mb_master", dut_class=ttfw_idf.ESP32DUT) dut_master = env.get_dut('modbus_master', 'examples/protocols/modbus/serial/mb_master', dut_class=ttfw_idf.ESP32DUT)
dut_slave = env.get_dut("modbus_slave", "examples/protocols/modbus/serial/mb_slave", dut_class=ttfw_idf.ESP32DUT) dut_slave = env.get_dut('modbus_slave', 'examples/protocols/modbus/serial/mb_slave', dut_class=ttfw_idf.ESP32DUT)
try: try:
logger.debug("Environment vars: %s\r\n" % os.environ) logger.debug('Environment vars: %s\r\n' % os.environ)
logger.debug("DUT slave sdkconfig: %s\r\n" % dut_slave.app.get_sdkconfig()) logger.debug('DUT slave sdkconfig: %s\r\n' % dut_slave.app.get_sdkconfig())
logger.debug("DUT master sdkconfig: %s\r\n" % dut_master.app.get_sdkconfig()) logger.debug('DUT master sdkconfig: %s\r\n' % dut_master.app.get_sdkconfig())
# Check Kconfig configuration options for each built example # Check Kconfig configuration options for each built example
if test_check_mode(dut_master, "CONFIG_MB_COMM_MODE_ASCII", "y") and test_check_mode(dut_slave, "CONFIG_MB_COMM_MODE_ASCII", "y"): if test_check_mode(dut_master, 'CONFIG_MB_COMM_MODE_ASCII', 'y') and test_check_mode(dut_slave, 'CONFIG_MB_COMM_MODE_ASCII', 'y'):
logger.info("ENV_TEST_INFO: Modbus ASCII test mode selected in the configuration. \n") logger.info('ENV_TEST_INFO: Modbus ASCII test mode selected in the configuration. \n')
slave_name = TEST_SLAVE_ASCII slave_name = TEST_SLAVE_ASCII
master_name = TEST_MASTER_ASCII master_name = TEST_MASTER_ASCII
elif test_check_mode(dut_master, "CONFIG_MB_COMM_MODE_RTU", "y") and test_check_mode(dut_slave, "CONFIG_MB_COMM_MODE_RTU", "y"): elif test_check_mode(dut_master, 'CONFIG_MB_COMM_MODE_RTU', 'y') and test_check_mode(dut_slave, 'CONFIG_MB_COMM_MODE_RTU', 'y'):
logger.info("ENV_TEST_INFO: Modbus RTU test mode selected in the configuration. \n") logger.info('ENV_TEST_INFO: Modbus RTU test mode selected in the configuration. \n')
slave_name = TEST_SLAVE_RTU slave_name = TEST_SLAVE_RTU
master_name = TEST_MASTER_RTU master_name = TEST_MASTER_RTU
else: else:
logger.error("ENV_TEST_FAILURE: Communication mode in master and slave configuration don't match.\n") logger.error("ENV_TEST_FAILURE: Communication mode in master and slave configuration don't match.\n")
raise Exception("ENV_TEST_FAILURE: Communication mode in master and slave configuration don't match.\n") raise Exception("ENV_TEST_FAILURE: Communication mode in master and slave configuration don't match.\n")
# Check if slave address for example application is default one to be able to communicate # Check if slave address for example application is default one to be able to communicate
if not test_check_mode(dut_slave, "CONFIG_MB_SLAVE_ADDR", "1"): if not test_check_mode(dut_slave, 'CONFIG_MB_SLAVE_ADDR', '1'):
logger.error("ENV_TEST_FAILURE: Slave address option is incorrect.\n") logger.error('ENV_TEST_FAILURE: Slave address option is incorrect.\n')
raise Exception("ENV_TEST_FAILURE: Slave address option is incorrect.\n") raise Exception('ENV_TEST_FAILURE: Slave address option is incorrect.\n')
# Flash app onto each DUT # Flash app onto each DUT
dut_master.start_app() dut_master.start_app()
@ -212,15 +212,15 @@ def test_modbus_communication(env, comm_mode):
dut_master_thread.join(timeout=TEST_THREAD_JOIN_TIMEOUT) dut_master_thread.join(timeout=TEST_THREAD_JOIN_TIMEOUT)
if dut_slave_thread.isAlive(): if dut_slave_thread.isAlive():
logger.error("ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % logger.error('ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
raise Exception("ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % raise Exception('ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
if dut_master_thread.isAlive(): if dut_master_thread.isAlive():
logger.error("ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % logger.error('ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
raise Exception("ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % raise Exception('ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
finally: finally:
dut_master.close() dut_master.close()
@ -228,43 +228,43 @@ def test_modbus_communication(env, comm_mode):
# Check if test threads completed successfully and captured data # Check if test threads completed successfully and captured data
if not dut_slave_thread.result or dut_slave_thread.data is None: if not dut_slave_thread.result or dut_slave_thread.data is None:
logger.error("The thread %s was not run successfully." % dut_slave_thread.tname) logger.error('The thread %s was not run successfully.' % dut_slave_thread.tname)
raise Exception("The thread %s was not run successfully." % dut_slave_thread.tname) raise Exception('The thread %s was not run successfully.' % dut_slave_thread.tname)
if not dut_master_thread.result or dut_master_thread.data is None: if not dut_master_thread.result or dut_master_thread.data is None:
logger.error("The thread %s was not run successfully." % dut_slave_thread.tname) logger.error('The thread %s was not run successfully.' % dut_slave_thread.tname)
raise Exception("The thread %s was not run successfully." % dut_master_thread.tname) raise Exception('The thread %s was not run successfully.' % dut_master_thread.tname)
# Filter output to get test messages # Filter output to get test messages
master_output = test_filter_output(dut_master_thread.data, master_expect[0], master_expect[len(master_expect) - 1]) master_output = test_filter_output(dut_master_thread.data, master_expect[0], master_expect[len(master_expect) - 1])
if master_output is not None: if master_output is not None:
logger.info("The data for master thread is captured.") logger.info('The data for master thread is captured.')
logger.debug(master_output) logger.debug(master_output)
slave_output = test_filter_output(dut_slave_thread.data, slave_expect[0], slave_expect[len(slave_expect) - 1]) slave_output = test_filter_output(dut_slave_thread.data, slave_expect[0], slave_expect[len(slave_expect) - 1])
if slave_output is not None: if slave_output is not None:
logger.info("The data for slave thread is captured.") logger.info('The data for slave thread is captured.')
logger.debug(slave_output) logger.debug(slave_output)
# Check if parameters are read correctly by master # Check if parameters are read correctly by master
match_count = test_check_output(master_output, pattern_dict_master_ok, expect_dict_master_ok) match_count = test_check_output(master_output, pattern_dict_master_ok, expect_dict_master_ok)
if match_count < TEST_READ_MIN_COUNT: if match_count < TEST_READ_MIN_COUNT:
logger.error("There are errors reading parameters from %s, %d" % (dut_master_thread.tname, match_count)) logger.error('There are errors reading parameters from %s, %d' % (dut_master_thread.tname, match_count))
raise Exception("There are errors reading parameters from %s, %d" % (dut_master_thread.tname, match_count)) raise Exception('There are errors reading parameters from %s, %d' % (dut_master_thread.tname, match_count))
logger.info("OK pattern test for %s, match_count=%d." % (dut_master_thread.tname, match_count)) logger.info('OK pattern test for %s, match_count=%d.' % (dut_master_thread.tname, match_count))
# If the test completed successfully (alarm triggered) but there are some errors during reading of parameters # If the test completed successfully (alarm triggered) but there are some errors during reading of parameters
match_count = test_check_output(master_output, pattern_dict_master_err, expect_dict_master_err) match_count = test_check_output(master_output, pattern_dict_master_err, expect_dict_master_err)
if match_count > TEST_READ_MAX_ERR_COUNT: if match_count > TEST_READ_MAX_ERR_COUNT:
logger.error("There are errors reading parameters from %s, %d" % (dut_master_thread.tname, match_count)) logger.error('There are errors reading parameters from %s, %d' % (dut_master_thread.tname, match_count))
raise Exception("There are errors reading parameters from %s, %d" % (dut_master_thread.tname, match_count)) raise Exception('There are errors reading parameters from %s, %d' % (dut_master_thread.tname, match_count))
logger.info("ERROR pattern test for %s, match_count=%d." % (dut_master_thread.tname, match_count)) logger.info('ERROR pattern test for %s, match_count=%d.' % (dut_master_thread.tname, match_count))
match_count = test_check_output(slave_output, pattern_dict_slave_ok, expect_dict_slave_ok) match_count = test_check_output(slave_output, pattern_dict_slave_ok, expect_dict_slave_ok)
if match_count < TEST_READ_MIN_COUNT: if match_count < TEST_READ_MIN_COUNT:
logger.error("There are errors reading parameters from %s, %d" % (dut_slave_thread.tname, match_count)) logger.error('There are errors reading parameters from %s, %d' % (dut_slave_thread.tname, match_count))
raise Exception("There are errors reading parameters from %s, %d" % (dut_slave_thread.tname, match_count)) raise Exception('There are errors reading parameters from %s, %d' % (dut_slave_thread.tname, match_count))
logger.info("OK pattern test for %s, match_count=%d." % (dut_slave_thread.tname, match_count)) logger.info('OK pattern test for %s, match_count=%d.' % (dut_slave_thread.tname, match_count))
if __name__ == '__main__': if __name__ == '__main__':
@ -282,7 +282,7 @@ if __name__ == '__main__':
fh.setFormatter(formatter) fh.setFormatter(formatter)
logger.addHandler(fh) logger.addHandler(fh)
logger.addHandler(ch) logger.addHandler(ch)
logger.info("Start script %s." % os.path.basename(__file__)) logger.info('Start script %s.' % os.path.basename(__file__))
print("Logging file name: %s" % logger.handlers[0].baseFilename) print('Logging file name: %s' % logger.handlers[0].baseFilename)
test_modbus_communication() test_modbus_communication()
logging.shutdown() logging.shutdown()

View File

@ -1,13 +1,13 @@
import logging
import os import os
import re import re
import logging
from threading import Thread from threading import Thread
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT from tiny_test_fw import DUT
LOG_LEVEL = logging.DEBUG LOG_LEVEL = logging.DEBUG
LOGGER_NAME = "modbus_test" LOGGER_NAME = 'modbus_test'
# Allowed options for the test # Allowed options for the test
TEST_READ_MAX_ERR_COUNT = 3 # Maximum allowed read errors during initialization TEST_READ_MAX_ERR_COUNT = 3 # Maximum allowed read errors during initialization
@ -69,7 +69,7 @@ class DutTestThread(Thread):
super(DutTestThread, self).__init__() super(DutTestThread, self).__init__()
def __enter__(self): def __enter__(self):
logger.debug("Restart %s." % self.tname) logger.debug('Restart %s.' % self.tname)
# Reset DUT first # Reset DUT first
self.dut.reset() self.dut.reset()
# Capture output from the DUT # Capture output from the DUT
@ -80,7 +80,7 @@ class DutTestThread(Thread):
""" The exit method of context manager """ The exit method of context manager
""" """
if exc_type is not None or exc_value is not None: if exc_type is not None or exc_value is not None:
logger.info("Thread %s rised an exception type: %s, value: %s" % (self.tname, str(exc_type), str(exc_value))) logger.info('Thread %s rised an exception type: %s, value: %s' % (self.tname, str(exc_type), str(exc_value)))
def run(self): def run(self):
""" The function implements thread functionality """ The function implements thread functionality
@ -94,8 +94,8 @@ class DutTestThread(Thread):
# Check DUT exceptions # Check DUT exceptions
dut_exceptions = self.dut.get_exceptions() dut_exceptions = self.dut.get_exceptions()
if "Guru Meditation Error:" in dut_exceptions: if 'Guru Meditation Error:' in dut_exceptions:
raise Exception("%s generated an exception: %s\n" % (str(self.dut), dut_exceptions)) raise Exception('%s generated an exception: %s\n' % (str(self.dut), dut_exceptions))
# Mark thread has run to completion without any exceptions # Mark thread has run to completion without any exceptions
self.data = self.dut.stop_capture_raw_data(capture_id=self.dut.name) self.data = self.dut.stop_capture_raw_data(capture_id=self.dut.name)
@ -108,13 +108,13 @@ class DutTestThread(Thread):
self.dut.read() self.dut.read()
result = self.dut.expect(re.compile(message), TEST_EXPECT_STR_TIMEOUT) result = self.dut.expect(re.compile(message), TEST_EXPECT_STR_TIMEOUT)
if int(result[0]) != index: if int(result[0]) != index:
raise Exception("Incorrect index of IP=%d for %s\n" % (int(result[0]), str(self.dut))) raise Exception('Incorrect index of IP=%d for %s\n' % (int(result[0]), str(self.dut)))
message = "IP%s=%s" % (result[0], self.ip_addr) message = 'IP%s=%s' % (result[0], self.ip_addr)
self.dut.write(message, "\r\n", False) self.dut.write(message, '\r\n', False)
logger.debug("Sent message for %s: %s" % (self.tname, message)) logger.debug('Sent message for %s: %s' % (self.tname, message))
message = r'.*IP\([0-9]+\) = \[([0-9a-zA-Z\.\:]+)\] set from stdin.*' message = r'.*IP\([0-9]+\) = \[([0-9a-zA-Z\.\:]+)\] set from stdin.*'
result = self.dut.expect(re.compile(message), TEST_EXPECT_STR_TIMEOUT) result = self.dut.expect(re.compile(message), TEST_EXPECT_STR_TIMEOUT)
logger.debug("Thread %s initialized with slave IP (%s)." % (self.tname, result[0])) logger.debug('Thread %s initialized with slave IP (%s).' % (self.tname, result[0]))
def test_start(self, timeout_value): def test_start(self, timeout_value):
""" The method to initialize and handle test stages """ The method to initialize and handle test stages
@ -122,37 +122,37 @@ class DutTestThread(Thread):
def handle_get_ip4(data): def handle_get_ip4(data):
""" Handle get_ip v4 """ Handle get_ip v4
""" """
logger.debug("%s[STACK_IPV4]: %s" % (self.tname, str(data))) logger.debug('%s[STACK_IPV4]: %s' % (self.tname, str(data)))
self.test_stage = STACK_IPV4 self.test_stage = STACK_IPV4
def handle_get_ip6(data): def handle_get_ip6(data):
""" Handle get_ip v6 """ Handle get_ip v6
""" """
logger.debug("%s[STACK_IPV6]: %s" % (self.tname, str(data))) logger.debug('%s[STACK_IPV6]: %s' % (self.tname, str(data)))
self.test_stage = STACK_IPV6 self.test_stage = STACK_IPV6
def handle_init(data): def handle_init(data):
""" Handle init """ Handle init
""" """
logger.debug("%s[STACK_INIT]: %s" % (self.tname, str(data))) logger.debug('%s[STACK_INIT]: %s' % (self.tname, str(data)))
self.test_stage = STACK_INIT self.test_stage = STACK_INIT
def handle_connect(data): def handle_connect(data):
""" Handle connect """ Handle connect
""" """
logger.debug("%s[STACK_CONNECT]: %s" % (self.tname, str(data))) logger.debug('%s[STACK_CONNECT]: %s' % (self.tname, str(data)))
self.test_stage = STACK_CONNECT self.test_stage = STACK_CONNECT
def handle_test_start(data): def handle_test_start(data):
""" Handle connect """ Handle connect
""" """
logger.debug("%s[STACK_START]: %s" % (self.tname, str(data))) logger.debug('%s[STACK_START]: %s' % (self.tname, str(data)))
self.test_stage = STACK_START self.test_stage = STACK_START
def handle_par_ok(data): def handle_par_ok(data):
""" Handle parameter ok """ Handle parameter ok
""" """
logger.debug("%s[READ_PAR_OK]: %s" % (self.tname, str(data))) logger.debug('%s[READ_PAR_OK]: %s' % (self.tname, str(data)))
if self.test_stage >= STACK_START: if self.test_stage >= STACK_START:
self.param_ok_count += 1 self.param_ok_count += 1
self.test_stage = STACK_PAR_OK self.test_stage = STACK_PAR_OK
@ -160,14 +160,14 @@ class DutTestThread(Thread):
def handle_par_fail(data): def handle_par_fail(data):
""" Handle parameter fail """ Handle parameter fail
""" """
logger.debug("%s[READ_PAR_FAIL]: %s" % (self.tname, str(data))) logger.debug('%s[READ_PAR_FAIL]: %s' % (self.tname, str(data)))
self.param_fail_count += 1 self.param_fail_count += 1
self.test_stage = STACK_PAR_FAIL self.test_stage = STACK_PAR_FAIL
def handle_destroy(data): def handle_destroy(data):
""" Handle destroy """ Handle destroy
""" """
logger.debug("%s[DESTROY]: %s" % (self.tname, str(data))) logger.debug('%s[DESTROY]: %s' % (self.tname, str(data)))
self.test_stage = STACK_DESTROY self.test_stage = STACK_DESTROY
self.test_finish = True self.test_finish = True
@ -183,7 +183,7 @@ class DutTestThread(Thread):
(re.compile(self.expected[STACK_DESTROY]), handle_destroy), (re.compile(self.expected[STACK_DESTROY]), handle_destroy),
timeout=timeout_value) timeout=timeout_value)
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
logger.debug("%s, expect timeout on stage #%d (%s seconds)" % (self.tname, self.test_stage, timeout_value)) logger.debug('%s, expect timeout on stage #%d (%s seconds)' % (self.tname, self.test_stage, timeout_value))
self.test_finish = True self.test_finish = True
@ -193,7 +193,7 @@ def test_check_mode(dut=None, mode_str=None, value=None):
global logger global logger
try: try:
opt = dut.app.get_sdkconfig()[mode_str] opt = dut.app.get_sdkconfig()[mode_str]
logger.debug("%s {%s} = %s.\n" % (str(dut), mode_str, opt)) logger.debug('%s {%s} = %s.\n' % (str(dut), mode_str, opt))
return value == opt return value == opt
except Exception: except Exception:
logger.error('ENV_TEST_FAILURE: %s: Cannot find option %s in sdkconfig.' % (str(dut), mode_str)) logger.error('ENV_TEST_FAILURE: %s: Cannot find option %s in sdkconfig.' % (str(dut), mode_str))
@ -208,8 +208,8 @@ def test_modbus_communication(env, comm_mode):
# Get device under test. Both duts must be able to be connected to WiFi router # Get device under test. Both duts must be able to be connected to WiFi router
dut_master = env.get_dut('modbus_tcp_master', os.path.join(rel_project_path, TEST_MASTER_TCP)) dut_master = env.get_dut('modbus_tcp_master', os.path.join(rel_project_path, TEST_MASTER_TCP))
dut_slave = env.get_dut('modbus_tcp_slave', os.path.join(rel_project_path, TEST_SLAVE_TCP)) dut_slave = env.get_dut('modbus_tcp_slave', os.path.join(rel_project_path, TEST_SLAVE_TCP))
log_file = os.path.join(env.log_path, "modbus_tcp_test.log") log_file = os.path.join(env.log_path, 'modbus_tcp_test.log')
print("Logging file name: %s" % log_file) print('Logging file name: %s' % log_file)
try: try:
# create file handler which logs even debug messages # create file handler which logs even debug messages
@ -229,29 +229,29 @@ def test_modbus_communication(env, comm_mode):
logger.addHandler(ch) logger.addHandler(ch)
# Check Kconfig configuration options for each built example # Check Kconfig configuration options for each built example
if (test_check_mode(dut_master, "CONFIG_FMB_COMM_MODE_TCP_EN", "y") and if (test_check_mode(dut_master, 'CONFIG_FMB_COMM_MODE_TCP_EN', 'y') and
test_check_mode(dut_slave, "CONFIG_FMB_COMM_MODE_TCP_EN", "y")): test_check_mode(dut_slave, 'CONFIG_FMB_COMM_MODE_TCP_EN', 'y')):
slave_name = TEST_SLAVE_TCP slave_name = TEST_SLAVE_TCP
master_name = TEST_MASTER_TCP master_name = TEST_MASTER_TCP
else: else:
logger.error("ENV_TEST_FAILURE: IP resolver mode do not match in the master and slave implementation.\n") logger.error('ENV_TEST_FAILURE: IP resolver mode do not match in the master and slave implementation.\n')
raise Exception("ENV_TEST_FAILURE: IP resolver mode do not match in the master and slave implementation.\n") raise Exception('ENV_TEST_FAILURE: IP resolver mode do not match in the master and slave implementation.\n')
address = None address = None
if test_check_mode(dut_master, "CONFIG_MB_SLAVE_IP_FROM_STDIN", "y"): if test_check_mode(dut_master, 'CONFIG_MB_SLAVE_IP_FROM_STDIN', 'y'):
logger.info("ENV_TEST_INFO: Set slave IP address through STDIN.\n") logger.info('ENV_TEST_INFO: Set slave IP address through STDIN.\n')
# Flash app onto DUT (Todo: Debug case when the slave flashed before master then expect does not work correctly for no reason # Flash app onto DUT (Todo: Debug case when the slave flashed before master then expect does not work correctly for no reason
dut_slave.start_app() dut_slave.start_app()
dut_master.start_app() dut_master.start_app()
if test_check_mode(dut_master, "CONFIG_EXAMPLE_CONNECT_IPV6", "y"): if test_check_mode(dut_master, 'CONFIG_EXAMPLE_CONNECT_IPV6', 'y'):
address = dut_slave.expect(re.compile(pattern_dict_slave[STACK_IPV6]), TEST_EXPECT_STR_TIMEOUT) address = dut_slave.expect(re.compile(pattern_dict_slave[STACK_IPV6]), TEST_EXPECT_STR_TIMEOUT)
else: else:
address = dut_slave.expect(re.compile(pattern_dict_slave[STACK_IPV4]), TEST_EXPECT_STR_TIMEOUT) address = dut_slave.expect(re.compile(pattern_dict_slave[STACK_IPV4]), TEST_EXPECT_STR_TIMEOUT)
if address is not None: if address is not None:
print("Found IP slave address: %s" % address[0]) print('Found IP slave address: %s' % address[0])
else: else:
raise Exception("ENV_TEST_FAILURE: Slave IP address is not found in the output. Check network settings.\n") raise Exception('ENV_TEST_FAILURE: Slave IP address is not found in the output. Check network settings.\n')
else: else:
raise Exception("ENV_TEST_FAILURE: Slave IP resolver is not configured correctly.\n") raise Exception('ENV_TEST_FAILURE: Slave IP resolver is not configured correctly.\n')
# Create thread for each dut # Create thread for each dut
with DutTestThread(dut=dut_master, name=master_name, ip_addr=address[0], expect=pattern_dict_master) as dut_master_thread: with DutTestThread(dut=dut_master, name=master_name, ip_addr=address[0], expect=pattern_dict_master) as dut_master_thread:
@ -266,21 +266,21 @@ def test_modbus_communication(env, comm_mode):
dut_master_thread.join(timeout=TEST_THREAD_JOIN_TIMEOUT) dut_master_thread.join(timeout=TEST_THREAD_JOIN_TIMEOUT)
if dut_slave_thread.isAlive(): if dut_slave_thread.isAlive():
logger.error("ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % logger.error('ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
raise Exception("ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % raise Exception('ENV_TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_slave_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
if dut_master_thread.isAlive(): if dut_master_thread.isAlive():
logger.error("TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % logger.error('TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
raise Exception("TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n" % raise Exception('TEST_FAILURE: The thread %s is not completed successfully after %d seconds.\n' %
(dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT)) (dut_master_thread.tname, TEST_THREAD_JOIN_TIMEOUT))
logger.info("TEST_INFO: %s error count = %d, %s error count = %d.\n" % logger.info('TEST_INFO: %s error count = %d, %s error count = %d.\n' %
(dut_master_thread.tname, dut_master_thread.param_fail_count, (dut_master_thread.tname, dut_master_thread.param_fail_count,
dut_slave_thread.tname, dut_slave_thread.param_fail_count)) dut_slave_thread.tname, dut_slave_thread.param_fail_count))
logger.info("TEST_INFO: %s ok count = %d, %s ok count = %d.\n" % logger.info('TEST_INFO: %s ok count = %d, %s ok count = %d.\n' %
(dut_master_thread.tname, dut_master_thread.param_ok_count, (dut_master_thread.tname, dut_master_thread.param_ok_count,
dut_slave_thread.tname, dut_slave_thread.param_ok_count)) dut_slave_thread.tname, dut_slave_thread.param_ok_count))
@ -288,10 +288,10 @@ def test_modbus_communication(env, comm_mode):
(dut_slave_thread.param_fail_count > TEST_READ_MAX_ERR_COUNT) or (dut_slave_thread.param_fail_count > TEST_READ_MAX_ERR_COUNT) or
(dut_slave_thread.param_ok_count == 0) or (dut_slave_thread.param_ok_count == 0) or
(dut_master_thread.param_ok_count == 0)): (dut_master_thread.param_ok_count == 0)):
raise Exception("TEST_FAILURE: %s parameter read error(ok) count = %d(%d), %s parameter read error(ok) count = %d(%d).\n" % raise Exception('TEST_FAILURE: %s parameter read error(ok) count = %d(%d), %s parameter read error(ok) count = %d(%d).\n' %
(dut_master_thread.tname, dut_master_thread.param_fail_count, dut_master_thread.param_ok_count, (dut_master_thread.tname, dut_master_thread.param_fail_count, dut_master_thread.param_ok_count,
dut_slave_thread.tname, dut_slave_thread.param_fail_count, dut_slave_thread.param_ok_count)) dut_slave_thread.tname, dut_slave_thread.param_fail_count, dut_slave_thread.param_ok_count))
logger.info("TEST_SUCCESS: The Modbus parameter test is completed successfully.\n") logger.info('TEST_SUCCESS: The Modbus parameter test is completed successfully.\n')
finally: finally:
dut_master.close() dut_master.close()

View File

@ -1,29 +1,28 @@
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
from builtins import str
import re
import os import os
import sys import re
import ssl import ssl
import sys
from builtins import str
from threading import Event, Thread
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
from threading import Thread, Event
from tiny_test_fw import DUT
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
event_client_received_correct = Event() event_client_received_correct = Event()
event_client_received_binary = Event() event_client_received_binary = Event()
message_log = "" message_log = ''
# The callback for when the client receives a CONNACK response from the server. # The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc): def on_connect(client, userdata, flags, rc):
print("Connected with result code " + str(rc)) print('Connected with result code ' + str(rc))
event_client_connected.set() event_client_connected.set()
client.subscribe("/topic/qos0") client.subscribe('/topic/qos0')
def mqtt_client_task(client): def mqtt_client_task(client):
@ -36,33 +35,33 @@ def on_message(client, userdata, msg):
global message_log global message_log
global event_client_received_correct global event_client_received_correct
global event_client_received_binary global event_client_received_binary
if msg.topic == "/topic/binary": if msg.topic == '/topic/binary':
binary = userdata binary = userdata
size = os.path.getsize(binary) size = os.path.getsize(binary)
print("Receiving binary from esp and comparing with {}, size {}...".format(binary, size)) print('Receiving binary from esp and comparing with {}, size {}...'.format(binary, size))
with open(binary, "rb") as f: with open(binary, 'rb') as f:
bin = f.read() bin = f.read()
if bin == msg.payload[:size]: if bin == msg.payload[:size]:
print("...matches!") print('...matches!')
event_client_received_binary.set() event_client_received_binary.set()
return return
else: else:
recv_binary = binary + ".received" recv_binary = binary + '.received'
with open(recv_binary, "w") as fw: with open(recv_binary, 'w') as fw:
fw.write(msg.payload) fw.write(msg.payload)
raise ValueError('Received binary (saved as: {}) does not match the original file: {}'.format(recv_binary, binary)) raise ValueError('Received binary (saved as: {}) does not match the original file: {}'.format(recv_binary, binary))
payload = msg.payload.decode() payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == "data": if not event_client_received_correct.is_set() and payload == 'data':
client.subscribe("/topic/binary") client.subscribe('/topic/binary')
client.publish("/topic/qos0", "send binary please") client.publish('/topic/qos0', 'send binary please')
if msg.topic == "/topic/qos0" and payload == "data": if msg.topic == '/topic/qos0' and payload == 'data':
event_client_received_correct.set() event_client_received_correct.set()
message_log += "Received data:" + msg.topic + " " + payload + "\n" message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_mqtt_ssl(env, extra_data): def test_examples_protocol_mqtt_ssl(env, extra_data):
broker_url = "" broker_url = ''
broker_port = 0 broker_port = 0
""" """
steps: | steps: |
@ -72,15 +71,15 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
4. Test ESP32 client received correct qos0 message 4. Test ESP32 client received correct qos0 message
5. Test python client receives binary data from running partition and compares it with the binary 5. Test python client receives binary data from running partition and compares it with the binary
""" """
dut1 = env.get_dut("mqtt_ssl", "examples/protocols/mqtt/ssl", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('mqtt_ssl', 'examples/protocols/mqtt/ssl', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "mqtt_ssl.bin") binary_file = os.path.join(dut1.app.binary_path, 'mqtt_ssl.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("mqtt_ssl_bin_size", "{}KB" ttfw_idf.log_performance('mqtt_ssl_bin_size', '{}KB'
.format(bin_size // 1024)) .format(bin_size // 1024))
# Look for host:port in sdkconfig # Look for host:port in sdkconfig
try: try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()["CONFIG_BROKER_URI"]) value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
broker_url = value.group(1) broker_url = value.group(1)
broker_port = int(value.group(2)) broker_port = int(value.group(2))
except Exception: except Exception:
@ -97,31 +96,31 @@ def test_examples_protocol_mqtt_ssl(env, extra_data):
None, None,
None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
client.tls_insecure_set(True) client.tls_insecure_set(True)
print("Connecting...") print('Connecting...')
client.connect(broker_url, broker_port, 60) client.connect(broker_url, broker_port, 60)
except Exception: except Exception:
print("ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:".format(broker_url, sys.exc_info()[0])) print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0]))
raise raise
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
thread1 = Thread(target=mqtt_client_task, args=(client,)) thread1 = Thread(target=mqtt_client_task, args=(client,))
thread1.start() thread1.start()
try: try:
print("Connecting py-client to broker {}:{}...".format(broker_url, broker_port)) print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30): if not event_client_connected.wait(timeout=30):
raise ValueError("ENV_TEST_FAILURE: Test script cannot connect to broker: {}".format(broker_url)) raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app() dut1.start_app()
try: try:
ip_address = dut1.expect(re.compile(r" sta ip: ([^,]+),"), timeout=30) ip_address = dut1.expect(re.compile(r' sta ip: ([^,]+),'), timeout=30)
print("Connected to AP with IP: {}".format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
print('ENV_TEST_FAILURE: Cannot connect to AP') print('ENV_TEST_FAILURE: Cannot connect to AP')
raise raise
print("Checking py-client received msg published from esp...") print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30): if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log)) raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print("Checking esp-client received msg published from py-client...") print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r"DATA=send binary please"), timeout=30) dut1.expect(re.compile(r'DATA=send binary please'), timeout=30)
print("Receiving binary data from running partition...") print('Receiving binary data from running partition...')
if not event_client_received_binary.wait(timeout=30): if not event_client_received_binary.wait(timeout=30):
raise ValueError('Binary not received within timeout') raise ValueError('Binary not received within timeout')
finally: finally:

View File

@ -1,20 +1,20 @@
import re
import os import os
import sys import re
import socket import socket
from threading import Thread
import struct import struct
import sys
import time import time
from threading import Thread
from tiny_test_fw import DUT
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT
msgid = -1 msgid = -1
def get_my_ip(): def get_my_ip():
s1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s1.connect(("8.8.8.8", 80)) s1.connect(('8.8.8.8', 80))
my_ip = s1.getsockname()[0] my_ip = s1.getsockname()[0]
s1.close() s1.close()
return my_ip return my_ip
@ -22,7 +22,7 @@ def get_my_ip():
def mqqt_server_sketch(my_ip, port): def mqqt_server_sketch(my_ip, port):
global msgid global msgid
print("Starting the server on {}".format(my_ip)) print('Starting the server on {}'.format(my_ip))
s = None s = None
try: try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -31,29 +31,29 @@ def mqqt_server_sketch(my_ip, port):
s.listen(1) s.listen(1)
q,addr = s.accept() q,addr = s.accept()
q.settimeout(30) q.settimeout(30)
print("connection accepted") print('connection accepted')
except Exception: except Exception:
print("Local server on {}:{} listening/accepting failure: {}" print('Local server on {}:{} listening/accepting failure: {}'
"Possibly check permissions or firewall settings" 'Possibly check permissions or firewall settings'
"to accept connections on this address".format(my_ip, port, sys.exc_info()[0])) 'to accept connections on this address'.format(my_ip, port, sys.exc_info()[0]))
raise raise
data = q.recv(1024) data = q.recv(1024)
# check if received initial empty message # check if received initial empty message
print("received from client {}".format(data)) print('received from client {}'.format(data))
data = bytearray([0x20, 0x02, 0x00, 0x00]) data = bytearray([0x20, 0x02, 0x00, 0x00])
q.send(data) q.send(data)
# try to receive qos1 # try to receive qos1
data = q.recv(1024) data = q.recv(1024)
msgid = struct.unpack(">H", data[15:17])[0] msgid = struct.unpack('>H', data[15:17])[0]
print("received from client {}, msgid: {}".format(data, msgid)) print('received from client {}, msgid: {}'.format(data, msgid))
data = bytearray([0x40, 0x02, data[15], data[16]]) data = bytearray([0x40, 0x02, data[15], data[16]])
q.send(data) q.send(data)
time.sleep(5) time.sleep(5)
s.close() s.close()
print("server closed") print('server closed')
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_mqtt_qos1(env, extra_data): def test_examples_protocol_mqtt_qos1(env, extra_data):
global msgid global msgid
""" """
@ -63,11 +63,11 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
3. Test evaluates that qos1 message is queued and removed from queued after ACK received 3. Test evaluates that qos1 message is queued and removed from queued after ACK received
4. Test the broker received the same message id evaluated in step 3 4. Test the broker received the same message id evaluated in step 3
""" """
dut1 = env.get_dut("mqtt_tcp", "examples/protocols/mqtt/tcp", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('mqtt_tcp', 'examples/protocols/mqtt/tcp', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "mqtt_tcp.bin") binary_file = os.path.join(dut1.app.binary_path, 'mqtt_tcp.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("mqtt_tcp_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('mqtt_tcp_bin_size', '{}KB'.format(bin_size // 1024))
# 1. start mqtt broker sketch # 1. start mqtt broker sketch
host_ip = get_my_ip() host_ip = get_my_ip()
thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883)) thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883))
@ -76,23 +76,23 @@ def test_examples_protocol_mqtt_qos1(env, extra_data):
dut1.start_app() dut1.start_app()
# waiting for getting the IP address # waiting for getting the IP address
try: try:
ip_address = dut1.expect(re.compile(r" sta ip: ([^,]+),"), timeout=30) ip_address = dut1.expect(re.compile(r' sta ip: ([^,]+),'), timeout=30)
print("Connected to AP with IP: {}".format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
raise ValueError('ENV_TEST_FAILURE: Cannot connect to AP') raise ValueError('ENV_TEST_FAILURE: Cannot connect to AP')
print("writing to device: {}".format("mqtt://" + host_ip + "\n")) print('writing to device: {}'.format('mqtt://' + host_ip + '\n'))
dut1.write("mqtt://" + host_ip + "\n") dut1.write('mqtt://' + host_ip + '\n')
thread1.join() thread1.join()
print("Message id received from server: {}".format(msgid)) print('Message id received from server: {}'.format(msgid))
# 3. check the message id was enqueued and then deleted # 3. check the message id was enqueued and then deleted
msgid_enqueued = dut1.expect(re.compile(r"OUTBOX: ENQUEUE msgid=([0-9]+)"), timeout=30) msgid_enqueued = dut1.expect(re.compile(r'OUTBOX: ENQUEUE msgid=([0-9]+)'), timeout=30)
msgid_deleted = dut1.expect(re.compile(r"OUTBOX: DELETED msgid=([0-9]+)"), timeout=30) msgid_deleted = dut1.expect(re.compile(r'OUTBOX: DELETED msgid=([0-9]+)'), timeout=30)
# 4. check the msgid of received data are the same as that of enqueued and deleted from outbox # 4. check the msgid of received data are the same as that of enqueued and deleted from outbox
if (msgid_enqueued[0] == str(msgid) and msgid_deleted[0] == str(msgid)): if (msgid_enqueued[0] == str(msgid) and msgid_deleted[0] == str(msgid)):
print("PASS: Received correct msg id") print('PASS: Received correct msg id')
else: else:
print("Failure!") print('Failure!')
raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted)) raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted))

View File

@ -1,26 +1,26 @@
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
from builtins import str
import re
import os
import sys
import paho.mqtt.client as mqtt
from threading import Thread, Event
from tiny_test_fw import DUT import os
import re
import sys
from builtins import str
from threading import Event, Thread
import paho.mqtt.client as mqtt
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
event_client_received_correct = Event() event_client_received_correct = Event()
message_log = "" message_log = ''
# The callback for when the client receives a CONNACK response from the server. # The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc): def on_connect(client, userdata, flags, rc):
print("Connected with result code " + str(rc)) print('Connected with result code ' + str(rc))
event_client_connected.set() event_client_connected.set()
client.subscribe("/topic/qos0") client.subscribe('/topic/qos0')
def mqtt_client_task(client): def mqtt_client_task(client):
@ -32,16 +32,16 @@ def mqtt_client_task(client):
def on_message(client, userdata, msg): def on_message(client, userdata, msg):
global message_log global message_log
payload = msg.payload.decode() payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == "data": if not event_client_received_correct.is_set() and payload == 'data':
client.publish("/topic/qos0", "data_to_esp32") client.publish('/topic/qos0', 'data_to_esp32')
if msg.topic == "/topic/qos0" and payload == "data": if msg.topic == '/topic/qos0' and payload == 'data':
event_client_received_correct.set() event_client_received_correct.set()
message_log += "Received data:" + msg.topic + " " + payload + "\n" message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_mqtt_ws(env, extra_data): def test_examples_protocol_mqtt_ws(env, extra_data):
broker_url = "" broker_url = ''
broker_port = 0 broker_port = 0
""" """
steps: | steps: |
@ -50,14 +50,14 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
3. Test evaluates it received correct qos0 message 3. Test evaluates it received correct qos0 message
4. Test ESP32 client received correct qos0 message 4. Test ESP32 client received correct qos0 message
""" """
dut1 = env.get_dut("mqtt_websocket", "examples/protocols/mqtt/ws", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('mqtt_websocket', 'examples/protocols/mqtt/ws', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "mqtt_websocket.bin") binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("mqtt_websocket_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('mqtt_websocket_bin_size', '{}KB'.format(bin_size // 1024))
# Look for host:port in sdkconfig # Look for host:port in sdkconfig
try: try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()["CONFIG_BROKER_URI"]) value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
broker_url = value.group(1) broker_url = value.group(1)
broker_port = int(value.group(2)) broker_port = int(value.group(2))
except Exception: except Exception:
@ -66,33 +66,33 @@ def test_examples_protocol_mqtt_ws(env, extra_data):
client = None client = None
# 1. Test connects to a broker # 1. Test connects to a broker
try: try:
client = mqtt.Client(transport="websockets") client = mqtt.Client(transport='websockets')
client.on_connect = on_connect client.on_connect = on_connect
client.on_message = on_message client.on_message = on_message
print("Connecting...") print('Connecting...')
client.connect(broker_url, broker_port, 60) client.connect(broker_url, broker_port, 60)
except Exception: except Exception:
print("ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:".format(broker_url, sys.exc_info()[0])) print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0]))
raise raise
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
thread1 = Thread(target=mqtt_client_task, args=(client,)) thread1 = Thread(target=mqtt_client_task, args=(client,))
thread1.start() thread1.start()
try: try:
print("Connecting py-client to broker {}:{}...".format(broker_url, broker_port)) print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30): if not event_client_connected.wait(timeout=30):
raise ValueError("ENV_TEST_FAILURE: Test script cannot connect to broker: {}".format(broker_url)) raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app() dut1.start_app()
try: try:
ip_address = dut1.expect(re.compile(r" sta ip: ([^,]+),"), timeout=30) ip_address = dut1.expect(re.compile(r' sta ip: ([^,]+),'), timeout=30)
print("Connected to AP with IP: {}".format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
print('ENV_TEST_FAILURE: Cannot connect to AP') print('ENV_TEST_FAILURE: Cannot connect to AP')
raise raise
print("Checking py-client received msg published from esp...") print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30): if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log)) raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print("Checking esp-client received msg published from py-client...") print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r"DATA=data_to_esp32"), timeout=30) dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30)
finally: finally:
event_stop_client.set() event_stop_client.set()
thread1.join() thread1.join()

View File

@ -1,28 +1,27 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from __future__ import unicode_literals
from builtins import str
import re
import os import os
import sys import re
import ssl import ssl
import sys
from builtins import str
from threading import Event, Thread
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
from threading import Thread, Event
from tiny_test_fw import DUT
import ttfw_idf import ttfw_idf
from tiny_test_fw import DUT
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
event_client_received_correct = Event() event_client_received_correct = Event()
message_log = "" message_log = ''
# The callback for when the client receives a CONNACK response from the server. # The callback for when the client receives a CONNACK response from the server.
def on_connect(client, userdata, flags, rc): def on_connect(client, userdata, flags, rc):
print("Connected with result code " + str(rc)) print('Connected with result code ' + str(rc))
event_client_connected.set() event_client_connected.set()
client.subscribe("/topic/qos0") client.subscribe('/topic/qos0')
def mqtt_client_task(client): def mqtt_client_task(client):
@ -34,16 +33,16 @@ def mqtt_client_task(client):
def on_message(client, userdata, msg): def on_message(client, userdata, msg):
global message_log global message_log
payload = msg.payload.decode() payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == "data": if not event_client_received_correct.is_set() and payload == 'data':
client.publish("/topic/qos0", "data_to_esp32") client.publish('/topic/qos0', 'data_to_esp32')
if msg.topic == "/topic/qos0" and payload == "data": if msg.topic == '/topic/qos0' and payload == 'data':
event_client_received_correct.set() event_client_received_correct.set()
message_log += "Received data:" + msg.topic + " " + payload + "\n" message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_mqtt_wss(env, extra_data): def test_examples_protocol_mqtt_wss(env, extra_data):
broker_url = "" broker_url = ''
broker_port = 0 broker_port = 0
""" """
steps: | steps: |
@ -52,14 +51,14 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
3. Test evaluates it received correct qos0 message 3. Test evaluates it received correct qos0 message
4. Test ESP32 client received correct qos0 message 4. Test ESP32 client received correct qos0 message
""" """
dut1 = env.get_dut("mqtt_websocket_secure", "examples/protocols/mqtt/wss", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('mqtt_websocket_secure', 'examples/protocols/mqtt/wss', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "mqtt_websocket_secure.bin") binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket_secure.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("mqtt_websocket_secure_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('mqtt_websocket_secure_bin_size', '{}KB'.format(bin_size // 1024))
# Look for host:port in sdkconfig # Look for host:port in sdkconfig
try: try:
value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()["CONFIG_BROKER_URI"]) value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI'])
broker_url = value.group(1) broker_url = value.group(1)
broker_port = int(value.group(2)) broker_port = int(value.group(2))
except Exception: except Exception:
@ -68,36 +67,36 @@ def test_examples_protocol_mqtt_wss(env, extra_data):
client = None client = None
# 1. Test connects to a broker # 1. Test connects to a broker
try: try:
client = mqtt.Client(transport="websockets") client = mqtt.Client(transport='websockets')
client.on_connect = on_connect client.on_connect = on_connect
client.on_message = on_message client.on_message = on_message
client.tls_set(None, client.tls_set(None,
None, None,
None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
print("Connecting...") print('Connecting...')
client.connect(broker_url, broker_port, 60) client.connect(broker_url, broker_port, 60)
except Exception: except Exception:
print("ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:".format(broker_url, sys.exc_info()[0])) print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0]))
raise raise
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
thread1 = Thread(target=mqtt_client_task, args=(client,)) thread1 = Thread(target=mqtt_client_task, args=(client,))
thread1.start() thread1.start()
try: try:
print("Connecting py-client to broker {}:{}...".format(broker_url, broker_port)) print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port))
if not event_client_connected.wait(timeout=30): if not event_client_connected.wait(timeout=30):
raise ValueError("ENV_TEST_FAILURE: Test script cannot connect to broker: {}".format(broker_url)) raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url))
dut1.start_app() dut1.start_app()
try: try:
ip_address = dut1.expect(re.compile(r" sta ip: ([^,]+),"), timeout=30) ip_address = dut1.expect(re.compile(r' sta ip: ([^,]+),'), timeout=30)
print("Connected to AP with IP: {}".format(ip_address)) print('Connected to AP with IP: {}'.format(ip_address))
except DUT.ExpectTimeout: except DUT.ExpectTimeout:
print('ENV_TEST_FAILURE: Cannot connect to AP') print('ENV_TEST_FAILURE: Cannot connect to AP')
raise raise
print("Checking py-client received msg published from esp...") print('Checking py-client received msg published from esp...')
if not event_client_received_correct.wait(timeout=30): if not event_client_received_correct.wait(timeout=30):
raise ValueError('Wrong data received, msg log: {}'.format(message_log)) raise ValueError('Wrong data received, msg log: {}'.format(message_log))
print("Checking esp-client received msg published from py-client...") print('Checking esp-client received msg published from py-client...')
dut1.expect(re.compile(r"DATA=data_to_esp32"), timeout=30) dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30)
finally: finally:
event_stop_client.set() event_stop_client.set()
thread1.join() thread1.join()

View File

@ -1,10 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from tiny_test_fw import Utility
import os import os
import serial
import threading import threading
import time import time
import serial
import ttfw_idf import ttfw_idf
from tiny_test_fw import Utility
class SerialThread(object): class SerialThread(object):

View File

@ -1,8 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from tiny_test_fw import Utility
import datetime import datetime
import re import re
import ttfw_idf import ttfw_idf
from tiny_test_fw import Utility
@ttfw_idf.idf_example_test(env_tag='Example_WIFI') @ttfw_idf.idf_example_test(env_tag='Example_WIFI')

View File

@ -6,17 +6,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
from builtins import input
import os import os
import re import re
import sys
import netifaces
import socket import socket
from threading import Thread, Event import sys
import ttfw_idf from builtins import input
from threading import Event, Thread
import netifaces
import ttfw_idf
# ----------- Config ---------- # ----------- Config ----------
PORT = 3333 PORT = 3333
@ -26,7 +26,7 @@ INTERFACE = 'eth0'
def get_my_ip(type): def get_my_ip(type):
for i in netifaces.ifaddresses(INTERFACE)[type]: for i in netifaces.ifaddresses(INTERFACE)[type]:
return i['addr'].replace("%{}".format(INTERFACE), "") return i['addr'].replace('%{}'.format(INTERFACE), '')
class TcpServer: class TcpServer:
@ -44,11 +44,11 @@ class TcpServer:
try: try:
self.socket.bind(('', self.port)) self.socket.bind(('', self.port))
except socket.error as e: except socket.error as e:
print("Bind failed:{}".format(e)) print('Bind failed:{}'.format(e))
raise raise
self.socket.listen(1) self.socket.listen(1)
print("Starting server on port={} family_addr={}".format(self.port, self.family_addr)) print('Starting server on port={} family_addr={}'.format(self.port, self.family_addr))
self.server_thread = Thread(target=self.run_server) self.server_thread = Thread(target=self.run_server)
self.server_thread.start() self.server_thread.start()
return self return self
@ -68,7 +68,7 @@ class TcpServer:
while not self.shutdown.is_set(): while not self.shutdown.is_set():
try: try:
conn, address = self.socket.accept() # accept new connection conn, address = self.socket.accept() # accept new connection
print("Connection from: {}".format(address)) print('Connection from: {}'.format(address))
conn.setblocking(1) conn.setblocking(1)
data = conn.recv(1024) data = conn.recv(1024)
if not data: if not data:
@ -79,13 +79,13 @@ class TcpServer:
conn.send(reply.encode()) conn.send(reply.encode())
conn.close() conn.close()
except socket.error as e: except socket.error as e:
print("Running server failed:{}".format(e)) print('Running server failed:{}'.format(e))
raise raise
if not self.persist: if not self.persist:
break break
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_socket_tcpclient(env, extra_data): def test_examples_protocol_socket_tcpclient(env, extra_data):
""" """
steps: steps:
@ -93,39 +93,39 @@ def test_examples_protocol_socket_tcpclient(env, extra_data):
2. have the board connect to the server 2. have the board connect to the server
3. send and receive data 3. send and receive data
""" """
dut1 = env.get_dut("tcp_client", "examples/protocols/sockets/tcp_client", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('tcp_client', 'examples/protocols/sockets/tcp_client', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "tcp_client.bin") binary_file = os.path.join(dut1.app.binary_path, 'tcp_client.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("tcp_client_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('tcp_client_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
ipv4 = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30)[0] ipv4 = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)[0]
ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form) ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form)
ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0] ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0]
print("Connected with IPv4={} and IPv6={}".format(ipv4, ipv6)) print('Connected with IPv4={} and IPv6={}'.format(ipv4, ipv6))
# test IPv4 # test IPv4
with TcpServer(PORT, socket.AF_INET): with TcpServer(PORT, socket.AF_INET):
server_ip = get_my_ip(netifaces.AF_INET) server_ip = get_my_ip(netifaces.AF_INET)
print("Connect tcp client to server IP={}".format(server_ip)) print('Connect tcp client to server IP={}'.format(server_ip))
dut1.write(server_ip) dut1.write(server_ip)
dut1.expect(re.compile(r"OK: Message from ESP32")) dut1.expect(re.compile(r'OK: Message from ESP32'))
# test IPv6 # test IPv6
with TcpServer(PORT, socket.AF_INET6): with TcpServer(PORT, socket.AF_INET6):
server_ip = get_my_ip(netifaces.AF_INET6) server_ip = get_my_ip(netifaces.AF_INET6)
print("Connect tcp client to server IP={}".format(server_ip)) print('Connect tcp client to server IP={}'.format(server_ip))
dut1.write(server_ip) dut1.write(server_ip)
dut1.expect(re.compile(r"OK: Message from ESP32")) dut1.expect(re.compile(r'OK: Message from ESP32'))
if __name__ == '__main__': if __name__ == '__main__':
if sys.argv[1:] and sys.argv[1].startswith("IPv"): # if additional arguments provided: if sys.argv[1:] and sys.argv[1].startswith('IPv'): # if additional arguments provided:
# Usage: example_test.py <IPv4|IPv6> # Usage: example_test.py <IPv4|IPv6>
family_addr = socket.AF_INET6 if sys.argv[1] == "IPv6" else socket.AF_INET family_addr = socket.AF_INET6 if sys.argv[1] == 'IPv6' else socket.AF_INET
with TcpServer(PORT, family_addr, persist=True) as s: with TcpServer(PORT, family_addr, persist=True) as s:
print(input("Press Enter stop the server...")) print(input('Press Enter stop the server...'))
else: else:
test_examples_protocol_socket_tcpclient() test_examples_protocol_socket_tcpclient()

View File

@ -6,14 +6,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
import os import os
import sys
import re import re
import socket import socket
import ttfw_idf import sys
import ttfw_idf
# ----------- Config ---------- # ----------- Config ----------
PORT = 3333 PORT = 3333
@ -46,28 +46,28 @@ def tcp_client(address, payload):
return data.decode() return data.decode()
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_socket_tcpserver(env, extra_data): def test_examples_protocol_socket_tcpserver(env, extra_data):
MESSAGE = "Data to ESP" MESSAGE = 'Data to ESP'
""" """
steps: steps:
1. join AP 1. join AP
2. have the board connect to the server 2. have the board connect to the server
3. send and receive data 3. send and receive data
""" """
dut1 = env.get_dut("tcp_client", "examples/protocols/sockets/tcp_server", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('tcp_client', 'examples/protocols/sockets/tcp_server', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "tcp_server.bin") binary_file = os.path.join(dut1.app.binary_path, 'tcp_server.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("tcp_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('tcp_server_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
ipv4 = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30)[0] ipv4 = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)[0]
ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form) ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form)
ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0] ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0]
print("Connected with IPv4={} and IPv6={}".format(ipv4, ipv6)) print('Connected with IPv4={} and IPv6={}'.format(ipv4, ipv6))
# test IPv4 # test IPv4
received = tcp_client(ipv4, MESSAGE) received = tcp_client(ipv4, MESSAGE)
@ -75,7 +75,7 @@ def test_examples_protocol_socket_tcpserver(env, extra_data):
raise raise
dut1.expect(MESSAGE) dut1.expect(MESSAGE)
# test IPv6 # test IPv6
received = tcp_client("{}%{}".format(ipv6, INTERFACE), MESSAGE) received = tcp_client('{}%{}'.format(ipv6, INTERFACE), MESSAGE)
if not received == MESSAGE: if not received == MESSAGE:
raise raise
dut1.expect(MESSAGE) dut1.expect(MESSAGE)

View File

@ -6,17 +6,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
from builtins import input
import os import os
import re import re
import netifaces
import socket import socket
from threading import Thread, Event
import ttfw_idf
import sys import sys
from builtins import input
from threading import Event, Thread
import netifaces
import ttfw_idf
# ----------- Config ---------- # ----------- Config ----------
PORT = 3333 PORT = 3333
@ -26,7 +26,7 @@ INTERFACE = 'eth0'
def get_my_ip(type): def get_my_ip(type):
for i in netifaces.ifaddresses(INTERFACE)[type]: for i in netifaces.ifaddresses(INTERFACE)[type]:
return i['addr'].replace("%{}".format(INTERFACE), "") return i['addr'].replace('%{}'.format(INTERFACE), '')
class UdpServer: class UdpServer:
@ -44,10 +44,10 @@ class UdpServer:
try: try:
self.socket.bind(('', self.port)) self.socket.bind(('', self.port))
except socket.error as e: except socket.error as e:
print("Bind failed:{}".format(e)) print('Bind failed:{}'.format(e))
raise raise
print("Starting server on port={} family_addr={}".format(self.port, self.family_addr)) print('Starting server on port={} family_addr={}'.format(self.port, self.family_addr))
self.server_thread = Thread(target=self.run_server) self.server_thread = Thread(target=self.run_server)
self.server_thread.start() self.server_thread.start()
return self return self
@ -72,13 +72,13 @@ class UdpServer:
reply = 'OK: ' + data reply = 'OK: ' + data
self.socket.sendto(reply.encode(), addr) self.socket.sendto(reply.encode(), addr)
except socket.error as e: except socket.error as e:
print("Running server failed:{}".format(e)) print('Running server failed:{}'.format(e))
raise raise
if not self.persist: if not self.persist:
break break
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_socket_udpclient(env, extra_data): def test_examples_protocol_socket_udpclient(env, extra_data):
""" """
steps: steps:
@ -86,39 +86,39 @@ def test_examples_protocol_socket_udpclient(env, extra_data):
2. have the board connect to the server 2. have the board connect to the server
3. send and receive data 3. send and receive data
""" """
dut1 = env.get_dut("udp_client", "examples/protocols/sockets/udp_client", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('udp_client', 'examples/protocols/sockets/udp_client', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "udp_client.bin") binary_file = os.path.join(dut1.app.binary_path, 'udp_client.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("udp_client_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('udp_client_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
ipv4 = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30)[0] ipv4 = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)[0]
ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form) ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form)
ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0] ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0]
print("Connected with IPv4={} and IPv6={}".format(ipv4, ipv6)) print('Connected with IPv4={} and IPv6={}'.format(ipv4, ipv6))
# test IPv4 # test IPv4
with UdpServer(PORT, socket.AF_INET): with UdpServer(PORT, socket.AF_INET):
server_ip = get_my_ip(netifaces.AF_INET) server_ip = get_my_ip(netifaces.AF_INET)
print("Connect udp client to server IP={}".format(server_ip)) print('Connect udp client to server IP={}'.format(server_ip))
dut1.write(server_ip) dut1.write(server_ip)
dut1.expect(re.compile(r"OK: Message from ESP32")) dut1.expect(re.compile(r'OK: Message from ESP32'))
# test IPv6 # test IPv6
with UdpServer(PORT, socket.AF_INET6): with UdpServer(PORT, socket.AF_INET6):
server_ip = get_my_ip(netifaces.AF_INET6) server_ip = get_my_ip(netifaces.AF_INET6)
print("Connect udp client to server IP={}".format(server_ip)) print('Connect udp client to server IP={}'.format(server_ip))
dut1.write(server_ip) dut1.write(server_ip)
dut1.expect(re.compile(r"OK: Message from ESP32")) dut1.expect(re.compile(r'OK: Message from ESP32'))
if __name__ == '__main__': if __name__ == '__main__':
if sys.argv[1:] and sys.argv[1].startswith("IPv"): # if additional arguments provided: if sys.argv[1:] and sys.argv[1].startswith('IPv'): # if additional arguments provided:
# Usage: example_test.py <IPv4|IPv6> # Usage: example_test.py <IPv4|IPv6>
family_addr = socket.AF_INET6 if sys.argv[1] == "IPv6" else socket.AF_INET family_addr = socket.AF_INET6 if sys.argv[1] == 'IPv6' else socket.AF_INET
with UdpServer(PORT, family_addr, persist=True) as s: with UdpServer(PORT, family_addr, persist=True) as s:
print(input("Press Enter stop the server...")) print(input('Press Enter stop the server...'))
else: else:
test_examples_protocol_socket_udpclient() test_examples_protocol_socket_udpclient()

View File

@ -6,14 +6,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
import os import os
import sys
import re import re
import socket import socket
import ttfw_idf import sys
import ttfw_idf
# ----------- Config ---------- # ----------- Config ----------
PORT = 3333 PORT = 3333
@ -44,28 +44,28 @@ def udp_client(address, payload):
return reply.decode() return reply.decode()
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_socket_udpserver(env, extra_data): def test_examples_protocol_socket_udpserver(env, extra_data):
MESSAGE = "Data to ESP" MESSAGE = 'Data to ESP'
""" """
steps: steps:
1. join AP 1. join AP
2. have the board connect to the server 2. have the board connect to the server
3. send and receive data 3. send and receive data
""" """
dut1 = env.get_dut("udp_server", "examples/protocols/sockets/udp_server", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('udp_server', 'examples/protocols/sockets/udp_server', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "udp_server.bin") binary_file = os.path.join(dut1.app.binary_path, 'udp_server.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("udp_server_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('udp_server_bin_size', '{}KB'.format(bin_size // 1024))
# start test # start test
dut1.start_app() dut1.start_app()
ipv4 = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30)[0] ipv4 = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)[0]
ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form) ipv6_r = r':'.join((r'[0-9a-fA-F]{4}',) * 8) # expect all 8 octets from IPv6 (assumes it's printed in the long form)
ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0] ipv6 = dut1.expect(re.compile(r' IPv6 address: ({})'.format(ipv6_r)), timeout=30)[0]
print("Connected with IPv4={} and IPv6={}".format(ipv4, ipv6)) print('Connected with IPv4={} and IPv6={}'.format(ipv4, ipv6))
# test IPv4 # test IPv4
received = udp_client(ipv4, MESSAGE) received = udp_client(ipv4, MESSAGE)
@ -73,7 +73,7 @@ def test_examples_protocol_socket_udpserver(env, extra_data):
raise raise
dut1.expect(MESSAGE) dut1.expect(MESSAGE)
# test IPv6 # test IPv6
received = udp_client("{}%{}".format(ipv6, INTERFACE), MESSAGE) received = udp_client('{}%{}'.format(ipv6, INTERFACE), MESSAGE)
if not received == MESSAGE: if not received == MESSAGE:
raise raise
dut1.expect(MESSAGE) dut1.expect(MESSAGE)

View File

@ -1,14 +1,15 @@
from __future__ import print_function from __future__ import print_function, unicode_literals
from __future__ import unicode_literals
import re
import os import os
import socket
import random import random
import re
import socket
import string import string
from threading import Event, Thread
import ttfw_idf
from SimpleWebSocketServer import SimpleWebSocketServer, WebSocket from SimpleWebSocketServer import SimpleWebSocketServer, WebSocket
from tiny_test_fw import Utility from tiny_test_fw import Utility
from threading import Thread, Event
import ttfw_idf
def get_my_ip(): def get_my_ip():
@ -66,15 +67,15 @@ class Websocket(object):
def test_echo(dut): def test_echo(dut):
dut.expect("WEBSOCKET_EVENT_CONNECTED") dut.expect('WEBSOCKET_EVENT_CONNECTED')
for i in range(0, 10): for i in range(0, 10):
dut.expect(re.compile(r"Received=hello (\d)"), timeout=30) dut.expect(re.compile(r'Received=hello (\d)'), timeout=30)
print("All echos received") print('All echos received')
def test_close(dut): def test_close(dut):
code = dut.expect(re.compile(r"WEBSOCKET: Received closed message with code=(\d*)"), timeout=60)[0] code = dut.expect(re.compile(r'WEBSOCKET: Received closed message with code=(\d*)'), timeout=60)[0]
print("Received close frame with code {}".format(code)) print('Received close frame with code {}'.format(code))
def test_recv_long_msg(dut, websocket, msg_len, repeats): def test_recv_long_msg(dut, websocket, msg_len, repeats):
@ -86,17 +87,17 @@ def test_recv_long_msg(dut, websocket, msg_len, repeats):
recv_msg = '' recv_msg = ''
while len(recv_msg) < msg_len: while len(recv_msg) < msg_len:
# Filter out color encoding # Filter out color encoding
match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0] match = dut.expect(re.compile(r'Received=([a-zA-Z0-9]*).*\n'), timeout=30)[0]
recv_msg += match recv_msg += match
if recv_msg == send_msg: if recv_msg == send_msg:
print("Sent message and received message are equal") print('Sent message and received message are equal')
else: else:
raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\ raise ValueError('DUT received string do not match sent string, \nexpected: {}\nwith length {}\
\nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg))) \nreceived: {}\nwith length {}'.format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag='Example_WIFI')
def test_examples_protocol_websocket(env, extra_data): def test_examples_protocol_websocket(env, extra_data):
""" """
steps: steps:
@ -104,17 +105,17 @@ def test_examples_protocol_websocket(env, extra_data):
2. connect to uri specified in the config 2. connect to uri specified in the config
3. send and receive data 3. send and receive data
""" """
dut1 = env.get_dut("websocket", "examples/protocols/websocket", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('websocket', 'examples/protocols/websocket', dut_class=ttfw_idf.ESP32DUT)
# check and log bin size # check and log bin size
binary_file = os.path.join(dut1.app.binary_path, "websocket-example.bin") binary_file = os.path.join(dut1.app.binary_path, 'websocket-example.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("websocket_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('websocket_bin_size', '{}KB'.format(bin_size // 1024))
try: try:
if "CONFIG_WEBSOCKET_URI_FROM_STDIN" in dut1.app.get_sdkconfig(): if 'CONFIG_WEBSOCKET_URI_FROM_STDIN' in dut1.app.get_sdkconfig():
uri_from_stdin = True uri_from_stdin = True
else: else:
uri = dut1.app.get_sdkconfig()["CONFIG_WEBSOCKET_URI"].strip('"') uri = dut1.app.get_sdkconfig()['CONFIG_WEBSOCKET_URI'].strip('"')
uri_from_stdin = False uri_from_stdin = False
except Exception: except Exception:
@ -127,9 +128,9 @@ def test_examples_protocol_websocket(env, extra_data):
if uri_from_stdin: if uri_from_stdin:
server_port = 4455 server_port = 4455
with Websocket(server_port) as ws: with Websocket(server_port) as ws:
uri = "ws://{}:{}".format(get_my_ip(), server_port) uri = 'ws://{}:{}'.format(get_my_ip(), server_port)
print("DUT connecting to {}".format(uri)) print('DUT connecting to {}'.format(uri))
dut1.expect("Please enter uri of websocket endpoint", timeout=30) dut1.expect('Please enter uri of websocket endpoint', timeout=30)
dut1.write(uri) dut1.write(uri)
test_echo(dut1) test_echo(dut1)
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
@ -137,7 +138,7 @@ def test_examples_protocol_websocket(env, extra_data):
test_close(dut1) test_close(dut1)
else: else:
print("DUT connecting to {}".format(uri)) print('DUT connecting to {}'.format(uri))
test_echo(dut1) test_echo(dut1)

View File

@ -15,73 +15,74 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import re
import os
import ttfw_idf import os
import re
import esp_prov import esp_prov
import ttfw_idf
# Have esp_prov throw exception # Have esp_prov throw exception
esp_prov.config_throw_except = True esp_prov.config_throw_except = True
@ttfw_idf.idf_example_test(env_tag="Example_WIFI_BT") @ttfw_idf.idf_example_test(env_tag='Example_WIFI_BT')
def test_examples_provisioning_ble(env, extra_data): def test_examples_provisioning_ble(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("ble_prov", "examples/provisioning/legacy/ble_prov", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('ble_prov', 'examples/provisioning/legacy/ble_prov', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "ble_prov.bin") binary_file = os.path.join(dut1.app.binary_path, 'ble_prov.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("ble_prov_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('ble_prov_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
dut1.start_app() dut1.start_app()
# Parse BLE devname # Parse BLE devname
devname = dut1.expect(re.compile(r"Provisioning started with BLE devname : '(PROV_\S\S\S\S\S\S)'"), timeout=60)[0] devname = dut1.expect(re.compile(r"Provisioning started with BLE devname : '(PROV_\S\S\S\S\S\S)'"), timeout=60)[0]
print("BLE Device Alias for DUT :", devname) print('BLE Device Alias for DUT :', devname)
# Match additional headers sent in the request # Match additional headers sent in the request
dut1.expect("BLE Provisioning started", timeout=30) dut1.expect('BLE Provisioning started', timeout=30)
print("Starting Provisioning") print('Starting Provisioning')
verbose = False verbose = False
protover = "V0.1" protover = 'V0.1'
secver = 1 secver = 1
pop = "abcd1234" pop = 'abcd1234'
provmode = "ble" provmode = 'ble'
ap_ssid = "myssid" ap_ssid = 'myssid'
ap_password = "mypassword" ap_password = 'mypassword'
print("Getting security") print('Getting security')
security = esp_prov.get_security(secver, pop, verbose) security = esp_prov.get_security(secver, pop, verbose)
if security is None: if security is None:
raise RuntimeError("Failed to get security") raise RuntimeError('Failed to get security')
print("Getting transport") print('Getting transport')
transport = esp_prov.get_transport(provmode, devname) transport = esp_prov.get_transport(provmode, devname)
if transport is None: if transport is None:
raise RuntimeError("Failed to get transport") raise RuntimeError('Failed to get transport')
print("Verifying protocol version") print('Verifying protocol version')
if not esp_prov.version_match(transport, protover): if not esp_prov.version_match(transport, protover):
raise RuntimeError("Mismatch in protocol version") raise RuntimeError('Mismatch in protocol version')
print("Starting Session") print('Starting Session')
if not esp_prov.establish_session(transport, security): if not esp_prov.establish_session(transport, security):
raise RuntimeError("Failed to start session") raise RuntimeError('Failed to start session')
print("Sending Wifi credential to DUT") print('Sending Wifi credential to DUT')
if not esp_prov.send_wifi_config(transport, security, ap_ssid, ap_password): if not esp_prov.send_wifi_config(transport, security, ap_ssid, ap_password):
raise RuntimeError("Failed to send Wi-Fi config") raise RuntimeError('Failed to send Wi-Fi config')
print("Applying config") print('Applying config')
if not esp_prov.apply_wifi_config(transport, security): if not esp_prov.apply_wifi_config(transport, security):
raise RuntimeError("Failed to send apply config") raise RuntimeError('Failed to send apply config')
if not esp_prov.wait_wifi_connected(transport, security): if not esp_prov.wait_wifi_connected(transport, security):
raise RuntimeError("Provisioning failed") raise RuntimeError('Provisioning failed')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -2,12 +2,14 @@
# source: custom_config.proto # source: custom_config.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import enum_type_wrapper
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -61,7 +63,7 @@ _CUSTOMCONFIGREQUEST = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='info', full_name='CustomConfigRequest.info', index=0, name='info', full_name='CustomConfigRequest.info', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'), has_default_value=False, default_value=_b('').decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),

View File

@ -15,88 +15,89 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import re
import os
import ttfw_idf import os
import re
import esp_prov import esp_prov
import ttfw_idf
import wifi_tools import wifi_tools
# Have esp_prov throw exception # Have esp_prov throw exception
esp_prov.config_throw_except = True esp_prov.config_throw_except = True
@ttfw_idf.idf_example_test(env_tag="Example_WIFI_BT") @ttfw_idf.idf_example_test(env_tag='Example_WIFI_BT')
def test_examples_provisioning_softap(env, extra_data): def test_examples_provisioning_softap(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("softap_prov", "examples/provisioning/legacy/softap_prov", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('softap_prov', 'examples/provisioning/legacy/softap_prov', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "softap_prov.bin") binary_file = os.path.join(dut1.app.binary_path, 'softap_prov.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("softap_prov_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('softap_prov_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
dut1.start_app() dut1.start_app()
# Parse IP address of STA # Parse IP address of STA
dut1.expect("Starting WiFi SoftAP provisioning", timeout=60) dut1.expect('Starting WiFi SoftAP provisioning', timeout=60)
[ssid, password] = dut1.expect(re.compile(r"SoftAP Provisioning started with SSID '(\S+)', Password '(\S+)'"), timeout=30) [ssid, password] = dut1.expect(re.compile(r"SoftAP Provisioning started with SSID '(\S+)', Password '(\S+)'"), timeout=30)
iface = wifi_tools.get_wiface_name() iface = wifi_tools.get_wiface_name()
if iface is None: if iface is None:
raise RuntimeError("Failed to get Wi-Fi interface on host") raise RuntimeError('Failed to get Wi-Fi interface on host')
print("Interface name : " + iface) print('Interface name : ' + iface)
print("SoftAP SSID : " + ssid) print('SoftAP SSID : ' + ssid)
print("SoftAP Password : " + password) print('SoftAP Password : ' + password)
try: try:
ctrl = wifi_tools.wpa_cli(iface, reset_on_exit=True) ctrl = wifi_tools.wpa_cli(iface, reset_on_exit=True)
print("Connecting to DUT SoftAP...") print('Connecting to DUT SoftAP...')
ip = ctrl.connect(ssid, password) ip = ctrl.connect(ssid, password)
got_ip = dut1.expect(re.compile(r"DHCP server assigned IP to a station, IP is: (\d+.\d+.\d+.\d+)"), timeout=60)[0] got_ip = dut1.expect(re.compile(r'DHCP server assigned IP to a station, IP is: (\d+.\d+.\d+.\d+)'), timeout=60)[0]
if ip != got_ip: if ip != got_ip:
raise RuntimeError("SoftAP connected to another host! " + ip + "!=" + got_ip) raise RuntimeError('SoftAP connected to another host! ' + ip + '!=' + got_ip)
print("Connected to DUT SoftAP") print('Connected to DUT SoftAP')
print("Starting Provisioning") print('Starting Provisioning')
verbose = False verbose = False
protover = "V0.1" protover = 'V0.1'
secver = 1 secver = 1
pop = "abcd1234" pop = 'abcd1234'
provmode = "softap" provmode = 'softap'
ap_ssid = "myssid" ap_ssid = 'myssid'
ap_password = "mypassword" ap_password = 'mypassword'
softap_endpoint = ip.split('.')[0] + "." + ip.split('.')[1] + "." + ip.split('.')[2] + ".1:80" softap_endpoint = ip.split('.')[0] + '.' + ip.split('.')[1] + '.' + ip.split('.')[2] + '.1:80'
print("Getting security") print('Getting security')
security = esp_prov.get_security(secver, pop, verbose) security = esp_prov.get_security(secver, pop, verbose)
if security is None: if security is None:
raise RuntimeError("Failed to get security") raise RuntimeError('Failed to get security')
print("Getting transport") print('Getting transport')
transport = esp_prov.get_transport(provmode, softap_endpoint) transport = esp_prov.get_transport(provmode, softap_endpoint)
if transport is None: if transport is None:
raise RuntimeError("Failed to get transport") raise RuntimeError('Failed to get transport')
print("Verifying protocol version") print('Verifying protocol version')
if not esp_prov.version_match(transport, protover): if not esp_prov.version_match(transport, protover):
raise RuntimeError("Mismatch in protocol version") raise RuntimeError('Mismatch in protocol version')
print("Starting Session") print('Starting Session')
if not esp_prov.establish_session(transport, security): if not esp_prov.establish_session(transport, security):
raise RuntimeError("Failed to start session") raise RuntimeError('Failed to start session')
print("Sending Wifi credential to DUT") print('Sending Wifi credential to DUT')
if not esp_prov.send_wifi_config(transport, security, ap_ssid, ap_password): if not esp_prov.send_wifi_config(transport, security, ap_ssid, ap_password):
raise RuntimeError("Failed to send Wi-Fi config") raise RuntimeError('Failed to send Wi-Fi config')
print("Applying config") print('Applying config')
if not esp_prov.apply_wifi_config(transport, security): if not esp_prov.apply_wifi_config(transport, security):
raise RuntimeError("Failed to send apply config") raise RuntimeError('Failed to send apply config')
if not esp_prov.wait_wifi_connected(transport, security): if not esp_prov.wait_wifi_connected(transport, security):
raise RuntimeError("Provisioning failed") raise RuntimeError('Provisioning failed')
finally: finally:
ctrl.reset() ctrl.reset()

View File

@ -15,84 +15,85 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import re
import os
import ttfw_idf import os
import re
import esp_prov import esp_prov
import ttfw_idf
# Have esp_prov throw exception # Have esp_prov throw exception
esp_prov.config_throw_except = True esp_prov.config_throw_except = True
@ttfw_idf.idf_example_test(env_tag="Example_WIFI_BT") @ttfw_idf.idf_example_test(env_tag='Example_WIFI_BT')
def test_examples_wifi_prov_mgr(env, extra_data): def test_examples_wifi_prov_mgr(env, extra_data):
# Acquire DUT # Acquire DUT
dut1 = env.get_dut("wifi_prov_mgr", "examples/provisioning/wifi_prov_mgr", dut_class=ttfw_idf.ESP32DUT) dut1 = env.get_dut('wifi_prov_mgr', 'examples/provisioning/wifi_prov_mgr', dut_class=ttfw_idf.ESP32DUT)
# Get binary file # Get binary file
binary_file = os.path.join(dut1.app.binary_path, "wifi_prov_mgr.bin") binary_file = os.path.join(dut1.app.binary_path, 'wifi_prov_mgr.bin')
bin_size = os.path.getsize(binary_file) bin_size = os.path.getsize(binary_file)
ttfw_idf.log_performance("wifi_prov_mgr_bin_size", "{}KB".format(bin_size // 1024)) ttfw_idf.log_performance('wifi_prov_mgr_bin_size', '{}KB'.format(bin_size // 1024))
# Upload binary and start testing # Upload binary and start testing
dut1.start_app() dut1.start_app()
# Check if BT memory is released before provisioning starts # Check if BT memory is released before provisioning starts
dut1.expect("wifi_prov_scheme_ble: BT memory released", timeout=60) dut1.expect('wifi_prov_scheme_ble: BT memory released', timeout=60)
# Parse BLE devname # Parse BLE devname
devname = dut1.expect(re.compile(r"Provisioning started with service name : (PROV_\S\S\S\S\S\S)"), timeout=30)[0] devname = dut1.expect(re.compile(r'Provisioning started with service name : (PROV_\S\S\S\S\S\S)'), timeout=30)[0]
print("BLE Device Alias for DUT :", devname) print('BLE Device Alias for DUT :', devname)
print("Starting Provisioning") print('Starting Provisioning')
verbose = False verbose = False
protover = "v1.1" protover = 'v1.1'
secver = 1 secver = 1
pop = "abcd1234" pop = 'abcd1234'
provmode = "ble" provmode = 'ble'
ap_ssid = "myssid" ap_ssid = 'myssid'
ap_password = "mypassword" ap_password = 'mypassword'
print("Getting security") print('Getting security')
security = esp_prov.get_security(secver, pop, verbose) security = esp_prov.get_security(secver, pop, verbose)
if security is None: if security is None:
raise RuntimeError("Failed to get security") raise RuntimeError('Failed to get security')
print("Getting transport") print('Getting transport')
transport = esp_prov.get_transport(provmode, devname) transport = esp_prov.get_transport(provmode, devname)
if transport is None: if transport is None:
raise RuntimeError("Failed to get transport") raise RuntimeError('Failed to get transport')
print("Verifying protocol version") print('Verifying protocol version')
if not esp_prov.version_match(transport, protover): if not esp_prov.version_match(transport, protover):
raise RuntimeError("Mismatch in protocol version") raise RuntimeError('Mismatch in protocol version')
print("Verifying scan list capability") print('Verifying scan list capability')
if not esp_prov.has_capability(transport, 'wifi_scan'): if not esp_prov.has_capability(transport, 'wifi_scan'):
raise RuntimeError("Capability not present") raise RuntimeError('Capability not present')
print("Starting Session") print('Starting Session')
if not esp_prov.establish_session(transport, security): if not esp_prov.establish_session(transport, security):
raise RuntimeError("Failed to start session") raise RuntimeError('Failed to start session')
print("Sending Custom Data") print('Sending Custom Data')
if not esp_prov.custom_data(transport, security, "My Custom Data"): if not esp_prov.custom_data(transport, security, 'My Custom Data'):
raise RuntimeError("Failed to send custom data") raise RuntimeError('Failed to send custom data')
print("Sending Wifi credential to DUT") print('Sending Wifi credential to DUT')
if not esp_prov.send_wifi_config(transport, security, ap_ssid, ap_password): if not esp_prov.send_wifi_config(transport, security, ap_ssid, ap_password):
raise RuntimeError("Failed to send Wi-Fi config") raise RuntimeError('Failed to send Wi-Fi config')
print("Applying config") print('Applying config')
if not esp_prov.apply_wifi_config(transport, security): if not esp_prov.apply_wifi_config(transport, security):
raise RuntimeError("Failed to send apply config") raise RuntimeError('Failed to send apply config')
if not esp_prov.wait_wifi_connected(transport, security): if not esp_prov.wait_wifi_connected(transport, security):
raise RuntimeError("Provisioning failed") raise RuntimeError('Provisioning failed')
# Check if BTDM memory is released after provisioning finishes # Check if BTDM memory is released after provisioning finishes
dut1.expect("wifi_prov_scheme_ble: BTDM memory released", timeout=30) dut1.expect('wifi_prov_scheme_ble: BTDM memory released', timeout=30)
if __name__ == '__main__': if __name__ == '__main__':

Some files were not shown because too many files have changed in this diff Show More