Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
86c79ba
some prelim cleanups
PeaBrane May 30, 2025
6bee243
router can route to dp ranks
PeaBrane May 30, 2025
dab052c
make the bunny hoppy
PeaBrane May 30, 2025
be6900e
Merge remote-tracking branch 'origin/main' into rupei/router-general
PeaBrane May 30, 2025
25e1291
Merge remote-tracking branch 'origin/main' into rupei/router-general
PeaBrane May 30, 2025
34e5c5b
new struct combining worker_id with dp_rank, dirty commit, breaks bin…
PeaBrane May 30, 2025
2cef74c
binding works
PeaBrane May 30, 2025
10d3326
dummy c binding note
PeaBrane May 30, 2025
4483c68
add_class WorkerWithDpRank
PeaBrane May 30, 2025
263c12d
renames + comments + fmt
PeaBrane May 31, 2025
65ea6b5
allow suffix for dp_rank identification
PeaBrane Jun 3, 2025
a2ef896
WIP: fix fn dp_rank, add TODO's
alec-flowers Jun 3, 2025
e80d66c
refactor: fix bugs, kv publishing working
alec-flowers Jun 3, 2025
7a733bd
fix panicing metric thread issue
alec-flowers Jun 4, 2025
1bddc8e
remove verbose log
alec-flowers Jun 4, 2025
ee283cc
update v1 worker
alec-flowers Jun 4, 2025
183a8fe
put dp_rank in PreprocessedRequest
PeaBrane Jun 4, 2025
be7f951
new agg config
PeaBrane Jun 4, 2025
e1011d8
updated comments
PeaBrane Jun 4, 2025
5bf4fae
update v1 example
alec-flowers Jun 4, 2025
d6ded6c
final touches for it working with dp
alec-flowers Jun 4, 2025
61b94ac
Merge branch 'main' into rupei/router-general
alec-flowers Jun 4, 2025
9335efe
fix cost function trace
PeaBrane Jun 4, 2025
931b837
fmt
PeaBrane Jun 4, 2025
2a72271
Merge branch 'main' into rupei/router-general
PeaBrane Jun 4, 2025
eb7bb10
WIP document current work steps
alec-flowers Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
binding works
  • Loading branch information
PeaBrane committed May 30, 2025
commit 2cef74c6fe608128f68504b22a69b475477a1dfe
11 changes: 5 additions & 6 deletions components/metrics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events
//! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache
use clap::Parser;
use dynamo_llm::kv_router::protocols::{WorkerWithDpRank, KVHitRateEvent, WorkerId};
use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerWithDpRank};
use dynamo_llm::kv_router::KV_HIT_RATE_SUBJECT;
use dynamo_runtime::{
error, logging,
Expand Down Expand Up @@ -180,16 +180,15 @@ async fn app(runtime: Runtime) -> Result<()> {
tracing::debug!("Successfully subscribed to KV hit rate events");

while let Some(msg) = subscriber.next().await {
match serde_json::from_slice::<KVHitRateEvent<WorkerWithDpRank>>(&msg.payload)
{
match serde_json::from_slice::<KVHitRateEvent<WorkerWithDpRank>>(&msg.payload) {
Ok(event) => {
// TODO: Lower to debug
let cache_hit_pct =
(event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0;
tracing::debug!(
"Received KV hit rate event: worker_id={}, dp_rank={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%",
event.worker_id_general.0,
event.worker_id_general.1,
event.worker_id_general.worker_id,
event.worker_id_general.dp_rank.unwrap_or(0),
event.isl_blocks,
event.overlap_blocks,
cache_hit_pct
Expand All @@ -200,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> {
metrics.update_kv_hit_rate(
&config_clone,
// TODO: this will not take care of dp ranks
event.worker_id_general.0,
event.worker_id_general.worker_id,
event.isl_blocks,
event.overlap_blocks,
);
Expand Down
2 changes: 1 addition & 1 deletion components/router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::sync::Arc;
use clap::Parser;

use dynamo_llm::kv_router::{
protocols::{WorkerWithDpRank, WorkerSelectionResult},
protocols::{WorkerSelectionResult, WorkerWithDpRank},
scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
KvRouter, WorkerSelector,
Expand Down
13 changes: 11 additions & 2 deletions lib/bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// limitations under the License.

use async_once_cell::OnceCell as AsyncOnceCell;
use dynamo_llm::kv_router::publisher::KvCacheEventWithDp;
use libc::c_char;
use once_cell::sync::OnceCell;
use std::ffi::CStr;
Expand Down Expand Up @@ -284,7 +285,11 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
};
let publisher = KV_PUB.get().unwrap();
let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
match publisher.publish(event) {
let event_with_dp = KvCacheEventWithDp {
kv_cache_event: event,
dp_rank: None,
};
match publisher.publish(event_with_dp) {
Ok(_) => DynamoLlmResult::OK,
Err(e) => {
eprintln!("Error publishing stored kv event {:?}", e);
Expand All @@ -301,7 +306,11 @@ pub extern "C" fn dynamo_kv_event_publish_removed(
) -> DynamoLlmResult {
let publisher = KV_PUB.get().unwrap();
let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks);
match publisher.publish(event) {
let event_with_dp = KvCacheEventWithDp {
kv_cache_event: event,
dp_rank: None,
};
match publisher.publish(event_with_dp) {
Ok(_) => DynamoLlmResult::OK,
Err(e) => {
eprintln!("Error publishing removed kv event {:?}", e);
Expand Down
73 changes: 50 additions & 23 deletions lib/bindings/python/rust/llm/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ use tracing;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig, KvCacheEventWithDp};

#[pyclass]
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct WorkerWithDpRank {
#[pyo3(get, set)]
pub worker_id: i64,
#[pyo3(get, set)]
pub dp_rank: Option<u32>,
}

impl From<llm_rs::kv_router::protocols::WorkerWithDpRank> for WorkerWithDpRank {
fn from(value: llm_rs::kv_router::protocols::WorkerWithDpRank) -> Self {
Self {
worker_id: value.worker_id,
dp_rank: value.dp_rank,
}
}
}

#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
Expand Down Expand Up @@ -57,7 +75,7 @@ impl KvRouter {
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(worker_id)
Ok(WorkerWithDpRank::from(worker_id))
})
}
}
Expand Down Expand Up @@ -107,7 +125,7 @@ impl WorkerMetricsPublisher {
num_requests_waiting: u64,
gpu_cache_usage_perc: f32,
gpu_prefix_cache_hit_rate: f32,
data_parallel_rank: u32,
data_parallel_rank: DpRank,
) -> PyResult<()> {
self.inner
.publish(
Expand Down Expand Up @@ -218,7 +236,7 @@ impl KvEventPublisher {
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))]
#[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, dp_rank=None))]
fn publish_stored(
&mut self,
_py: Python,
Expand All @@ -228,6 +246,7 @@ impl KvEventPublisher {
block_hashes: Vec<i64>,
lora_id: u64,
parent_hash: Option<i64>,
dp_rank: Option<DpRank>,
) -> PyResult<()> {
let event = KvCacheEvent {
event_id,
Expand All @@ -244,13 +263,14 @@ impl KvEventPublisher {
}),
};
let event_with_dp = KvCacheEventWithDp {
kv_cache_event: event, dp_rank: None,
kv_cache_event: event, dp_rank,
};

self.inner.publish(event_with_dp).map_err(to_pyerr)
}

fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec<i64>) -> PyResult<()> {
#[pyo3(signature = (event_id, block_hashes, dp_rank=None))]
fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec<i64>, dp_rank: Option<DpRank>) -> PyResult<()> {
let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
.iter()
.map(|&h| ExternalSequenceBlockHash::from(h))
Expand All @@ -260,7 +280,7 @@ impl KvEventPublisher {
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
};
let event_with_dp = KvCacheEventWithDp {
kv_cache_event: event, dp_rank: None,
kv_cache_event: event, dp_rank,
};

self.inner.publish(event_with_dp).map_err(to_pyerr)
Expand All @@ -270,14 +290,16 @@ impl KvEventPublisher {
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
inner: llm_rs::kv_router::indexer::OverlapScores<(WorkerId, DpRank)>,
inner: llm_rs::kv_router::indexer::OverlapScores<llm_rs::kv_router::protocols::WorkerWithDpRank>,
}

#[pymethods]
impl OverlapScores {
#[getter]
fn scores(&self) -> HashMap<(WorkerId, DpRank), u32> {
self.inner.scores.clone()
fn scores(&self) -> HashMap<WorkerWithDpRank, u32> {
self.inner.scores.iter()
.map(|(k, v)| (WorkerWithDpRank::from(*k), *v))
.collect()
}

#[getter]
Expand All @@ -288,7 +310,7 @@ impl OverlapScores {

#[pyclass]
pub(crate) struct KvIndexer {
inner: Arc<llm_rs::kv_router::indexer::KvIndexer<(WorkerId, DpRank)>>,
inner: Arc<llm_rs::kv_router::indexer::KvIndexer<llm_rs::kv_router::protocols::WorkerWithDpRank>>,
}

#[pymethods]
Expand All @@ -297,7 +319,7 @@ impl KvIndexer {
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer<(WorkerId, DpRank)>> =
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer<llm_rs::kv_router::protocols::WorkerWithDpRank>> =
llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(),
kv_block_size,
Expand All @@ -316,7 +338,7 @@ impl KvIndexer {
// should have been made to a trait and implemented here? i.e. AsyncEngine style
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::protocols::RouterEvent<(WorkerId, DpRank)> =
let event: llm_rs::kv_router::protocols::RouterEvent<llm_rs::kv_router::protocols::WorkerWithDpRank> =
serde_json::from_slice(&event.payload).unwrap();
tracing::debug!("received kv event: {:?}", event);
if let Err(e) = kv_events_tx.send(event).await {
Expand Down Expand Up @@ -360,6 +382,8 @@ pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)]
pub worker_id: i64,
#[pyo3(get, set)]
pub dp_rank: Option<DpRank>,
#[pyo3(get, set)]
pub request_active_slots: u64,
#[pyo3(get, set)]
pub request_total_slots: u64,
Expand Down Expand Up @@ -413,15 +437,18 @@ impl KvMetricsAggregator {
let endpoint_kv_metrics = endpoints
.endpoints
.iter()
.map(|(worker_id, x)| EndpointKvMetrics {
worker_id: *worker_id,
request_active_slots: x.data.request_active_slots,
request_total_slots: x.data.request_total_slots,
kv_active_blocks: x.data.kv_active_blocks,
kv_total_blocks: x.data.kv_total_blocks,
num_requests_waiting: x.data.num_requests_waiting,
gpu_cache_usage_perc: x.data.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: x.data.gpu_prefix_cache_hit_rate,
.flat_map(|(worker_id, x)| {
x.data.iter().map(move |data_item| EndpointKvMetrics {
worker_id: *worker_id,
dp_rank: data_item.data_parallel_rank,
request_active_slots: data_item.request_active_slots,
request_total_slots: data_item.request_total_slots,
kv_active_blocks: data_item.kv_active_blocks,
kv_total_blocks: data_item.kv_total_blocks,
num_requests_waiting: data_item.num_requests_waiting,
gpu_cache_usage_perc: data_item.gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: data_item.gpu_prefix_cache_hit_rate,
})
})
.collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Expand All @@ -436,7 +463,7 @@ impl KvMetricsAggregator {

#[pyclass]
pub(crate) struct KvRecorder {
inner: Arc<llm_rs::kv_router::recorder::KvRecorder>,
inner: Arc<llm_rs::kv_router::recorder::KvRecorder<llm_rs::kv_router::protocols::WorkerWithDpRank>>,
}

#[pymethods]
Expand Down Expand Up @@ -487,7 +514,7 @@ impl KvRecorder {
// Spawn a task to forward events to the recorder
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::indexer::RouterEvent =
let event: llm_rs::kv_router::protocols::RouterEvent<llm_rs::kv_router::protocols::WorkerWithDpRank> =
serde_json::from_slice(&event.payload).unwrap();
tracing::debug!("KvRecorder received kv event: {:?}", event);
if let Err(e) = event_tx.send(event).await {
Expand Down
7 changes: 3 additions & 4 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use dynamo_runtime::{
protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
use protocols::{WorkerId, WorkerWithDpRank};
use protocols::WorkerWithDpRank;

pub mod indexer;
pub mod metrics_aggregator;
Expand Down Expand Up @@ -171,13 +171,12 @@ impl
async fn generate(
&self,
request: SingleIn<RouterRequest>,
) -> Result<ManyOut<Annotated<RouterResponse<WorkerId>>>> {
) -> Result<ManyOut<Annotated<RouterResponse<WorkerWithDpRank>>>> {
let (request, ctx) = request.into_parts();
let (best_match, _) = self.find_best_match(&request.tokens).await?;

// NOTE: this ignores dp routing
let response = RouterResponse {
worker_id_general: best_match.worker_id,
worker_id_general: best_match,
};
let response = Annotated::from_data(response);
let stream = stream::iter(vec![response]);
Expand Down
2 changes: 1 addition & 1 deletion lib/llm/src/kv_router/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ impl WorkerSelector for DefaultWorkerSelector {
for dp_rank in ep.data.iter().map(|metrics| metrics.data_parallel_rank) {
let worker_with_dp_rank = WorkerWithDpRank {
worker_id: *worker_id,
dp_rank: dp_rank,
dp_rank,
};
if let Some(score) = request.overlap.scores.get(&worker_with_dp_rank) {
let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
Expand Down
2 changes: 1 addition & 1 deletion lib/llm/src/kv_router/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl ProcessedEndpoints {
.flat_map(|endpoint| endpoint.data.iter())
.map(|metrics| metrics.kv_active_blocks as f64)
.collect();
if load_values.len() == 0 {
if load_values.is_empty() {
panic!("No endpoints to process!")
};
let load_avg = load_values.iter().copied().sum::<f64>() / load_values.len() as f64;
Expand Down
Loading