diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index a3a8e087e34..a22a62bf808 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -220,10 +220,10 @@ void initBindings(nb::module_& m) nb::class_(m, "DecoderState") .def(nb::init<>()) - .def("setup", &tr::decoder::DecoderState::setup, nb::arg("max_batch_size"), nb::arg("max_beam_width"), + .def("setup", &tr::decoder::DecoderState::setup, nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window"), nb::arg("sink_token_length"), nb::arg("max_sequence_length"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")) - .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, nb::arg("max_batch_size"), + .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window"), nb::arg("buffer_manager")) .def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding, nb::arg("speculative_decoding_mode"), nb::arg("max_tokens_per_engine_step"), nb::arg("dtype"), @@ -277,7 +277,7 @@ void initBindings(nb::module_& m) nb::class_(m, "GptDecoderBatched") .def(nb::init(), nb::arg("stream")) - .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_batch_size"), + .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config")) .def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input")) .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference)