diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 222af9b2a0..b5e9bed027 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1243,6 +1243,7 @@ dependencies = [ "async-openai", "async-stream", "async-trait", + "cudarc", "derive-getters", "dlpark", "dynamo-llm", diff --git a/lib/bindings/python/Cargo.toml b/lib/bindings/python/Cargo.toml index 7fb8c83539..0b43583898 100644 --- a/lib/bindings/python/Cargo.toml +++ b/lib/bindings/python/Cargo.toml @@ -35,7 +35,7 @@ crate-type = ["cdylib", "rlib"] [features] default = ["block-manager"] -block-manager = ["dynamo-llm/block-manager", "dep:dlpark"] +block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc"] [dependencies] dynamo-llm = { path = "../../llm" } @@ -79,6 +79,8 @@ pyo3-async-runtimes = { version = "0.23.0", default-features = false, features = pythonize = "0.23" dlpark = { version = "0.5", features = ["pyo3", "half"], optional = true } +cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true } + [dev-dependencies] rstest = "0.25" diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index 648dd07701..0dfcfacea2 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -127,18 +127,18 @@ impl BlockManager { } else { tracing::info!("Leader not provided. Block transfer functionality will be disabled."); - let num_device_blocks = num_device_blocks - .expect("num_device_blocks must be provided if leader is not provided"); - - config = config.device_layout( - dynamo_llm::block_manager::KvManagerLayoutConfig::builder() - .num_blocks(num_device_blocks) - .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) - .build() - .map_err(to_pyerr)?, - ); - - unimplemented!("construct a drt or get one from args") + // let num_device_blocks = num_device_blocks + // .expect("num_device_blocks must be provided if leader is not provided"); + + // config = config.device_layout( + // dynamo_llm::block_manager::KvManagerLayoutConfig::builder() + // .num_blocks(num_device_blocks) + // .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) + // .build() + // .map_err(to_pyerr)?, + // ); + + unimplemented!("Leader not provided"); // ( // None, // Arc::new( diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs index 7365db3d70..d58f123248 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs @@ -139,11 +139,6 @@ impl Leader for KvConnectorLeader { let shared_slot = self.slot_manager.get_slot(&request_id).map_err(to_pyerr)?; let mut slot = shared_slot.lock().map_err(to_pyerr)?; - // vllm is telling us that the tokens have been computed, since we do not have insight into the device pool - // we accept this and advance the computed position - slot.advance_computed_position(num_computed_tokens)?; - slot.record_cached_device_tokens(num_computed_tokens); - // early exit if we cannot match full block if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size { return Ok((0, false)); @@ -151,7 +146,7 @@ impl Leader for KvConnectorLeader { // find matches for any remaining tokens // this will advance the computed position and hold any newly matched blocks in the slot - slot.acquire_all_local_matches()?; + slot.acquire_local_matches(num_computed_tokens)?; // return the number of external tokens that are ready for onboarding // we always return true here as we always asynchronously onboard matched blocks @@ -168,9 +163,6 @@ impl Leader for KvConnectorLeader { } } - /// We drop the need to pass in the KvCacheBlocks and the num_external_tokens as they are captured - /// statefully in the [`VllmLeaderKvCacheManagerAndConnector::get_num_new_matched_tokens`] function. - /// /// Note: vLLM will not provide any scheduler output data for requests that are onboarding. it is entirely /// on the connector's implementation to handle this case. #[tracing::instrument(level = "debug", skip_all, fields(request_id))] @@ -190,11 +182,18 @@ impl Leader for KvConnectorLeader { let shared_slot = self.slot_manager.get_slot(&request_id).map_err(to_pyerr)?; let mut slot = shared_slot.lock().map_err(to_pyerr)?; + // we have not yet advanced the computed position, but now we can, since we have an indication that we have + // necessary gpu blocks into which we will load the external tokens. + slot.append_mutable_device_blocks(&block_ids)?; // the second call will show num_external_tokens == 0 // this call is just letting us know the other blocks that are being used for the remainder of the prefill if num_external_tokens > 0 { + let num_computed_tokens = block_ids.len() * self.block_size - num_external_tokens; + slot.record_cached_device_tokens(num_computed_tokens); + slot.advance_computed_position(num_computed_tokens)?; + tracing::debug!( request_id = request_id, "triggering onboarding for {} external tokens", @@ -207,8 +206,8 @@ impl Leader for KvConnectorLeader { Ok(()) } - #[tracing::instrument(level = "debug", skip_all)] - fn build_connector_metadata( + #[tracing::instrument(level = "debug", skip_all, fields(iteration = self.iteration_counter + 1))] + pub fn build_connector_metadata( &mut self, scheduler_output: SchedulerOutput, ) -> PyResult> { @@ -220,8 +219,8 @@ impl Leader for KvConnectorLeader { self.iteration_counter += 1; let iteration = self.iteration_counter; - tracing::debug!("Building connector metadata; iteration {iteration}"); - tracing::debug!("scheduler_output: {scheduler_output:#?}"); + tracing::debug!("Building connector metadata"); + tracing::debug!("SchedulerOutput: {scheduler_output:#?}"); let mut inflight_requests = self.inflight_requests.clone(); let mut md = ConnectorMetadata::new(iteration); @@ -322,7 +321,7 @@ impl Leader for KvConnectorLeader { } } - tracing::debug!("scheduler_output: {scheduler_output:#?}"); + tracing::debug!("metadata: {md:#?}"); serde_json::to_vec(&md).map_err(to_pyerr) } diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs index 156b5ef964..25a67593b0 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs @@ -104,7 +104,9 @@ pub trait Slot: std::fmt::Debug { /// of any kv blocks for tokens in the isl that are not already in memory on the device, but on some local storage. /// /// If external tokens are matched, then the slot will transition to the [`SlotState::Onboarding`] state. - fn acquire_all_local_matches(&mut self) -> Result<(), SlotError>; + /// `num_computed_tokens` is the number of tokens that have been computed on the device, this indicated the number of + /// blocks in the ISL sequence that we should skip before we start looking for matches. + fn acquire_local_matches(&mut self, num_computed_tokens: usize) -> Result<(), SlotError>; /// Trigger the onboarding operation for the slot. fn trigger_onboarding(&mut self, num_external_tokens: usize) -> Result<(), SlotError>; @@ -349,6 +351,7 @@ impl Slot for VllmConnectorSlot { tracing::debug!("recording {} cached disk tokens", num_tokens); } + #[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id.as_str()))] fn apply_scheduler_output( &mut self, tokens: &[u32], @@ -356,10 +359,11 @@ impl Slot for VllmConnectorSlot { num_computed_tokens: usize, num_scheduled_tokens: usize, ) -> Result<(), SlotError> { - // debug_assert!(num_computed_tokens == self.computed_tokens()); - if !tokens.is_empty() { - tracing::debug!("appending {} newly decodedtokens to sequence", tokens.len()); + tracing::debug!( + "appending {} newly decoded tokens to sequence", + tokens.len() + ); self.state = SlotState::Decoding; self.sequence.extend(tokens.into()).unwrap(); } else { @@ -443,6 +447,8 @@ impl Slot for VllmConnectorSlot { self.offload_blocks(&offload_block_ids, &offload_token_blocks) .expect("failed to offload blocks"); + + self.evaluated_blocks += num_candidate_blocks; } // done applying policy @@ -494,7 +500,12 @@ impl Slot for VllmConnectorSlot { } #[tracing::instrument(level = "debug", skip_all)] - fn acquire_all_local_matches(&mut self) -> Result<(), SlotError> { + fn acquire_local_matches(&mut self, num_computed_tokens: usize) -> Result<(), SlotError> { + if matches!(self.state(), SlotState::OnboardStaged(_)) { + tracing::debug!("slot is already in the OnboardStaged state; skipping lookup"); + return Ok(()); + } + if !matches!(self.state(), SlotState::Initialized) { return Err(SlotError::InvalidOperation(format!( "slot must be in the NotScheduled state to acquire local matches; got {:?}", @@ -503,7 +514,6 @@ impl Slot for VllmConnectorSlot { } let block_size = self.block_manager.block_size(); - let num_computed_tokens = self.computed_tokens(); let num_computed_blocks = num_computed_tokens / block_size; debug_assert!(num_computed_tokens % block_size == 0); @@ -571,16 +581,10 @@ impl Slot for VllmConnectorSlot { return Ok(()); } - // early exit if we need to onboard 0 blocks - if (num_computed_blocks + num_matched_blocks) * block_size == self.sequence().total_tokens() - { - return Ok(()); - } - let mut num_new_matched_tokens = num_matched_blocks * block_size; // we are on a block boundary, so we need to throw away the last block - if num_computed_tokens + num_new_matched_tokens == self.sequence().total_tokens() { + if (num_computed_tokens + num_new_matched_tokens) == self.sequence().total_tokens() { tracing::debug!("on a block boundary, throwing away the last block"); // we should have matched at least one block @@ -597,6 +601,11 @@ impl Slot for VllmConnectorSlot { num_new_matched_tokens -= block_size; } + // early exit if we need to onboard 0 blocks (after potentially dropping the last block) + if num_new_matched_tokens == 0 { + return Ok(()); + } + self.staging_from_host = if !host_blocks.is_empty() { Some(host_blocks) } else { @@ -704,6 +713,7 @@ impl ExternallyManagedDeviceSlot for VllmConnectorSlot { Ok(()) } + #[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id))] fn append_mutable_device_blocks(&mut self, block_ids: &[BlockId]) -> Result<(), SlotError> { let count = block_ids.len(); self.device_blocks.extend(block_ids); diff --git a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs index c3da870ebf..6bd58ad511 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs @@ -6,7 +6,7 @@ use dynamo_llm::block_manager::connector::scheduler::{ Scheduler, TransferSchedulerClient, WorkerSchedulerClient, }; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::sync::{Arc, OnceLock}; use super::*; @@ -28,8 +28,7 @@ pub struct KvConnectorWorker { connector: WorkerSchedulerClient, transfer_client: TransferSchedulerClient, - // request_slots: HashMap, - kv_caches: HashMap>, + kv_cache_names: Vec<(String, Arc)>, /// Map of request id to inflight load requests maybe_finished_onboarding: HashSet, @@ -43,6 +42,9 @@ pub struct KvConnectorWorker { bound: bool, iteration: u64, layers_complete: usize, + + /// cuda events created by the python side + layer_events: Vec, } #[pymethods] @@ -76,13 +78,14 @@ impl KvConnectorWorker { kvbm_worker: OnceLock::new(), connector: worker_client, transfer_client, - kv_caches: HashMap::new(), maybe_finished_onboarding: HashSet::new(), maybe_finished_offloading: HashSet::new(), offloading_operations: Vec::new(), bound: false, iteration: 0, layers_complete: 0, + kv_cache_names: Vec::new(), + layer_events: Vec::new(), }) } @@ -97,6 +100,7 @@ impl KvConnectorWorker { device_id: usize, dtype_width_bytes: usize, kv_caches: Vec<(String, Py)>, + raw_event_handles: Vec, ) -> PyResult<()> { if self.kvbm_worker.get().is_some() { tracing::warn!("kvbm worker already registered"); @@ -105,17 +109,20 @@ impl KvConnectorWorker { // Process kv_caches in layer execution order (already sorted by layer index) let mut vllm_tensors = Vec::new(); + let mut kv_cache_names = Vec::new(); for (layer_name, torch_tensor) in kv_caches { let vllm_tensor = Arc::new(VllmTensor::new(torch_tensor).map_err(to_pyerr)?); - tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}"); - - // Store for later lookup by name - self.kv_caches.insert(layer_name, vllm_tensor.clone()); + tracing::debug!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}"); // Build ordered tensor list for worker config + kv_cache_names.push((layer_name, vllm_tensor.clone())); vllm_tensors.push(vllm_tensor as Arc); } + assert_eq!(kv_cache_names.len(), raw_event_handles.len()); + self.kv_cache_names = kv_cache_names; + self.layer_events = raw_event_handles; + let config = KvbmWorkerConfig::builder() .drt(self.drt.clone()) .num_device_blocks(num_device_blocks) @@ -149,7 +156,7 @@ impl KvConnectorWorker { /// This action translates the metadata into a set of actions that the worker will perform. /// All actions much be assigned to a slot before [`KvConnectorWorker::clear_metadata`] is called. pub fn bind_connector_metadata(&mut self, metadata: Vec) -> PyResult<()> { - debug_assert!(!self.bound, "connector metadata already bound"); + // debug_assert!(!self.bound, "connector metadata already bound"); let metadata: ConnectorMetadata = serde_json::from_slice(&metadata).map_err(to_pyerr)?; self.bound = true; self.iteration = metadata.iteration; @@ -229,8 +236,13 @@ impl KvConnectorWorker { /// Trigger block-wise completion signals afer last layer. pub fn save_kv_layer(&mut self, _layer_name: String, _kv_layer: Py) -> PyResult<()> { self.layers_complete += 1; - if self.layers_complete == self.kv_caches.len() { + if self.layers_complete == self.kv_cache_names.len() { let offloading_operations = std::mem::take(&mut self.offloading_operations); + + // block on the the completion of the last layer + // todo(ryan): capture the context, pass this to the scheduler to do the await on another thread + // or put the event on a stream and use stream waits to keep it all on device. + event_sync_blocking(self.layer_events[self.layers_complete - 1]); for operation in offloading_operations { self.connector.enqueue_request(operation); } @@ -321,3 +333,30 @@ impl KvConnectorWorker { (is_finished_offloading, is_finished_onboarding) } } + +use cudarc::driver::sys::{ + cuCtxGetCurrent, cuEventSynchronize, cudaError_enum, CUcontext, CUevent, +}; +use std::ptr; + +// todo(ryan): we will need this if we farm off the cuEventSynchronize to another thread +fn _get_current_context() -> CUcontext { + let mut ctx: CUcontext = ptr::null_mut(); + let status = unsafe { cuCtxGetCurrent(&mut ctx) }; + assert_eq!( + status, + cudaError_enum::CUDA_SUCCESS, + "cuCtxGetCurrent failed" + ); + assert!(!ctx.is_null(), "Torch has not set a CUDA context"); + ctx +} + +fn event_sync_blocking(event: u64) { + let status = unsafe { cuEventSynchronize(event as CUevent) }; + assert_eq!( + status, + cudaError_enum::CUDA_SUCCESS, + "cuEventSynchronize failed" + ); +} diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py index 3892e74ea8..411cdd3f98 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/connector_worker.py @@ -72,6 +72,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) ] + events = [ + torch.cuda.Event(enable_timing=False, interprocess=False) + for _ in range(len(ordered_kv_caches)) + ] + + # events are lazy, if we don't record them once here, the raw handles we pass to rust will be null + for event in events: + event.record(torch.cuda.current_stream()) + + raw_event_handles = [event.cuda_event for event in events] + + self.events = { + layer_name: event + for (layer_name, _tensor), event in zip(ordered_kv_caches, events) + } + # Get first tensor to extract common properties first_tensor = ordered_kv_caches[0][1] shape = first_tensor.shape @@ -101,6 +117,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): device_id, kv_cache_dtype.itemsize, ordered_kv_caches, + raw_event_handles, ) def bind_connector_metadata(self, data: bytes) -> None: @@ -159,6 +176,7 @@ def save_kv_layer( attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. """ + self.events[layer_name].record(torch.cuda.current_stream()) self._connector.save_kv_layer(layer_name, kv_layer) def get_finished( diff --git a/lib/llm/src/block_manager/connector/scheduler.rs b/lib/llm/src/block_manager/connector/scheduler.rs index f8ac3ba18a..b93a9ea4ce 100644 --- a/lib/llm/src/block_manager/connector/scheduler.rs +++ b/lib/llm/src/block_manager/connector/scheduler.rs @@ -107,10 +107,10 @@ impl WorkerSchedulerClient { } pub fn start_next_iteration(&mut self) -> Result<(), SchedulerError> { - debug_assert!( - self.iteration_complete, - "previous iteration must be complete before starting a new iteration" - ); + // debug_assert!( + // self.iteration_complete, + // "previous iteration must be complete before starting a new iteration" + // ); self.iteration += 1; self.iteration_complete = false; self.layers_complete = 0; @@ -421,11 +421,11 @@ impl Scheduler { } fn start_iteration(&mut self, iteration: u64) { - tracing::debug!(iteration, "engine state updating iteration"); - debug_assert!( - self.iteration_complete, - "previous iteration must be complete before starting a new iteration" - ); + // tracing::debug!(iteration, "engine state updating iteration"); + // debug_assert!( + // self.iteration_complete, + // "previous iteration must be complete before starting a new iteration" + // ); debug_assert_eq!( self.iteration, iteration - 1, diff --git a/lib/llm/src/block_manager/distributed/transfer.rs b/lib/llm/src/block_manager/distributed/transfer.rs index 000b9ff7c9..6629a2fa87 100644 --- a/lib/llm/src/block_manager/distributed/transfer.rs +++ b/lib/llm/src/block_manager/distributed/transfer.rs @@ -132,6 +132,8 @@ impl BlockTransferHandler { request.to_pool() ); + tracing::debug!("request: {request:#?}"); + let notify = match (request.from_pool(), request.to_pool()) { (Device, Host) => self.begin_transfer(&self.device, &self.host, request).await, (Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,