diff --git a/CHANGES/11632.bugfix.rst b/CHANGES/11632.bugfix.rst new file mode 100644 index 00000000000..c07bfb2b1f7 --- /dev/null +++ b/CHANGES/11632.bugfix.rst @@ -0,0 +1 @@ +Fixed cookie parser to continue parsing subsequent cookies when encountering a malformed cookie that fails regex validation, such as Google's ``g_state`` cookie with unescaped quotes -- by :user:`bdraco`. diff --git a/CHANGES/11713.bugfix.rst b/CHANGES/11713.bugfix.rst new file mode 100644 index 00000000000..dbb45a5254f --- /dev/null +++ b/CHANGES/11713.bugfix.rst @@ -0,0 +1 @@ +Fixed loading netrc credentials from the default :file:`~/.netrc` (:file:`~/_netrc` on Windows) location when the :envvar:`NETRC` environment variable is not set -- by :user:`bdraco`. diff --git a/CHANGES/11714.bugfix.rst b/CHANGES/11714.bugfix.rst new file mode 120000 index 00000000000..5a506f1ded3 --- /dev/null +++ b/CHANGES/11714.bugfix.rst @@ -0,0 +1 @@ +11713.bugfix.rst \ No newline at end of file diff --git a/CHANGES/11725.bugfix.rst b/CHANGES/11725.bugfix.rst new file mode 100644 index 00000000000..e78fc054230 --- /dev/null +++ b/CHANGES/11725.bugfix.rst @@ -0,0 +1 @@ +Fixed WebSocket compressed sends to be cancellation safe. Tasks are now shielded during compression to prevent compressor state corruption. This ensures that the stateful compressor remains consistent even when send operations are cancelled -- by :user:`bdraco`. diff --git a/aiohttp/_cookie_helpers.py b/aiohttp/_cookie_helpers.py index 7fe8f43d12b..20a278b0d5b 100644 --- a/aiohttp/_cookie_helpers.py +++ b/aiohttp/_cookie_helpers.py @@ -166,7 +166,10 @@ def parse_cookie_header(header: str) -> list[tuple[str, Morsel[str]]]: attribute names (like 'path' or 'secure') should be treated as cookies. This parser uses the same regex-based approach as parse_set_cookie_headers - to properly handle quoted values that may contain semicolons. + to properly handle quoted values that may contain semicolons. When the + regex fails to match a malformed cookie, it falls back to simple parsing + to ensure subsequent cookies are not lost + https://github.com/aio-libs/aiohttp/issues/11632 Args: header: The Cookie header value to parse @@ -177,6 +180,7 @@ def parse_cookie_header(header: str) -> list[tuple[str, Morsel[str]]]: if not header: return [] + morsel: Morsel[str] cookies: list[tuple[str, Morsel[str]]] = [] i = 0 n = len(header) @@ -185,7 +189,32 @@ def parse_cookie_header(header: str) -> list[tuple[str, Morsel[str]]]: # Use the same pattern as parse_set_cookie_headers to find cookies match = _COOKIE_PATTERN.match(header, i) if not match: - break + # Fallback for malformed cookies https://github.com/aio-libs/aiohttp/issues/11632 + # Find next semicolon to skip or attempt simple key=value parsing + next_semi = header.find(";", i) + eq_pos = header.find("=", i) + + # Try to extract key=value if '=' comes before ';' + if eq_pos != -1 and (next_semi == -1 or eq_pos < next_semi): + end_pos = next_semi if next_semi != -1 else n + key = header[i:eq_pos].strip() + value = header[eq_pos + 1 : end_pos].strip() + + # Validate the name (same as regex path) + if not _COOKIE_NAME_RE.match(key): + internal_logger.warning( + "Can not load cookie: Illegal cookie name %r", key + ) + else: + morsel = Morsel() + morsel.__setstate__( # type: ignore[attr-defined] + {"key": key, "value": _unquote(value), "coded_value": value} + ) + cookies.append((key, morsel)) + + # Move to next cookie or end + i = next_semi + 1 if next_semi != -1 else n + continue key = match.group("key") value = match.group("val") or "" @@ -197,7 +226,7 @@ def parse_cookie_header(header: str) -> list[tuple[str, Morsel[str]]]: continue # Create new morsel - morsel: Morsel[str] = Morsel() + morsel = Morsel() # Preserve the original value as coded_value (with quotes if present) # We use __setstate__ instead of the public set() API because it allows us to # bypass validation and set already validated state. This is more stable than diff --git a/aiohttp/_websocket/writer.py b/aiohttp/_websocket/writer.py index fdbcda45c3c..1b27dff9371 100644 --- a/aiohttp/_websocket/writer.py +++ b/aiohttp/_websocket/writer.py @@ -2,8 +2,9 @@ import asyncio import random +import sys from functools import partial -from typing import Any, Final +from typing import Final from ..base_protocol import BaseProtocol from ..client_exceptions import ClientConnectionResetError @@ -22,14 +23,18 @@ DEFAULT_LIMIT: Final[int] = 2**16 +# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames +# Control frames (ping, pong, close) are never compressed +WS_CONTROL_FRAME_OPCODE: Final[int] = 8 + # For websockets, keeping latency low is extremely important as implementations -# generally expect to be able to send and receive messages quickly. We use a -# larger chunk size than the default to reduce the number of executor calls -# since the executor is a significant source of latency and overhead when -# the chunks are small. A size of 5KiB was chosen because it is also the -# same value python-zlib-ng choose to use as the threshold to release the GIL. +# generally expect to be able to send and receive messages quickly. We use a +# larger chunk size to reduce the number of executor calls and avoid task +# creation overhead, since both are significant sources of latency when chunks +# are small. A size of 16KiB was chosen as a balance between avoiding task +# overhead and not blocking the event loop too long with synchronous compression. -WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024 +WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024 class WebSocketWriter: @@ -62,7 +67,9 @@ def __init__( self._closing = False self._limit = limit self._output_size = 0 - self._compressobj: Any = None # actually compressobj + self._compressobj: ZLibCompressor | None = None + self._send_lock = asyncio.Lock() + self._background_tasks: set[asyncio.Task[None]] = set() async def send_frame( self, message: bytes, opcode: int, compress: int | None = None @@ -71,39 +78,57 @@ async def send_frame( if self._closing and not (opcode & WSMsgType.CLOSE): raise ClientConnectionResetError("Cannot write to closing transport") - # RSV are the reserved bits in the frame header. They are used to - # indicate that the frame is using an extension. - # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 - rsv = 0 - # Only compress larger packets (disabled) - # Does small packet needs to be compressed? - # if self.compress and opcode < 8 and len(message) > 124: - if (compress or self.compress) and opcode < 8: - # RSV1 (rsv = 0x40) is set for compressed frames - # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 - rsv = 0x40 - - if compress: - # Do not set self._compress if compressing is for this frame - compressobj = self._make_compress_obj(compress) - else: # self.compress - if not self._compressobj: - self._compressobj = self._make_compress_obj(self.compress) - compressobj = self._compressobj - - message = ( - await compressobj.compress(message) - + compressobj.flush( - ZLibBackend.Z_FULL_FLUSH - if self.notakeover - else ZLibBackend.Z_SYNC_FLUSH - ) - ).removesuffix(WS_DEFLATE_TRAILING) - # Its critical that we do not return control to the event - # loop until we have finished sending all the compressed - # data. Otherwise we could end up mixing compressed frames - # if there are multiple coroutines compressing data. + if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE: + # Non-compressed frames don't need lock or shield + self._write_websocket_frame(message, opcode, 0) + elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE: + # Small compressed payloads - compress synchronously in event loop + # We need the lock even though sync compression has no await points. + # This prevents small frames from interleaving with large frames that + # compress in the executor, avoiding compressor state corruption. + async with self._send_lock: + self._send_compressed_frame_sync(message, opcode, compress) + else: + # Large compressed frames need shield to prevent corruption + # For large compressed frames, the entire compress+send + # operation must be atomic. If cancelled after compression but + # before send, the compressor state would be advanced but data + # not sent, corrupting subsequent frames. + # Create a task to shield from cancellation + # The lock is acquired inside the shielded task so the entire + # operation (lock + compress + send) completes atomically. + # Use eager_start on Python 3.12+ to avoid scheduling overhead + loop = asyncio.get_running_loop() + coro = self._send_compressed_frame_async_locked(message, opcode, compress) + if sys.version_info >= (3, 12): + send_task = asyncio.Task(coro, loop=loop, eager_start=True) + else: + send_task = loop.create_task(coro) + # Keep a strong reference to prevent garbage collection + self._background_tasks.add(send_task) + send_task.add_done_callback(self._background_tasks.discard) + await asyncio.shield(send_task) + + # It is safe to return control to the event loop when using compression + # after this point as we have already sent or buffered all the data. + # Once we have written output_size up to the limit, we call the + # drain helper which waits for the transport to be ready to accept + # more data. This is a flow control mechanism to prevent the buffer + # from growing too large. The drain helper will return right away + # if the writer is not paused. + if self._output_size > self._limit: + self._output_size = 0 + if self.protocol._paused: + await self.protocol._drain_helper() + def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None: + """ + Write a websocket frame to the transport. + + This method handles frame header construction, masking, and writing to transport. + It does not handle compression or flow control - those are the responsibility + of the caller. + """ msg_length = len(message) use_mask = self.use_mask @@ -146,26 +171,85 @@ async def send_frame( self._output_size += header_len + msg_length - # It is safe to return control to the event loop when using compression - # after this point as we have already sent or buffered all the data. + def _get_compressor(self, compress: int | None) -> ZLibCompressor: + """Get or create a compressor object for the given compression level.""" + if compress: + # Do not set self._compress if compressing is for this frame + return ZLibCompressor( + level=ZLibBackend.Z_BEST_SPEED, + wbits=-compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + if not self._compressobj: + self._compressobj = ZLibCompressor( + level=ZLibBackend.Z_BEST_SPEED, + wbits=-self.compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + return self._compressobj - # Once we have written output_size up to the limit, we call the - # drain helper which waits for the transport to be ready to accept - # more data. This is a flow control mechanism to prevent the buffer - # from growing too large. The drain helper will return right away - # if the writer is not paused. - if self._output_size > self._limit: - self._output_size = 0 - if self.protocol._paused: - await self.protocol._drain_helper() + def _send_compressed_frame_sync( + self, message: bytes, opcode: int, compress: int | None + ) -> None: + """ + Synchronous send for small compressed frames. - def _make_compress_obj(self, compress: int) -> ZLibCompressor: - return ZLibCompressor( - level=ZLibBackend.Z_BEST_SPEED, - wbits=-compress, - max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + This is used for small compressed payloads that compress synchronously in the event loop. + Since there are no await points, this is inherently cancellation-safe. + """ + # RSV are the reserved bits in the frame header. They are used to + # indicate that the frame is using an extension. + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + compressobj = self._get_compressor(compress) + # (0x40) RSV1 is set for compressed frames + # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + self._write_websocket_frame( + ( + compressobj.compress_sync(message) + + compressobj.flush( + ZLibBackend.Z_FULL_FLUSH + if self.notakeover + else ZLibBackend.Z_SYNC_FLUSH + ) + ).removesuffix(WS_DEFLATE_TRAILING), + opcode, + 0x40, ) + async def _send_compressed_frame_async_locked( + self, message: bytes, opcode: int, compress: int | None + ) -> None: + """ + Async send for large compressed frames with lock. + + Acquires the lock and compresses large payloads asynchronously in + the executor. The lock is held for the entire operation to ensure + the compressor state is not corrupted by concurrent sends. + + MUST be run shielded from cancellation. If cancelled after + compression but before sending, the compressor state would be + advanced but data not sent, corrupting subsequent frames. + """ + async with self._send_lock: + # RSV are the reserved bits in the frame header. They are used to + # indicate that the frame is using an extension. + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + compressobj = self._get_compressor(compress) + # (0x40) RSV1 is set for compressed frames + # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + self._write_websocket_frame( + ( + await compressobj.compress(message) + + compressobj.flush( + ZLibBackend.Z_FULL_FLUSH + if self.notakeover + else ZLibBackend.Z_SYNC_FLUSH + ) + ).removesuffix(WS_DEFLATE_TRAILING), + opcode, + 0x40, + ) + async def close(self, code: int = 1000, message: bytes | str = b"") -> None: """Close the websocket, sending the specified code and message.""" if isinstance(message, str): diff --git a/aiohttp/client.py b/aiohttp/client.py index 059f1adc401..026006023ce 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -590,14 +590,7 @@ async def _request( auth = self._default_auth # Try netrc if auth is still None and trust_env is enabled. - # Only check if NETRC environment variable is set to avoid - # creating an expensive executor job unnecessarily. - if ( - auth is None - and self._trust_env - and url.host is not None - and os.environ.get("NETRC") - ): + if auth is None and self._trust_env and url.host is not None: auth = await self._loop.run_in_executor( None, self._get_netrc_auth, url.host ) diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index c6c6f0d71fc..d9c74fa5400 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -185,7 +185,6 @@ def __init__( if level is not None: kwargs["level"] = level self._compressor = self._zlib_backend.compressobj(**kwargs) - self._compress_lock = asyncio.Lock() def compress_sync(self, data: Buffer) -> bytes: return self._compressor.compress(data) @@ -198,22 +197,37 @@ async def compress(self, data: Buffer) -> bytes: If the data size is large than the max_sync_chunk_size, the compression will be done in the executor. Otherwise, the compression will be done in the event loop. + + **WARNING: This method is NOT cancellation-safe when used with flush().** + If this operation is cancelled, the compressor state may be corrupted. + The connection MUST be closed after cancellation to avoid data corruption + in subsequent compress operations. + + For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap + compress() + flush() + send operations in a shield and lock to ensure atomicity. """ - async with self._compress_lock: - # To ensure the stream is consistent in the event - # there are multiple writers, we need to lock - # the compressor so that only one writer can - # compress at a time. - if ( - self._max_sync_chunk_size is not None - and len(data) > self._max_sync_chunk_size - ): - return await asyncio.get_running_loop().run_in_executor( - self._executor, self._compressor.compress, data - ) - return self.compress_sync(data) + # For large payloads, offload compression to executor to avoid blocking event loop + should_use_executor = ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ) + if should_use_executor: + return await asyncio.get_running_loop().run_in_executor( + self._executor, self._compressor.compress, data + ) + return self.compress_sync(data) def flush(self, mode: int | None = None) -> bytes: + """Flush the compressor synchronously. + + **WARNING: This method is NOT cancellation-safe when called after compress().** + The flush() operation accesses shared compressor state. If compress() was + cancelled, calling flush() may result in corrupted data. The connection MUST + be closed after compress() cancellation. + + For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap + compress() + flush() + send operations in a shield and lock to ensure atomicity. + """ return self._compressor.flush( mode if mode is not None else self._zlib_backend.Z_FINISH ) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 347ec198e24..74ace02c5ec 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -305,6 +305,7 @@ SocketSocketTransport ssl SSLContext startup +stateful subapplication subclassed subclasses @@ -341,8 +342,9 @@ tuples UI un unawaited -undercounting unclosed +undercounting +unescaped unhandled unicode unittest diff --git a/tests/conftest.py b/tests/conftest.py index e5dc79cad4d..80b94ffa50a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,13 @@ import asyncio import base64 import os +import platform import socket import ssl import sys +import time from collections.abc import AsyncIterator, Callable, Iterator +from concurrent.futures import Future, ThreadPoolExecutor from hashlib import md5, sha1, sha256 from http.cookies import BaseCookie from pathlib import Path @@ -329,6 +332,23 @@ def netrc_other_host(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path: return netrc_file +@pytest.fixture +def netrc_home_directory(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path: + """Create a netrc file in a mocked home directory without setting NETRC env var.""" + home_dir = tmp_path / "home" + home_dir.mkdir() + netrc_filename = "_netrc" if platform.system() == "Windows" else ".netrc" + netrc_file = home_dir / netrc_filename + netrc_file.write_text("default login netrc_user password netrc_pass\n") + + home_env_var = "USERPROFILE" if platform.system() == "Windows" else "HOME" + monkeypatch.setenv(home_env_var, str(home_dir)) + # Ensure NETRC env var is not set + monkeypatch.delenv("NETRC", raising=False) + + return netrc_file + + @pytest.fixture def start_connection() -> Iterator[mock.Mock]: with mock.patch( @@ -450,3 +470,27 @@ def maker( await request._close() assert session is not None await session.close() + + +@pytest.fixture +def slow_executor() -> Iterator[ThreadPoolExecutor]: + """Executor that adds delay to simulate slow operations. + + Useful for testing cancellation and race conditions in compression tests. + """ + + class SlowExecutor(ThreadPoolExecutor): + """Executor that adds delay to operations.""" + + def submit( + self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any + ) -> Future[Any]: + def slow_fn(*args: Any, **kwargs: Any) -> Any: + time.sleep(0.05) # Add delay to simulate slow operation + return fn(*args, **kwargs) + + return super().submit(slow_fn, *args, **kwargs) + + executor = SlowExecutor(max_workers=10) + yield executor + executor.shutdown(wait=True) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 473427278f8..95b40cce9bb 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -3748,12 +3748,12 @@ async def test_netrc_auth_from_env( # type: ignore[misc] @pytest.mark.usefixtures("no_netrc") -async def test_netrc_auth_skipped_without_env_var( # type: ignore[misc] +async def test_netrc_auth_skipped_without_netrc_file( # type: ignore[misc] headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: - """Test that netrc authentication is skipped when NETRC env var is not set.""" + """Test that netrc authentication is skipped when no netrc file exists.""" client = await headers_echo_client(trust_env=True) async with client.get("/") as r: assert r.status == 200 @@ -3762,6 +3762,20 @@ async def test_netrc_auth_skipped_without_env_var( # type: ignore[misc] assert "Authorization" not in content["headers"] +@pytest.mark.usefixtures("netrc_home_directory") +async def test_netrc_auth_from_home_directory( # type: ignore[misc] + headers_echo_client: Callable[ + ..., Awaitable[TestClient[web.Request, web.Application]] + ], +) -> None: + """Test that netrc authentication works from default ~/.netrc without NETRC env var.""" + client = await headers_echo_client(trust_env=True) + async with client.get("/") as r: + assert r.status == 200 + content = await r.json() + assert content["headers"]["Authorization"] == "Basic bmV0cmNfdXNlcjpuZXRyY19wYXNz" + + @pytest.mark.usefixtures("netrc_default_contents") async def test_netrc_auth_overridden_by_explicit_auth( # type: ignore[misc] headers_echo_client: Callable[ diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 21057d3fbb5..e9106c3443d 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -1368,8 +1368,8 @@ async def test_netrc_auth_skipped_without_trust_env(auth_server: TestServer) -> @pytest.mark.usefixtures("no_netrc") -async def test_netrc_auth_skipped_without_netrc_env(auth_server: TestServer) -> None: - """Test that netrc authentication is skipped when NETRC env var is not set.""" +async def test_netrc_auth_skipped_without_netrc_file(auth_server: TestServer) -> None: + """Test that netrc authentication is skipped when no netrc file exists.""" async with ( ClientSession(trust_env=True) as session, session.get(auth_server.make_url("/")) as resp, @@ -1378,6 +1378,17 @@ async def test_netrc_auth_skipped_without_netrc_env(auth_server: TestServer) -> assert text == "no_auth" +@pytest.mark.usefixtures("netrc_home_directory") +async def test_netrc_auth_from_home_directory(auth_server: TestServer) -> None: + """Test that netrc authentication works from default ~/.netrc location without NETRC env var.""" + async with ( + ClientSession(trust_env=True) as session, + session.get(auth_server.make_url("/")) as resp, + ): + text = await resp.text() + assert text == "auth:Basic bmV0cmNfdXNlcjpuZXRyY19wYXNz" + + @pytest.mark.usefixtures("netrc_default_contents") async def test_netrc_auth_overridden_by_explicit_auth(auth_server: TestServer) -> None: """Test that explicit auth parameter overrides netrc authentication.""" diff --git a/tests/test_cookie_helpers.py b/tests/test_cookie_helpers.py index 575bbe54d01..577e3156560 100644 --- a/tests/test_cookie_helpers.py +++ b/tests/test_cookie_helpers.py @@ -1137,7 +1137,6 @@ def test_parse_cookie_header_empty() -> None: assert parse_cookie_header(" ") == [] -@pytest.mark.xfail(reason="https://github.com/aio-libs/aiohttp/issues/11632") def test_parse_cookie_gstate_header() -> None: header = ( "_ga=ga; " @@ -1444,6 +1443,142 @@ def test_parse_cookie_header_illegal_names(caplog: pytest.LogCaptureFixture) -> assert "Can not load cookie: Illegal cookie name 'invalid,cookie'" in caplog.text +def test_parse_cookie_header_large_value() -> None: + """Test that large cookie values don't cause DoS.""" + large_value = "A" * 8192 + header = f"normal=value; large={large_value}; after=cookie" + + result = parse_cookie_header(header) + cookie_names = [name for name, _ in result] + + assert len(result) == 3 + assert "normal" in cookie_names + assert "large" in cookie_names + assert "after" in cookie_names + + large_cookie = next(morsel for name, morsel in result if name == "large") + assert len(large_cookie.value) == 8192 + + +def test_parse_cookie_header_multiple_equals() -> None: + """Test handling of multiple equals signs in cookie values.""" + header = "session=abc123; data=key1=val1&key2=val2; token=xyz" + + result = parse_cookie_header(header) + + assert len(result) == 3 + + name1, morsel1 = result[0] + assert name1 == "session" + assert morsel1.value == "abc123" + + name2, morsel2 = result[1] + assert name2 == "data" + assert morsel2.value == "key1=val1&key2=val2" + + name3, morsel3 = result[2] + assert name3 == "token" + assert morsel3.value == "xyz" + + +def test_parse_cookie_header_fallback_preserves_subsequent_cookies() -> None: + """Test that fallback parser doesn't lose subsequent cookies.""" + header = 'normal=value; malformed={"json":"value"}; after1=cookie1; after2=cookie2' + + result = parse_cookie_header(header) + cookie_names = [name for name, _ in result] + + assert len(result) == 4 + assert cookie_names == ["normal", "malformed", "after1", "after2"] + + name1, morsel1 = result[0] + assert morsel1.value == "value" + + name2, morsel2 = result[1] + assert morsel2.value == '{"json":"value"}' + + name3, morsel3 = result[2] + assert morsel3.value == "cookie1" + + name4, morsel4 = result[3] + assert morsel4.value == "cookie2" + + +def test_parse_cookie_header_whitespace_in_fallback() -> None: + """Test that fallback parser handles whitespace correctly.""" + header = "a=1; b = 2 ; c= 3; d =4" + + result = parse_cookie_header(header) + + assert len(result) == 4 + for name, morsel in result: + assert name in ("a", "b", "c", "d") + assert morsel.value in ("1", "2", "3", "4") + + +def test_parse_cookie_header_empty_value_in_fallback() -> None: + """Test that fallback handles empty values correctly.""" + header = "normal=value; empty=; another=test" + + result = parse_cookie_header(header) + + assert len(result) == 3 + + name1, morsel1 = result[0] + assert name1 == "normal" + assert morsel1.value == "value" + + name2, morsel2 = result[1] + assert name2 == "empty" + assert morsel2.value == "" + + name3, morsel3 = result[2] + assert name3 == "another" + assert morsel3.value == "test" + + +def test_parse_cookie_header_invalid_name_in_fallback( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that fallback parser rejects cookies with invalid names.""" + header = 'normal=value; invalid,name={"x":"y"}; another=test' + + result = parse_cookie_header(header) + + assert len(result) == 2 + + name1, morsel1 = result[0] + assert name1 == "normal" + assert morsel1.value == "value" + + name2, morsel2 = result[1] + assert name2 == "another" + assert morsel2.value == "test" + + assert "Can not load cookie: Illegal cookie name 'invalid,name'" in caplog.text + + +def test_parse_cookie_header_empty_key_in_fallback( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that fallback parser logs warning for empty cookie names.""" + header = 'normal=value; ={"malformed":"json"}; another=test' + + result = parse_cookie_header(header) + + assert len(result) == 2 + + name1, morsel1 = result[0] + assert name1 == "normal" + assert morsel1.value == "value" + + name2, morsel2 = result[1] + assert name2 == "another" + assert morsel2.value == "test" + + assert "Can not load cookie: Illegal cookie name ''" in caplog.text + + @pytest.mark.parametrize( ("input_str", "expected"), [ diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 3fcd9f06eb4..14032f42e83 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -1,6 +1,8 @@ import asyncio import random from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from unittest import mock import pytest @@ -143,6 +145,130 @@ async def test_send_compress_text_per_message( writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_send_compress_cancelled( + protocol: BaseProtocol, + transport: asyncio.Transport, + slow_executor: ThreadPoolExecutor, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that cancelled compression doesn't corrupt subsequent sends. + + Regression test for https://github.com/aio-libs/aiohttp/issues/11725 + """ + monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) + writer = WebSocketWriter(protocol, transport, compress=15) + loop = asyncio.get_running_loop() + queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) + reader = WebSocketReader(queue, 50000) + + # Replace executor with slow one to make race condition reproducible + writer._compressobj = writer._get_compressor(None) + writer._compressobj._executor = slow_executor + + # Create large data that will trigger executor-based compression + large_data_1 = b"A" * 10000 + large_data_2 = b"B" * 10000 + + # Start first send and cancel it during compression + async def send_and_cancel() -> None: + await writer.send_frame(large_data_1, WSMsgType.BINARY) + + task = asyncio.create_task(send_and_cancel()) + # Give it a moment to start compression + await asyncio.sleep(0.01) + task.cancel() + + # Await task cancellation (expected and intentionally ignored) + with suppress(asyncio.CancelledError): + await task + + # Send second message - this should NOT be corrupted + await writer.send_frame(large_data_2, WSMsgType.BINARY) + + # Verify the second send produced correct data + last_call = writer.transport.write.call_args_list[-1] # type: ignore[attr-defined] + call_bytes = last_call[0][0] + result, _ = reader.feed_data(call_bytes) + assert result is False + msg = await queue.read() + assert msg.type is WSMsgType.BINARY + # The data should be all B's, not mixed with A's from the cancelled send + assert msg.data == large_data_2 + + +@pytest.mark.usefixtures("parametrize_zlib_backend") +async def test_send_compress_multiple_cancelled( + protocol: BaseProtocol, + transport: asyncio.Transport, + slow_executor: ThreadPoolExecutor, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that multiple compressed sends all complete despite cancellation. + + Regression test for https://github.com/aio-libs/aiohttp/issues/11725 + This verifies that once a send operation enters the shield, it completes + even if cancelled. With the lock inside the shield, all tasks that enter + the shield will complete their sends, even while waiting for the lock. + """ + monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) + writer = WebSocketWriter(protocol, transport, compress=15) + loop = asyncio.get_running_loop() + queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) + reader = WebSocketReader(queue, 50000) + + # Replace executor with slow one + writer._compressobj = writer._get_compressor(None) + writer._compressobj._executor = slow_executor + + # Create 5 large messages with different content + messages = [bytes([ord("A") + i]) * 10000 for i in range(5)] + + # Start sending all 5 messages - they'll queue due to the lock + tasks = [ + asyncio.create_task(writer.send_frame(msg, WSMsgType.BINARY)) + for msg in messages + ] + + # Cancel all tasks during execution + # With lock inside shield, all tasks that enter the shield will complete + # even while waiting for the lock + await asyncio.sleep(0.1) # Let tasks enter the shield + for task in tasks: + task.cancel() + + # Collect results + cancelled_count = 0 + for task in tasks: + try: + await task + except asyncio.CancelledError: + cancelled_count += 1 + + # Wait for all background tasks to complete + # (they continue running even after cancellation due to shield) + await asyncio.gather(*writer._background_tasks, return_exceptions=True) + + # All tasks that entered the shield should complete, even if cancelled + # With lock inside shield, all tasks enter shield immediately then wait for lock + sent_count = len(writer.transport.write.call_args_list) # type: ignore[attr-defined] + assert ( + sent_count == 5 + ), "All 5 sends should complete due to shield protecting lock acquisition" + + # Verify all sent messages are correct (no corruption) + for i in range(sent_count): + call = writer.transport.write.call_args_list[i] # type: ignore[attr-defined] + call_bytes = call[0][0] + result, _ = reader.feed_data(call_bytes) + assert result is False + msg = await queue.read() + assert msg.type is WSMsgType.BINARY + # Verify the data matches the expected message + expected_byte = bytes([ord("A") + i]) + assert msg.data == expected_byte * 10000, f"Message {i} corrupted" + + @pytest.mark.parametrize( ("max_sync_chunk_size", "payload_point_generator"), ( @@ -206,3 +332,6 @@ async def test_concurrent_messages( # we want to validate that all the bytes are # the same value assert bytes_data == bytes_data[0:1] * char_val + + # Wait for any background tasks to complete + await asyncio.gather(*writer._background_tasks, return_exceptions=True)