diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp index abac6d17ed8..8a7f73f3b06 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -84,21 +84,15 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m) .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus) .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete); - nb::enum_(m, "CommType") - .value("UNKNOWN", tb::CacheTransceiver::CommType::UNKNOWN) - .value("MPI", tb::CacheTransceiver::CommType::MPI) - .value("UCX", tb::CacheTransceiver::CommType::UCX) - .value("NIXL", tb::CacheTransceiver::CommType::NIXL); - nb::enum_(m, "AttentionType") .value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT) .value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA); nb::class_(m, "CacheTransceiver") - .def(nb::init, SizeType32, SizeType32, runtime::WorldConfig, nvinfer1::DataType, - executor::kv_cache::CacheState::AttentionType, std::optional>(), - nb::arg("cache_manager"), nb::arg("comm_type"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), + .def(nb::init, SizeType32, SizeType32, + runtime::WorldConfig, nvinfer1::DataType, executor::kv_cache::CacheState::AttentionType, + std::optional>(), + nb::arg("cache_manager"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("dtype"), nb::arg("attention_type"), nb::arg("cache_transceiver_config") = std::nullopt); @@ -106,5 +100,5 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m) .def(nb::init>(), nb::arg("cache_manager"), nb::arg("max_num_tokens") = std::nullopt) .def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize, - nb::arg("max_num_tokens") = std::nullopt); + nb::arg("cache_size_bytes_per_token_per_window"), nb::arg("cache_transceiver_config") = nb::none()); } diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp index c2d9fe25dff..e79bf3fa799 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -424,21 +424,44 @@ void initConfigBindings(nb::module_& m) .def("__getstate__", guidedDecodingConfigGetstate) .def("__setstate__", guidedDecodingConfigSetstate); - auto cacheTransceiverConfigGetstate - = [](tle::CacheTransceiverConfig const& self) { return nb::make_tuple(self.getMaxNumTokens()); }; + auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) + { return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state) { - if (state.size() != 1) + if (state.size() != 2) { throw std::runtime_error("Invalid CacheTransceiverConfig state!"); } - new (&self) tle::CacheTransceiverConfig(nb::cast>(state[0])); + new (&self) tle::CacheTransceiverConfig( + nb::cast(state[0]), nb::cast>(state[1])); }; + nb::enum_(m, "CacheTransceiverBackendType") + .value("DEFAULT", tle::CacheTransceiverConfig::BackendType::DEFAULT) + .value("MPI", tle::CacheTransceiverConfig::BackendType::MPI) + .value("UCX", tle::CacheTransceiverConfig::BackendType::UCX) + .value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL) + .def("from_string", + [](std::string const& str) + { + if (str == "DEFAULT" || str == "default") + return tle::CacheTransceiverConfig::BackendType::DEFAULT; + if (str == "MPI" || str == "mpi") + return tle::CacheTransceiverConfig::BackendType::MPI; + if (str == "UCX" || str == "ucx") + return tle::CacheTransceiverConfig::BackendType::UCX; + if (str == "NIXL" || str == "nixl") + return tle::CacheTransceiverConfig::BackendType::NIXL; + throw std::runtime_error("Invalid backend type: " + str); + }); + nb::class_(m, "CacheTransceiverConfig") - .def(nb::init>(), nb::arg("max_num_tokens") = nb::none()) - .def_prop_rw("max_num_tokens", &tle::CacheTransceiverConfig::getMaxNumTokens, - &tle::CacheTransceiverConfig::setMaxNumTokens) + .def(nb::init, std::optional>(), + nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt) + .def_prop_rw( + "backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) + .def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, + &tle::CacheTransceiverConfig::setMaxTokensInBuffer) .def("__getstate__", cacheTransceiverConfigGetstate) .def("__setstate__", cacheTransceiverConfigSetstate); diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index af72d9ac44b..08082584cda 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -2478,8 +2478,9 @@ def test_guided_decoding_config_pickle(): def test_cache_transceiver_config_pickle(): - config = trtllm.CacheTransceiverConfig(backend="UCX", - max_tokens_in_buffer=1024) + config = trtllm.CacheTransceiverConfig( + backend=trtllm.CacheTransceiverBackendType.UCX, + max_tokens_in_buffer=1024) config_copy = pickle.loads(pickle.dumps(config)) assert config_copy.backend == config.backend assert config_copy.max_tokens_in_buffer == config.max_tokens_in_buffer