Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1443,13 +1443,16 @@ class CacheTransceiverConfig
UCX = 2,
NIXL = 3
};
explicit CacheTransceiverConfig(
std::optional<BackendType> backendType = std::nullopt, std::optional<size_t> maxNumTokens = std::nullopt);
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt,
std::optional<MillisecondsType> kvTransferTimeoutMs = std::nullopt);

bool operator==(CacheTransceiverConfig const& other) const;
void setBackendType(std::optional<BackendType> backendType);
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
void setKvTransferTimeoutMs(std::optional<MillisecondsType> kvTransferTimeoutMs);

[[nodiscard]] std::optional<std::chrono::milliseconds> getKvTransferTimeoutMs() const;
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
[[nodiscard]] std::optional<BackendType> getBackendType() const;

Expand All @@ -1459,6 +1462,7 @@ class CacheTransceiverConfig
/// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache
/// transfer may be degraded.
std::optional<size_t> mMaxTokensInBuffer;
std::optional<std::chrono::milliseconds> mKvTransferTimeoutMs;
};

/// @brief Configuration class for the model executor
Expand Down
18 changes: 15 additions & 3 deletions cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@
namespace tensorrt_llm::executor
{

CacheTransceiverConfig::CacheTransceiverConfig(
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens)
CacheTransceiverConfig::CacheTransceiverConfig(std::optional<BackendType> backendType,
std::optional<size_t> maxNumTokens, std::optional<std::chrono::milliseconds> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
{
}

bool CacheTransceiverConfig::operator==(CacheTransceiverConfig const& other) const
{
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType;
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType
&& mKvTransferTimeoutMs == other.mKvTransferTimeoutMs;
}

void CacheTransceiverConfig::setBackendType(std::optional<BackendType> backendType)
Expand All @@ -53,4 +55,14 @@ std::optional<size_t> CacheTransceiverConfig::getMaxTokensInBuffer() const
return mMaxTokensInBuffer;
}

void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional<std::chrono::milliseconds> kvTransferTimeoutMs)
{
mKvTransferTimeoutMs = kvTransferTimeoutMs;
}

std::optional<std::chrono::milliseconds> CacheTransceiverConfig::getKvTransferTimeoutMs() const
{
return mKvTransferTimeoutMs;
}

} // namespace tensorrt_llm::executor
15 changes: 10 additions & 5 deletions cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,16 @@ void initConfigBindings(nb::module_& m)
.def("__setstate__", guidedDecodingConfigSetstate);

auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
new (&self) tle::CacheTransceiverConfig(
nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), nb::cast<std::optional<size_t>>(state[1]));
nb::cast<std::optional<tle::CacheTransceiverConfig::BackendType>>(state[0]),
nb::cast<std::optional<size_t>>(state[1]), nb::cast<std::optional<std::chrono::milliseconds>>(state[2]));
};

nb::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
Expand All @@ -460,12 +461,16 @@ void initConfigBindings(nb::module_& m)
});

nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt)
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<std::chrono::milliseconds>>(),
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt,
nb::arg("kv_transfer_timeout_ms") = std::nullopt)
.def_prop_rw(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def("__getstate__", cacheTransceiverConfigGetstate)
.def("__setstate__", cacheTransceiverConfigSetstate);

Expand Down
55 changes: 49 additions & 6 deletions cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,26 @@ void initConfigBindings(pybind11::module_& m)
.def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate));

auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
{
auto timeout = self.getKvTransferTimeoutMs();
std::optional<int64_t> timeoutMs = timeout ? std::optional<int64_t>(timeout->count()) : std::nullopt;
return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), timeoutMs);
};
auto cacheTransceiverConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
return tle::CacheTransceiverConfig(
state[0].cast<tle::CacheTransceiverConfig::BackendType>(), state[1].cast<std::optional<size_t>>());
auto config
= tle::CacheTransceiverConfig(state[0].cast<std::optional<tle::CacheTransceiverConfig::BackendType>>(),
state[1].cast<std::optional<size_t>>());
auto timeoutMs = state[2].cast<std::optional<int64_t>>();
if (timeoutMs)
{
config.setKvTransferTimeoutMs(std::chrono::milliseconds(*timeoutMs));
}
return config;
};

py::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
Expand All @@ -443,12 +454,44 @@ void initConfigBindings(pybind11::module_& m)
});

py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt)
.def(py::init(
[](std::optional<tle::CacheTransceiverConfig::BackendType> backend, std::optional<size_t> maxTokens,
std::optional<int64_t> timeoutMs)
{
std::optional<std::chrono::milliseconds> timeout = std::nullopt;
if (timeoutMs)
{
timeout = std::chrono::milliseconds(*timeoutMs);
}
return tle::CacheTransceiverConfig(backend, maxTokens, timeout);
}),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt,
py::arg("kv_transfer_timeout_ms") = std::nullopt)
.def_property(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_property(
"kv_transfer_timeout_ms",
[](tle::CacheTransceiverConfig const& self) -> std::optional<int64_t>
{
auto timeout = self.getKvTransferTimeoutMs();
return timeout ? std::optional<int64_t>(timeout->count()) : std::nullopt;
},
[](tle::CacheTransceiverConfig& self, std::optional<int64_t> timeoutMs)
{
if (timeoutMs)
{
self.setKvTransferTimeoutMs(std::chrono::milliseconds(*timeoutMs));
}
else
{
self.setKvTransferTimeoutMs(std::nullopt);
}
})
.def("setKvTransferTimeoutMs",
[](tle::CacheTransceiverConfig& self, int64_t timeoutMs)
{ self.setKvTransferTimeoutMs(std::chrono::milliseconds(timeoutMs)); })
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));

auto executorConfigGetState = [](py::object const& self)
Expand Down
3 changes: 3 additions & 0 deletions examples/disaggregated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ cache_transceiver_config:
backend: <str>
# KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance.
max_tokens_in_buffer: <int>
# KV cache transfer timeout in milliseconds.
# For requests, if they do not send/receive the KV cache in time they are removed and cleaned up.
kv_transfer_timeout_ms: <int>
```

The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below.
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager,
tokens_per_block = kv_cache_manager.tokens_per_block
dtype = kv_cache_manager.dtype

self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
total_num_kv_heads_per_layer, head_dim,
tokens_per_block, world_config, dtype,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def __init__(
self.is_attention_dp_dummy = False
self.is_cuda_graph_dummy = False
self.py_lora_task_layer_module_configs = None
self.py_kv_transfer_start_time = None
self.py_to_cleanup = False

self.py_return_log_probs = return_log_probs
self.py_return_context_logits = return_context_logits
Expand Down
54 changes: 53 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ def _executor_loop_pp(self):

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()
self._check_kv_transfer_timeout()

# march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
Expand All @@ -857,6 +858,7 @@ def _prepare_and_schedule_batch(self):

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()

iter_stats = None
if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -973,6 +975,7 @@ def _executor_loop(self):
self._add_kv_cache_events()

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -1097,6 +1100,7 @@ def _executor_loop_overlap(self):
ctx_transmission_reqs=ctx_transmission_reqs)

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

def _process_previous_batch(self):
Expand Down Expand Up @@ -1263,6 +1267,36 @@ def _check_disagg_gen_transfer_status(self):

return

def _check_kv_transfer_timeout(self):
if not self.kv_cache_transceiver:
return
timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms
if timeout_ms is None or timeout_ms <= 0:
return
current_time = time.time()

for req in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_to_cleanup = True

for req in self.active_requests:
if req.is_disagg_generation_transmission_in_progress and req.py_kv_transfer_start_time is not None:
elapsed_time = (current_time -
req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms:
logger.warning(
f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_to_cleanup = True

return

@nvtx_range("_pad_attention_dp_dummy_request")
def _pad_attention_dp_dummy_request(self):
"""
Expand Down Expand Up @@ -1335,6 +1369,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
req.context_current_position = req.prompt_len
req.decoding_iter = 1
req.py_decoding_iter = 1
req.py_kv_transfer_start_time = None
first_gen_tokens = req.context_phase_params.first_gen_tokens
ctx_draft_tokens = req.context_phase_params.draft_tokens
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
Expand All @@ -1358,6 +1393,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
for req in new_gen_reqs:
self.kv_cache_transceiver.request_and_receive_async(req)

if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in new_gen_reqs:
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()

block_transfer = all([
req.is_disagg_generation_transmission_in_progress
for req in self.active_requests
Expand Down Expand Up @@ -1391,6 +1431,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
]

if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_transmission_reqs:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()

return ctx_transmission_reqs

def _forward_step(self,
Expand Down Expand Up @@ -1608,6 +1653,12 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

# Check if generation request needs cleanup due to KV cache transfer timeout
if request.py_to_cleanup:
request.state = LlmRequestState.GENERATION_COMPLETE
requests_to_terminate.append(request)
continue

if request.is_generation_only_request():
# If request is in transmission, so we don't need to emit a response
# Also, for the first iteration with overlap, we should skip since first
Expand Down Expand Up @@ -1649,7 +1700,8 @@ def _handle_responses(self):
@nvtx_range("_terminate_ctx_finished_requests")
def _terminate_ctx_finished_requests(self):
for request in self.ctx_in_transmission_requests[:]:
if request.is_disagg_context_complete_state:
if request.is_disagg_context_complete_state or request.py_to_cleanup:
request.py_kv_transfer_start_time = None
self._terminate_request(request)
self.ctx_in_transmission_requests.remove(request)

Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,10 +1070,16 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
default=None,
description="The max number of tokens the transfer buffer can fit.")

kv_transfer_timeout_ms: Optional[int] = Field(
default=None,
description="Timeout in milliseconds for KV cache transfer operations. "
"Requests exceeding this timeout will be terminated.")

def _to_pybind(self):
return _CacheTransceiverConfig(
backend=_CacheTransceiverBackendType.from_string(self.backend),
max_tokens_in_buffer=self.max_tokens_in_buffer)
max_tokens_in_buffer=self.max_tokens_in_buffer,
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms)


@dataclass
Expand Down