Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/bindings/python/rust/llm/block_manager/vllm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<SlotUpdate>()?;

m.add_class::<connector::worker::KvConnectorWorker>()?;
m.add_class::<connector::leader::KvConnectorLeader>()?;
m.add_class::<connector::leader::PyKvConnectorLeader>()?;
m.add_class::<connector::SchedulerOutput>()?;
Ok(())
}
Expand Down
132 changes: 122 additions & 10 deletions lib/bindings/python/rust/llm/block_manager/vllm/connector/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

pub mod slot;
pub mod recorder;

use super::*;
use dynamo_runtime::DistributedRuntime;
Expand All @@ -28,6 +29,8 @@ use std::{
sync::{Arc, Mutex},
};
use tokio::sync::mpsc;
use tokio;
use pyo3_async_runtimes;

type VllmLocality = Logical<DistributedLeaderWorkerResources>;

Expand All @@ -36,8 +39,42 @@ impl From<SlotError> for PyErr {
to_pyerr(err)
}
}
use dynamo_llm::recorder::Recorder;
use tokio_util::sync::CancellationToken;

#[pyclass]

pub trait Leader: Send + Sync + std::fmt::Debug {
fn get_num_new_matched_tokens(
&self,
request_id: String,
request_num_tokens: usize,
num_computed_tokens: usize,
) -> PyResult<(usize, bool)>;

fn update_state_after_alloc(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
num_external_tokens: usize,
) -> PyResult<()>;

fn build_connector_metadata(
&mut self,
scheduler_output: SchedulerOutput,
) -> PyResult<Vec<u8>>;

fn request_finished(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
) -> PyResult<bool>;

fn has_slot(&self, request_id: String) -> bool;

fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()>;
}

#[derive(Debug)]
pub struct KvConnectorLeader {
slot_manager: ConnectorSlotManager<String>,
block_size: usize,
Expand All @@ -46,11 +83,9 @@ pub struct KvConnectorLeader {
iteration_counter: u64,
}

#[pymethods]

impl KvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, block_manager, leader))]
pub fn new(
fn new(
worker_id: String,
drt: PyDistributedRuntime,
block_manager: PyBlockManager,
Expand Down Expand Up @@ -78,15 +113,17 @@ impl KvConnectorLeader {
iteration_counter: 0,
}
}
}

impl Leader for KvConnectorLeader {
/// Match the tokens in the request with the available block pools.
/// Note: the necessary details of the request are captured prior to this call. For vllm,
/// we make a create slot call prior to this call, so a slot is guaranteed to exist.
///
/// To align with the connector interface, we must ensure that if no blocks are matched, we return (0, false).
/// In our implementation, if we match any block, we return (num_matched_tokens, true).
#[tracing::instrument(level = "debug", skip(self, request_num_tokens, num_computed_tokens))]
pub fn get_num_new_matched_tokens(
fn get_num_new_matched_tokens(
&self,
request_id: String,
request_num_tokens: usize,
Expand Down Expand Up @@ -137,7 +174,7 @@ impl KvConnectorLeader {
/// Note: vLLM will not provide any scheduler output data for requests that are onboarding. it is entirely
/// on the connector's implementation to handle this case.
#[tracing::instrument(level = "debug", skip_all, fields(request_id))]
pub fn update_state_after_alloc(
fn update_state_after_alloc(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
Expand Down Expand Up @@ -171,7 +208,7 @@ impl KvConnectorLeader {
}

#[tracing::instrument(level = "debug", skip_all)]
pub fn build_connector_metadata(
fn build_connector_metadata(
&mut self,
scheduler_output: SchedulerOutput,
) -> PyResult<Vec<u8>> {
Expand Down Expand Up @@ -318,13 +355,13 @@ impl KvConnectorLeader {
}
}

pub fn has_slot(&self, request_id: String) -> bool {
fn has_slot(&self, request_id: String) -> bool {
self.slot_manager.has_slot(&request_id)
}

/// Create a new slot for the given request ID.
/// This is used to create a new slot for the request.
pub fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()> {
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()> {
self.slot_manager
.create_slot(&request.request_id, tokens, request.salt_hash)?;

Expand All @@ -333,3 +370,78 @@ impl KvConnectorLeader {
Ok(())
}
}

#[pyclass]
pub struct PyKvConnectorLeader {
connector_leader: Box<dyn Leader>,
}

#[pymethods]
impl PyKvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, block_manager, leader))]
pub fn new(
worker_id: String,
drt: PyDistributedRuntime,
block_manager: PyBlockManager,
leader: PyKvbmLeader,
) -> Self {
let enable_kvbm_record = std::env::var("ENABLE_KVBM_RECORD")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);

let connector_leader: Box<dyn Leader> = if enable_kvbm_record {
Box::new(recorder::KvConnectorLeaderRecorder::new(
worker_id,
drt,
block_manager,
leader,
))
} else {
Box::new(KvConnectorLeader::new(
worker_id,
drt,
block_manager,
leader,
))
};
Self { connector_leader }
}

fn get_num_new_matched_tokens(
&self,
request_id: String,
request_num_tokens: usize,
num_computed_tokens: usize,
) -> PyResult<(usize, bool)> {
self.connector_leader.get_num_new_matched_tokens(request_id, request_num_tokens, num_computed_tokens)
}

fn update_state_after_alloc(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
num_external_tokens: usize,
) -> PyResult<()> {
self.connector_leader.update_state_after_alloc(request_id, block_ids, num_external_tokens)
}

fn build_connector_metadata(
&mut self,
scheduler_output: SchedulerOutput,
) -> PyResult<Vec<u8>> {
self.connector_leader.build_connector_metadata(scheduler_output)
}

fn request_finished(&mut self, request_id: &str, block_ids: Vec<BlockId>) -> PyResult<bool> {
self.connector_leader.request_finished(request_id.to_string(), block_ids)
}

fn has_slot(&self, request_id: &str) -> bool {
self.connector_leader.has_slot(request_id.to_string())
}

fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()> {
self.connector_leader.create_slot(request, tokens)
}
}
Loading
Loading