Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
preen
  • Loading branch information
bdraco committed Oct 27, 2025
commit 0c5584fc764ae62510b18e93c33c73b799e22385
180 changes: 91 additions & 89 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,115 +149,117 @@
protocol: BaseProtocol,
transport: asyncio.Transport,
slow_executor: ThreadPoolExecutor,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that cancelled compression doesn't corrupt subsequent sends.

Regression test for https://github.com/aio-libs/aiohttp/issues/11725
"""
with mock.patch("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024):
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)

# Replace executor with slow one to make race condition reproducible
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create large data that will trigger executor-based compression
large_data_1 = b"A" * 10000
large_data_2 = b"B" * 10000

# Start first send and cancel it during compression
async def send_and_cancel():
await writer.send_frame(large_data_1, WSMsgType.BINARY)

task = asyncio.create_task(send_and_cancel())
# Give it a moment to start compression
await asyncio.sleep(0.01)
task.cancel()

# Catch the cancellation
try:
await task
except asyncio.CancelledError:
pass

# Send second message - this should NOT be corrupted
await writer.send_frame(large_data_2, WSMsgType.BINARY)

# Verify the second send produced correct data
last_call = writer.transport.write.call_args_list[-1] # type: ignore[attr-defined]
call_bytes = last_call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# The data should be all B's, not mixed with A's from the cancelled send
assert msg.data == large_data_2
monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024)
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)

# Replace executor with slow one to make race condition reproducible
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create large data that will trigger executor-based compression
large_data_1 = b"A" * 10000
large_data_2 = b"B" * 10000

# Start first send and cancel it during compression
async def send_and_cancel():
await writer.send_frame(large_data_1, WSMsgType.BINARY)

task = asyncio.create_task(send_and_cancel())
# Give it a moment to start compression
await asyncio.sleep(0.01)
task.cancel()

# Catch the cancellation
try:
await task
except asyncio.CancelledError:
pass

# Send second message - this should NOT be corrupted
await writer.send_frame(large_data_2, WSMsgType.BINARY)

# Verify the second send produced correct data
last_call = writer.transport.write.call_args_list[-1] # type: ignore[attr-defined]
call_bytes = last_call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# The data should be all B's, not mixed with A's from the cancelled send
assert msg.data == large_data_2


@pytest.mark.usefixtures("parametrize_zlib_backend")
async def test_send_compress_multiple_cancelled(
protocol: BaseProtocol,
transport: asyncio.Transport,
slow_executor: ThreadPoolExecutor,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that multiple compressed sends all complete despite cancellation.

Regression test for https://github.com/aio-libs/aiohttp/issues/11725
This verifies that once a send operation enters the shield, it completes
even if cancelled. The lock serializes sends, so they process one at a time.
"""
with mock.patch("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024):
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)
monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024)
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)

# Replace executor with slow one
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create 5 large messages with different content
messages = [bytes([ord("A") + i]) * 10000 for i in range(5)]

# Start sending all 5 messages - they'll queue due to the lock
tasks = [
asyncio.create_task(writer.send_frame(msg, WSMsgType.BINARY))
for msg in messages
]

# Cancel all tasks during execution
# Tasks in the shield will complete, tasks waiting for lock will cancel
await asyncio.sleep(0.03) # Let one or two enter the shield
for task in tasks:
task.cancel()

# Collect results
cancelled_count = 0
for task in tasks:
try:
await task
except asyncio.CancelledError:
cancelled_count += 1

# Replace executor with slow one
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create 5 large messages with different content
messages = [bytes([ord("A") + i]) * 10000 for i in range(5)]

# Start sending all 5 messages - they'll queue due to the lock
tasks = [
asyncio.create_task(writer.send_frame(msg, WSMsgType.BINARY))
for msg in messages
]

# Cancel all tasks during execution
# Tasks in the shield will complete, tasks waiting for lock will cancel
await asyncio.sleep(0.03) # Let one or two enter the shield
for task in tasks:
task.cancel()

# Collect results
cancelled_count = 0
for task in tasks:
try:
await task
except asyncio.CancelledError:
cancelled_count += 1

# At least one message should have been sent (the one in the shield)
sent_count = len(writer.transport.write.call_args_list) # type: ignore[attr-defined]
assert sent_count >= 1, "At least one send should complete due to shield"
assert sent_count <= 5, "Can't send more than 5 messages"

# Verify all sent messages are correct (no corruption)
for i in range(sent_count):
call = writer.transport.write.call_args_list[i] # type: ignore[attr-defined]
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# Verify the data matches the expected message
expected_byte = bytes([ord("A") + i])
assert msg.data == expected_byte * 10000, f"Message {i} corrupted"
# At least one message should have been sent (the one in the shield)
sent_count = len(writer.transport.write.call_args_list) # type: ignore[attr-defined]
assert sent_count >= 1, "At least one send should complete due to shield"
assert sent_count <= 5, "Can't send more than 5 messages"

# Verify all sent messages are correct (no corruption)
for i in range(sent_count):
call = writer.transport.write.call_args_list[i] # type: ignore[attr-defined]
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# Verify the data matches the expected message
expected_byte = bytes([ord("A") + i])
assert msg.data == expected_byte * 10000, f"Message {i} corrupted"


@pytest.mark.parametrize(
Expand Down
Loading