Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pre-refactor
  • Loading branch information
ryanolson committed Sep 25, 2025
commit 7fcf58de94d0f8930be1cfbb10a926697a7c6c6e
1 change: 0 additions & 1 deletion dynamo.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"settings": {
"rust-analyzer.linkedProjects": [
"Cargo.toml",
"launch/dynamo-run/Cargo.toml",
"lib/bindings/python/Cargo.toml"
],
"rust-analyzer.procMacro.enable": true,
Expand Down
2 changes: 1 addition & 1 deletion lib/bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ name = "_core"
crate-type = ["cdylib", "rlib"]

[features]
default = []
default = ["block-manager"]
block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc"]

[dependencies]
Expand Down
1 change: 1 addition & 0 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {

engine::add_to_module(m)?;
parsers::add_to_module(m)?;
llm::scheduler_connector::register_module(m)?;

#[cfg(feature = "block-manager")]
llm::block_manager::add_to_module(m)?;
Expand Down
1 change: 1 addition & 0 deletions lib/bindings/python/rust/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub mod local_model;
pub mod model_card;
pub mod nats;
pub mod preprocessor;
pub mod scheduler_connector;
pub mod vllm_scheduler;

#[cfg(feature = "block-manager")]
Expand Down
202 changes: 202 additions & 0 deletions lib/bindings/python/rust/llm/scheduler_connector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Python bindings for scheduler worker device blocks.

use std::sync::Arc;

use pyo3::prelude::*;

use dynamo_llm::integrations::vllm::scheduler::worker::WorkerDeviceBlocks as RustWorkerDeviceBlocks;
use dynamo_llm::block_manager::storage::torch::{TorchDevice, TorchTensor};

use crate::to_pyerr;

/// A wrapper around a Torch tensor for scheduler connector.
/// We hold onto the py object to ensure it doesn't get GCed.
#[derive(Clone, Debug)]
pub struct SchedulerTensor {
_py_tensor: Py<PyAny>,
device: TorchDevice,
data_ptr: u64,
size_bytes: usize,
shape: Vec<usize>,
stride: Vec<usize>,
}

impl SchedulerTensor {
pub fn new(py_tensor: Py<PyAny>) -> anyhow::Result<Self> {
Python::with_gil(|py| {
let device = py_tensor.getattr(py, "device")?;
let device_type = device.getattr(py, "type")?.extract::<String>(py)?;

let device = if device_type == "cuda" {
TorchDevice::Cuda(device.getattr(py, "index")?.extract::<usize>(py)?)
} else {
TorchDevice::Other(device_type)
};

let data_ptr = py_tensor.call_method0(py, "data_ptr")?.extract::<u64>(py)?;
let size_bytes = py_tensor.getattr(py, "nbytes")?.extract::<usize>(py)?;
let shape = py_tensor.getattr(py, "shape")?.extract::<Vec<usize>>(py)?;
let stride = py_tensor
.call_method0(py, "stride")?
.extract::<Vec<usize>>(py)?;

Ok(Self {
_py_tensor: py_tensor,
device,
data_ptr,
size_bytes,
shape,
stride,
})
})
}
}

impl TorchTensor for SchedulerTensor {
fn device(&self) -> TorchDevice {
self.device.clone()
}

fn data_ptr(&self) -> u64 {
self.data_ptr
}

fn size_bytes(&self) -> usize {
self.size_bytes
}

fn shape(&self) -> Vec<usize> {
self.shape.clone()
}

fn stride(&self) -> Vec<usize> {
self.stride.clone()
}
}

/// Python wrapper for WorkerDeviceBlocks.
///
/// This class provides worker device block construction for the scheduler
/// without requiring leader/worker synchronization.
#[pyclass]
pub struct WorkerDeviceBlocks {
inner: Arc<RustWorkerDeviceBlocks>,
}

#[pymethods]
impl WorkerDeviceBlocks {
/// Create local blocks from KV cache tensors.
///
/// Args:
/// tensors: List of torch tensors (one per layer)
/// num_device_blocks: Number of device blocks
/// page_size: Page size (typically 16)
/// device_id: CUDA device ID
/// dtype_width_bytes: Bytes per dtype element (e.g., 2 for fp16)
/// is_fully_contiguous: Whether layout is fully contiguous
#[new]
#[pyo3(signature = (tensors, num_device_blocks, page_size, device_id=0, dtype_width_bytes=2, is_fully_contiguous=false))]
fn new(
tensors: Vec<Py<PyAny>>,
num_device_blocks: usize,
page_size: usize,
device_id: usize,
dtype_width_bytes: usize,
is_fully_contiguous: bool,
) -> PyResult<Self> {
// Convert Python tensors to Rust tensors
let mut rust_tensors: Vec<Arc<dyn TorchTensor>> = Vec::with_capacity(tensors.len());

for tensor in tensors {
let scheduler_tensor = SchedulerTensor::new(tensor).map_err(to_pyerr)?;
rust_tensors.push(Arc::new(scheduler_tensor));
}

// Build worker device blocks
let worker_blocks = RustWorkerDeviceBlocks::from_tensors(
rust_tensors,
num_device_blocks,
page_size,
device_id,
dtype_width_bytes,
is_fully_contiguous,
)
.map_err(to_pyerr)?;

Ok(Self {
inner: Arc::new(worker_blocks),
})
}

/// Get the number of device blocks.
#[getter]
fn num_device_blocks(&self) -> usize {
self.inner.num_device_blocks
}

/// Get the number of layers.
#[getter]
fn num_layers(&self) -> usize {
self.inner.num_layers
}

/// Get the outer dimension.
#[getter]
fn outer_dim(&self) -> usize {
self.inner.outer_dim
}

/// Get the page size.
#[getter]
fn page_size(&self) -> usize {
self.inner.page_size
}

/// Get the inner dimension.
#[getter]
fn inner_dim(&self) -> usize {
self.inner.inner_dim
}

/// Get the dtype width in bytes.
#[getter]
fn dtype_width_bytes(&self) -> usize {
self.inner.dtype_width_bytes
}

/// Get the total bytes per block.
#[getter]
fn bytes_per_block(&self) -> usize {
self.inner.bytes_per_block
}

/// Get the number of blocks that were created.
fn num_blocks(&self) -> usize {
self.inner.device_blocks.len()
}

/// String representation for debugging.
fn __repr__(&self) -> String {
format!(
"WorkerDeviceBlocks(num_blocks={}, num_layers={}, outer_dim={}, page_size={}, inner_dim={}, dtype_width_bytes={}, bytes_per_block={})",
self.inner.device_blocks.len(),
self.inner.num_layers,
self.inner.outer_dim,
self.inner.page_size,
self.inner.inner_dim,
self.inner.dtype_width_bytes,
self.inner.bytes_per_block
)
}
}

/// Register the module with Python.
pub fn register_module(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
let m = PyModule::new(parent_module.py(), "scheduler_connector")?;
m.add_class::<WorkerDeviceBlocks>()?;
parent_module.add_submodule(&m)?;
Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from typing import TYPE_CHECKING, Optional

import torch
from vllm.model_executor.models.utils import extract_layer_index
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

# Import our local block builder
from dynamo._core import scheduler_connector

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
Expand All @@ -32,16 +37,83 @@ def __init__(self, vllm_config: "VllmConfig", engine_id: str, **kwargs):
"""Initialize the scheduler connector worker."""
self.vllm_config = vllm_config
self.engine_id = engine_id
self.local_blocks = None
print(f"SchedulerConnectorWorker initialized with engine_id: {engine_id}")

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Register KV caches - no-op for now.
Register KV caches - builds local blocks without leader sync.

Will be implemented in a later phase.
This creates device blocks locally from the provided tensors
without requiring any network setup or synchronization.
"""
# TODO: Implement in future phase
pass
if not kv_caches:
print("Warning: register_kv_caches called with empty kv_caches")
return

print(
f"SchedulerConnectorWorker.register_kv_caches called with {len(kv_caches)} layers"
)

# Extract configuration from vLLM config
cache_config = self.vllm_config.cache_config

# Sort tensors by layer index to ensure correct ordering
ordered_kv_caches = sorted(
kv_caches.items(), key=lambda item: extract_layer_index(item[0])
)

# Extract tensors in order
tensors = [tensor for _, tensor in ordered_kv_caches]

# Get first tensor to extract common properties
first_tensor = tensors[0]
shape = first_tensor.shape

# Validate all tensors have same shape
if not all(t.shape == shape for t in tensors):
raise NotImplementedError(
"Hybrid models with different KV cache shapes are not supported yet."
)

# Extract parameters
# TODO: Assume the block dimension is within the first 2. This will break if you're doing something weird
num_device_blocks = max(shape[0], shape[1])
page_size = cache_config.block_size
device_id = (
first_tensor.device.index if first_tensor.device.type == "cuda" else 0
)

# Determine cache dtype
if cache_config.cache_dtype == "auto":
kv_cache_dtype = self.vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

dtype_width_bytes = kv_cache_dtype.itemsize

# Build worker device blocks
try:
self.local_blocks = scheduler_connector.WorkerDeviceBlocks(
tensors=tensors,
num_device_blocks=num_device_blocks,
page_size=page_size,
device_id=device_id,
dtype_width_bytes=dtype_width_bytes,
is_fully_contiguous=False, # Default to layer-separate layout
)

print(f"Successfully built worker device blocks: {self.local_blocks}")
print(f" - Blocks created: {self.local_blocks.num_blocks()}")
print(f" - Layers: {self.local_blocks.num_layers}")
print(f" - Outer dim: {self.local_blocks.outer_dim}")
print(f" - Page size: {self.local_blocks.page_size}")
print(f" - Inner dim: {self.local_blocks.inner_dim}")
print(f" - Bytes per block: {self.local_blocks.bytes_per_block}")

except Exception as e:
print(f"Failed to build worker device blocks: {e}")
raise

def bind_connector_metadata(self, data: bytes) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion lib/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ readme.workspace = true
description = "Dynamo LLM Library"

[features]
default = []
default = ["block-manager"]
# todo(ops): get this working in CI as a default.
# default = ["block-manager", "testing-full"]

Expand Down
2 changes: 2 additions & 0 deletions lib/llm/src/block_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
//! mechanisms. It handles storage allocation, block management, and safe access
//! patterns for both system memory and remote (NIXL) storage.

pub mod v2;

pub mod config;
mod state;

Expand Down
Loading