Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/runtime/hello_world/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@

from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.runtime import Context

logger = logging.getLogger(__name__)
configure_dynamo_logging(service_name="backend")


@dynamo_endpoint(str, str)
async def content_generator(request: str):
logger.info(f"Received request: {request}")
async def content_generator(request: str, context: Context):
logger.info(f"Received request: {request} with `id={context.id()}`")
for word in request.split(","):
await asyncio.sleep(1)
yield f"Hello {word}!"
Expand Down
64 changes: 64 additions & 0 deletions lib/bindings/python/rust/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
pub use dynamo_runtime::pipeline::AsyncEngineContext;
use pyo3::prelude::*;
use std::sync::Arc;
use tokio::time::{timeout, Duration};

// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation.
// Kept as class, to allow for future expansion if needed.
#[pyclass]
pub struct PyContext {
pub inner: Arc<dyn AsyncEngineContext>,
}

impl PyContext {
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
Self { inner }
}
}

#[pymethods]
impl PyContext {
// sync method of `await async_is_stopped()`
fn is_stopped(&self) -> bool {
self.inner.is_stopped();
}

// sync method of `await async_is_killed()`
fn is_killed(&self) -> bool {
self.inner.is_killed()
}
// issues a stop generating
fn stop_generating(&self) {
self.inner.stop_generating();
}

fn id(&self) -> &str {
self.inner.id()
}

// allows building a async callback.
// since async tasks in python get canceled, but memory is not freed in rust.
// allow for up to 360 seconds for the async task to cycle and free memory.
// however, calling `is_stopped()` would take a long time, therefore its preferable to have a async method
#[pyo3(signature = (wait_for=60))]
fn async_is_stopped<'a>(&self, py: Python<'a>, wait_for: u16) -> PyResult<Bound<'a, PyAny>> {
let inner = self.inner.clone();
// allow wait_for to be 360 seconds max
if wait_for > 360 || wait_for < 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"wait_for must be between 1 and 360 seconds to allow for async task to cycle.",
));
}

pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Wait up to `wait_for` seconds for inner.stopped() to complete.
if inner.is_stopped() {
return Ok(true);
}
let _ = timeout(Duration::from_secs(wait_for as u64), inner.stopped()).await;

Ok(inner.is_stopped() || inner.is_killed())
})
}
}
7 changes: 6 additions & 1 deletion lib/bindings/python/rust/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

use std::sync::Arc;

use super::context::PyContext;
use pyo3::prelude::*;
use pyo3_async_runtimes::TaskLocals;
use pythonize::{depythonize, pythonize};
Expand Down Expand Up @@ -163,6 +164,7 @@ where

let generator = self.generator.clone();
let event_loop = self.event_loop.clone();
let ctx_python = ctx.clone();

// Acquiring the GIL is similar to acquiring a standard lock/mutex
// Performing this in an tokio async task could block the thread for an undefined amount of time
Expand All @@ -177,7 +179,10 @@ where
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let gen = generator.call1(py, (py_request,))?;
let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?;

let gen = generator.call1(py, (py_request, py_ctx))?;

let locals = TaskLocals::new(event_loop.bind(py).clone());
pyo3_async_runtimes::tokio::into_stream_with_locals_v1(locals, gen.into_bound(py))
})
Expand Down
2 changes: 2 additions & 0 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ impl From<RouterMode> for RsRouterMode {
}
}

mod context;
mod engine;
mod http;
mod llm;
Expand Down Expand Up @@ -100,6 +101,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::PyContext>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?;
Expand Down
11 changes: 9 additions & 2 deletions lib/bindings/python/src/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import asyncio
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, Union

import inspect
from pydantic import BaseModel, ValidationError

# List all the classes in the _core module for re-export
Expand All @@ -29,7 +29,7 @@
from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor

from dynamo._core import PyContext as Context

def dynamo_worker(static=False):
def decorator(func):
Expand Down Expand Up @@ -66,11 +66,18 @@ def dynamo_endpoint(
def decorator(
func: Callable[..., AsyncGenerator[Any, None]],
) -> Callable[..., AsyncGenerator[Any, None]]:
has_context_kwarg = 'context' in inspect.signature(func).parameters

@wraps(func)
async def wrapper(*args, **kwargs) -> AsyncGenerator[Any, None]:
# Validate the request
try:
if isinstance(args[-1], Context):
args, context = args[:-1], args[-1]
if has_context_kwarg:
kwargs['context'] = context
args_list = list(args)

if len(args) in [1, 2] and issubclass(request_model, BaseModel):
if isinstance(args[-1], str):
args_list[-1] = request_model.parse_raw(args[-1])
Expand Down
Loading