Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions components/backends/vllm/deploy/disagg_planner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ spec:
- /bin/sh
- -c
args:
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B 2>&1 | tee /tmp/vllm.log"
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --migration-limit=3 2>&1 | tee /tmp/vllm.log"
VllmPrefillWorker:
dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret
Expand Down Expand Up @@ -240,4 +240,4 @@ spec:
- /bin/sh
- -c
args:
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker 2>&1 | tee /tmp/vllm.log
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker --migration-limit=3 2>&1 | tee /tmp/vllm.log
80 changes: 46 additions & 34 deletions components/backends/vllm/src/dynamo/vllm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,34 @@ async def generate_tokens(self, prompt, sampling_params, request_id):
gen = self.engine_client.generate(prompt, sampling_params, request_id)

num_output_tokens_so_far = 0
async for res in gen:
# res is vllm's RequestOutput

# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break

if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break

output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
try:
async for res in gen:
# res is vllm's RequestOutput

# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break

if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break

output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None


class DecodeWorkerHandler(BaseWorkerHandler):
Expand Down Expand Up @@ -173,15 +179,21 @@ async def generate(self, request):
gen = self.engine_client.generate(prompt, sampling_params, request_id)

# Generate only 1 token in prefill
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id,
prompt=res.prompt,
prompt_token_ids=res.prompt_token_ids,
prompt_logprobs=res.prompt_logprobs,
outputs=res.outputs,
finished=res.finished,
metrics=res.metrics,
kv_transfer_params=res.kv_transfer_params,
).model_dump_json()
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id,
prompt=res.prompt,
prompt_token_ids=res.prompt_token_ids,
prompt_logprobs=res.prompt_logprobs,
outputs=res.outputs,
finished=res.finished,
metrics=res.metrics,
kv_transfer_params=res.kv_transfer_params,
).model_dump_json()
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
21 changes: 15 additions & 6 deletions components/backends/vllm/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@

async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidate so no new requests will be accepted.
For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
Expand Down Expand Up @@ -113,7 +113,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):

try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
# for prefill, we want to shutdown the engine after all prefill requests are finished because
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
Expand Down Expand Up @@ -142,6 +146,9 @@ async def init(runtime: DistributedRuntime, config: Config):
)

if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(f"Migration limit: {config.migration_limit}")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
await register_llm(
ModelType.Backend,
generate_endpoint,
Expand Down Expand Up @@ -188,7 +195,9 @@ async def init(runtime: DistributedRuntime, config: Config):

try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
Expand Down
18 changes: 16 additions & 2 deletions lib/bindings/python/rust/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ enum ResponseProcessingError {
#[error("python exception: {0}")]
PythonException(String),

#[error("python generator exit: {0}")]
PyGeneratorExit(String),

#[error("deserialize error: {0}")]
DeserializeError(String),

Expand Down Expand Up @@ -225,6 +228,9 @@ where
let msg = format!("critical error: invalid response object from python async generator; application-logic-mismatch: {}", e);
msg
}
ResponseProcessingError::PyGeneratorExit(_) => {
"Stream ended before generation completed".to_string()
}
ResponseProcessingError::PythonException(e) => {
let msg = format!("a python exception was caught while processing the async generator: {}", e);
msg
Expand Down Expand Up @@ -276,8 +282,16 @@ where
{
let item = item.map_err(|e| {
println!();
Python::with_gil(|py| e.display(py));
ResponseProcessingError::PythonException(e.to_string())
let mut is_py_generator_exit = false;
Python::with_gil(|py| {
e.display(py);
is_py_generator_exit = e.is_instance_of::<pyo3::exceptions::PyGeneratorExit>(py);
});
if is_py_generator_exit {
ResponseProcessingError::PyGeneratorExit(e.to_string())
} else {
ResponseProcessingError::PythonException(e.to_string())
}
})?;
let response = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| depythonize::<Resp>(&item.into_bound(py)))
Expand Down
6 changes: 4 additions & 2 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,20 +475,22 @@ impl Component {

#[pymethods]
impl Endpoint {
#[pyo3(signature = (generator))]
#[pyo3(signature = (generator, graceful_shutdown = true))]
fn serve_endpoint<'p>(
&self,
py: Python<'p>,
generator: PyObject,
graceful_shutdown: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new(
generator,
self.event_loop.clone(),
)?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let builder = self.inner.endpoint_builder().handler(ingress);
let graceful_shutdown = graceful_shutdown.unwrap_or(true);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder.start().await.map_err(to_pyerr)?;
builder.graceful_shutdown(graceful_shutdown).start().await.map_err(to_pyerr)?;
Ok(())
})
}
Expand Down
6 changes: 5 additions & 1 deletion lib/bindings/python/src/dynamo/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,14 @@ class Endpoint:

...

async def serve_endpoint(self, handler: RequestHandler) -> None:
async def serve_endpoint(self, handler: RequestHandler, graceful_shutdown: bool = True) -> None:
"""
Serve an endpoint discoverable by all connected clients at
`{{ namespace }}/components/{{ component_name }}/endpoints/{{ endpoint_name }}`

Args:
handler: The request handler function
graceful_shutdown: Whether to wait for inflight requests to complete during shutdown (default: True)
"""
...

Expand Down
7 changes: 6 additions & 1 deletion lib/runtime/src/component/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ pub struct EndpointConfig {
#[educe(Debug(ignore))]
#[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>,

/// Whether to wait for inflight requests to complete during shutdown
#[builder(default = "true")]
graceful_shutdown: bool,
}

impl EndpointConfigBuilder {
Expand All @@ -55,7 +59,7 @@ impl EndpointConfigBuilder {
}

pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve();
let (endpoint, lease, handler, stats_handler, graceful_shutdown) = self.build_internal()?.dissolve();
let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);

Expand Down Expand Up @@ -109,6 +113,7 @@ impl EndpointConfigBuilder {
let push_endpoint = PushEndpoint::builder()
.service_handler(handler)
.cancellation_token(cancel_token.clone())
.graceful_shutdown(graceful_shutdown)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;

Expand Down
22 changes: 14 additions & 8 deletions lib/runtime/src/pipeline/network/ingress/push_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use tokio_util::sync::CancellationToken;
pub struct PushEndpoint {
pub service_handler: Arc<dyn PushWorkHandler>,
pub cancellation_token: CancellationToken,
#[builder(default = "true")]
pub graceful_shutdown: bool,
}

/// version of crate
Expand Down Expand Up @@ -116,15 +118,19 @@ impl PushEndpoint {
.await
.set_endpoint_health_status(endpoint_name.clone(), HealthStatus::NotReady);

// await for all inflight requests to complete
tracing::info!(
"Waiting for {} inflight requests to complete",
inflight.load(Ordering::SeqCst)
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
// await for all inflight requests to complete if graceful shutdown
if self.graceful_shutdown {
tracing::info!(
"Waiting for {} inflight requests to complete",
inflight.load(Ordering::SeqCst)
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
}
tracing::info!("All inflight requests completed");
} else {
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests");
}
tracing::info!("All inflight requests completed");

Ok(())
}
Expand Down
11 changes: 10 additions & 1 deletion lib/runtime/src/pipeline/network/ingress/push_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// limitations under the License.

use super::*;
use crate::protocols::maybe_error::MaybeError;
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
Expand Down Expand Up @@ -105,7 +106,7 @@ impl WorkHandlerMetrics {
impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> {
// Call the Ingress-specific add_metrics implementation
Expand Down Expand Up @@ -220,6 +221,14 @@ where
let mut send_complete_final = true;
while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp);
if let Some(err) = resp.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
tracing::warn!(STREAM_ERR_MSG);
send_complete_final = false;
break;
}
}
let resp_wrapper = NetworkStreamWrapper {
data: Some(resp),
complete_final: false,
Expand Down
Loading