diff --git a/docs/guides/backend.md b/docs/guides/backend.md index c087c7a9ba..9af77d73fe 100644 --- a/docs/guides/backend.md +++ b/docs/guides/backend.md @@ -108,3 +108,31 @@ Example 4: Multiple component in a pipeline. In the P/D disaggregated setup you would have `deepseek-distill-llama8b.prefill.generate` (possibly multiple instances of this) and `deepseek-distill-llama8b.decode.generate`. +## Migrate Ongoing Requests + +A Python worker may need to be shut down promptly, for example when the node running the worker is to be reclaimed and there isn't enough time to complete all ongoing requests before the shutdown deadline. + +In such cases, you can signal incomplete responses by raising a `GeneratorExit` exception in your generate loop. This will immediately close the response stream, signaling to the frontend that the stream is incomplete. With request migration enabled (see the [`migration_limit`](../architecture/request_migration.md) parameter), the frontend will automatically migrate the partially completed request to another worker instance, if available, to be completed. + +> [!WARNING] +> We will update the `GeneratorExit` exception to a new Dynamo exception. Please expect minor code breaking change in the near future. + +Here's an example of how to implement this in your `RequestHandler`: + +```python +class RequestHandler: + + async def generate(self, request): + """Generate response, with support for request migration""" + for result in self.engine.generate_streaming(request): + # Check if we need to migrate before yielding each token + if is_shutting_down(): + # Raising GeneratorExit closes the stream and triggers migration + raise GeneratorExit("Worker shutting down, migrating request") + + yield result +``` + +When `GeneratorExit` is raised, the frontend receives the incomplete response and can seamlessly continue generation on another available worker instance, preserving the user experience even during worker shutdowns. + +For more information about how request migration works, see the [Request Migration Architecture](../architecture/request_migration.md) documentation. diff --git a/lib/bindings/python/rust/engine.rs b/lib/bindings/python/rust/engine.rs index aea385d0a9..a4816c22e0 100644 --- a/lib/bindings/python/rust/engine.rs +++ b/lib/bindings/python/rust/engine.rs @@ -134,6 +134,9 @@ enum ResponseProcessingError { #[error("python exception: {0}")] PythonException(String), + #[error("python generator exit: {0}")] + PyGeneratorExit(String), + #[error("deserialize error: {0}")] DeserializeError(String), @@ -225,6 +228,9 @@ where let msg = format!("critical error: invalid response object from python async generator; application-logic-mismatch: {}", e); msg } + ResponseProcessingError::PyGeneratorExit(_) => { + "Stream ended before generation completed".to_string() + } ResponseProcessingError::PythonException(e) => { let msg = format!("a python exception was caught while processing the async generator: {}", e); msg @@ -276,8 +282,16 @@ where { let item = item.map_err(|e| { println!(); - Python::with_gil(|py| e.display(py)); - ResponseProcessingError::PythonException(e.to_string()) + let mut is_py_generator_exit = false; + Python::with_gil(|py| { + e.display(py); + is_py_generator_exit = e.is_instance_of::(py); + }); + if is_py_generator_exit { + ResponseProcessingError::PyGeneratorExit(e.to_string()) + } else { + ResponseProcessingError::PythonException(e.to_string()) + } })?; let response = tokio::task::spawn_blocking(move || { Python::with_gil(|py| depythonize::(&item.into_bound(py))) diff --git a/lib/runtime/src/pipeline/network/ingress/push_handler.rs b/lib/runtime/src/pipeline/network/ingress/push_handler.rs index ec8baa044f..9cdd839352 100644 --- a/lib/runtime/src/pipeline/network/ingress/push_handler.rs +++ b/lib/runtime/src/pipeline/network/ingress/push_handler.rs @@ -14,6 +14,7 @@ // limitations under the License. use super::*; +use crate::protocols::maybe_error::MaybeError; use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge}; use serde::{Deserialize, Serialize}; use std::sync::Arc; @@ -105,7 +106,7 @@ impl WorkHandlerMetrics { impl PushWorkHandler for Ingress, ManyOut> where T: Data + for<'de> Deserialize<'de> + std::fmt::Debug, - U: Data + Serialize + std::fmt::Debug, + U: Data + Serialize + MaybeError + std::fmt::Debug, { fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> { // Call the Ingress-specific add_metrics implementation @@ -220,6 +221,14 @@ where let mut send_complete_final = true; while let Some(resp) = stream.next().await { tracing::trace!("Sending response: {:?}", resp); + if let Some(err) = resp.err() { + const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; + if format!("{:?}", err) == STREAM_ERR_MSG { + tracing::warn!(STREAM_ERR_MSG); + send_complete_final = false; + break; + } + } let resp_wrapper = NetworkStreamWrapper { data: Some(resp), complete_final: false,