Skip to content
Merged
Next Next commit
trtllm integration connector api
fix

fix

fix

fix interace

fix interace

fix

add logs

async leader

fix

fix

fix

fix

fix scheduled tokens

fix

fix

fix

fix

add logs

add logs

fix

fix and log

fix and log

fix

fix

fix layout

fix

fix

fix

fix

fmt

fix

fix

comments

fmt

fix comment

Signed-off-by: richardhuo-nv <[email protected]>
  • Loading branch information
richardhuo-nv committed Aug 29, 2025
commit 147dc1b801a5bef20a21170e64cb02e6b9aceb6a
111 changes: 111 additions & 0 deletions lib/bindings/python/rust/llm/block_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
// limitations under the License.

use super::*;
use anyhow::Result;
use dynamo_llm::block_manager::block::{
data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical,
};
Expand Down Expand Up @@ -220,3 +221,113 @@ impl BlockManager {
&self.inner
}
}

#[derive(Default)]
pub struct BlockManagerBuilder {
worker_id: u64,
leader: Option<distributed::KvbmLeader>,
page_size: usize,
disable_device_pool: bool,
}

impl BlockManagerBuilder {
pub fn new() -> Self {
Self {
page_size: 0,
..Default::default()
}
}

pub fn worker_id(mut self, id: u64) -> Self {
self.worker_id = id;
self
}
pub fn page_size(mut self, ps: usize) -> Self {
self.page_size = ps;
self
}
pub fn leader(mut self, l: distributed::KvbmLeader) -> Self {
self.leader = Some(l);
self
}
pub fn disable_device_pool(mut self, yes: bool) -> Self {
self.disable_device_pool = yes;
self
}

/// Async build (call from an async context).
pub async fn build(self) -> Result<BlockManager> {
let worker_id = self.worker_id;
let leader = self.leader.ok_or_else(|| {
anyhow::anyhow!("leader is required (runtime is always taken from leader)")
})?;

// Get (inner leader handle, runtime) from the provided leader.
let (leader_inner, drt) = leader.dissolve();

let cancel_token = CancellationToken::new();

// Runtime & model config
let runtime_config = dynamo_llm::block_manager::KvManagerRuntimeConfig::builder()
.worker_id(worker_id)
.cancellation_token(cancel_token.clone())
.build()?;

let mut config =
dynamo_llm::block_manager::KvBlockManagerConfig::builder().runtime(runtime_config);

let model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
.num_layers(1)
.outer_dim(1)
.page_size(self.page_size)
.inner_dim(1)
.build()?;

config = config.model(model_config);

// Layouts derived from leader’s counts
if !self.disable_device_pool {
config = config.device_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(leader_inner.num_device_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
.build()?,
);
}

if leader_inner.num_host_blocks() > 0 {
config = config.host_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(leader_inner.num_host_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
.build()?,
);
}

if leader_inner.num_disk_blocks() > 0 {
config = config.disk_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(leader_inner.num_disk_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
.build()?,
);
}

let config = config.build()?;

let resources =
DistributedLeaderWorkerResources::new(Some(leader_inner), cancel_token.child_token())?;

let inner = dynamo_llm::block_manager::KvBlockManager::<
Logical<DistributedLeaderWorkerResources>,
BasicMetadata,
>::new(config, resources)
.await?;

Ok(BlockManager {
inner,
drt,
_controller: None,
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ mod utils;
mod worker;

pub use leader::KvbmLeader;
pub use utils::get_barrier_id;
pub use utils::get_barrier_id_prefix;
pub use worker::{KvbmWorker, VllmTensor};
56 changes: 36 additions & 20 deletions lib/bindings/python/rust/llm/block_manager/distributed/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// SPDX-License-Identifier: Apache-2.0

use super::*;
use utils::get_barrier_id;
use utils::get_barrier_id_prefix;

use derive_getters::Dissolve;
use llm_rs::block_manager::distributed::{KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig};
use llm_rs::block_manager::distributed::{
KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig,
};

const CPU_CACHE: &str = "DYN_KVBM_CPU_CACHE_GB";
const CPU_CACHE_OVERRIDE: &str = "DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS";
Expand All @@ -16,15 +18,32 @@ const DISK_CACHE_OVERRIDE: &str = "DYN_KVBM_DISK_CACHE_OVERRIDE_NUM_BLOCKS";
const LEADER_WORKER_INIT_TIMEOUT_SECS: &str = "DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS";
const DEFAULT_INIT_TIMEOUT_SECS: u64 = 120;

fn compute_num_blocks(cache_size_key: &str, override_key: &str, bytes_per_block: usize) -> usize {
if let Ok(override_num_blocks) = std::env::var(override_key) {
override_num_blocks.parse::<usize>().unwrap_or(0)
} else {
let cache_size_gb = std::env::var(cache_size_key)
.unwrap_or_default()
.parse::<f64>()
.unwrap_or(0.0);
((cache_size_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize
fn read_env_usize(key: &str) -> Option<usize> {
std::env::var(key).ok()?.trim().parse::<usize>().ok()
}

fn read_cache_size_float(key: &str) -> f64 {
std::env::var(key)
.unwrap_or_default()
.parse::<f64>()
.unwrap_or(0.0)
}

fn get_blocks_config(cache_size_key: &str, override_key: &str) -> KvbmLeaderNumBlocksConfig {
if let Some(nblocks) = read_env_usize(override_key) {
// Optional: still read cache size for observability, but override takes precedence.
let cache_gb: f64 = read_cache_size_float(cache_size_key);
return KvbmLeaderNumBlocksConfig {
cache_size_in_gb: cache_gb,
num_blocks_overriden: nblocks,
};
}

// No override -> compute from cache size (in GB)
let cache_gb: f64 = read_cache_size_float(cache_size_key);
KvbmLeaderNumBlocksConfig {
cache_size_in_gb: cache_gb,
num_blocks_overriden: 0,
}
}

Expand All @@ -51,22 +70,19 @@ impl KvbmLeader {
#[pymethods]
impl KvbmLeader {
#[new]
#[pyo3(signature = (bytes_per_block, world_size, drt))]
fn new(bytes_per_block: usize, world_size: usize, drt: DistributedRuntime) -> PyResult<Self> {
let num_host_blocks = compute_num_blocks(CPU_CACHE, CPU_CACHE_OVERRIDE, bytes_per_block);
let num_disk_blocks = compute_num_blocks(DISK_CACHE, DISK_CACHE_OVERRIDE, bytes_per_block);

let barrier_id = get_barrier_id();
#[pyo3(signature = (world_size, drt))]
fn new(world_size: usize, drt: DistributedRuntime) -> PyResult<Self> {
let barrier_id_prefix = get_barrier_id_prefix();
let leader_init_timeout_sec: u64 =
get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS);

let config = KvbmLeaderConfig::builder()
.barrier_id(barrier_id)
.num_host_blocks(num_host_blocks)
.num_disk_blocks(num_disk_blocks)
.barrier_id_prefix(barrier_id_prefix)
.world_size(world_size)
.leader_init_timeout_secs(leader_init_timeout_sec)
.drt(drt.inner().clone())
.host_blocks_config(get_blocks_config(CPU_CACHE, CPU_CACHE_OVERRIDE))
.disk_blocks_config(get_blocks_config(DISK_CACHE, DISK_CACHE_OVERRIDE))
.build()
.map_err(to_pyerr)?;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

pub fn get_barrier_id() -> String {
std::env::var("DYN_KVBM_BARRIER_ID").unwrap_or("kvbm".to_string())
pub fn get_barrier_id_prefix() -> String {
std::env::var("DYN_KVBM_BARRIER_ID_PREFIX").unwrap_or("kvbm".to_string())
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use super::*;

use std::sync::Arc;
use utils::get_barrier_id;
use utils::get_barrier_id_prefix;

use llm_rs::block_manager::distributed::{
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
Expand Down Expand Up @@ -131,7 +131,7 @@ impl KvbmWorker {
vllm_tensors.push(Arc::new(vllm_tensor));
}

let barrier_id = get_barrier_id();
let barrier_id_prefix = get_barrier_id_prefix();

let config = KvbmWorkerConfig::builder()
.drt(drt)
Expand All @@ -140,7 +140,7 @@ impl KvbmWorker {
.tensors(vllm_tensors)
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.barrier_id(barrier_id)
.barrier_id_prefix(barrier_id_prefix)
.build()
.map_err(to_pyerr)?;

Expand Down
3 changes: 3 additions & 0 deletions lib/bindings/python/rust/llm/block_manager/vllm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<connector::worker::PyKvConnectorWorker>()?;
m.add_class::<connector::leader::PyKvConnectorLeader>()?;
m.add_class::<connector::SchedulerOutput>()?;
// TODO: use TRTLLM own integration module
m.add_class::<connector::trtllm_worker::PyTrtllmKvConnectorWorker>()?;
m.add_class::<connector::trtllm_leader::PyTrtllmKvConnectorLeader>()?;
Ok(())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use dynamo_llm::block_manager::{
};

pub mod leader;
pub mod trtllm_leader;
pub mod trtllm_worker;
pub mod worker;

use pyo3::prelude::*;
Expand Down
Loading