Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions lib/bindings/python/rust/context.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.

use dynamo_runtime::pipeline::context::Controller;
pub use dynamo_runtime::pipeline::AsyncEngineContext;
use pyo3::prelude::*;
use std::sync::Arc;

// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation.
// Kept as class, to allow for future expansion if needed.
#[derive(Clone)]
#[pyclass]
pub struct PyContext {
pub inner: Arc<dyn AsyncEngineContext>,
pub struct Context {
inner: Arc<dyn AsyncEngineContext>,
}

impl PyContext {
impl Context {
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
Self { inner }
}

pub fn inner(&self) -> Arc<dyn AsyncEngineContext> {
self.inner.clone()
}
}

#[pymethods]
impl PyContext {
impl Context {
#[new]
#[pyo3(signature = (id=None))]
fn py_new(id: Option<String>) -> Self {
let controller = match id {
Some(id) => Controller::new(id),
None => Controller::default(),
};
Self {
inner: Arc::new(controller),
}
}

// sync method of `await async_is_stopped()`
fn is_stopped(&self) -> bool {
self.inner.is_stopped()
Expand Down
14 changes: 7 additions & 7 deletions lib/bindings/python/rust/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use super::context::{callable_accepts_kwarg, PyContext};
use super::context::{callable_accepts_kwarg, Context};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use pyo3::{PyAny, PyErr};
Expand Down Expand Up @@ -114,7 +114,7 @@ pub struct PythonServerStreamingEngine {
_cancel_token: CancellationToken,
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
has_pycontext: bool,
has_context: bool,
}

impl PythonServerStreamingEngine {
Expand All @@ -123,7 +123,7 @@ impl PythonServerStreamingEngine {
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
) -> Self {
let has_pycontext = Python::with_gil(|py| {
let has_context = Python::with_gil(|py| {
let callable = generator.bind(py);
callable_accepts_kwarg(py, callable, "context").unwrap_or(false)
});
Expand All @@ -132,7 +132,7 @@ impl PythonServerStreamingEngine {
_cancel_token: cancel_token,
generator,
event_loop,
has_pycontext,
has_context,
}
}
}
Expand Down Expand Up @@ -175,7 +175,7 @@ where
let generator = self.generator.clone();
let event_loop = self.event_loop.clone();
let ctx_python = ctx.clone();
let has_pycontext = self.has_pycontext;
let has_context = self.has_context;

// Acquiring the GIL is similar to acquiring a standard lock/mutex
// Performing this in an tokio async task could block the thread for an undefined amount of time
Expand All @@ -190,9 +190,9 @@ where
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?;
let py_ctx = Py::new(py, Context::new(ctx_python.clone()))?;

let gen = if has_pycontext {
let gen = if has_context {
// Pass context as a kwarg
let kwarg = PyDict::new(py);
kwarg.set_item("context", &py_ctx)?;
Expand Down
73 changes: 57 additions & 16 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use tokio::sync::Mutex;
use dynamo_runtime::{
self as rs, logging,
pipeline::{
network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn,
context::Context as RsContext, network::egress::push_router::RouterMode as RsRouterMode,
EngineStream, ManyOut, SingleIn,
},
protocols::annotated::Annotated as RsAnnotated,
traits::DistributedRuntimeProvider,
Expand Down Expand Up @@ -104,7 +105,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::PyContext>()?;
m.add_class::<context::Context>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?;
Expand Down Expand Up @@ -697,27 +698,29 @@ impl Client {
}

/// Issue a request to the endpoint using the default routing strategy.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn generate<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
if self.router.client.is_static() {
self.r#static(py, request, annotated)
self.r#static(py, request, annotated, context)
} else {
self.random(py, request, annotated)
self.random(py, request, annotated, context)
}
}

/// Send a request to the next endpoint in a round-robin fashion.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn round_robin<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -726,7 +729,15 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.round_robin(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.round_robin(request.into()).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
Expand All @@ -736,12 +747,13 @@ impl Client {
}

/// Send a request to a random endpoint.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn random<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -750,7 +762,15 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.random(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.random(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.random(request.into()).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
Expand All @@ -760,13 +780,14 @@ impl Client {
}

/// Directly send a request to a specific endpoint.
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn direct<'p>(
&self,
py: Python<'p>,
request: PyObject,
instance_id: i64,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -775,10 +796,21 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client
.direct(request.into(), instance_id)
.await
.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client
.direct(request, instance_id)
.await
.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client
.direct(request.into(), instance_id)
.await
.map_err(to_pyerr)?,
};

tokio::spawn(process_stream(stream, tx));

Expand All @@ -790,12 +822,13 @@ impl Client {
}

/// Directly send a request to a pre-defined static worker
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn r#static<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
Expand All @@ -804,7 +837,15 @@ impl Client {
let client = self.router.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.r#static(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.r#static(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.r#static(request.into()).await.map_err(to_pyerr)?,
};

tokio::spawn(process_stream(stream, tx));

Expand Down
2 changes: 1 addition & 1 deletion lib/bindings/python/src/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from dynamo._core import Backend as Backend
from dynamo._core import Client as Client
from dynamo._core import Component as Component
from dynamo._core import Context as Context
from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import Endpoint as Endpoint
from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
from dynamo._core import PyContext as PyContext


def dynamo_worker(static=False):
Expand Down
Loading