diff --git a/CHANGES/11725.bugfix.rst b/CHANGES/11725.bugfix.rst new file mode 100644 index 00000000000..e78fc054230 --- /dev/null +++ b/CHANGES/11725.bugfix.rst @@ -0,0 +1 @@ +Fixed WebSocket compressed sends to be cancellation safe. Tasks are now shielded during compression to prevent compressor state corruption. This ensures that the stateful compressor remains consistent even when send operations are cancelled -- by :user:`bdraco`. diff --git a/aiohttp/_websocket/writer.py b/aiohttp/_websocket/writer.py index fdbcda45c3c..1b27dff9371 100644 --- a/aiohttp/_websocket/writer.py +++ b/aiohttp/_websocket/writer.py @@ -2,8 +2,9 @@ import asyncio import random +import sys from functools import partial -from typing import Any, Final +from typing import Final from ..base_protocol import BaseProtocol from ..client_exceptions import ClientConnectionResetError @@ -22,14 +23,18 @@ DEFAULT_LIMIT: Final[int] = 2**16 +# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames +# Control frames (ping, pong, close) are never compressed +WS_CONTROL_FRAME_OPCODE: Final[int] = 8 + # For websockets, keeping latency low is extremely important as implementations -# generally expect to be able to send and receive messages quickly. We use a -# larger chunk size than the default to reduce the number of executor calls -# since the executor is a significant source of latency and overhead when -# the chunks are small. A size of 5KiB was chosen because it is also the -# same value python-zlib-ng choose to use as the threshold to release the GIL. +# generally expect to be able to send and receive messages quickly. We use a +# larger chunk size to reduce the number of executor calls and avoid task +# creation overhead, since both are significant sources of latency when chunks +# are small. A size of 16KiB was chosen as a balance between avoiding task +# overhead and not blocking the event loop too long with synchronous compression. -WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024 +WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024 class WebSocketWriter: @@ -62,7 +67,9 @@ def __init__( self._closing = False self._limit = limit self._output_size = 0 - self._compressobj: Any = None # actually compressobj + self._compressobj: ZLibCompressor | None = None + self._send_lock = asyncio.Lock() + self._background_tasks: set[asyncio.Task[None]] = set() async def send_frame( self, message: bytes, opcode: int, compress: int | None = None @@ -71,39 +78,57 @@ async def send_frame( if self._closing and not (opcode & WSMsgType.CLOSE): raise ClientConnectionResetError("Cannot write to closing transport") - # RSV are the reserved bits in the frame header. They are used to - # indicate that the frame is using an extension. - # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 - rsv = 0 - # Only compress larger packets (disabled) - # Does small packet needs to be compressed? - # if self.compress and opcode < 8 and len(message) > 124: - if (compress or self.compress) and opcode < 8: - # RSV1 (rsv = 0x40) is set for compressed frames - # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 - rsv = 0x40 - - if compress: - # Do not set self._compress if compressing is for this frame - compressobj = self._make_compress_obj(compress) - else: # self.compress - if not self._compressobj: - self._compressobj = self._make_compress_obj(self.compress) - compressobj = self._compressobj - - message = ( - await compressobj.compress(message) - + compressobj.flush( - ZLibBackend.Z_FULL_FLUSH - if self.notakeover - else ZLibBackend.Z_SYNC_FLUSH - ) - ).removesuffix(WS_DEFLATE_TRAILING) - # Its critical that we do not return control to the event - # loop until we have finished sending all the compressed - # data. Otherwise we could end up mixing compressed frames - # if there are multiple coroutines compressing data. + if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE: + # Non-compressed frames don't need lock or shield + self._write_websocket_frame(message, opcode, 0) + elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE: + # Small compressed payloads - compress synchronously in event loop + # We need the lock even though sync compression has no await points. + # This prevents small frames from interleaving with large frames that + # compress in the executor, avoiding compressor state corruption. + async with self._send_lock: + self._send_compressed_frame_sync(message, opcode, compress) + else: + # Large compressed frames need shield to prevent corruption + # For large compressed frames, the entire compress+send + # operation must be atomic. If cancelled after compression but + # before send, the compressor state would be advanced but data + # not sent, corrupting subsequent frames. + # Create a task to shield from cancellation + # The lock is acquired inside the shielded task so the entire + # operation (lock + compress + send) completes atomically. + # Use eager_start on Python 3.12+ to avoid scheduling overhead + loop = asyncio.get_running_loop() + coro = self._send_compressed_frame_async_locked(message, opcode, compress) + if sys.version_info >= (3, 12): + send_task = asyncio.Task(coro, loop=loop, eager_start=True) + else: + send_task = loop.create_task(coro) + # Keep a strong reference to prevent garbage collection + self._background_tasks.add(send_task) + send_task.add_done_callback(self._background_tasks.discard) + await asyncio.shield(send_task) + + # It is safe to return control to the event loop when using compression + # after this point as we have already sent or buffered all the data. + # Once we have written output_size up to the limit, we call the + # drain helper which waits for the transport to be ready to accept + # more data. This is a flow control mechanism to prevent the buffer + # from growing too large. The drain helper will return right away + # if the writer is not paused. + if self._output_size > self._limit: + self._output_size = 0 + if self.protocol._paused: + await self.protocol._drain_helper() + def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None: + """ + Write a websocket frame to the transport. + + This method handles frame header construction, masking, and writing to transport. + It does not handle compression or flow control - those are the responsibility + of the caller. + """ msg_length = len(message) use_mask = self.use_mask @@ -146,26 +171,85 @@ async def send_frame( self._output_size += header_len + msg_length - # It is safe to return control to the event loop when using compression - # after this point as we have already sent or buffered all the data. + def _get_compressor(self, compress: int | None) -> ZLibCompressor: + """Get or create a compressor object for the given compression level.""" + if compress: + # Do not set self._compress if compressing is for this frame + return ZLibCompressor( + level=ZLibBackend.Z_BEST_SPEED, + wbits=-compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + if not self._compressobj: + self._compressobj = ZLibCompressor( + level=ZLibBackend.Z_BEST_SPEED, + wbits=-self.compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + return self._compressobj - # Once we have written output_size up to the limit, we call the - # drain helper which waits for the transport to be ready to accept - # more data. This is a flow control mechanism to prevent the buffer - # from growing too large. The drain helper will return right away - # if the writer is not paused. - if self._output_size > self._limit: - self._output_size = 0 - if self.protocol._paused: - await self.protocol._drain_helper() + def _send_compressed_frame_sync( + self, message: bytes, opcode: int, compress: int | None + ) -> None: + """ + Synchronous send for small compressed frames. - def _make_compress_obj(self, compress: int) -> ZLibCompressor: - return ZLibCompressor( - level=ZLibBackend.Z_BEST_SPEED, - wbits=-compress, - max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + This is used for small compressed payloads that compress synchronously in the event loop. + Since there are no await points, this is inherently cancellation-safe. + """ + # RSV are the reserved bits in the frame header. They are used to + # indicate that the frame is using an extension. + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + compressobj = self._get_compressor(compress) + # (0x40) RSV1 is set for compressed frames + # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + self._write_websocket_frame( + ( + compressobj.compress_sync(message) + + compressobj.flush( + ZLibBackend.Z_FULL_FLUSH + if self.notakeover + else ZLibBackend.Z_SYNC_FLUSH + ) + ).removesuffix(WS_DEFLATE_TRAILING), + opcode, + 0x40, ) + async def _send_compressed_frame_async_locked( + self, message: bytes, opcode: int, compress: int | None + ) -> None: + """ + Async send for large compressed frames with lock. + + Acquires the lock and compresses large payloads asynchronously in + the executor. The lock is held for the entire operation to ensure + the compressor state is not corrupted by concurrent sends. + + MUST be run shielded from cancellation. If cancelled after + compression but before sending, the compressor state would be + advanced but data not sent, corrupting subsequent frames. + """ + async with self._send_lock: + # RSV are the reserved bits in the frame header. They are used to + # indicate that the frame is using an extension. + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + compressobj = self._get_compressor(compress) + # (0x40) RSV1 is set for compressed frames + # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + self._write_websocket_frame( + ( + await compressobj.compress(message) + + compressobj.flush( + ZLibBackend.Z_FULL_FLUSH + if self.notakeover + else ZLibBackend.Z_SYNC_FLUSH + ) + ).removesuffix(WS_DEFLATE_TRAILING), + opcode, + 0x40, + ) + async def close(self, code: int = 1000, message: bytes | str = b"") -> None: """Close the websocket, sending the specified code and message.""" if isinstance(message, str): diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index c6c6f0d71fc..d9c74fa5400 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -185,7 +185,6 @@ def __init__( if level is not None: kwargs["level"] = level self._compressor = self._zlib_backend.compressobj(**kwargs) - self._compress_lock = asyncio.Lock() def compress_sync(self, data: Buffer) -> bytes: return self._compressor.compress(data) @@ -198,22 +197,37 @@ async def compress(self, data: Buffer) -> bytes: If the data size is large than the max_sync_chunk_size, the compression will be done in the executor. Otherwise, the compression will be done in the event loop. + + **WARNING: This method is NOT cancellation-safe when used with flush().** + If this operation is cancelled, the compressor state may be corrupted. + The connection MUST be closed after cancellation to avoid data corruption + in subsequent compress operations. + + For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap + compress() + flush() + send operations in a shield and lock to ensure atomicity. """ - async with self._compress_lock: - # To ensure the stream is consistent in the event - # there are multiple writers, we need to lock - # the compressor so that only one writer can - # compress at a time. - if ( - self._max_sync_chunk_size is not None - and len(data) > self._max_sync_chunk_size - ): - return await asyncio.get_running_loop().run_in_executor( - self._executor, self._compressor.compress, data - ) - return self.compress_sync(data) + # For large payloads, offload compression to executor to avoid blocking event loop + should_use_executor = ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ) + if should_use_executor: + return await asyncio.get_running_loop().run_in_executor( + self._executor, self._compressor.compress, data + ) + return self.compress_sync(data) def flush(self, mode: int | None = None) -> bytes: + """Flush the compressor synchronously. + + **WARNING: This method is NOT cancellation-safe when called after compress().** + The flush() operation accesses shared compressor state. If compress() was + cancelled, calling flush() may result in corrupted data. The connection MUST + be closed after compress() cancellation. + + For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap + compress() + flush() + send operations in a shield and lock to ensure atomicity. + """ return self._compressor.flush( mode if mode is not None else self._zlib_backend.Z_FINISH ) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 347ec198e24..13ec2f15e98 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -305,6 +305,7 @@ SocketSocketTransport ssl SSLContext startup +stateful subapplication subclassed subclasses diff --git a/tests/conftest.py b/tests/conftest.py index e5dc79cad4d..6ee8520a4a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,9 @@ import socket import ssl import sys +import time from collections.abc import AsyncIterator, Callable, Iterator +from concurrent.futures import Future, ThreadPoolExecutor from hashlib import md5, sha1, sha256 from http.cookies import BaseCookie from pathlib import Path @@ -450,3 +452,27 @@ def maker( await request._close() assert session is not None await session.close() + + +@pytest.fixture +def slow_executor() -> Iterator[ThreadPoolExecutor]: + """Executor that adds delay to simulate slow operations. + + Useful for testing cancellation and race conditions in compression tests. + """ + + class SlowExecutor(ThreadPoolExecutor): + """Executor that adds delay to operations.""" + + def submit( + self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any + ) -> Future[Any]: + def slow_fn(*args: Any, **kwargs: Any) -> Any: + time.sleep(0.05) # Add delay to simulate slow operation + return fn(*args, **kwargs) + + return super().submit(slow_fn, *args, **kwargs) + + executor = SlowExecutor(max_workers=10) + yield executor + executor.shutdown(wait=True) diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 3fcd9f06eb4..14032f42e83 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -1,6 +1,8 @@ import asyncio import random from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from unittest import mock import pytest @@ -143,6 +145,130 @@ async def test_send_compress_text_per_message( writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_send_compress_cancelled( + protocol: BaseProtocol, + transport: asyncio.Transport, + slow_executor: ThreadPoolExecutor, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that cancelled compression doesn't corrupt subsequent sends. + + Regression test for https://github.com/aio-libs/aiohttp/issues/11725 + """ + monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) + writer = WebSocketWriter(protocol, transport, compress=15) + loop = asyncio.get_running_loop() + queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) + reader = WebSocketReader(queue, 50000) + + # Replace executor with slow one to make race condition reproducible + writer._compressobj = writer._get_compressor(None) + writer._compressobj._executor = slow_executor + + # Create large data that will trigger executor-based compression + large_data_1 = b"A" * 10000 + large_data_2 = b"B" * 10000 + + # Start first send and cancel it during compression + async def send_and_cancel() -> None: + await writer.send_frame(large_data_1, WSMsgType.BINARY) + + task = asyncio.create_task(send_and_cancel()) + # Give it a moment to start compression + await asyncio.sleep(0.01) + task.cancel() + + # Await task cancellation (expected and intentionally ignored) + with suppress(asyncio.CancelledError): + await task + + # Send second message - this should NOT be corrupted + await writer.send_frame(large_data_2, WSMsgType.BINARY) + + # Verify the second send produced correct data + last_call = writer.transport.write.call_args_list[-1] # type: ignore[attr-defined] + call_bytes = last_call[0][0] + result, _ = reader.feed_data(call_bytes) + assert result is False + msg = await queue.read() + assert msg.type is WSMsgType.BINARY + # The data should be all B's, not mixed with A's from the cancelled send + assert msg.data == large_data_2 + + +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_send_compress_multiple_cancelled( + protocol: BaseProtocol, + transport: asyncio.Transport, + slow_executor: ThreadPoolExecutor, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that multiple compressed sends all complete despite cancellation. + + Regression test for https://github.com/aio-libs/aiohttp/issues/11725 + This verifies that once a send operation enters the shield, it completes + even if cancelled. With the lock inside the shield, all tasks that enter + the shield will complete their sends, even while waiting for the lock. + """ + monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) + writer = WebSocketWriter(protocol, transport, compress=15) + loop = asyncio.get_running_loop() + queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) + reader = WebSocketReader(queue, 50000) + + # Replace executor with slow one + writer._compressobj = writer._get_compressor(None) + writer._compressobj._executor = slow_executor + + # Create 5 large messages with different content + messages = [bytes([ord("A") + i]) * 10000 for i in range(5)] + + # Start sending all 5 messages - they'll queue due to the lock + tasks = [ + asyncio.create_task(writer.send_frame(msg, WSMsgType.BINARY)) + for msg in messages + ] + + # Cancel all tasks during execution + # With lock inside shield, all tasks that enter the shield will complete + # even while waiting for the lock + await asyncio.sleep(0.1) # Let tasks enter the shield + for task in tasks: + task.cancel() + + # Collect results + cancelled_count = 0 + for task in tasks: + try: + await task + except asyncio.CancelledError: + cancelled_count += 1 + + # Wait for all background tasks to complete + # (they continue running even after cancellation due to shield) + await asyncio.gather(*writer._background_tasks, return_exceptions=True) + + # All tasks that entered the shield should complete, even if cancelled + # With lock inside shield, all tasks enter shield immediately then wait for lock + sent_count = len(writer.transport.write.call_args_list) # type: ignore[attr-defined] + assert ( + sent_count == 5 + ), "All 5 sends should complete due to shield protecting lock acquisition" + + # Verify all sent messages are correct (no corruption) + for i in range(sent_count): + call = writer.transport.write.call_args_list[i] # type: ignore[attr-defined] + call_bytes = call[0][0] + result, _ = reader.feed_data(call_bytes) + assert result is False + msg = await queue.read() + assert msg.type is WSMsgType.BINARY + # Verify the data matches the expected message + expected_byte = bytes([ord("A") + i]) + assert msg.data == expected_byte * 10000, f"Message {i} corrupted" + + @pytest.mark.parametrize( ("max_sync_chunk_size", "payload_point_generator"), ( @@ -206,3 +332,6 @@ async def test_concurrent_messages( # we want to validate that all the bytes are # the same value assert bytes_data == bytes_data[0:1] * char_val + + # Wait for any background tasks to complete + await asyncio.gather(*writer._background_tasks, return_exceptions=True)