Skip to content

Commit 07798c8

Browse files
fix: chunked prefill update (#2307)
Signed-off-by: Ryan Olson <[email protected]> Co-authored-by: Olga Andreeva <[email protected]>
1 parent 89c87ef commit 07798c8

File tree

9 files changed

+130
-59
lines changed

9 files changed

+130
-59
lines changed

lib/bindings/python/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/bindings/python/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ crate-type = ["cdylib", "rlib"]
3535

3636
[features]
3737
default = ["block-manager"]
38-
block-manager = ["dynamo-llm/block-manager", "dep:dlpark"]
38+
block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc"]
3939

4040
[dependencies]
4141
dynamo-llm = { path = "../../llm" }
@@ -79,6 +79,8 @@ pyo3-async-runtimes = { version = "0.23.0", default-features = false, features =
7979
pythonize = "0.23"
8080

8181
dlpark = { version = "0.5", features = ["pyo3", "half"], optional = true }
82+
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
83+
8284

8385
[dev-dependencies]
8486
rstest = "0.25"

lib/bindings/python/rust/llm/block_manager.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,18 @@ impl BlockManager {
127127
} else {
128128
tracing::info!("Leader not provided. Block transfer functionality will be disabled.");
129129

130-
let num_device_blocks = num_device_blocks
131-
.expect("num_device_blocks must be provided if leader is not provided");
132-
133-
config = config.device_layout(
134-
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
135-
.num_blocks(num_device_blocks)
136-
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
137-
.build()
138-
.map_err(to_pyerr)?,
139-
);
140-
141-
unimplemented!("construct a drt or get one from args")
130+
// let num_device_blocks = num_device_blocks
131+
// .expect("num_device_blocks must be provided if leader is not provided");
132+
133+
// config = config.device_layout(
134+
// dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
135+
// .num_blocks(num_device_blocks)
136+
// .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
137+
// .build()
138+
// .map_err(to_pyerr)?,
139+
// );
140+
141+
unimplemented!("Leader not provided");
142142
// (
143143
// None,
144144
// Arc::new(

lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,19 +139,14 @@ impl Leader for KvConnectorLeader {
139139
let shared_slot = self.slot_manager.get_slot(&request_id).map_err(to_pyerr)?;
140140
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
141141

142-
// vllm is telling us that the tokens have been computed, since we do not have insight into the device pool
143-
// we accept this and advance the computed position
144-
slot.advance_computed_position(num_computed_tokens)?;
145-
slot.record_cached_device_tokens(num_computed_tokens);
146-
147142
// early exit if we cannot match full block
148143
if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size {
149144
return Ok((0, false));
150145
}
151146

152147
// find matches for any remaining tokens
153148
// this will advance the computed position and hold any newly matched blocks in the slot
154-
slot.acquire_all_local_matches()?;
149+
slot.acquire_local_matches(num_computed_tokens)?;
155150

156151
// return the number of external tokens that are ready for onboarding
157152
// we always return true here as we always asynchronously onboard matched blocks
@@ -168,9 +163,6 @@ impl Leader for KvConnectorLeader {
168163
}
169164
}
170165

171-
/// We drop the need to pass in the KvCacheBlocks and the num_external_tokens as they are captured
172-
/// statefully in the [`VllmLeaderKvCacheManagerAndConnector::get_num_new_matched_tokens`] function.
173-
///
174166
/// Note: vLLM will not provide any scheduler output data for requests that are onboarding. it is entirely
175167
/// on the connector's implementation to handle this case.
176168
#[tracing::instrument(level = "debug", skip_all, fields(request_id))]
@@ -190,11 +182,18 @@ impl Leader for KvConnectorLeader {
190182
let shared_slot = self.slot_manager.get_slot(&request_id).map_err(to_pyerr)?;
191183
let mut slot = shared_slot.lock().map_err(to_pyerr)?;
192184

185+
// we have not yet advanced the computed position, but now we can, since we have an indication that we have
186+
// necessary gpu blocks into which we will load the external tokens.
187+
193188
slot.append_mutable_device_blocks(&block_ids)?;
194189

195190
// the second call will show num_external_tokens == 0
196191
// this call is just letting us know the other blocks that are being used for the remainder of the prefill
197192
if num_external_tokens > 0 {
193+
let num_computed_tokens = block_ids.len() * self.block_size - num_external_tokens;
194+
slot.record_cached_device_tokens(num_computed_tokens);
195+
slot.advance_computed_position(num_computed_tokens)?;
196+
198197
tracing::debug!(
199198
request_id = request_id,
200199
"triggering onboarding for {} external tokens",
@@ -207,8 +206,8 @@ impl Leader for KvConnectorLeader {
207206
Ok(())
208207
}
209208

210-
#[tracing::instrument(level = "debug", skip_all)]
211-
fn build_connector_metadata(
209+
#[tracing::instrument(level = "debug", skip_all, fields(iteration = self.iteration_counter + 1))]
210+
pub fn build_connector_metadata(
212211
&mut self,
213212
scheduler_output: SchedulerOutput,
214213
) -> PyResult<Vec<u8>> {
@@ -220,8 +219,8 @@ impl Leader for KvConnectorLeader {
220219
self.iteration_counter += 1;
221220
let iteration = self.iteration_counter;
222221

223-
tracing::debug!("Building connector metadata; iteration {iteration}");
224-
tracing::debug!("scheduler_output: {scheduler_output:#?}");
222+
tracing::debug!("Building connector metadata");
223+
tracing::debug!("SchedulerOutput: {scheduler_output:#?}");
225224

226225
let mut inflight_requests = self.inflight_requests.clone();
227226
let mut md = ConnectorMetadata::new(iteration);
@@ -322,7 +321,7 @@ impl Leader for KvConnectorLeader {
322321
}
323322
}
324323

325-
tracing::debug!("scheduler_output: {scheduler_output:#?}");
324+
tracing::debug!("metadata: {md:#?}");
326325
serde_json::to_vec(&md).map_err(to_pyerr)
327326
}
328327

lib/bindings/python/rust/llm/block_manager/vllm/connector/leader/slot.rs

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ pub trait Slot: std::fmt::Debug {
104104
/// of any kv blocks for tokens in the isl that are not already in memory on the device, but on some local storage.
105105
///
106106
/// If external tokens are matched, then the slot will transition to the [`SlotState::Onboarding`] state.
107-
fn acquire_all_local_matches(&mut self) -> Result<(), SlotError>;
107+
/// `num_computed_tokens` is the number of tokens that have been computed on the device, this indicated the number of
108+
/// blocks in the ISL sequence that we should skip before we start looking for matches.
109+
fn acquire_local_matches(&mut self, num_computed_tokens: usize) -> Result<(), SlotError>;
108110

109111
/// Trigger the onboarding operation for the slot.
110112
fn trigger_onboarding(&mut self, num_external_tokens: usize) -> Result<(), SlotError>;
@@ -349,17 +351,19 @@ impl Slot for VllmConnectorSlot {
349351
tracing::debug!("recording {} cached disk tokens", num_tokens);
350352
}
351353

354+
#[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id.as_str()))]
352355
fn apply_scheduler_output(
353356
&mut self,
354357
tokens: &[u32],
355358
block_ids: &[BlockId],
356359
num_computed_tokens: usize,
357360
num_scheduled_tokens: usize,
358361
) -> Result<(), SlotError> {
359-
// debug_assert!(num_computed_tokens == self.computed_tokens());
360-
361362
if !tokens.is_empty() {
362-
tracing::debug!("appending {} newly decodedtokens to sequence", tokens.len());
363+
tracing::debug!(
364+
"appending {} newly decoded tokens to sequence",
365+
tokens.len()
366+
);
363367
self.state = SlotState::Decoding;
364368
self.sequence.extend(tokens.into()).unwrap();
365369
} else {
@@ -443,6 +447,8 @@ impl Slot for VllmConnectorSlot {
443447

444448
self.offload_blocks(&offload_block_ids, &offload_token_blocks)
445449
.expect("failed to offload blocks");
450+
451+
self.evaluated_blocks += num_candidate_blocks;
446452
}
447453

448454
// done applying policy
@@ -494,7 +500,12 @@ impl Slot for VllmConnectorSlot {
494500
}
495501

496502
#[tracing::instrument(level = "debug", skip_all)]
497-
fn acquire_all_local_matches(&mut self) -> Result<(), SlotError> {
503+
fn acquire_local_matches(&mut self, num_computed_tokens: usize) -> Result<(), SlotError> {
504+
if matches!(self.state(), SlotState::OnboardStaged(_)) {
505+
tracing::debug!("slot is already in the OnboardStaged state; skipping lookup");
506+
return Ok(());
507+
}
508+
498509
if !matches!(self.state(), SlotState::Initialized) {
499510
return Err(SlotError::InvalidOperation(format!(
500511
"slot must be in the NotScheduled state to acquire local matches; got {:?}",
@@ -503,7 +514,6 @@ impl Slot for VllmConnectorSlot {
503514
}
504515

505516
let block_size = self.block_manager.block_size();
506-
let num_computed_tokens = self.computed_tokens();
507517
let num_computed_blocks = num_computed_tokens / block_size;
508518
debug_assert!(num_computed_tokens % block_size == 0);
509519

@@ -571,16 +581,10 @@ impl Slot for VllmConnectorSlot {
571581
return Ok(());
572582
}
573583

574-
// early exit if we need to onboard 0 blocks
575-
if (num_computed_blocks + num_matched_blocks) * block_size == self.sequence().total_tokens()
576-
{
577-
return Ok(());
578-
}
579-
580584
let mut num_new_matched_tokens = num_matched_blocks * block_size;
581585

582586
// we are on a block boundary, so we need to throw away the last block
583-
if num_computed_tokens + num_new_matched_tokens == self.sequence().total_tokens() {
587+
if (num_computed_tokens + num_new_matched_tokens) == self.sequence().total_tokens() {
584588
tracing::debug!("on a block boundary, throwing away the last block");
585589

586590
// we should have matched at least one block
@@ -597,6 +601,11 @@ impl Slot for VllmConnectorSlot {
597601
num_new_matched_tokens -= block_size;
598602
}
599603

604+
// early exit if we need to onboard 0 blocks (after potentially dropping the last block)
605+
if num_new_matched_tokens == 0 {
606+
return Ok(());
607+
}
608+
600609
self.staging_from_host = if !host_blocks.is_empty() {
601610
Some(host_blocks)
602611
} else {
@@ -704,6 +713,7 @@ impl ExternallyManagedDeviceSlot for VllmConnectorSlot {
704713
Ok(())
705714
}
706715

716+
#[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id))]
707717
fn append_mutable_device_blocks(&mut self, block_ids: &[BlockId]) -> Result<(), SlotError> {
708718
let count = block_ids.len();
709719
self.device_blocks.extend(block_ids);

lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use dynamo_llm::block_manager::connector::scheduler::{
66
Scheduler, TransferSchedulerClient, WorkerSchedulerClient,
77
};
88

9-
use std::collections::{HashMap, HashSet};
9+
use std::collections::HashSet;
1010
use std::sync::{Arc, OnceLock};
1111

1212
use super::*;
@@ -28,8 +28,7 @@ pub struct KvConnectorWorker {
2828
connector: WorkerSchedulerClient,
2929
transfer_client: TransferSchedulerClient,
3030

31-
// request_slots: HashMap<String, WorkerSlot>,
32-
kv_caches: HashMap<String, Arc<VllmTensor>>,
31+
kv_cache_names: Vec<(String, Arc<VllmTensor>)>,
3332

3433
/// Map of request id to inflight load requests
3534
maybe_finished_onboarding: HashSet<String>,
@@ -43,6 +42,9 @@ pub struct KvConnectorWorker {
4342
bound: bool,
4443
iteration: u64,
4544
layers_complete: usize,
45+
46+
/// cuda events created by the python side
47+
layer_events: Vec<u64>,
4648
}
4749

4850
#[pymethods]
@@ -76,13 +78,14 @@ impl KvConnectorWorker {
7678
kvbm_worker: OnceLock::new(),
7779
connector: worker_client,
7880
transfer_client,
79-
kv_caches: HashMap::new(),
8081
maybe_finished_onboarding: HashSet::new(),
8182
maybe_finished_offloading: HashSet::new(),
8283
offloading_operations: Vec::new(),
8384
bound: false,
8485
iteration: 0,
8586
layers_complete: 0,
87+
kv_cache_names: Vec::new(),
88+
layer_events: Vec::new(),
8689
})
8790
}
8891

@@ -97,6 +100,7 @@ impl KvConnectorWorker {
97100
device_id: usize,
98101
dtype_width_bytes: usize,
99102
kv_caches: Vec<(String, Py<PyAny>)>,
103+
raw_event_handles: Vec<u64>,
100104
) -> PyResult<()> {
101105
if self.kvbm_worker.get().is_some() {
102106
tracing::warn!("kvbm worker already registered");
@@ -105,17 +109,20 @@ impl KvConnectorWorker {
105109

106110
// Process kv_caches in layer execution order (already sorted by layer index)
107111
let mut vllm_tensors = Vec::new();
112+
let mut kv_cache_names = Vec::new();
108113
for (layer_name, torch_tensor) in kv_caches {
109114
let vllm_tensor = Arc::new(VllmTensor::new(torch_tensor).map_err(to_pyerr)?);
110-
tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}");
111-
112-
// Store for later lookup by name
113-
self.kv_caches.insert(layer_name, vllm_tensor.clone());
115+
tracing::debug!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}");
114116

115117
// Build ordered tensor list for worker config
118+
kv_cache_names.push((layer_name, vllm_tensor.clone()));
116119
vllm_tensors.push(vllm_tensor as Arc<dyn TorchTensor>);
117120
}
118121

122+
assert_eq!(kv_cache_names.len(), raw_event_handles.len());
123+
self.kv_cache_names = kv_cache_names;
124+
self.layer_events = raw_event_handles;
125+
119126
let config = KvbmWorkerConfig::builder()
120127
.drt(self.drt.clone())
121128
.num_device_blocks(num_device_blocks)
@@ -149,7 +156,7 @@ impl KvConnectorWorker {
149156
/// This action translates the metadata into a set of actions that the worker will perform.
150157
/// All actions much be assigned to a slot before [`KvConnectorWorker::clear_metadata`] is called.
151158
pub fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> PyResult<()> {
152-
debug_assert!(!self.bound, "connector metadata already bound");
159+
// debug_assert!(!self.bound, "connector metadata already bound");
153160
let metadata: ConnectorMetadata = serde_json::from_slice(&metadata).map_err(to_pyerr)?;
154161
self.bound = true;
155162
self.iteration = metadata.iteration;
@@ -229,8 +236,13 @@ impl KvConnectorWorker {
229236
/// Trigger block-wise completion signals afer last layer.
230237
pub fn save_kv_layer(&mut self, _layer_name: String, _kv_layer: Py<PyAny>) -> PyResult<()> {
231238
self.layers_complete += 1;
232-
if self.layers_complete == self.kv_caches.len() {
239+
if self.layers_complete == self.kv_cache_names.len() {
233240
let offloading_operations = std::mem::take(&mut self.offloading_operations);
241+
242+
// block on the the completion of the last layer
243+
// todo(ryan): capture the context, pass this to the scheduler to do the await on another thread
244+
// or put the event on a stream and use stream waits to keep it all on device.
245+
event_sync_blocking(self.layer_events[self.layers_complete - 1]);
234246
for operation in offloading_operations {
235247
self.connector.enqueue_request(operation);
236248
}
@@ -321,3 +333,30 @@ impl KvConnectorWorker {
321333
(is_finished_offloading, is_finished_onboarding)
322334
}
323335
}
336+
337+
use cudarc::driver::sys::{
338+
cuCtxGetCurrent, cuEventSynchronize, cudaError_enum, CUcontext, CUevent,
339+
};
340+
use std::ptr;
341+
342+
// todo(ryan): we will need this if we farm off the cuEventSynchronize to another thread
343+
fn _get_current_context() -> CUcontext {
344+
let mut ctx: CUcontext = ptr::null_mut();
345+
let status = unsafe { cuCtxGetCurrent(&mut ctx) };
346+
assert_eq!(
347+
status,
348+
cudaError_enum::CUDA_SUCCESS,
349+
"cuCtxGetCurrent failed"
350+
);
351+
assert!(!ctx.is_null(), "Torch has not set a CUDA context");
352+
ctx
353+
}
354+
355+
fn event_sync_blocking(event: u64) {
356+
let status = unsafe { cuEventSynchronize(event as CUevent) };
357+
assert_eq!(
358+
status,
359+
cudaError_enum::CUDA_SUCCESS,
360+
"cuEventSynchronize failed"
361+
);
362+
}

0 commit comments

Comments
 (0)