Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleanup
Signed-off-by: jthomson04 <[email protected]>
  • Loading branch information
jthomson04 committed Oct 24, 2025
commit 1cad35a5011c25ce746eca73d444729d4c168e8e
6 changes: 3 additions & 3 deletions components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
KvCacheConfig,
SchedulerConfig,
)
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.metrics import MetricsCollector
Expand Down Expand Up @@ -102,8 +102,8 @@ async def get_engine_runtime_config(
# Return config with default/None values if retrieval fails
return runtime_config

def build_kv_connector_config(config: Config):

def build_kv_connector_config(config: Config):
if config.connector is not None:
if config.connector == "kvbm":
return KvCacheConnectorConfig(
Expand Down Expand Up @@ -287,7 +287,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Populate default sampling params from the model
tokenizer = tokenizer_factory(arg_map["model"])
default_sampling_params = SamplingParams()
default_sampling_params.end_id = tokenizer.eos_token_id
default_sampling_params._setup(tokenizer)
default_sampling_params.stop = None
model_input = ModelInput.Tokens
model_type = ModelType.Chat | ModelType.Completions
Expand Down
37 changes: 23 additions & 14 deletions tests/kvbm/test_determinism_disagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
Example reference: https://github.com/vllm-project/vllm/issues/7779#issuecomment-2304967870
"""

from copy import deepcopy
import importlib.util
import logging
import os
import signal
import subprocess
import time
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional, TextIO, Dict, Any
import yaml
from typing import Any, Dict, Optional, TextIO

import pytest
import requests
import yaml

from common import DeterminismTester, ServerType
from common import TestDeterminism as BaseTestDeterminism
from .common import DeterminismTester, ServerType
from .common import TestDeterminism as BaseTestDeterminism

# Test markers to align with repository conventions
# Todo: enable the rest when kvbm is built in the ci
Expand Down Expand Up @@ -167,14 +167,17 @@ def _set_up_vllm_config(self, gpu_cache_blocks):
self.prefiller_cmd.extend(
["--num-gpu-blocks-override", str(gpu_cache_blocks)]
)

def _set_up_trtllm_config(self, gpu_cache_blocks):
# Mostly the same parameters here as in the
# Mostly the same parameters here as in the
prefill_config_path = os.environ.get(
"KVBM_TRTLLM_LLMAPI_PREFILL_CONFIG_PATH", "/tmp/kvbm_llm_api_prefill_config.yaml"
"KVBM_TRTLLM_LLMAPI_PREFILL_CONFIG_PATH",
"/tmp/kvbm_llm_api_prefill_config.yaml",
)

decode_config_path = os.environ.get(
"KVBM_TRTLLM_LLMAPI_DECODE_CONFIG_PATH", "/tmp/kvbm_llm_api_decode_config.yaml"
"KVBM_TRTLLM_LLMAPI_DECODE_CONFIG_PATH",
"/tmp/kvbm_llm_api_decode_config.yaml",
)

llm_api_config: Dict[str, Any] = {}
Expand All @@ -183,7 +186,7 @@ def _set_up_trtllm_config(self, gpu_cache_blocks):
"free_gpu_memory_fraction": 0.10,
"tokens_per_block": 16,
}

# GPU blocks override
if gpu_cache_blocks is not None:
del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"]
Expand All @@ -206,7 +209,9 @@ def _set_up_trtllm_config(self, gpu_cache_blocks):
"max_tokens_in_buffer": 65536,
}

model = os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
model = os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)

cmd_root = [
"python3",
Expand All @@ -217,7 +222,7 @@ def _set_up_trtllm_config(self, gpu_cache_blocks):
"--kv-block-size",
"16",
"--max-num-tokens",
"8000"
"8000",
]

self.prefiller_cmd = cmd_root + [
Expand All @@ -226,7 +231,7 @@ def _set_up_trtllm_config(self, gpu_cache_blocks):
"--disaggregation-mode",
"prefill",
"--connector",
"kvbm"
"kvbm",
]

self.decoder_cmd = cmd_root + [
Expand Down Expand Up @@ -431,7 +436,9 @@ def is_server_running(self) -> bool:
timeout=10,
)
if response.status_code != 200:
print(f"Model endpoint test failed with status code: {response.status_code}")
print(
f"Model endpoint test failed with status code: {response.status_code}"
)
return response.status_code == 200

except requests.exceptions.RequestException as e:
Expand Down Expand Up @@ -492,7 +499,9 @@ def llm_server(request, runtime_services):
elif importlib.util.find_spec("tensorrt_llm") is not None:
server_type = ServerType.trtllm
else:
raise Exception("Neither the vllm nor the tensorrt_llm module is available in the current environment.")
raise Exception(
"Neither the vllm nor the tensorrt_llm module is available in the current environment."
)

server_manager = LLMServerManager(
port=port,
Expand Down
Loading