Skip to content
Prev Previous commit
Next Next commit
fix: Retain context_id between requests
  • Loading branch information
kthui committed Aug 25, 2025
commit 1c8188a494a369fc1a406a4d6e38a37bc1ea6282
84 changes: 63 additions & 21 deletions lib/llm/src/migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use crate::{

use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
SingleIn, async_trait,
AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn, async_trait,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
};
Expand Down Expand Up @@ -50,10 +50,12 @@ impl
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let (preprocessed_request, context) = request.transfer(());
let context_id = context.id().to_string();
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
let retry_manager =
RetryManager::build(preprocessed_request, next, self.migration_limit).await?;
RetryManager::build(context_id, preprocessed_request, next, self.migration_limit)
.await?;
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| {
let engine_ctx = engine_ctx_.clone();
async move {
Expand All @@ -71,6 +73,7 @@ impl
}

struct RetryManager {
context_id: String,
request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
Expand All @@ -79,11 +82,13 @@ struct RetryManager {

impl RetryManager {
pub async fn build(
context_id: String,
preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
retries_left: u32,
) -> Result<Self> {
let mut slf = Self {
context_id,
request: preprocessed_request,
next_generate: next,
next_stream: None,
Expand Down Expand Up @@ -130,8 +135,7 @@ impl RetryManager {
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
while self.retries_left > 0 {
self.retries_left -= 1;
// TODO: Is there anything needed to pass between context?
let request = SingleIn::new(self.request.clone());
let request = Context::with_id(self.request.clone(), self.context_id.clone());
response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
Expand Down Expand Up @@ -237,15 +241,22 @@ mod tests {
num_responses: usize,
token_offset: u32,
call_count: Arc<AtomicU32>,
context_id: String,
}

impl MockEngine {
fn new(behavior: MockBehavior, num_responses: usize, token_offset: u32) -> Self {
fn new(
behavior: MockBehavior,
num_responses: usize,
token_offset: u32,
context_id: String,
) -> Self {
Self {
behavior,
num_responses,
token_offset,
call_count: Arc::new(AtomicU32::new(0)),
context_id,
}
}
}
Expand All @@ -263,7 +274,14 @@ mod tests {
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
let (preprocessed_request, _) = request.transfer(());
let (preprocessed_request, context) = request.transfer(());

// Assert that the context_id matches the expected one
assert_eq!(
context.id().to_string(),
self.context_id,
"Context ID mismatch"
);

// Calculate how many responses we've already generated based on request token_ids
// Initial request has [1, 2, 3], so anything beyond that are generated responses
Expand Down Expand Up @@ -338,7 +356,7 @@ mod tests {
}

let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
Expand Down Expand Up @@ -369,7 +387,7 @@ mod tests {
});

let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
Expand Down Expand Up @@ -405,7 +423,7 @@ mod tests {
});

let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
Expand All @@ -422,7 +440,7 @@ mod tests {
});

let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
Expand Down Expand Up @@ -457,7 +475,7 @@ mod tests {
});

let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
Expand All @@ -472,12 +490,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_no_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100));
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::Success,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;

let mut retry_manager = RetryManager::build(request, next_generate, 0)
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 0)
.await
.expect("Failed to build RetryManager");

Expand All @@ -504,12 +528,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_new_request_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100));
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::FailThenSuccess,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;

let mut retry_manager = RetryManager::build(request, next_generate, 3)
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
.await
.expect("Failed to build RetryManager");

Expand Down Expand Up @@ -537,16 +567,18 @@ mod tests {
async fn test_retry_manager_ongoing_request_migration() {
dynamo_runtime::logging::init();

let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 5 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;

let mut retry_manager = RetryManager::build(request, next_generate, 3)
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
.await
.expect("Failed to build RetryManager");

Expand Down Expand Up @@ -574,13 +606,19 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_new_request_migration_indefinite_failure() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(0);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100));
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::AlwaysFail,
0,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;

// Should fail to build due to initial stream creation failure after exhausting all 3 retries
let retry_manager_result = RetryManager::build(request, next_generate, 3).await;
let retry_manager_result = RetryManager::build(context_id, request, next_generate, 3).await;

assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result {
Expand All @@ -595,16 +633,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlways { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;

let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");

Expand Down Expand Up @@ -645,16 +685,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;

let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");

Expand Down