diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 0a58298c279..0f51d30b714 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1443,13 +1443,16 @@ class CacheTransceiverConfig UCX = 2, NIXL = 3 }; - explicit CacheTransceiverConfig( - std::optional backendType = std::nullopt, std::optional maxNumTokens = std::nullopt); + explicit CacheTransceiverConfig(std::optional backendType = std::nullopt, + std::optional maxNumTokens = std::nullopt, + std::optional kvTransferTimeoutMs = std::nullopt); bool operator==(CacheTransceiverConfig const& other) const; void setBackendType(std::optional backendType); void setMaxTokensInBuffer(std::optional maxTokensInBuffer); + void setKvTransferTimeoutMs(std::optional kvTransferTimeoutMs); + [[nodiscard]] std::optional getKvTransferTimeoutMs() const; [[nodiscard]] std::optional getMaxTokensInBuffer() const; [[nodiscard]] std::optional getBackendType() const; @@ -1459,6 +1462,7 @@ class CacheTransceiverConfig /// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache /// transfer may be degraded. std::optional mMaxTokensInBuffer; + std::optional mKvTransferTimeoutMs; }; /// @brief Configuration class for the model executor diff --git a/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp b/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp index 6919d213642..3ea105e2dfb 100644 --- a/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp +++ b/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp @@ -21,16 +21,18 @@ namespace tensorrt_llm::executor { -CacheTransceiverConfig::CacheTransceiverConfig( - std::optional backendType, std::optional maxNumTokens) +CacheTransceiverConfig::CacheTransceiverConfig(std::optional backendType, + std::optional maxNumTokens, std::optional kvTransferTimeoutMs) : mBackendType(backendType) , mMaxTokensInBuffer(maxNumTokens) + , mKvTransferTimeoutMs(kvTransferTimeoutMs) { } bool CacheTransceiverConfig::operator==(CacheTransceiverConfig const& other) const { - return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType; + return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType + && mKvTransferTimeoutMs == other.mKvTransferTimeoutMs; } void CacheTransceiverConfig::setBackendType(std::optional backendType) @@ -53,4 +55,14 @@ std::optional CacheTransceiverConfig::getMaxTokensInBuffer() const return mMaxTokensInBuffer; } +void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional kvTransferTimeoutMs) +{ + mKvTransferTimeoutMs = kvTransferTimeoutMs; +} + +std::optional CacheTransceiverConfig::getKvTransferTimeoutMs() const +{ + return mKvTransferTimeoutMs; +} + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp index 505ecfca595..5b192c9678a 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -429,15 +429,16 @@ void initConfigBindings(nb::module_& m) .def("__setstate__", guidedDecodingConfigSetstate); auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) - { return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; + { return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); }; auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid CacheTransceiverConfig state!"); } new (&self) tle::CacheTransceiverConfig( - nb::cast(state[0]), nb::cast>(state[1])); + nb::cast>(state[0]), + nb::cast>(state[1]), nb::cast>(state[2])); }; nb::enum_(m, "CacheTransceiverBackendType") @@ -460,12 +461,16 @@ void initConfigBindings(nb::module_& m) }); nb::class_(m, "CacheTransceiverConfig") - .def(nb::init, std::optional>(), - nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt) + .def(nb::init, std::optional, + std::optional>(), + nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt, + nb::arg("kv_transfer_timeout_ms") = std::nullopt) .def_prop_rw( "backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) .def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, &tle::CacheTransceiverConfig::setMaxTokensInBuffer) + .def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs, + &tle::CacheTransceiverConfig::setKvTransferTimeoutMs) .def("__getstate__", cacheTransceiverConfigGetstate) .def("__setstate__", cacheTransceiverConfigSetstate); diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index 0e279a3e47b..dce82051467 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -412,15 +412,26 @@ void initConfigBindings(pybind11::module_& m) .def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate)); auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) - { return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; + { + auto timeout = self.getKvTransferTimeoutMs(); + std::optional timeoutMs = timeout ? std::optional(timeout->count()) : std::nullopt; + return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), timeoutMs); + }; auto cacheTransceiverConfigSetstate = [](py::tuple const& state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid CacheTransceiverConfig state!"); } - return tle::CacheTransceiverConfig( - state[0].cast(), state[1].cast>()); + auto config + = tle::CacheTransceiverConfig(state[0].cast>(), + state[1].cast>()); + auto timeoutMs = state[2].cast>(); + if (timeoutMs) + { + config.setKvTransferTimeoutMs(std::chrono::milliseconds(*timeoutMs)); + } + return config; }; py::enum_(m, "CacheTransceiverBackendType") @@ -443,12 +454,44 @@ void initConfigBindings(pybind11::module_& m) }); py::class_(m, "CacheTransceiverConfig") - .def(py::init, std::optional>(), - py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt) + .def(py::init( + [](std::optional backend, std::optional maxTokens, + std::optional timeoutMs) + { + std::optional timeout = std::nullopt; + if (timeoutMs) + { + timeout = std::chrono::milliseconds(*timeoutMs); + } + return tle::CacheTransceiverConfig(backend, maxTokens, timeout); + }), + py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt, + py::arg("kv_transfer_timeout_ms") = std::nullopt) .def_property( "backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) .def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, &tle::CacheTransceiverConfig::setMaxTokensInBuffer) + .def_property( + "kv_transfer_timeout_ms", + [](tle::CacheTransceiverConfig const& self) -> std::optional + { + auto timeout = self.getKvTransferTimeoutMs(); + return timeout ? std::optional(timeout->count()) : std::nullopt; + }, + [](tle::CacheTransceiverConfig& self, std::optional timeoutMs) + { + if (timeoutMs) + { + self.setKvTransferTimeoutMs(std::chrono::milliseconds(*timeoutMs)); + } + else + { + self.setKvTransferTimeoutMs(std::nullopt); + } + }) + .def("setKvTransferTimeoutMs", + [](tle::CacheTransceiverConfig& self, int64_t timeoutMs) + { self.setKvTransferTimeoutMs(std::chrono::milliseconds(timeoutMs)); }) .def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate)); auto executorConfigGetState = [](py::object const& self) diff --git a/examples/disaggregated/README.md b/examples/disaggregated/README.md index 196113d9872..4e5cf30eaf2 100644 --- a/examples/disaggregated/README.md +++ b/examples/disaggregated/README.md @@ -16,6 +16,9 @@ cache_transceiver_config: backend: # KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance. max_tokens_in_buffer: + # KV cache transfer timeout in milliseconds. + # For requests, if they do not send/receive the KV cache in time they are removed and cleaned up. + kv_transfer_timeout_ms: ``` The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below. diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index eb1f2019781..aaf7ec2fe4a 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -101,6 +101,7 @@ def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager, tokens_per_block = kv_cache_manager.tokens_per_block dtype = kv_cache_manager.dtype + self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms self.impl = CacheTransceiverCpp(kv_cache_manager.impl, total_num_kv_heads_per_layer, head_dim, tokens_per_block, world_config, dtype, diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 80f1153e504..6be170146ab 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -329,6 +329,8 @@ def __init__( self.is_attention_dp_dummy = False self.is_cuda_graph_dummy = False self.py_lora_task_layer_module_configs = None + self.py_kv_transfer_start_time = None + self.py_to_cleanup = False self.py_return_log_probs = return_log_probs self.py_return_context_logits = return_context_logits diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index d4886887cef..a6cf82c31c9 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -841,6 +841,7 @@ def _executor_loop_pp(self): if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() + self._check_kv_transfer_timeout() # march forward in microbatch slots microbatch_id = (microbatch_id + 1) % self.num_micro_batches @@ -857,6 +858,7 @@ def _prepare_and_schedule_batch(self): if self.kv_cache_transceiver: self._check_disagg_gen_transfer_status() + self._check_kv_transfer_timeout() iter_stats = None if self.enable_iter_perf_stats: @@ -973,6 +975,7 @@ def _executor_loop(self): self._add_kv_cache_events() if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() self._terminate_ctx_finished_requests() if self.enable_iter_perf_stats: @@ -1097,6 +1100,7 @@ def _executor_loop_overlap(self): ctx_transmission_reqs=ctx_transmission_reqs) if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() self._terminate_ctx_finished_requests() def _process_previous_batch(self): @@ -1263,6 +1267,36 @@ def _check_disagg_gen_transfer_status(self): return + def _check_kv_transfer_timeout(self): + if not self.kv_cache_transceiver: + return + timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms + if timeout_ms is None or timeout_ms <= 0: + return + current_time = time.time() + + for req in self.ctx_in_transmission_requests: + if req.py_kv_transfer_start_time is None: + continue + elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000 + if elapsed_time > timeout_ms: + logger.warning( + f"Terminating context request {req.py_request_id} due to KV cache transfer timeout" + ) + req.py_to_cleanup = True + + for req in self.active_requests: + if req.is_disagg_generation_transmission_in_progress and req.py_kv_transfer_start_time is not None: + elapsed_time = (current_time - + req.py_kv_transfer_start_time) * 1000 + if elapsed_time > timeout_ms: + logger.warning( + f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout" + ) + req.py_to_cleanup = True + + return + @nvtx_range("_pad_attention_dp_dummy_request") def _pad_attention_dp_dummy_request(self): """ @@ -1335,6 +1369,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): req.context_current_position = req.prompt_len req.decoding_iter = 1 req.py_decoding_iter = 1 + req.py_kv_transfer_start_time = None first_gen_tokens = req.context_phase_params.first_gen_tokens ctx_draft_tokens = req.context_phase_params.draft_tokens req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens @@ -1358,6 +1393,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs): for req in new_gen_reqs: self.kv_cache_transceiver.request_and_receive_async(req) + if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: + for req in new_gen_reqs: + if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS: + req.py_kv_transfer_start_time = time.time() + block_transfer = all([ req.is_disagg_generation_transmission_in_progress for req in self.active_requests @@ -1391,6 +1431,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests): if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS ] + if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: + for req in ctx_transmission_reqs: + if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS: + req.py_kv_transfer_start_time = time.time() + return ctx_transmission_reqs def _forward_step(self, @@ -1608,6 +1653,12 @@ def _handle_responses(self): requests_to_terminate.append(request) continue + # Check if generation request needs cleanup due to KV cache transfer timeout + if request.py_to_cleanup: + request.state = LlmRequestState.GENERATION_COMPLETE + requests_to_terminate.append(request) + continue + if request.is_generation_only_request(): # If request is in transmission, so we don't need to emit a response # Also, for the first iteration with overlap, we should skip since first @@ -1649,7 +1700,8 @@ def _handle_responses(self): @nvtx_range("_terminate_ctx_finished_requests") def _terminate_ctx_finished_requests(self): for request in self.ctx_in_transmission_requests[:]: - if request.is_disagg_context_complete_state: + if request.is_disagg_context_complete_state or request.py_to_cleanup: + request.py_kv_transfer_start_time = None self._terminate_request(request) self.ctx_in_transmission_requests.remove(request) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 0f377657261..5914f4b93c2 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1070,10 +1070,16 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror): default=None, description="The max number of tokens the transfer buffer can fit.") + kv_transfer_timeout_ms: Optional[int] = Field( + default=None, + description="Timeout in milliseconds for KV cache transfer operations. " + "Requests exceeding this timeout will be terminated.") + def _to_pybind(self): return _CacheTransceiverConfig( backend=_CacheTransceiverBackendType.from_string(self.backend), - max_tokens_in_buffer=self.max_tokens_in_buffer) + max_tokens_in_buffer=self.max_tokens_in_buffer, + kv_transfer_timeout_ms=self.kv_transfer_timeout_ms) @dataclass