diff --git a/lib/bindings/python/rust/engine.rs b/lib/bindings/python/rust/engine.rs index 8aeab7ed7d..26d909bed3 100644 --- a/lib/bindings/python/rust/engine.rs +++ b/lib/bindings/python/rust/engine.rs @@ -14,6 +14,8 @@ // limitations under the License. use super::context::{callable_accepts_kwarg, Context}; +use dynamo_llm::protocols::DataStream; +use dynamo_runtime::engine::AsyncEngineContext; use pyo3::prelude::*; use pyo3::types::{PyDict, PyModule}; use pyo3::{PyAny, PyErr}; @@ -73,7 +75,7 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> { /// ``` #[pyclass] #[derive(Clone)] -pub struct PythonAsyncEngine(PythonServerStreamingEngine); +pub struct PythonAsyncEngine(pub PythonServerStreamingEngine); #[pymethods] impl PythonAsyncEngine { @@ -135,31 +137,16 @@ impl PythonServerStreamingEngine { has_context, } } -} -#[derive(Debug, thiserror::Error)] -enum ResponseProcessingError { - #[error("python exception: {0}")] - PythonException(String), - - #[error("python generator exit: {0}")] - PyGeneratorExit(String), - - #[error("deserialize error: {0}")] - DeserializeError(String), - - #[error("gil offload error: {0}")] - OffloadError(String), -} - -#[async_trait] -impl AsyncEngine, ManyOut>, Error> - for PythonServerStreamingEngine -where - Req: Data + Serialize, - Resp: Data + for<'de> Deserialize<'de>, -{ - async fn generate(&self, request: SingleIn) -> Result>, Error> { + /// Generate the response in parts. + pub async fn generate_in_parts( + &self, + request: SingleIn, + ) -> Result<(DataStream>, Arc), Error> + where + Req: Data + Serialize, + Resp: Data + for<'de> Deserialize<'de>, + { // Create a context let (request, context) = request.transfer(()); let ctx = context.context(); @@ -290,8 +277,36 @@ where }); let stream = ReceiverStream::new(rx); + let context = context.context(); + Ok((Box::pin(stream), context)) + } +} + +#[derive(Debug, thiserror::Error)] +enum ResponseProcessingError { + #[error("python exception: {0}")] + PythonException(String), + + #[error("python generator exit: {0}")] + PyGeneratorExit(String), + + #[error("deserialize error: {0}")] + DeserializeError(String), - Ok(ResponseStream::new(Box::pin(stream), context.context())) + #[error("gil offload error: {0}")] + OffloadError(String), +} + +#[async_trait] +impl AsyncEngine, ManyOut>, Error> + for PythonServerStreamingEngine +where + Req: Data + Serialize, + Resp: Data + for<'de> Deserialize<'de>, +{ + async fn generate(&self, request: SingleIn) -> Result>, Error> { + let (stream, context) = self.generate_in_parts(request).await?; + Ok(ResponseStream::new(Box::pin(stream), context)) } } diff --git a/lib/bindings/python/rust/http.rs b/lib/bindings/python/rust/http.rs index 397da921b1..d01bec51d0 100644 --- a/lib/bindings/python/rust/http.rs +++ b/lib/bindings/python/rust/http.rs @@ -177,8 +177,29 @@ where Resp: Data + for<'de> Deserialize<'de>, { async fn generate(&self, request: SingleIn) -> Result>, Error> { - match self.0.generate(request).await { - Ok(res) => Ok(res), + match self.0 .0.generate_in_parts(request).await { + Ok((mut stream, context)) => { + let request_id = context.id().to_string(); + let first_item = match futures::StreamExt::next(&mut stream).await { + // TODO - item may still contain an Annotated error. How do we want to handle that? + // TODO - should we be returning an HttpError here? + Some(item) => item, + None => { + let error_msg = "python async generator stream ended before processing started"; + tracing::warn!(request_id, error_msg); + return Err(Error::new(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + error_msg, + ))); + } + }; + + // Create a new stream that yields the first item followed by the rest of the original stream + let once_stream = futures::stream::once(async { first_item }); + let stream = futures::StreamExt::chain(once_stream, stream); + + Ok(ResponseStream::new(Box::pin(stream), context)) + } // Inspect the error - if it was an HttpError from Python, extract the code and message // and return the rust version of HttpError diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 05239efabf..f97152e30f 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -377,6 +377,12 @@ impl DistributedRuntime { self.inner.runtime().shutdown(); } + fn child_token(&self) -> CancellationToken { + CancellationToken { + inner: self.inner.runtime().child_token(), + } + } + fn event_loop(&self) -> PyObject { self.event_loop.clone() } diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 5e447c047e..3c5bc7301d 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -52,6 +52,30 @@ class DistributedRuntime: Shutdown the runtime by triggering the cancellation token """ ... + + def child_token(self) -> CancellationToken: + """ + Get a child cancellation token from the runtime + """ + ... + +class CancellationToken: + """ + A cancellation token for coordinating shutdown across components + """ + + def cancel(self) -> None: + """ + Cancel the token + """ + ... + + async def cancelled(self) -> None: + """ + Wait for the token to be cancelled + """ + ... + class EtcdClient: """ Etcd is used for discovery in the DistributedRuntime