Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/11725.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed WebSocket compressed sends to be cancellation safe. Tasks are now shielded during compression to prevent compressor state corruption. This ensures that the stateful compressor remains consistent even when send operations are cancelled -- by :user:`bdraco`.
196 changes: 140 additions & 56 deletions aiohttp/_websocket/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import asyncio
import random
import sys
from functools import partial
from typing import Any, Final
from typing import Final

from ..base_protocol import BaseProtocol
from ..client_exceptions import ClientConnectionResetError
Expand All @@ -22,14 +23,18 @@

DEFAULT_LIMIT: Final[int] = 2**16

# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
# Control frames (ping, pong, close) are never compressed
WS_CONTROL_FRAME_OPCODE: Final[int] = 8

# For websockets, keeping latency low is extremely important as implementations
# generally expect to be able to send and receive messages quickly. We use a
# larger chunk size than the default to reduce the number of executor calls
# since the executor is a significant source of latency and overhead when
# the chunks are small. A size of 5KiB was chosen because it is also the
# same value python-zlib-ng choose to use as the threshold to release the GIL.
# generally expect to be able to send and receive messages quickly. We use a
# larger chunk size to reduce the number of executor calls and avoid task
# creation overhead, since both are significant sources of latency when chunks
# are small. A size of 16KiB was chosen as a balance between avoiding task
# overhead and not blocking the event loop too long with synchronous compression.

WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024


class WebSocketWriter:
Expand Down Expand Up @@ -62,7 +67,9 @@ def __init__(
self._closing = False
self._limit = limit
self._output_size = 0
self._compressobj: Any = None # actually compressobj
self._compressobj: ZLibCompressor | None = None
self._send_lock = asyncio.Lock()
self._background_tasks: set[asyncio.Task[None]] = set()

async def send_frame(
self, message: bytes, opcode: int, compress: int | None = None
Expand All @@ -71,39 +78,57 @@ async def send_frame(
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ClientConnectionResetError("Cannot write to closing transport")

# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
rsv = 0
# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
# RSV1 (rsv = 0x40) is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
rsv = 0x40

if compress:
# Do not set self._compress if compressing is for this frame
compressobj = self._make_compress_obj(compress)
else: # self.compress
if not self._compressobj:
self._compressobj = self._make_compress_obj(self.compress)
compressobj = self._compressobj

message = (
await compressobj.compress(message)
+ compressobj.flush(
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING)
# Its critical that we do not return control to the event
# loop until we have finished sending all the compressed
# data. Otherwise we could end up mixing compressed frames
# if there are multiple coroutines compressing data.
if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
# Non-compressed frames don't need lock or shield
self._write_websocket_frame(message, opcode, 0)
elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
# Small compressed payloads - compress synchronously in event loop
# We need the lock even though sync compression has no await points.
# This prevents small frames from interleaving with large frames that
# compress in the executor, avoiding compressor state corruption.
async with self._send_lock:
self._send_compressed_frame_sync(message, opcode, compress)
else:
# Large compressed frames need shield to prevent corruption
# For large compressed frames, the entire compress+send
# operation must be atomic. If cancelled after compression but
# before send, the compressor state would be advanced but data
# not sent, corrupting subsequent frames.
# Create a task to shield from cancellation
# The lock is acquired inside the shielded task so the entire
# operation (lock + compress + send) completes atomically.
# Use eager_start on Python 3.12+ to avoid scheduling overhead
loop = asyncio.get_running_loop()
coro = self._send_compressed_frame_async_locked(message, opcode, compress)
if sys.version_info >= (3, 12):
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
send_task = loop.create_task(coro)
# Keep a strong reference to prevent garbage collection
self._background_tasks.add(send_task)
send_task.add_done_callback(self._background_tasks.discard)
await asyncio.shield(send_task)

# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
if self.protocol._paused:
await self.protocol._drain_helper()

def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
"""
Write a websocket frame to the transport.

This method handles frame header construction, masking, and writing to transport.
It does not handle compression or flow control - those are the responsibility
of the caller.
"""
msg_length = len(message)

use_mask = self.use_mask
Expand Down Expand Up @@ -146,26 +171,85 @@ async def send_frame(

self._output_size += header_len + msg_length

# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
def _get_compressor(self, compress: int | None) -> ZLibCompressor:
"""Get or create a compressor object for the given compression level."""
if compress:
# Do not set self._compress if compressing is for this frame
return ZLibCompressor(
level=ZLibBackend.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
if not self._compressobj:
self._compressobj = ZLibCompressor(
level=ZLibBackend.Z_BEST_SPEED,
wbits=-self.compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
return self._compressobj

# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
if self.protocol._paused:
await self.protocol._drain_helper()
def _send_compressed_frame_sync(
self, message: bytes, opcode: int, compress: int | None
) -> None:
"""
Synchronous send for small compressed frames.

def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=ZLibBackend.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
This is used for small compressed payloads that compress synchronously in the event loop.
Since there are no await points, this is inherently cancellation-safe.
"""
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
compressobj = self._get_compressor(compress)
# (0x40) RSV1 is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
self._write_websocket_frame(
(
compressobj.compress_sync(message)
+ compressobj.flush(
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING),
opcode,
0x40,
)

async def _send_compressed_frame_async_locked(
self, message: bytes, opcode: int, compress: int | None
) -> None:
"""
Async send for large compressed frames with lock.

Acquires the lock and compresses large payloads asynchronously in
the executor. The lock is held for the entire operation to ensure
the compressor state is not corrupted by concurrent sends.

MUST be run shielded from cancellation. If cancelled after
compression but before sending, the compressor state would be
advanced but data not sent, corrupting subsequent frames.
"""
async with self._send_lock:
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
compressobj = self._get_compressor(compress)
# (0x40) RSV1 is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
self._write_websocket_frame(
(
await compressobj.compress(message)
+ compressobj.flush(
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING),
opcode,
0x40,
)

async def close(self, code: int = 1000, message: bytes | str = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
Expand Down
42 changes: 28 additions & 14 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
if level is not None:
kwargs["level"] = level
self._compressor = self._zlib_backend.compressobj(**kwargs)
self._compress_lock = asyncio.Lock()

def compress_sync(self, data: Buffer) -> bytes:
return self._compressor.compress(data)
Expand All @@ -198,22 +197,37 @@ async def compress(self, data: Buffer) -> bytes:
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.

**WARNING: This method is NOT cancellation-safe when used with flush().**
If this operation is cancelled, the compressor state may be corrupted.
The connection MUST be closed after cancellation to avoid data corruption
in subsequent compress operations.

For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
# For large payloads, offload compression to executor to avoid blocking event loop
should_use_executor = (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
)
if should_use_executor:
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)

def flush(self, mode: int | None = None) -> bytes:
"""Flush the compressor synchronously.

**WARNING: This method is NOT cancellation-safe when called after compress().**
The flush() operation accesses shared compressor state. If compress() was
cancelled, calling flush() may result in corrupted data. The connection MUST
be closed after compress() cancellation.

For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
return self._compressor.flush(
mode if mode is not None else self._zlib_backend.Z_FINISH
)
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ SocketSocketTransport
ssl
SSLContext
startup
stateful
subapplication
subclassed
subclasses
Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import socket
import ssl
import sys
import time
from collections.abc import AsyncIterator, Callable, Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from hashlib import md5, sha1, sha256
from http.cookies import BaseCookie
from pathlib import Path
Expand Down Expand Up @@ -450,3 +452,27 @@ def maker(
await request._close()
assert session is not None
await session.close()


@pytest.fixture
def slow_executor() -> Iterator[ThreadPoolExecutor]:
"""Executor that adds delay to simulate slow operations.

Useful for testing cancellation and race conditions in compression tests.
"""

class SlowExecutor(ThreadPoolExecutor):
"""Executor that adds delay to operations."""

def submit(
self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any
) -> Future[Any]:
def slow_fn(*args: Any, **kwargs: Any) -> Any:
time.sleep(0.05) # Add delay to simulate slow operation
return fn(*args, **kwargs)

return super().submit(slow_fn, *args, **kwargs)

executor = SlowExecutor(max_workers=10)
yield executor
executor.shutdown(wait=True)
Loading
Loading