Skip to content
Prev Previous commit
Next Next commit
feat: FT Python Context and Unit Tests (#2677)
  • Loading branch information
kthui authored Aug 25, 2025
commit e75ca6d908b3d45633384d5e373372c72a0551d5
30 changes: 24 additions & 6 deletions lib/bindings/python/rust/context.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.

use dynamo_runtime::pipeline::context::Controller;
pub use dynamo_runtime::pipeline::AsyncEngineContext;
use pyo3::prelude::*;
use std::sync::Arc;

// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation.
// Kept as class, to allow for future expansion if needed.
#[derive(Clone)]
#[pyclass]
pub struct PyContext {
pub inner: Arc<dyn AsyncEngineContext>,
pub struct Context {
inner: Arc<dyn AsyncEngineContext>,
}

impl PyContext {
impl Context {
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
Self { inner }
}

pub fn inner(&self) -> Arc<dyn AsyncEngineContext> {
self.inner.clone()
}
}

#[pymethods]
impl PyContext {
impl Context {
#[new]
#[pyo3(signature = (id=None))]
fn py_new(id: Option<String>) -> Self {
let controller = match id {
Some(id) => Controller::new(id),
None => Controller::default(),
};
Self {
inner: Arc::new(controller),
}
}

// sync method of `await async_is_stopped()`
fn is_stopped(&self) -> bool {
self.inner.is_stopped()
Expand Down
14 changes: 7 additions & 7 deletions lib/bindings/python/rust/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use super::context::{callable_accepts_kwarg, PyContext};
use super::context::{callable_accepts_kwarg, Context};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use pyo3::{PyAny, PyErr};
Expand Down Expand Up @@ -114,7 +114,7 @@ pub struct PythonServerStreamingEngine {
_cancel_token: CancellationToken,
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
has_pycontext: bool,
has_context: bool,
}

impl PythonServerStreamingEngine {
Expand All @@ -123,7 +123,7 @@ impl PythonServerStreamingEngine {
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
) -> Self {
let has_pycontext = Python::with_gil(|py| {
let has_context = Python::with_gil(|py| {
let callable = generator.bind(py);
callable_accepts_kwarg(py, callable, "context").unwrap_or(false)
});
Expand All @@ -132,7 +132,7 @@ impl PythonServerStreamingEngine {
_cancel_token: cancel_token,
generator,
event_loop,
has_pycontext,
has_context,
}
}
}
Expand Down Expand Up @@ -175,7 +175,7 @@ where
let generator = self.generator.clone();
let event_loop = self.event_loop.clone();
let ctx_python = ctx.clone();
let has_pycontext = self.has_pycontext;
let has_context = self.has_context;

// Acquiring the GIL is similar to acquiring a standard lock/mutex
// Performing this in an tokio async task could block the thread for an undefined amount of time
Expand All @@ -190,9 +190,9 @@ where
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?;
let py_ctx = Py::new(py, Context::new(ctx_python.clone()))?;

let gen = if has_pycontext {
let gen = if has_context {
// Pass context as a kwarg
let kwarg = PyDict::new(py);
kwarg.set_item("context", &py_ctx)?;
Expand Down
73 changes: 57 additions & 16 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use tokio::sync::Mutex;
use dynamo_runtime::{
self as rs, logging,
pipeline::{
network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn,
context::Context as RsContext, network::egress::push_router::RouterMode as RsRouterMode,
EngineStream, ManyOut, SingleIn,
},
protocols::annotated::Annotated as RsAnnotated,
traits::DistributedRuntimeProvider,
Expand Down Expand Up @@ -104,7 +105,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::PyContext>()?;
m.add_class::<context::Context>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?;
Expand Down Expand Up @@ -697,27 +698,29 @@ impl Client {
}

/// Issue a request to the endpoint using the default routing strategy.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn generate<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
if self.router.client.is_static() {
self.r#static(py, request, annotated)
self.r#static(py, request, annotated, context)
} else {
self.random(py, request, annotated)
self.random(py, request, annotated, context)
}
}

/// Send a request to the next endpoint in a round-robin fashion.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn round_robin<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -726,7 +729,15 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.round_robin(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.round_robin(request.into()).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
Expand All @@ -736,12 +747,13 @@ impl Client {
}

/// Send a request to a random endpoint.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn random<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -750,7 +762,15 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.random(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.random(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.random(request.into()).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
Expand All @@ -760,13 +780,14 @@ impl Client {
}

/// Directly send a request to a specific endpoint.
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn direct<'p>(
&self,
py: Python<'p>,
request: PyObject,
instance_id: i64,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -775,10 +796,21 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client
.direct(request.into(), instance_id)
.await
.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client
.direct(request, instance_id)
.await
.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client
.direct(request.into(), instance_id)
.await
.map_err(to_pyerr)?,
};

tokio::spawn(process_stream(stream, tx));

Expand All @@ -790,12 +822,13 @@ impl Client {
}

/// Directly send a request to a pre-defined static worker
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn r#static<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -804,7 +837,15 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.r#static(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.r#static(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.r#static(request.into()).await.map_err(to_pyerr)?,
};

tokio::spawn(process_stream(stream, tx));

Expand Down
2 changes: 1 addition & 1 deletion lib/bindings/python/src/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from dynamo._core import Backend as Backend
from dynamo._core import Client as Client
from dynamo._core import Component as Component
from dynamo._core import Context as Context
from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import Endpoint as Endpoint
from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
from dynamo._core import PyContext as PyContext


def dynamo_worker(static=False):
Expand Down
Loading
Loading