Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add test pycontext
  • Loading branch information
michaelfeil committed Jul 28, 2025
commit c96e96c92b4d9754646984f0012dd016bda2125d
5 changes: 3 additions & 2 deletions examples/runtime/hello_world/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@

from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.runtime import Context

logger = logging.getLogger(__name__)
configure_dynamo_logging(service_name="backend")


@dynamo_endpoint(str, str)
async def content_generator(request: str):
logger.info(f"Received request: {request}")
async def content_generator(request: str, context: Context):
logger.info(f"Received request: {request} with `id={context.id()}`")
for word in request.split(","):
await asyncio.sleep(1)
yield f"Hello {word}!"
Expand Down
64 changes: 64 additions & 0 deletions lib/bindings/python/rust/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
pub use dynamo_runtime::pipeline::AsyncEngineContext;
use pyo3::prelude::*;
use std::sync::Arc;
use tokio::time::{timeout, Duration};

// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation.
// Kept as class, to allow for future expansion if needed.
#[pyclass]
pub struct PyContext {
pub inner: Arc<dyn AsyncEngineContext>,
}

impl PyContext {
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
Self { inner }
}
}

#[pymethods]
impl PyContext {
// sync method of `await async_is_stopped()`
fn is_stopped(&self) -> bool {
self.inner.is_stopped();
}

// sync method of `await async_is_killed()`
fn is_killed(&self) -> bool {
self.inner.is_killed()
}
// issues a stop generating
fn stop_generating(&self) {
self.inner.stop_generating();
}

fn id(&self) -> &str {
self.inner.id()
}

// allows building a async callback.
// since async tasks in python get canceled, but memory is not freed in rust.
// allow for up to 360 seconds for the async task to cycle and free memory.
// however, calling `is_stopped()` would take a long time, therefore its preferable to have a async method
#[pyo3(signature = (wait_for=60))]
fn async_is_stopped<'a>(&self, py: Python<'a>, wait_for: u16) -> PyResult<Bound<'a, PyAny>> {
let inner = self.inner.clone();
// allow wait_for to be 360 seconds max
if wait_for > 360 || wait_for < 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"wait_for must be between 1 and 360 seconds to allow for async task to cycle.",
));
}

pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Wait up to `wait_for` seconds for inner.stopped() to complete.
if inner.is_stopped() {
return Ok(true);
}
let _ = timeout(Duration::from_secs(wait_for as u64), inner.stopped()).await;

Ok(inner.is_stopped() || inner.is_killed())
})
}
}
7 changes: 6 additions & 1 deletion lib/bindings/python/rust/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

use std::sync::Arc;

use super::context::PyContext;
use pyo3::prelude::*;
use pyo3_async_runtimes::TaskLocals;
use pythonize::{depythonize, pythonize};
Expand Down Expand Up @@ -163,6 +164,7 @@ where

let generator = self.generator.clone();
let event_loop = self.event_loop.clone();
let ctx_python = ctx.clone();

// Acquiring the GIL is similar to acquiring a standard lock/mutex
// Performing this in an tokio async task could block the thread for an undefined amount of time
Expand All @@ -177,7 +179,10 @@ where
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let gen = generator.call1(py, (py_request,))?;
let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?;

let gen = generator.call1(py, (py_request, py_ctx))?;

let locals = TaskLocals::new(event_loop.bind(py).clone());
pyo3_async_runtimes::tokio::into_stream_with_locals_v1(locals, gen.into_bound(py))
})
Expand Down
2 changes: 2 additions & 0 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ impl From<RouterMode> for RsRouterMode {
}
}

mod context;
mod engine;
mod http;
mod llm;
Expand Down Expand Up @@ -100,6 +101,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::PyContext>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?;
Expand Down
11 changes: 9 additions & 2 deletions lib/bindings/python/src/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import asyncio
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, Union

import inspect
from pydantic import BaseModel, ValidationError

# List all the classes in the _core module for re-export
Expand All @@ -29,7 +29,7 @@
from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor

from dynamo._core import PyContext as Context

def dynamo_worker(static=False):
def decorator(func):
Expand Down Expand Up @@ -66,11 +66,18 @@ def dynamo_endpoint(
def decorator(
func: Callable[..., AsyncGenerator[Any, None]],
) -> Callable[..., AsyncGenerator[Any, None]]:
has_context_kwarg = 'context' in inspect.signature(func).parameters

@wraps(func)
async def wrapper(*args, **kwargs) -> AsyncGenerator[Any, None]:
# Validate the request
try:
if isinstance(args[-1], Context):
args, context = args[:-1], args[-1]
if has_context_kwarg:
kwargs['context'] = context
args_list = list(args)

if len(args) in [1, 2] and issubclass(request_model, BaseModel):
if isinstance(args[-1], str):
args_list[-1] = request_model.parse_raw(args[-1])
Expand Down
Loading