diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 59a35b39d3..5ab688be27 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -29,6 +29,7 @@ SchedulerConfig, ) from tensorrt_llm.llmapi.llm import SamplingParams +from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from tensorrt_llm.metrics import MetricsCollector @@ -110,6 +111,20 @@ async def get_engine_runtime_config( return runtime_config +def build_kv_connector_config(config: Config): + if config.connector is not None: + if config.connector == "kvbm": + return KvCacheConnectorConfig( + connector_module="kvbm.trtllm_integration.connector", + connector_scheduler_class="DynamoKVBMConnectorLeader", + connector_worker_class="DynamoKVBMConnectorWorker", + ) + else: + logging.error(f"Invalid connector: {config.connector}") + sys.exit(1) + return None + + async def worker(): config = cmd_line_args() @@ -173,6 +188,9 @@ async def init(runtime: DistributedRuntime, config: Config): free_gpu_memory_fraction=config.free_gpu_memory_fraction ) + if config.connector is not None and "kvbm" in config.connector: + kv_cache_config.enable_partial_reuse = False + dynamic_batch_config = DynamicBatchConfig( enable_batch_size_tuning=True, enable_max_num_tokens_tuning=False, @@ -182,6 +200,8 @@ async def init(runtime: DistributedRuntime, config: Config): capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, dynamic_batch_config=dynamic_batch_config, ) + kv_connector_config = build_kv_connector_config(config) + modality = getattr(config, "modality", None) or "text" arg_map = { "model": model_path, @@ -198,6 +218,7 @@ async def init(runtime: DistributedRuntime, config: Config): "max_beam_width": config.max_beam_width, "max_batch_size": config.max_batch_size, "return_perf_metrics": config.publish_events_and_metrics, + "kv_connector_config": kv_connector_config, } if config.extra_engine_args != "": diff --git a/components/src/dynamo/trtllm/utils/trtllm_utils.py b/components/src/dynamo/trtllm/utils/trtllm_utils.py index 5bcb94af16..805f45c53d 100644 --- a/components/src/dynamo/trtllm/utils/trtllm_utils.py +++ b/components/src/dynamo/trtllm/utils/trtllm_utils.py @@ -61,6 +61,7 @@ def __init__(self) -> None: self.dyn_endpoint_types: str = "chat,completions" self.store_kv: str = "" self.request_plane: str = "" + self.connector: Optional[str] = None def __str__(self) -> str: return ( @@ -276,6 +277,13 @@ def cmd_line_args(): choices=get_reasoning_parser_names(), help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.", ) + parser.add_argument( + "--connector", + type=str, + default=None, + choices=["kvbm"], + help="Connector to use for the model.", + ) add_config_dump_args(parser) parser.add_argument( "--custom-jinja-template", @@ -365,6 +373,7 @@ def cmd_line_args(): config.dyn_endpoint_types = args.dyn_endpoint_types config.store_kv = args.store_kv config.request_plane = args.request_plane + config.connector = args.connector # Handle custom jinja template path expansion (environment variables and home directory) if args.custom_jinja_template: diff --git a/docs/kvbm/trtllm-setup.md b/docs/kvbm/trtllm-setup.md index 3884fad4c2..5621d16ffa 100644 --- a/docs/kvbm/trtllm-setup.md +++ b/docs/kvbm/trtllm-setup.md @@ -23,10 +23,9 @@ To learn what KVBM is, please check [here](kvbm_architecture.md) > [!Note] > - Ensure that `etcd` and `nats` are running before starting. -> - KVBM does not currently support CUDA graphs in TensorRT-LLM. > - KVBM only supports TensorRT-LLM’s PyTorch backend. > - Disable partial reuse `enable_partial_reuse: false` in the LLM API config’s `kv_connector_config` to increase offloading cache hits. -> - KVBM requires TensorRT-LLM v1.1.0rc5 or newer. +> - KVBM requires TensorRT-LLM v1.2.0rc2 or newer. > - Enabling KVBM metrics with TensorRT-LLM is still a work in progress. ## Quick Start @@ -107,6 +106,16 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" ``` +KVBM is also supported on the prefill worker of disaggregated serving. To launch the prefill worker, run: +```bash +# [DYNAMO] To serve an LLM model with dynamo +python3 -m dynamo.trtllm \ + --model-path Qwen/Qwen3-0.6B \ + --served-model-name Qwen/Qwen3-0.6B \ + --extra-engine-args /tmp/kvbm_llm_api_config.yaml + --disaggregation-mode prefill & +``` + Alternatively, can use "trtllm-serve" with KVBM by replacing the above two [DYNAMO] cmds with below: ```bash trtllm-serve Qwen/Qwen3-0.6B --host localhost --port 8000 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml diff --git a/lib/bindings/kvbm/Cargo.lock b/lib/bindings/kvbm/Cargo.lock index b9ff2fc358..78a9d24714 100644 --- a/lib/bindings/kvbm/Cargo.lock +++ b/lib/bindings/kvbm/Cargo.lock @@ -199,13 +199,13 @@ dependencies = [ [[package]] name = "async-nats" -version = "0.40.0" +version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e23419d455dc57d3ae60a2f4278cf561fc74fe866e548e14d2b0ad3e1b8ca0b2" +checksum = "86dde77d8a733a9dbaf865a9eb65c72e09c88f3d14d3dd0d2aecf511920ee4fe" dependencies = [ "base64 0.22.1", "bytes", - "futures", + "futures-util", "memchr", "nkeys", "nuid", @@ -226,6 +226,7 @@ dependencies = [ "time", "tokio", "tokio-rustls", + "tokio-stream", "tokio-util", "tokio-websockets", "tracing", @@ -1473,7 +1474,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1630,6 +1631,8 @@ dependencies = [ "modelexpress-client", "modelexpress-common", "ndarray", + "ndarray-interp", + "ndarray-npy", "nix 0.26.4", "nixl-sys", "offset-allocator", @@ -1657,7 +1660,7 @@ dependencies = [ "toktrie", "toktrie_hf_tokenizers", "tonic 0.13.1", - "tonic-build", + "tonic-build 0.13.1", "tower", "tower-http", "tracing", @@ -1731,6 +1734,7 @@ dependencies = [ "opentelemetry-otlp", "opentelemetry_sdk", "parking_lot", + "percent-encoding", "prometheus", "rand 0.9.2", "rayon", @@ -1890,7 +1894,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1904,16 +1908,18 @@ dependencies = [ [[package]] name = "etcd-client" -version = "0.16.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88365f1a5671eb2f7fc240adb216786bc6494b38ce15f1d26ad6eaa303d5e822" +checksum = "8acfe553027cd07fc5fafa81a84f19a7a87eaffaccd2162b6db05e8d6ce98084" dependencies = [ "http", - "prost 0.13.5", + "prost 0.14.1", "tokio", "tokio-stream", - "tonic 0.13.1", - "tonic-build", + "tonic 0.14.2", + "tonic-build 0.14.2", + "tonic-prost", + "tonic-prost-build", "tower", "tower-service", ] @@ -2825,7 +2831,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.1", "system-configuration", "tokio", "tower-service", @@ -3215,7 +3221,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3937,7 +3943,7 @@ dependencies = [ "thiserror 2.0.17", "tokio", "tonic 0.13.1", - "tonic-build", + "tonic-build 0.13.1", "tracing", ] @@ -3994,6 +4000,31 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "ndarray-interp" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e43087829efb5ec2736598e88587df286425b59df5a9ce991994cdd2c5855d3f" +dependencies = [ + "ndarray", + "num-traits", + "thiserror 2.0.17", +] + +[[package]] +name = "ndarray-npy" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b313788c468c49141a9d9b6131fc15f403e6ef4e8446a0b2e18f664ddb278a9" +dependencies = [ + "byteorder", + "ndarray", + "num-complex", + "num-traits", + "py_literal", + "zip 2.4.2", +] + [[package]] name = "neli" version = "0.6.5" @@ -4086,9 +4117,9 @@ dependencies = [ [[package]] name = "nixl-sys" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a73b92494c94b2ff2d004cd9274d966863089e867dc9cd98bc640aefe7622036" +checksum = "6d80bd4b5b8363cfd933000a8757a453e58ee10ee6e400c38ae31db512444a31" dependencies = [ "bindgen 0.71.1", "cc", @@ -4170,7 +4201,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -5097,7 +5128,29 @@ dependencies = [ "petgraph", "prettyplease", "prost 0.13.5", - "prost-types", + "prost-types 0.13.5", + "regex", + "syn 2.0.110", + "tempfile", +] + +[[package]] +name = "prost-build" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +dependencies = [ + "heck", + "itertools 0.14.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost 0.14.1", + "prost-types 0.14.1", + "pulldown-cmark", + "pulldown-cmark-to-cmark", "regex", "syn 2.0.110", "tempfile", @@ -5138,6 +5191,15 @@ dependencies = [ "prost 0.13.5", ] +[[package]] +name = "prost-types" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +dependencies = [ + "prost 0.14.1", +] + [[package]] name = "protobuf" version = "3.7.2" @@ -5158,6 +5220,26 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "pulldown-cmark" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0" +dependencies = [ + "bitflags 2.10.0", + "memchr", + "unicase", +] + +[[package]] +name = "pulldown-cmark-to-cmark" +version = "21.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8246feae3db61428fd0bb94285c690b460e4517d83152377543ca802357785f1" +dependencies = [ + "pulldown-cmark", +] + [[package]] name = "pulp" version = "0.18.22" @@ -5193,6 +5275,19 @@ dependencies = [ "num-traits", ] +[[package]] +name = "py_literal" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1" +dependencies = [ + "num-bigint", + "num-complex", + "num-traits", + "pest", + "pest_derive", +] + [[package]] name = "pyo3" version = "0.23.5" @@ -5322,7 +5417,7 @@ dependencies = [ "quinn-udp", "rustc-hash 2.1.1", "rustls", - "socket2 0.5.10", + "socket2 0.6.1", "thiserror 2.0.17", "tokio", "tracing", @@ -5359,9 +5454,9 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.1", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -5864,7 +5959,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -6660,7 +6755,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -7070,7 +7165,6 @@ dependencies = [ "prost 0.13.5", "socket2 0.5.10", "tokio", - "tokio-rustls", "tokio-stream", "tower", "tower-layer", @@ -7085,8 +7179,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" dependencies = [ "async-trait", + "axum", "base64 0.22.1", "bytes", + "h2", "http", "http-body", "http-body-util", @@ -7095,8 +7191,10 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", + "socket2 0.6.1", "sync_wrapper", "tokio", + "tokio-rustls", "tokio-stream", "tower", "tower-layer", @@ -7112,8 +7210,20 @@ checksum = "eac6f67be712d12f0b41328db3137e0d0757645d8904b4cb7d51cd9c2279e847" dependencies = [ "prettyplease", "proc-macro2", - "prost-build", - "prost-types", + "prost-build 0.13.5", + "prost-types 0.13.5", + "quote", + "syn 2.0.110", +] + +[[package]] +name = "tonic-build" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c40aaccc9f9eccf2cd82ebc111adc13030d23e887244bc9cfa5d1d636049de3" +dependencies = [ + "prettyplease", + "proc-macro2", "quote", "syn 2.0.110", ] @@ -7129,6 +7239,22 @@ dependencies = [ "tonic 0.14.2", ] +[[package]] +name = "tonic-prost-build" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4a16cba4043dc3ff43fcb3f96b4c5c154c64cbd18ca8dce2ab2c6a451d058a2" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build 0.14.1", + "prost-types 0.14.1", + "quote", + "syn 2.0.110", + "tempfile", + "tonic-build 0.14.2", +] + [[package]] name = "tower" version = "0.5.2" @@ -7840,7 +7966,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] @@ -8407,6 +8533,23 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "zip" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "flate2", + "indexmap 2.12.0", + "memchr", + "thiserror 2.0.17", + "zopfli", +] + [[package]] name = "zip" version = "3.0.0" diff --git a/lib/bindings/kvbm/python/kvbm/trtllm_integration/connector/kvbm_connector_leader.py b/lib/bindings/kvbm/python/kvbm/trtllm_integration/connector/kvbm_connector_leader.py index 836fd9d7fb..aa7628fd2f 100644 --- a/lib/bindings/kvbm/python/kvbm/trtllm_integration/connector/kvbm_connector_leader.py +++ b/lib/bindings/kvbm/python/kvbm/trtllm_integration/connector/kvbm_connector_leader.py @@ -118,6 +118,12 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes: output = RustSchedulerOutput() for req in scheduler_output.new_requests: + if not hasattr(req, "num_scheduled_tokens"): + raise ValueError( + """num_scheduled_tokens is not found in the SchedulerOutput! + This indicates you're using an older, unsupported version of TensorRT-LLM. + Are you running at least TRTLLM v1.2.0rc2?""" + ) output.add_new_request( str(req.request_id), req.new_tokens, @@ -135,6 +141,14 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes: req.computed_position, ) + output.add_num_scheduled_tokens( + { + str(req.request_id): req.num_scheduled_tokens + for req in scheduler_output.new_requests + + scheduler_output.cached_requests + } + ) + return self._connector.build_connector_metadata(output) def get_num_new_matched_tokens( diff --git a/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs b/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs index f753d46496..16706f1210 100644 --- a/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs +++ b/lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs @@ -110,18 +110,6 @@ pub trait Slot: std::fmt::Debug { num_scheduled_tokens: usize, ) -> Result<(), SlotError>; - // TRT-LLM does not include scheduled tokens in the scheduler output. - // Ideally, we should have a dedicated implementation for the TRT-LLM slot. - // However, since only this single function needs to be rewritten for now, - // we keep it as a separate function in Slot. - fn apply_scheduler_output_with_computed_position( - &mut self, - tokens: &[u32], - block_ids: &[usize], - computed_position: usize, - is_new_request: bool, - ) -> Result<(), SlotError>; - fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>; fn mark_as_prefilling(&mut self, iteration: u64) -> Result<(), SlotError>; @@ -642,111 +630,6 @@ impl Slot for VllmConnectorSlot { Ok(()) } - #[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id.as_str()))] - fn apply_scheduler_output_with_computed_position( - &mut self, - tokens: &[u32], - block_ids: &[usize], - computed_position: usize, - is_new_request: bool, - ) -> Result<(), SlotError> { - // TRTLLM's KV Connector Manager will have (computed_position - external matches) - // in onborading case - if computed_position < self.current_position { - tracing::debug!( - "computed_position={} < current_position={}, so we are onboarding during prefilling phase", - computed_position, - self.current_position - ); - return Ok(()); - } - - // now we decide what we should do for the new computed tokens - tracing::debug!( - "applying scheduler output, computed_position={}, sequence_total_tokens={}", - computed_position, - self.sequence.total_tokens() - ); - - if computed_position < self.sequence.total_tokens() { - // no need to apply new tokens, since it's applied when created the slot during prefilling - self.state = SlotState::Prefilling; - } else { - tracing::debug!( - "appending {} newly decoded tokens to sequence", - tokens.len() - ); - self.sequence.extend(tokens.into()).unwrap(); - self.state = SlotState::Decoding; - } - - // apply new block_ids, this should be applied for both prefilling and decoding - // because this is unknown when creating the slot - if !block_ids.is_empty() { - tracing::debug!("assigning {} new device blocks slot", block_ids.len()); - self.device_blocks.extend(block_ids); - } - - // This approach is fragile, but it’s the only way currently to skip evaluating - // the device matched blocks and to avoid offloading them again. - // TODO: Consider adding an indicator in the scheduler output to distinguish between - // matched and unmatched device blocks/tokens from the scheduler. - let maybe_have_device_matched_blocks = - is_new_request && computed_position > 0 && self.evaluated_blocks == 0; - - if maybe_have_device_matched_blocks { - self.evaluated_blocks = (computed_position + 1) / self.block_size; - } - - let num_candidate_blocks = - ((computed_position + 1) / self.block_size).saturating_sub(self.evaluated_blocks); - - if num_candidate_blocks > 0 { - // do we have a mechanism for skipping gpu cache hit blocks? not sure yet. - // for now, offload all the blocks to the host - let offload_block_ids: Vec = self - .device_blocks - .iter() - .skip(self.evaluated_blocks) - .take(num_candidate_blocks) - .copied() - .collect::>(); - - assert_eq!( - offload_block_ids.len(), - num_candidate_blocks, - "device block overflow - candidate blocks exceed block count at offset {}", - self.evaluated_blocks - ); - - let offload_token_blocks: Vec = self - .sequence - .blocks() - .iter() - .skip(self.evaluated_blocks) - .take(num_candidate_blocks) - .cloned() - .collect::>(); - - self.offload_blocks(&offload_block_ids, &offload_token_blocks) - .expect("failed to offload blocks"); - - self.evaluated_blocks += num_candidate_blocks; - } - - // done applying policy - tracing::debug!( - "done applying kv cache policy at current_position: {}; computed_position: {}", - self.current_position, - computed_position, - ); - - // advance current position to computed position - self.current_position = computed_position; - - Ok(()) - } - fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError> { if self.iteration_first_scheduled.is_none() { self.iteration_first_scheduled = Some(iteration); diff --git a/lib/bindings/kvbm/src/block_manager/vllm/connector/trtllm_leader.rs b/lib/bindings/kvbm/src/block_manager/vllm/connector/trtllm_leader.rs index e734dedb7b..020d72c13e 100644 --- a/lib/bindings/kvbm/src/block_manager/vllm/connector/trtllm_leader.rs +++ b/lib/bindings/kvbm/src/block_manager/vllm/connector/trtllm_leader.rs @@ -351,11 +351,16 @@ impl Leader for KvConnectorLeader { slot.state() ); - slot.apply_scheduler_output_with_computed_position( + let scheduled_tokens = *scheduler_output + .num_scheduled_tokens + .get(request_id) + .unwrap_or(&0); + + slot.apply_scheduler_output( &new_req.prompt_token_ids, &new_req.block_ids, new_req.num_computed_tokens, - true, + scheduled_tokens, )?; if let Some(pending_ops) = slot.take_pending_operations() { @@ -382,11 +387,16 @@ impl Leader for KvConnectorLeader { .lock() .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; - slot.apply_scheduler_output_with_computed_position( + let scheduled_tokens = *scheduler_output + .num_scheduled_tokens + .get(request_id) + .unwrap_or(&0); + + slot.apply_scheduler_output( &cached_req.new_token_ids, &cached_req.new_block_ids, cached_req.num_computed_tokens, - false, + scheduled_tokens, )?; if let Some(pending_ops) = slot.take_pending_operations() { diff --git a/tests/kvbm_integration/test_determinism_disagg.py b/tests/kvbm_integration/test_determinism_disagg.py index 5451fb0a6c..422578dc26 100755 --- a/tests/kvbm_integration/test_determinism_disagg.py +++ b/tests/kvbm_integration/test_determinism_disagg.py @@ -21,12 +21,14 @@ import signal import subprocess import time +from copy import deepcopy from datetime import datetime from pathlib import Path -from typing import Optional, TextIO +from typing import Any, Dict, Optional, TextIO import pytest import requests +import yaml from .common import DeterminismTester, ServerType from .common import TestDeterminism as BaseTestDeterminism @@ -105,6 +107,8 @@ def __init__( if self.server_type == ServerType.vllm: self._set_up_vllm_config(gpu_cache_blocks) + elif self.server_type == ServerType.trtllm: + self._set_up_trtllm_config(gpu_cache_blocks) else: raise ValueError( f"{self.server_type} is not supported yet in the KVBM test suite" @@ -165,6 +169,84 @@ def _set_up_vllm_config(self, gpu_cache_blocks): ["--num-gpu-blocks-override", str(gpu_cache_blocks)] ) + def _set_up_trtllm_config(self, gpu_cache_blocks): + # Mostly the same parameters here as in the + prefill_config_path = os.environ.get( + "KVBM_TRTLLM_LLMAPI_PREFILL_CONFIG_PATH", + "/tmp/kvbm_llm_api_prefill_config.yaml", + ) + + decode_config_path = os.environ.get( + "KVBM_TRTLLM_LLMAPI_DECODE_CONFIG_PATH", + "/tmp/kvbm_llm_api_decode_config.yaml", + ) + + llm_api_config: Dict[str, Any] = {} + llm_api_config["kv_cache_config"] = { + "enable_partial_reuse": False, + "free_gpu_memory_fraction": 0.10, + "tokens_per_block": 16, + } + + # GPU blocks override + if gpu_cache_blocks is not None: + del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"] + llm_api_config["kv_cache_config"]["max_tokens"] = ( + int(gpu_cache_blocks) * 32 + ) # TRTLLM defaults 32 tokens per block + + prefill_config = deepcopy(llm_api_config) + prefill_config["disable_overlap_scheduler"] = True + prefill_config["cache_transceiver_config"] = { + "backend": "DEFAULT", + "max_tokens_in_buffer": 16384, + } + prefill_config["cuda_graph_config"] = None + + decode_config = deepcopy(llm_api_config) + decode_config["disable_overlap_scheduler"] = False + decode_config["cache_transceiver_config"] = { + "backend": "DEFAULT", + "max_tokens_in_buffer": 65536, + } + + model = os.environ.get( + "KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" + ) + + cmd_root = [ + "python3", + "-m", + "dynamo.trtllm", + "--model", + model, + "--kv-block-size", + "16", + "--max-num-tokens", + "8000", + ] + + self.prefiller_cmd = cmd_root + [ + "--extra-engine-args", + prefill_config_path, + "--disaggregation-mode", + "prefill", + "--connector", + "kvbm", + ] + + self.decoder_cmd = cmd_root + [ + "--extra-engine-args", + decode_config_path, + "--disaggregation-mode", + "decode", + ] + + with open(prefill_config_path, "w") as f: + yaml.dump(prefill_config, f, default_flow_style=False, sort_keys=False) + with open(decode_config_path, "w") as f: + yaml.dump(decode_config, f, default_flow_style=False, sort_keys=False) + def start_server(self, timeout: int = 300) -> bool: """Start LLM server and wait for readiness.""" if self.is_server_running(): @@ -345,6 +427,7 @@ def is_server_running(self) -> bool: # First check basic health response = requests.get(f"{self.base_url}/health", timeout=5) if response.status_code != 200: + print(f"Health check failed with status code: {response.status_code}") return False # Then check if the model endpoint is ready with a simple test request @@ -363,9 +446,14 @@ def is_server_running(self) -> bool: json=test_payload, timeout=10, ) + if response.status_code != 200: + print( + f"Model endpoint test failed with status code: {response.status_code}" + ) return response.status_code == 200 - except requests.exceptions.RequestException: + except requests.exceptions.RequestException as e: + print(f"Error checking server status: {e}") return False @@ -419,6 +507,8 @@ def llm_server(request, runtime_services): if importlib.util.find_spec("vllm") is not None: server_type = ServerType.vllm + elif importlib.util.find_spec("tensorrt_llm") is not None: + server_type = ServerType.trtllm else: pytest.skip("vllm module is not available in the current environment.")