diff --git a/examples/protocols/http_server/ws_echo_server/ws_server_example_test.py b/examples/protocols/http_server/ws_echo_server/ws_server_example_test.py index 3dd77b4697..2d53b698f0 100644 --- a/examples/protocols/http_server/ws_echo_server/ws_server_example_test.py +++ b/examples/protocols/http_server/ws_echo_server/ws_server_example_test.py @@ -21,11 +21,7 @@ import re from tiny_test_fw import Utility import ttfw_idf import os -import six -import socket -import hashlib -import base64 -import struct +import websocket OPCODE_TEXT = 0x1 @@ -38,63 +34,24 @@ class WsClient: def __init__(self, ip, port): self.port = port self.ip = ip - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.client_key = "abcdefghjk" - self.socket.settimeout(10.0) + self.ws = websocket.WebSocket() def __enter__(self): - self.socket.connect((self.ip, self.port)) - self._handshake() + self.ws.connect("ws://{}:{}/ws".format(self.ip, self.port)) return self def __exit__(self, exc_type, exc_value, traceback): - self.socket.close() - - def _handshake(self): - MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - client_key = self.client_key + MAGIC_STRING - expected_accept = base64.standard_b64encode(hashlib.sha1(client_key.encode()).digest()) - request = ('GET /ws HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: ' - 'Upgrade\r\nSec-WebSocket-Key: {}\r\n' - 'Sec-WebSocket-Version: 13\r\n\r\n'.format(self.client_key)) - self.socket.send(request.encode('utf-8')) - response = self.socket.recv(1024) - ws_accept = re.search(b'Sec-WebSocket-Accept: (.*)\r\n', response, re.IGNORECASE) - if ws_accept and ws_accept.group(1) is not None and ws_accept.group(1) == expected_accept: - pass - else: - raise("Unexpected Sec-WebSocket-Accept, handshake response: {}".format(response)) - - def _masked(self, data): - mask = struct.unpack('B' * 4, os.urandom(4)) - out = list(mask) - for i, d in enumerate(struct.unpack('B' * len(data), data)): - out.append(d ^ mask[i % 4]) - return struct.pack('B' * len(out), *out) - - def _ws_encode(self, data="", opcode=OPCODE_TEXT, mask=1): - data = data.encode('utf-8') - length = len(data) - if length >= 126: - raise("Packet length of {} not supported!".format(length)) - frame_header = chr(1 << 7 | opcode) - frame_header += chr(mask << 7 | length) - frame_header = six.b(frame_header) - if not mask: - return frame_header + data - return frame_header + self._masked(data) + self.ws.close() def read(self): - header = self.socket.recv(2) - if not six.PY3: - header = [ord(character) for character in header] - opcode = header[0] & 15 - length = header[1] & 127 - payload = self.socket.recv(length) - return opcode, payload.decode('utf-8') + return self.ws.recv_data(control_frame=True) - def write(self, data="", opcode=OPCODE_TEXT, mask=1): - return self.socket.sendall(self._ws_encode(data=data, opcode=opcode, mask=mask)) + def write(self, data="", opcode=OPCODE_TEXT): + if opcode == OPCODE_BIN: + return self.ws.send_binary(data.encode()) + if opcode == OPCODE_PING: + return self.ws.ping(data) + return self.ws.send(data) @ttfw_idf.idf_example_test(env_tag="Example_WIFI")