Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
make safe
  • Loading branch information
bdraco committed Oct 27, 2025
commit 328e3dc5569f80f7785630721fbf0f182474f7b0
36 changes: 36 additions & 0 deletions aiohttp/_websocket/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import random
import sys
from functools import partial
from typing import Any, Final

Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
self._limit = limit
self._output_size = 0
self._compressobj: Any = None # actually compressobj
self._send_lock: asyncio.Lock | None = None # Created on first compressed send

async def send_frame(
self, message: bytes, opcode: int, compress: int | None = None
Expand All @@ -71,6 +73,40 @@ async def send_frame(
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ClientConnectionResetError("Cannot write to closing transport")

# For compressed frames, the entire compress+send operation must be atomic.
# If cancelled after compression but before send, the compressor state would
# be advanced but data not sent, corrupting subsequent frames.
# We shield the operation to prevent cancellation mid-compression.
use_compression_lock = (compress or self.compress) and opcode < 8
if use_compression_lock:
if self._send_lock is None:
self._send_lock = asyncio.Lock()
await self._send_lock.acquire()
# Create a task to shield from cancellation
# Use eager_start on Python 3.12+ to avoid scheduling overhead
loop = asyncio.get_running_loop()
coro = self._send_frame_impl(message, opcode, compress)
if sys.version_info >= (3, 12):
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
send_task = loop.create_task(coro)
try:
await asyncio.shield(send_task)
except asyncio.CancelledError:
# Shield will re-raise but task continues - wait for it
await send_task
raise
finally:
self._send_lock.release()
return

# Non-compressed path - no shielding needed
await self._send_frame_impl(message, opcode, compress)

async def _send_frame_impl(
self, message: bytes, opcode: int, compress: int | None
) -> None:
"""Internal implementation of send_frame without locking."""
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
Expand Down
40 changes: 26 additions & 14 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
if level is not None:
kwargs["level"] = level
self._compressor = self._zlib_backend.compressobj(**kwargs)
self._compress_lock = asyncio.Lock()

def compress_sync(self, data: Buffer) -> bytes:
return self._compressor.compress(data)
Expand All @@ -198,22 +197,35 @@ async def compress(self, data: Buffer) -> bytes:
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.

**WARNING: This method is NOT cancellation-safe when used with flush().**
If this operation is cancelled, the compressor state may be corrupted.
The connection MUST be closed after cancellation to avoid data corruption
in subsequent compress operations.

For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)

def flush(self, mode: int | None = None) -> bytes:
"""Flush the compressor synchronously.

**WARNING: This method is NOT cancellation-safe when called after compress().**
The flush() operation accesses shared compressor state. If compress() was
cancelled, calling flush() may result in corrupted data. The connection MUST
be closed after compress() cancellation.

For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
return self._compressor.flush(
mode if mode is not None else self._zlib_backend.Z_FINISH
)
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import socket
import ssl
import sys
import time
from collections.abc import AsyncIterator, Callable, Iterator
from concurrent.futures import ThreadPoolExecutor
from hashlib import md5, sha1, sha256
from http.cookies import BaseCookie
from pathlib import Path
Expand Down Expand Up @@ -450,3 +452,25 @@ def maker(
await request._close()
assert session is not None
await session.close()


@pytest.fixture
def slow_executor() -> Iterator[ThreadPoolExecutor]:
"""Executor that adds delay to simulate slow operations.

Useful for testing cancellation and race conditions in compression tests.
"""

class SlowExecutor(ThreadPoolExecutor):
"""Executor that adds delay to operations."""

def submit(self, fn, *args, **kwargs):
def slow_fn(*args, **kwargs):
time.sleep(0.05) # Add delay to simulate slow operation
return fn(*args, **kwargs)

return super().submit(slow_fn, *args, **kwargs)

executor = SlowExecutor(max_workers=10)
yield executor
executor.shutdown(wait=True)
75 changes: 75 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import zlib
from collections.abc import Generator, Iterable
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from unittest import mock

Expand Down Expand Up @@ -1697,3 +1698,77 @@ async def test_send_headers_with_payload_chunked_eof_no_data(
assert b"GET /test HTTP/1.1\r\n" in buf
assert b"Transfer-Encoding: chunked\r\n" in buf
assert buf.endswith(b"0\r\n\r\n")


@pytest.mark.usefixtures("parametrize_zlib_backend")
async def test_compression_cancelled_closes_connection(
buf: bytearray,
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
slow_executor: ThreadPoolExecutor,
) -> None:
"""Test that HTTP doesn't need cancellation safety like WebSocket does.

This demonstrates that for HTTP, compression cancellation is handled by
closing the entire connection. The compressor state corruption doesn't
matter because:
1. The HTTP response/request is aborted
2. The connection is closed
3. A new connection with a fresh compressor is established for next request

This is different from WebSocket which:
1. Keeps the connection alive across many messages
2. Reuses the same stateful compressor
3. Needs shield protection to prevent state corruption

Related to https://github.com/aio-libs/aiohttp/issues/11725
"""
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()

# Patch the compressor to use slow executor
msg._compress = msg._compress or msg._make_compressor()
msg._compress._executor = slow_executor

# Write large data that triggers executor-based compression
with mock.patch("aiohttp.compression_utils.MAX_SYNC_CHUNK_SIZE", 1024):
large_data = b"X" * 10000

# Start write and cancel it
task = asyncio.create_task(msg.write(large_data))
await asyncio.sleep(0.01) # Let compression start
task.cancel()

try:
await task
except asyncio.CancelledError:
pass

# In real HTTP scenarios, after cancellation the response is aborted and
# the connection is closed. This makes compressor state corruption harmless.

# Assert that the compressor exists and has been used
assert msg._compress is not None, "Compressor should have been created"

# The compressor state may be corrupted after cancellation, but for HTTP
# this doesn't matter because:
# 1. The response is aborted
# 2. The connection will be closed
# 3. The next request gets a fresh connection with a new compressor

# Key differences from WebSocket:
#
# HTTP pattern:
# - One compressor instance per HTTP response
# - Compression cancellation → response aborted → connection closed
# - Next request gets a new connection with fresh compressor
# - Compressor state corruption is harmless
# - No need for shield protection
#
# WebSocket pattern (see test_websocket_writer.py):
# - One compressor instance for entire WebSocket connection lifetime
# - Connection stays open across hundreds/thousands of messages
# - Compression cancellation without shield → state corruption → data corruption
# - Requires shield + lock to ensure atomicity
117 changes: 117 additions & 0 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import random
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from unittest import mock

import pytest
Expand Down Expand Up @@ -143,6 +144,122 @@ async def test_send_compress_text_per_message(
writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined]


@pytest.mark.usefixtures("parametrize_zlib_backend")
async def test_send_compress_cancelled(
protocol: BaseProtocol,
transport: asyncio.Transport,
slow_executor: ThreadPoolExecutor,
) -> None:
"""Test that cancelled compression doesn't corrupt subsequent sends.

Regression test for https://github.com/aio-libs/aiohttp/issues/11725
"""
with mock.patch("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024):
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)

# Replace executor with slow one to make race condition reproducible
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create large data that will trigger executor-based compression
large_data_1 = b"A" * 10000
large_data_2 = b"B" * 10000

# Start first send and cancel it during compression
async def send_and_cancel():
await writer.send_frame(large_data_1, WSMsgType.BINARY)

task = asyncio.create_task(send_and_cancel())
# Give it a moment to start compression
await asyncio.sleep(0.01)
task.cancel()

# Catch the cancellation
try:
await task
except asyncio.CancelledError:
pass

# Send second message - this should NOT be corrupted
await writer.send_frame(large_data_2, WSMsgType.BINARY)

# Verify the second send produced correct data
last_call = writer.transport.write.call_args_list[-1] # type: ignore[attr-defined]
call_bytes = last_call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# The data should be all B's, not mixed with A's from the cancelled send
assert msg.data == large_data_2


@pytest.mark.usefixtures("parametrize_zlib_backend")
async def test_send_compress_multiple_cancelled(
protocol: BaseProtocol,
transport: asyncio.Transport,
slow_executor: ThreadPoolExecutor,
) -> None:
"""Test that multiple compressed sends all complete despite cancellation.

Regression test for https://github.com/aio-libs/aiohttp/issues/11725
This verifies that once a send operation enters the shield, it completes
even if cancelled. The lock serializes sends, so they process one at a time.
"""
with mock.patch("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024):
writer = WebSocketWriter(protocol, transport, compress=15)
loop = asyncio.get_running_loop()
queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop)
reader = WebSocketReader(queue, 50000)

# Replace executor with slow one
writer._compressobj = writer._make_compress_obj(writer.compress)
writer._compressobj._executor = slow_executor

# Create 5 large messages with different content
messages = [bytes([ord("A") + i]) * 10000 for i in range(5)]

# Start sending all 5 messages - they'll queue due to the lock
tasks = [
asyncio.create_task(writer.send_frame(msg, WSMsgType.BINARY))
for msg in messages
]

# Cancel all tasks during execution
# Tasks in the shield will complete, tasks waiting for lock will cancel
await asyncio.sleep(0.03) # Let one or two enter the shield
for task in tasks:
task.cancel()

# Collect results
cancelled_count = 0
for task in tasks:
try:
await task
except asyncio.CancelledError:
cancelled_count += 1

# At least one message should have been sent (the one in the shield)
sent_count = len(writer.transport.write.call_args_list) # type: ignore[attr-defined]
assert sent_count >= 1, "At least one send should complete due to shield"
assert sent_count <= 5, "Can't send more than 5 messages"

# Verify all sent messages are correct (no corruption)
for i in range(sent_count):
call = writer.transport.write.call_args_list[i] # type: ignore[attr-defined]
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
assert msg.type is WSMsgType.BINARY
# Verify the data matches the expected message
expected_byte = bytes([ord("A") + i])
assert msg.data == expected_byte * 10000, f"Message {i} corrupted"


@pytest.mark.parametrize(
("max_sync_chunk_size", "payload_point_generator"),
(
Expand Down