diff --git a/components/backends/vllm/src/dynamo/vllm/handlers.py b/components/backends/vllm/src/dynamo/vllm/handlers.py index 2c4590a898..1776db6ebc 100644 --- a/components/backends/vllm/src/dynamo/vllm/handlers.py +++ b/components/backends/vllm/src/dynamo/vllm/handlers.py @@ -32,7 +32,7 @@ def __init__(self, component, engine, default_sampling_params): self.kv_publisher = None @abstractmethod - async def generate(self, request) -> AsyncGenerator[dict, None]: + async def generate(self, request, context) -> AsyncGenerator[dict, None]: raise NotImplementedError async def clear_kv_blocks(self, request=None): @@ -110,7 +110,7 @@ def cleanup(self): self._prefill_check_task.cancel() super().cleanup() - async def generate(self, request): + async def generate(self, request, context): request_id = str(uuid.uuid4().hex) logger.debug(f"New Request ID: {request_id}") @@ -147,9 +147,20 @@ async def generate(self, request): # TODO Change to prefill queue if self.prefill_worker_client is not None: - prefill_response = await anext( - await self.prefill_worker_client.round_robin(prefill_request) - ) + try: + prefill_response = await anext( + await self.prefill_worker_client.round_robin( + prefill_request, context=context + ) + ) + except Exception as e: + # TODO: Cancellation does not propagate until the first token is received + if context.is_stopped() or context.is_killed(): + logger.debug(f"Aborted Remote Prefill Request ID: {request_id}") + # TODO: Raise asyncio.CancelledError into bindings + return + raise e + prefill_response = MyRequestOutput.model_validate_json( prefill_response.data() ) @@ -162,6 +173,12 @@ async def generate(self, request): ] = prefill_response.kv_transfer_params async for tok in self.generate_tokens(prompt, sampling_params, request_id): + if context.is_stopped() or context.is_killed(): + await self.engine_client.abort(request_id) + logger.debug(f"Aborted Request ID: {request_id}") + # TODO: Raise asyncio.CancelledError into bindings + break + yield tok @@ -169,7 +186,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): def __init__(self, component, engine, default_sampling_params): super().__init__(component, engine, default_sampling_params) - async def generate(self, request): + async def generate(self, request, context): request_id = request["request_id"] logger.debug(f"New Prefill Request ID: {request_id}") @@ -181,6 +198,12 @@ async def generate(self, request): # Generate only 1 token in prefill try: async for res in gen: + if context.is_stopped() or context.is_killed(): + await self.engine_client.abort(request_id) + logger.debug(f"Aborted Prefill Request ID: {request_id}") + # TODO: Raise asyncio.CancelledError into bindings + break + logger.debug(f"kv transfer params: {res.kv_transfer_params}") yield MyRequestOutput( request_id=res.request_id, diff --git a/docs/architecture/request_cancellation.md b/docs/architecture/request_cancellation.md new file mode 100644 index 0000000000..27946d9d81 --- /dev/null +++ b/docs/architecture/request_cancellation.md @@ -0,0 +1,86 @@ +# Request Cancellation Architecture + +This document describes how Dynamo implements request cancellation to cancel in-flight requests between Dynamo workers. Request cancellation allows in-flight requests to terminate early, saving computational resources that would otherwise be spent on responses that are no longer needed. + +## AsyncEngineContext Trait + +At the core of Dynamo's request cancellation system is the `AsyncEngineContext` trait. This trait is associated with every request stream and provides lifecycle management for async operations, including stream identification, graceful shutdown capabilities, and immediate termination capabilities. + +### Key Methods + +#### Identification +- **`id()`**: Returns the unique identifier for the stream. This ID is set by the user for request identification, and the same ID can be used for sub-requests to associate them with the original user request. + +#### Status Checking +- **`is_stopped()`**: Returns `true` if graceful cancellation has been requested via `stop_generating()`. This represents a signal to the worker that the request has been cancelled and it should return early. +- **`is_killed()`**: Returns `true` if a hard stop has been issued via `kill()`. This typically indicates that the network connection between client and server has been cut or an immediate termination is required. + +#### Async Status Monitoring +- **`stopped()`**: An async method that completes when the context becomes stopped. If already stopped, returns immediately. +- **`killed()`**: An async method that completes when the context becomes killed. If already killed, returns immediately. + +#### Cancellation Control +- **`stop_generating()`**: The recommended method for cancelling a request. This informs the engine to stop producing results for the stream gracefully. This method is idempotent and does not invalidate results currently in the stream. +- **`stop()`**: Alias for `stop_generating()`. +- **`kill()`**: Extends `stop_generating()` but also indicates a preference to terminate without draining remaining items in the stream. This is implementation-specific and may not be supported by all engines. + +#### Child Request Management +- **`link_child(child: Arc)`**: Links a child `AsyncEngineContext` to this context. When `stop_generating()`, `stop()`, or `kill()` is called on the parent context, the same method is automatically called on all linked child contexts in the order they were linked. This is especially useful in disaggregated serving scenarios where a frontend receives cancellation notification and needs to cancel requests to workers, and the worker can then cancel its sub-requests (e.g., remote prefill operations). + +### Thread Safety + +The `AsyncEngineContext` trait ensures thread-safety with `Send + Sync` bounds, allowing safe concurrent access across multiple threads and async tasks. + +## Python Bindings + +The `AsyncEngineContext` functionality is exposed to Python through the `Context` class, which provides a largely one-to-one mapping from Rust methods to Python methods. + +### Python Context Class + +The Python `Context` class wraps the Rust `AsyncEngineContext` and exposes the following methods: + +- **`id()`**: Returns the unique identifier for the context +- **`is_stopped()`**: Synchronous method equivalent to the Rust `is_stopped()` +- **`is_killed()`**: Synchronous method equivalent to the Rust `is_killed()` +- **`stop_generating()`**: Issues a stop generating signal, equivalent to the Rust method +- **`async_killed_or_stopped()`**: An async method that completes when the context becomes either killed or stopped, whichever happens first. This combines the functionality of the Rust `killed()` and `stopped()` async methods using `tokio::select!`. + +### Context Usage in Python + +The context is available optionally in both incoming and outgoing request scenarios: + +#### Incoming Requests +For incoming requests, the generate method may optionally accept a `context` argument after the `request` argument. If the `context` parameter is specified in the method signature, it will receive the context object of the incoming request. Request handlers can: + +- Check for cancellation synchronously using `context.is_stopped()` before beginning expensive operations +- Listen for cancellation asynchronously using `await context.async_killed_or_stopped()` + +Example: +```python +async def generate(self, request, context): + for i in range(1000): + # Check for cancellation before expensive work + if context.is_stopped(): + raise asyncio.CancelledError + + # Perform work... + await expensive_computation() + yield result +``` + +#### Outgoing Requests +For outgoing requests, Python scripts may optionally provide a context object to outgoing runtime endpoint client router operations (such as `generate`, `round_robin`, `random`, `direct` methods) as a keyword argument. The script can cancel the outgoing request via the provided context object. + +This is especially useful when child outgoing requests need to be cancelled when the parent incoming request is cancelled. In such cases, the script can simply pass the incoming context object to the outgoing request, automatically linking the cancellation behavior. + +Example: +```python +async def generate(self, request, context): + # Forward the incoming context to outgoing request + # If the incoming request is cancelled, the outgoing request will be too + stream = await self.client.generate(request, context=context) + async for response in stream: + yield response +``` + +This design enables seamless cancellation propagation through multi-tier request chains, ensuring that when a client cancels a request, all associated sub-requests are automatically cancelled, saving computational resources across the entire request pipeline. diff --git a/docs/guides/backend.md b/docs/guides/backend.md index 68b0e98432..86685ca7b6 100644 --- a/docs/guides/backend.md +++ b/docs/guides/backend.md @@ -137,3 +137,25 @@ class RequestHandler: When `GeneratorExit` is raised, the frontend receives the incomplete response and can seamlessly continue generation on another available worker instance, preserving the user experience even during worker shutdowns. For more information about how request migration works, see the [Request Migration Architecture](../architecture/request_migration.md) documentation. + +## Request Cancellation + +Your Python worker's request handler can optionally support request cancellation by accepting a `context` argument after the `request` argument. This context object allows you to check for cancellation signals and respond appropriately: + +```python +class RequestHandler: + + async def generate(self, request, context): + """Generate response with cancellation support""" + for result in self.engine.generate_streaming(request): + # Check if the request has been cancelled + if context.is_stopped(): + # Stop processing and clean up + break + + yield result +``` + +The context parameter is optional - if your generate method doesn't include it in its signature, Dynamo will call your method without the context argument. + +For detailed information about request cancellation, including async cancellation monitoring and context propagation patterns, see the [Request Cancellation Architecture](../architecture/request_cancellation.md) documentation. diff --git a/docs/guides/dynamo_run.md b/docs/guides/dynamo_run.md index d90fbe5801..48533e6388 100644 --- a/docs/guides/dynamo_run.md +++ b/docs/guides/dynamo_run.md @@ -178,6 +178,12 @@ dynamo-run in=dyn://... out= ... --migration-limit=3 This allows a request to be migrated up to 3 times before failing. See the [Request Migration Architecture](../architecture/request_migration.md) documentation for details on how this works. +### Request Cancellation + +When using the HTTP interface (`in=http`), if the HTTP request connection is dropped by the client, Dynamo automatically cancels the downstream request to the worker. This ensures that computational resources are not wasted on generating responses that are no longer needed. + +For detailed information about how request cancellation works across the system, see the [Request Cancellation Architecture](../architecture/request_cancellation.md) documentation. + ## Development `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features. diff --git a/lib/bindings/python/rust/context.rs b/lib/bindings/python/rust/context.rs index b14337682d..1edf38d465 100644 --- a/lib/bindings/python/rust/context.rs +++ b/lib/bindings/python/rust/context.rs @@ -1,28 +1,46 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings. +// Context is a wrapper around the AsyncEngineContext to allow for Python bindings. +use dynamo_runtime::pipeline::context::Controller; pub use dynamo_runtime::pipeline::AsyncEngineContext; use pyo3::prelude::*; use std::sync::Arc; -// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings. +// Context is a wrapper around the AsyncEngineContext to allow for Python bindings. // Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation. // Kept as class, to allow for future expansion if needed. +#[derive(Clone)] #[pyclass] -pub struct PyContext { - pub inner: Arc, +pub struct Context { + inner: Arc, } -impl PyContext { +impl Context { pub fn new(inner: Arc) -> Self { Self { inner } } + + pub fn inner(&self) -> Arc { + self.inner.clone() + } } #[pymethods] -impl PyContext { +impl Context { + #[new] + #[pyo3(signature = (id=None))] + fn py_new(id: Option) -> Self { + let controller = match id { + Some(id) => Controller::new(id), + None => Controller::default(), + }; + Self { + inner: Arc::new(controller), + } + } + // sync method of `await async_is_stopped()` fn is_stopped(&self) -> bool { self.inner.is_stopped() diff --git a/lib/bindings/python/rust/engine.rs b/lib/bindings/python/rust/engine.rs index c26c0f3bb5..8aeab7ed7d 100644 --- a/lib/bindings/python/rust/engine.rs +++ b/lib/bindings/python/rust/engine.rs @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::context::{callable_accepts_kwarg, PyContext}; +use super::context::{callable_accepts_kwarg, Context}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyModule}; use pyo3::{PyAny, PyErr}; @@ -114,7 +114,7 @@ pub struct PythonServerStreamingEngine { _cancel_token: CancellationToken, generator: Arc, event_loop: Arc, - has_pycontext: bool, + has_context: bool, } impl PythonServerStreamingEngine { @@ -123,7 +123,7 @@ impl PythonServerStreamingEngine { generator: Arc, event_loop: Arc, ) -> Self { - let has_pycontext = Python::with_gil(|py| { + let has_context = Python::with_gil(|py| { let callable = generator.bind(py); callable_accepts_kwarg(py, callable, "context").unwrap_or(false) }); @@ -132,7 +132,7 @@ impl PythonServerStreamingEngine { _cancel_token: cancel_token, generator, event_loop, - has_pycontext, + has_context, } } } @@ -175,7 +175,7 @@ where let generator = self.generator.clone(); let event_loop = self.event_loop.clone(); let ctx_python = ctx.clone(); - let has_pycontext = self.has_pycontext; + let has_context = self.has_context; // Acquiring the GIL is similar to acquiring a standard lock/mutex // Performing this in an tokio async task could block the thread for an undefined amount of time @@ -190,9 +190,9 @@ where let stream = tokio::task::spawn_blocking(move || { Python::with_gil(|py| { let py_request = pythonize(py, &request)?; - let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?; + let py_ctx = Py::new(py, Context::new(ctx_python.clone()))?; - let gen = if has_pycontext { + let gen = if has_context { // Pass context as a kwarg let kwarg = PyDict::new(py); kwarg.set_item("context", &py_ctx)?; diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 5989f076b1..7b11e191fc 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -16,7 +16,8 @@ use tokio::sync::Mutex; use dynamo_runtime::{ self as rs, logging, pipeline::{ - network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn, + context::Context as RsContext, network::egress::push_router::RouterMode as RsRouterMode, + EngineStream, ManyOut, SingleIn, }, protocols::annotated::Annotated as RsAnnotated, traits::DistributedRuntimeProvider, @@ -104,7 +105,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -697,27 +698,29 @@ impl Client { } /// Issue a request to the endpoint using the default routing strategy. - #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))] + #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))] fn generate<'p>( &self, py: Python<'p>, request: PyObject, annotated: Option, + context: Option, ) -> PyResult> { if self.router.client.is_static() { - self.r#static(py, request, annotated) + self.r#static(py, request, annotated, context) } else { - self.random(py, request, annotated) + self.random(py, request, annotated, context) } } /// Send a request to the next endpoint in a round-robin fashion. - #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))] + #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))] fn round_robin<'p>( &self, py: Python<'p>, request: PyObject, annotated: Option, + context: Option, ) -> PyResult> { let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?; let annotated = annotated.unwrap_or(false); @@ -726,7 +729,15 @@ impl Client { let client = self.router.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?; + let stream = match context { + Some(context) => { + let request = RsContext::with_id(request, context.inner().id().to_string()); + let stream = client.round_robin(request).await.map_err(to_pyerr)?; + context.inner().link_child(stream.context()); + stream + } + _ => client.round_robin(request.into()).await.map_err(to_pyerr)?, + }; tokio::spawn(process_stream(stream, tx)); Ok(AsyncResponseStream { rx: Arc::new(Mutex::new(rx)), @@ -736,12 +747,13 @@ impl Client { } /// Send a request to a random endpoint. - #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))] + #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))] fn random<'p>( &self, py: Python<'p>, request: PyObject, annotated: Option, + context: Option, ) -> PyResult> { let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?; let annotated = annotated.unwrap_or(false); @@ -750,7 +762,15 @@ impl Client { let client = self.router.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let stream = client.random(request.into()).await.map_err(to_pyerr)?; + let stream = match context { + Some(context) => { + let request = RsContext::with_id(request, context.inner().id().to_string()); + let stream = client.random(request).await.map_err(to_pyerr)?; + context.inner().link_child(stream.context()); + stream + } + _ => client.random(request.into()).await.map_err(to_pyerr)?, + }; tokio::spawn(process_stream(stream, tx)); Ok(AsyncResponseStream { rx: Arc::new(Mutex::new(rx)), @@ -760,13 +780,14 @@ impl Client { } /// Directly send a request to a specific endpoint. - #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING))] + #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))] fn direct<'p>( &self, py: Python<'p>, request: PyObject, instance_id: i64, annotated: Option, + context: Option, ) -> PyResult> { let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?; let annotated = annotated.unwrap_or(false); @@ -775,10 +796,21 @@ impl Client { let client = self.router.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let stream = client - .direct(request.into(), instance_id) - .await - .map_err(to_pyerr)?; + let stream = match context { + Some(context) => { + let request = RsContext::with_id(request, context.inner().id().to_string()); + let stream = client + .direct(request, instance_id) + .await + .map_err(to_pyerr)?; + context.inner().link_child(stream.context()); + stream + } + _ => client + .direct(request.into(), instance_id) + .await + .map_err(to_pyerr)?, + }; tokio::spawn(process_stream(stream, tx)); @@ -790,12 +822,13 @@ impl Client { } /// Directly send a request to a pre-defined static worker - #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))] + #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))] fn r#static<'p>( &self, py: Python<'p>, request: PyObject, annotated: Option, + context: Option, ) -> PyResult> { let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?; let annotated = annotated.unwrap_or(false); @@ -804,7 +837,15 @@ impl Client { let client = self.router.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let stream = client.r#static(request.into()).await.map_err(to_pyerr)?; + let stream = match context { + Some(context) => { + let request = RsContext::with_id(request, context.inner().id().to_string()); + let stream = client.r#static(request).await.map_err(to_pyerr)?; + context.inner().link_child(stream.context()); + stream + } + _ => client.r#static(request.into()).await.map_err(to_pyerr)?, + }; tokio::spawn(process_stream(stream, tx)); diff --git a/lib/bindings/python/src/dynamo/runtime/__init__.py b/lib/bindings/python/src/dynamo/runtime/__init__.py index 497adb2036..c4ee81edc6 100644 --- a/lib/bindings/python/src/dynamo/runtime/__init__.py +++ b/lib/bindings/python/src/dynamo/runtime/__init__.py @@ -25,12 +25,12 @@ from dynamo._core import Backend as Backend from dynamo._core import Client as Client from dynamo._core import Component as Component +from dynamo._core import Context as Context from dynamo._core import DistributedRuntime as DistributedRuntime from dynamo._core import Endpoint as Endpoint from dynamo._core import EtcdKvCache as EtcdKvCache from dynamo._core import ModelDeploymentCard as ModelDeploymentCard from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor -from dynamo._core import PyContext as PyContext def dynamo_worker(static=False): diff --git a/lib/bindings/python/tests/conftest.py b/lib/bindings/python/tests/conftest.py new file mode 100644 index 0000000000..4fa7027db9 --- /dev/null +++ b/lib/bindings/python/tests/conftest.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +from time import sleep + +import pytest + + +@pytest.fixture(scope="module", autouse=True) +def nats_and_etcd(): + # Setup code + nats_server = subprocess.Popen(["nats-server", "-js"]) + etcd = subprocess.Popen(["etcd"]) + print("Setting up resources") + + sleep(5) # wait for nats-server and etcd to start + yield + + # Teardown code + print("Tearing down resources") + nats_server.terminate() + nats_server.wait() + etcd.terminate() + etcd.wait() diff --git a/lib/bindings/python/tests/test_cancellation/conftest.py b/lib/bindings/python/tests/test_cancellation/conftest.py new file mode 100644 index 0000000000..5c20b150dd --- /dev/null +++ b/lib/bindings/python/tests/test_cancellation/conftest.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import random +import string + +import pytest + +from dynamo._core import DistributedRuntime + + +class MockServer: + """ + Test request handler that simulates a generate method with cancellation support + """ + + def __init__(self): + self.context_is_stopped = False + self.context_is_killed = False + + async def generate(self, request, context): + self.context_is_stopped = False + self.context_is_killed = False + + method_name = request + assert hasattr( + self, method_name + ), f"Method '{method_name}' not found on {self.__class__.__name__}" + method = getattr(self, method_name) + async for response in method(request, context): + yield response + + async def _generate_until_context_cancelled(self, request, context): + """ + Generate method that yields numbers 0-999 every 0.1 seconds + Checks for context.is_stopped() / context.is_killed() before each yield and raises + CancelledError if stopped / killed + """ + for i in range(1000): + print(f"Processing iteration {i}") + + # Check if context is stopped + if context.is_stopped(): + print(f"Context stopped at iteration {i}") + self.context_is_stopped = True + self.context_is_killed = context.is_killed() + raise asyncio.CancelledError + + # Check if context is killed + if context.is_killed(): + print(f"Context killed at iteration {i}") + self.context_is_stopped = context.is_stopped() + self.context_is_killed = True + raise asyncio.CancelledError + + await asyncio.sleep(0.1) + + print(f"Sending iteration {i}") + yield i + + assert ( + False + ), "Test failed: generate_until_cancelled did not raise CancelledError" + + async def _generate_until_asyncio_cancelled(self, request, context): + """ + Generate method that yields numbers 0-999 every 0.1 seconds + """ + i = 0 + try: + for i in range(1000): + print(f"Processing iteration {i}") + await asyncio.sleep(0.1) + print(f"Sending iteration {i}") + yield i + except asyncio.CancelledError: + print(f"Cancelled at iteration {i}") + self.context_is_stopped = context.is_stopped() + self.context_is_killed = context.is_killed() + raise + + assert ( + False + ), "Test failed: generate_until_cancelled did not raise CancelledError" + + async def _generate_and_cancel_context(self, request, context): + """ + Generate method that yields numbers 0-1, and then cancel the context + """ + for i in range(2): + print(f"Processing iteration {i}") + await asyncio.sleep(0.1) + print(f"Sending iteration {i}") + yield i + + context.stop_generating() + + self.context_is_stopped = context.is_stopped() + self.context_is_killed = context.is_killed() + + async def _generate_and_raise_cancelled(self, request, context): + """ + Generate method that yields numbers 0-1, and then raise asyncio.CancelledError + """ + for i in range(2): + print(f"Processing iteration {i}") + await asyncio.sleep(0.1) + print(f"Sending iteration {i}") + yield i + + raise asyncio.CancelledError + + +def random_string(length=10): + """Generate a random string for namespace isolation""" + # Start with a letter to satisfy Prometheus naming requirements + first_char = random.choice(string.ascii_lowercase) + remaining_chars = string.ascii_lowercase + string.digits + rest = "".join(random.choices(remaining_chars, k=length - 1)) + return first_char + rest + + +@pytest.fixture +async def runtime(): + """Create a DistributedRuntime for testing""" + loop = asyncio.get_running_loop() + runtime = DistributedRuntime(loop, True) + yield runtime + runtime.shutdown() + + +@pytest.fixture +def namespace(): + """Generate a random namespace for test isolation""" + return random_string() + + +@pytest.fixture +async def server(runtime, namespace): + """Start a test server in the background""" + + handler = MockServer() + + async def init_server(): + """Initialize the test server component and serve the generate endpoint""" + component = runtime.namespace(namespace).component("backend") + await component.create_service() + + endpoint = component.endpoint("generate") + print("Started test server instance") + + # Serve the endpoint - this will block until shutdown + await endpoint.serve_endpoint(handler.generate) + + # Start server in background task + server_task = asyncio.create_task(init_server()) + + # Give server time to start up + await asyncio.sleep(0.5) + + yield server_task, handler + + # Cleanup - cancel server task + if not server_task.done(): + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass + + +@pytest.fixture +async def client(runtime, namespace): + """Create a client connected to the test server""" + # Create client + endpoint = runtime.namespace(namespace).component("backend").endpoint("generate") + client = await endpoint.client() + await client.wait_for_instances() + + return client diff --git a/lib/bindings/python/tests/test_cancellation/test_cancellation.py b/lib/bindings/python/tests/test_cancellation/test_cancellation.py new file mode 100644 index 0000000000..cf170dcb44 --- /dev/null +++ b/lib/bindings/python/tests/test_cancellation/test_cancellation.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess + +import pytest + +pytestmark = pytest.mark.pre_merge + + +def _run_test_in_subprocess(test_name: str): + """Helper function to run a test file in a separate process""" + test_file = os.path.join(os.path.dirname(__file__), f"{test_name}.py") + result = subprocess.run( + ["pytest", test_file, "-v"], + capture_output=True, + text=True, + cwd=os.path.dirname(__file__), + ) + + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + print("Return code:", result.returncode) + + assert ( + result.returncode == 0 + ), f"Test {test_name} failed with return code {result.returncode}" + + +def test_client_context_cancel(): + _run_test_in_subprocess("test_client_context_cancel") + + +def test_client_loop_break(): + _run_test_in_subprocess("test_client_loop_break") + + +def test_server_context_cancel(): + _run_test_in_subprocess("test_server_context_cancel") + + +def test_server_raise_cancelled(): + _run_test_in_subprocess("test_server_raise_cancelled") diff --git a/lib/bindings/python/tests/test_cancellation/test_client_context_cancel.py b/lib/bindings/python/tests/test_cancellation/test_client_context_cancel.py new file mode 100644 index 0000000000..6a49d6b622 --- /dev/null +++ b/lib/bindings/python/tests/test_cancellation/test_client_context_cancel.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from dynamo._core import Context + + +@pytest.mark.asyncio +async def test_client_context_cancel(server, client): + _, handler = server + context = Context() + stream = await client.generate("_generate_until_context_cancelled", context=context) + + iteration_count = 0 + async for annotated in stream: + number = annotated.data() + print(f"Received iteration: {number}") + + # Verify received valid number + assert number == iteration_count + + # Break after receiving 2 responses + if iteration_count >= 2: + print("Cancelling after 2 responses...") + context.stop_generating() + break + + iteration_count += 1 + + # Give server a moment to process the cancellation + await asyncio.sleep(0.2) + + # Verify server detected the cancellation + assert handler.context_is_stopped + assert handler.context_is_killed + + # TODO: Test with _generate_until_asyncio_cancelled server handler diff --git a/lib/bindings/python/tests/test_cancellation/test_client_loop_break.py b/lib/bindings/python/tests/test_cancellation/test_client_loop_break.py new file mode 100644 index 0000000000..a5031d85e7 --- /dev/null +++ b/lib/bindings/python/tests/test_cancellation/test_client_loop_break.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_client_loop_break(server, client): + _, handler = server + stream = await client.generate("_generate_until_context_cancelled") + + iteration_count = 0 + async for annotated in stream: + number = annotated.data() + print(f"Received iteration: {number}") + + # Verify received valid number + assert number == iteration_count + + # Break after receiving 2 responses + if iteration_count >= 2: + print("Cancelling after 2 responses...") + break + + iteration_count += 1 + + # Give server a moment to process the cancellation + await asyncio.sleep(0.2) + + # TODO: Implicit cancellation is not yet implemented, so the server context will not + # show any cancellation. + assert not handler.context_is_stopped + assert not handler.context_is_killed + + # TODO: Test with _generate_until_asyncio_cancelled server handler diff --git a/lib/bindings/python/tests/test_cancellation/test_server_context_cancel.py b/lib/bindings/python/tests/test_cancellation/test_server_context_cancel.py new file mode 100644 index 0000000000..f644314f44 --- /dev/null +++ b/lib/bindings/python/tests/test_cancellation/test_server_context_cancel.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + + +@pytest.mark.asyncio +async def test_server_context_cancel(server, client): + _, handler = server + stream = await client.generate("_generate_and_cancel_context") + + iteration_count = 0 + try: + async for annotated in stream: + number = annotated.data() + print(f"Received iteration: {number}") + assert number == iteration_count + iteration_count += 1 + assert False, "Stream completed without cancellation" + except ValueError as e: + # Verify the expected cancellation exception is received + # TODO: Should this be a asyncio.CancelledError? + assert str(e) == "Stream ended before generation completed" + + # Verify server context cancellation status + assert handler.context_is_stopped + assert not handler.context_is_killed diff --git a/lib/bindings/python/tests/test_cancellation/test_server_raise_cancelled.py b/lib/bindings/python/tests/test_cancellation/test_server_raise_cancelled.py new file mode 100644 index 0000000000..14225d5981 --- /dev/null +++ b/lib/bindings/python/tests/test_cancellation/test_server_raise_cancelled.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + + +@pytest.mark.asyncio +async def test_server_raise_cancelled(server, client): + _, handler = server + stream = await client.generate("_generate_and_raise_cancelled") + + iteration_count = 0 + try: + async for annotated in stream: + number = annotated.data() + print(f"Received iteration: {number}") + assert number == iteration_count + iteration_count += 1 + assert False, "Stream completed without cancellation" + except ValueError as e: + # Verify the expected cancellation exception is received + # TODO: Should this be a asyncio.CancelledError? + assert ( + str(e) + == "a python exception was caught while processing the async generator: CancelledError: " + ) + + # Verify server context cancellation status + # TODO: Server to gracefully stop the stream? + assert not handler.context_is_stopped + assert not handler.context_is_killed diff --git a/lib/bindings/python/tests/test_kv_bindings.py b/lib/bindings/python/tests/test_kv_bindings.py index 7e7aaebc5e..f1a9e426e9 100644 --- a/lib/bindings/python/tests/test_kv_bindings.py +++ b/lib/bindings/python/tests/test_kv_bindings.py @@ -15,8 +15,6 @@ import asyncio -import subprocess -from time import sleep from typing import List import pytest @@ -37,24 +35,6 @@ pytestmark = pytest.mark.pre_merge -@pytest.fixture(scope="module", autouse=True) -def setup_and_teardown(): - # Setup code - nats_server = subprocess.Popen(["nats-server", "-js"]) - etcd = subprocess.Popen(["etcd"]) - print("Setting up resources") - - sleep(5) # wait for nats-server and etcd to start - yield - - # Teardown code - print("Tearing down resources") - nats_server.terminate() - nats_server.wait() - etcd.terminate() - etcd.wait() - - @pytest.fixture(scope="module") async def distributed_runtime(): loop = asyncio.get_running_loop() diff --git a/lib/llm/src/http/client.rs b/lib/llm/src/http/client.rs index 4809979856..1a9d8cf338 100644 --- a/lib/llm/src/http/client.rs +++ b/lib/llm/src/http/client.rs @@ -8,7 +8,7 @@ //! for performance analysis. use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use std::time::Instant; @@ -64,6 +64,8 @@ pub struct HttpRequestContext { created_at: Instant, /// Whether the request has been stopped stopped: Arc, + /// Child contexts to be stopped if this is stopped + child_context: Arc>>>, } impl HttpRequestContext { @@ -74,6 +76,7 @@ impl HttpRequestContext { cancel_token: CancellationToken::new(), created_at: Instant::now(), stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)), + child_context: Arc::new(Mutex::new(Vec::new())), } } @@ -84,6 +87,7 @@ impl HttpRequestContext { cancel_token: CancellationToken::new(), created_at: Instant::now(), stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)), + child_context: Arc::new(Mutex::new(Vec::new())), } } @@ -95,6 +99,7 @@ impl HttpRequestContext { cancel_token: self.cancel_token.child_token(), created_at: Instant::now(), stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)), + child_context: Arc::new(Mutex::new(Vec::new())), } } @@ -105,6 +110,7 @@ impl HttpRequestContext { cancel_token: self.cancel_token.child_token(), created_at: Instant::now(), stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)), + child_context: Arc::new(Mutex::new(Vec::new())), } } @@ -144,17 +150,55 @@ impl AsyncEngineContext for HttpRequestContext { } fn stop(&self) { + // Clone child Arcs to avoid deadlock if parent is accidentally linked under child + let children = self + .child_context + .lock() + .expect("Failed to lock child context") + .iter() + .cloned() + .collect::>(); + for child in children { + child.stop(); + } + self.stopped .store(true, std::sync::atomic::Ordering::Release); self.cancel_token.cancel(); } fn stop_generating(&self) { + // Clone child Arcs to avoid deadlock if parent is accidentally linked under child + let children = self + .child_context + .lock() + .expect("Failed to lock child context") + .iter() + .cloned() + .collect::>(); + for child in children { + child.stop_generating(); + } + // For HTTP clients, stop_generating is the same as stop - self.stop(); + self.stopped + .store(true, std::sync::atomic::Ordering::Release); + self.cancel_token.cancel(); } fn kill(&self) { + // Clone child Arcs to avoid deadlock if parent is accidentally linked under child + let children = self + .child_context + .lock() + .expect("Failed to lock child context") + .iter() + .cloned() + .collect::>(); + for child in children { + child.kill(); + } + self.stopped .store(true, std::sync::atomic::Ordering::Release); self.cancel_token.cancel(); @@ -176,6 +220,13 @@ impl AsyncEngineContext for HttpRequestContext { // For HTTP clients, killed is the same as stopped self.cancel_token.cancelled().await; } + + fn link_child(&self, child: Arc) { + self.child_context + .lock() + .expect("Failed to lock child context") + .push(child); + } } /// Base HTTP client with common functionality diff --git a/lib/llm/src/migration.rs b/lib/llm/src/migration.rs index 7136e0d8d6..6dfd20aaa4 100644 --- a/lib/llm/src/migration.rs +++ b/lib/llm/src/migration.rs @@ -17,8 +17,8 @@ use crate::{ use dynamo_runtime::{ pipeline::{ - AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine, - SingleIn, async_trait, + AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream, + ServerStreamingEngine, SingleIn, async_trait, }, protocols::{annotated::Annotated, maybe_error::MaybeError}, }; @@ -29,6 +29,11 @@ pub struct Migration { impl Migration { pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result> { + tracing::debug!( + "model {} migration limit {}", + mdc.display_name, + mdc.migration_limit + ); Ok(Arc::new(Self { migration_limit: mdc.migration_limit, })) @@ -50,20 +55,30 @@ impl next: ServerStreamingEngine>, ) -> Result>> { let (preprocessed_request, context) = request.transfer(()); + let context_id = context.id().to_string(); let engine_ctx = context.context(); + let engine_ctx_ = engine_ctx.clone(); let retry_manager = - RetryManager::build(preprocessed_request, next, self.migration_limit).await?; - let response_stream = stream::unfold(retry_manager, |mut retry_manager| async move { - retry_manager - .next() - .await - .map(|response| (response, retry_manager)) + RetryManager::build(context_id, preprocessed_request, next, self.migration_limit) + .await?; + let response_stream = stream::unfold(retry_manager, move |mut retry_manager| { + let engine_ctx = engine_ctx_.clone(); + async move { + if engine_ctx.is_stopped() || engine_ctx.is_killed() { + return None; // Stop if the context is cancelled or stopped + } + retry_manager + .next() + .await + .map(|response| (response, retry_manager)) + } }); Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx)) } } struct RetryManager { + context_id: String, request: PreprocessedRequest, next_generate: ServerStreamingEngine>, next_stream: Option>>, @@ -72,11 +87,13 @@ struct RetryManager { impl RetryManager { pub async fn build( + context_id: String, preprocessed_request: PreprocessedRequest, next: ServerStreamingEngine>, retries_left: u32, ) -> Result { let mut slf = Self { + context_id, request: preprocessed_request, next_generate: next, next_stream: None, @@ -123,8 +140,7 @@ impl RetryManager { let mut response_stream: Option>>> = None; while self.retries_left > 0 { self.retries_left -= 1; - // TODO: Is there anything needed to pass between context? - let request = SingleIn::new(self.request.clone()); + let request = Context::with_id(self.request.clone(), self.context_id.clone()); response_stream = Some(self.next_generate.generate(request).await); if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() && let Some(req_err) = err.downcast_ref::() @@ -230,15 +246,22 @@ mod tests { num_responses: usize, token_offset: u32, call_count: Arc, + context_id: String, } impl MockEngine { - fn new(behavior: MockBehavior, num_responses: usize, token_offset: u32) -> Self { + fn new( + behavior: MockBehavior, + num_responses: usize, + token_offset: u32, + context_id: String, + ) -> Self { Self { behavior, num_responses, token_offset, call_count: Arc::new(AtomicU32::new(0)), + context_id, } } } @@ -256,7 +279,14 @@ mod tests { request: SingleIn, ) -> Result>> { let call_num = self.call_count.fetch_add(1, Ordering::SeqCst); - let (preprocessed_request, _) = request.transfer(()); + let (preprocessed_request, context) = request.transfer(()); + + // Assert that the context_id matches the expected one + assert_eq!( + context.id().to_string(), + self.context_id, + "Context ID mismatch" + ); // Calculate how many responses we've already generated based on request token_ids // Initial request has [1, 2, 3], so anything beyond that are generated responses @@ -331,7 +361,7 @@ mod tests { } let stream = tokio_stream::wrappers::ReceiverStream::new(rx); - let ctx = Arc::new(Controller::default()); + let ctx = Arc::new(Controller::new(self.context_id.clone())); Ok(dynamo_runtime::pipeline::ResponseStream::new( Box::pin(stream), ctx, @@ -362,7 +392,7 @@ mod tests { }); let stream = tokio_stream::wrappers::ReceiverStream::new(rx); - let ctx = Arc::new(Controller::default()); + let ctx = Arc::new(Controller::new(self.context_id.clone())); Ok(dynamo_runtime::pipeline::ResponseStream::new( Box::pin(stream), ctx, @@ -398,7 +428,7 @@ mod tests { }); let stream = tokio_stream::wrappers::ReceiverStream::new(rx); - let ctx = Arc::new(Controller::default()); + let ctx = Arc::new(Controller::new(self.context_id.clone())); Ok(dynamo_runtime::pipeline::ResponseStream::new( Box::pin(stream), ctx, @@ -415,7 +445,7 @@ mod tests { }); let stream = tokio_stream::wrappers::ReceiverStream::new(rx); - let ctx = Arc::new(Controller::default()); + let ctx = Arc::new(Controller::new(self.context_id.clone())); Ok(dynamo_runtime::pipeline::ResponseStream::new( Box::pin(stream), ctx, @@ -450,7 +480,7 @@ mod tests { }); let stream = tokio_stream::wrappers::ReceiverStream::new(rx); - let ctx = Arc::new(Controller::default()); + let ctx = Arc::new(Controller::new(self.context_id.clone())); Ok(dynamo_runtime::pipeline::ResponseStream::new( Box::pin(stream), ctx, @@ -465,12 +495,18 @@ mod tests { #[tokio::test] async fn test_retry_manager_no_migration() { dynamo_runtime::logging::init(); + let context_id = uuid::Uuid::new_v4().to_string(); let request = create_mock_request(10); - let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100)); + let mock_engine = Arc::new(MockEngine::new( + MockBehavior::Success, + 10, + 100, + context_id.clone(), + )); let next_generate: ServerStreamingEngine> = mock_engine; - let mut retry_manager = RetryManager::build(request, next_generate, 0) + let mut retry_manager = RetryManager::build(context_id, request, next_generate, 0) .await .expect("Failed to build RetryManager"); @@ -497,12 +533,18 @@ mod tests { #[tokio::test] async fn test_retry_manager_new_request_migration() { dynamo_runtime::logging::init(); + let context_id = uuid::Uuid::new_v4().to_string(); let request = create_mock_request(10); - let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100)); + let mock_engine = Arc::new(MockEngine::new( + MockBehavior::FailThenSuccess, + 10, + 100, + context_id.clone(), + )); let next_generate: ServerStreamingEngine> = mock_engine; - let mut retry_manager = RetryManager::build(request, next_generate, 3) + let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) .await .expect("Failed to build RetryManager"); @@ -530,16 +572,18 @@ mod tests { async fn test_retry_manager_ongoing_request_migration() { dynamo_runtime::logging::init(); + let context_id = uuid::Uuid::new_v4().to_string(); let request = create_mock_request(10); let mock_engine = Arc::new(MockEngine::new( MockBehavior::MidStreamFail { fail_after: 5 }, 10, 100, + context_id.clone(), )); let next_generate: ServerStreamingEngine> = mock_engine; - let mut retry_manager = RetryManager::build(request, next_generate, 3) + let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) .await .expect("Failed to build RetryManager"); @@ -567,13 +611,19 @@ mod tests { #[tokio::test] async fn test_retry_manager_new_request_migration_indefinite_failure() { dynamo_runtime::logging::init(); + let context_id = uuid::Uuid::new_v4().to_string(); let request = create_mock_request(0); - let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100)); + let mock_engine = Arc::new(MockEngine::new( + MockBehavior::AlwaysFail, + 0, + 100, + context_id.clone(), + )); let next_generate: ServerStreamingEngine> = mock_engine; // Should fail to build due to initial stream creation failure after exhausting all 3 retries - let retry_manager_result = RetryManager::build(request, next_generate, 3).await; + let retry_manager_result = RetryManager::build(context_id, request, next_generate, 3).await; assert!(retry_manager_result.is_err()); if let Err(error) = retry_manager_result { @@ -588,16 +638,18 @@ mod tests { #[tokio::test] async fn test_retry_manager_ongoing_request_migration_indefinite_failure() { dynamo_runtime::logging::init(); + let context_id = uuid::Uuid::new_v4().to_string(); let request = create_mock_request(10); let mock_engine = Arc::new(MockEngine::new( MockBehavior::MidStreamFailAlways { fail_after: 3 }, 10, 100, + context_id.clone(), )); let next_generate: ServerStreamingEngine> = mock_engine; - let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries + let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries .await .expect("Failed to build RetryManager"); @@ -638,16 +690,18 @@ mod tests { #[tokio::test] async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() { dynamo_runtime::logging::init(); + let context_id = uuid::Uuid::new_v4().to_string(); let request = create_mock_request(10); let mock_engine = Arc::new(MockEngine::new( MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 }, 10, 100, + context_id.clone(), )); let next_generate: ServerStreamingEngine> = mock_engine; - let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries + let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries .await .expect("Failed to build RetryManager"); diff --git a/lib/llm/src/perf.rs b/lib/llm/src/perf.rs index 68807e1b70..57d87e4428 100644 --- a/lib/llm/src/perf.rs +++ b/lib/llm/src/perf.rs @@ -552,5 +552,9 @@ pub mod tests { async fn killed(&self) { // No-op for testing } + + fn link_child(&self, _: Arc) { + // No-op for testing + } } } diff --git a/lib/llm/src/perf/logprobs.rs b/lib/llm/src/perf/logprobs.rs index 0defe57526..b2e62ed71a 100644 --- a/lib/llm/src/perf/logprobs.rs +++ b/lib/llm/src/perf/logprobs.rs @@ -1613,5 +1613,9 @@ mod tests { async fn killed(&self) { // No-op for testing } + + fn link_child(&self, _: Arc) { + // No-op for testing + } } } diff --git a/lib/runtime/src/engine.rs b/lib/runtime/src/engine.rs index c054c681b9..88593b897e 100644 --- a/lib/runtime/src/engine.rs +++ b/lib/runtime/src/engine.rs @@ -159,6 +159,12 @@ pub trait AsyncEngineContext: Send + Sync + Debug { /// terminate without draining the remaining items in the stream. This is implementation /// specific and may not be supported by all engines. fn kill(&self); + + /// Links child AsyncEngineContext to this AsyncEngineContext. If the `stop_generating`, `stop` + /// or `kill` on this AsyncEngineContext is called, the same method is called on all linked + /// child AsyncEngineContext, in the order they are linked, and then the method on this + /// AsyncEngineContext continues. + fn link_child(&self, child: Arc); } /// Provides access to the [`AsyncEngineContext`] associated with an engine operation. diff --git a/lib/runtime/src/pipeline/context.rs b/lib/runtime/src/pipeline/context.rs index 7833fc5c4b..e78c5ad4bd 100644 --- a/lib/runtime/src/pipeline/context.rs +++ b/lib/runtime/src/pipeline/context.rs @@ -22,7 +22,7 @@ //! registry and visitors. `StreamAdaptors` will amend themselves to the [`StreamContext`] to allow for the use std::ops::{Deref, DerefMut}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use super::{AsyncEngineContext, AsyncEngineContextProvider, Data}; use crate::engine::AsyncEngineController; @@ -300,6 +300,10 @@ impl AsyncEngineContext for StreamContext { async fn killed(&self) { self.controller.killed().await } + + fn link_child(&self, child: Arc) { + self.controller.link_child(child); + } } impl AsyncEngineContextProvider for StreamContext { @@ -331,12 +335,18 @@ pub struct Controller { id: String, tx: Sender, rx: Receiver, + child_context: Mutex>>, } impl Controller { pub fn new(id: String) -> Self { let (tx, rx) = channel(State::Live); - Self { id, tx, rx } + Self { + id, + tx, + rx, + child_context: Mutex::new(Vec::new()), + } } pub fn id(&self) -> &str { @@ -383,16 +393,59 @@ impl AsyncEngineContext for Controller { } fn stop_generating(&self) { - self.stop(); + // Clone child Arcs to avoid deadlock if parent is accidentally linked under child + let children = self + .child_context + .lock() + .expect("Failed to lock child context") + .iter() + .cloned() + .collect::>(); + for child in children { + child.stop_generating(); + } + + let _ = self.tx.send(State::Stopped); } fn stop(&self) { + // Clone child Arcs to avoid deadlock if parent is accidentally linked under child + let children = self + .child_context + .lock() + .expect("Failed to lock child context") + .iter() + .cloned() + .collect::>(); + for child in children { + child.stop(); + } + let _ = self.tx.send(State::Stopped); } fn kill(&self) { + // Clone child Arcs to avoid deadlock if parent is accidentally linked under child + let children = self + .child_context + .lock() + .expect("Failed to lock child context") + .iter() + .cloned() + .collect::>(); + for child in children { + child.kill(); + } + let _ = self.tx.send(State::Killed); } + + fn link_child(&self, child: Arc) { + self.child_context + .lock() + .expect("Failed to lock child context") + .push(child); + } } #[cfg(test)] diff --git a/tests/fault_tolerance/README.md b/tests/fault_tolerance/README.md index 1fa125f034..4e758b7ac6 100644 --- a/tests/fault_tolerance/README.md +++ b/tests/fault_tolerance/README.md @@ -14,7 +14,7 @@ Tests worker fault tolerance with migration support using the `test_request_migr - Model: `deepseek-ai/DeepSeek-R1-Distill-Llama-8B` - `--enforce-eager`, `--gpu-memory-utilization 0.45` - `--max-model-len 8192`, `--migration-limit 3` -3. Waits for both workers to be fully ready (looking for "Reading Events from" messages) +3. Waits for both workers to be fully ready (health check returns "ready" status) 4. Sends a test request ("Who are you?", 100 tokens) to determine which worker handles requests 5. Determines primary/backup worker roles based on round-robin routing and log analysis 6. Sends a long completion request ("Tell me a long long long story about yourself?", 8000 tokens) in a separate thread @@ -22,12 +22,56 @@ Tests worker fault tolerance with migration support using the `test_request_migr 8. Verifies the request completes successfully despite the worker failure (with 240s timeout) 9. Checks that the frontend logs contain "Stream disconnected... recreating stream..." indicating migration occurred +### `test_request_cancellation.py` + +Tests request cancellation functionality across multiple API endpoints and deployment configurations. Contains three test functions: + +#### `test_request_cancellation_vllm` +Tests basic request cancellation with a single worker: + +0. Downloads the DeepSeek-R1-Distill-Llama-8B model from HuggingFace if not already cached +1. Starts a Dynamo frontend using `python -m dynamo.frontend` with debug logging enabled +2. Starts a single worker using `python3 -m dynamo.vllm` with specific configuration: + - Model: `deepseek-ai/DeepSeek-R1-Distill-Llama-8B` + - `--enforce-eager`, `--gpu-memory-utilization 0.45`, `--max-model-len 8192`, `--migration-limit 3` + - Debug logging enabled on port 8081 +3. Tests request cancellation across three scenarios: + - **Completion API**: `/v1/completions` endpoint cancellation + - **Chat Completion API (non-streaming)**: `/v1/chat/completions` endpoint cancellation + - **Chat Completion API (streaming)**: `/v1/chat/completions` with streaming cancellation +4. For each scenario: + - Sends a long request with 1-second timeout to trigger cancellation + - Validates that cancellation messages appear in both frontend and worker logs + - Uses incremental log offset tracking to avoid false positives from previous tests +5. Checks for specific cancellation patterns: + - Frontend log: "issued control message Kill to sender" + - Worker log: "Aborted Request ID: " matching the "New Request ID: " + +#### `test_request_cancellation_vllm_decode` +Tests request cancellation during disaggregated decode phase: + +0. Downloads the DeepSeek-R1-Distill-Llama-8B model from HuggingFace if not already cached +1. Starts a Dynamo frontend using `python -m dynamo.frontend` with debug logging enabled +2. Starts a prefill worker using `python3 -m dynamo.vllm --is-prefill-worker` on port 8082 +3. Starts a decode worker using `python3 -m dynamo.vllm` on port 8081 +4. Tests completion request cancellation in the disaggregated setup +5. Validates cancellation messages appear in prefill worker, decode worker, and frontend logs +6. Checks for specific patterns: + - Frontend log: "issued control message Kill to sender" + - Decode worker log: "Aborted Request ID: " + - Prefill worker log: "New Prefill Request ID: " + +#### `test_request_cancellation_vllm_prefill` +Tests request cancellation during disaggregated prefill phase: + +- (Skipped until request cancellation can cancel before receiving the first response) + ## Prerequisites -- vLLM backend installed (`pip install ai-dynamo-vllm`) +- vLLM backend installed - NATS and etcd services running (provided by `runtime_services` fixture) - Access to DeepSeek-R1-Distill-Llama-8B model (automatically downloaded from HuggingFace) -- Sufficient GPU memory (test uses 0.45 GPU memory utilization) +- Sufficient GPU memory ## Running the Tests @@ -35,16 +79,12 @@ To run the fault tolerance tests: ```bash # Run all fault tolerance tests -pytest /workspace/tests/fault_tolerance - -# Run specific test with verbose output -pytest /workspace/tests/fault_tolerance/test_request_migration.py::test_request_migration_vllm -v - -# Run with specific markers pytest -m "e2e and vllm" /workspace/tests/fault_tolerance -# Run with debug logging +# Run specific test functions with debug logging pytest /workspace/tests/fault_tolerance/test_request_migration.py::test_request_migration_vllm -v -s +pytest /workspace/tests/fault_tolerance/test_request_cancellation.py::test_request_cancellation_vllm -v -s +pytest /workspace/tests/fault_tolerance/test_request_cancellation.py::test_request_cancellation_vllm_decode -v -s ``` ## Test Markers @@ -61,11 +101,11 @@ pytest /workspace/tests/fault_tolerance/test_request_migration.py::test_request_ ## Expected Test Duration -The test typically takes 2-3 minutes to complete, including: +The tests typically take 2-3 minutes to complete each, including: - Model download/loading time (if not cached) - can take 1-2 minutes for first run - Worker startup and registration - Request processing and response validation -- Worker failure simulation and migration +- Worker failure simulation and migration (for migration test) / Request cancellation validation (for cancellation tests) - Cleanup ## Troubleshooting @@ -74,7 +114,10 @@ If tests fail: 1. Check that NATS and etcd services are running 2. Verify vLLM backend is properly installed -3. Ensure sufficient GPU memory is available (test requires ~45% GPU memory) +3. Ensure sufficient GPU memory is available 4. Check internet connectivity for model download from HuggingFace 5. Review test logs for specific error messages 6. Verify that the DeepSeek-R1-Distill-Llama-8B model can be accessed +7. For cancellation tests: Check that timeout-based cancellation is working properly and cancellation patterns appear in logs +8. For migration tests: Verify worker process termination and stream recreation behavior +9. For disaggregated cancellation tests: Ensure both prefill and decode workers are properly started and cancellation works across the disaggregated setup diff --git a/tests/fault_tolerance/test_request_cancellation.py b/tests/fault_tolerance/test_request_cancellation.py new file mode 100644 index 0000000000..1c7cf688ab --- /dev/null +++ b/tests/fault_tolerance/test_request_cancellation.py @@ -0,0 +1,506 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import shutil +import time + +import pytest +import requests +from huggingface_hub import snapshot_download + +from tests.utils.managed_process import ManagedProcess + +logger = logging.getLogger(__name__) + + +class DynamoFrontendProcess(ManagedProcess): + """Process manager for Dynamo frontend""" + + def __init__(self, request): + command = ["python", "-m", "dynamo.frontend"] + + # Set debug logging environment + env = os.environ.copy() + env["DYN_LOG"] = "debug" + + log_dir = f"{request.node.name}_frontend" + + # Clean up any existing log directory from previous runs + try: + shutil.rmtree(log_dir) + logger.info(f"Cleaned up existing log directory: {log_dir}") + except FileNotFoundError: + # Directory doesn't exist, which is fine + pass + + super().__init__( + command=command, + env=env, + display_output=True, + terminate_existing=True, + log_dir=log_dir, + ) + + +class DynamoWorkerProcess(ManagedProcess): + """Process manager for Dynamo worker with vLLM backend""" + + def __init__(self, request, is_prefill: bool = False): + command = [ + "python3", + "-m", + "dynamo.vllm", + "--model", + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + "--enforce-eager", + "--gpu-memory-utilization", + "0.45", + "--max-model-len", + "8192", + "--migration-limit", + "3", + ] + + # Add prefill worker flag if needed + if is_prefill: + command.append("--is-prefill-worker") + + # Set port based on worker type + port = "8082" if is_prefill else "8081" + + # Set debug logging environment + env = os.environ.copy() + env["DYN_LOG"] = "debug" + env["DYN_SYSTEM_ENABLED"] = "true" + env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' + env["DYN_SYSTEM_PORT"] = port + + # Set log directory based on worker type + worker_type = "prefill_worker" if is_prefill else "worker" + log_dir = f"{request.node.name}_{worker_type}" + + # Clean up any existing log directory from previous runs + try: + shutil.rmtree(log_dir) + logger.info(f"Cleaned up existing log directory: {log_dir}") + except FileNotFoundError: + # Directory doesn't exist, which is fine + pass + + super().__init__( + command=command, + env=env, + health_check_urls=[(f"http://localhost:{port}/health", self.is_ready)], + timeout=300, + display_output=True, + terminate_existing=False, + log_dir=log_dir, + ) + + self.is_prefill = is_prefill + + def get_pid(self): + """Get the PID of the worker process""" + return self.proc.pid if self.proc else None + + def is_ready(self, response) -> bool: + """Check the health of the worker process""" + try: + data = response.json() + if data.get("status") == "ready": + worker_type = "Prefill worker" if self.is_prefill else "Worker" + logger.info(f"{worker_type} status is ready") + return True + worker_type = "Prefill worker" if self.is_prefill else "Worker" + logger.warning(f"{worker_type} status is not ready: {data.get('status')}") + except ValueError: + worker_type = "Prefill worker" if self.is_prefill else "Worker" + logger.warning(f"{worker_type} health response is not valid JSON") + return False + + +def download_model() -> None: + """ + Download the DeepSeek-R1-Distill-Llama-8B model from HuggingFace Hub if not already cached. + """ + model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" + logger.info(f"Caching model {model_id}...") + + max_retries = 5 + retry_delay = 30 # seconds + + for attempt in range(max_retries): + try: + # Download the model to the default cache directory + # This will skip download if the model is already cached + snapshot_download( + repo_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + repo_type="model", + local_files_only=False, + ) + logger.info(f"Model {model_id} is ready for use") + return # Success, exit the function + except Exception as e: + if attempt < max_retries - 1: # Not the last attempt + logger.warning( + f"Failed to download model {model_id} (attempt {attempt + 1}/{max_retries}): {e}" + ) + logger.info(f"Retrying in {retry_delay} seconds...") + time.sleep(retry_delay) + else: # Last attempt failed + logger.error( + f"Failed to download model {model_id} after {max_retries} attempts: {e}" + ) + raise + + +def send_completion_request( + prompt: str, max_tokens: int, timeout: int = 120 +) -> requests.Response: + """Send a completion request to the frontend""" + payload = { + "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + "prompt": prompt, + "max_tokens": max_tokens, + } + + headers = {"Content-Type": "application/json"} + + logger.info( + f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}" + ) + + session = requests.Session() + try: + response = session.post( + "http://localhost:8080/v1/completions", + headers=headers, + json=payload, + timeout=timeout, + ) + logger.info(f"Received response with status code: {response.status_code}") + return response + except requests.exceptions.Timeout: + logger.error(f"Request timed out after {timeout} seconds") + raise + except requests.exceptions.RequestException as e: + logger.error(f"Request failed with error: {e}") + raise + + +def send_chat_completion_request( + prompt: str, max_tokens: int, timeout: int = 120, stream: bool = False +) -> requests.Response: + """Send a chat completion request to the frontend""" + payload = { + "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "stream": stream, + } + + headers = {"Content-Type": "application/json"} + + logger.info( + f"Sending chat completion request (stream={stream}) with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}" + ) + + session = requests.Session() + try: + response = session.post( + "http://localhost:8080/v1/chat/completions", + headers=headers, + json=payload, + timeout=timeout, + stream=stream, + ) + logger.info(f"Received response with status code: {response.status_code}") + return response + except requests.exceptions.Timeout: + logger.error(f"Request timed out after {timeout} seconds") + raise + except requests.exceptions.RequestException as e: + logger.error(f"Request failed with error: {e}") + raise + + +def send_request_and_cancel(request_type: str = "completion", timeout: int = 1): + """Send a request with short timeout to trigger cancellation""" + logger.info(f"Sending {request_type} request to be cancelled...") + + prompt = "Tell me a very long and detailed story about the history of artificial intelligence, including all major milestones, researchers, and breakthroughs?" + try: + if request_type == "completion": + response = send_completion_request(prompt, 8000, timeout) + elif request_type == "chat_completion": + response = send_chat_completion_request(prompt, 8000, timeout, False) + elif request_type == "chat_completion_stream": + response = send_chat_completion_request(prompt, 8000, timeout, True) + # Read a few responses and then disconnect + if response.status_code == 200: + itr_count, max_itr = 0, 5 + try: + for res in response.iter_lines(): + logger.info(f"Received response {itr_count + 1}: {res[:50]}...") + itr_count += 1 + if itr_count >= max_itr: + break + time.sleep(0.1) + except Exception as e: + pytest.fail(f"Stream reading failed: {e}") + + response.close() + raise Exception("Closed response") + else: + pytest.fail(f"Unknown request type: {request_type}") + + pytest.fail( + f"{request_type} request completed unexpectedly - should have been cancelled" + ) + except Exception as e: + logger.info(f"{request_type} request was cancelled: {e}") + + +def read_log_content(log_path: str | None) -> str: + """Read log content from a file""" + if log_path is None: + pytest.fail("Log path is None - cannot read log content") + + try: + with open(log_path, "r") as f: + return f.read() + except Exception as e: + pytest.fail(f"Could not read log file {log_path}: {e}") + + +def strip_ansi_codes(text: str) -> str: + """Remove ANSI color codes from text""" + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + +def verify_request_cancelled( + frontend_process: DynamoFrontendProcess, + worker_process: DynamoWorkerProcess, + prefill_worker_process: DynamoWorkerProcess | None = None, + frontend_log_offset: int = 0, + worker_log_offset: int = 0, + prefill_worker_log_offset: int = 0, +) -> tuple[int, int]: + """Verify that the worker and frontend logs contain cancellation messages + + Returns: + tuple: (new_worker_log_length, new_frontend_log_length) + """ + + # Check worker log for cancellation pattern + worker_log_content = read_log_content(worker_process._log_path) + new_worker_content = worker_log_content[worker_log_offset:] + + # Find request ID from "New Request ID: " line + request_id = None + for line in new_worker_content.split("\n"): + # Strip ANSI codes and whitespace for pattern matching + clean_line = strip_ansi_codes(line).strip() + if "New Request ID: " in clean_line: + # Extract ID from the end of the line + parts = clean_line.split("New Request ID: ") + if len(parts) > 1: + request_id = parts[-1].strip() + break + if request_id is None: + pytest.fail("Could not find 'New Request ID: ' pattern in worker log") + + # Check if the same request ID was cancelled + has_worker_cancellation = False + cancellation_pattern = f"Aborted Request ID: {request_id}" + for line in new_worker_content.split("\n"): + # Strip ANSI codes and whitespace for pattern matching + clean_line = strip_ansi_codes(line).strip() + if clean_line.endswith(cancellation_pattern): + has_worker_cancellation = True + break + if not has_worker_cancellation: + pytest.fail( + f"Could not find 'Aborted Request ID: {request_id}' pattern in worker log" + ) + + # Check if the same request ID was remote prefilled + if prefill_worker_process is not None: + prefill_worker_log_content = read_log_content(prefill_worker_process._log_path) + new_prefill_worker_content = prefill_worker_log_content[ + prefill_worker_log_offset: + ] + + has_remote_prefill = False + remote_prefill_pattern = f"New Prefill Request ID: {request_id}" + for line in new_prefill_worker_content.split("\n"): + clean_line = strip_ansi_codes(line).strip() + if clean_line.endswith(remote_prefill_pattern): + has_remote_prefill = True + break + if not has_remote_prefill: + pytest.fail( + f"Could not find 'New Prefill Request ID: {request_id}' pattern in prefill worker log" + ) + + # Check frontend log for cancellation issued pattern + frontend_log_content = read_log_content(frontend_process._log_path) + new_frontend_content = frontend_log_content[frontend_log_offset:] + + has_kill_message = False + kill_message = "issued control message Kill to sender" + for line in new_frontend_content.split("\n"): + # Strip ANSI codes and whitespace for pattern matching + clean_line = strip_ansi_codes(line).strip() + if clean_line.endswith(kill_message): + has_kill_message = True + break + if not has_kill_message: + pytest.fail("Could not find cancellation issued in frontend log") + + return len(frontend_log_content), len(worker_log_content) + + +@pytest.mark.vllm +@pytest.mark.gpu_1 +@pytest.mark.e2e +@pytest.mark.slow +def test_request_cancellation_vllm(request, runtime_services): + """ + End-to-end test for request cancellation functionality. + + This test verifies that when a request is cancelled by the client, + the system properly handles the cancellation and cleans up resources + on the worker side. Tests three scenarios: + 1. Completion request + 2. Chat completion request (non-streaming) + 3. Chat completion request (streaming) + """ + # Step 0: Download the model from HuggingFace if not already cached + download_model() + + # Step 1: Start the frontend + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start a single worker + logger.info("Starting worker...") + worker = DynamoWorkerProcess(request) + + with worker: + logger.info(f"Worker PID: {worker.get_pid()}") + + # TODO: Why the model is not immediately available at the frontend after health check + # returns success. + time.sleep(2) + + # Step 3: Test request cancellation + frontend_log_offset, worker_log_offset = 0, 0 + + test_scenarios = [ + ("completion", "Completion request cancellation"), + ("chat_completion", "Chat completion request cancellation"), + ( + "chat_completion_stream", + "Chat completion stream request cancellation", + ), + ] + + for i, (request_type, description) in enumerate(test_scenarios, 1): + logger.info(f"Testing {description.lower()}...") + send_request_and_cancel(request_type) + + logger.info( + "Checking for cancellation messages in worker and frontend logs..." + ) + time.sleep(0.5) # Make sure logs are written before proceeding + frontend_log_offset, worker_log_offset = verify_request_cancelled( + frontend, + worker, + frontend_log_offset=frontend_log_offset, + worker_log_offset=worker_log_offset, + ) + + logger.info(f"{description} detected successfully") + + logger.info( + "All request cancellation tests completed successfully - request cancellation is working correctly" + ) + + +@pytest.mark.vllm +@pytest.mark.gpu_1 +@pytest.mark.e2e +@pytest.mark.slow +def test_request_cancellation_vllm_decode(request, runtime_services): + """ + End-to-end test for request cancellation functionality with remote prefill. + + This test verifies that when a request is cancelled by the client, + the system properly handles the cancellation and cleans up resources + on the decode worker side in a disaggregated setup. + """ + # Step 0: Download the model from HuggingFace if not already cached + download_model() + + # Step 1: Start the frontend + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start the prefill worker + logger.info("Starting prefill worker...") + prefill_worker = DynamoWorkerProcess(request, is_prefill=True) + + with prefill_worker: + logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") + + # Step 3: Start the decode worker + logger.info("Starting decode worker...") + decode_worker = DynamoWorkerProcess(request, is_prefill=False) + + with decode_worker: + logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") + + # TODO: Why the model is not immediately available at the frontend after health check + # returns success. + time.sleep(2) + + # Step 4: Test request cancellation for completion scenario only + logger.info( + "Testing completion request cancellation in disaggregated mode..." + ) + send_request_and_cancel("completion") + + logger.info( + "Checking for cancellation messages in decode worker, prefill worker, and frontend logs..." + ) + time.sleep(0.5) # Make sure logs are written before proceeding + verify_request_cancelled(frontend, decode_worker, prefill_worker) + + logger.info( + "Completion request cancellation detected successfully in disaggregated mode" + ) + + logger.info( + "Request cancellation test completed successfully in disaggregated mode - request cancellation is working correctly" + ) + + +@pytest.mark.skip(reason="require cancel support before receiving 1st response") +@pytest.mark.vllm +@pytest.mark.gpu_1 +@pytest.mark.e2e +@pytest.mark.slow +def test_request_cancellation_vllm_prefill(request, runtime_services): + """ + End-to-end test for request cancellation on remote prefill. + + This test verifies that when a request is cancelled by the client during the + prefill phase, the system properly handles the cancellation and cleans up + resources on the prefill worker and decode worker sides in a disaggregated + setup. + """