Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adding cuda events to default pytorch stream; awaiting event for last…
… layer before triggering offload
  • Loading branch information
ryanolson committed Aug 6, 2025
commit a99d8d0dd0efe65edad0df8764d5226e269e7242
1 change: 1 addition & 0 deletions lib/bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion lib/bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ crate-type = ["cdylib", "rlib"]

[features]
default = ["block-manager"]
block-manager = ["dynamo-llm/block-manager", "dep:dlpark"]
block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc"]

[dependencies]
dynamo-llm = { path = "../../llm" }
Expand Down Expand Up @@ -79,6 +79,8 @@ pyo3-async-runtimes = { version = "0.23.0", default-features = false, features =
pythonize = "0.23"

dlpark = { version = "0.5", features = ["pyo3", "half"], optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }


[dev-dependencies]
rstest = "0.25"
24 changes: 12 additions & 12 deletions lib/bindings/python/rust/llm/block_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,18 @@ impl BlockManager {
} else {
tracing::info!("Leader not provided. Block transfer functionality will be disabled.");

let num_device_blocks = num_device_blocks
.expect("num_device_blocks must be provided if leader is not provided");

config = config.device_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(num_device_blocks)
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
.build()
.map_err(to_pyerr)?,
);

unimplemented!("construct a drt or get one from args")
// let num_device_blocks = num_device_blocks
// .expect("num_device_blocks must be provided if leader is not provided");

// config = config.device_layout(
// dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
// .num_blocks(num_device_blocks)
// .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
// .build()
// .map_err(to_pyerr)?,
// );

unimplemented!("Leader not provided");
// (
// None,
// Arc::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use dynamo_llm::block_manager::connector::scheduler::{
Scheduler, TransferSchedulerClient, WorkerSchedulerClient,
};

use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::sync::{Arc, OnceLock};

use super::*;
Expand Down Expand Up @@ -42,6 +42,9 @@ pub struct KvConnectorWorker {
bound: bool,
iteration: u64,
layers_complete: usize,

/// cuda events created by the python side
layer_events: Vec<u64>,
}

#[pymethods]
Expand Down Expand Up @@ -82,6 +85,7 @@ impl KvConnectorWorker {
iteration: 0,
layers_complete: 0,
kv_cache_names: Vec::new(),
layer_events: Vec::new(),
})
}

Expand All @@ -96,6 +100,7 @@ impl KvConnectorWorker {
device_id: usize,
dtype_width_bytes: usize,
kv_caches: Vec<(String, Py<PyAny>)>,
raw_event_handles: Vec<u64>,
) -> PyResult<()> {
if self.kvbm_worker.get().is_some() {
tracing::warn!("kvbm worker already registered");
Expand All @@ -114,7 +119,9 @@ impl KvConnectorWorker {
vllm_tensors.push(vllm_tensor as Arc<dyn TorchTensor>);
}

assert_eq!(kv_cache_names.len(), raw_event_handles.len());
self.kv_cache_names = kv_cache_names;
self.layer_events = raw_event_handles;

let config = KvbmWorkerConfig::builder()
.drt(self.drt.clone())
Expand Down Expand Up @@ -231,6 +238,11 @@ impl KvConnectorWorker {
self.layers_complete += 1;
if self.layers_complete == self.kv_cache_names.len() {
let offloading_operations = std::mem::take(&mut self.offloading_operations);

// block on the the completion of the last layer
// todo(ryan): capture the context, pass this to the scheduler to do the await on another thread
// or put the event on a stream and use stream waits to keep it all on device.
event_sync_blocking(self.layer_events[self.layers_complete - 1]);
for operation in offloading_operations {
self.connector.enqueue_request(operation);
}
Expand Down Expand Up @@ -321,3 +333,30 @@ impl KvConnectorWorker {
(is_finished_offloading, is_finished_onboarding)
}
}

use cudarc::driver::sys::{
cuCtxGetCurrent, cuEventSynchronize, cudaError_enum, CUcontext, CUevent,
};
use std::ptr;

// todo(ryan): we will need this if we farm off the cuEventSynchronize to another thread
fn _get_current_context() -> CUcontext {
let mut ctx: CUcontext = ptr::null_mut();
let status = unsafe { cuCtxGetCurrent(&mut ctx) };
assert_eq!(
status,
cudaError_enum::CUDA_SUCCESS,
"cuCtxGetCurrent failed"
);
assert!(!ctx.is_null(), "Torch has not set a CUDA context");
ctx
}

fn event_sync_blocking(event: u64) {
let status = unsafe { cuEventSynchronize(event as CUevent) };
assert_eq!(
status,
cudaError_enum::CUDA_SUCCESS,
"cuEventSynchronize failed"
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
)
]

events = [
torch.cuda.Event(enable_timing=False, interprocess=False)
for _ in range(len(ordered_kv_caches))
]

# events are lazy, if we don't record them once here, the raw handles we pass to rust will be null
for event in events:
event.record(torch.cuda.current_stream())

raw_event_handles = [event.cuda_event for event in events]

self.events = {
layer_name: event
for (layer_name, _tensor), event in zip(ordered_kv_caches, events)
}

# Get first tensor to extract common properties
first_tensor = ordered_kv_caches[0][1]
shape = first_tensor.shape
Expand Down Expand Up @@ -101,6 +117,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
device_id,
kv_cache_dtype.itemsize,
ordered_kv_caches,
raw_event_handles,
)

def bind_connector_metadata(self, data: bytes) -> None:
Expand Down Expand Up @@ -159,6 +176,7 @@ def save_kv_layer(
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self.events[layer_name].record(torch.cuda.current_stream())
self._connector.save_kv_layer(layer_name, kv_layer)

def get_finished(
Expand Down
Loading