Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ mod utils;
mod worker;

pub use leader::KvbmLeader;
pub use utils::get_barrier_id_prefix;
pub use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
pub use worker::{KvbmWorker, PyLayoutType, VllmTensor};
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
// SPDX-License-Identifier: Apache-2.0

use super::*;
use utils::get_barrier_id_prefix;

use derive_getters::Dissolve;
use llm_rs::block_manager::distributed::{
KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig,
};
use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};

const CPU_CACHE: &str = "DYN_KVBM_CPU_CACHE_GB";
const CPU_CACHE_OVERRIDE: &str = "DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS";
Expand Down Expand Up @@ -72,17 +71,16 @@ impl KvbmLeader {
#[new]
#[pyo3(signature = (world_size, drt))]
fn new(world_size: usize, drt: DistributedRuntime) -> PyResult<Self> {
let barrier_id_prefix = get_barrier_id_prefix();
let leader_init_timeout_sec: u64 =
get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS);

let config = KvbmLeaderConfig::builder()
.barrier_id_prefix(barrier_id_prefix)
.world_size(world_size)
.leader_init_timeout_secs(leader_init_timeout_sec)
.drt(drt.inner().clone())
.host_blocks_config(get_blocks_config(CPU_CACHE, CPU_CACHE_OVERRIDE))
.disk_blocks_config(get_blocks_config(DISK_CACHE, DISK_CACHE_OVERRIDE))
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.build()
.map_err(to_pyerr)?;

Expand Down
63 changes: 59 additions & 4 deletions lib/bindings/python/rust/llm/block_manager/distributed/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::env;

pub fn get_barrier_id_prefix() -> String {
std::env::var("DYN_KVBM_BARRIER_ID_PREFIX")
const DEFAULT_LEADER_ZMQ_HOST: &str = "127.0.0.1";
const DEFAULT_LEADER_ZMQ_PUB_PORT: u16 = 56001;
const DEFAULT_LEADER_ZMQ_ACK_PORT: u16 = 56002;

fn read_env_trimmed(key: &str) -> Option<String> {
env::var(key)
.ok()
.filter(|s| !s.trim().is_empty())
.unwrap_or_else(|| "kvbm".to_string())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}

fn parse_port_u16(s: &str) -> Option<u16> {
match s.parse::<u32>() {
Ok(v) if (1..=65535).contains(&v) => Some(v as u16),
_ => None,
}
}

fn validated_port_from_env(key: &str, default_port: u16) -> u16 {
if let Some(val) = read_env_trimmed(key) {
if let Some(p) = parse_port_u16(&val) {
if p < 1024 {
tracing::warn!("{key} is a privileged port ({p}); binding may require extra caps");
}
return p;
} else {
tracing::warn!("{key} invalid value '{val}', falling back to default {default_port}");
}
}
default_port
}

fn get_leader_zmq_host() -> String {
read_env_trimmed("DYN_KVBM_LEADER_ZMQ_HOST")
.unwrap_or_else(|| DEFAULT_LEADER_ZMQ_HOST.to_string())
}

fn get_leader_zmq_pub_port() -> String {
validated_port_from_env("DYN_KVBM_LEADER_ZMQ_PUB_PORT", DEFAULT_LEADER_ZMQ_PUB_PORT).to_string()
}

fn get_leader_zmq_ack_port() -> String {
validated_port_from_env("DYN_KVBM_LEADER_ZMQ_ACK_PORT", DEFAULT_LEADER_ZMQ_ACK_PORT).to_string()
}

pub fn get_leader_zmq_pub_url() -> String {
format!(
"tcp://{}:{}",
get_leader_zmq_host(),
get_leader_zmq_pub_port()
)
}

pub fn get_leader_zmq_ack_url() -> String {
format!(
"tcp://{}:{}",
get_leader_zmq_host(),
get_leader_zmq_ack_port()
)
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};

use super::*;

use std::sync::Arc;
use utils::get_barrier_id_prefix;

use llm_rs::block_manager::distributed::{
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
Expand Down Expand Up @@ -171,16 +172,13 @@ impl KvbmWorker {
vllm_tensors.push(Arc::new(vllm_tensor));
}

let barrier_id_prefix = get_barrier_id_prefix();

let config = KvbmWorkerConfig::builder()
.drt(drt)
.num_device_blocks(num_device_blocks)
.page_size(page_size)
.tensors(vllm_tensors)
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(barrier_id_prefix)
.device_layout_type(
device_layout_type
.map(|py_layout| py_layout.into())
Expand All @@ -196,6 +194,8 @@ impl KvbmWorker {
.map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous),
)
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.build()
.map_err(to_pyerr)?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ impl KvConnectorLeader {

let _ = slot_manager_cell.set(sm);

// another barrier sync to make sure worker init won't return before leader is ready
let _ = leader.run_leader_readiness_barrier_blocking(drt);

if leader_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("main routine receiver dropped before result was sent");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ impl KvConnectorLeaderRecorder {

let _ = slot_manager_cell.set(sm);

// another barrier sync to make sure worker init won't return before leader is ready
leader.spawn_leader_readiness_barrier(drt);

if leader_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("main routine receiver dropped before result was sent");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ impl KvConnectorLeader {

let _ = slot_manager_cell.set(sm);

// another barrier sync to make sure worker init won't return before leader is ready
leader.spawn_leader_readiness_barrier(drt);

tracing::info!("KvConnectorLeader init complete.");
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashSet;
use std::sync::{Arc, OnceLock};

use super::*;
use crate::llm::block_manager::distributed::get_barrier_id_prefix;
use crate::llm::block_manager::distributed::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
use crate::llm::block_manager::vllm::connector::worker::event_sync_blocking;
use crate::{
DistributedRuntime as PyDistributedRuntime, llm::block_manager::distributed::VllmTensor,
Expand Down Expand Up @@ -138,7 +138,8 @@ impl Worker for KvConnectorWorker {
.device_layout_type(LayoutType::FullyContiguous)
.host_layout_type(LayoutType::FullyContiguous)
.disk_layout_type(LayoutType::FullyContiguous)
.barrier_id_prefix(get_barrier_id_prefix())
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.scheduler_client(Some(self.transfer_client.clone()))
.build()?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashSet;
use std::sync::{Arc, OnceLock};

use super::*;
use crate::llm::block_manager::distributed::get_barrier_id_prefix;
use crate::llm::block_manager::distributed::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
use crate::{
DistributedRuntime as PyDistributedRuntime, llm::block_manager::distributed::VllmTensor,
to_pyerr,
Expand Down Expand Up @@ -200,7 +200,8 @@ impl Worker for KvConnectorWorker {
.tensors(vllm_tensors)
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(get_barrier_id_prefix())
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.scheduler_client(Some(self.transfer_client.clone()))
.device_layout_type(detected_device_layout_type)
.host_layout_type(host_layout_type.unwrap_or(LayoutType::FullyContiguous))
Expand Down
1 change: 1 addition & 0 deletions lib/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ offset-allocator = "0.2"
regex = "1"
rayon = "1"
dashmap = { version = "5.5.3" }
bincode = "1"

# input/text
dialoguer = { version = "0.11", default-features = false, features = [
Expand Down
3 changes: 0 additions & 3 deletions lib/llm/src/block_manager/distributed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,12 @@ mod tests {

async fn build_leader_and_workers(num_workers: usize) -> Result<(KvbmLeader, Vec<KvbmWorker>)> {
let mut workers = Vec::new();
let barrier_id = get_unique_barrier_id();

for i in 0..num_workers {
let tensors: Vec<Arc<dyn TorchTensor>> =
vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))];

let config = KvbmWorkerConfig::builder()
.barrier_id_prefix(barrier_id.clone())
.num_device_blocks(NUM_BLOCKS)
.tensors(tensors)
.device_id(i)
Expand All @@ -151,7 +149,6 @@ mod tests {
};

let leader_config = KvbmLeaderConfig::builder()
.barrier_id_prefix(barrier_id)
.world_size(num_workers)
.host_blocks_config(host_blocks)
.disk_blocks_config(disk_blocks)
Expand Down
Loading
Loading