Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1456,13 +1456,15 @@ class CacheTransceiverConfig
UCX = 2,
NIXL = 3
};
explicit CacheTransceiverConfig(
std::optional<BackendType> backendType = std::nullopt, std::optional<size_t> maxNumTokens = std::nullopt);
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt);

bool operator==(CacheTransceiverConfig const& other) const;
void setBackendType(std::optional<BackendType> backendType);
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
void setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs);

[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
[[nodiscard]] std::optional<BackendType> getBackendType() const;

Expand All @@ -1472,6 +1474,7 @@ class CacheTransceiverConfig
/// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache
/// transfer may be degraded.
std::optional<size_t> mMaxTokensInBuffer;
std::optional<int> mKvTransferTimeoutMs;
};

/// @brief Configuration class for the model executor
Expand Down
16 changes: 14 additions & 2 deletions cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ namespace tensorrt_llm::executor
{

CacheTransceiverConfig::CacheTransceiverConfig(
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens)
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
{
}

bool CacheTransceiverConfig::operator==(CacheTransceiverConfig const& other) const
{
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType;
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType
&& mKvTransferTimeoutMs == other.mKvTransferTimeoutMs;
}

void CacheTransceiverConfig::setBackendType(std::optional<BackendType> backendType)
Expand All @@ -43,6 +45,11 @@ void CacheTransceiverConfig::setMaxTokensInBuffer(std::optional<size_t> maxToken
mMaxTokensInBuffer = maxTokensInBuffer;
}

void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs)
{
mKvTransferTimeoutMs = kvTransferTimeoutMs;
}

std::optional<CacheTransceiverConfig::BackendType> CacheTransceiverConfig::getBackendType() const
{
return mBackendType;
Expand All @@ -53,4 +60,9 @@ std::optional<size_t> CacheTransceiverConfig::getMaxTokensInBuffer() const
return mMaxTokensInBuffer;
}

std::optional<int> CacheTransceiverConfig::getKvTransferTimeoutMs() const
{
return mKvTransferTimeoutMs;
}

} // namespace tensorrt_llm::executor
16 changes: 10 additions & 6 deletions cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,15 @@ void initConfigBindings(nb::module_& m)
.def("__setstate__", guidedDecodingConfigSetstate);

auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
new (&self) tle::CacheTransceiverConfig(
nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), nb::cast<std::optional<size_t>>(state[1]));
new (&self) tle::CacheTransceiverConfig(nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]),
nb::cast<std::optional<size_t>>(state[1]), nb::cast<std::optional<int>>(state[2]));
};

nb::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
Expand All @@ -464,12 +464,16 @@ void initConfigBindings(nb::module_& m)
});

nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt)
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(),
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt,
nb::arg("kv_transfer_timeout_ms") = std::nullopt)
.def_prop_rw(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def("__getstate__", cacheTransceiverConfigGetstate)
.def("__setstate__", cacheTransceiverConfigSetstate);

Expand Down
16 changes: 10 additions & 6 deletions cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,15 @@ void initConfigBindings(pybind11::module_& m)
.def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate));

auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
auto cacheTransceiverConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
return tle::CacheTransceiverConfig(
state[0].cast<tle::CacheTransceiverConfig::BackendType>(), state[1].cast<std::optional<size_t>>());
return tle::CacheTransceiverConfig(state[0].cast<tle::CacheTransceiverConfig::BackendType>(),
state[1].cast<std::optional<size_t>>(), state[2].cast<std::optional<int>>());
};

py::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
Expand All @@ -446,12 +446,16 @@ void initConfigBindings(pybind11::module_& m)
});

py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt)
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt,
py::arg("kv_transfer_timeout_ms") = std::nullopt)
.def_property(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));

auto executorConfigGetState = [](py::object const& self)
Expand Down
3 changes: 3 additions & 0 deletions examples/disaggregated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ cache_transceiver_config:
backend: <str>
# KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance.
max_tokens_in_buffer: <int>
# KV cache transfer timeout in milliseconds
# For requests, if they do not send/receive the KV cache in time they are cancelled and cleaned up
kv_transfer_timeout_ms: <int>
```

The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below.
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(self, mapping: Mapping, dist: Distributed,
# get the layer num per pp rank, which is required by cache transceiver.
pp_layer_num = len(kv_cache_manager.pp_layers)
pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num)

self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
total_num_kv_heads_per_layer, head_dim,
tokens_per_block, world_config,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ def __init__(
self.py_lora_task_layer_module_configs: list[
tensorrt_llm.bindings.internal.runtime.
TaskLayerModuleConfig] | None = None
self.py_kv_transfer_start_time = None
self.py_kv_transfer_timed_out = False

self.py_num_logprobs = num_logprobs
self.py_return_log_probs = return_log_probs
Expand Down
61 changes: 61 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ def _executor_loop_pp(self):
self.micro_batches[prev_microbatch_id] = None

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

if self._disagg_pp_termination_handler is not None:
Expand Down Expand Up @@ -1006,6 +1007,7 @@ def _prepare_and_schedule_batch(self):

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()

iter_stats = None
if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -1179,6 +1181,7 @@ def _executor_loop(self):
self._add_kv_cache_events()

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

self._kv_connector_terminate_requests()
Expand Down Expand Up @@ -1364,6 +1367,7 @@ def _executor_loop_overlap(self):
ctx_transmission_reqs=ctx_transmission_reqs)

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

self._kv_connector_terminate_requests()
Expand Down Expand Up @@ -1572,6 +1576,38 @@ def _check_disagg_gen_transfer_status(self):

return

@nvtx_range("_check_kv_transfer_timeout")
def _check_kv_transfer_timeout(self):
if not self.kv_cache_transceiver:
return
timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms
if timeout_ms is None or timeout_ms <= 0:
return

current_time = time.time()

for req in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True

for req in self.active_requests:
if req.is_disagg_generation_transmission_in_progress and req.py_kv_transfer_start_time is not None:
elapsed_time = (current_time -
req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True

return

@nvtx_range("_pad_attention_dp_dummy_request")
def _pad_attention_dp_dummy_request(self):
"""
Expand Down Expand Up @@ -1646,6 +1682,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
req.context_current_position = req.prompt_len
req.decoding_iter = 1
req.py_decoding_iter = 1
req.py_kv_transfer_start_time = None
first_gen_tokens = req.context_phase_params.first_gen_tokens
ctx_draft_tokens = req.context_phase_params.draft_tokens
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
Expand All @@ -1669,6 +1706,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
for req in new_gen_reqs:
self.kv_cache_transceiver.request_and_receive_async(req)

if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in new_gen_reqs:
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()

block_transfer = all([
req.is_disagg_generation_transmission_in_progress
for req in self.active_requests
Expand Down Expand Up @@ -1701,6 +1743,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
]

if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_in_transmission_requests:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()

return ctx_transmission_reqs

def _get_disagg_reqs_in_error_state(self):
Expand Down Expand Up @@ -2018,6 +2065,15 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

# Check if generation request needs cleanup due to KV cache transfer timeout
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for a ctx request to ever enter this code path here? If not, should we move this inside the if request.is_generation_only_requests to make it clearer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think a ctx request should be able to enter this code path since it loops over the generation requests in self.active_requests. But my understanding was that is_generation_only_request was for gen-only server requests? And if we had the ctx-gen disagg servers separate, we would never go down that code path in that case? I may be misunderstanding the is_generation_only_request case. Either way, I will test it.

if request.py_kv_transfer_timed_out:
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
if is_cancelled:
self._handle_errors(
error_msg=f"Request {py.request_id} timed out",
requests=[request])
continue

if request.is_generation_only_request():
# If request is in transmission, so we don't need to emit a response
# Also, for the first iteration with overlap, we should skip since first
Expand Down Expand Up @@ -2068,6 +2124,11 @@ def _handle_responses(self):
def _terminate_ctx_finished_requests(self):
for request, block_id in self.ctx_in_transmission_requests[:]:
if request.is_disagg_context_complete_state:
if request.py_kv_transfer_timed_out:
is_cancelled = self.kv_cache_transceiver.cancel_request(
request)
if is_cancelled:
request.py_kv_transfer_start_time = None
if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa:
self._terminate_request(request)
else:
Expand Down
9 changes: 8 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,10 +1286,17 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
default=None,
description="The max number of tokens the transfer buffer can fit.")

kv_transfer_timeout_ms: Optional[int] = Field(
default=None,
description=
"Timeout in milliseconds for KV cache transfer. Requests exceeding this timeout will be cancelled."
)

def _to_pybind(self):
return _CacheTransceiverConfig(
backend=_CacheTransceiverBackendType.from_string(self.backend),
max_tokens_in_buffer=self.max_tokens_in_buffer)
max_tokens_in_buffer=self.max_tokens_in_buffer,
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms)


@dataclass
Expand Down
Loading