Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
SchedulerConfig,
)
from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.metrics import MetricsCollector
Expand Down Expand Up @@ -102,6 +103,20 @@ async def get_engine_runtime_config(
return runtime_config


def build_kv_connector_config(config: Config):
if config.connector is not None:
if config.connector == "kvbm":
return KvCacheConnectorConfig(
connector_module="kvbm.trtllm_integration.connector",
connector_scheduler_class="DynamoKVBMConnectorLeader",
connector_worker_class="DynamoKVBMConnectorWorker",
)
else:
logging.error(f"Invalid connector: {config.connector}")
sys.exit(1)
return None


async def worker():
config = cmd_line_args()

Expand Down Expand Up @@ -165,6 +180,9 @@ async def init(runtime: DistributedRuntime, config: Config):
free_gpu_memory_fraction=config.free_gpu_memory_fraction
)

if config.connector is not None and "kvbm" in config.connector:
kv_cache_config.enable_partial_reuse = False

dynamic_batch_config = DynamicBatchConfig(
enable_batch_size_tuning=True,
enable_max_num_tokens_tuning=False,
Expand All @@ -174,6 +192,8 @@ async def init(runtime: DistributedRuntime, config: Config):
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
dynamic_batch_config=dynamic_batch_config,
)
kv_connector_config = build_kv_connector_config(config)

modality = getattr(config, "modality", None) or "text"
arg_map = {
"model": model_path,
Expand All @@ -190,6 +210,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"max_beam_width": config.max_beam_width,
"max_batch_size": config.max_batch_size,
"return_perf_metrics": config.publish_events_and_metrics,
"kv_connector_config": kv_connector_config,
}

if config.extra_engine_args != "":
Expand Down
9 changes: 9 additions & 0 deletions components/src/dynamo/trtllm/utils/trtllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self) -> None:
self.custom_jinja_template: Optional[str] = None
self.store_kv: str = ""
self.request_plane: str = ""
self.connector: Optional[str] = None

def __str__(self) -> str:
return (
Expand Down Expand Up @@ -275,6 +276,13 @@ def cmd_line_args():
choices=get_reasoning_parser_names(),
help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
)
parser.add_argument(
"--connector",
type=str,
default=None,
choices=["kvbm"],
help="Connector to use for the model.",
)
add_config_dump_args(parser)
parser.add_argument(
"--custom-jinja-template",
Expand Down Expand Up @@ -357,6 +365,7 @@ def cmd_line_args():
config.dump_config_to = args.dump_config_to
config.store_kv = args.store_kv
config.request_plane = args.request_plane
config.connector = args.connector

# Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template:
Expand Down
13 changes: 11 additions & 2 deletions docs/kvbm/trtllm-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ To learn what KVBM is, please check [here](kvbm_architecture.md)

> [!Note]
> - Ensure that `etcd` and `nats` are running before starting.
> - KVBM does not currently support CUDA graphs in TensorRT-LLM.
> - KVBM only supports TensorRT-LLM’s PyTorch backend.
> - Disable partial reuse `enable_partial_reuse: false` in the LLM API config’s `kv_connector_config` to increase offloading cache hits.
> - KVBM requires TensorRT-LLM v1.1.0rc5 or newer.
> - KVBM requires TensorRT-LLM v1.2.0rc2 or newer.
> - Enabling KVBM metrics with TensorRT-LLM is still a work in progress.

## Quick Start
Expand Down Expand Up @@ -107,6 +106,16 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json"

```

KVBM is also supported on the prefill worker of disaggregated serving. To launch the prefill worker, run:
```bash
# [DYNAMO] To serve an LLM model with dynamo
python3 -m dynamo.trtllm \
--model-path Qwen/Qwen3-0.6B \
--served-model-name Qwen/Qwen3-0.6B \
--extra-engine-args /tmp/kvbm_llm_api_config.yaml
--disaggregation-mode prefill &
```

Alternatively, can use "trtllm-serve" with KVBM by replacing the above two [DYNAMO] cmds with below:
```bash
trtllm-serve Qwen/Qwen3-0.6B --host localhost --port 8000 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes:
req.computed_position,
)

output.add_num_scheduled_tokens(
{
str(req.request_id): req.num_scheduled_tokens
for req in scheduler_output.new_requests
+ scheduler_output.cached_requests
}
)

return self._connector.build_connector_metadata(output)

def get_num_new_matched_tokens(
Expand Down
117 changes: 0 additions & 117 deletions lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,6 @@ pub trait Slot: std::fmt::Debug {
num_scheduled_tokens: usize,
) -> Result<(), SlotError>;

// TRT-LLM does not include scheduled tokens in the scheduler output.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so trtllm 1.2.0 includes scheduled tokens now?

Copy link
Contributor Author

@jthomson04 jthomson04 Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, 1.2.0rc2 supports it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can run connector with trtllm starting 1.1.0rc5, if we remove this part, what happens with 1.1.0rc5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.1.0rc5 would break with this MR. We could (in theory) detect the TRTLLM version, and fallback to the non scheduled-tokens implementation, but that could be super ugly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the scheduled-tokens output in rc2 (as well as the scheduled-tokens handling on the KVBM-side) is required for Dynamo + kvbm to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd need to handle it in some way, even if it's simple detect trtllm version -> fail if incompatible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we-d need to also adjust docs, where 1.1.0.rc5 is mentioned as supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated docs, and added a little check that the num_scheduled_tokens field exists; throws an error if it doesn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oandreeva-nv Can you please re-review?

// Ideally, we should have a dedicated implementation for the TRT-LLM slot.
// However, since only this single function needs to be rewritten for now,
// we keep it as a separate function in Slot.
fn apply_scheduler_output_with_computed_position(
&mut self,
tokens: &[u32],
block_ids: &[usize],
computed_position: usize,
is_new_request: bool,
) -> Result<(), SlotError>;

fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;

fn mark_as_prefilling(&mut self, iteration: u64) -> Result<(), SlotError>;
Expand Down Expand Up @@ -642,111 +630,6 @@ impl Slot for VllmConnectorSlot {
Ok(())
}

#[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id.as_str()))]
fn apply_scheduler_output_with_computed_position(
&mut self,
tokens: &[u32],
block_ids: &[usize],
computed_position: usize,
is_new_request: bool,
) -> Result<(), SlotError> {
// TRTLLM's KV Connector Manager will have (computed_position - external matches)
// in onborading case
if computed_position < self.current_position {
tracing::debug!(
"computed_position={} < current_position={}, so we are onboarding during prefilling phase",
computed_position,
self.current_position
);
return Ok(());
}

// now we decide what we should do for the new computed tokens
tracing::debug!(
"applying scheduler output, computed_position={}, sequence_total_tokens={}",
computed_position,
self.sequence.total_tokens()
);

if computed_position < self.sequence.total_tokens() {
// no need to apply new tokens, since it's applied when created the slot during prefilling
self.state = SlotState::Prefilling;
} else {
tracing::debug!(
"appending {} newly decoded tokens to sequence",
tokens.len()
);
self.sequence.extend(tokens.into()).unwrap();
self.state = SlotState::Decoding;
}

// apply new block_ids, this should be applied for both prefilling and decoding
// because this is unknown when creating the slot
if !block_ids.is_empty() {
tracing::debug!("assigning {} new device blocks slot", block_ids.len());
self.device_blocks.extend(block_ids);
}

// This approach is fragile, but it’s the only way currently to skip evaluating
// the device matched blocks and to avoid offloading them again.
// TODO: Consider adding an indicator in the scheduler output to distinguish between
// matched and unmatched device blocks/tokens from the scheduler.
let maybe_have_device_matched_blocks =
is_new_request && computed_position > 0 && self.evaluated_blocks == 0;

if maybe_have_device_matched_blocks {
self.evaluated_blocks = (computed_position + 1) / self.block_size;
}

let num_candidate_blocks =
((computed_position + 1) / self.block_size).saturating_sub(self.evaluated_blocks);

if num_candidate_blocks > 0 {
// do we have a mechanism for skipping gpu cache hit blocks? not sure yet.
// for now, offload all the blocks to the host
let offload_block_ids: Vec<usize> = self
.device_blocks
.iter()
.skip(self.evaluated_blocks)
.take(num_candidate_blocks)
.copied()
.collect::<Vec<_>>();

assert_eq!(
offload_block_ids.len(),
num_candidate_blocks,
"device block overflow - candidate blocks exceed block count at offset {}",
self.evaluated_blocks
);

let offload_token_blocks: Vec<TokenBlock> = self
.sequence
.blocks()
.iter()
.skip(self.evaluated_blocks)
.take(num_candidate_blocks)
.cloned()
.collect::<Vec<_>>();

self.offload_blocks(&offload_block_ids, &offload_token_blocks)
.expect("failed to offload blocks");

self.evaluated_blocks += num_candidate_blocks;
}

// done applying policy
tracing::debug!(
"done applying kv cache policy at current_position: {}; computed_position: {}",
self.current_position,
computed_position,
);

// advance current position to computed position
self.current_position = computed_position;

Ok(())
}

fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError> {
if self.iteration_first_scheduled.is_none() {
self.iteration_first_scheduled = Some(iteration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,16 @@ impl Leader for KvConnectorLeader {
slot.state()
);

slot.apply_scheduler_output_with_computed_position(
let scheduled_tokens = *scheduler_output
.num_scheduled_tokens
.get(request_id)
.unwrap_or(&0);

slot.apply_scheduler_output(
&new_req.prompt_token_ids,
&new_req.block_ids,
new_req.num_computed_tokens,
true,
scheduled_tokens,
)?;

if let Some(pending_ops) = slot.take_pending_operations() {
Expand All @@ -356,11 +361,16 @@ impl Leader for KvConnectorLeader {
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;

slot.apply_scheduler_output_with_computed_position(
let scheduled_tokens = *scheduler_output
.num_scheduled_tokens
.get(request_id)
.unwrap_or(&0);

slot.apply_scheduler_output(
&cached_req.new_token_ids,
&cached_req.new_block_ids,
cached_req.num_computed_tokens,
false,
scheduled_tokens,
)?;

if let Some(pending_ops) = slot.take_pending_operations() {
Expand Down
Loading
Loading