diff --git a/lib/llm/src/mocker.rs b/lib/llm/src/mocker.rs index 2a9e63a9e2..4315868c49 100644 --- a/lib/llm/src/mocker.rs +++ b/lib/llm/src/mocker.rs @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod engine; pub mod evictor; pub mod kv_manager; pub mod protocols; diff --git a/lib/llm/src/mocker/engine.rs b/lib/llm/src/mocker/engine.rs new file mode 100644 index 0000000000..b367910792 --- /dev/null +++ b/lib/llm/src/mocker/engine.rs @@ -0,0 +1,764 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! MockSchedulerEngine - AsyncEngine wrapper around the Scheduler +//! +//! This module provides an AsyncEngine implementation that wraps the Scheduler +//! to provide streaming token generation with realistic timing simulation. + +use crate::kv_router::publisher::WorkerMetricsPublisher; +use crate::mocker::protocols::DirectRequest; +use crate::mocker::protocols::{MockEngineArgs, OutputSignal}; +use crate::mocker::scheduler::Scheduler; +use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}; +use crate::protocols::TokenIdType; +use dynamo_runtime::protocols::annotated::Annotated; +use dynamo_runtime::DistributedRuntime; +use tokio_util::sync::CancellationToken; + +use dynamo_runtime::{ + component::Component, + engine::AsyncEngineContextProvider, + pipeline::{async_trait, AsyncEngine, Error, ManyOut, ResponseStream, SingleIn}, + traits::DistributedRuntimeProvider, + Result, +}; + +use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData}; +use crate::kv_router::publisher::KvEventPublisher; +use futures::StreamExt; +use rand::Rng; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex, OnceCell}; +use tokio::time::{interval, Duration}; +use tokio_stream::wrappers::ReceiverStream; +use uuid::Uuid; + +pub const MOCKER_COMPONENT: &str = "mocker"; + +/// Generate a random token ID from 1k to 5k +fn generate_random_token() -> TokenIdType { + let mut rng = rand::rng(); + rng.random_range(1000..5000) +} + +/// AsyncEngine wrapper around the Scheduler that generates random character tokens +#[derive(Clone)] +pub struct MockVllmEngine { + active_requests: Arc>>>, + request_senders: Arc>>>, + engine_args: MockEngineArgs, +} + +impl MockVllmEngine { + /// Create a new MockVllmEngine with the given parameters + pub fn new(args: MockEngineArgs) -> Self { + Self { + active_requests: Arc::new(Mutex::new(HashMap::new())), + request_senders: Arc::new(OnceCell::new()), + engine_args: args, + } + } + + pub async fn start(&self, component: Component) -> Result<()> { + let cancel_token = component.drt().runtime().child_token(); + + let (schedulers, kv_event_receiver) = self.start_schedulers( + self.engine_args.clone(), + self.active_requests.clone(), + cancel_token.clone(), + ); + + Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone()) + .await?; + + // Start KV events publishing with the actual receivers from schedulers + if self.engine_args.enable_prefix_caching { + Self::start_kv_events_publishing( + kv_event_receiver, + Some(component.clone()), + self.engine_args.block_size, + cancel_token.clone(), + ) + .await?; + } + + Ok(()) + } + + pub fn direct(&self, request: DirectRequest, dp_rank: usize) { + let senders = self.request_senders.get().expect("Not initialized"); + let _ = senders[dp_rank].send(request); + } + + /// Create schedulers and spawn their background tasks for distributing token notifications + /// Returns schedulers and their corresponding KV event receivers + fn start_schedulers( + &self, + args: MockEngineArgs, + active_requests: Arc>>>, + cancel_token: CancellationToken, + ) -> ( + Vec, + Vec>, + ) { + let mut schedulers = Vec::::new(); + let mut kv_event_receivers = Vec::new(); + let mut senders = Vec::with_capacity(args.dp_size as usize); + + // Create multiple schedulers and their background tasks + for dp_rank in 0..args.dp_size { + // Create a shared output channel that this scheduler will use + let (output_tx, mut output_rx) = mpsc::unbounded_channel::(); + + // Create a channel for KV events from this scheduler + let (kv_events_tx, kv_events_rx) = mpsc::unbounded_channel::(); + + let scheduler = Scheduler::new( + args.clone(), + Some(dp_rank), + Some(output_tx), + Some(kv_events_tx), // Pass the KV events sender to scheduler + Some(cancel_token.clone()), + ); + + senders.push(scheduler.request_sender()); + schedulers.push(scheduler); + kv_event_receivers.push(kv_events_rx); + + // Spawn a background task for this scheduler to distribute token notifications to active requests + // let output_rx = Arc::new(Mutex::new(output_rx)); + let active_requests_clone = active_requests.clone(); + let cancel_token_cloned = cancel_token.clone(); + + tokio::spawn(async move { + loop { + tokio::select! { + signal_result = output_rx.recv() => { + let Some(signal) = signal_result else { + break; // Channel closed + }; + + // Notify the specific request that a token was generated + let active = active_requests_clone.lock().await; + if let Some(request_tx) = active.get(&signal.uuid) { + let _ = request_tx.send(signal); + } + } + _ = cancel_token_cloned.cancelled() => { + break; + } + } + } + }); + } + + // Set the senders once + self.request_senders + .set(senders) + .expect("Already initialized"); + + (schedulers, kv_event_receivers) + } + + /// Start background tasks to poll and publish metrics every second + async fn start_metrics_publishing( + schedulers: &[Scheduler], + component: Option, + cancel_token: CancellationToken, + ) -> Result<()> { + tracing::info!("Creating metrics publisher"); + let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?); + tracing::info!("Metrics publisher created"); + + if let Some(comp) = component { + tracing::info!("Creating metrics endpoint"); + tokio::spawn({ + let publisher = metrics_publisher.clone(); + async move { + if let Err(e) = publisher.create_endpoint(comp.clone()).await { + tracing::error!("Metrics endpoint failed: {e}"); + } + } + }); + + // Give it a moment to start + tokio::time::sleep(Duration::from_millis(100)).await; + tracing::info!("Metrics endpoint started (background)"); + } + + tracing::info!("Starting metrics background tasks"); + for (dp_rank, scheduler) in schedulers.iter().enumerate() { + let scheduler = scheduler.clone(); + let publisher = metrics_publisher.clone(); + let dp_rank = dp_rank as u32; + let cancel_token = cancel_token.clone(); + + tokio::spawn(async move { + let mut interval = interval(Duration::from_millis(100)); + + loop { + tokio::select! { + _ = interval.tick() => { + // Get metrics from scheduler + let metrics = scheduler.get_forward_pass_metrics().await; + + // Publish metrics + if let Err(e) = publisher.publish(Arc::new(metrics)) { + tracing::warn!("Failed to publish metrics for DP rank {dp_rank}: {e}"); + } else { + tracing::trace!("Published metrics for DP rank {}", dp_rank); + } + } + _ = cancel_token.cancelled() => { + tracing::info!("Metrics publishing cancelled for DP rank {dp_rank}"); + break; + } + } + } + }); + } + tracing::info!("Metrics background tasks started"); + Ok(()) + } + + /// Start background tasks to collect and publish KV events from schedulers + async fn start_kv_events_publishing( + kv_event_receivers: Vec>, + component: Option, + block_size: usize, + cancel_token: CancellationToken, + ) -> Result<()> { + tracing::info!("Starting KV events publishing"); + + // Only start KV events publishing if we have a component + let Some(comp) = component else { + tracing::warn!("No component provided, skipping KV events publishing"); + return Ok(()); + }; + tracing::info!("Component found for KV events publishing"); + + tracing::debug!("Getting worker_id"); + let worker_id = comp + .drt() + .primary_lease() + .expect("Cannot publish KV events without lease") // ← This will PANIC on static! + .id(); + // let worker_id = 0; + tracing::debug!("Worker_id set to: {worker_id}"); + + tracing::info!("Creating KV event publisher"); + let kv_event_publisher = Arc::new(KvEventPublisher::new( + comp.clone(), + worker_id, + block_size as u32, + None, + )?); + tracing::info!("KV event publisher created"); + + tracing::info!( + "Starting KV event background tasks for {} receivers", + kv_event_receivers.len() + ); + for (dp_rank, mut kv_events_rx) in kv_event_receivers.into_iter().enumerate() { + tracing::debug!("Starting background task for DP rank {dp_rank}"); + let publisher = kv_event_publisher.clone(); + let dp_rank = dp_rank as u32; + let cancel_token = cancel_token.clone(); + + tokio::spawn(async move { + tracing::debug!("Background task started for DP rank {dp_rank}"); + loop { + tokio::select! { + // Receive actual KV events from the scheduler + Some(event_data) = kv_events_rx.recv() => { + // Convert KvCacheEventData to KvCacheEvent with random UUID as event_id + let event = KvCacheEvent { + event_id: Uuid::new_v4().as_u128() as u64, + data: event_data, + }; + + // Publish the event + if let Err(e) = publisher.publish(event) { + tracing::warn!("Failed to publish KV event for DP rank {dp_rank}: {e}"); + } else { + tracing::trace!("Published KV event for DP rank {dp_rank}"); + } + } + _ = cancel_token.cancelled() => { + tracing::info!("KV events publishing cancelled for DP rank {dp_rank}"); + break; + } + } + } + }); + } + tracing::info!("All KV event background tasks started"); + + Ok(()) + } +} + +#[async_trait] +impl AsyncEngine, ManyOut, Error> + for MockVllmEngine +{ + async fn generate( + &self, + input: SingleIn, + ) -> Result, Error> { + let (request, ctx) = input.into_parts(); + + // Extract dp_rank from annotations if present + let dp_rank = request + .annotations + .iter() + .find_map(|ann| { + if ann.starts_with("dp_rank:") { + ann.strip_prefix("dp_rank:").and_then(|s| s.parse().ok()) + } else { + None + } + }) + .unwrap_or(0); + + // Validate dp_rank + if dp_rank >= self.engine_args.dp_size { + return Err(Error::msg(format!( + "dp_rank {} is out of bounds for dp_size {}", + dp_rank, self.engine_args.dp_size + ))); + } + + let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4()); + + // Convert PreprocessedRequest to DirectRequest for scheduler + let direct_request = DirectRequest { + tokens: request.token_ids.clone(), + max_output_tokens: request + .stop_conditions + .max_tokens + .expect("max_output_tokens must be specified for mocker") + as usize, + uuid: Some(request_uuid), + dp_rank: Some(dp_rank), + }; + + let (request_tx, mut request_rx) = mpsc::unbounded_channel::(); + { + let mut active = self.active_requests.lock().await; + active.insert(request_uuid, request_tx); + } + + // Send the request to the appropriate scheduler based on dp_rank + self.direct(direct_request, dp_rank as usize); + + // Create a simple channel for the stream + let (stream_tx, stream_rx) = mpsc::channel::(64); + + let active_requests = self.active_requests.clone(); + let async_context = ctx.context(); + let max_tokens = request.stop_conditions.max_tokens.unwrap_or(100) as usize; + + // Spawn a task to handle the complex async logic + tokio::spawn(async move { + let mut token_count = 0; + + loop { + tokio::select! { + maybe_signal = request_rx.recv() => { + let Some(signal) = maybe_signal else { + let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string())).await; + break; + }; + + if signal.completed && token_count < max_tokens { + let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string())).await; + break; + } + + if signal.completed { + let _ = stream_tx.send(LLMEngineOutput::length()).await; + break; + } + + // Generate a new token + let token_id = generate_random_token(); + token_count += 1; + + let output = LLMEngineOutput { + token_ids: vec![token_id], + tokens: None, // Let backend handle detokenization + text: None, + cum_log_probs: None, + log_probs: None, + finish_reason: None, + index: None, + }; + + if stream_tx.send(output).await.is_err() { + break; + } + } + + _ = async_context.stopped() => { + let _ = stream_tx.send(LLMEngineOutput::cancelled()).await; + break; + } + } + } + + // Clean up: remove this request from active requests + let mut active = active_requests.lock().await; + active.remove(&request_uuid); + }); + + // Create a simple ReceiverStream which is naturally Send + Sync + let stream = ReceiverStream::new(stream_rx); + Ok(ResponseStream::new(Box::pin(stream), ctx.context())) + } +} + +pub struct AnnotatedMockEngine { + inner: Arc, +} + +impl AnnotatedMockEngine { + pub fn new( + inner: MockVllmEngine, + distributed_runtime: DistributedRuntime, + endpoint: dynamo_runtime::protocols::Endpoint, + ) -> Self { + let inner = Arc::new(inner); + let inner_clone = inner.clone(); + + // Start background task to wait for component service and start the engine + tokio::spawn(async move { + loop { + // Try to create component + let Ok(namespace) = distributed_runtime.namespace(&endpoint.namespace) else { + tracing::debug!("Namespace not available yet, retrying..."); + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + }; + + let Ok(component) = namespace.component(&endpoint.component) else { + tracing::debug!("Component not available yet, retrying..."); + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + }; + + // Check if service is available by trying to list instances + let Ok(instances) = component.list_instances().await else { + tracing::debug!("Cannot list instances yet, retrying..."); + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + }; + + if instances.is_empty() { + tracing::debug!("No instances available yet, retrying..."); + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } + + tracing::info!("Component service is now available, starting mocker engine"); + + // Start the engine with the component + if let Err(e) = inner_clone.start(component).await { + tracing::error!("Failed to start mocker engine: {e}"); + } + break; + } + }); + + Self { inner } + } +} + +#[async_trait] +impl AsyncEngine, ManyOut>, Error> + for AnnotatedMockEngine +{ + async fn generate( + &self, + input: SingleIn, + ) -> Result>, Error> { + let stream = self.inner.generate(input).await?; + let context = stream.context(); + + // Convert stream of LLMEngineOutput to Annotated + let annotated_stream = stream.map(Annotated::from_data); + + Ok(ResponseStream::new(Box::pin(annotated_stream), context)) + } +} + +/// Create a mocker engine as ExecutionContext +pub async fn make_mocker_engine( + distributed_runtime: DistributedRuntime, + endpoint: dynamo_runtime::protocols::Endpoint, + args: MockEngineArgs, +) -> Result { + // Create the mocker engine + tracing::info!("Creating mocker engine (service will be started in background)"); + let annotated_engine = + AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint); + + Ok(Arc::new(annotated_engine)) +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use crate::kv_router::indexer::RouterEvent; + use crate::kv_router::KV_EVENT_SUBJECT; + use crate::protocols::common::{SamplingOptions, StopConditions}; + use dynamo_runtime::{ + pipeline::Context, + pipeline::{network::Ingress, PushRouter}, + traits::events::EventSubscriber, + DistributedRuntime, Worker, + }; + use futures::StreamExt; + use tokio::time::timeout; + + #[tokio::test] + #[ignore] // Run with: cargo test -- --ignored + async fn test_mock_vllm_engine_full_integration() -> Result<()> { + const DP_SIZE: u32 = 2; + const TOKENS_PER_REQUEST: usize = 20; + const BLOCK_SIZE: usize = 2; + + // Create runtime and distributed runtime + let worker = Worker::from_settings()?; + let runtime = worker.runtime(); + let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; + tracing::info!("✓ Runtime and distributed runtime created"); + + // Create component for MockVllmEngine (needed for publishers) + let test_component = distributed + .namespace("test")? + .component(MOCKER_COMPONENT)? + .service_builder() + .create() + .await?; + tracing::info!("✓ Test component created"); + + // Create MockVllmEngine WITH component (enables publishers) + let args = MockEngineArgs::builder() + .speedup_ratio(10.0) + .dp_size(DP_SIZE) + .block_size(BLOCK_SIZE) + .build() + .unwrap(); + + let engine = MockVllmEngine::new(args); + engine.start(test_component.clone()).await?; + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + let engine = Arc::new(engine); + tracing::info!("✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}"); + + // Set up KV events subscriber + let mut kv_events_subscriber = test_component.subscribe(KV_EVENT_SUBJECT).await?; + tracing::info!("✓ KV events subscriber created"); + + // Wrap with Ingress and register with component/endpoint + let ingress = Ingress::for_engine(engine)?; + tracing::info!("✓ Ingress wrapper created"); + + // Start the server in background + let server_handle = tokio::spawn({ + let test_component = test_component.clone(); + async move { + if let Err(e) = test_component + .endpoint("generate") + .endpoint_builder() + .handler(ingress) + .start() + .await + { + eprintln!("❌ Generate endpoint failed: {e}"); + } + } + }); + tracing::info!("✓ Server started in background"); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + tracing::info!("✓ Server startup delay completed"); + + // Print all registered instances from etcd + match test_component.list_instances().await { + Ok(instances) => { + tracing::info!("📋 Found {} registered instances:", instances.len()); + for instance in instances { + tracing::info!( + " • {}/{}/{} (ID: {})", + instance.namespace, + instance.component, + instance.endpoint, + instance.instance_id + ); + } + } + Err(e) => { + tracing::error!("❌ Failed to list instances: {e}"); + } + } + + // Create client + let client = distributed + .namespace("test")? + .component(MOCKER_COMPONENT)? + .endpoint("generate") + .client() + .await?; + tracing::info!("✓ Client created"); + + let router = PushRouter::from_client(client, Default::default()).await?; + tracing::info!("✓ Router created"); + + // Create test requests for both DP workers + let create_request = |tokens: Vec, dp_rank: u32| PreprocessedRequest { + token_ids: tokens, + batch_token_ids: None, + stop_conditions: StopConditions { + max_tokens: Some(TOKENS_PER_REQUEST as u32), + ..Default::default() + }, + sampling_options: SamplingOptions::default(), + eos_token_ids: vec![], + mdc_sum: None, + annotations: vec![format!("dp_rank:{dp_rank}")], + estimated_prefix_hit_num_blocks: None, + }; + + let requests = vec![ + create_request(vec![1, 2, 3, 4, 5], 0), + create_request(vec![1, 2, 3, 4, 5], 0), + create_request(vec![1, 2, 3, 4, 5], 1), + create_request(vec![1, 2, 3, 4, 5], 1), + ]; + tracing::info!( + "✓ Test requests created ({} requests total)", + requests.len() + ); + + // Test each request + for (i, request) in requests.into_iter().enumerate() { + tracing::info!("Testing request {}", i + 1); + + let response_stream = router.generate(Context::new(request)).await?; + let responses: Vec = response_stream.collect().await; + + // Should have at least one response + assert!( + !responses.is_empty(), + "Request {} should produce at least one response", + i + 1 + ); + + // Count total tokens generated (excluding final message) + let mut total_tokens = 0; + let mut has_finish_reason = false; + + for response in &responses { + total_tokens += response.token_ids.len(); + if response.finish_reason.is_some() { + has_finish_reason = true; + } + } + + // Should have a finish reason in the last response + assert!( + has_finish_reason, + "Request {} should have a finish reason", + i + 1 + ); + + // Verify we got approximately the expected number of tokens + assert!( + total_tokens <= TOKENS_PER_REQUEST + 1, // +1 for potential final empty response + "Request {} generated {} tokens, expected at most {}", + i + 1, + total_tokens, + TOKENS_PER_REQUEST + 1 + ); + + tracing::info!( + "✓ Request {} completed successfully with {} tokens", + i + 1, + total_tokens + ); + } + + tracing::info!("🎉 All requests completed successfully!"); + + // Try to receive at least one KV event with 100ms timeout + tracing::info!("Waiting for KV event with 100ms timeout..."); + let msg = timeout(Duration::from_millis(100), kv_events_subscriber.next()) + .await + .map_err(|_| Error::msg("Timeout waiting for KV event"))? + .ok_or_else(|| Error::msg("KV events stream ended unexpectedly"))?; + + match serde_json::from_slice::(&msg.payload) { + Ok(event) => { + tracing::info!("✓ Received KV event: {event:?}"); + } + Err(e) => { + return Err(Error::msg(format!("Failed to deserialize KV event: {e}"))); + } + } + + // Use KvMetricsAggregator to get metrics more easily + let cancel_token = test_component.drt().runtime().child_token(); + let metrics_aggregator = crate::kv_router::metrics_aggregator::KvMetricsAggregator::new( + test_component.clone(), + cancel_token, + ) + .await; + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + let processed_endpoints = metrics_aggregator.get_endpoints(); + tracing::info!( + "Found {} metrics endpoints", + processed_endpoints.endpoints.len() + ); + + // Verify we found at least one metrics endpoint + assert!( + !processed_endpoints.endpoints.is_empty(), + "Should find at least one metrics endpoint" + ); + tracing::info!( + "✓ Successfully found {} metrics endpoints", + processed_endpoints.endpoints.len() + ); + + // Verify the metrics endpoints contain valid data + for (worker_id, endpoint) in &processed_endpoints.endpoints { + tracing::info!("✓ Worker {} metrics: {:?}", worker_id, endpoint.data); + } + + tracing::info!("🎉 Event verification completed!"); + + // Cleanup + distributed.shutdown(); + server_handle.await?; + + Ok(()) + } +} diff --git a/lib/llm/src/mocker/evictor.rs b/lib/llm/src/mocker/evictor.rs index 47a312eede..63d079180d 100644 --- a/lib/llm/src/mocker/evictor.rs +++ b/lib/llm/src/mocker/evictor.rs @@ -13,167 +13,158 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp::Eq; -use std::collections::{HashMap, VecDeque}; +use std::cmp::{Eq, Ordering}; +use std::collections::{BTreeSet, HashMap}; use std::hash::Hash; -use std::time::Instant; + +/// A wrapper for (T, counter) that implements Ord based only on counter +#[derive(Debug, Clone, Eq, PartialEq)] +struct PriorityItem { + item: T, + counter: i64, +} + +impl Ord for PriorityItem { + fn cmp(&self, other: &Self) -> Ordering { + self.counter.cmp(&other.counter) + } +} + +impl PartialOrd for PriorityItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} /// An LRU evictor that maintains objects and evicts them based on their -/// last accessed time. Implements a "lazy" eviction mechanism where: -/// 1. The priority queue does not immediately reflect updates or removes -/// 2. Objects are pushed to the queue in order of increasing priority (older objects first) -/// 3. The user must ensure objects are added in correct priority (temporal order) -/// 4. Remove and update operations are lazy - entries remain in the queue until -/// they are either evicted or cleaned up during maintenance +/// priority counter. Lower counter values are evicted first. #[derive(Debug)] pub struct LRUEvictor { - free_table: HashMap, - priority_queue: VecDeque<(T, f64)>, - cleanup_threshold: usize, - start_time: Instant, + free_table: HashMap, + priority_queue: BTreeSet>, + positive_counter: i64, + negative_counter: i64, } impl Default for LRUEvictor { fn default() -> Self { Self { free_table: HashMap::new(), - priority_queue: VecDeque::new(), - cleanup_threshold: 50, - start_time: Instant::now(), + priority_queue: BTreeSet::new(), + positive_counter: 0, + negative_counter: 0, } } } impl LRUEvictor { - /// Create a new LRUEvictor with the default cleanup threshold - pub fn new(cleanup_threshold: usize) -> Self { - Self { - cleanup_threshold, - ..Default::default() - } + pub fn new(_cleanup_threshold: usize) -> Self { + Self::default() } - /// Get the current timestamp as seconds since initialization - pub fn current_timestamp(&self) -> f64 { - self.start_time.elapsed().as_secs_f64() + pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, i64> { + self.free_table.keys() } - /// Get an iterator over the keys in the evictor - pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, f64> { - self.free_table.keys() + fn update(&mut self, object: T, counter: i64) { + self.free_table.insert(object.clone(), counter); + self.priority_queue.insert(PriorityItem { + item: object, + counter, + }); } - /// Insert or update an object in the evictor with current timestamp pub fn insert(&mut self, object: T) { - let timestamp = self.current_timestamp(); - self._insert(object, timestamp); - } + // Remove old entry if it exists + if let Some(&old_counter) = self.free_table.get(&object) { + self.priority_queue.remove(&PriorityItem { + item: object.clone(), + counter: old_counter, + }); + } - /// Check if the evictor contains the given object - pub fn contains(&self, object: &T) -> bool { - self.free_table.contains_key(object) + // Increment positive counter and insert + self.positive_counter += 1; + let counter = self.positive_counter; + + self.update(object, counter); } - /// Evict an object based on LRU policy - /// Returns the evicted object or None if no objects are available - pub fn evict(&mut self) -> Option { - if self.free_table.is_empty() { - return None; + /// Push an object to the front with negative counter (highest priority for eviction) + pub fn push_front(&mut self, object: T) { + // Remove old entry if it exists + if let Some(&old_counter) = self.free_table.get(&object) { + self.priority_queue.remove(&PriorityItem { + item: object.clone(), + counter: old_counter, + }); } - while let Some((object, last_accessed)) = self.priority_queue.pop_front() { - let Some(¤t_last_accessed) = self.free_table.get(&object) else { - continue; // entry is already removed - }; + // Decrement negative counter and insert + self.negative_counter -= 1; + let counter = self.negative_counter; - if current_last_accessed == last_accessed { - self.free_table.remove(&object); - return Some(object); - } // otherwise entry is stale - } + self.update(object, counter); + } - None + pub fn contains(&self, object: &T) -> bool { + self.free_table.contains_key(object) } - /// Insert or update an object in the evictor - fn _insert(&mut self, object: T, last_accessed: f64) { - self.free_table.insert(object.clone(), last_accessed); - self.priority_queue.push_back((object, last_accessed)); - self.cleanup_if_necessary(); + /// Evict an object based on LRU policy (lowest counter value) + /// Returns the evicted object or None if no objects are available + pub fn evict(&mut self) -> Option { + self.priority_queue.pop_first().map(|item| { + self.free_table.remove(&item.item); + item.item + }) } - /// Remove an object from the evictor - /// We don't remove from the priority queue immediately, as that would be inefficient - /// Outdated entries will be filtered out during eviction or cleanup pub fn remove(&mut self, object: &T) -> bool { - self.free_table.remove(object).is_some() + let Some(&counter) = self.free_table.get(object) else { + return false; + }; + + self.free_table.remove(object); + self.priority_queue.remove(&PriorityItem { + item: object.clone(), + counter, + }); + true } - /// Get the number of objects in the evictor pub fn len(&self) -> usize { self.free_table.len() } - /// Check if the evictor is empty pub fn is_empty(&self) -> bool { self.free_table.is_empty() } - - /// Check if cleanup is necessary and perform it if needed - fn cleanup_if_necessary(&mut self) { - if self.priority_queue.len() > self.cleanup_threshold * self.free_table.len() { - self.cleanup(); - } - } - - /// Clean up the priority queue by removing outdated entries - fn cleanup(&mut self) { - let mut new_priority_queue = VecDeque::new(); - for (object, timestamp) in self.priority_queue.drain(..) { - let Some(¤t_timestamp) = self.free_table.get(&object) else { - continue; - }; - - if current_timestamp == timestamp { - new_priority_queue.push_back((object, timestamp)); - } - } - self.priority_queue = new_priority_queue; - } } #[cfg(test)] mod tests { use super::*; - use rstest::rstest; - #[rstest] - #[case(1)] - #[case(2)] - #[case(3)] - fn test_lru_evictor_eviction_order(#[case] threshold: usize) { - // Create a new LRUEvictor with the given cleanup threshold - let mut evictor = LRUEvictor::::new(threshold); + #[test] + fn test_lru_evictor_eviction_order() { + // Create a new LRUEvictor + let mut evictor = LRUEvictor::::new(1); // threshold value doesn't matter anymore - // Add items in the specified order with small delays between each + // Add items in the specified order evictor.insert(4); - std::thread::sleep(std::time::Duration::from_millis(1)); evictor.insert(3); - std::thread::sleep(std::time::Duration::from_millis(1)); evictor.insert(2); - std::thread::sleep(std::time::Duration::from_millis(1)); evictor.insert(1); - std::thread::sleep(std::time::Duration::from_millis(1)); evictor.insert(5); - std::thread::sleep(std::time::Duration::from_millis(1)); - evictor.insert(1); // Updates timestamp for 1 - std::thread::sleep(std::time::Duration::from_millis(1)); - evictor.insert(4); // Updates timestamp for 4 - std::thread::sleep(std::time::Duration::from_millis(1)); - evictor.insert(2); // Updates timestamp for 2 + evictor.insert(1); // Updates counter for 1 + evictor.insert(4); // Updates counter for 4 + evictor.insert(2); // Updates counter for 2 + evictor.push_front(4); // Verify the eviction order - println!("Testing with threshold {}", threshold); + let evicted = evictor.evict().unwrap(); + assert_eq!(evicted, 4); let evicted = evictor.evict().unwrap(); assert_eq!(evicted, 3); let evicted = evictor.evict().unwrap(); @@ -181,11 +172,11 @@ mod tests { let evicted = evictor.evict().unwrap(); assert_eq!(evicted, 1); let evicted = evictor.evict().unwrap(); - assert_eq!(evicted, 4); - let evicted = evictor.evict().unwrap(); assert_eq!(evicted, 2); let evicted = evictor.evict(); assert_eq!(evicted, None); assert_eq!(evictor.len(), 0); } + + // ... existing test_push_front test ... } diff --git a/lib/llm/src/mocker/kv_manager.rs b/lib/llm/src/mocker/kv_manager.rs index d1cd4a41ec..d28e577c44 100644 --- a/lib/llm/src/mocker/kv_manager.rs +++ b/lib/llm/src/mocker/kv_manager.rs @@ -46,10 +46,11 @@ //! implementation of the main block manager. use crate::mocker::evictor::LRUEvictor; -use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock}; +use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost, UniqueBlock}; use crate::mocker::sequence::ActiveSequence; use derive_getters::Getters; use std::collections::{HashMap, HashSet}; +use tokio::sync::mpsc; #[derive(Getters)] pub struct KvManager { @@ -57,17 +58,27 @@ pub struct KvManager { max_capacity: usize, #[getter(copy)] - block_size: u32, + block_size: usize, active_blocks: HashMap, inactive_blocks: LRUEvictor, all_blocks: HashSet, + + move_block_response_tx: Option>, } impl KvManager { - pub fn new(max_capacity: usize, block_size: u32) -> Self { + pub fn new(max_capacity: usize, block_size: usize) -> Self { + Self::new_with_sender(max_capacity, block_size, None) + } + + pub fn new_with_sender( + max_capacity: usize, + block_size: usize, + move_block_response_tx: Option>, + ) -> Self { let active_blocks = HashMap::new(); let inactive_blocks = LRUEvictor::default(); let all_blocks = HashSet::new(); @@ -78,18 +89,46 @@ impl KvManager { active_blocks, inactive_blocks, all_blocks, + move_block_response_tx, + } + } + + /// Utility method to send block responses with optional reversing + fn send_block_response( + &self, + mut blocks: Vec, + reverse: bool, + store: bool, + parent_hash: Option, + ) { + if let Some(ref tx) = self.move_block_response_tx { + if !blocks.is_empty() { + if reverse { + blocks.reverse(); + } + let response = if store { + MoveBlockResponse::Store(blocks, parent_hash) + } else { + MoveBlockResponse::Remove(blocks) + }; + tx.send(response).unwrap(); + } } } /// Process a MoveBlock instruction synchronously pub fn process(&mut self, event: &MoveBlock) -> bool { match event { - MoveBlock::Use(hashes, _) => { + MoveBlock::Use(hashes) => { + let mut blocks_stored = Vec::::new(); + + let mut parent_block: Option<&UniqueBlock> = None; for hash in hashes { // First check if it already exists in active blocks if let Some(ref_count) = self.active_blocks.get_mut(hash) { // Block already active, just increment reference count *ref_count += 1; + parent_block = Some(hash); continue; } @@ -97,6 +136,7 @@ impl KvManager { if self.inactive_blocks.remove(hash) { // Insert into active with reference count 1 self.active_blocks.insert(hash.clone(), 1); + parent_block = Some(hash); continue; } @@ -106,30 +146,53 @@ impl KvManager { // If at max capacity, evict the oldest entry from inactive blocks if active_count + inactive_count >= self.max_capacity { - if let Some(evicted) = self.inactive_blocks.evict() { - // Remove evicted block from all_blocks - self.all_blocks.remove(&evicted); - } else { - // Cannot evict block, meaning no free blocks left in inactive pool - // Send a signal, scheduler would expect to handle preemption upon receiving this + let Some(evicted) = self.inactive_blocks.evict() else { return false; + }; + self.all_blocks.remove(&evicted); + if let UniqueBlock::FullBlock(evicted_full_block) = evicted { + self.send_block_response(vec![evicted_full_block], false, false, None); } } // Now insert the new block in active blocks with reference count 1 self.active_blocks.insert(hash.clone(), 1); - // Add to all_blocks as it's a new block self.all_blocks.insert(hash.clone()); + if self.move_block_response_tx.is_some() { + if let UniqueBlock::FullBlock(stored_full_block) = hash { + blocks_stored.push(*stored_full_block); + } + } } + + let parent_hash = match parent_block { + None => None, + Some(UniqueBlock::FullBlock(block)) => Some(*block), + Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"), + }; + self.send_block_response(blocks_stored, false, true, parent_hash); } + MoveBlock::Destroy(hashes) => { + let mut blocks_destroyed = Vec::::new(); + // Loop in inverse direction for hash in hashes.iter().rev() { self.active_blocks.remove(hash).unwrap(); // Remove from all_blocks when destroyed assert!(self.all_blocks.remove(hash)); + + // Track blocks for batch sending + if self.move_block_response_tx.is_some() { + if let UniqueBlock::FullBlock(destroyed_full_block) = hash { + blocks_destroyed.push(*destroyed_full_block); + } + } } + + self.send_block_response(blocks_destroyed, true, false, None); } + MoveBlock::Deref(hashes) => { // Loop in inverse direction for hash in hashes.iter().rev() { @@ -149,15 +212,15 @@ impl KvManager { } } } - MoveBlock::Promote(uuid, hash) => { + + MoveBlock::Promote(uuid, hash, parent_hash) => { let uuid_block = UniqueBlock::PartialBlock(*uuid); let hash_block = UniqueBlock::FullBlock(*hash); let Some(ref_count) = self.active_blocks.remove(&uuid_block) else { let in_all_blocks = self.all_blocks.contains(&uuid_block); panic!( - "Missing active block for promotion: {:?}. Block still exists: {}", - uuid_block, in_all_blocks + "Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}" ); }; @@ -167,6 +230,7 @@ impl KvManager { // Update all_blocks assert!(self.all_blocks.remove(&uuid_block)); self.all_blocks.insert(hash_block); + self.send_block_response(vec![*hash], false, true, *parent_hash); } } @@ -178,6 +242,7 @@ impl KvManager { pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize { blocks .iter() + // .filter(|&block| !self.active_blocks.contains_key(block)) .filter(|&block| !self.all_blocks.contains(block)) .count() } @@ -200,6 +265,11 @@ impl KvManager { self.active_blocks.len() } + /// Get the percentage of active blocks relative to maximum capacity + pub fn get_active_perc(&self) -> f64 { + self.active_blocks.len() as f64 / self.max_capacity as f64 + } + /// Get the number of inactive blocks pub fn num_inactive_blocks(&self) -> usize { self.inactive_blocks.len() @@ -216,63 +286,28 @@ impl KvManager { } /// Check if a sequence can be scheduled and calculate cost if possible - pub fn try_schedule( - &self, - sequence: &ActiveSequence, - watermark: f64, - tokens_budget: usize, - ) -> Option { - // Return None immediately if tokens_budget is 0 - if tokens_budget == 0 { - return None; - } - - // Get unique blocks from the sequence - let unique_blocks = sequence.unique_blocks(); - - // Get the count of new blocks - let new_blocks = self.probe_new_blocks(unique_blocks); - - // Calculate current usage and available capacity - let active_count = self.active_blocks.len(); - - // Check if we can schedule based on the watermark - if (active_count + new_blocks) as f64 > (1.0 - watermark) * self.max_capacity as f64 { - return None; - } - - // Calculate overlap blocks - let overlap_blocks = unique_blocks.len() - new_blocks; - - // Calculate new tokens - let new_tokens = sequence.num_input_tokens() - overlap_blocks * (self.block_size as usize); - - // // Print the full equation with actual values substituted - // println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)", - // new_tokens, - // sequence.num_input_tokens(), - // overlap_blocks, - // self.block_size); - - // Return None if new_tokens exceeds tokens_budget - if new_tokens > tokens_budget { - return None; - } + pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost { + let seq_blocks = sequence.unique_blocks(); + let new_blocks = self.probe_new_blocks(seq_blocks); + let overlap_blocks = seq_blocks.len() - new_blocks; + let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size; // Calculate prefill compute let prefill_compute = - new_tokens as f64 * (new_tokens + overlap_blocks * (self.block_size as usize)) as f64; + 1.25e-6 * (new_tokens as f64).powi(2) + 7.41e-2 * (new_tokens as f64) + 2.62e1; - Some(PrefillCost { + PrefillCost { + new_blocks, new_tokens, prefill_compute, - }) + } } } #[cfg(test)] mod tests { use super::*; + use tokio::sync::mpsc; #[test] fn test_failure_on_max_capacity() { @@ -282,7 +317,7 @@ mod tests { // Helper function to use multiple blocks that returns the response fn use_blocks(manager: &mut KvManager, ids: Vec) -> bool { let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); - manager.process(&MoveBlock::Use(blocks, None)) + manager.process(&MoveBlock::Use(blocks)) } // First use 10 blocks (0 to 9) in a batch @@ -301,15 +336,17 @@ mod tests { } #[test] - // This is taken directly from the example in the vllm v1 prefix caching docs fn test_block_lifecycle_stringent() { - // Create a KvManager with 10 blocks capacity - let mut manager = KvManager::new(10, 16); + // Create a channel to listen to block responses + let (tx, mut rx) = mpsc::unbounded_channel::(); + + // Create a KvManager with 10 blocks capacity and the response sender + let mut manager = KvManager::new_with_sender(10, 16, Some(tx)); // Helper function to use multiple blocks fn use_blocks(manager: &mut KvManager, ids: Vec) { let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect(); - manager.process(&MoveBlock::Use(blocks, None)); + manager.process(&MoveBlock::Use(blocks)); } // Helper function to destroy multiple blocks @@ -324,6 +361,56 @@ mod tests { manager.process(&MoveBlock::Deref(blocks)); } + // Helper function to assert block responses + fn assert_block_response( + rx: &mut mpsc::UnboundedReceiver, + expected_type: &str, + expected_blocks: Vec, + description: &str, + ) { + let response = rx + .try_recv() + .unwrap_or_else(|_| panic!("Expected {expected_type} response {description}")); + + match (&response, expected_type) { + (MoveBlockResponse::Store(blocks, _parent_hash), "Store") => { + assert_eq!( + blocks.len(), + expected_blocks.len(), + "Expected {} blocks in Store response {}", + expected_blocks.len(), + description + ); + assert_eq!( + *blocks, expected_blocks, + "Store blocks don't match expected {description}" + ); + } + (MoveBlockResponse::Remove(blocks), "Remove") => { + assert_eq!( + blocks.len(), + expected_blocks.len(), + "Expected {} blocks in Remove response {}", + expected_blocks.len(), + description + ); + assert_eq!( + *blocks, expected_blocks, + "Remove blocks don't match expected {description}" + ); + } + _ => panic!("Expected {expected_type} response, got {response:?} {description}"), + } + } + + // Helper function to assert no response is received + fn assert_no_response( + rx: &mut mpsc::UnboundedReceiver, + description: &str, + ) { + assert!(rx.try_recv().is_err(), "Expected no response {description}",); + } + // Helper function to check if active blocks contain expected blocks with expected ref counts fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) { assert_eq!( @@ -336,14 +423,12 @@ mod tests { let block = UniqueBlock::FullBlock(id); assert!( manager.active_blocks().contains_key(&block), - "Block {} not found in active blocks", - id + "Block {id} not found in active blocks", ); assert_eq!( manager.active_blocks().get(&block), Some(&ref_count), - "Block {} has wrong reference count", - id + "Block {id} has wrong reference count", ); } } @@ -366,17 +451,18 @@ mod tests { let block = UniqueBlock::FullBlock(id); assert!( inactive_blocks.iter().any(|&b| *b == block), - "Block {} not found in inactive blocks", - id + "Block {id} not found in inactive blocks", ); } } // First use blocks 0, 1, 2, 3, 4 in a batch use_blocks(&mut manager, (0..5).collect()); + assert_block_response(&mut rx, "Store", vec![0, 1, 2, 3, 4], "after first use"); // Then use blocks 0, 1, 5, 6 in a batch use_blocks(&mut manager, vec![0, 1, 5, 6]); + assert_block_response(&mut rx, "Store", vec![5, 6], "after second use"); // Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2 assert_active_blocks( @@ -386,9 +472,11 @@ mod tests { // Now destroy block 4 destroy_blocks(&mut manager, vec![4]); + assert_block_response(&mut rx, "Remove", vec![4], "after destroy block 4"); // And deref blocks 3, 2, 1, 0 in this order as a batch deref_blocks(&mut manager, vec![0, 1, 2, 3]); + assert_no_response(&mut rx, "after deref operation"); // Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2 assert_inactive_blocks(&manager, 2, &[3, 2]); @@ -396,6 +484,7 @@ mod tests { // Now destroy block 6 destroy_blocks(&mut manager, vec![6]); + assert_block_response(&mut rx, "Remove", vec![6], "after block 6 eviction"); // And deref blocks 5, 1, 0 as a batch deref_blocks(&mut manager, vec![0, 1, 5]); @@ -406,6 +495,7 @@ mod tests { // Now use 0, 1, 2, 7, 8, 9 as a batch use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]); + assert_block_response(&mut rx, "Store", vec![7, 8, 9], "after [7, 8, 9] use"); // Check that the inactive_blocks is size 2, and contains 3 and 5 assert_inactive_blocks(&manager, 2, &[3, 5]); @@ -420,8 +510,14 @@ mod tests { // Now use blocks 10, 11, 12 as a batch use_blocks(&mut manager, vec![10, 11, 12]); + assert_block_response(&mut rx, "Remove", vec![3], "after block 5 eviction"); + assert_block_response(&mut rx, "Store", vec![10, 11, 12], "after [10, 11, 12] use"); // Check that the inactive_blocks is size 1 and contains only 5 assert_inactive_blocks(&manager, 1, &[5]); + + use_blocks(&mut manager, vec![13]); + assert_block_response(&mut rx, "Remove", vec![5], "after block 5 eviction"); + assert_block_response(&mut rx, "Store", vec![13], "after block 13 use"); } } diff --git a/lib/llm/src/mocker/protocols.rs b/lib/llm/src/mocker/protocols.rs index 2b551db61b..880b97495c 100644 --- a/lib/llm/src/mocker/protocols.rs +++ b/lib/llm/src/mocker/protocols.rs @@ -13,12 +13,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use derive_builder::Builder; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use crate::kv_router::protocols::{ + ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, + KvCacheStoredBlockData, LocalBlockHash, +}; + pub type Token = u32; -pub type LocalBlockHash = u64; -/// A global hash identifier for blocks pub type GlobalHash = u64; pub type NumBlocks = usize; @@ -39,12 +43,19 @@ impl Default for UniqueBlock { } /// Represents different block movement operations in the cache +/// For Use and Promote variants, parent hash is the second field #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum MoveBlock { - Use(Vec, Option), + Use(Vec), Destroy(Vec), Deref(Vec), - Promote(Uuid, GlobalHash), + Promote(Uuid, GlobalHash, Option), +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum MoveBlockResponse { + Store(Vec, Option), + Remove(Vec), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -52,15 +63,86 @@ pub struct DirectRequest { pub tokens: Vec, pub max_output_tokens: usize, pub uuid: Option, + pub dp_rank: Option, } /// Represents the cost of prefilling content in the cache #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PrefillCost { + pub new_blocks: usize, pub new_tokens: usize, pub prefill_compute: f64, } +/// Signal for output token generation with completion status +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutputSignal { + pub uuid: Uuid, + pub completed: bool, +} + +/// Configuration arguments for MockVllmEngine +#[derive(Debug, Clone, Serialize, Deserialize, Builder)] +#[builder(pattern = "owned", build_fn(public))] +pub struct MockEngineArgs { + #[builder(default = "16384")] + pub num_gpu_blocks: usize, + + #[builder(default = "64")] + pub block_size: usize, + + // This was 1024 in the past but reverted back to 256 + #[builder(default = Some(256))] + pub max_num_seqs: Option, + + // default for open api server, for llm class it's 16384 + #[builder(default = Some(8192))] + pub max_num_batched_tokens: Option, + + #[builder(default = true)] + pub enable_prefix_caching: bool, + + #[builder(default = "0.01")] + pub watermark: f64, + + #[builder(default = "1.0")] + pub speedup_ratio: f64, + + #[builder(default = "1")] + pub dp_size: u32, +} + +impl MockEngineArgs { + pub fn builder() -> MockEngineArgsBuilder { + MockEngineArgsBuilder::default() + } +} + +/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases +/// where the sequence-aware hash differs from the token content hash. +pub fn block_response_to_kv_event(response: MoveBlockResponse) -> KvCacheEventData { + match response { + MoveBlockResponse::Store(full_blocks, parent_hash) => { + KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: parent_hash.map(ExternalSequenceBlockHash), + blocks: full_blocks + .into_iter() + .map(|block| KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(block), + tokens_hash: LocalBlockHash(block), + }) + .collect(), + }) + } + MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData { + block_hashes: full_blocks + .into_iter() + .map(ExternalSequenceBlockHash) + .collect(), + }), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/lib/llm/src/mocker/scheduler.rs b/lib/llm/src/mocker/scheduler.rs index 604a6e1589..0223b04d5e 100644 --- a/lib/llm/src/mocker/scheduler.rs +++ b/lib/llm/src/mocker/scheduler.rs @@ -40,11 +40,13 @@ //! ## NOTE //! The current prefill and decoding time simulations are not scientific at all and are WIP -use crate::kv_router::protocols::ForwardPassMetrics; +use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData}; use crate::mocker::evictor::LRUEvictor; use crate::mocker::kv_manager::KvManager; -use crate::mocker::protocols::DirectRequest; -use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock}; +use crate::mocker::protocols::{ + block_response_to_kv_event, MoveBlock, OutputSignal, PrefillCost, UniqueBlock, +}; +use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse}; use crate::mocker::sequence::ActiveSequence; use std::collections::HashMap; use std::collections::VecDeque; @@ -63,8 +65,8 @@ pub enum Request { #[derive(Default)] struct SchedulerState { waiting: VecDeque, - ready: VecDeque, - running: LRUEvictor, + prefill: VecDeque, + decode: LRUEvictor, requests: HashMap, prefill_costs: HashMap>, } @@ -74,61 +76,70 @@ impl SchedulerState { fn receive(&mut self, request: DirectRequest) -> Uuid { // Use the provided UUID if available, otherwise generate a new one let uuid = request.uuid.unwrap_or_else(Uuid::new_v4); - - // Add the request to the map and waiting queue self.requests.insert(uuid, Request::Direct(request)); self.waiting.push_back(uuid); uuid } /// Get the next UUID from ready or waiting queue and its associated Request. - /// Returns from ready if not empty, otherwise from waiting, or None if both are empty. - /// Also removes the Request from the requests HashMap. fn next(&mut self) -> Option<(Uuid, Request)> { - let uuid = self - .ready - .pop_front() - .or_else(|| self.waiting.pop_front())?; - let request = self.requests.remove(&uuid)?; + let uuid = self.waiting.pop_front()?; + let request = self + .requests + .remove(&uuid) + .expect("Request does not exist."); Some((uuid, request)) } + /// Move a UUID and its Request to the waiting queue (front). + fn first_in_line(&mut self, uuid: Uuid, request: Request) { + self.requests.insert(uuid, request); + self.waiting.push_front(uuid); + } + /// Move a UUID and its Request to the ready queue. - fn make_ready(&mut self, uuid: Uuid, active_seq: ActiveSequence) { + fn start_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: Option) { self.requests.insert(uuid, Request::Active(active_seq)); - self.ready.push_back(uuid); + self.prefill.push_back(uuid); + self.prefill_costs.insert(uuid, cost); } - /// Schedule the request with the given UUID. - /// Returns the creation signal from the ActiveSequence. - fn run(&mut self, uuid: Uuid, active_seq: ActiveSequence) -> MoveBlock { - // Insert the request into the map - self.requests.insert(uuid, Request::Active(active_seq)); + /// Pop from prefill queue and move to decode queue. + /// Returns the prefill_compute value if available. + fn start_decode(&mut self) -> Option<(f64, MoveBlock)> { + let uuid = self.prefill.pop_front()?; + self.decode.insert(uuid); + + // Remove and extract prefill_compute from prefill_costs + let prefill_cost = self + .prefill_costs + .remove(&uuid) + .flatten() + .expect("Expects valid prefill cost."); - // Get the creation signal let Some(Request::Active(sequence)) = self.requests.get(&uuid) else { - panic!("Failed to get ActiveSequence for UUID"); - }; - let Some(signal) = sequence.creation_signal() else { - panic!("Failed to get creation signal from ActiveSequence"); + panic!("Request does not exist."); }; + let creation_signal = sequence + .creation_signal() + .clone() + .expect("Must have creation signal."); - // Add to running requests - self.running.insert(uuid); - signal.clone() + Some((prefill_cost.prefill_compute, creation_signal)) } - /// Set the prefill cost for a UUID - fn set_prefill_cost(&mut self, uuid: Uuid, cost: Option) { - self.prefill_costs.insert(uuid, cost); + fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> { + if !self.decode.contains(&uuid) { + return None; + } + let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else { + panic!("Request does not exist."); + }; + Some(sequence) } - /// Get the prefill compute value for a UUID if available - fn get_prefill_compute(&self, uuid: &Uuid) -> Option { - self.prefill_costs - .get(uuid) - .and_then(|cost| cost.as_ref()) - .map(|cost| cost.prefill_compute) + fn num_active_requests(&self) -> usize { + self.prefill.len() + self.decode.len() } /// Calculate the current running batched tokens @@ -145,7 +156,7 @@ impl SchedulerState { /// Remove a UUID and its associated Request from collections. fn complete(&mut self, uuid: &Uuid) { // println!("Request {} will complete", uuid); - self.running.remove(uuid); + self.decode.remove(uuid); self.requests.remove(uuid); self.prefill_costs.remove(uuid); } @@ -153,76 +164,93 @@ impl SchedulerState { /// Preempt the oldest running request by evicting it from running, resetting the sequence, /// and adding it back to the waiting queue. /// Returns the signal from reset_with_signal or None if no requests are running. - fn preempt(&mut self) -> Option> { + fn preempt(&mut self) -> Vec { // Evict the oldest UUID from running - let uuid = self.running.evict()?; - eprintln!("Request {} will be preempted", uuid); - - // Remove the request from the requests HashMap and ensure it's an ActiveSequence - let request = self.requests.remove(&uuid)?; - - // Remove the prefill cost to force recomputation + let uuid = self + .decode + .evict() + .expect("Nothing to evict for preemption."); + let request = self + .requests + .remove(&uuid) + .expect("Request does not exist."); self.prefill_costs.remove(&uuid); + eprintln!("Request {uuid} will be preempted"); - // Extract the ActiveSequence from the Request enum + // Reset the sequence and get the new sequence and signal + // Insert the new sequence back into the requests map and add to waiting queue let Request::Active(mut active_sequence) = request else { panic!("Expected ActiveSequence in running queue") }; - - // Reset the sequence and get the new sequence and signal let signals = active_sequence.reset_with_signal(); - // Insert the new sequence back into the requests map and add to waiting queue - self.requests.insert(uuid, Request::Active(active_sequence)); - self.waiting.push_back(uuid); + // Note: For preemption, we don't compute hit rate since we don't have access to new_tokens + // and the sequence is being reset anyway. Hit rate tracking is primarily for new scheduling attempts. - Some(signals) + self.first_in_line(uuid, Request::Active(active_sequence)); + + signals } } /// Manages scheduling of requests using KvManager resources #[derive(Clone)] pub struct Scheduler { + dp_rank: Option, state: Arc>, kv_manager: Arc>, - request_tx: mpsc::Sender, + request_tx: mpsc::UnboundedSender, + hit_rates: Arc>>, } impl Scheduler { /// Create a new Scheduler with the given parameters pub fn new( - kv_capacity: usize, - watermark: f64, - block_size: u32, - chunk_size: Option, - output_tx: Option>, + args: MockEngineArgs, + dp_rank: Option, + output_tx: Option>, + kv_events_tx: Option>, cancellation_token: Option, ) -> Self { - // Create KvManager internally - let kv_manager = KvManager::new(kv_capacity, block_size); - - let token_capacity: usize = 8192; let state = Arc::new(Mutex::new(SchedulerState::default())); - let kv_manager = Arc::new(Mutex::new(kv_manager)); - let chunk_size = chunk_size.unwrap_or(256); + // Create internal channel for KV events only if needed + let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() { + let (tx, rx) = mpsc::unbounded_channel::(); + (Some(tx), Some(rx)) + } else { + (None, None) + }; - // Create channel for request handling - let (request_tx, mut request_rx) = mpsc::channel::(1024); + let kv_manager = Arc::new(Mutex::new(KvManager::new_with_sender( + args.num_gpu_blocks, + args.block_size, + block_resp_tx, + ))); + let hit_rates = Arc::new(Mutex::new(VecDeque::with_capacity(1000))); + + // Assert speedup_ratio is greater than 0 + assert!( + args.speedup_ratio > 0.0, + "speedup_ratio must be greater than 0, got: {}", + args.speedup_ratio + ); - // Use provided cancellation token or create new one - let cancellation_token = cancellation_token.unwrap_or_default(); - let token_clone = cancellation_token.clone(); + // Create channel for request handling + let (request_tx, mut request_rx) = mpsc::unbounded_channel::(); // Create a clone for the background task let state_clone = state.clone(); let kv_manager_clone = kv_manager.clone(); let output_tx_clone = output_tx.clone(); + let cancel_token_clone = cancellation_token.unwrap_or_default().clone(); + let hit_rates_clone = hit_rates.clone(); // Spawn main background task with cancellation token tokio::spawn(async move { - let mut schedule_interval = interval(Duration::from_millis(5)); - let mut simulate_interval = interval(Duration::from_millis(1)); + let mut schedule_interval = interval(Duration::from_secs_f64(1e-3)); + let mut simulate_interval = interval(Duration::from_secs_f64(1e-4)); + let mut should_schedule = true; loop { tokio::select! { @@ -234,35 +262,63 @@ impl Scheduler { state.receive(request); } - // Try Scheduling Requests + // Try Scheduling Requests - runs on normal interval or after simulation _ = schedule_interval.tick() => { + // Skip if we just ran scheduling after simulation to prevent consecutive runs + if !should_schedule { + continue; + } + let mut state_guard = state_clone.lock().await; - let mut kv_manager_guard = kv_manager_clone.lock().await; + let kv_manager_guard = kv_manager_clone.lock().await; // Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't // schedule anymore. + let mut current_blocks = kv_manager_guard.num_active_blocks(); + let mut current_tokens = state_guard.num_batched_tokens(); + let mut current_seqs = state_guard.num_active_requests(); + while let Some((uuid, request)) = state_guard.next() { - let active_sequence = get_active_sequence(request, block_size, chunk_size); + let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching); - // Calculate token budget using new_tokens from PrefillCost - let total_prefill_tokens = state_guard.num_batched_tokens(); - let tokens_budget = token_capacity.saturating_sub(total_prefill_tokens); + // Update predictive budgets + let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence); + let total_tokens = active_sequence.len(); + let new_blocks = (total_tokens + 1) / args.block_size; // this is conservative, assumes no cache hit + let new_tokens = prefill_cost.new_tokens; + + current_blocks += new_blocks; + current_tokens += new_tokens; + current_seqs += 1; // Check if it can be scheduled - let Some(prefill_cost) = kv_manager_guard.try_schedule(&active_sequence, watermark, tokens_budget) else { - state_guard.make_ready(uuid, active_sequence); + let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64; + let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| current_tokens <= limit); + let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit); + + // Cannot schedule, put first in line instead + if !(under_block_budget && under_token_budget && under_seq_budget) { + state_guard.first_in_line(uuid, Request::Active(active_sequence)); break; - }; + } + + // Compute and store hit rate + let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 }; + { + let mut hit_rates_guard = hit_rates_clone.lock().await; + hit_rates_guard.push_back(hit_rate); + if hit_rates_guard.len() > 1000 { + hit_rates_guard.pop_front(); + } + } - // Get creation signal and schedule the request - let signal = state_guard.run(uuid, active_sequence); - kv_manager_guard.process(&signal); - state_guard.set_prefill_cost(uuid, Some(prefill_cost)); + state_guard.start_prefill(uuid, active_sequence, Some(prefill_cost)); + should_schedule = false; } } // Check for cancellation - _ = token_clone.cancelled() => { + _ = cancel_token_clone.cancelled() => { break; } @@ -271,75 +327,84 @@ impl Scheduler { let mut state_guard = state_clone.lock().await; let mut kv_manager_guard = kv_manager_clone.lock().await; - // Base time needed for decoding (assumed memory bound on KV cache) - let active_tokens = kv_manager_guard.num_active_blocks() * (block_size as usize); - // TODO: 2 is a dummy / magic scaling factor - let mut generation_time = Duration::from_micros((active_tokens / 2) as u64); - - // Process each running request - let uuids: Vec = state_guard.running.keys().cloned().collect(); - for uuid in uuids { - // Check if UUID is still in running_requests, if not skip this iteration - if !state_guard.running.contains(&uuid) { - continue; + // Base time needed for decoding using active percentage and quadratic formula + let active_perc = kv_manager_guard.get_active_perc(); + let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44; + let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0); + + // Process prefilling + while let Some((prefill_compute, creation_signal)) = state_guard.start_decode() { + // NOTE: Prefill cost/time is always incremented for new blocks, even if they + // could be cached by other requests in the same batch. This matches vLLM behavior. + total_time += Duration::from_secs_f64(prefill_compute / 1000.0); + let prefill_success = process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal)); + if !prefill_success { + panic!("Block allocation for prefilling cannot fail."); } - // Get prefill compute value first - let prefill_compute = state_guard.get_prefill_compute(&uuid); - - // Get the active sequence for this UUID - let sequence = state_guard.requests.get_mut(&uuid) - .and_then(|req| if let Request::Active(seq) = req { Some(seq) } else { None }) - .expect("UUID in running_requests must have a corresponding active sequence"); + // Drain KV events and forward to relay after prefill signal processing + if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) { + while let Ok(event) = rx.try_recv() { + let _ = relay_tx.send(block_response_to_kv_event(event)); + } + } + } - // Generate token and get signals + // Process decoding + let uuids: Vec = state_guard.decode.keys().cloned().collect(); + if !uuids.is_empty() {should_schedule = true}; + for uuid in uuids { + let Some(sequence) = state_guard.run(uuid) else { + continue; + }; let signals = sequence.generate(); - // Accumulate sleep duration based on prefill_compute if available - // prefill compute = (cached_tokens + new_tokens) * new_tokens - let sleep_ms = if let Some(compute) = prefill_compute { - // TODO: 1024 is a dummy / magic scaling factor - (compute / 1024.0) as u64 - } else { 0 }; - generation_time += Duration::from_micros(sleep_ms); - // Process all signals with the KvManager // Handling of preemption on failure if !process_signals(&mut kv_manager_guard, &signals) { sequence.pop(); // revert the failed generation op - - // free_signal derefs the preempted blocks - let Some(free_signal) = state_guard.preempt() else { - panic!("Failed to acquire signal to free KV blocks from preemption"); - }; - - for signal in free_signal { + for signal in state_guard.preempt() { kv_manager_guard.process(&signal); } continue; } - // Send UUID notification for each generated token - // TODO: hook this up to an AsyncEngine - if let Some(tx) = &output_tx_clone { - let _ = tx.try_send(uuid); + // Drain KV events and forward to relay after decode signal processing + if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) { + while let Ok(event) = rx.try_recv() { + let _ = relay_tx.send(block_response_to_kv_event(event)); + } } - // Check if we're done after generating - if sequence.generated_tokens() >= sequence.max_output_tokens() { - state_guard.complete(&uuid); - continue; + // Check completion and send notification + let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); + let should_output = sequence.generated_tokens() > sequence.already_generated_tokens(); + + let mut send_failed = false; + if should_output { + send_failed = output_tx_clone.as_ref().is_some_and(|tx| { + tx.send(OutputSignal { uuid, completed: is_complete }).is_err() + }); } - // Transition to decode (no prefill cost) - if sequence.generated_tokens() == 1 { - state_guard.set_prefill_cost(uuid, None); + if send_failed { + for signal in &sequence.free_signal() { + kv_manager_guard.process(signal); + } + } + + if send_failed || is_complete { + state_guard.complete(&uuid); + continue; } } - // Sleep once for the accumulated duration - if generation_time.as_millis() > 0 { - tokio::time::sleep(generation_time).await; + // Sleep once for the adjusted duration + drop(kv_manager_guard); + drop(state_guard); + let adjusted_time = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio); + if adjusted_time.as_millis() > 0 { + tokio::time::sleep(adjusted_time).await; } } } @@ -347,15 +412,22 @@ impl Scheduler { }); Self { + dp_rank, state, kv_manager, request_tx, + hit_rates, } } /// Add a new request to the waiting queue pub async fn receive(&self, request: DirectRequest) { - let _ = self.request_tx.send(request).await; + let _ = self.request_tx.send(request); + } + + /// Expose the sender + pub fn request_sender(&self) -> mpsc::UnboundedSender { + self.request_tx.clone() } /// Get the count of waiting requests @@ -367,7 +439,7 @@ impl Scheduler { /// Get the count of running requests pub async fn running_count(&self) -> usize { let state = self.state.lock().await; - state.running.len() + state.decode.len() } /// Get the current capacity of the KvManager @@ -378,35 +450,53 @@ impl Scheduler { /// Returns forward pass metrics for monitoring purposes pub async fn get_forward_pass_metrics(&self) -> ForwardPassMetrics { + // Acquire all locks in consistent order: state -> kv_manager -> hit_rates let state = self.state.lock().await; let kv_manager = self.kv_manager.lock().await; + let hit_rates_guard = self.hit_rates.lock().await; + + // Get state metrics + let request_active_slots = state.decode.len() as u64; + let num_requests_waiting = state.waiting.len() as u64; - // Get the active blocks and total capacity from KvManager + // Get KV manager metrics let active_blocks_count = kv_manager.active_blocks().len() as u64; let total_capacity = kv_manager.max_capacity() as u64; - - // Calculate GPU cache usage percentage let gpu_cache_usage_perc = if total_capacity > 0 { active_blocks_count as f32 / total_capacity as f32 } else { 0.0 }; + // Get hit rate metrics + let gpu_prefix_cache_hit_rate = if hit_rates_guard.is_empty() { + 0.0 + } else { + let sum: f32 = hit_rates_guard.iter().sum(); + sum / hit_rates_guard.len() as f32 + }; + ForwardPassMetrics { - data_parallel_rank: None, // Default for backwards compatibility - request_active_slots: state.running.len() as u64, - request_total_slots: 420, // Dummy value as specified + data_parallel_rank: self.dp_rank, + request_active_slots, + // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128 + request_total_slots: 1024, kv_active_blocks: active_blocks_count, kv_total_blocks: total_capacity, - num_requests_waiting: state.waiting.len() as u64, + num_requests_waiting, gpu_cache_usage_perc, - gpu_prefix_cache_hit_rate: 0.0, // Placeholder value as specified + gpu_prefix_cache_hit_rate, } + // Guards drop naturally here in reverse order (LIFO): hit_rates_guard, kv_manager, state } } /// Convert a Request to an ActiveSequence -fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) -> ActiveSequence { +fn get_active_sequence( + request: Request, + block_size: usize, + enable_prefix_caching: bool, +) -> ActiveSequence { if let Request::Active(active_seq) = request { return active_seq; } @@ -419,7 +509,7 @@ fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) -> direct_request.tokens, direct_request.max_output_tokens, Some(block_size), - Some(chunk_size), + enable_prefix_caching, ) } @@ -440,7 +530,7 @@ fn process_signals( } // Check we have a Use signal with blocks - let MoveBlock::Use(blocks, _) = signal else { + let MoveBlock::Use(blocks) = signal else { panic!("Failed signal is Invalid. Has to fail on generation signal."); }; @@ -467,32 +557,37 @@ mod tests { use std::time::Duration; #[rstest] - #[case::random(false)] - #[case::caching(true)] + #[case::random_no_prefix_caching(false, false)] + #[case::random_with_prefix_caching(false, true)] + #[case::caching_no_prefix_caching(true, false)] + #[case::caching_with_prefix_caching(true, true)] #[tokio::test] - async fn test_scheduler_token_generation_patterns(#[case] use_shared_tokens: bool) { + async fn test_scheduler_token_generation_patterns( + #[case] use_shared_tokens: bool, + #[case] enable_prefix_caching: bool, + ) { std::env::set_var("RUST_LOG", "debug"); let kv_capacity: usize = 500; - let watermark: f64 = 0.01; // 1% watermark - let block_size: u32 = 64; - let chunk_size: usize = 256; + let block_size: usize = 64; let num_requests: usize = 100; let input_len: usize = 1000; let max_output_tokens: usize = 100; // Create channel for token output - let (output_tx, mut output_rx) = mpsc::channel::(1024); - - // Create scheduler with internal KvManager - let scheduler = Scheduler::new( - kv_capacity, - watermark, - block_size, - Some(chunk_size), - Some(output_tx), - None, - ); + let (output_tx, mut output_rx) = mpsc::unbounded_channel::(); + + // Create scheduler args using builder - now including enable_prefix_caching + let args = MockEngineArgs::builder() + .num_gpu_blocks(kv_capacity) + .block_size(block_size) + .speedup_ratio(10.0) + .enable_prefix_caching(enable_prefix_caching) + .build() + .unwrap(); + + // Create scheduler with new args struct + let scheduler = Scheduler::new(args, None, Some(output_tx), None, None); // Create shared tokens for caching case let shared_tokens = if use_shared_tokens { @@ -523,6 +618,7 @@ mod tests { tokens: input_tokens, max_output_tokens, uuid: None, + dp_rank: None, }; scheduler.receive(request).await; } @@ -547,7 +643,7 @@ mod tests { // Manual debug ticker that prints forward pass metrics _ = debug_interval.tick() => { let _metrics = scheduler.get_forward_pass_metrics().await; - // println!("Forward Pass Metrics: {:#?}", _metrics); + println!("Forward Pass Metrics: {_metrics:#?}"); } Some(_) = output_rx.recv() => { @@ -566,21 +662,177 @@ mod tests { // Calculate and print elapsed time let elapsed = start_time.elapsed(); println!( - "Test completed in: {:?} for {} case", + "Test completed in: {:?} for {} case with prefix_caching={}", elapsed, if use_shared_tokens { "caching" } else { "random" - } + }, + enable_prefix_caching ); // Assert that we received the expected number of tokens assert!( - received_tokens > expected_tokens, - "Received {} tokens but expected more than {}", - received_tokens, - expected_tokens + received_tokens == expected_tokens, + "Received {received_tokens} tokens but expected exactly {expected_tokens}" + ); + } + + #[tokio::test] + async fn test_cache_hit_rate_with_identical_requests() { + let block_size: usize = 64; + let max_output_tokens: usize = 10; + let speedup_ratio = 10.0; + let num_requests = 10; + let token_length = 65; + + // Create channel for token output + let (output_tx, mut output_rx) = mpsc::unbounded_channel::(); + + // Create scheduler args + let args = MockEngineArgs::builder() + .num_gpu_blocks(100) // Large enough to not be a constraint + .block_size(block_size) + .speedup_ratio(speedup_ratio) + .build() + .unwrap(); + + // Create scheduler + let scheduler = Scheduler::new(args, None, Some(output_tx), None, None); + + // Create identical tokens for all requests + let identical_tokens: Vec = (0..token_length).map(|i| i as u32).collect(); + + // Send all requests with identical tokens + for _ in 0..num_requests { + let request = DirectRequest { + tokens: identical_tokens.clone(), + max_output_tokens, + uuid: None, + dp_rank: None, + }; + scheduler.receive(request).await; + // Sleep for 0.1 second after each request + tokio::time::sleep(Duration::from_millis(100)).await; + } + + // Collect all generated tokens + let mut received_tokens = 0; + + // Set up a timeout that resets to 0.5 seconds on each received token + let timeout = tokio::time::sleep(Duration::from_millis(500)); + tokio::pin!(timeout); + + // Set up debug ticker interval + let mut debug_interval = interval(Duration::from_millis(500)); + + loop { + tokio::select! { + biased; + + // Manual debug ticker that prints forward pass metrics + _ = debug_interval.tick() => { + let _metrics = scheduler.get_forward_pass_metrics().await; + println!("Forward Pass Metrics: {_metrics:#?}"); + } + + Some(_signal) = output_rx.recv() => { + received_tokens += 1; + // Reset timeout whenever we receive a token + timeout.set(tokio::time::sleep(Duration::from_millis(500))); + } + + _ = &mut timeout => { + // Break when timeout occurs (no more tokens for 0.5 seconds) + break; + } + } + } + + // Verify forward pass metrics + let metrics = scheduler.get_forward_pass_metrics().await; + + assert_eq!( + metrics.num_requests_waiting, 0, + "Expected no waiting requests, got {}", + metrics.num_requests_waiting + ); + + assert!( + metrics.gpu_prefix_cache_hit_rate > 0.8, + "Expected cache hit rate > 0.8, got {}", + metrics.gpu_prefix_cache_hit_rate + ); + + println!( + "Test passed! Cache hit rate: {:.3}", + metrics.gpu_prefix_cache_hit_rate + ); + println!("Received {received_tokens} tokens"); + } + + #[tokio::test] + async fn test_receiver_drop_cleans_up_resources() { + let block_size: usize = 64; + let input_tokens = 256; + let max_output_tokens = 200; // More than we'll receive + + // Create channel for token output + let (output_tx, mut output_rx) = mpsc::unbounded_channel::(); + + // Create scheduler args + let args = MockEngineArgs::builder() + .num_gpu_blocks(10) // Enough for 256 tokens (4 blocks) + .block_size(block_size) + .speedup_ratio(100.0) // Fast simulation + .build() + .unwrap(); + + // Create scheduler + let scheduler = Scheduler::new(args, None, Some(output_tx), None, None); + + // Create request with 256 tokens + let tokens: Vec = (0..input_tokens).map(|i| i as u32).collect(); + let request = DirectRequest { + tokens, + max_output_tokens, + uuid: None, + dp_rank: None, + }; + + scheduler.receive(request).await; + + // Receive exactly 129 tokens + let mut received_count = 0; + while received_count < 129 { + if let Some(_signal) = output_rx.recv().await { + received_count += 1; + } else { + panic!("Channel closed before receiving 129 tokens"); + } + } + + // Drop the receiver immediately + drop(output_rx); + + // Wait for 1 second to allow cleanup + tokio::time::sleep(Duration::from_secs(1)).await; + + // Check forward pass metrics + let metrics = scheduler.get_forward_pass_metrics().await; + + assert_eq!( + metrics.gpu_cache_usage_perc, + 0.0, + "Expected GPU cache usage to be 0%, got {}%", + metrics.gpu_cache_usage_perc * 100.0 + ); + + assert_eq!( + metrics.kv_active_blocks, 0, + "Expected 0 active blocks, got {}", + metrics.kv_active_blocks ); } } diff --git a/lib/llm/src/mocker/sequence.rs b/lib/llm/src/mocker/sequence.rs index e8900fae2c..2145d8e561 100644 --- a/lib/llm/src/mocker/sequence.rs +++ b/lib/llm/src/mocker/sequence.rs @@ -23,16 +23,23 @@ use uuid; fn create_unique_blocks_from_sequence( tokens: &TokenBlockSequence, uuid: Option, - block_size: u32, + block_size: usize, + enable_prefix_caching: bool, ) -> Vec { let mut unique_blocks: Vec = tokens .blocks() .iter() - .map(|block| UniqueBlock::FullBlock(block.sequence_hash())) + .map(|block| { + if enable_prefix_caching { + UniqueBlock::FullBlock(block.sequence_hash()) + } else { + UniqueBlock::FullBlock(random::()) + } + }) .collect(); // Only push the partial block if tokens count isn't a multiple of block_size - if tokens.total_tokens() % (block_size as usize) != 0 { + if tokens.total_tokens() % block_size != 0 { unique_blocks.push(match uuid { Some(uuid) => UniqueBlock::PartialBlock(uuid), None => UniqueBlock::default(), @@ -50,10 +57,7 @@ pub struct ActiveSequence { tokens: TokenBlockSequence, #[getter(copy)] - block_size: u32, - - #[getter(copy)] - chunk_size: usize, // TODO: not actually used + block_size: usize, #[getter(copy)] max_output_tokens: usize, @@ -61,10 +65,16 @@ pub struct ActiveSequence { #[getter(copy)] generated_tokens: usize, + #[getter(copy)] + already_generated_tokens: usize, + #[getter(copy)] num_input_tokens: usize, creation_signal: Option, + + #[getter(copy)] + enable_prefix_caching: bool, } impl ActiveSequence { @@ -72,32 +82,33 @@ impl ActiveSequence { pub fn new( tokens: Vec, max_output_tokens: usize, - block_size: Option, - chunk_size: Option, + block_size: Option, + enable_prefix_caching: bool, ) -> Self { let block_size = block_size.unwrap_or(64); assert!(block_size > 1, "block_size must be greater than 1"); - let chunk_size = chunk_size.unwrap_or(256); let num_input_tokens = tokens.len(); - let tokens = Tokens::from(tokens).into_sequence(block_size, None); - let unique_blocks = create_unique_blocks_from_sequence(&tokens, None, block_size); - let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), None)); + let tokens = Tokens::from(tokens).into_sequence(block_size as u32, None); + let unique_blocks = + create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching); + let creation_signal = Some(MoveBlock::Use(unique_blocks.clone())); Self { unique_blocks, tokens, block_size, - chunk_size, max_output_tokens, generated_tokens: 0, + already_generated_tokens: 0, num_input_tokens, creation_signal, + enable_prefix_caching, } } pub fn extra_tokens(&self) -> u32 { - (self.len() % self.block_size as usize) as u32 + (self.len() % self.block_size) as u32 } pub fn len(&self) -> usize { @@ -112,20 +123,31 @@ impl ActiveSequence { pub fn new_with_signal( tokens: Vec, max_output_tokens: usize, - block_size: Option, - chunk_size: Option, + block_size: Option, + enable_prefix_caching: bool, ) -> (Self, Option) { - let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size); + let mut sequence = Self::new(tokens, max_output_tokens, block_size, enable_prefix_caching); let signal = sequence.creation_signal.take(); (sequence, signal) } + /// Get the parent hash from the second-to-last block if it exists and is a FullBlock + fn get_parent_hash(&self) -> Option { + if self.unique_blocks.len() < 2 { + return None; + } + match &self.unique_blocks[self.unique_blocks.len() - 2] { + UniqueBlock::FullBlock(hash) => Some(*hash), + _ => panic!("Cannot have a partial block as parent"), + } + } + /// Push a token to the sequence pub fn push(&mut self, token: u32) -> Option> { self.tokens.append(token).expect("Token push failed."); self.generated_tokens += 1; - if self.len() % (self.block_size as usize) != 1 { + if self.len() % self.block_size != 1 { return None; } @@ -135,16 +157,24 @@ impl ActiveSequence { // Replace last partial block with full block if it exists if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() { - let last_block_hash = self.tokens.last_complete_block().unwrap().sequence_hash(); + let last_block_hash = if self.enable_prefix_caching { + self.tokens.last_complete_block().unwrap().sequence_hash() + } else { + random::() + }; self.unique_blocks.pop(); self.unique_blocks .push(UniqueBlock::FullBlock(last_block_hash)); - signals.push(MoveBlock::Promote(uuid, last_block_hash)); + signals.push(MoveBlock::Promote( + uuid, + last_block_hash, + self.get_parent_hash(), + )); } let new_partial_block = UniqueBlock::default(); self.unique_blocks.push(new_partial_block.clone()); - signals.push(MoveBlock::Use(vec![new_partial_block], None)); + signals.push(MoveBlock::Use(vec![new_partial_block])); Some(signals) } @@ -204,15 +234,19 @@ impl ActiveSequence { } /// Reset the sequence to its initial state and return the free signals from freeing current blocks - /// maintaining the uuid of the last partial block pub fn reset_with_signal(&mut self) -> Vec { let free_signal = self.free_signal(); self.tokens.truncate(self.num_input_tokens).unwrap(); - self.unique_blocks = - create_unique_blocks_from_sequence(&self.tokens, None, self.block_size); + self.unique_blocks = create_unique_blocks_from_sequence( + &self.tokens, + None, + self.block_size, + self.enable_prefix_caching, + ); + self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens); self.generated_tokens = 0; - self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone(), None)); + self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone())); free_signal } @@ -223,7 +257,7 @@ impl ActiveSequence { self.generated_tokens = self.generated_tokens.saturating_sub(1); // Reverts to the last full block - if self.tokens.total_tokens() % (self.block_size as usize) == 0 { + if self.tokens.total_tokens() % self.block_size == 0 { self.unique_blocks.pop(); } } @@ -238,14 +272,14 @@ mod tests { // Create a sequence with block size 16 initialized with tokens [0..15] let initial_tokens: Vec = (0..15).collect(); let (mut seq1, signal1) = - ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), Some(256)); + ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true); assert_eq!(seq1.num_input_tokens(), 15); assert_eq!(seq1.len(), 15); // Check that we got a Use signal assert!(signal1.is_some()); match &signal1 { - Some(MoveBlock::Use(blocks, _)) => { + Some(MoveBlock::Use(blocks)) => { assert_eq!(blocks.len(), 1); } _ => panic!("Expected Use signal"), @@ -264,33 +298,31 @@ mod tests { let signal_16 = signal_16.unwrap(); assert_eq!(signal_16.len(), 2); + // First signal should be Promote for the previous block + match &signal_16[0] { + MoveBlock::Promote(_, _, parent_hash) => { + assert_eq!(*parent_hash, None); + } + _ => panic!("Expected Promote signal as second signal"), + } + // Second signal should be Use for new partial block match &signal_16[1] { - MoveBlock::Use(blocks, _) => { + MoveBlock::Use(blocks) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } _ => panic!("Expected Use signal as first signal"), } - // First signal should be Promote for the previous block - match &signal_16[0] { - MoveBlock::Promote(uuid, _) => { - // The uuid is generated dynamically, so we just check it exists - let _ = uuid; - } - _ => panic!("Expected Promote signal as second signal"), - } - // Verify state after pushing tokens assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block assert_eq!(seq1.len(), 17); - assert_eq!(seq1.len() % (seq1.block_size() as usize), 1); + assert_eq!(seq1.len() % seq1.block_size(), 1); // Create another sequence with block size 16 initialized with tokens [0..17] let extended_tokens: Vec = (0..16).collect(); - let (mut seq2, _) = - ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), Some(256)); + let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true); seq2.push(16); seq2.pop(); seq2.push(16); @@ -335,12 +367,12 @@ mod tests { "seq2 should have exactly 3 blocks" ); assert_eq!( - seq1.len() % (seq1.block_size() as usize), + seq1.len() % seq1.block_size(), 1, "seq1 should have 1 partial token" ); assert_eq!( - seq2.len() % (seq2.block_size() as usize), + seq2.len() % seq2.block_size(), 1, "seq2 should have 1 partial token" ); @@ -352,9 +384,38 @@ mod tests { "First two blocks should be identical" ); + // Push tokens 34..47 to seq1 + for token in 33..48 { + seq1.push(token); + } + + // Push token 48 and get the signal - this completes the block and triggers signals + let signal = seq1.push(48); + let signal = signal.unwrap(); + + // Check that signal[0] is promote + match &signal[0] { + MoveBlock::Promote(_, _, parent_hash) => { + // Check that the parent_hash matches unique_blocks[1], which should be a full block + if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] { + assert_eq!( + *parent_hash, + Some(expected_hash), + "Parent hash should match unique_blocks[1]" + ); + } else { + panic!("unique_blocks[1] should be a full block"); + } + } + _ => panic!("Expected Promote signal as first signal"), + } + // Reset seq1 and check that it equals the original clone let free_signals = seq1.reset_with_signal(); + // 49 - 15 generated tokens + assert_eq!(seq1.already_generated_tokens, 34); + // Verify the reset signals include proper cleanup events assert!(!free_signals.is_empty()); } @@ -363,13 +424,12 @@ mod tests { fn test_active_sequence_generate_signals() { // Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14) let initial_tokens: Vec = (0..14).collect(); - let (mut seq, signal) = - ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), Some(256)); + let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true); // Initial signal - should have received a Use signal for the partial block assert!(signal.is_some()); match signal { - Some(MoveBlock::Use(blocks, _)) => { + Some(MoveBlock::Use(blocks)) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } @@ -385,25 +445,23 @@ mod tests { let signals_second = seq.generate(); assert_eq!(signals_second.len(), 2); - // First signal should be Use for new partial block + // First signal should be Promote + match &signals_second[0] { + MoveBlock::Promote(_, _, parent_hash) => { + assert_eq!(*parent_hash, None); + } + _ => panic!("Expected Promote signal as first signal after second token"), + } + + // Second signal should be Use for new partial block match &signals_second[1] { - MoveBlock::Use(blocks, _) => { + MoveBlock::Use(blocks) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } _ => panic!("Expected Use signal as second signal after second token"), } - // Second signal should be Promote - match &signals_second[0] { - MoveBlock::Promote(uuid, hash) => { - // The uuid and hash values are generated dynamically, so we just check the event type - let _ = uuid; - let _ = hash; - } - _ => panic!("Expected Promote signal as first signal after second token"), - } - // Generate fourth token - should not trigger new signals as it's adding to partial block let signals_third = seq.generate(); assert_eq!(signals_third.len(), 0);