Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/guides/dynamo_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.

Usage:
```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)]
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--use-kv-events=true] [--verbosity (-v|-vv)]
```

Example: `dynamo run Qwen/Qwen3-0.6B`
Expand Down Expand Up @@ -201,7 +201,13 @@ The only difference from the distributed system above is `--router-mode kv`. The

For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.

The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked.
The KV-aware routing arguments:

- `--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks).

- `--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked.

- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, then we use the `ApproxKvIndexer` to predict them based on the incoming requests. Set false if your backend engine does not emit KV events.

## Full usage details

Expand Down
8 changes: 8 additions & 0 deletions launch/dynamo-run/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ pub struct Flags {
#[arg(long)]
pub router_temperature: Option<f64>,

/// KV Router: Whether to use KV events to maintain the view of cached blocks
/// If false, would use ApproxKvRouter for predicting block creation / deletion
/// based only on incoming requests at a timer.
/// Default: true
#[arg(long)]
pub use_kv_events: Option<bool>,

/// Max model context length. Reduce this if you don't have enough VRAM for the full model
/// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
Expand Down Expand Up @@ -215,6 +222,7 @@ impl Flags {
KvRouterConfig::new(
self.kv_overlap_score_weight,
self.router_temperature,
self.use_kv_events,
self.max_num_batched_tokens,
),
)
Expand Down
10 changes: 2 additions & 8 deletions lib/llm/src/discovery/model_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,12 @@ impl ModelManager {
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
// Determine if we should use KV events based on overlap score weight
let use_kv_events = kv_router_config
.as_ref()
.map(|config| config.overlap_score_weight > 0.0)
.unwrap_or(false);

let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config.clone()));
let chooser = KvRouter::new(
component.clone(),
kv_cache_block_size,
Some(selector),
use_kv_events,
kv_router_config.unwrap_or_default().use_kv_events,
)
.await?;
let new_kv_chooser = Arc::new(chooser);
Expand Down
85 changes: 66 additions & 19 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
use dynamo_runtime::{
Expand All @@ -14,6 +15,7 @@ use dynamo_runtime::{
protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
use tokio::sync::Mutex;

pub mod approx;
pub mod indexer;
Expand All @@ -27,15 +29,18 @@ pub mod sequence;

use crate::{
kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
approx::ApproxKvIndexer,
indexer::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
OverlapScores, RouterEvent,
},
metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
},
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
tokens::TokenBlockSequence,
};

use dynamo_runtime::traits::events::EventSubscriber;
Expand Down Expand Up @@ -63,6 +68,8 @@ pub struct KvRouterConfig {

pub router_temperature: f64,

pub use_kv_events: bool,

// note: this is not actually used for now
pub max_num_batched_tokens: u32,
}
Expand All @@ -72,6 +79,7 @@ impl Default for KvRouterConfig {
Self {
overlap_score_weight: 1.0,
router_temperature: 0.5,
use_kv_events: true,
max_num_batched_tokens: 8192,
}
}
Expand All @@ -83,24 +91,52 @@ impl KvRouterConfig {
pub fn new(
overlap_score_weight: Option<f64>,
temperature: Option<f64>,
use_kv_events: Option<bool>,
max_num_batched_tokens: Option<u32>,
) -> Self {
let default = Self::default();
Self {
overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
router_temperature: temperature.unwrap_or(default.router_temperature),
use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
}
}
}

// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
KvIndexer(KvIndexer),
ApproxKvIndexer(ApproxKvIndexer),
}

impl Indexer {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
}
}
}

/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter {
indexer: Option<KvIndexer>,
indexer: Indexer,

// How about a Box<dyn KvIndexerInterface>
scheduler: KvScheduler,

block_size: u32,

// To ensure blocking reads / writes
// TODO: benchmark tradeoffs
find_best_match_mutex: Mutex<()>,
}

impl KvRouter {
Expand All @@ -118,8 +154,16 @@ impl KvRouter {
let metrics_aggregator =
EndpointCollector::new(component.clone(), cancellation_token.clone()).await;

let maybe_indexer =
use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size));
let indexer = if use_kv_events {
Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
} else {
// hard code 120 seconds for now
Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
cancellation_token.clone(),
block_size,
Duration::from_secs(120),
))
};

let scheduler = KvScheduler::start(
component.namespace().clone(),
Expand All @@ -131,9 +175,9 @@ impl KvRouter {

// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different.
if let Some(ref indexer) = maybe_indexer {
if let Indexer::KvIndexer(ref kv_indexer) = indexer {
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender();
let kv_events_tx = kv_indexer.event_sender();

tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
Expand All @@ -158,9 +202,10 @@ impl KvRouter {

tracing::info!("KV Routing initialized");
Ok(Self {
indexer: maybe_indexer,
indexer,
scheduler,
block_size,
find_best_match_mutex: Mutex::new(()), // Add this
})
}

Expand All @@ -172,20 +217,15 @@ impl KvRouter {
context_id: &str,
tokens: &[u32],
) -> anyhow::Result<(i64, u32)> {
// Acquire mutex to serialize access
// TODO: may as well make all the subroutines synchronous if benchmarking favors this
let _guard = self.find_best_match_mutex.lock().await;

let isl_tokens = tokens.len();
let block_size = self.block_size;

let (complete_blocks, _partial_block) =
TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64);

let local_block_hashes = complete_blocks
.into_iter()
.map(|block| LocalBlockHash(block.block_hash()))
.collect();
let overlap_scores = match &self.indexer {
Some(indexer) => indexer.find_matches(local_block_hashes).await?,
None => Default::default(), // Returns empty/default instance
};
let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;

let best_worker_id = self
.scheduler
Expand All @@ -198,6 +238,13 @@ impl KvRouter {
)
.await?;

if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
indexer
.process_routing_decision_for_request(tokens, best_worker_id)
.await
.unwrap();
};

let overlap_amount = overlap_scores
.scores
.get(&best_worker_id)
Expand Down
48 changes: 29 additions & 19 deletions lib/llm/src/kv_router/approx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,40 @@ struct TimerEntry {
struct TimerManager<K: Clone + Hash + Eq + Ord> {
/// The source of truth. Maps a key to its current expiration instant.
timers: HashMap<K, Instant>,

/// A min-heap of (expiration_instant, key) used to efficiently find the
/// next expiring timer. An entry in this heap is "stale" if the instant
/// does not match the one in the `timers` map.
expirations: BinaryHeap<Reverse<(Instant, K)>>,

/// The expiration duration of the timers.
ttl: Duration,

/// Threshold for rebuilding the heap.
/// The heap will be rebuilt from scratch to remove stale entries.
threshold: usize,
}

impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
/// Creates a new, empty TimerManager.
pub fn new(ttl: Duration) -> Self {
pub fn new(ttl: Duration, threshold: usize) -> Self {
TimerManager {
timers: HashMap::new(),
expirations: BinaryHeap::new(),
ttl,
threshold,
}
}

/// Rebuilds the expirations heap from the timers map, removing all stale entries.
fn rebuild_heap(&mut self) {
self.expirations = self
.timers
.iter()
.map(|(key, &expiry)| Reverse((expiry, key.clone())))
.collect();
}

/// Inserts a new timer or updates an existing one for the given key.
///
/// # Arguments
Expand All @@ -108,6 +123,11 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
// which will be ignored when it's popped.
self.expirations.push(Reverse((expiry_time, key)));
}

// Check if we should rebuild the heap to remove stale entries
if self.expirations.len() > self.timers.len() * self.threshold {
self.rebuild_heap();
}
}

/// Polls for expired timers and returns a list of keys for all timers
Expand All @@ -123,23 +143,12 @@ impl<K: Clone + Hash + Eq + Ord> TimerManager<K> {
}

// The timer might be expired, so pop it from the heap.
// We can safely unwrap because we just peeked.
let Reverse((expiry_time, key)) = self.expirations.pop().unwrap();

// CRUCIAL STEP: Check if the popped timer is stale.
// A timer is stale if its key is no longer in our authoritative map,
// or if the expiration time in the map is different (i.e., it was updated).
match self.timers.get(&key) {
Some(authoritative_expiry) if *authoritative_expiry == expiry_time => {
// This is a valid, non-stale, expired timer.
// Remove it from the map and add its key to our results.
self.timers.remove(&key);
expired_keys.push(key);
}
_ => {
// This entry in the heap was stale. It was either removed
// or updated with a new time. We just ignore it and continue.
}
if self.timers.get(&key) == Some(&expiry_time) {
// This is a valid, non-stale, expired timer.
self.timers.remove(&key);
expired_keys.push(key);
}
}

Expand Down Expand Up @@ -184,7 +193,8 @@ impl ApproxKvIndexer {

runtime.block_on(async move {
let mut trie = RadixTree::new();
let mut timer_manager: TimerManager<TimerEntry> = TimerManager::new(ttl);
// Use a reasonable threshold - can be made configurable if needed
let mut timer_manager: TimerManager<TimerEntry> = TimerManager::new(ttl, 50);
let mut event_id = 0;
loop {
// Create a future that sleeps until the next expiration time.
Expand Down Expand Up @@ -398,7 +408,7 @@ mod tests {
#[tokio::test]
async fn test_timer_manager_expiry() {
const TTL: Duration = Duration::from_millis(50);
let mut tm: TimerManager<u32> = TimerManager::new(TTL);
let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);

tm.insert(vec![1, 2, 3]);
assert!(tm.get_expiry(&1).is_some());
Expand All @@ -419,7 +429,7 @@ mod tests {
async fn test_timer_manager_update_resets_ttl() {
// Validate that reinserting an existing key extends its TTL and prevents premature expiry.
const TTL: Duration = Duration::from_millis(50);
let mut tm: TimerManager<u32> = TimerManager::new(TTL);
let mut tm: TimerManager<u32> = TimerManager::new(TTL, 50);

// Initial insert and capture the original expiry.
tm.insert(vec![42]);
Expand Down
Loading