Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/11725.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed WebSocket compressed sends to be cancellation safe by adding shield protection to prevent compressor state corruption when tasks are cancelled during compression -- by :user:`bdraco`.
36 changes: 36 additions & 0 deletions aiohttp/_websocket/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import random
import sys
from functools import partial
from typing import Any, Final

Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
self._limit = limit
self._output_size = 0
self._compressobj: Any = None # actually compressobj
self._send_lock: asyncio.Lock | None = None # Created on first compressed send

async def send_frame(
self, message: bytes, opcode: int, compress: int | None = None
Expand All @@ -71,6 +73,40 @@ async def send_frame(
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ClientConnectionResetError("Cannot write to closing transport")

# For compressed frames, the entire compress+send operation must be atomic.
# If cancelled after compression but before send, the compressor state would
# be advanced but data not sent, corrupting subsequent frames.
# We shield the operation to prevent cancellation mid-compression.
use_compression_lock = (compress or self.compress) and opcode < 8
if use_compression_lock:
if self._send_lock is None:
self._send_lock = asyncio.Lock()
await self._send_lock.acquire()
# Create a task to shield from cancellation
# Use eager_start on Python 3.12+ to avoid scheduling overhead
loop = asyncio.get_running_loop()
coro = self._send_frame_impl(message, opcode, compress)
if sys.version_info >= (3, 12):
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
send_task = loop.create_task(coro)
try:
await asyncio.shield(send_task)
except asyncio.CancelledError:
# Shield will re-raise but task continues - wait for it
await send_task
raise
finally:
self._send_lock.release()
return

# Non-compressed path - no shielding needed
await self._send_frame_impl(message, opcode, compress)

async def _send_frame_impl(
self, message: bytes, opcode: int, compress: int | None
) -> None:
"""Internal implementation of send_frame without locking."""
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
Expand Down
40 changes: 26 additions & 14 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
if level is not None:
kwargs["level"] = level
self._compressor = self._zlib_backend.compressobj(**kwargs)
self._compress_lock = asyncio.Lock()

def compress_sync(self, data: Buffer) -> bytes:
return self._compressor.compress(data)
Expand All @@ -198,22 +197,35 @@ async def compress(self, data: Buffer) -> bytes:
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.

**WARNING: This method is NOT cancellation-safe when used with flush().**
If this operation is cancelled, the compressor state may be corrupted.
The connection MUST be closed after cancellation to avoid data corruption
in subsequent compress operations.

For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)

def flush(self, mode: int | None = None) -> bytes:
"""Flush the compressor synchronously.

**WARNING: This method is NOT cancellation-safe when called after compress().**
The flush() operation accesses shared compressor state. If compress() was
cancelled, calling flush() may result in corrupted data. The connection MUST
be closed after compress() cancellation.

For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
return self._compressor.flush(
mode if mode is not None else self._zlib_backend.Z_FINISH
)
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import socket
import ssl
import sys
import time
from collections.abc import AsyncIterator, Callable, Iterator
from concurrent.futures import ThreadPoolExecutor
from hashlib import md5, sha1, sha256
from http.cookies import BaseCookie
from pathlib import Path
Expand Down Expand Up @@ -450,3 +452,25 @@ def maker(
await request._close()
assert session is not None
await session.close()


@pytest.fixture
def slow_executor() -> Iterator[ThreadPoolExecutor]:
"""Executor that adds delay to simulate slow operations.

Useful for testing cancellation and race conditions in compression tests.
"""

class SlowExecutor(ThreadPoolExecutor):
"""Executor that adds delay to operations."""

def submit(self, fn: Any, *args: Any, **kwargs: Any) -> Any:
def slow_fn(*args: Any, **kwargs: Any) -> Any:
time.sleep(0.05) # Add delay to simulate slow operation
return fn(*args, **kwargs)

return super().submit(slow_fn, *args, **kwargs)

executor = SlowExecutor(max_workers=10)
yield executor
executor.shutdown(wait=True)
65 changes: 65 additions & 0 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from unittest import mock

import pytest
from multidict import CIMultiDict
from pytest_mock import MockerFixture
from yarl import URL
Expand Down Expand Up @@ -355,3 +357,66 @@ async def test_abort_without_transport(loop: asyncio.AbstractEventLoop) -> None:
# Should not raise and should still clean up
assert proto._exception is None
mock_drop_timeout.assert_not_called()


async def test_compression_cancelled_marks_connection_for_closure(
loop: asyncio.AbstractEventLoop,
slow_executor: ThreadPoolExecutor,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that compression cancellation leads to connection closure.

This demonstrates why HTTP doesn't need the shield+lock pattern for
compression cancellation that WebSocket requires:

HTTP pattern:
- Compression cancelled → write fails → connection marked for closure
- Response is aborted
- Connection is closed
- Next request gets a new connection with fresh compressor
- Compressor state corruption is harmless because connection closes

WebSocket pattern (see test_websocket_writer.py):
- Connection stays open across hundreds/thousands of messages
- Same stateful compressor is reused for entire connection lifetime
- Compression cancellation without shield → state corruption → data corruption
- Requires shield + lock to ensure atomicity

Related to https://github.com/aio-libs/aiohttp/issues/11725
"""
proto = ResponseHandler(loop=loop)
transport = mock.Mock()
transport.is_closing.return_value = False
proto.connection_made(transport)

# Create a writer with compression enabled
writer = http.StreamWriter(proto, loop)
writer.enable_compression("deflate")
writer.enable_chunking()

# Use slow executor to make cancellation reproducible
assert writer._compress is not None
writer._compress._executor = slow_executor

# Initially connection is healthy
assert not proto.should_close

# Write large data that triggers executor-based compression
monkeypatch.setattr("aiohttp.compression_utils.MAX_SYNC_CHUNK_SIZE", 1024)
large_data = b"X" * 10000

# Start write and cancel during compression
task = asyncio.create_task(writer.write(large_data))
await asyncio.sleep(0.01) # Let compression start
task.cancel()

try:
await task
except asyncio.CancelledError:
# In real HTTP client, cancellation of write would lead to
# the protocol marking connection for closure
proto.set_exception(http.HttpProcessingError(message="Write cancelled"))

# After cancellation and exception, connection should be marked for closure
assert proto.should_close
assert proto.exception() is not None
119 changes: 119 additions & 0 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import random
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from unittest import mock

import pytest
Expand Down Expand Up @@ -143,6 +144,124 @@
writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined]


@pytest.mark.usefixtures("parametrize_zlib_backend")
async def test_send_compress_cancelled(
protocol: BaseProtocol,
transport: asyncio.Transport,
slow_executor: ThreadPoolExecutor,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that cancelled compression doesn't corrupt subsequent sends.
Regression test for https://github.com/aio-libs/aiohttp/issues/11725
"""
monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024)
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)

# Replace executor with slow one to make race condition reproducible
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create large data that will trigger executor-based compression
large_data_1 = b"A" * 10000
large_data_2 = b"B" * 10000

# Start first send and cancel it during compression
async def send_and_cancel() -> None:
await writer.send_frame(large_data_1, WSMsgType.BINARY)

task = asyncio.create_task(send_and_cancel())
# Give it a moment to start compression
await asyncio.sleep(0.01)
task.cancel()

# Catch the cancellation
try:
await task
except asyncio.CancelledError:
pass

# Send second message - this should NOT be corrupted
await writer.send_frame(large_data_2, WSMsgType.BINARY)

# Verify the second send produced correct data
last_call = writer.transport.write.call_args_list[-1] # type: ignore[attr-defined]
call_bytes = last_call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# The data should be all B's, not mixed with A's from the cancelled send
assert msg.data == large_data_2


@pytest.mark.usefixtures("parametrize_zlib_backend")
async def test_send_compress_multiple_cancelled(
protocol: BaseProtocol,
transport: asyncio.Transport,
slow_executor: ThreadPoolExecutor,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that multiple compressed sends all complete despite cancellation.
Regression test for https://github.com/aio-libs/aiohttp/issues/11725
This verifies that once a send operation enters the shield, it completes
even if cancelled. The lock serializes sends, so they process one at a time.
"""
monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024)
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)

# Replace executor with slow one
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create 5 large messages with different content
messages = [bytes([ord("A") + i]) * 10000 for i in range(5)]

# Start sending all 5 messages - they'll queue due to the lock
tasks = [
asyncio.create_task(writer.send_frame(msg, WSMsgType.BINARY))
for msg in messages
]

# Cancel all tasks during execution
# Tasks in the shield will complete, tasks waiting for lock will cancel
await asyncio.sleep(0.03) # Let one or two enter the shield
for task in tasks:
task.cancel()

# Collect results
cancelled_count = 0
for task in tasks:
try:
await task
except asyncio.CancelledError:
cancelled_count += 1

# At least one message should have been sent (the one in the shield)
sent_count = len(writer.transport.write.call_args_list) # type: ignore[attr-defined]
assert sent_count >= 1, "At least one send should complete due to shield"
assert sent_count <= 5, "Can't send more than 5 messages"

# Verify all sent messages are correct (no corruption)
for i in range(sent_count):
call = writer.transport.write.call_args_list[i] # type: ignore[attr-defined]
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# Verify the data matches the expected message
expected_byte = bytes([ord("A") + i])
assert msg.data == expected_byte * 10000, f"Message {i} corrupted"


@pytest.mark.parametrize(
("max_sync_chunk_size", "payload_point_generator"),
(
Expand Down
Loading