diff --git a/docs/guides/dynamo_run.md b/docs/guides/dynamo_run.md index 06df9dada08..e5995eb0f6e 100644 --- a/docs/guides/dynamo_run.md +++ b/docs/guides/dynamo_run.md @@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm. Usage: ``` -dynamo-run in=[http|text|dyn://|batch:] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path ] [--model-name ] [--model-config ] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)] +dynamo-run in=[http|text|dyn://|batch:] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path ] [--model-name ] [--model-config ] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--use-kv-events=true] [--verbosity (-v|-vv)] ``` Example: `dynamo run Qwen/Qwen3-0.6B` @@ -201,7 +201,13 @@ The only difference from the distributed system above is `--router-mode kv`. The For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing. -The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked. +The KV-aware routing arguments: + +- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). + +- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked. + +- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. ## Full usage details diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index 94889d9b822..b7fe662b529 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -128,6 +128,13 @@ pub struct Flags { #[arg(long)] pub router_temperature: Option, + /// KV Router: Whether to use KV events to maintain the view of cached blocks + /// If false, would use ApproxKvRouter for predicting block creation / deletion + /// based only on incoming requests at a timer. + /// Default: true + #[arg(long)] + pub use_kv_events: Option, + /// Max model context length. Reduce this if you don't have enough VRAM for the full model /// context length (e.g. Llama 4). /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. @@ -215,6 +222,7 @@ impl Flags { KvRouterConfig::new( self.kv_overlap_score_weight, self.router_temperature, + self.use_kv_events, self.max_num_batched_tokens, ), ) diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index 83a98fda51f..c64e40d5d91 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -212,18 +212,12 @@ impl ModelManager { kv_cache_block_size: u32, kv_router_config: Option, ) -> anyhow::Result> { - // Determine if we should use KV events based on overlap score weight - let use_kv_events = kv_router_config - .as_ref() - .map(|config| config.overlap_score_weight > 0.0) - .unwrap_or(false); - - let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); + let selector = Box::new(DefaultWorkerSelector::new(kv_router_config.clone())); let chooser = KvRouter::new( component.clone(), kv_cache_block_size, Some(selector), - use_kv_events, + kv_router_config.unwrap_or_default().use_kv_events, ) .await?; let new_kv_chooser = Arc::new(chooser); diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index da82587e4fb..83514e18ccb 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use std::sync::Arc; +use std::time::Duration; use anyhow::Result; use dynamo_runtime::{ @@ -14,6 +15,7 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; +use tokio::sync::Mutex; pub mod approx; pub mod indexer; @@ -27,7 +29,11 @@ pub mod sequence; use crate::{ kv_router::{ - indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, + approx::ApproxKvIndexer, + indexer::{ + compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError, + OverlapScores, RouterEvent, + }, metrics_aggregator::EndpointCollector, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, @@ -35,7 +41,6 @@ use crate::{ }, preprocessor::PreprocessedRequest, protocols::common::llm_backend::LLMEngineOutput, - tokens::TokenBlockSequence, }; use dynamo_runtime::traits::events::EventSubscriber; @@ -63,6 +68,8 @@ pub struct KvRouterConfig { pub router_temperature: f64, + pub use_kv_events: bool, + // note: this is not actually used for now pub max_num_batched_tokens: u32, } @@ -72,6 +79,7 @@ impl Default for KvRouterConfig { Self { overlap_score_weight: 1.0, router_temperature: 0.5, + use_kv_events: true, max_num_batched_tokens: 8192, } } @@ -83,24 +91,52 @@ impl KvRouterConfig { pub fn new( overlap_score_weight: Option, temperature: Option, + use_kv_events: Option, max_num_batched_tokens: Option, ) -> Self { let default = Self::default(); Self { overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight), router_temperature: temperature.unwrap_or(default.router_temperature), + use_kv_events: use_kv_events.unwrap_or(default.use_kv_events), max_num_batched_tokens: max_num_batched_tokens .unwrap_or(default.max_num_batched_tokens), } } } +// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this +// since both variants implement it +pub enum Indexer { + KvIndexer(KvIndexer), + ApproxKvIndexer(ApproxKvIndexer), +} + +impl Indexer { + async fn find_matches( + &self, + sequence: Vec, + ) -> Result { + match self { + Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await, + Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await, + } + } +} + /// A KvRouter only decides which worker you should use. It doesn't send you there. /// TODO: Rename this to indicate it only selects a worker, it does not route. pub struct KvRouter { - indexer: Option, + indexer: Indexer, + + // How about a Box scheduler: KvScheduler, + block_size: u32, + + // To ensure blocking reads / writes + // TODO: benchmark tradeoffs + find_best_match_mutex: Mutex<()>, } impl KvRouter { @@ -118,8 +154,16 @@ impl KvRouter { let metrics_aggregator = EndpointCollector::new(component.clone(), cancellation_token.clone()).await; - let maybe_indexer = - use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size)); + let indexer = if use_kv_events { + Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) + } else { + // hard code 120 seconds for now + Indexer::ApproxKvIndexer(ApproxKvIndexer::new( + cancellation_token.clone(), + block_size, + Duration::from_secs(120), + )) + }; let scheduler = KvScheduler::start( component.namespace().clone(), @@ -131,9 +175,9 @@ impl KvRouter { // [gluo TODO] try subscribe_with_type::, // error checking below will be different. - if let Some(ref indexer) = maybe_indexer { + if let Indexer::KvIndexer(ref kv_indexer) = indexer { let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?; - let kv_events_tx = indexer.event_sender(); + let kv_events_tx = kv_indexer.event_sender(); tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { @@ -158,9 +202,10 @@ impl KvRouter { tracing::info!("KV Routing initialized"); Ok(Self { - indexer: maybe_indexer, + indexer, scheduler, block_size, + find_best_match_mutex: Mutex::new(()), // Add this }) } @@ -172,20 +217,15 @@ impl KvRouter { context_id: &str, tokens: &[u32], ) -> anyhow::Result<(i64, u32)> { + // Acquire mutex to serialize access + // TODO: may as well make all the subroutines synchronous if benchmarking favors this + let _guard = self.find_best_match_mutex.lock().await; + let isl_tokens = tokens.len(); let block_size = self.block_size; - let (complete_blocks, _partial_block) = - TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64); - - let local_block_hashes = complete_blocks - .into_iter() - .map(|block| LocalBlockHash(block.block_hash())) - .collect(); - let overlap_scores = match &self.indexer { - Some(indexer) => indexer.find_matches(local_block_hashes).await?, - None => Default::default(), // Returns empty/default instance - }; + let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size); + let overlap_scores = self.indexer.find_matches(local_block_hashes).await?; let best_worker_id = self .scheduler @@ -198,6 +238,13 @@ impl KvRouter { ) .await?; + if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer { + indexer + .process_routing_decision_for_request(tokens, best_worker_id) + .await + .unwrap(); + }; + let overlap_amount = overlap_scores .scores .get(&best_worker_id) diff --git a/lib/llm/src/kv_router/approx.rs b/lib/llm/src/kv_router/approx.rs index fca9cd585af..20bf1469857 100644 --- a/lib/llm/src/kv_router/approx.rs +++ b/lib/llm/src/kv_router/approx.rs @@ -72,6 +72,7 @@ struct TimerEntry { struct TimerManager { /// The source of truth. Maps a key to its current expiration instant. timers: HashMap, + /// A min-heap of (expiration_instant, key) used to efficiently find the /// next expiring timer. An entry in this heap is "stale" if the instant /// does not match the one in the `timers` map. @@ -79,18 +80,32 @@ struct TimerManager { /// The expiration duration of the timers. ttl: Duration, + + /// Threshold for rebuilding the heap. + /// The heap will be rebuilt from scratch to remove stale entries. + threshold: usize, } impl TimerManager { /// Creates a new, empty TimerManager. - pub fn new(ttl: Duration) -> Self { + pub fn new(ttl: Duration, threshold: usize) -> Self { TimerManager { timers: HashMap::new(), expirations: BinaryHeap::new(), ttl, + threshold, } } + /// Rebuilds the expirations heap from the timers map, removing all stale entries. + fn rebuild_heap(&mut self) { + self.expirations = self + .timers + .iter() + .map(|(key, &expiry)| Reverse((expiry, key.clone()))) + .collect(); + } + /// Inserts a new timer or updates an existing one for the given key. /// /// # Arguments @@ -108,6 +123,11 @@ impl TimerManager { // which will be ignored when it's popped. self.expirations.push(Reverse((expiry_time, key))); } + + // Check if we should rebuild the heap to remove stale entries + if self.expirations.len() > self.timers.len() * self.threshold { + self.rebuild_heap(); + } } /// Polls for expired timers and returns a list of keys for all timers @@ -123,23 +143,12 @@ impl TimerManager { } // The timer might be expired, so pop it from the heap. - // We can safely unwrap because we just peeked. let Reverse((expiry_time, key)) = self.expirations.pop().unwrap(); - // CRUCIAL STEP: Check if the popped timer is stale. - // A timer is stale if its key is no longer in our authoritative map, - // or if the expiration time in the map is different (i.e., it was updated). - match self.timers.get(&key) { - Some(authoritative_expiry) if *authoritative_expiry == expiry_time => { - // This is a valid, non-stale, expired timer. - // Remove it from the map and add its key to our results. - self.timers.remove(&key); - expired_keys.push(key); - } - _ => { - // This entry in the heap was stale. It was either removed - // or updated with a new time. We just ignore it and continue. - } + if self.timers.get(&key) == Some(&expiry_time) { + // This is a valid, non-stale, expired timer. + self.timers.remove(&key); + expired_keys.push(key); } } @@ -184,7 +193,8 @@ impl ApproxKvIndexer { runtime.block_on(async move { let mut trie = RadixTree::new(); - let mut timer_manager: TimerManager = TimerManager::new(ttl); + // Use a reasonable threshold - can be made configurable if needed + let mut timer_manager: TimerManager = TimerManager::new(ttl, 50); let mut event_id = 0; loop { // Create a future that sleeps until the next expiration time. @@ -398,7 +408,7 @@ mod tests { #[tokio::test] async fn test_timer_manager_expiry() { const TTL: Duration = Duration::from_millis(50); - let mut tm: TimerManager = TimerManager::new(TTL); + let mut tm: TimerManager = TimerManager::new(TTL, 50); tm.insert(vec![1, 2, 3]); assert!(tm.get_expiry(&1).is_some()); @@ -419,7 +429,7 @@ mod tests { async fn test_timer_manager_update_resets_ttl() { // Validate that reinserting an existing key extends its TTL and prevents premature expiry. const TTL: Duration = Duration::from_millis(50); - let mut tm: TimerManager = TimerManager::new(TTL); + let mut tm: TimerManager = TimerManager::new(TTL, 50); // Initial insert and capture the original expiry. tm.insert(vec![42]);