diff --git a/lib/bindings/python/rust/engine.rs b/lib/bindings/python/rust/engine.rs index c26c0f3bb5..dd3bfed42a 100644 --- a/lib/bindings/python/rust/engine.rs +++ b/lib/bindings/python/rust/engine.rs @@ -14,6 +14,7 @@ // limitations under the License. use super::context::{callable_accepts_kwarg, PyContext}; +use futures::stream::{self, StreamExt as FuturesStreamExt}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyModule}; use pyo3::{PyAny, PyErr}; @@ -21,7 +22,7 @@ use pyo3_async_runtimes::TaskLocals; use pythonize::{depythonize, pythonize}; use std::sync::Arc; use tokio::sync::mpsc; -use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tokio_stream::wrappers::ReceiverStream; pub use dynamo_runtime::{ pipeline::{ @@ -96,6 +97,10 @@ impl PythonAsyncEngine { Arc::new(event_loop), ))) } + + pub fn block_until_stream_item(&mut self, enabled: bool) { + self.0.block_until_stream_item(enabled); + } } #[async_trait] @@ -115,6 +120,7 @@ pub struct PythonServerStreamingEngine { generator: Arc, event_loop: Arc, has_pycontext: bool, + block_until_stream_item: bool, } impl PythonServerStreamingEngine { @@ -133,8 +139,13 @@ impl PythonServerStreamingEngine { generator, event_loop, has_pycontext, + block_until_stream_item: false, } } + + pub fn block_until_stream_item(&mut self, enabled: bool) { + self.block_until_stream_item = enabled; + } } #[derive(Debug, thiserror::Error)] @@ -208,14 +219,46 @@ where }) .await??; - let stream = Box::pin(stream); - // process the stream // any error thrown in the stream will be caught and complete the processing task // errors are captured by a task that is watching the processing task // the error will be emitted as an annotated error let request_id = id.clone(); + let mut stream = Box::pin(stream); + + let stream = if self.block_until_stream_item { + let first_item = match FuturesStreamExt::next(&mut stream).await { + Some(Ok(item)) => item, + Some(Err(e)) => { + // Any Python exception (including HttpError) is already wrapped in PyErr + // The HttpAsyncEngine will inspect this PyErr later to see if it's an HttpError + tracing::warn!( + request_id, + "Python exception occurred before finish of first iteration: {}", + e + ); + return Err(Error::new(e)); + } + None => { + tracing::warn!( + request_id, + "python async generator stream ended before processing started" + ); + return Err(Error::new(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "python async generator stream ended before processing started", + ))); + } + }; + // Create a new stream that yields the first item followed by the rest of the original stream + let stream = + futures::StreamExt::chain(stream::once(futures::future::ok(first_item)), stream); + FuturesStreamExt::boxed(stream) + } else { + stream + }; + tokio::spawn(async move { tracing::debug!( request_id, @@ -225,7 +268,8 @@ where let mut stream = stream; let mut count = 0; - while let Some(item) = stream.next().await { + // Fix the third error by explicitly using FuturesStreamExt::next + while let Some(item) = FuturesStreamExt::next(&mut stream).await { count += 1; tracing::trace!( request_id, diff --git a/lib/bindings/python/rust/http.rs b/lib/bindings/python/rust/http.rs index 3a22092334..e7862d92ba 100644 --- a/lib/bindings/python/rust/http.rs +++ b/lib/bindings/python/rust/http.rs @@ -147,6 +147,10 @@ impl HttpAsyncEngine { pub fn new(generator: PyObject, event_loop: PyObject) -> PyResult { Ok(PythonAsyncEngine::new(generator, event_loop)?.into()) } + + pub fn block_until_stream_item(&mut self, enabled: bool) { + self.0.block_until_stream_item(enabled); + } } #[async_trait]