diff --git a/Cargo.lock b/Cargo.lock index 174759fa26..cba8d9f485 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2663,7 +2663,7 @@ dependencies = [ "bytes", "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "chrono", - "clap 4.5.52", + "clap 4.5.53", "criterion 0.3.6", "cudarc", "dashmap 5.5.3", @@ -4065,8 +4065,8 @@ checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" dependencies = [ "dirs", "futures", - "indicatif 0.17.11", "http 1.4.0", + "indicatif 0.17.11", "libc", "log", "num_cpus", diff --git a/components/src/dynamo/mocker/args.py b/components/src/dynamo/mocker/args.py index 3bdde43daf..6180a1c9b6 100644 --- a/components/src/dynamo/mocker/args.py +++ b/components/src/dynamo/mocker/args.py @@ -113,6 +113,7 @@ def create_temp_engine_args_file(args) -> Path: else None, "is_prefill": getattr(args, "is_prefill_worker", None), "is_decode": getattr(args, "is_decode_worker", None), + "enable_local_indexer": getattr(args, "enable_local_indexer", None), } # Remove None values to only include explicitly set arguments @@ -284,6 +285,12 @@ def parse_args(): default=False, help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)", ) + parser.add_argument( + "--enable-local-indexer", + action="store_true", + default=False, + help="Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)", + ) parser.add_argument( "--store-kv", type=str, diff --git a/components/src/dynamo/vllm/args.py b/components/src/dynamo/vllm/args.py index 64c4677db6..de2a2eaa37 100644 --- a/components/src/dynamo/vllm/args.py +++ b/components/src/dynamo/vllm/args.py @@ -40,6 +40,7 @@ class Config: custom_jinja_template: Optional[str] = None store_kv: str request_plane: str + enable_local_indexer: bool = False # mirror vLLM model: str @@ -204,6 +205,13 @@ def parse_args() -> Config: default=os.environ.get("DYN_REQUEST_PLANE", "nats"), help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", ) + parser.add_argument( + "--enable-local-indexer", + type=str, + choices=["true", "false"], + default=os.environ.get("DYN_LOCAL_INDEXER", "false"), + help="Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).", + ) parser.add_argument( "--use-vllm-tokenizer", action="store_true", @@ -214,6 +222,7 @@ def parse_args() -> Config: parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() + args.enable_local_indexer = str(args.enable_local_indexer).lower() == "true" engine_args = AsyncEngineArgs.from_cli_args(args) # Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor. @@ -312,6 +321,7 @@ def parse_args() -> Config: config.mm_prompt_template = args.mm_prompt_template config.store_kv = args.store_kv config.request_plane = args.request_plane + config.enable_local_indexer = args.enable_local_indexer config.use_vllm_tokenizer = args.use_vllm_tokenizer # Validate custom Jinja template file exists if provided diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index 3f8ab182d0..b26d663891 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -224,6 +224,7 @@ def setup_kv_event_publisher( worker_id=generate_endpoint.connection_id(), kv_block_size=vllm_config.cache_config.block_size, zmq_endpoint=zmq_endpoint, + enable_local_indexer=config.enable_local_indexer, ) kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config) kv_publishers.append(kv_publisher) @@ -336,6 +337,7 @@ async def register_vllm_model( runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"] + runtime_config.enable_local_indexer = config.enable_local_indexer # Add tool/reasoning parsers for decode models if model_type != ModelType.Prefill: diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 986a95464d..e4802083ba 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -21,7 +21,7 @@ use rs::traits::events::EventSubscriber; use tracing; use llm_rs::kv_router::protocols::*; -use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks}; +use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener}; use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; #[pyfunction] @@ -106,6 +106,9 @@ pub struct ZmqKvEventPublisherConfig { pub zmq_endpoint: String, #[pyo3(get, set)] pub zmq_topic: String, + #[pyo3(get, set)] + pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to + // both global and worker-local KvIndexers } #[pymethods] @@ -115,19 +118,22 @@ impl ZmqKvEventPublisherConfig { worker_id, kv_block_size, zmq_endpoint = "tcp://127.0.0.1:5557".to_string(), - zmq_topic = "".to_string() + zmq_topic = "".to_string(), + enable_local_indexer = false ))] pub fn new( worker_id: WorkerId, kv_block_size: usize, zmq_endpoint: String, zmq_topic: String, + enable_local_indexer: bool, ) -> Self { Self { worker_id, kv_block_size, zmq_endpoint, zmq_topic, + enable_local_indexer, } } } @@ -141,13 +147,14 @@ pub(crate) struct ZmqKvEventPublisher { impl ZmqKvEventPublisher { #[new] fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult { - let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( + let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer( component.inner, config.kv_block_size as u32, Some(KvEventSourceConfig::Zmq { endpoint: config.zmq_endpoint, topic: config.zmq_topic, }), + config.enable_local_indexer, ) .map_err(to_pyerr)?; Ok(Self { inner }) @@ -179,7 +186,7 @@ impl ZmqKvEventListener { let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); let shutdown_token = tokio_util::sync::CancellationToken::new(); - tokio::spawn(llm_rs::kv_router::publisher::start_zmq_listener( + tokio::spawn(start_zmq_listener( zmq_endpoint, zmq_topic, tx, diff --git a/lib/bindings/python/rust/llm/local_model.rs b/lib/bindings/python/rust/llm/local_model.rs index 15fb24f373..3917c7a089 100644 --- a/lib/bindings/python/rust/llm/local_model.rs +++ b/lib/bindings/python/rust/llm/local_model.rs @@ -49,6 +49,11 @@ impl ModelRuntimeConfig { self.inner.data_parallel_size = data_parallel_size; } + #[setter] + fn set_enable_local_indexer(&mut self, enable_local_indexer: bool) { + self.inner.enable_local_indexer = enable_local_indexer; + } + fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; self.inner @@ -103,6 +108,11 @@ impl ModelRuntimeConfig { self.inner.reasoning_parser.clone() } + #[getter] + fn enable_local_indexer(&self) -> bool { + self.inner.enable_local_indexer + } + #[getter] fn runtime_data(&self, py: Python<'_>) -> PyResult { let dict = PyDict::new(py); diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index ba8e9a8434..1a0d1913aa 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -460,6 +460,7 @@ class ModelRuntimeConfig: max_num_batched_tokens: int | None tool_call_parser: str | None reasoning_parser: str | None + enable_local_indexer: bool runtime_data: dict[str, Any] tensor_model_config: Any | None @@ -843,7 +844,8 @@ class ZmqKvEventPublisherConfig: worker_id: int, kv_block_size: int, zmq_endpoint: str = "tcp://127.0.0.1:5557", - zmq_topic: str = "" + zmq_topic: str = "", + enable_local_indexer: bool = False ) -> None: """ Configuration for the ZmqKvEventPublisher. @@ -852,6 +854,7 @@ class ZmqKvEventPublisherConfig: :param kv_block_size: The block size for the key-value store. :param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557". :param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string. + :param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False. """ ... diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 59ec3c5cab..665c3ed7ee 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -34,8 +34,11 @@ pub mod scheduler; pub mod scoring; pub mod sequence; pub mod subscriber; +pub mod worker_query; +use indexer::WorkerKvQueryResponse; pub use prefill_router::PrefillRouter; +use worker_query::WorkerQueryClient; use crate::{ kv_router::{ @@ -45,11 +48,12 @@ use crate::{ compute_block_hash_for_seq, compute_seq_hash_for_block, }, protocols::{ - LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank, + LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult, + WorkerWithDpRank, }, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, sequence::SequenceError, - subscriber::start_kv_router_background, + subscriber::{recover_from_all_workers, start_kv_router_background}, }, local_model::runtime_config::ModelRuntimeConfig, model_card::ModelDeploymentCard, @@ -77,6 +81,10 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; pub const RADIX_STATE_BUCKET: &str = "radix-bucket"; pub const RADIX_STATE_FILE: &str = "radix-state"; +// for worker-local kvindexer query +pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query"; +pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer + // for router discovery registration pub const KV_ROUTER_COMPONENT: &str = "kv-router"; pub const KV_ROUTER_ENDPOINT: &str = "generate"; @@ -270,6 +278,8 @@ pub struct KvRouter { cancellation_token: tokio_util::sync::CancellationToken, client: Client, + + worker_query_client: Option, } impl KvRouter { @@ -296,7 +306,7 @@ impl KvRouter { endpoint: endpoint_id.name.clone(), }; let discovery_stream = discovery - .list_and_watch(discovery_key, Some(cancellation_token.clone())) + .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone())) .await?; let runtime_configs_rx = watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| { @@ -333,13 +343,19 @@ impl KvRouter { component.clone(), block_size, instance_ids_rx, - runtime_configs_rx, + runtime_configs_rx.clone(), selector, kv_router_config.router_replica_sync, consumer_id.clone(), ) .await?; + // Initialize worker query client using namespace abstraction + // (created before background task so we can use it for startup recovery) + let worker_query_client = + worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone()); + tracing::info!("Worker query client initialized"); + // Start KV event subscriber background process (only when use_kv_events is enabled) if kv_router_config.use_kv_events && let Indexer::KvIndexer(ref kv_indexer) = indexer @@ -360,6 +376,47 @@ impl KvRouter { kv_router_config.router_reset_states, ) .await?; + + // Perform startup recovery from workers with local indexers + // This catches up on any events missed while the router was offline + let last_event_ids = kv_indexer + .get_last_received_event_ids() + .await + .unwrap_or_default(); + let instances = client.instance_source.as_ref().borrow().clone(); + let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); + + if !worker_ids.is_empty() { + tracing::info!( + worker_count = worker_ids.len(), + "Starting recovery from workers with local indexers" + ); + + // NOTE: recover_from_all_workers() is a no-op if + // Worker with worker_id is not associated with a + // local indexer instance. + let recovered = recover_from_all_workers( + &worker_query_client, + &last_event_ids, + &worker_ids, + &kv_indexer.event_sender(), + ) + .await; + + if recovered > 0 { + tracing::info!( + recovered_events = recovered, + "KV Router startup: Recovered {} KV events from workers {:?}", + recovered, + worker_ids + ); + } else { + tracing::info!( + "KV Router startup: No KV events recovered from workers {:?}", + worker_ids + ); + } + } } tracing::info!("KV Routing initialized"); @@ -370,6 +427,7 @@ impl KvRouter { kv_router_config, cancellation_token, client, + worker_query_client: Some(worker_query_client), }) } @@ -502,6 +560,62 @@ impl KvRouter { pub async fn dump_events(&self) -> Result, KvRouterError> { self.indexer.dump_events().await } + + /// Query a specific worker's local KV indexer for its events + /// (See docstring for `WorkerQueryClient.query_worker()`) + pub async fn query_worker_local_kv( + &self, + worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, + ) -> Result { + let query_client = self + .worker_query_client + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?; + + query_client + .query_worker(worker_id, start_event_id, end_event_id) + .await + } + + /// Recover missed KV events from a specific worker. + /// + /// Queries the worker's local KV indexer for events starting from + /// `start_event_id` and applies them to the router's indexer. + /// + /// # Arguments + /// + /// * `worker_id` - The worker to recover from + /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning + /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all + pub async fn recover_from_worker( + &self, + worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, + ) -> Result { + let query_client = self + .worker_query_client + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?; + + let event_tx = match &self.indexer { + Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(), + Indexer::None => { + anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)") + } + }; + + subscriber::recover_from_worker( + query_client, + worker_id, + start_event_id, + end_event_id, + &event_tx, + ) + .await + } } // NOTE: KVRouter works like a PushRouter, diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index 6349590403..37060b93f9 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -44,7 +44,7 @@ use std::{ collections::{HashMap, VecDeque}, iter, rc::Rc, - sync::{Arc, OnceLock}, + sync::{Arc, Mutex, OnceLock}, thread::JoinHandle, time::{Duration, Instant}, }; @@ -199,6 +199,31 @@ impl RouterEvent { } } +// ------- +// Distributed router - Worker KV Query types +// ------- + +/// Request to query a worker's local KV indexer. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct WorkerKvQueryRequest { + /// The worker ID of the worker to query. + pub worker_id: WorkerId, + + /// The query can specify the [start, end) range of event id's to return. + /// If neither is specified, the worker dumps all events. + /// If only one is specified, `start` is assumed to be the oldest logged event, + /// and `end` is assumed to be the newest logged event. + pub start_event_id: Option, + pub end_event_id: Option, +} + +/// Response from a worker's local KV indexer. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct WorkerKvQueryResponse { + /// The events from the worker local KvIndexer. + pub events: Vec, +} + /// A block in the Radix Tree. #[derive(Debug)] struct RadixBlock { @@ -781,6 +806,13 @@ pub struct GetWorkersRequest { pub resp: oneshot::Sender>, } +/// A request to get the last received event ID per worker. +/// Used for fault tolerance recovery to determine which events to request from workers. +pub struct GetLastReceivedEventIdsRequest { + /// Channel to send the last received event IDs per worker + pub resp: oneshot::Sender>, +} + #[async_trait] pub trait KvIndexerInterface { /// Find matches for a given sequence of `LocalBlockHash`es. @@ -885,6 +917,8 @@ pub struct KvIndexer { dump_tx: mpsc::Sender, /// A sender for routing decision requests. routing_tx: mpsc::Sender, + /// A sender for getting last received event IDs (for fault tolerance recovery). + last_event_ids_tx: mpsc::Sender, /// A handle to the background task managing the KV store. task: OnceLock>, /// The size of the KV block this indexer can handle. @@ -918,6 +952,9 @@ impl KvIndexer { let (dump_tx, dump_rx) = mpsc::channel::(16); let (routing_tx, mut routing_rx) = mpsc::channel::(2048); let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1); + let (last_event_ids_tx, mut last_event_ids_rx) = + mpsc::channel::(16); + let cancel_clone = token.clone(); let task = std::thread::spawn(move || { @@ -942,6 +979,10 @@ impl KvIndexer { }); let mut event_id_counter = 0u64; + // Track last received event ID per worker (for fault tolerance recovery) + // Only used when enable_event_tracking is true + let mut last_received_event_id: HashMap = HashMap::new(); + loop { // Create a future that sleeps until the next expiration time let expiry_fut = if let Some(ref pm) = prune_manager @@ -968,6 +1009,10 @@ impl KvIndexer { let _ = get_workers_req.resp.send(workers); } + Some(req) = last_event_ids_rx.recv() => { + let _ = req.resp.send(last_received_event_id.clone()); + } + Some(_) = prune_rx.recv() => { // Tree size-based pruning triggered let Some(ref mut pm) = prune_manager else { continue }; @@ -990,6 +1035,33 @@ impl KvIndexer { } Some(event) = event_rx.recv() => { + // Track last received event ID per worker + // Check for gaps before updating the last received ID + // TODO should this trigger a recovery event? + let last_id = *last_received_event_id.get(&event.worker_id).unwrap_or(&0); + let incoming_id = event.event.event_id; + + // Detect gap: if incoming ID is more than 1 greater than last received + if incoming_id > last_id + 1 && last_id > 0 { + let gap_start = last_id + 1; + let gap_end = incoming_id - 1; + tracing::warn!( + worker_id = event.worker_id, + gap_start, + gap_end, + gap_size = gap_end - gap_start + 1, + "Event ID gap detected! Missed events [{}, {}]. \ + If this is a global KvIndexer, within a KvRouter context, + consider calling KvRouter::query_worker_local_kv() to potentially recover worker-stored events.", + gap_start, + gap_end, + ); + } + + // Update last received event ID (use max to handle out-of-order events) + let entry = last_received_event_id.entry(event.worker_id).or_insert(0); + *entry = (*entry).max(event.event.event_id); + let event_type = KvIndexerMetrics::get_event_type(&event.event.data); let result = trie.apply_event(event.clone()); let result_is_ok = result.is_ok(); @@ -1121,6 +1193,7 @@ impl KvIndexer { get_workers_tx, dump_tx, routing_tx, + last_event_ids_tx, task: once, kv_block_size, } @@ -1173,6 +1246,48 @@ impl KvIndexer { pub fn get_workers_sender(&self) -> mpsc::Sender { self.get_workers_tx.clone() } + + /// Get a sender for last received event IDs requests. + /// + /// ### Returns + /// + /// A `mpsc::Sender` for `GetLastReceivedEventIdsRequest`s. + pub fn last_event_ids_sender(&self) -> mpsc::Sender { + self.last_event_ids_tx.clone() + } + + /// Get the last received event ID for each worker. + /// + /// This method is used for **fault tolerance recovery** when the router needs to + /// catch up on missed events after a disconnect. By tracking the last event ID + /// received from each worker, the router can query workers for events starting + /// from `last_id + 1` to recover missed state. + /// + /// **Note**: This method is intdned for the global `KvIndexer` used by routers, + /// not on `LocalKvIndexer` (worker-side) or `KvIndexerSharded`. + /// + /// ### Returns + /// + /// A `HashMap` mapping worker IDs to their last received event ID. + /// + pub async fn get_last_received_event_ids( + &self, + ) -> Result, KvRouterError> { + let (resp_tx, resp_rx) = oneshot::channel(); + let req = GetLastReceivedEventIdsRequest { resp: resp_tx }; + + if let Err(e) = self.last_event_ids_tx.send(req).await { + tracing::error!( + "Failed to send last event IDs request: {:?}; the indexer maybe offline", + e + ); + return Err(KvRouterError::IndexerOffline); + } + + resp_rx + .await + .map_err(|_| KvRouterError::IndexerDroppedRequest) + } } #[async_trait] @@ -1285,6 +1400,574 @@ impl Drop for KvIndexer { } } +// ------------------------------------------------- +// Decentralized router: LocalKvIndexer for workers +// ------------------------------------------------- + +/// A thin wrapper around KvIndexer that buffers recent events +/// (e.g. which may be queued by router upon startup) +/// +pub struct LocalKvIndexer { + /// The underlying indexer + indexer: KvIndexer, + /// Circular buffer of recent events + event_buffer: Mutex>, + /// Maximum number of events to keep in buffer + max_buffer_size: usize, // Router sets this to WORKER_KV_INDEXER_BUFFER_SIZE +} + +impl LocalKvIndexer { + /// create a new LocalKvIndexer pointing to a KvIndexer. + pub fn new( + token: CancellationToken, + kv_block_size: u32, + metrics: Arc, + max_buffer_size: usize, + ) -> Self { + Self { + indexer: KvIndexer::new(token, kv_block_size, metrics), + event_buffer: Mutex::new(VecDeque::with_capacity(max_buffer_size)), + max_buffer_size, + } + } + + /// Get all buffered events (oldest first). + pub fn get_all_events_in_buffer(&self) -> Vec { + let buffer = self.event_buffer.lock().unwrap(); + buffer.iter().cloned().collect() + } + + /// Query events by ID range, returning events in `[start_id, end_id]` (both inclusive). + /// + /// This method attempts to serve the request from the in-memory event buffer when possible. + /// If the requested range extends beyond what's available in the buffer, a full tree dump + /// is performed instead. + /// + /// ### Arguments + /// + /// * `start_id` - Starting event ID (inclusive). If `None`, returns from oldest available. + /// * `end_id` - Ending event ID (inclusive). If `None`, returns up to newest available. + /// + /// ### Behavior + /// + /// - **Buffer path**: If `start_id >= first_buffered_id`, events are retrieved directly + /// from the buffer with their original event IDs. + /// + /// - **Tree dump path**: If the range extends before the buffer or no range is specified, + /// a full tree dump is performed. **Note**: Tree dumps generate synthetic 0-indexed + /// event IDs that do NOT correspond to the original event IDs. The entire tree state + /// is returned regardless of the requested range. + /// + /// ### Returns + /// + /// A vector of `RouterEvent`s. When served from buffer, events have their original IDs. + /// When served from tree dump, events have synthetic sequential IDs starting from 0. + pub async fn get_events_in_id_range( + &self, + start_id: Option, + end_id: Option, + ) -> Vec { + // Validate range if both specified + if let (Some(s), Some(e)) = (start_id, end_id) + && s > e + { + tracing::warn!( + start_id = s, + end_id = e, + "Requested start_id > end_id; returning empty result." + ); + return Vec::new(); + } + + // Check if we can serve from buffer + let buffer_range = { + let buffer = self.event_buffer.lock().unwrap(); + if buffer.is_empty() { + None + } else { + Some(( + buffer.front().unwrap().event.event_id, + buffer.back().unwrap().event.event_id, + )) + } + }; + + // Determine if request can be served from buffer + let can_use_buffer = match (start_id, buffer_range) { + // No start specified means we need everything from the beginning -> tree dump + (None, _) => false, + // Buffer is empty -> tree dump + (_, None) => false, + // start_id is within or after buffer range -> can use buffer + (Some(s), Some((first_buffered, _))) => s >= first_buffered, + }; + + if can_use_buffer { + // Serve from buffer - these have real event IDs + self.get_buffer_events_in_id_range(start_id, end_id) + } else { + // Must dump entire tree + if let (Some(s), Some(e)) = (start_id, end_id) { + tracing::warn!( + requested_start_id = s, + requested_end_id = e, + buffer_range = ?buffer_range, + "Requested event ID range extends before buffer; dumping entire tree. \ + Note: Tree dump returns synthetic 0-indexed event IDs, not original IDs." + ); + } else if start_id.is_some() || end_id.is_some() { + tracing::warn!( + requested_start_id = ?start_id, + requested_end_id = ?end_id, + buffer_range = ?buffer_range, + "Partial range specified but cannot serve from buffer; dumping entire tree. \ + Note: Tree dump returns synthetic 0-indexed event IDs, not original IDs." + ); + } + // Return full tree dump - no filtering since IDs are synthetic + self.dump_events().await.unwrap_or_default() + } + } + + /// Get events from the buffer in the range `[start_id, end_id]` (both inclusive). + pub fn get_buffer_events_in_id_range( + &self, + start_id: Option, + end_id: Option, + ) -> Vec { + let buffer = self.event_buffer.lock().unwrap(); + if buffer.is_empty() { + tracing::warn!("No events in buffer yet; returning empty result."); + return Vec::new(); + } + + let first_id = buffer.front().map(|e| e.event.event_id).unwrap(); + let last_id = buffer.back().map(|e| e.event.event_id).unwrap(); + + let start_id = start_id.unwrap_or(first_id); + let end_id = end_id.unwrap_or(last_id); + + if start_id > end_id { + tracing::warn!( + start_id, + end_id, + "Requested start_id > end_id; returning empty result." + ); + return Vec::new(); + } + + let start_idx = match buffer.binary_search_by_key(&start_id, |e| e.event.event_id) { + Ok(idx) => idx, + Err(_) if start_id < first_id => { + tracing::warn!( + start_id, + first_id, + "Requested start_id precedes buffer; clamping to oldest." + ); + 0 + } + Err(_) if start_id > last_id => { + tracing::error!( + start_id, + last_id, + "Requested start_id is newer than buffer; returning empty." + ); + return Vec::new(); + } + Err(insertion_point) => insertion_point, + }; + + // For inclusive end, we need idx + 1 when we find an exact match + let end_idx = match buffer.binary_search_by_key(&end_id, |e| e.event.event_id) { + Ok(idx) => idx + 1, // Include the matched element + Err(_) if end_id < first_id => { + return Vec::new(); + } + Err(_) if end_id > last_id => { + tracing::warn!( + end_id, + last_id, + "Requested end_id exceeds buffer; clamping to newest." + ); + buffer.len() + } + Err(insertion_point) => insertion_point, + }; + + buffer + .iter() + .skip(start_idx) + .take(end_idx.saturating_sub(start_idx)) + .cloned() + .collect() + } + + /// Record an event in the buffer + fn record_event(&self, event: RouterEvent) { + let mut buffer = self.event_buffer.lock().unwrap(); + + // Check that event id is consecutive to last one + if let Some(last_event) = buffer.back() + && event.event.event_id != last_event.event.event_id + 1 + { + let expected = last_event.event.event_id + 1; + tracing::error!( + worker_id = event.worker_id, + expected, + got = event.event.event_id, + "Non-consecutive KV event id; buffer may have gaps" + ); + } + tracing::info!( + "Recorded event {:?} in buffer, now size is {}", + event, + buffer.len() + ); + + // Add to back + buffer.push_back(event); + + // Remove from front if over capacity (circular buffer behavior) + while buffer.len() > self.max_buffer_size { + buffer.pop_front(); + } + } + + /// Apply event with buffering. + /// + /// This records the event in the buffer and forwards it to the underlying indexer. + pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> { + // Record in buffer + self.record_event(event.clone()); + + // Forward to underlying indexer + self.indexer + .event_sender() + .send(event) + .await + .map_err(|_| KvRouterError::IndexerOffline) + } + + /// Clear the event buffer. + pub fn clear_buffer(&self) { + let mut buffer = self.event_buffer.lock().unwrap(); + buffer.clear(); + } + + /// Get the current buffer size. + pub fn buffer_len(&self) -> usize { + let buffer = self.event_buffer.lock().unwrap(); + buffer.len() + } + + // Delegation methods to underlying KvIndexer + /// Get a sender for `RouterEvent`s. + pub fn event_sender(&self) -> mpsc::Sender { + self.indexer.event_sender() + } + + /// Get a sender for dump requests (snapshot events). + pub fn snapshot_event_sender(&self) -> mpsc::Sender { + self.indexer.snapshot_event_sender() + } + + /// Get a sender for worker removal requests. + pub fn remove_worker_sender(&self) -> mpsc::Sender { + self.indexer.remove_worker_sender() + } + + /// Get a sender for get workers requests. + pub fn get_workers_sender(&self) -> mpsc::Sender { + self.indexer.get_workers_sender() + } + + /// Get the KV block size. + pub fn block_size(&self) -> u32 { + self.indexer.block_size() + } +} + +#[cfg(test)] +mod local_kv_indexer_tests { + use super::*; + + fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer { + let indexer = LocalKvIndexer::new( + CancellationToken::new(), + 4, + Arc::new(KvIndexerMetrics::new_unregistered()), + 32, + ); + { + let mut buffer = indexer.event_buffer.lock().unwrap(); + for &id in ids { + buffer.push_back(RouterEvent::new( + 0, + KvCacheEvent { + event_id: id, + data: KvCacheEventData::Cleared, + dp_rank: 0, + }, + )); + } + } + indexer + } + + #[test] + fn returns_slice_within_range() { + let indexer = make_indexer_with_events(&[1, 2, 3, 4, 5]); + + // Test get_buffer_events_in_id_range (buffer-only queries) + // Range is [start, end] inclusive + let mut result = indexer.get_buffer_events_in_id_range(Some(2), Some(4)); + let mut ids: Vec = result + .iter() + .map(|router_event| router_event.event.event_id) + .collect(); + assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4] + + result = indexer.get_buffer_events_in_id_range(Some(2), Some(6)); + ids = result + .iter() + .map(|router_event| router_event.event.event_id) + .collect(); + assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max + + result = indexer.get_buffer_events_in_id_range(Some(0), Some(4)); + ids = result + .iter() + .map(|router_event| router_event.event.event_id) + .collect(); + assert_eq!(ids, vec![1, 2, 3, 4]); // clamp start to buffer min, inclusive end + + result = indexer.get_buffer_events_in_id_range(Some(3), Some(3)); + ids = result + .iter() + .map(|router_event| router_event.event.event_id) + .collect(); + assert_eq!(ids, vec![3]); // single element when start == end + + result = indexer.get_buffer_events_in_id_range(Some(5), Some(2)); + ids = result + .iter() + .map(|router_event| router_event.event.event_id) + .collect(); + assert!(ids.is_empty()); // return empty when start > end + } + + #[tokio::test] + async fn test_get_events_in_id_range_all_cases() { + use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash}; + + // Create indexer with small buffer (5 events max) + // This way older events will only be in the tree, not the buffer + let indexer = LocalKvIndexer::new( + CancellationToken::new(), + 4, // block_size + Arc::new(KvIndexerMetrics::new_unregistered()), + 5, // max_buffer_size - only keeps 5 most recent events + ); + + // Helper to create a test event + let make_event = |id: u64| { + RouterEvent::new( + 0, // worker_id + KvCacheEvent { + event_id: id, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(id * 100), + tokens_hash: LocalBlockHash(id * 200), + }], + }), + dp_rank: 0, + }, + ) + }; + + // Add 10 events (IDs 5-14) + // Buffer will only keep the last 5: events 10-14 + // Tree will have all blocks + for id in 5..15 { + indexer + .apply_event_with_buffer(make_event(id)) + .await + .unwrap(); + } + + // Wait for events to be processed by the tree + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Helper to extract event IDs from result + let get_ids = |events: Vec| -> Vec { + events.iter().map(|e| e.event.event_id).collect() + }; + + // Verify buffer state: should have events 10-14 (last 5) + let buffer_events = indexer.get_all_events_in_buffer(); + assert_eq!( + get_ids(buffer_events), + vec![10, 11, 12, 13, 14], + "Buffer should have events 10-14" + ); + + // ========== BUFFER PATH TESTS (start_id >= first_buffered) ========== + // Range is [start, end] inclusive + + // Test: start_id within buffer, no end + let result = indexer.get_events_in_id_range(Some(11), None).await; + assert_eq!( + get_ids(result), + vec![11, 12, 13, 14], + "start_id=11 (in buffer) should return [11, 14]" + ); + + // Test: start_id at buffer boundary + let result = indexer.get_events_in_id_range(Some(10), None).await; + assert_eq!( + get_ids(result), + vec![10, 11, 12, 13, 14], + "start_id=10 (buffer start) should return [10, 14]" + ); + + // Test: both start and end within buffer (inclusive) + let result = indexer.get_events_in_id_range(Some(11), Some(13)).await; + assert_eq!( + get_ids(result), + vec![11, 12, 13], + "range [11, 13] inclusive should return 3 events" + ); + + let result = indexer.get_events_in_id_range(Some(10), Some(14)).await; + assert_eq!( + get_ids(result), + vec![10, 11, 12, 13, 14], + "range [10, 14] should return all buffer events" + ); + + // ========== TREE DUMP PATH TESTS (range extends before buffer) ========== + // Note: Tree dumps return synthetic 0-indexed event IDs, so we just check + // that we get events back (the IDs won't match original IDs) + + // Test: (None, None) dumps entire tree + let result = indexer.get_events_in_id_range(None, None).await; + assert_eq!( + result.len(), + 10, + "(None, None) should dump entire tree (10 events)" + ); + + // Test: (None, Some(_)) dumps entire tree + let result = indexer.get_events_in_id_range(None, Some(8)).await; + assert_eq!( + result.len(), + 10, + "(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps" + ); + + // Test: start_id before buffer triggers tree dump + let result = indexer.get_events_in_id_range(Some(7), None).await; + assert_eq!( + result.len(), + 10, + "start_id=7 (before buffer) should dump entire tree" + ); + + let result = indexer.get_events_in_id_range(Some(5), Some(12)).await; + assert_eq!( + result.len(), + 10, + "range [5, 12] extending before buffer should dump entire tree" + ); + + // ========== EDGE CASES ========== + + // Single element when start == end (inclusive range) + let result = indexer.get_events_in_id_range(Some(12), Some(12)).await; + assert_eq!( + get_ids(result), + vec![12], + "start == end should return single event" + ); + + // Empty when start > end + let result = indexer.get_events_in_id_range(Some(15), Some(10)).await; + assert!(result.is_empty(), "start > end should return empty"); + + // Request beyond buffer but valid range -> buffer returns what it has + let result = indexer.get_events_in_id_range(Some(12), Some(100)).await; + assert_eq!( + get_ids(result), + vec![12, 13, 14], + "range with end beyond buffer should return available buffer events" + ); + } +} + +// Implement KvIndexerInterface by delegating to the underlying indexer +#[async_trait] +impl KvIndexerInterface for LocalKvIndexer { + async fn find_matches( + &self, + sequence: Vec, + ) -> Result { + self.indexer.find_matches(sequence).await + } + + async fn find_matches_for_request( + &self, + tokens: &[u32], + ) -> Result { + self.indexer.find_matches_for_request(tokens).await + } + + async fn apply_event(&mut self, event: RouterEvent) { + // Use the buffering version + let _ = self.apply_event_with_buffer(event).await; + } + + async fn remove_worker(&mut self, worker: WorkerId) { + let _ = self.indexer.remove_worker_sender().send(worker).await; + } + + fn shutdown(&mut self) { + // Note: Since indexer is Arc, we can't call mutable methods directly. + // The indexer will be shut down when the CancellationToken is cancelled + // or when the last Arc reference is dropped. + } + + async fn dump_events(&self) -> Result, KvRouterError> { + self.indexer.dump_events().await + } + + async fn process_routing_decision( + &self, + worker: WorkerWithDpRank, + local_hashes: Vec, + sequence_hashes: Vec, + ) -> Result<(), KvRouterError> { + // TODO I guess the local kvindexers have little use for this method? + // Keeping it here now to implement the trait fully + self.indexer + .process_routing_decision(worker, local_hashes, sequence_hashes) + .await + } + + async fn process_routing_decision_for_request( + &self, + tokens: &[u32], + worker: WorkerWithDpRank, + ) -> Result<(), KvRouterError> { + // TODO I guess the local kvindexers have little use for this method? + // Keeping it here now to implement the trait fully + self.indexer + .process_routing_decision_for_request(tokens, worker) + .await + } +} + #[derive(Debug, Clone)] pub struct ShardedMatchRequest { sequence: Vec, @@ -2978,3 +3661,158 @@ mod tests { assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2))); } } + +#[cfg(test)] +mod tests_local_indexer { + use super::*; + use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash}; + use tokio::time; + use tokio_util::sync::CancellationToken; + + fn setup() { + dynamo_runtime::logging::init(); + } + + fn make_blocks(hashes: Vec) -> Vec { + hashes + .iter() + .map(|i| KvCacheStoredBlockData { + tokens_hash: LocalBlockHash(*i), + block_hash: ExternalSequenceBlockHash(*i * 100), + }) + .collect() + } + + fn create_store_event( + worker_id: WorkerId, + event_id: u64, + hashes: Vec, + parent: Option, + ) -> RouterEvent { + RouterEvent { + worker_id, + event: KvCacheEvent { + event_id, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: parent, + blocks: make_blocks(hashes), + }), + dp_rank: 0, + }, + } + } + + #[tokio::test] + async fn test_local_indexer_buffer_and_serialization() { + // Tests components of the LocalKvIndexer query without using nats + + let worker_id = 42u64; + + // Create a local indexer + let token = CancellationToken::new(); + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100)); + + // Add events to local indexer's buffer + let test_event_1 = RouterEvent::new( + worker_id, + KvCacheEvent { + event_id: 1, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(100), + tokens_hash: LocalBlockHash(200), + }], + }), + dp_rank: 0, + }, + ); + + // Apply events with buffer + local_indexer + .apply_event_with_buffer(test_event_1) + .await + .unwrap(); + + // Wait for events to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + // Get buffered events (what the query service would return) + let buffered_events = local_indexer.get_all_events_in_buffer(); + + // Verify buffer contents + assert_eq!(buffered_events.len(), 1, "Buffer should have 1 event"); + assert_eq!(buffered_events[0].worker_id, worker_id); + assert_eq!(buffered_events[0].event.event_id, 1); + + // Build the response that would be sent + let response = WorkerKvQueryResponse { + events: buffered_events.clone(), + }; + + // Test serialization/deserialization (simulating NATS round-trip) + let serialized = serde_json::to_vec(&response).unwrap(); + let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap(); + + // Verify response correctness + assert_eq!(deserialized.events.len(), 1); + assert_eq!(deserialized.events[0].worker_id, worker_id); + assert_eq!(deserialized.events[0].event.event_id, 1); + + // Verify event data + match &deserialized.events[0].event.data { + KvCacheEventData::Stored(store_data) => { + assert_eq!(store_data.blocks.len(), 1); + assert_eq!(store_data.blocks[0].block_hash.0, 100); + assert_eq!(store_data.blocks[0].tokens_hash.0, 200); + } + _ => panic!("Expected Stored event"), + } + } + + #[tokio::test] + async fn test_gap_detection_per_worker() { + setup(); + + let token = CancellationToken::new(); + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + let indexer = KvIndexer::new(token.clone(), 4, metrics); + + let worker_a: WorkerId = 100; + let worker_b: WorkerId = 200; + let event_tx = indexer.event_sender(); + + // Worker A: events 1, 2, 3 (no gap) + for id in 1..=3 { + let event = create_store_event(worker_a, id, vec![id], None); + event_tx.send(event).await.unwrap(); + } + + // Worker B: events 1, then 5 (gap of 2, 3, 4) + let event_b1 = create_store_event(worker_b, 1, vec![10], None); + event_tx.send(event_b1).await.unwrap(); + + let event_b5 = create_store_event(worker_b, 5, vec![50], None); + event_tx.send(event_b5).await.unwrap(); + + // Give time for events to be processed + time::sleep(Duration::from_millis(20)).await; + + // Verify each worker has correct last_received_event_id + let last_ids = indexer.get_last_received_event_ids().await.unwrap(); + assert_eq!( + last_ids.get(&worker_a), + Some(&3), + "Worker A should have last_id = 3 (no gap)" + ); + assert_eq!( + last_ids.get(&worker_b), + Some(&5), + "Worker B should have last_id = 5 (despite gap)" + ); + + // Cleanup + token.cancel(); + } +} diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index fca3783fe6..b6f80a49a0 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -315,6 +315,9 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash { } } +// ------ +// Tests +// ------ #[cfg(test)] mod tests { use super::*; diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 703afcf867..0eb695610e 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -16,15 +16,22 @@ use tokio_util::sync::CancellationToken; use zeromq::{Socket, SocketRecv, SubSocket}; use dynamo_runtime::metrics::{MetricsHierarchy, prometheus_names::kvstats}; -use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher}; +use dynamo_runtime::traits::{ + DistributedRuntimeProvider, events::EventPublisher, events::EventSubscriber, +}; use dynamo_runtime::{ component::{Component, Namespace}, transports::nats::{NatsQueue, QUEUE_NAME, Slug}, }; +use futures::StreamExt; use crate::kv_router::{ - KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, - indexer::{RouterEvent, compute_block_hash_for_seq}, + KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE, + WORKER_KV_INDEXER_QUERY_SUBJECT, + indexer::{ + KvIndexerInterface, KvIndexerMetrics, LocalKvIndexer, RouterEvent, WorkerKvQueryRequest, + WorkerKvQueryResponse, compute_block_hash_for_seq, + }, protocols::*, scoring::LoadEvent, }; @@ -102,6 +109,15 @@ impl KvEventPublisher { component: Component, kv_block_size: u32, source_config: Option, + ) -> Result { + Self::new_with_local_indexer(component, kv_block_size, source_config, false) + } + + pub fn new_with_local_indexer( + component: Component, + kv_block_size: u32, + source_config: Option, + enable_local_indexer: bool, ) -> Result { let cancellation_token = CancellationToken::new(); @@ -110,6 +126,18 @@ impl KvEventPublisher { // Infer worker_id from component's connection let worker_id = component.drt().connection_id(); + tracing::info!( + worker_id, + component = component.name(), + "Initializing KvEventPublisher for worker {worker_id} in component {component}" + ); + + if enable_local_indexer { + tracing::info!( + "LocalKvIndexer enabled for worker {worker_id} in component {component}" + ); + } + // Create our event source (if any) let mut source = None; if let Some(config) = source_config { @@ -122,6 +150,36 @@ impl KvEventPublisher { )?); } + // Create local indexer if requested + let local_indexer = if enable_local_indexer { + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + Some(Arc::new(LocalKvIndexer::new( + cancellation_token.clone(), + kv_block_size, + metrics, + WORKER_KV_INDEXER_BUFFER_SIZE, + ))) + } else { + None + }; + + // Spawn runtime for router->local indexer comm if requested + let _local_indexer_query_handle = local_indexer.as_ref().map(|local_indexer_ref| { + let component = component.clone(); + let local_indexer = local_indexer_ref.clone(); + + component + .drt() + .runtime() + .secondary() + .spawn(start_worker_kv_query_service( + component, + worker_id, + local_indexer, + cancellation_token.clone(), + )) + }); + let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT)) .to_string() .replace("_", "-"); @@ -136,12 +194,20 @@ impl KvEventPublisher { // Connect the NatsQueue before passing it to the event processor let cancellation_token_clone = cancellation_token.clone(); + let local_indexer_clone = local_indexer.clone(); component.drt().runtime().secondary().spawn(async move { if let Err(e) = nats_queue.connect().await { tracing::error!("Failed to connect NatsQueue: {}", e); return; } - start_event_processor(nats_queue, worker_id, cancellation_token_clone, rx).await + start_event_processor( + nats_queue, + worker_id, + cancellation_token_clone, + rx, + local_indexer_clone, + ) + .await }); Ok(Self { @@ -182,6 +248,7 @@ async fn start_event_processor( worker_id: u64, cancellation_token: CancellationToken, mut rx: mpsc::UnboundedReceiver, + local_indexer: Option>, ) { loop { tokio::select! { @@ -195,17 +262,129 @@ async fn start_event_processor( break; }; - // Encapsulate in a router event and publish. + // Encapsulate in a router event. tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data); let router_event = RouterEvent::new(worker_id, event); + + // Apply to local indexer first (if present) + if let Some(indexer) = &local_indexer { + // Adds event into local indexer, and logs it into internal buffer + if let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await { + tracing::warn!( + "Failed to send event to local indexer for worker {}: {}", + worker_id, + e + ); + } + } + + // Then publish to NATS for global distribution if let Err(e) = publisher.publish(QUEUE_NAME, &router_event).await { - tracing::error!("Failed to publish event: {}", e); + tracing::error!("Failed to publish event to NATS: {}", e); } + } } } } +// Processor for Router -> LocalKvIndexer query service +async fn start_worker_kv_query_service( + component: Component, + worker_id: u64, + local_indexer: Arc, + cancellation_token: CancellationToken, +) { + // Create NATS subscriber on a subject specific to worker's id + let subject = format!("{}.{}", WORKER_KV_INDEXER_QUERY_SUBJECT, worker_id); + let mut subscriber = match component.subscribe(&subject).await { + Ok(sub) => sub, + Err(e) => { + tracing::error!("Failed to subscribe to {}: {}", subject, e); + return; // No ? because function doesn't return Result + } + }; + tracing::debug!( + "Query service on worker {} listening on NATS subject: {}", + worker_id, + subject + ); + + // Receive query request from router, retrieve event(s) from LocalKvIndexer, return response + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::info!("Router-Worker communication channel received cancellation signal"); + break; + } + + msg = subscriber.next() => { + let Some(msg) = msg else { + tracing::debug!("Router-Worker stream ended."); + break; + }; + + // deserialize from msg (async_nats::Message) + let request: WorkerKvQueryRequest = match serde_json::from_slice(&msg.payload) { + Ok(request) => request, + Err(e) => { + tracing::error!("Failed to deserialize WorkerKvQueryRequest: {}", e); + continue; + } + }; + + // TODO extract request event id range. For now, just debug print + tracing::debug!("Received WorkerKvQueryRequest: {:?}", request); + + // Resolve which events to return based on optional start/end ids + let events = match (request.start_event_id, request.end_event_id) { + (None, None) => { + match local_indexer.dump_events().await { + Ok(events) => events, + Err(err) => { + tracing::error!( + error = %err, + worker_id, + "Failed to dump events for WorkerKvQueryRequest; returning buffered events instead" + ); + local_indexer.get_all_events_in_buffer() + } + } + } + _ => { + local_indexer.get_events_in_id_range(request.start_event_id, request.end_event_id).await + } + }; + + // Build WorkerKvQueryResponse + let response = WorkerKvQueryResponse { events }; + + // Send reply back (if reply subject exists) + if let Some(reply_subject) = msg.reply { + let payload = match serde_json::to_vec(&response) { + Ok(p) => p, + Err(e) => { + tracing::error!("Failed to serialize response: {}", e); + continue; + } + }; + + // Publish through DRT/NATS directly instead of namespace (adds a prefix) + if let Err(e) = component + .drt() + .kv_router_nats_publish(reply_subject.to_string(), payload.into()) + .await + { + tracing::error!("Failed to send reply: {}", e); + } + } + + } + + } + } +} + // Error handling configuration for ZMQ operations const INITIAL_BACKOFF_MS: u64 = 10; const MAX_BACKOFF_MS: u64 = 5000; @@ -1008,7 +1187,9 @@ mod test_event_processing { #[cfg(test)] mod tests_startup_helpers { use super::*; - use crate::kv_router::protocols::ExternalSequenceBlockHash; + use crate::kv_router::KvIndexer; + use crate::kv_router::indexer::KvIndexerInterface; + use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash}; use async_trait; use bytes::Bytes; use std::sync::{Arc, Mutex}; @@ -1089,7 +1270,7 @@ mod tests_startup_helpers { tx.send(event).unwrap(); drop(tx); - let handle = tokio::spawn(start_event_processor(component, 1, token, rx)); + let handle = tokio::spawn(start_event_processor(component, 1, token, rx, None)); tokio::time::timeout(tokio::time::Duration::from_secs(1), handle) .await @@ -1102,6 +1283,300 @@ mod tests_startup_helpers { assert_eq!(subject, QUEUE_NAME); } + //-------------------------------------------------------------------- + // Test start_event_processor with local indexer + //-------------------------------------------------------------------- + #[tokio::test] + async fn test_start_event_processor_with_local_indexer() { + let (component, published) = MockComponent::new(); + + // Create a local indexer + let token = CancellationToken::new(); + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100)); + + // Create BlockStored event + let event = KvCacheEvent { + event_id: 1, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![ + KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(100), + tokens_hash: LocalBlockHash(200), + }, + KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(101), + tokens_hash: LocalBlockHash(201), + }, + ], + }), + dp_rank: 0, + }; + + let (tx, rx) = mpsc::unbounded_channel::(); + tx.send(event).unwrap(); + drop(tx); + + // Start event processor with local indexer + let handle = tokio::spawn(start_event_processor( + component, + 1, + token.clone(), + rx, + Some(local_indexer.clone()), // arc::clone just increments atomic counters + )); + + // Wait for processing + tokio::time::timeout(tokio::time::Duration::from_secs(1), handle) + .await + .unwrap() + .unwrap(); + + // Verify event was published to NATS (same as test_start_event_processor) + { + let published_events = published.lock().unwrap(); + assert_eq!(published_events.len(), 1); + let (subject, _) = &published_events[0]; + assert_eq!(subject, QUEUE_NAME); + } // drop lock + + // Verify event was applied to local indexer + // We can check by querying the workers that have blocks + let get_workers_tx = local_indexer.get_workers_sender(); + let mut found = false; + for _ in 0..20 { + // Try up to 20 times (200ms total) + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); + get_workers_tx + .send(crate::kv_router::indexer::GetWorkersRequest { resp: resp_tx }) + .await + .unwrap(); + let workers: Vec = resp_rx.await.unwrap(); + + if workers.contains(&1) { + found = true; + break; + } + + // Wait before retrying + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + + // Worker 1 should be in the set (we used worker_id=1) + assert!( + found, + "Worker 1 was not found in the indexer after processing" + ); + + // Cleanup + token.cancel(); + } + + //-------------------------------------------------------------------- + // Test BlockRemoved event with local indexer + //-------------------------------------------------------------------- + #[tokio::test] + async fn test_event_processor_block_removed_with_local_indexer() { + let (component, published) = MockComponent::new(); + + let token = CancellationToken::new(); + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100)); + + // First, store a block + let store_event = KvCacheEvent { + event_id: 1, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(100), + tokens_hash: LocalBlockHash(200), + }], + }), + dp_rank: 0, + }; + + let (tx, rx) = mpsc::unbounded_channel::(); + tx.send(store_event).unwrap(); + + // Start event processor with local indexer + let handle = tokio::spawn(start_event_processor( + component, + 1, + token.clone(), + rx, + Some(local_indexer.clone()), + )); + + // Then remove same event + let remove_event = KvCacheEvent { + event_id: 2, + data: KvCacheEventData::Removed(KvCacheRemoveData { + block_hashes: vec![ExternalSequenceBlockHash(100)], + }), + dp_rank: 0, + }; + tx.send(remove_event).unwrap(); + drop(tx); + + tokio::time::timeout(tokio::time::Duration::from_secs(1), handle) + .await + .unwrap() + .unwrap(); + + // Local indexer should have no block + let mut no_blocks = false; + for _ in 0..20 { + // Try up to 20 times (200ms total) + let scores = local_indexer + .find_matches(vec![LocalBlockHash(200)]) + .await + .unwrap(); + if scores.scores.is_empty() { + no_blocks = true; + break; + } + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + assert!(no_blocks, "worker should have no blocks after removal"); + + // Global kvindexer should have recieved two events (create/remove) + let published = published.lock().unwrap(); + assert_eq!( + published.len(), + 2, + "expected 2 published events, found {}", + published.len() + ); + + token.cancel(); + } + + //-------------------------------------------------------------------- + // Test AllBlocksCleared event with local indexer + //-------------------------------------------------------------------- + #[tokio::test] + async fn test_event_processor_all_blocks_cleared_with_local_indexer() { + let (component, published) = MockComponent::new(); + + let token = CancellationToken::new(); + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100)); + + // Store a block + let store_event = KvCacheEvent { + event_id: 1, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(100), + tokens_hash: LocalBlockHash(200), + }], + }), + dp_rank: 0, + }; + + let (tx, rx) = mpsc::unbounded_channel::(); + tx.send(store_event).unwrap(); + + // Clear all blocks + let clear_event = KvCacheEvent { + event_id: 2, + data: KvCacheEventData::Cleared, + dp_rank: 0, + }; + tx.send(clear_event).unwrap(); + drop(tx); + + // Create event processor and wait + let handle = tokio::spawn(start_event_processor( + component, + 1, + token.clone(), + rx, + Some(local_indexer.clone()), + )); + + tokio::time::timeout(tokio::time::Duration::from_secs(1), handle) + .await + .unwrap() + .unwrap(); + + // Local indexer should have no block + let mut no_blocks = false; + for _ in 0..20 { + // Try up to 20 times (200ms total) + let scores = local_indexer + .find_matches(vec![LocalBlockHash(200)]) + .await + .unwrap(); + if scores.scores.is_empty() { + no_blocks = true; + break; + } + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + assert!(no_blocks, "worker should have no blocks after clearing"); + + // Global kvindexer should have recieved two events (create/remove) + let published = published.lock().unwrap(); + assert_eq!( + published.len(), + 2, + "expected 2 published events, found {}", + published.len() + ); + + token.cancel(); + } + + //-------------------------------------------------------------------- + // Test that local indexer failure doesn't break NATS publishing + //-------------------------------------------------------------------- + #[tokio::test] + async fn test_event_processor_local_indexer_failure_continues() { + let (component, published) = MockComponent::new(); + + let token = CancellationToken::new(); + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100)); + + // cancel indexer immediately to simulate failure + token.cancel(); + + let event = KvCacheEvent { + event_id: 1, + data: KvCacheEventData::Removed(KvCacheRemoveData { + block_hashes: vec![ExternalSequenceBlockHash(1)], + }), + dp_rank: 0, + }; + + let new_token = CancellationToken::new(); + let (tx, rx) = mpsc::unbounded_channel::(); + tx.send(event).unwrap(); + drop(tx); + + // Despite local indexer being cancelled, event processor should continue + let handle = tokio::spawn(start_event_processor( + component, + 1, + new_token, + rx, + Some(local_indexer), + )); + + tokio::time::timeout(tokio::time::Duration::from_secs(1), handle) + .await + .unwrap() + .unwrap(); + + // Verify event was still published to NATS despite local indexer failure + let published_events = published.lock().unwrap(); + assert_eq!(published_events.len(), 1); + } + //-------------------------------------------------------------------- // Test start_zmq_listener without a real socket // (feed it frames through a ZMQ PAIR tcp socket) @@ -1185,6 +1660,215 @@ mod tests_startup_helpers { token.cancel(); let _ = listener_handle.await; } + + //-------------------------------------------------------------------- + // Test distributed recovery: Router queries worker's LocalKvIndexer after outage + //-------------------------------------------------------------------- + #[tokio::test] + async fn test_distributed_kvindexer_recovery_from_outage() { + let worker_1_id = 1u64; + let block_size = 4u32; + let token = CancellationToken::new(); + + // === SETUP: Worker Components === + let (worker_component, worker_published) = MockComponent::new(); + let local_indexer_1 = Arc::new(LocalKvIndexer::new( + token.clone(), + block_size, + Arc::new(KvIndexerMetrics::new_unregistered()), + 100, // buffer size + )); + + let (worker_tx, worker_rx) = mpsc::unbounded_channel::(); + + // Start worker's event processor + tokio::spawn(start_event_processor( + worker_component, + worker_1_id, + token.clone(), + worker_rx, + Some(local_indexer_1.clone()), + )); + + // === SETUP: Router Components === + let router_indexer = Arc::new(KvIndexer::new( + token.clone(), + block_size, + Arc::new(KvIndexerMetrics::new_unregistered()), + )); + + // === STEP 1: Normal Operation === + let event_1 = KvCacheEvent { + event_id: 1, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![ + KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(100), + tokens_hash: LocalBlockHash(200), + }, + KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(101), + tokens_hash: LocalBlockHash(201), + }, + ], + }), + dp_rank: 0, + }; + + worker_tx.send(event_1.clone()).unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Simulate JetStream: forward worker's published event to router + let (subject, bytes) = { + let published = worker_published.lock().unwrap(); + assert_eq!(published.len(), 1, "Worker should have published 1 event"); + (published[0].0.clone(), published[0].1.clone()) + }; // drop worker_published before await + assert_eq!(subject, QUEUE_NAME); + + let router_event: RouterEvent = rmp_serde::from_slice(&bytes).unwrap(); + router_indexer + .event_sender() + .send(router_event) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // assert: Router's indexer has event + let get_workers_tx = router_indexer.get_workers_sender(); + let mut router_has_worker = false; + for _ in 0..20 { + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); + get_workers_tx + .send(crate::kv_router::indexer::GetWorkersRequest { resp: resp_tx }) + .await + .unwrap(); + let workers: Vec = resp_rx.await.unwrap(); + if workers.contains(&worker_1_id) { + router_has_worker = true; + break; + } + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + assert!( + router_has_worker, + "Router should see worker 1 after normal operation" + ); + + // assert: Worker's local indexer buffered event + let buffered = local_indexer_1.get_all_events_in_buffer(); + assert_eq!(buffered.len(), 1, "Local indexer should buffer 1 event"); + + // === STEP 2 & 3: Simulate Outage - Stop forwarding to router === + let event_2 = KvCacheEvent { + event_id: 2, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![ + KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(100), // Shared prefix + tokens_hash: LocalBlockHash(200), + }, + KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(102), // New block + tokens_hash: LocalBlockHash(202), + }, + ], + }), + dp_rank: 0, + }; + + worker_tx.send(event_2.clone()).unwrap(); // send to worker but not to router + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // assert: Worker published event_2 to "NATS" (MockComponent) + { + let published = worker_published.lock().unwrap(); + assert_eq!( + published.len(), + 2, + "Worker should have published 2 events total" + ); + } + + // assert: Worker's local indexer has both events + let buffered = local_indexer_1.get_all_events_in_buffer(); + assert_eq!( + buffered.len(), + 2, + "Local indexer should have both events during outage" + ); + + // assert: Router DOESN'T have event_2 + let block_hashes_2 = vec![LocalBlockHash(200), LocalBlockHash(202)]; + let overlap = router_indexer + .find_matches(block_hashes_2.clone()) + .await + .unwrap(); + let router_overlap = overlap + .scores + .get(&crate::kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id)) + .copied() + .unwrap_or(0); + assert_eq!( + router_overlap, 1, + "Router should only see 1 shared block (not the new block from event_2)" + ); + + // === STEP 4 & 5: Recovery - Query last received event IDs and fetch missed events === + // Step 4a: Router queries its last received event ID per worker + let last_ids = router_indexer.get_last_received_event_ids().await.unwrap(); + let last_known_id = last_ids.get(&worker_1_id).copied().unwrap_or(0); + assert_eq!( + last_known_id, 1, + "Router should have last_received_event_id = 1 for worker (only event_1 was forwarded)" + ); + + // Step 4b: Query worker's local indexer for events after last_known_id + let missed_events = local_indexer_1 + .get_events_in_id_range(Some(last_known_id + 1), None) + .await; + assert_eq!( + missed_events.len(), + 1, + "Should get 1 missed event (event_2 with id=2)" + ); + + // Step 5: Apply missed events to router + for router_event in missed_events { + router_indexer + .event_sender() + .send(router_event) + .await + .unwrap(); + } + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // assert: Router now has complete state + let overlap = router_indexer.find_matches(block_hashes_2).await.unwrap(); + let router_overlap_after = overlap + .scores + .get(&crate::kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id)) + .copied() + .unwrap_or(0); + assert_eq!( + router_overlap_after, 2, + "Router should now see both blocks after recovery" + ); + + // assert: Router's last_received_event_id is updated after recovery + let last_ids_after = router_indexer.get_last_received_event_ids().await.unwrap(); + assert_eq!( + last_ids_after.get(&worker_1_id), + Some(&2), + "Router should have last_received_event_id = 2 after recovery" + ); + + token.cancel(); + } } #[cfg(test)] @@ -1430,3 +2114,402 @@ mod test_integration_publisher { ); } } + +#[cfg(all(test, feature = "integration"))] +mod test_integration_publisher_with_kvindexer { + use super::*; + + use crate::kv_router::scheduler::DefaultWorkerSelector; + use crate::kv_router::{KvPushRouter, KvRouter, KvRouterConfig}; + use crate::local_model::LocalModelBuilder; + use crate::local_model::runtime_config::ModelRuntimeConfig; + use crate::mocker::engine::{MOCKER_COMPONENT, MockVllmEngine}; + use crate::mocker::protocols::MockEngineArgs; + use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}; + use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; + use dynamo_runtime::distributed_test_utils::create_test_shared_drt_async; + use dynamo_runtime::engine::AsyncEngine; + use dynamo_runtime::pipeline::{Context, PushRouter, RouterMode, network::Ingress}; + use dynamo_runtime::protocols::annotated::Annotated; + + /// Integration test: KvPushRouter end-to-end routing with mock engines. + #[tokio::test(flavor = "multi_thread")] + #[ignore] // Requires NATS/etcd. Run with: cargo test --package dynamo-llm --lib --features integration test_distributed_kvindexer_e2e -- --ignored --nocapture + async fn test_distributed_kvindexer_e2e() -> anyhow::Result<()> { + const BLOCK_SIZE: u32 = 4; + const NUM_REQUESTS: usize = 4; + + dynamo_runtime::logging::init(); + + // === SETUP: Distributed runtimes and namespaces === + let shared_store_dir = tempfile::tempdir()?; + let shared_store_path = shared_store_dir.path().to_path_buf(); + + // Make both runtimes point at the same file-backed storage backend so worker + // registrations and heartbeats remain visible to every DRT instance. + let distributed1 = create_test_shared_drt_async(&shared_store_path).await; + let distributed2 = create_test_shared_drt_async(&shared_store_path).await; + let component1 = distributed1 + .namespace("test_e2e_router")? + .component(MOCKER_COMPONENT)?; + let component2 = distributed2 + .namespace("test_e2e_router")? + .component(MOCKER_COMPONENT)?; + + // === SETUP: Start mocker workers === + let mocker_args = MockEngineArgs::builder() + .block_size(BLOCK_SIZE as usize) + .dp_size(1) // single worker per runtime + .enable_prefix_caching(true) + .enable_local_indexer(true) // affects scheduler/publisher args + .build()?; + + let worker_components = vec![component1.clone(), component2.clone()]; + let mut server_handles = Vec::new(); + let mut worker_ids = Vec::new(); + + for comp in worker_components { + let engine = Arc::new(MockVllmEngine::new(mocker_args.clone())); + engine.start(comp.clone()).await?; + tracing::info!("MockVllmEngine started for {:?}", comp); + + // Register MDC with runtime_config so router can discover enable_local_indexer. + // (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.) + // This inlines code which in the Python path would be performed by: + // - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs + // - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery + let endpoint = comp.endpoint("generate"); + let runtime_config = ModelRuntimeConfig { + enable_local_indexer: true, + ..Default::default() + }; + let mut builder = LocalModelBuilder::default(); + builder + .model_name(Some("mock".to_string())) + .kv_cache_block_size(Some(BLOCK_SIZE)) + .runtime_config(runtime_config); + let mut local_model = builder.build().await?; + local_model + .attach( + &endpoint, + crate::model_type::ModelType::Chat, + crate::model_type::ModelInput::Tokens, + None, + ) + .await?; + + let ingress = Ingress::for_engine(engine.clone())?; + let endpoint_component = comp.clone(); + let handle = tokio::spawn(async move { + if let Err(e) = endpoint_component + .endpoint("generate") + .endpoint_builder() + .handler(ingress) + .start() + .await + { + tracing::error!("Generate endpoint failed: {e}"); + } + }); + server_handles.push(handle); + worker_ids.push(comp.drt().connection_id()); + } + tracing::info!("Generate endpoint servers launched"); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // === SETUP: Build KvPushRouter === + let router_distributed = create_test_shared_drt_async(&shared_store_path).await; + let router_namespace = router_distributed.namespace("test_e2e_router")?; + let backend_component = router_namespace.component(MOCKER_COMPONENT)?; + let backend_endpoint = backend_component.endpoint("generate"); + let client = backend_endpoint.client().await?; + let kv_router_config = KvRouterConfig::default(); + let selector = Box::new(DefaultWorkerSelector::new(Some(kv_router_config))); + let consumer_id = format!("test-router-{}", router_distributed.connection_id()); + + let kv_router: Arc = Arc::new( + KvRouter::new( + backend_endpoint.clone(), + client.clone(), + BLOCK_SIZE, + Some(selector), + Some(kv_router_config), + consumer_id, + ) + .await?, + ); + + let push_router = + PushRouter::>::from_client_with_threshold( + client, + RouterMode::KV, + None, + None, + ) + .await?; + + let kv_push_router = KvPushRouter::new(push_router, kv_router.clone()); + + // ===== TEST PART 1: ROUTE & SEND REQUESTS TO WORKERS (ROUTER -> WORKER) ===== + let create_request = |tokens: Vec| { + PreprocessedRequest::builder() + .model("mock".to_string()) + .token_ids(tokens) + .stop_conditions(StopConditions { + max_tokens: Some(10), + ..Default::default() + }) + .sampling_options(SamplingOptions::default()) + .output_options(OutputOptions::default()) + .build() + .unwrap() + }; // from mocker/engine.rs + + for i in 0..NUM_REQUESTS { + tracing::info!("Sending routed request {}", i + 1); + let tokens = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, i as u32]; + let request = create_request(tokens.clone()); + + let response_stream = kv_push_router.generate(Context::new(request)).await?; + let responses: Vec> = response_stream.collect().await; + assert!( + !responses.is_empty(), + "Request {} should produce at least one response", + i + 1 + ); + } + + tracing::info!("KvPushRouter generate() succeeded for {NUM_REQUESTS} requests"); + + // ===== TEST PART 2: QUERY WORKER-LOCAL KVINDEXERS DIRECTLY ===== + // TODO: This could be refactored as router function (e.g. router.refresh_from_worker(worker_id)) + // (which should also update the global kvIndexer with the buffer from the local kvIndexer) + let mut best_worker_info: Option<(u64, usize)> = None; + + // Exactly one worker should have been routed requests. Find that worker + for &worker_id in &worker_ids { + let response = kv_router + .query_worker_local_kv(worker_id, None, None) + .await?; + if response.events.is_empty() { + continue; + } + + let event_count = response.events.len(); + tracing::info!( + worker_id, + events = event_count, + "Worker query on worker {worker_id} returned buffered KV events" + ); + best_worker_info = Some((worker_id, event_count)); + break; + } + + // Verify that only one worker has KV events in buffer + let (best_worker_id, best_worker_event_count) = + best_worker_info.expect("At least one worker should have buffered KV events"); + + tracing::info!( + "Best worker is {best_worker_id} with {best_worker_event_count} buffered KV events" + ); + + for &worker_id in &worker_ids { + if worker_id == best_worker_id { + continue; + } + + let response = kv_router + .query_worker_local_kv(worker_id, None, None) + .await?; + assert!( + response.events.is_empty(), + "Worker {worker_id} should not report buffered KV events; best worker {best_worker_id} reported {best_worker_event_count}" + ); + } + + // === Cleanup === + for handle in server_handles { + handle.abort(); + } + distributed1.shutdown(); + distributed2.shutdown(); + router_distributed.shutdown(); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + #[ignore] + async fn test_distributed_kvindexer_e2e_startup() -> anyhow::Result<()> { + const BLOCK_SIZE: u32 = 4; + + dynamo_runtime::logging::init(); + + // === SETUP: Distributed runtimes and namespaces === + let shared_store_dir = tempfile::tempdir()?; + let shared_store_path = shared_store_dir.path().to_path_buf(); + + // Use a unique namespace per test run for full isolation + let test_namespace = format!("test_e2e_{}", uuid::Uuid::new_v4().simple()); + + // Make both runtimes point at the same file-backed storage backend so worker + // registrations and heartbeats remain visible to every DRT instance. + let distributed1 = create_test_shared_drt_async(&shared_store_path).await; + let distributed2 = create_test_shared_drt_async(&shared_store_path).await; + let component1 = distributed1 + .namespace(&test_namespace)? + .component(MOCKER_COMPONENT)?; + let component2 = distributed2 + .namespace(&test_namespace)? + .component(MOCKER_COMPONENT)?; + + // === SETUP: Start mocker workers === + let mocker_args = MockEngineArgs::builder() + .block_size(BLOCK_SIZE as usize) + .dp_size(1) // single worker per runtime + .enable_prefix_caching(true) + .enable_local_indexer(true) // affects scheduler/publisher args + .build()?; + + let worker_components = vec![component1.clone(), component2.clone()]; + let mut server_handles = Vec::new(); + let mut worker_ids = Vec::new(); + + for comp in worker_components { + let engine: Arc = Arc::new(MockVllmEngine::new(mocker_args.clone())); + engine.start(comp.clone()).await?; + tracing::info!("MockVllmEngine started for {:?}", comp); + + // Register MDC with runtime_config so router can discover enable_local_indexer. + // (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.) + // This inlines code which in the Python path would be performed by: + // - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs + // - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery + let endpoint = comp.endpoint("generate"); + let runtime_config = ModelRuntimeConfig { + enable_local_indexer: true, + ..Default::default() + }; + let mut builder = LocalModelBuilder::default(); + builder + .model_name(Some("mock".to_string())) + .kv_cache_block_size(Some(BLOCK_SIZE)) + .runtime_config(runtime_config); + let mut local_model = builder.build().await?; + local_model + .attach( + &endpoint, + crate::model_type::ModelType::Chat, + crate::model_type::ModelInput::Tokens, + None, + ) + .await?; + + let ingress = Ingress::for_engine(engine.clone())?; + let endpoint_component = comp.clone(); + let handle = tokio::spawn(async move { + if let Err(e) = endpoint_component + .endpoint("generate") + .endpoint_builder() + .handler(ingress) + .start() + .await + { + tracing::error!("Generate endpoint failed: {e}"); + } + }); + server_handles.push(handle); + worker_ids.push(comp.drt().connection_id()); + } + tracing::info!("Generate endpoint servers launched"); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // === STEP 1: Send request to worker_ids[0] to populate its local indexer === + // This simulates a situation where KvPushRouter is initialized + // to route to workers which already have KV events + let pre_router_distributed = create_test_shared_drt_async(&shared_store_path).await; + let pre_backend_endpoint = pre_router_distributed + .namespace(&test_namespace)? + .component(MOCKER_COMPONENT)? + .endpoint("generate"); + let pre_client = pre_backend_endpoint.client().await?; + + // Create a PushRouter to send requests directly to a specific worker + let pre_push_router = + PushRouter::>::from_client_with_threshold( + pre_client, + RouterMode::Random, // We'll use direct() so mode doesn't matter + None, + None, + ) + .await?; + + // Force sending one requests each to the two workers + for &worker_id in &worker_ids { + let tokens: Vec = vec![0, 1, 2, 3]; + let request = PreprocessedRequest::builder() + .model("mock".to_string()) + .token_ids(tokens.clone()) + .sampling_options(SamplingOptions::default()) + .output_options(OutputOptions::default()) + .stop_conditions(StopConditions { + max_tokens: Some(5), + ..Default::default() + }) + .build()?; + let response_stream = pre_push_router + .direct(Context::new(request), worker_id) + .await?; + // Consume the stream to complete the request + let _responses: Vec<_> = response_stream.collect().await; + tracing::debug!( + "Sent request {:?} directly to worker {} to populate its local indexer", + tokens, + worker_id + ); + } + tokio::time::sleep(Duration::from_millis(1000)).await; + + // === SETUP: Build KvPushRouter === + let router_distributed = create_test_shared_drt_async(&shared_store_path).await; + let router_namespace = router_distributed.namespace(&test_namespace)?; + let backend_component = router_namespace.component(MOCKER_COMPONENT)?; + let backend_endpoint = backend_component.endpoint("generate"); + let client = backend_endpoint.client().await?; + let kv_router_config = KvRouterConfig::default(); + let selector = Box::new(DefaultWorkerSelector::new(Some(kv_router_config))); + let consumer_id = format!("test-router-{}", router_distributed.connection_id()); + + let kv_router: Arc = Arc::new( + KvRouter::new( + backend_endpoint.clone(), + client.clone(), + BLOCK_SIZE, + Some(selector), + Some(kv_router_config), + consumer_id, + ) + .await?, + ); + + // At this point kvrouter's indexer should already have the + // events stored in the workers, due to the catch-up built into KvRouter::new. + // Each request generates 2 events: input block (parent_hash: None) + output block (parent_hash: Some) + // With 2 workers, that's 4 events total. + let global_kv_events = kv_router.indexer.dump_events().await?; + tracing::debug!("Global KV events: {:?}", global_kv_events); + assert_eq!(global_kv_events.len(), 4); // 2 workers × 2 events per request (input + output) + + // === Cleanup === + for handle in server_handles { + handle.abort(); + } + distributed1.shutdown(); + distributed2.shutdown(); + router_distributed.shutdown(); + + Ok(()) + } +} diff --git a/lib/llm/src/kv_router/subscriber.rs b/lib/llm/src/kv_router/subscriber.rs index d9afb3fc56..ba2ad43676 100644 --- a/lib/llm/src/kv_router/subscriber.rs +++ b/lib/llm/src/kv_router/subscriber.rs @@ -1,9 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Background processes for the KV Router including event consumption and snapshot uploads. - -use std::{collections::HashSet, time::Duration}; +use std::{collections::HashMap, collections::HashSet, time::Duration}; use anyhow::Result; use dynamo_runtime::{ @@ -24,6 +22,7 @@ use crate::kv_router::{ indexer::{DumpRequest, GetWorkersRequest, RouterEvent}, protocols::WorkerId, router_discovery_query, + worker_query::WorkerQueryClient, }; /// Delay between snapshot reads to verify stability @@ -33,6 +32,163 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10; const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1); const CHECK_INTERVAL_JITTER_MS: i64 = 100; +// ============================================================================ +// Local KvIndexer-based Recovery +// ============================================================================ + +/// Recover missed events from all workers with local indexers. +/// +/// This function should be called on router startup to catch up on any events +/// that were missed while the router was offline. +/// +/// # Arguments +/// +/// * `worker_query_client` - Client for querying worker local indexers +/// * `last_received_event_ids` - Map of worker ID to last received event ID +/// * `worker_ids` - List of worker IDs to recover from +/// * `event_tx` - Channel to send recovered events to the indexer +/// +/// # Returns +/// +/// Total number of events recovered across all workers +pub async fn recover_from_all_workers( + worker_query_client: &WorkerQueryClient, + last_received_event_ids: &HashMap, + worker_ids: &Vec, + event_tx: &mpsc::Sender, +) -> usize { + let mut total_recovered = 0; + let mut successful_workers = 0; + let mut failed_workers = 0; + + for &worker_id in worker_ids { + // Skip workers without local indexer + if !worker_query_client.has_local_indexer(worker_id) { + tracing::debug!( + worker_id, + "Skipping recovery - worker does not have local indexer enabled" + ); + continue; + } + + // If we haven't seen any events from this worker, start from beginning (None) + // If we've seen events, start from last_known_id + 1 + let start_event_id = last_received_event_ids + .get(&worker_id) + .map(|&last_id| last_id + 1); + + match recover_from_worker( + worker_query_client, + worker_id, + start_event_id, + None, // Get all events after start_event_id + event_tx, + ) + .await + { + Ok(count) => { + total_recovered += count; + if count > 0 { + successful_workers += 1; + } + } + Err(_) => { + failed_workers += 1; + } + } + } + + // Log summary + if total_recovered > 0 || failed_workers > 0 { + tracing::info!( + total_recovered, + successful_workers, + failed_workers, + "Startup recovery completed" + ); + } + + total_recovered +} + +/// Recover missed KV events from a specific worker. +/// +/// # Arguments +/// +/// * `worker_query_client` - Client for querying worker local indexers +/// * `worker_id` - The worker to recover from +/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning +/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all +/// * `event_tx` - Channel to send recovered events to the indexer +/// +/// # Returns +/// +/// Number of events recovered, or error if recovery failed +pub async fn recover_from_worker( + worker_query_client: &WorkerQueryClient, + worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, + event_tx: &mpsc::Sender, +) -> Result { + if worker_query_client.has_local_indexer(worker_id) { + tracing::debug!( + worker_id, + start_event_id = ?start_event_id, + end_event_id = ?end_event_id, + "Attempting recovery from worker" + ); + } else { + tracing::warn!( + "Worker {} does not have local indexer enabled, skipping recovery", + worker_id + ); + return Ok(0); + } + + // Query worker for events in range + let response = worker_query_client + .query_worker(worker_id, start_event_id, end_event_id) + .await?; + + let events_count = response.events.len(); + + if events_count == 0 { + tracing::debug!( + worker_id, + start_event_id = ?start_event_id, + "No missed events to recover from worker" + ); + return Ok(0); + } + + tracing::info!( + worker_id, + start_event_id = ?start_event_id, + events_count, + "Recovered {} missed events from worker", + events_count + ); + + // Apply recovered events to the indexer + for event in response.events { + if let Err(e) = event_tx.send(event).await { + tracing::error!( + worker_id, + error = %e, + "Failed to send recovered event to indexer" + ); + anyhow::bail!("Failed to send recovered event: {}", e); + } + } + + Ok(events_count) +} + +// ============================================================================ +// Snapshot Management +// ============================================================================ + /// Download a stable snapshot from object store and send events to the indexer. /// Retries until two consecutive reads match or max attempts is reached. async fn download_stable_snapshot( diff --git a/lib/llm/src/kv_router/worker_query.rs b/lib/llm/src/kv_router/worker_query.rs new file mode 100644 index 0000000000..7fb9a312e8 --- /dev/null +++ b/lib/llm/src/kv_router/worker_query.rs @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; + +use anyhow::{Context, Result}; +use dynamo_runtime::component::Component; +use dynamo_runtime::traits::DistributedRuntimeProvider; +use dynamo_runtime::traits::events::EventPublisher; +use tokio::sync::watch; + +use crate::kv_router::WORKER_KV_INDEXER_QUERY_SUBJECT; +use crate::kv_router::indexer::{WorkerKvQueryRequest, WorkerKvQueryResponse}; +use crate::kv_router::protocols::WorkerId; +use crate::local_model::runtime_config::ModelRuntimeConfig; + +/// Router-side client for querying worker local KV indexers +/// +/// Performs request/reply communication with workers via NATS. +/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data) +/// The client is spawned by KvRouter; it watches same discovery stream as the router. +pub struct WorkerQueryClient { + component: Component, + /// Watch receiver for enable_local_indexer state per worker + model_runtime_config_rx: watch::Receiver>, +} + +impl WorkerQueryClient { + /// Create a new WorkerQueryClient with a watch receiver for local indexer states + pub fn new( + component: Component, + model_runtime_config_rx: watch::Receiver>, + ) -> Self { + Self { + component, + model_runtime_config_rx, + } + } + + /// Check if a worker has local indexer enabled + pub fn has_local_indexer(&self, worker_id: WorkerId) -> bool { + self.model_runtime_config_rx + .borrow() + .get(&worker_id) + .map(|config| config.enable_local_indexer) + .unwrap_or(false) + } + + /// Query a specific worker's local KV indexer and return its buffered events. + /// Returns an error if the worker does not have enable_local_indexer=true. + pub async fn query_worker( + &self, + worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, + ) -> Result { + // Check if worker has local indexer enabled + if !self.has_local_indexer(worker_id) { + anyhow::bail!( + "Worker {} does not have local indexer enabled (enable_local_indexer=false or not set in MDC user_data)", + worker_id + ); + } + + // Match worker's subscribe format + let subject_str = format!("{}.{}", WORKER_KV_INDEXER_QUERY_SUBJECT, worker_id); // see publisher.rs/start_worker_kv_query_service() + let subject = format!("{}.{}", self.component.subject(), subject_str); + + tracing::debug!( + "Router sending query request to worker {} on NATS subject: {}", + worker_id, + subject + ); + + // Create and serialize request + let request = WorkerKvQueryRequest { + worker_id, + start_event_id, + end_event_id, + }; + let request_bytes = + serde_json::to_vec(&request).context("Failed to serialize WorkerKvQueryRequest")?; + + // Send NATS request with timeout using DRT helper + let timeout = tokio::time::Duration::from_secs(1); + let response_msg = self + .component + .drt() + .kv_router_nats_request(subject.clone(), request_bytes.into(), timeout) + .await + .with_context(|| { + format!( + "Failed to send request to worker {} on subject {}", + worker_id, subject + ) + })?; + + // Deserialize response + let response: WorkerKvQueryResponse = serde_json::from_slice(&response_msg.payload) + .context("Failed to deserialize WorkerKvQueryResponse")?; + + Ok(response) + } +} diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 3c0adf0b26..9ab5e7f6be 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -234,6 +234,7 @@ impl LocalModelBuilder { self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64); self.runtime_config.max_num_batched_tokens = mocker_engine_args.max_num_batched_tokens.map(|v| v as u64); + self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer; self.runtime_config.data_parallel_size = mocker_engine_args.dp_size; self.media_decoder = Some(MediaDecoder::default()); self.media_fetcher = Some(MediaFetcher::default()); diff --git a/lib/llm/src/local_model/runtime_config.rs b/lib/llm/src/local_model/runtime_config.rs index 833465a672..482b77578f 100644 --- a/lib/llm/src/local_model/runtime_config.rs +++ b/lib/llm/src/local_model/runtime_config.rs @@ -23,6 +23,10 @@ pub struct ModelRuntimeConfig { #[serde(default = "default_data_parallel_size")] pub data_parallel_size: u32, + /// Enable worker-local KV indexer for tracking this worker's own KV cache state + #[serde(default)] + pub enable_local_indexer: bool, + /// Mapping of engine-specific runtime configs #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub runtime_data: HashMap, @@ -51,6 +55,7 @@ impl Default for ModelRuntimeConfig { tool_call_parser: None, reasoning_parser: None, data_parallel_size: default_data_parallel_size(), + enable_local_indexer: false, runtime_data: HashMap::new(), tensor_model_config: None, } diff --git a/lib/llm/src/mocker/kv_manager.rs b/lib/llm/src/mocker/kv_manager.rs index 17d7491162..a949139475 100644 --- a/lib/llm/src/mocker/kv_manager.rs +++ b/lib/llm/src/mocker/kv_manager.rs @@ -72,7 +72,7 @@ pub struct KvManager { impl KvManager { pub fn new(max_capacity: usize, block_size: usize) -> Self { - Self::new_with_publisher(max_capacity, block_size, None, 0) + Self::new_with_publisher(max_capacity, block_size, None, 0, false) } pub fn new_with_publisher( @@ -80,6 +80,7 @@ impl KvManager { block_size: usize, component: Option, dp_rank: u32, + enable_local_indexer: bool, ) -> Self { let active_blocks = HashMap::new(); let inactive_blocks = LRUEvictor::default(); @@ -87,10 +88,10 @@ impl KvManager { let kv_event_publisher = component.map(|comp| { tracing::info!( - "Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}" + "Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}" ); Arc::new( - KvEventPublisher::new(comp, block_size as u32, None) + KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer) .expect("Failed to create KV event publisher"), ) }); diff --git a/lib/llm/src/mocker/protocols.rs b/lib/llm/src/mocker/protocols.rs index d67e707ee6..4e62d836e4 100644 --- a/lib/llm/src/mocker/protocols.rs +++ b/lib/llm/src/mocker/protocols.rs @@ -120,6 +120,10 @@ pub struct MockEngineArgs { #[serde(skip)] #[builder(default = "Arc::new(PerfModel::default())")] pub perf_model: Arc, + + /// Enable worker-local KV indexer for tracking this worker's own KV cache state + #[builder(default = "false")] + pub enable_local_indexer: bool, } impl Default for MockEngineArgs { @@ -158,6 +162,7 @@ impl MockEngineArgs { "is_prefill", "is_decode", "planner_profile_data", + "enable_local_indexer", ] .iter() .cloned() @@ -239,6 +244,12 @@ impl MockEngineArgs { builder = builder.startup_time(Some(num)); } + if let Some(value) = extra_args.get("enable_local_indexer") + && let Some(enabled) = value.as_bool() + { + builder = builder.enable_local_indexer(enabled); + } + // Parse worker type from is_prefill and is_decode flags let is_prefill = extra_args .get("is_prefill") diff --git a/lib/llm/src/mocker/scheduler.rs b/lib/llm/src/mocker/scheduler.rs index e09aa17282..5ea205a4cb 100644 --- a/lib/llm/src/mocker/scheduler.rs +++ b/lib/llm/src/mocker/scheduler.rs @@ -275,6 +275,7 @@ impl Scheduler { args.block_size, component, dp_rank, + args.enable_local_indexer, ); let mut hit_rates = RunningMean::new(1000); diff --git a/lib/runtime/src/distributed.rs b/lib/runtime/src/distributed.rs index b6a440da12..d517664467 100644 --- a/lib/runtime/src/distributed.rs +++ b/lib/runtime/src/distributed.rs @@ -397,7 +397,7 @@ impl DistributedRuntime { /// TODO: This is a temporary KV router measure for component/component.rs EventPublisher impl for /// Component, to allow it to publish to NATS. KV Router is the only user. - pub(crate) async fn kv_router_nats_publish( + pub async fn kv_router_nats_publish( &self, subject: String, payload: bytes::Bytes, @@ -420,6 +420,25 @@ impl DistributedRuntime { Ok(nats_client.client().subscribe(subject).await?) } + /// TODO (karenc): This is a temporary KV router measure for worker query requests. + /// Allows KV Router to perform request/reply with workers. (versus the pub/sub pattern above) + /// KV Router is the only user, made public for use in dynamo-llm crate + pub async fn kv_router_nats_request( + &self, + subject: String, + payload: bytes::Bytes, + timeout: std::time::Duration, + ) -> anyhow::Result { + let Some(nats_client) = self.nats_client.as_ref() else { + anyhow::bail!("KV router's request requires NATS"); + }; + let response = + tokio::time::timeout(timeout, nats_client.client().request(subject, payload)) + .await + .map_err(|_| anyhow::anyhow!("Request timed out after {:?}", timeout))??; + Ok(response) + } + /// DEPRECATED: This method exists only for NATS request plane support. /// Once everything uses the TCP request plane, this can be removed along with /// the NATS service registration infrastructure. @@ -633,6 +652,26 @@ pub mod distributed_test_utils { }; super::DistributedRuntime::new(rt, config).await.unwrap() } + + /// Helper function to create a DRT instance which points at + /// a (shared) file-backed KV store and ephemeral NATS transport so that + /// multiple DRT instances may observe the same registration state. + /// NOTE: This gets around the fact that create_test_drt_async() is + /// hardcoded to spin up a memory-backed discovery store + /// which means we can't share discovery state across runtimes. + pub async fn create_test_shared_drt_async( + store_path: &std::path::Path, + ) -> super::DistributedRuntime { + use crate::{storage::kv, transports::nats}; + + let rt = crate::Runtime::from_current().unwrap(); + let config = super::DistributedConfig { + store_backend: kv::Selector::File(store_path.to_path_buf()), + nats_config: Some(nats::ClientOptions::default()), + request_plane: crate::distributed::RequestPlaneMode::default(), + }; + super::DistributedRuntime::new(rt, config).await.unwrap() + } } #[cfg(all(test, feature = "integration"))]