Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,8 @@ void CacheFormatter::unformat(TransferSession& session)
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
{
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
TLLM_LOG_WARNING("self: %d dest %d", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
Comment on lines 816 to +818
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incorrect printf specifier for size_t — will trigger warnings or UB

size() returns size_t, yet the format string uses %d, which expects int. On LP64/LLP64 targets this causes a type/width mismatch.

-        TLLM_LOG_WARNING("self: %d dest %d", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
-            destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
+        TLLM_LOG_WARNING("self: %zu dest %zu",
+            selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
+            destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
TLLM_LOG_WARNING("self: %d dest %d", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
TLLM_LOG_WARNING("self: %zu dest %zu",
selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp around lines 816 to 818,
the format specifier %d is incorrectly used for size_t values returned by
size(), which can cause warnings or undefined behavior. Replace %d with %zu in
the TLLM_LOG_WARNING calls to correctly format size_t types.

return false;
}
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();
Expand Down
5 changes: 4 additions & 1 deletion scripts/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def clear_folder(folder_path):
if os.path.isdir(item_path) and not os.path.islink(item_path):
rmtree(item_path)
else:
os.remove(item_path)
try:
os.remove(item_path)
except:
print(f"failed to remove {item_path}")
Comment on lines +74 to +77
Copy link
Contributor

@coderabbitai coderabbitai bot Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Improve error handling specificity and reporting.

The bare except clause catches all exceptions including system errors and keyboard interrupts, which can mask critical issues. Additionally, the current implementation doesn't provide information about what went wrong.

Apply this diff to improve error handling:

-            try:
-                os.remove(item_path)
-            except:
-                print(f"failed to remove {item_path}")
+            try:
+                os.remove(item_path)
+            except (OSError, IOError) as e:
+                print(f"Failed to remove {item_path}: {e}", file=sys.stderr)

This change:

  • Catches specific filesystem-related exceptions
  • Provides detailed error information including the underlying cause
  • Uses sys.stderr for error output following best practices
  • Allows critical exceptions like KeyboardInterrupt to propagate normally
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
os.remove(item_path)
except:
print(f"failed to remove {item_path}")
try:
os.remove(item_path)
except (OSError, IOError) as e:
print(f"Failed to remove {item_path}: {e}", file=sys.stderr)
🧰 Tools
🪛 Ruff (0.12.2)

76-76: Do not use bare except

(E722)

🤖 Prompt for AI Agents
In scripts/build_wheel.py around lines 74 to 77, replace the bare except clause
with a more specific exception handler that catches only filesystem-related
errors such as OSError. Modify the except block to capture the exception
instance and print a detailed error message including the exception details to
sys.stderr. This will improve error reporting and allow critical exceptions like
KeyboardInterrupt to propagate normally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reasonsolo can we apply this suggestion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!



def sysconfig_scheme(override_vars=None):
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager,
attention_type: AttentionTypeCpp,
cache_transceiver_config: CacheTransceiverConfig):
world_config = mapping_to_world_config(mapping)
num_kv_heads_per_layer = kv_cache_manager.num_kv_heads_per_layer
total_num_kv_heads_per_layer = kv_cache_manager.total_num_kv_heads_per_layer
head_dim = kv_cache_manager.head_dim
tokens_per_block = kv_cache_manager.tokens_per_block
dtype = kv_cache_manager.dtype

self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
num_kv_heads_per_layer, head_dim,
total_num_kv_heads_per_layer, head_dim,
tokens_per_block, world_config, dtype,
attention_type,
cache_transceiver_config)
Expand Down
57 changes: 55 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class BatchState:
@dataclasses.dataclass
class BatchStatePP(BatchState):
microbatch_id: int = -1
scheduled_ctx_reqs: list[LlmRequest] = None


class PyExecutor:
Expand Down Expand Up @@ -643,6 +644,7 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
return False

def _executor_loop_pp(self):
logger.info(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
torch.cuda.set_device(self.device_id)
microbatch_id = 0
with self._profiler() as profile_step:
Expand All @@ -656,6 +658,9 @@ def _executor_loop_pp(self):
if self.should_stop_processing:
break

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()

if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
Expand All @@ -664,9 +669,27 @@ def _executor_loop_pp(self):

self._pad_attention_dp_dummy_request()

scheduled_batch, _, _ = self._schedule()
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)

if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)

if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self.kv_cache_transceiver.check_context_transfer_status(
1)
else:
assert scheduled_batch.batch_size > 0, (
"fail to schedule any pending request, "
"probably run out of resource.")

self.num_scheduled_requests = scheduled_batch.batch_size

logger.debug(
f'has {len(self.active_requests)} active_request, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
Expand All @@ -679,7 +702,7 @@ def _executor_loop_pp(self):
can_queue = 0 not in tp_batch_sizes
else:
can_queue = scheduled_batch.batch_size > 0
if not can_queue:
if not can_queue and not self.kv_cache_transceiver:
assert len(self.inflight_req_ids) > 0, (
"fail to schedule any pending request, probably run out of resource"
)
Expand All @@ -688,8 +711,28 @@ def _executor_loop_pp(self):
self.micro_batches[microbatch_id] = None
else:
self._add_inflight_ids(scheduled_batch)

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)

self.resource_manager.prepare_resources(scheduled_batch)

# The generation requests that are do not have batch_idx,
# needs to be in front of the batch due to the assumptions
# made in model_engine.py::_forward_step. This is only important
# for disaggregated serving. For non-disaggregated serving,
# the generation requests always have batch_idx.
scheduled_batch.generation_requests = sorted( # stable sort
scheduled_batch.generation_requests,
key=lambda req: int(req.py_batch_idx is not None),
)

if self.kv_cache_transceiver:
# Return the first token to the client
self._handle_first_token_response(scheduled_batch)

# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
if not self.dist.is_last_pp_rank:
sample_state = self._forward_step_inter_pp(
Expand Down Expand Up @@ -720,6 +763,7 @@ def _executor_loop_pp(self):
iter_start_time=iter_start_time,
iter_stats=iter_stats,
microbatch_id=microbatch_id,
scheduled_ctx_reqs=scheduled_batch.context_requests,
)

self.micro_batches[microbatch_id] = batch_state
Expand Down Expand Up @@ -784,6 +828,12 @@ def _executor_loop_pp(self):
if previous_batch is not None:
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
self._update_requests(previous_batch.sample_state)

if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
self._send_disagg_ctx_cache(
previous_batch.scheduled_ctx_reqs
) if self.kv_cache_transceiver else []

Comment on lines +832 to +836
Copy link
Contributor

@coderabbitai coderabbitai bot Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove redundant conditional check.

The ternary operator on line 835 is redundant since the code is already inside an if self.kv_cache_transceiver block.

-                        if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
-                            self._send_disagg_ctx_cache(
-                                previous_batch.scheduled_ctx_reqs
-                            ) if self.kv_cache_transceiver else []
+                        if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
+                            self._send_disagg_ctx_cache(
+                                previous_batch.scheduled_ctx_reqs
+                            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
self._send_disagg_ctx_cache(
previous_batch.scheduled_ctx_reqs
) if self.kv_cache_transceiver else []
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
self._send_disagg_ctx_cache(
previous_batch.scheduled_ctx_reqs
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 832 to 836, remove
the redundant ternary conditional on line 835 that checks
self.kv_cache_transceiver again inside an if block that already ensures
self.kv_cache_transceiver is truthy. Simply call
self._send_disagg_ctx_cache(previous_batch.scheduled_ctx_reqs) directly without
the ternary operator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reasonsolo @raayandhar can you apply this suggestion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

self._handle_canceled_requests()
finished_requests = self._handle_responses()
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
Expand All @@ -792,6 +842,9 @@ def _executor_loop_pp(self):
self._remove_inflight_ids(previous_scheduled_batch)
self.micro_batches[prev_microbatch_id] = None

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()

# march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches

Expand Down
25 changes: 20 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,33 @@ def __init__(
(num_kv_heads + tp_size - 1) // tp_size
for _ in range(self.num_local_layers)
]
self.total_num_kv_heads_per_layer = [
(num_kv_heads + tp_size - 1) // tp_size
for _ in range(self.num_layers)
]
else:
assert len(num_kv_heads) == self.num_layers

def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
kv_head: Optional[int]):
if kv_head is not None:
num_kv_heads_per_layer.append(
(kv_head + tp_size - 1) // tp_size)
else:
num_kv_heads_per_layer.append(0)

self.num_kv_heads_per_layer = []
if self.num_local_layers > 0:
for i in self.pp_layers:
kv_head = num_kv_heads[i]
if kv_head is not None:
self.num_kv_heads_per_layer.append(
(kv_head + tp_size - 1) // tp_size)
else:
self.num_kv_heads_per_layer.append(0)
append_to_kv_heads_per_layer(self.num_kv_heads_per_layer,
kv_head)

self.total_num_kv_heads_per_layer = []
for i in range(self.num_layers):
kv_head = num_kv_heads[i]
append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer,
kv_head)

self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,14 @@ def setup_class(cls):
logger.set_level("info")
yield
logger.set_level(original_level)


def get_accuracy_task(dataset_name: str):
try:
task_class = globals()[dataset_name]
if issubclass(task_class, AccuracyTask):
return task_class
else:
raise ValueError(f"Unknown dataset: {dataset_name}.")
except KeyError:
raise ValueError(f"Not registered dataset: {dataset_name}.")
104 changes: 93 additions & 11 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import LlmArgs

from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_pre_hopper)
from ..trt_test_alternative import popen
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
get_accuracy_task)


class Result(GenerationResultBase):
Expand Down Expand Up @@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
temp_dir = tempfile.TemporaryDirectory()
disaggregated_serving_config_path = os.path.join(
temp_dir.name, "disaggregated_serving_config.yaml")

if tensor_parallel_size > 1:
print(
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
)

with open(disaggregated_serving_config_path, "w") as f:
yaml.dump(disaggregated_server_config, f)
ctx_server_config_path = os.path.join(temp_dir.name,
Expand All @@ -88,27 +96,38 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers
common_args = [
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
"pytorch"
trtllm_serve_path,
model_name,
"--host",
"localhost",
"--backend",
"pytorch",
]
gen_tp, gen_pp = gen_server_config.get("tensor_parallel_size",
1), gen_server_config.get(
"pipeline_parallel_size", 1)
ctx_tp, ctx_pp = ctx_server_config.get("tensor_parallel_size",
1), ctx_server_config.get(
"pipeline_parallel_size", 1)

if tensor_parallel_size > 1:
common_args.append(f"--tp_size={tensor_parallel_size}")
ctx_total_gpus = ctx_tp * ctx_pp
gen_total_gpus = gen_tp * gen_pp

env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(tensor_parallel_size)))
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))

env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
ctx_server_args = common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
]
gen_server_args = common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
]
if "max_num_tokens" in ctx_server_config:
ctx_server_args.append(
Expand Down Expand Up @@ -315,6 +334,69 @@ def test_eagle3(self, overlap_scheduler):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

def run_parallel_test(self, ctx_pp: int, ctx_tp: int, gen_pp: int,
gen_tp: int, test_set: LlmapiAccuracyTestHarness):
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
pytest.skip(
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
)

kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
}
ctx_server_config = {
"pipeline_parallel_size": ctx_pp,
"tensor_parallel_size": ctx_tp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "default"
}
}
gen_server_config = {
"tensor_parallel_size": gen_tp,
"pipeline_parallel_size": gen_pp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "default"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = test_set(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_tp_pp_symmetric(self, tp, pp, testset):
return self.run_parallel_test(pp, tp, pp, tp,
get_accuracy_task(testset))

# We focus on ctx+pp and gen+tp usecases for RTX6000D
@parametrize_with_ids("ctx_pp", [2, 4])
@parametrize_with_ids("gen_tp", [1, 2])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
return self.run_parallel_test(ctx_pp, 1, gen_tp,
get_accuracy_task(testset))
Comment on lines +396 to +398
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect method signature in test_ctx_pp_gen_tp_asymmetric.

The method is missing the gen_pp parameter - it's passing gen_tp where gen_pp is expected in run_parallel_test.

 def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
-    return self.run_parallel_test(ctx_pp, 1, gen_tp,
+    return self.run_parallel_test(ctx_pp, 1, 1, gen_tp,
                                   get_accuracy_task(testset))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
return self.run_parallel_test(ctx_pp, 1, gen_tp,
get_accuracy_task(testset))
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
return self.run_parallel_test(ctx_pp, 1, 1, gen_tp,
get_accuracy_task(testset))
🤖 Prompt for AI Agents
In tests/integration/defs/accuracy/test_disaggregated_serving.py around lines
480 to 482, the method test_ctx_pp_gen_tp_asymmetric is missing the gen_pp
parameter in its signature but is passing gen_tp where gen_pp is expected in the
call to run_parallel_test. Fix this by adding the gen_pp parameter to the method
signature and passing gen_pp instead of gen_tp to run_parallel_test.



@pytest.mark.skip_less_device_memory(140000)
@pytest.mark.timeout(3600)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
hostname: localhost
port: 8000
backend: "pytorch"
cuda_graph_config: null
free_gpu_memory_fraction: 0.2
context_servers:
num_instances: 1
max_batch_size: 1
max_num_tokens: 3000
max_seq_len: 4096
tensor_parallel_size: 1
pipeline_parallel_size: 2
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 2
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: default
urls:
- "localhost:8002"
Loading