Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adding backpressure test
  • Loading branch information
graebm committed Dec 17, 2022
commit a07f0ac8766b05778ca0ae7f2c27b3808c292b95
57 changes: 52 additions & 5 deletions test/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from io import StringIO
import logging
from os import urandom
from queue import Queue
from queue import Empty, Queue
import secrets
import socket
from test import NativeResourceTest
import threading
from time import sleep
from time import sleep, time
from typing import Optional

# using a 3rdparty websocket library for the server
Expand All @@ -26,7 +27,7 @@
# logging.basicConfig(format="%(message)s", level=logging.DEBUG)

# uncomment this for logging from our websockets client
# init_logging(LogLevel.Trace, 'stderr')
init_logging(LogLevel.Trace, 'stderr')


@dataclass
Expand All @@ -46,15 +47,16 @@ def __init__(self):
self.incoming_frame_payload = bytearray()
self.exception = None

def connect_sync(self, host, port):
def connect_sync(self, host, port, **connect_kwargs):
connect(host=host,
port=port,
handshake_request=create_handshake_request(host=host),
on_connection_setup=self._on_connection_setup,
on_connection_shutdown=self._on_connection_shutdown,
on_incoming_frame_begin=self._on_incoming_frame_begin,
on_incoming_frame_payload=self._on_incoming_frame_payload,
on_incoming_frame_complete=self._on_incoming_frame_complete)
on_incoming_frame_complete=self._on_incoming_frame_complete,
**connect_kwargs)
# wait for on_connection_setup to fire
setup_data = self.setup_future.result(TIMEOUT)
assert setup_data.exception is None
Expand Down Expand Up @@ -128,6 +130,8 @@ def __enter__(self):
# don't return until the server signals that it's started up and is listening for connections
assert self._server_started_event.wait(TIMEOUT)

return self

def __exit__(self, exc_type, exc_value, exc_tb):
# main thread is exiting the `with` block: tell the server to stop...

Expand Down Expand Up @@ -158,6 +162,7 @@ async def _run_asyncio_server(self):
async def _run_connection(self, server_connection: websockets_server_3rdparty.WebSocketServerProtocol):
# this coroutine runs once for each connection to the server
# when this coroutine exits, the connection gets shut down
self._current_connection = server_connection
try:
# await each message...
async for msg in server_connection:
Expand All @@ -170,6 +175,12 @@ async def _run_connection(self, server_connection: websockets_server_3rdparty.We
# even if the connection ends cleanly, so just swallow it
pass

finally:
self._current_connection = None

def send_async(self, msg):
asyncio.run_coroutine_threadsafe(self._current_connection.send(msg), self._server_loop)


class TestClient(NativeResourceTest):
def setUp(self):
Expand Down Expand Up @@ -504,3 +515,39 @@ def bad_incoming_frame_callback(data):
# wait for the frame to echo back, firing the bad callback,
# which raises an exception, which should result in the WebSocket closing
shutdown_future.result(TIMEOUT)

def test_backpressure_enabled(self):
# test that we can use read backpressure to control how much data is read
with WebSocketServer(self.host, self.port) as server:
handler = ClientHandler()
handler.connect_sync(self.host, self.port, enable_read_backpressure=True, initial_read_window=0)
# handler.connect_sync(self.host, self.port, enable_read_backpressure=True, initial_read_window=1000)

# # window is 1000-bytes
# # send 10 100-byte messages
# # they should all get through
# for i in range(10):
# msg = secrets.token_bytes(100)
# server.send_async(msg)
# recv: RecvFrame = handler.complete_frames.get(timeout=TIMEOUT)
# self.assertEqual(recv.payload, msg, "did not receive expected payload")

# now window is 0
# send a 1000 byte message, NONE of its payload should arrive
msg = secrets.token_bytes(1000)
server.send_async(msg)
with self.assertRaises(Empty):
handler.complete_frames.get(timeout=1.0)
self.assertEqual(len(handler.incoming_frame_payload), 0, "No payload should arrive while window is 0")

# now increment the window and let half the (500/1000) bytes in
handler.websocket.increment_read_window(500)
max_wait_until = time() + TIMEOUT
while len(handler.incoming_frame_payload) < 500:
sleep(0.001)
self.assertLess(time(), max_wait_until, "timed out waiting for all bytes")
sleep(1.0) # sleep a moment to be sure we don't receive MORE than 500 bytes
self.assertEqual(len(handler.incoming_frame_payload), 500, "received more bytes than expected")


handler.close_sync()