From 9b45499caa217e756bc6d2b9a89e524b63bce00f Mon Sep 17 00:00:00 2001 From: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Date: Thu, 17 Jul 2025 18:05:45 +0800 Subject: [PATCH 1/9] test: update max_beam_width to 1 due to torchsampler changes. (#6101) Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> --- tests/unittest/llmapi/test_llm_args.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index c1bfdcc4001..801a2bf12a9 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -372,18 +372,18 @@ class TestTorchLlmArgs: def test_runtime_sizes(self): llm = TorchLLM( llama_model_path, - max_beam_width=4, + max_beam_width=1, max_num_tokens=256, max_seq_len=128, max_batch_size=8, ) - assert llm.args.max_beam_width == 4 + assert llm.args.max_beam_width == 1 assert llm.args.max_num_tokens == 256 assert llm.args.max_seq_len == 128 assert llm.args.max_batch_size == 8 - assert llm._executor_config.max_beam_width == 4 + assert llm._executor_config.max_beam_width == 1 assert llm._executor_config.max_num_tokens == 256 assert llm._executor_config.max_seq_len == 128 assert llm._executor_config.max_batch_size == 8 From a7184869001d28ca70a738e9862ea91cb147da8c Mon Sep 17 00:00:00 2001 From: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Date: Thu, 17 Jul 2025 18:24:49 +0800 Subject: [PATCH 2/9] fix: Fix DeepSeek R1 CI (#6129) Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 4 ++-- tests/integration/test_lists/waives.txt | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 8c5b75e65fb..4e12889fa98 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1352,7 +1352,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, attention_dp, cuda_graph, overlap_scheduler, max_batch_size, moe_backend): - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.85) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -1374,7 +1374,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, enable_attention_dp=attention_dp, speculative_config=mtp_config) as llm: - assert llm.args.moe_backend == moe_backend + assert llm.args.moe_config.backend == moe_backend assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 task = MMLU(self.MODEL_NAME) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index e9f4ed4401e..cd453839d9a 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -439,5 +439,3 @@ examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float test_e2e.py::test_ptp_quickstart SKIP (https://nvbugs/5387762) triton_server/test_triton_llm.py::test_llava_onevision[test_basic-False-1---False-True-False-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-max_utilization---1-1-1-False-tensorrt_llm_bls] SKIP (https://nvbugs/5396437) triton_server/test_triton_llm.py::test_llava_onevision[test_video-False-1---False-True-False-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-tensorrt_llm_bls] SKIP (https://nvbugs/5396437) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] SKIP (https://nvbugs/5397036) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5397036) From 9518e14f69e408ce74f4128522ab5cbf516bb7f1 Mon Sep 17 00:00:00 2001 From: Stanley Sun <190317771+StanleySun639@users.noreply.github.com> Date: Thu, 17 Jul 2025 18:55:04 +0800 Subject: [PATCH 3/9] test: fix PytestUnknownMarkWarning: Unknown pytest.mark.timeout (#6115) Signed-off-by: Stanley Sun <190317771+StanleySun639@users.noreply.github.com> --- tests/integration/defs/pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/defs/pytest.ini b/tests/integration/defs/pytest.ini index 24b270884c0..69629dce95c 100644 --- a/tests/integration/defs/pytest.ini +++ b/tests/integration/defs/pytest.ini @@ -12,3 +12,4 @@ markers = skip_less_host_memory: skip when less host memory detected than the requested support_fp8: skip when fp8 is not supported on the device skip_device_not_contain: skip when the device does not contain the specified keyword + timeout: set test timeout in seconds From 58d22a72f1f2b893b8b937a01c3d827efb4815e6 Mon Sep 17 00:00:00 2001 From: Ziyi Xiong <219238287+ziyixiong-nv@users.noreply.github.com> Date: Thu, 17 Jul 2025 21:15:01 +0800 Subject: [PATCH 4/9] [TRTLLM-6352][feat] Migrate EAGLE3 and draft/target speculation to Drafter (#6007) Signed-off-by: ziyixiong-nv --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 198 +--------- .../_torch/pyexecutor/py_executor_creator.py | 3 +- tensorrt_llm/_torch/speculative/drafter.py | 7 + .../_torch/speculative/model_drafter.py | 353 ++++++++++++++++++ tensorrt_llm/_torch/speculative/ngram.py | 7 +- tensorrt_llm/_torch/speculative/utils.py | 20 +- 6 files changed, 388 insertions(+), 200 deletions(-) create mode 100644 tensorrt_llm/_torch/speculative/model_drafter.py diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index c402480b7d9..6826cda6114 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -11,7 +11,7 @@ import weakref from collections import deque, namedtuple from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import torch @@ -308,7 +308,7 @@ def __init__(self, if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) - if self.draft_model_engine is not None: + if self.drafter is not None: if self.event_loop.__name__ != self._executor_loop.__name__: raise NotImplementedError( "Drafting is not supported for selected executor loop. " @@ -905,10 +905,6 @@ def _executor_loop_pp(self): def _executor_loop(self): torch.cuda.set_device(self.device_id) - is_ngram = hasattr( - self.model_engine, "spec_config" - ) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram( - ) with self._profiler() as profile_step: sample_state = None iter_start_time = time.time() @@ -931,7 +927,7 @@ def _executor_loop(self): self._pad_attention_dp_dummy_request() - if self.draft_model_engine is not None or is_ngram or self.drafter is not None: + if self.drafter is not None: self._prepare_draft_requests(self.active_requests) scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( @@ -971,11 +967,9 @@ def _executor_loop(self): scheduled_batch) self.resource_manager.prepare_resources(scheduled_batch) - if self.draft_model_engine is not None: - self._prepare_draft_tokens(scheduled_batch) - if self.drafter is not None: - self.drafter.prepare_draft_tokens(scheduled_batch) + self.drafter.prepare_draft_tokens( + scheduled_batch, self.resource_manager) if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer @@ -1798,188 +1792,6 @@ def _update_requests(self, sample_state: SampleState): logger.error(f"Encountered an error in sampling: {error_msg}") self._handle_errors(error_msg) - @nvtx_range("_prepare_draft_batch") - def _prepare_draft_batch( - self, scheduled_requests: ScheduledRequests - ) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]: - """ - Prepares a batch for the draft model engine. Draft tokens are only produced - for generation requests. - - The requests are prepared as follows: - 1. The first time the draft engine sees a request, it's a context request. - 2. Otherwise, if draft tokens were accepted on the last target model decoding - step, it's a chunked context request (we process all the accepted tokens together). - 3. Otherwise, it's a generation request. - """ - try: - draft_batch = ScheduledRequests() - - for request in scheduled_requests.generation_requests: - if request.py_draft_pages_allocated == 0: - # No space for draft tokens. - continue - - # Stop drafting when we hit the max seqlen. We still need dummy draft - # tokens attached to the requests to make sure everything works properly - # with CUDA graph. These dummy tokens are already added by - # _prepare_draft_requests to make the KV cache/scheduler aware of the fact - # that we want to do spec decoding, so no need to do anything else here. - # This makes the perf for this case suboptimal, but that's OK - this is - # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. - if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: - continue - - num_draft_tokens = len( - request.py_last_draft_tokens - ) if request.py_last_draft_tokens is not None else 0 - request.py_draft_tokens = [] - - num_accepted_tokens = request.py_num_accepted_draft_tokens - num_rejected_tokens = num_draft_tokens - num_accepted_tokens - assert num_rejected_tokens >= 0 - - spec_config = self.model_engine.spec_config - beam_idx = 0 - input_tokens = spec_config.get_draft_model_prompt( - request.get_tokens()[beam_idx]) - - def create_new_request(input_tokens): - return LlmRequest( - request_id=request.py_request_id, - max_new_tokens=request.py_max_new_tokens, - input_tokens=input_tokens, - sampling_config=request.sampling_config, - return_perf_metrics=request.return_perf_metrics, - is_streaming=False, - is_draft=True) - - if request.max_beam_num_tokens - 1 == request.py_prompt_len: - # This is the first time the draft model is seeing this request. - # Prepare a context request. We discard the first token and take - # the newly decoded one - this is the convention for EAGLE 2 and 3. - new_request = create_new_request(input_tokens) - draft_batch.context_requests.append(new_request) - elif num_accepted_tokens == 0: - new_request = create_new_request(input_tokens[:-1]) - # Explicitly add the last token so get_last_tokens() returns - # the right value - new_request.add_new_token(input_tokens[-1], beam_idx) - new_request.state = LlmRequestState.GENERATION_IN_PROGRESS - draft_batch.generation_requests.append(new_request) - else: - new_request = create_new_request(input_tokens) - new_request.context_chunk_size = num_accepted_tokens + 1 - new_request.context_current_position = len( - input_tokens) - num_accepted_tokens - 1 - new_request.context_chunk_size = num_accepted_tokens + 1 - new_request.context_current_position = len( - input_tokens) - num_accepted_tokens - 1 - - draft_batch.context_requests.append(new_request) - - new_request.py_stop_words_list = request.py_stop_words_list - - return draft_batch - - except Exception as e: - traceback.print_exc() - error_msg = str(e) - logger.error(f"Encountered an error in decode: {error_msg}") - self._handle_errors(error_msg) - - @nvtx_range("_prepare_draft_tokens") - def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): - if not self.draft_model_engine: - raise ValueError("Draft model engine is not set") - - try: - draft_batch = self._prepare_draft_batch(scheduled_requests) - - if draft_batch.batch_size == 0: - return - self.draft_seq_slot_manager.prepare_resources(draft_batch) - - req_id_to_old_request = { - req.py_request_id: req - for req in scheduled_requests.all_requests() - } - - # Disable cuda graph for the 1st draft model forward - if self.model_engine.spec_config.spec_dec_mode.needs_kv_cache_recompute( - ): - with self.draft_model_engine.no_cuda_graph(): - outputs = self.draft_model_engine.forward( - draft_batch, self.resource_manager) - else: - outputs = self.draft_model_engine.forward( - draft_batch, self.resource_manager) - if hasattr(self.draft_model_engine.model.model, 'd2t'): - outputs['d2t'] = self.draft_model_engine.model.model.d2t.data - - sample_state = self._sample_async(draft_batch, outputs) - previous_batch = sample_state - - self._update_request_states(draft_batch) - - def _process_decoded_tokens(draft_batch): - new_requests = [] - for req in draft_batch.all_requests(): - target_model_req = req_id_to_old_request[req.py_request_id] - target_model_req.py_draft_tokens.append( - req.get_last_tokens(0)) - if req.state != LlmRequestState.GENERATION_COMPLETE and len( - target_model_req.py_draft_tokens - ) < target_model_req.py_draft_pages_allocated: - new_requests.append(req) - else: - self.draft_seq_slot_manager.free_resources(req) - - return new_requests - - # The TRTLLM attention kernels cannot handle generation requests with - # different seqlens. No issues with flashinfer, should we look into removing - # this? Just needs proper kernel support. - def _pad_to_max_draft_tokens(): - for req in scheduled_requests.generation_requests: - max_draft_len = self.max_draft_len - num_draft_tokens = len(req.py_draft_tokens) - req.py_draft_tokens.extend( - 0 for _ in range(max_draft_len - num_draft_tokens)) - - draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests - draft_batch.context_requests = [] - - for i in range(self.max_draft_len - 1): - if len(draft_batch.generation_requests) == 0: - break - - outputs = self.draft_model_engine.forward( - draft_batch, - self.resource_manager, - new_tensors_device=previous_batch.device) - - if hasattr(self.draft_model_engine.model.model, 'd2t'): - outputs[ - 'd2t'] = self.draft_model_engine.model.model.d2t.data - sample_state = self._sample_async(draft_batch, outputs) - self._update_request_states(draft_batch) - self._update_requests(previous_batch) - new_requests = _process_decoded_tokens( - previous_batch.scheduled_requests) - draft_batch.generation_requests = new_requests - previous_batch = sample_state - self._update_requests(previous_batch) - new_requests = _process_decoded_tokens( - previous_batch.scheduled_requests) - _pad_to_max_draft_tokens() - - except Exception as e: - traceback.print_exc() - error_msg = str(e) - logger.error(f"Encountered an error in decode: {error_msg}") - self._handle_errors(error_msg) - def _handle_errors(self, error_msg: Optional[str] = None): error_responses = {} error_msg = error_msg or "error" diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index b9eccc90601..446b647618d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -382,7 +382,8 @@ def create_py_executor( # Drafter for speculative decoding with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER): - drafter = get_spec_drafter(model_engine, spec_resource_manager) + drafter = get_spec_drafter(model_engine, draft_model_engine, sampler, + spec_resource_manager) with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_EXTRA_RESOURCES diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index d99c5dd92d8..e08044cbb4f 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -1,16 +1,23 @@ from abc import ABC, abstractmethod +from typing import Optional +from ..pyexecutor.resource_manager import ResourceManager from ..pyexecutor.scheduler import ScheduledRequests class Drafter(ABC): + """Abstract base class for all drafter implementations.""" @abstractmethod def prepare_draft_tokens( self, scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, ) -> None: """ Prepare the drafter tokens for the forward computation this step. + + Args: + scheduled_requests: The scheduled requests for this iteration """ raise NotImplementedError diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py new file mode 100644 index 00000000000..ac195ccf515 --- /dev/null +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +import traceback +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.logger import logger + +from ..pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig +from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager +from ..pyexecutor.sampler import Sampler, SampleState +from ..pyexecutor.scheduler import ScheduledRequests +from ..pyexecutor.seq_slot_manager import SeqSlotManager +from .drafter import Drafter + +if TYPE_CHECKING: + from ..pyexecutor.model_engine import ModelEngine + + +class ModelDrafter(Drafter): + """Model-based drafter that uses a draft model to generate draft tokens.""" + + def __init__( + self, + spec_config: "DecodingBaseConfig", + draft_model_engine: "ModelEngine", + max_draft_tokens: int, + draft_seq_slot_manager: SeqSlotManager, + sampler: Sampler, + spec_resource_manager: Optional[BaseResourceManager] = None, + ): + # Validate required parameters + if draft_model_engine is None: + raise ValueError("draft_model_engine cannot be None") + if max_draft_tokens < 0: + raise ValueError(f"max_draft_tokens must be >= 0") + + # Model and resource management + self.draft_model_engine = draft_model_engine + self.draft_seq_slot_manager = draft_seq_slot_manager + self.spec_resource_manager = spec_resource_manager + + # Configuration + self.spec_config = spec_config + self.max_draft_tokens = max_draft_tokens + + # Sampling + self.sampler = sampler + + def _create_draft_request(self, request_id: int, max_new_tokens: int, + input_tokens: Optional[List], + sampling_config: SamplingConfig, + return_perf_metrics: bool) -> LlmRequest: + """Create a draft request with common parameters.""" + return LlmRequest(request_id=request_id, + max_new_tokens=max_new_tokens, + input_tokens=input_tokens, + sampling_config=sampling_config, + return_perf_metrics=return_perf_metrics, + is_streaming=False, + is_draft=True) + + def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]: + """Initialize draft token tracking for a request.""" + num_draft_tokens = len( + request.py_last_draft_tokens + ) if request.py_last_draft_tokens is not None else 0 + request.py_draft_tokens = [] + + num_accepted_tokens = request.py_num_accepted_draft_tokens + num_rejected_tokens = num_draft_tokens - num_accepted_tokens + assert num_rejected_tokens >= 0 + + return num_draft_tokens, num_accepted_tokens + + def _create_context_request(self, request: LlmRequest, + input_tokens: Any) -> LlmRequest: + """Create a context request for first-time drafting.""" + return self._create_draft_request(request.py_request_id, + request.py_max_new_tokens, + input_tokens, request.sampling_config, + request.return_perf_metrics) + + def _create_generation_request(self, request: LlmRequest, + input_tokens: Any) -> LlmRequest: + """Create a generation request when no tokens were accepted.""" + new_request = self._create_draft_request(request.py_request_id, + request.py_max_new_tokens, + input_tokens[:-1], + request.sampling_config, + request.return_perf_metrics) + # Explicitly add the last token so get_last_tokens() returns the right value + new_request.add_new_token(input_tokens[-1], 0) + new_request.state = LlmRequestState.GENERATION_IN_PROGRESS + return new_request + + def _create_chunked_context_request(self, request: LlmRequest, + input_tokens: Any, + num_accepted_tokens: int) -> LlmRequest: + """Create a chunked context request when some tokens were accepted.""" + new_request = self._create_draft_request(request.py_request_id, + request.py_max_new_tokens, + input_tokens, + request.sampling_config, + request.return_perf_metrics) + new_request.context_chunk_size = num_accepted_tokens + 1 + new_request.context_current_position = len( + input_tokens) - num_accepted_tokens - 1 + return new_request + + def _create_draft_request_for_request( + self, request: LlmRequest) -> Optional[LlmRequest]: + """Create a draft request based on the original request state.""" + num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens( + request) + input_tokens = self.spec_config.get_draft_model_prompt( + request.get_tokens()[0]) + + # First time seeing this request - context request + if request.max_beam_num_tokens - 1 == request.py_prompt_len: + # This is the first time the draft model is seeing this request. + # Prepare a context request. We discard the first token and take + # the newly decoded one - this is the convention for EAGLE 2 and 3. + assert num_draft_tokens == 0 + return self._create_context_request(request, input_tokens) + + # No tokens accepted - generation request + elif num_accepted_tokens == 0: + return self._create_generation_request(request, input_tokens) + + # Tokens accepted - chunked context request + else: + return self._create_chunked_context_request(request, input_tokens, + num_accepted_tokens) + + def _add_to_draft_batch(self, draft_batch: ScheduledRequests, + draft_request: LlmRequest, + original_request: LlmRequest) -> None: + """Add the draft request to the appropriate batch list.""" + # Copy additional properties + draft_request.py_stop_words_list = original_request.py_stop_words_list + + # Add to appropriate batch based on request type + if draft_request.state == LlmRequestState.GENERATION_IN_PROGRESS: + draft_batch.generation_requests.append(draft_request) + else: + draft_batch.context_requests.append(draft_request) + + @nvtx_range("_prepare_draft_batch") + def _prepare_draft_batch( + self, scheduled_requests: ScheduledRequests) -> ScheduledRequests: + """ + Prepares a batch for the draft model engine. Draft tokens are only produced + for generation requests. + + The requests are prepared as follows: + 1. The first time the draft engine sees a request, it's a context request. + 2. Otherwise, if draft tokens were accepted on the last target model decoding + step, it's a chunked context request (we process all the accepted tokens together). + 3. Otherwise, it's a generation request. + + Args: + scheduled_requests: The scheduled requests to prepare draft batch for + + Returns: + ScheduledRequests: The prepared draft batch + """ + try: + draft_batch = ScheduledRequests() + + for request in scheduled_requests.generation_requests: + if request.py_draft_pages_allocated == 0: + # No space for draft tokens + continue + + # Stop drafting when we hit the max seqlen. We still need dummy draft + # tokens attached to the requests to make sure everything works properly + # with CUDA graph. These dummy tokens are already added by + # _prepare_draft_requests to make the KV cache/scheduler aware of the fact + # that we want to do spec decoding, so no need to do anything else here. + # This makes the perf for this case suboptimal, but that's OK - this is + # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. + if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: + continue + + draft_request = self._create_draft_request_for_request(request) + if draft_request is not None: + self._add_to_draft_batch(draft_batch, draft_request, + request) + + return draft_batch + + except Exception as e: + logger.error(f"Error in _prepare_draft_batch: {str(e)}") + traceback.print_exc() + raise e + + def _should_disable_cuda_graph( + self, previous_batch: Optional[SampleState]) -> bool: + """Check if CUDA graph should be disabled for the current forward pass.""" + if previous_batch is not None: + return False + return self.spec_config.spec_dec_mode.needs_kv_cache_recompute() + + def _forward_draft_model( + self, + draft_batch: ScheduledRequests, + resource_manager: ResourceManager, + previous_batch: Optional[SampleState] = None) -> Dict[str, Any]: + """Forward pass through the draft model.""" + if self._should_disable_cuda_graph(previous_batch): + with self.draft_model_engine.no_cuda_graph(): + outputs = self.draft_model_engine.forward( + draft_batch, resource_manager) + else: + new_tensors_device = previous_batch.device if previous_batch else None + outputs = self.draft_model_engine.forward( + draft_batch, + resource_manager, + new_tensors_device=new_tensors_device) + + # Handle d2t data if available + if hasattr(self.draft_model_engine.model.model, 'd2t'): + outputs['d2t'] = self.draft_model_engine.model.model.d2t.data + + return outputs + + def _sample_async(self, draft_batch: ScheduledRequests, + outputs: Dict[str, Any]) -> Optional[SampleState]: + """Sample tokens from draft model outputs.""" + try: + if self.sampler is not None: + return self.sampler.sample_async(draft_batch, outputs) + return None + except Exception as e: + logger.error(f"Error in sampling: {str(e)}") + return None + + def _update_request_states(self, + scheduled_requests: ScheduledRequests) -> None: + """Update request states after processing.""" + for request in scheduled_requests.context_requests: + if request.state != LlmRequestState.GENERATION_COMPLETE: + request.move_to_next_context_chunk() + if request.context_remaining_length == 0: + request.state = LlmRequestState.GENERATION_IN_PROGRESS + + def _update_requests(self, sample_state: SampleState) -> None: + """Update requests with sample state.""" + if self.sampler is not None: + self.sampler.update_requests(sample_state) + + def _process_decoded_tokens( + self, draft_batch: ScheduledRequests, + req_id_to_old_request: Dict[int, LlmRequest]) -> List[LlmRequest]: + """Process decoded tokens and determine which requests to continue processing.""" + new_requests = [] + for req in draft_batch.all_requests(): + target_model_req = req_id_to_old_request[req.py_request_id] + target_model_req.py_draft_tokens.append(req.get_last_tokens(0)) + if req.state != LlmRequestState.GENERATION_COMPLETE and len( + target_model_req.py_draft_tokens + ) < target_model_req.py_draft_pages_allocated: + new_requests.append(req) + else: + self.draft_seq_slot_manager.free_resources(req) + + return new_requests + + def _pad_to_max_draft_tokens(self, + scheduled_requests: ScheduledRequests) -> None: + """Pad draft tokens to maximum length for all generation requests.""" + for req in scheduled_requests.generation_requests: + max_draft_tokens = self.max_draft_tokens + num_draft_tokens = len(req.py_draft_tokens) + req.py_draft_tokens.extend( + 0 for _ in range(max_draft_tokens - num_draft_tokens)) + + @nvtx_range("prepare_draft_tokens") + def prepare_draft_tokens( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + ) -> None: + """ + Prepare draft tokens for the scheduled requests. + + Args: + scheduled_requests: The scheduled requests for this iteration + resource_manager: The resource manager for this iteration + """ + if not self.draft_model_engine: + raise ValueError("Draft model engine is not set") + + if resource_manager is None: + raise ValueError("Resource manager is required") + + try: + draft_batch = self._prepare_draft_batch(scheduled_requests) + + if draft_batch.batch_size == 0: + return + + self.draft_seq_slot_manager.prepare_resources(draft_batch) + + req_id_to_old_request = { + req.py_request_id: req + for req in scheduled_requests.all_requests() + } + + # Initial forward pass + outputs = self._forward_draft_model(draft_batch, resource_manager) + sample_state = self._sample_async(draft_batch, outputs) + previous_batch = sample_state + + self._update_request_states(draft_batch) + + # Convert context requests to generation requests + draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests + draft_batch.context_requests = [] + + # Generate remaining draft tokens iteratively + for i in range(self.max_draft_tokens - 1): + if len(draft_batch.generation_requests) == 0: + break + + outputs = self._forward_draft_model(draft_batch, + resource_manager, + previous_batch) + sample_state = self._sample_async(draft_batch, outputs) + self._update_request_states(draft_batch) + if previous_batch is not None: + self._update_requests(previous_batch) + new_requests = self._process_decoded_tokens( + previous_batch.scheduled_requests, + req_id_to_old_request) + else: + new_requests = [] + draft_batch.generation_requests = new_requests + previous_batch = sample_state + + # Final cleanup + if previous_batch is not None: + self._update_requests(previous_batch) + self._process_decoded_tokens(previous_batch.scheduled_requests, + req_id_to_old_request) + self._pad_to_max_draft_tokens(scheduled_requests) + + except Exception as e: + traceback.print_exc() + error_msg = str(e) + logger.error(f"Encountered an error in decode: {error_msg}") + raise e diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 57f3045e664..9113900ef94 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -5,7 +5,7 @@ from tensorrt_llm.logger import logger from ..pyexecutor.llm_request import * -from ..pyexecutor.resource_manager import BaseResourceManager +from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager from ..pyexecutor.scheduler import ScheduledRequests from .drafter import Drafter @@ -59,10 +59,10 @@ def __init__(self, spec_config: "NGramDecodingConfig", self.start_index = {} def get_max_resource_count(self) -> int: - raise self.max_num_requests + return self.max_num_requests def get_needed_resource_to_completion(self, request: LlmRequest) -> int: - raise 0 + return 0 def prepare_resources(self, scheduled_batch: ScheduledRequests): pass @@ -173,6 +173,7 @@ def __init__( def prepare_draft_tokens( self, scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, ) -> None: # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 667d1a14b0e..2519584274f 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -1,9 +1,11 @@ from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler from tensorrt_llm._torch.speculative.interface import SpecMetadata +from ..pyexecutor.seq_slot_manager import SeqSlotManager from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata, Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3SpecMetadata) +from .model_drafter import ModelDrafter from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata, MTPWorker) from .ngram import NGramDrafter, NGramPoolManager @@ -112,14 +114,26 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}") -def get_spec_drafter(model_engine, spec_resource_manager): +def get_spec_drafter(model_engine, draft_model_engine, sampler, + spec_resource_manager): spec_config = model_engine.spec_config if spec_config is None: return None - if spec_config.spec_dec_mode.is_ngram(): - return NGramDrafter(spec_config, spec_resource_manager) + if spec_config.spec_dec_mode.is_user_provided(): return spec_config.drafter + + max_num_requests = model_engine.batch_size + if spec_config.spec_dec_mode.is_draft_target( + ) or spec_config.spec_dec_mode.is_eagle3(): + return ModelDrafter(spec_config, draft_model_engine, + spec_config.max_draft_len, + SeqSlotManager(max_num_requests), sampler, + spec_resource_manager) + + if spec_config.spec_dec_mode.is_ngram(): + return NGramDrafter(spec_config, spec_resource_manager) + return None From 5bff317abf528b03a8ab3ee8d05857addb221af8 Mon Sep 17 00:00:00 2001 From: Linda <57756729+Linda-Stadter@users.noreply.github.com> Date: Thu, 17 Jul 2025 16:42:52 +0200 Subject: [PATCH 5/9] feat: nanobind bindings (#5961) Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> --- cpp/CMakeLists.txt | 4 +- .../batch_manager/runtimeBuffers.h | 2 +- .../batch_manager/runtimeBuffers.cpp | 2 +- cpp/tensorrt_llm/nanobind/CMakeLists.txt | 37 +- .../nanobind/batch_manager/algorithms.cpp | 178 ++++ .../nanobind/batch_manager/algorithms.h | 29 + .../nanobind/batch_manager/bindings.cpp | 525 ++++++++++ .../nanobind/batch_manager/bindings.h | 28 + .../nanobind/batch_manager/buffers.cpp | 108 ++ .../nanobind/batch_manager/buffers.h | 29 + .../batch_manager/cacheTransceiver.cpp | 110 +++ .../nanobind/batch_manager/cacheTransceiver.h | 29 + .../nanobind/batch_manager/kvCacheManager.cpp | 478 +++++++++ .../nanobind/batch_manager/kvCacheManager.h | 39 + .../nanobind/batch_manager/llmRequest.cpp | 131 +++ .../nanobind/batch_manager/llmRequest.h | 160 +++ cpp/tensorrt_llm/nanobind/bindings.cpp | 471 ++++++++- cpp/tensorrt_llm/nanobind/common/bindTypes.h | 100 ++ .../nanobind/common/customCasters.h | 345 +++++++ .../nanobind/executor/bindings.cpp | 263 +++++ cpp/tensorrt_llm/nanobind/executor/bindings.h | 29 + .../nanobind/executor/executor.cpp | 241 +++++ cpp/tensorrt_llm/nanobind/executor/executor.h | 129 +++ .../nanobind/executor/executorConfig.cpp | 616 ++++++++++++ .../nanobind/executor/executorConfig.h | 30 + .../nanobind/executor/request.cpp | 935 ++++++++++++++++++ cpp/tensorrt_llm/nanobind/executor/request.h | 29 + .../nanobind/runtime/bindings.cpp | 388 ++++++++ cpp/tensorrt_llm/nanobind/runtime/bindings.h | 30 + .../nanobind/runtime/moeBindings.cpp | 124 +++ .../nanobind/runtime/moeBindings.h | 29 + .../nanobind/testing/modelSpecBinding.cpp | 87 ++ .../nanobind/testing/modelSpecBinding.h | 29 + .../nanobind/userbuffers/bindings.cpp | 47 + .../nanobind/userbuffers/bindings.h | 30 + cpp/tensorrt_llm/pybind/bindings.cpp | 2 +- cpp/tensorrt_llm/pybind/executor/bindings.cpp | 12 +- .../pybind/executor/executorConfig.cpp | 2 +- examples/models/core/llama/summarize_long.py | 2 +- examples/models/core/qwen2audio/run.py | 3 +- examples/models/core/qwenvl/run.py | 3 +- jenkins/Build.groovy | 18 + jenkins/L0_Test.groovy | 8 + tensorrt_llm/builder.py | 2 +- tensorrt_llm/commands/build.py | 19 +- tensorrt_llm/runtime/model_runner.py | 2 +- .../integration/test_lists/test-db/l0_a10.yml | 15 + tests/unittest/bindings/test_bindings_ut.py | 7 + .../bindings/test_executor_bindings.py | 17 +- 49 files changed, 5932 insertions(+), 21 deletions(-) create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/algorithms.h create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/bindings.h create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/buffers.h create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h create mode 100644 cpp/tensorrt_llm/nanobind/common/bindTypes.h create mode 100644 cpp/tensorrt_llm/nanobind/common/customCasters.h create mode 100644 cpp/tensorrt_llm/nanobind/executor/bindings.cpp create mode 100644 cpp/tensorrt_llm/nanobind/executor/bindings.h create mode 100644 cpp/tensorrt_llm/nanobind/executor/executor.cpp create mode 100644 cpp/tensorrt_llm/nanobind/executor/executor.h create mode 100644 cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp create mode 100644 cpp/tensorrt_llm/nanobind/executor/executorConfig.h create mode 100644 cpp/tensorrt_llm/nanobind/executor/request.cpp create mode 100644 cpp/tensorrt_llm/nanobind/executor/request.h create mode 100644 cpp/tensorrt_llm/nanobind/runtime/bindings.cpp create mode 100644 cpp/tensorrt_llm/nanobind/runtime/bindings.h create mode 100644 cpp/tensorrt_llm/nanobind/runtime/moeBindings.cpp create mode 100644 cpp/tensorrt_llm/nanobind/runtime/moeBindings.h create mode 100644 cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.cpp create mode 100644 cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.h create mode 100644 cpp/tensorrt_llm/nanobind/userbuffers/bindings.cpp create mode 100644 cpp/tensorrt_llm/nanobind/userbuffers/bindings.h diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a76b3e21558..d9e8c206f46 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -198,7 +198,7 @@ set(TRT_LIB TensorRT::NvInfer) get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH) set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty) -if(BINDING_TYPE STREQUAL "pybind") +if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP) add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11) endif() @@ -217,7 +217,7 @@ include_directories( ${3RDPARTY_DIR}/cutlass/tools/util/include ${3RDPARTY_DIR}/NVTX/include ${3RDPARTY_DIR}/json/include) -if(BINDING_TYPE STREQUAL "pybind") +if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP) include_directories(${3RDPARTY_DIR}/pybind11/include) endif() if(BINDING_TYPE STREQUAL "nanobind") diff --git a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h index 13bde6d07a5..fa43d084b27 100644 --- a/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h +++ b/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h @@ -168,7 +168,7 @@ class RuntimeBuffers public: //! Additional buffers depending on model type - std::unique_ptr transformerBuffers; + std::shared_ptr transformerBuffers; std::unique_ptr rnnStateBuffers; //! Encoder-Decoder diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index 691fb9c7efd..e8b71d065f3 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -84,7 +84,7 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, if (modelConfig.isTransformerBased()) { - transformerBuffers = std::make_unique(maxBatchSize, maxBeamWidth, maxAttentionWindowVec, + transformerBuffers = std::make_shared(maxBatchSize, maxBeamWidth, maxAttentionWindowVec, maxAttentionWindow, sinkTokenLen, runtime, modelConfig, worldConfig); } if (modelConfig.isRnnBased()) diff --git a/cpp/tensorrt_llm/nanobind/CMakeLists.txt b/cpp/tensorrt_llm/nanobind/CMakeLists.txt index d2e7eac20c2..3d570f024d7 100755 --- a/cpp/tensorrt_llm/nanobind/CMakeLists.txt +++ b/cpp/tensorrt_llm/nanobind/CMakeLists.txt @@ -3,7 +3,23 @@ set(TRTLLM_NB_MODULE ${TRTLLM_NB_MODULE} PARENT_SCOPE) -set(SRCS ../runtime/ipcNvlsMemory.cu bindings.cpp) +set(SRCS + batch_manager/algorithms.cpp + batch_manager/bindings.cpp + batch_manager/buffers.cpp + batch_manager/cacheTransceiver.cpp + batch_manager/kvCacheManager.cpp + batch_manager/llmRequest.cpp + executor/bindings.cpp + executor/executor.cpp + executor/executorConfig.cpp + executor/request.cpp + runtime/bindings.cpp + testing/modelSpecBinding.cpp + runtime/moeBindings.cpp + userbuffers/bindings.cpp + ../runtime/ipcNvlsMemory.cu + bindings.cpp) include_directories(${PROJECT_SOURCE_DIR}/include) @@ -14,20 +30,29 @@ set_property(TARGET ${TRTLLM_NB_MODULE} PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_directories(${TRTLLM_NB_MODULE} PUBLIC "${TORCH_INSTALL_PREFIX}/lib") +if(ENABLE_NVSHMEM) + target_link_libraries(${TRTLLM_NB_MODULE} PUBLIC nvshmem::nvshmem_host + nvshmem::nvshmem_device) +endif() + target_link_libraries( ${TRTLLM_NB_MODULE} - PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} ${NO_AS_NEEDED_FLAG} - ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python) - + PUBLIC ${SHARED_TARGET} + ${UNDEFINED_FLAG} + ${NO_AS_NEEDED_FLAG} + ${Python3_LIBRARIES} + ${TORCH_LIBRARIES} + torch_python + ${CUDA_NVML_LIB}) target_compile_definitions( ${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE} - NB_DETAILED_ERROR_MESSAGES=1) + PYBIND11_DETAILED_ERROR_MESSAGES=1) if(NOT WIN32) set_target_properties( ${TRTLLM_NB_MODULE} PROPERTIES LINK_FLAGS - "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" + "-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' -Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib/stubs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" ) endif() diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp new file mode 100644 index 00000000000..637401555e8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp @@ -0,0 +1,178 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "algorithms.h" +#include "tensorrt_llm/batch_manager/allocateKvCache.h" +#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h" +#include "tensorrt_llm/batch_manager/capacityScheduler.h" +#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h" +#include "tensorrt_llm/batch_manager/handleContextLogits.h" +#include "tensorrt_llm/batch_manager/handleGenerationLogits.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/batch_manager/logitsPostProcessor.h" +#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h" +#include "tensorrt_llm/batch_manager/medusaBuffers.h" +#include "tensorrt_llm/batch_manager/microBatchScheduler.h" +#include "tensorrt_llm/batch_manager/pauseRequests.h" +#include "tensorrt_llm/batch_manager/peftCacheManager.h" +#include "tensorrt_llm/batch_manager/runtimeBuffers.h" +#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/decoderState.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace nb = nanobind; + +namespace tr = tensorrt_llm::runtime; +using namespace tensorrt_llm::batch_manager; + +void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_& m) +{ + nb::class_(m, CapacityScheduler::name) + .def(nb::init(), + nb::arg("max_num_requests"), nb::arg("capacity_scheduler_policy"), nb::arg("has_kv_cache_manager"), + nb::arg("two_step_lookahead") = false, nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT, + nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE) + .def("__call__", &CapacityScheduler::operator(), nb::arg("active_requests"), + nb::arg("kv_cache_manager") = nullptr, nb::arg("peft_cache_manager") = nullptr, + nb::arg("cross_kv_cache_manager") = nullptr) + .def("name", [](CapacityScheduler const&) { return CapacityScheduler::name; }); + + nb::class_(m, MicroBatchScheduler::name) + .def(nb::init, std::optional, LlmRequestState, + LlmRequestState>(), + nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt, + nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT, + nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE) + .def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"), + nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime")) + .def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; }); + + nb::class_(m, PauseRequests::name) + .def(nb::init(), nb::arg("max_input_len")) + .def("__call__", &PauseRequests::operator(), nb::arg("requests_to_pause"), nb::arg("inflight_req_ids"), + nb::arg("req_ids_to_pause"), nb::arg("pause_flagged"), nb::arg("seq_slot_manager"), + nb::arg("kv_cache_manager") = std::nullopt, nb::arg("cross_kv_cache_manager") = std::nullopt, + nb::arg("peft_cache_manager") = std::nullopt) + .def("name", [](PauseRequests const&) { return PauseRequests::name; }); + + nb::class_(m, AssignReqSeqSlots::name) + .def(nb::init<>()) + .def("__call__", &AssignReqSeqSlots::operator(), nb::arg("seq_slot_manager"), nb::arg("context_requests"), + nb::arg("generation_requests")) + .def("name", [](AssignReqSeqSlots const&) { return AssignReqSeqSlots::name; }); + + nb::class_(m, AllocateKvCache::name) + .def(nb::init<>()) + .def("__call__", &AllocateKvCache::operator(), nb::arg("kv_cache_manager"), nb::arg("context_requests"), + nb::arg("generation_requests"), nb::arg("model_config"), nb::arg("cross_kv_cache_manager") = std::nullopt) + .def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; }); + + nb::class_(m, HandleContextLogits::name) + .def(nb::init<>()) + .def( + "__call__", + [](HandleContextLogits const& self, DecoderInputBuffers& inputBuffers, RequestVector const& contextRequests, + at::Tensor const& logits, std::vector const& numContextLogitsVec, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + OptionalRef medusaBuffers = std::nullopt) + { + return self(inputBuffers, contextRequests, tr::TorchView::of(logits), numContextLogitsVec, modelConfig, + manager, medusaBuffers); + }, + nb::arg("decoder_input_buffers"), nb::arg("context_requests"), nb::arg("logits"), + nb::arg("num_context_logits"), nb::arg("model_config"), nb::arg("buffer_manager"), + nb::arg("medusa_buffers") = std::nullopt) + .def("name", [](HandleContextLogits const&) { return HandleContextLogits::name; }); + + nb::class_(m, HandleGenerationLogits::name) + .def(nb::init<>()) + .def( + "__call__", + [](HandleGenerationLogits const& self, DecoderInputBuffers& inputBuffers, + RequestVector const& generationRequests, at::Tensor const& logits, tr::SizeType32 logitsIndex, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + OptionalRef genRuntimeBuffers = std::nullopt, + OptionalRef medusaBuffers = std::nullopt) + { + self(inputBuffers, generationRequests, tr::TorchView::of(logits), logitsIndex, modelConfig, manager, + genRuntimeBuffers, medusaBuffers); + }, + nb::arg("decoder_input_buffers"), nb::arg("generation_requests"), nb::arg("logits"), + nb::arg("logits_index"), nb::arg("model_config"), nb::arg("buffer_manager"), + nb::arg("gen_runtime_buffers") = std::nullopt, nb::arg("medusa_buffers") = std::nullopt) + .def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; }); + + nb::class_(m, MakeDecodingBatchInputOutput::name) + .def(nb::init<>()) + .def("__call__", &MakeDecodingBatchInputOutput::operator(), nb::arg("context_requests"), + nb::arg("generation_requests"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), + nb::arg("model_config"), nb::arg("max_num_sequences"), nb::arg("fused_runtime_buffers") = std::nullopt) + .def("name", [](MakeDecodingBatchInputOutput const&) { return MakeDecodingBatchInputOutput::name; }); + + nb::class_(m, LogitsPostProcessor::name) + .def(nb::init<>()) + .def("__call__", &LogitsPostProcessor::operator(), nb::arg("context_requests"), nb::arg("generation_requests"), + nb::arg("replicate_logits_post_processor"), nb::arg("decoder_buffers"), nb::arg("world_config"), + nb::arg("runtime"), nb::arg("logits_post_processor_batched") = std::nullopt) + .def("name", [](LogitsPostProcessor const&) { return LogitsPostProcessor::name; }); + + nb::class_(m, CreateNewDecoderRequests::name) + .def(nb::init(), nb::arg("speculative_decoding_fast_logits"), + nb::arg("is_leader_in_orch_mode"), nb::arg("is_normalize_log_probs")) + .def( + "__call__", + [](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig, + executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, + tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType, + DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState, + tensorrt_llm::runtime::CudaStream const& runtimeStream, + tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, + SizeType32 beamWidth, OptionalRef medusaBuffers = std::nullopt) + { + auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig, + worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState, + runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers); + + return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs), + std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; + }, + nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"), + nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"), + nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"), + nb::arg("max_sequence_length"), nb::arg("beam_width"), nb::arg("medusa_buffers") = std::nullopt) + .def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; }); + + nb::class_(m, UpdateDecoderBuffers::name) + .def(nb::init<>()) + .def("__call__", &UpdateDecoderBuffers::operator(), nb::arg("model_config"), nb::arg("decoder_output_buffers"), + nb::arg("copy_buffer_manager"), nb::arg("decoder_state"), nb::arg("return_log_probs"), + nb::arg("decoder_finish_event")) + .def("name", [](UpdateDecoderBuffers const&) { return UpdateDecoderBuffers::name; }); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.h b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.h new file mode 100644 index 00000000000..cac81d73f27 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager::algorithms +{ + +void initBindings(nb::module_& m); + +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp new file mode 100644 index 00000000000..d44a957aad9 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -0,0 +1,525 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include "tensorrt_llm/batch_manager/common.h" +#include "tensorrt_llm/batch_manager/decoderBuffers.h" +#include "tensorrt_llm/batch_manager/medusaBuffers.h" +#include "tensorrt_llm/batch_manager/microBatchScheduler.h" +#include "tensorrt_llm/batch_manager/peftCacheManager.h" +#include "tensorrt_llm/batch_manager/rnnStateManager.h" +#include "tensorrt_llm/batch_manager/runtimeBuffers.h" +#include "tensorrt_llm/batch_manager/sequenceSlotManager.h" +#include "tensorrt_llm/nanobind/common/bindTypes.h" +#include "tensorrt_llm/runtime/gptDecoderBatched.h" +#include "tensorrt_llm/runtime/runtimeKernels.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tb = tensorrt_llm::batch_manager; +namespace tle = tensorrt_llm::executor; +namespace tr = tensorrt_llm::runtime; + +using namespace tensorrt_llm::runtime; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +void initBindings(nb::module_& m) +{ + using GenLlmReq = tb::GenericLlmRequest; + + // Create and register exceptions in module scope + nb::exception(m, "PeftTaskNotCachedException"); + nb::exception(m, "LoraCacheFullException"); + + // Register with no captures + nb::register_exception_translator( + [](std::exception_ptr const& p, void*) + { + try + { + if (p) + std::rethrow_exception(p); + } + catch (const tb::PeftTaskNotCachedException& e) + { + PyErr_SetString(nb::type().ptr(), e.what()); + } + catch (const tr::LoraCacheFullException& e) + { + PyErr_SetString(nb::type().ptr(), e.what()); + } + }); + + PybindUtils::bindSet(m, "ReqIdsSet"); + + nb::enum_(m, "LlmRequestType") + .value("LLMREQUEST_TYPE_CONTEXT_AND_GENERATION", tb::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION) + .value("LLMREQUEST_TYPE_CONTEXT_ONLY", tb::LLMREQUEST_TYPE_CONTEXT_ONLY) + .value("LLMREQUEST_TYPE_GENERATION_ONLY", tb::LLMREQUEST_TYPE_GENERATION_ONLY) + .export_values(); + + nb::class_(m, "ContextChunkingConfig") + .def(nb::init(), nb::arg("chunking_policy"), + nb::arg("chunk_unit_size")) + .def_rw("chunking_policy", &tb::batch_scheduler::ContextChunkingConfig::chunkingPolicy) + .def_rw("chunk_unit_size", &tb::batch_scheduler::ContextChunkingConfig::chunkUnitSize); + + nb::class_(m, "GenericLlmRequest") + .def("set_exclude_input_from_output", &GenLlmReq::setExcludeInputFromOutput, nb::arg("exclude")) + .def("get_num_tokens", &GenLlmReq::getNumTokens, nb::arg("beam")) + .def_prop_ro("max_beam_num_tokens", &GenLlmReq::getMaxBeamNumTokens) + .def("get_token", &GenLlmReq::getToken, nb::arg("beam"), nb::arg("pos")) + .def("get_tokens", nb::overload_cast(&GenLlmReq::getTokens, nb::const_), nb::arg("beam")) + .def("get_tokens", nb::overload_cast<>(&GenLlmReq::getTokens, nb::const_)) + .def("get_last_tokens", nb::overload_cast(&GenLlmReq::getLastTokens), nb::arg("beam")) + .def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens)) + .def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false) + .def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens) + .def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam")) + .def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens")) + .def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens) + .def("set_generated_tokens", &GenLlmReq::setGeneratedTokens, nb::arg("generated_beam_tokens")) + .def("pause", &GenLlmReq::pause, nb::arg("max_input_len")) + .def_prop_rw("max_sent_token_len", &GenLlmReq::getMaxSentTokenLen, &GenLlmReq::setMaxSentTokenLen) + .def_prop_ro("prompt_embedding_table", &GenLlmReq::getPromptEmbeddingTable) + .def_prop_ro("multimodal_embedding", &GenLlmReq::getMultimodalEmbedding) + .def_prop_ro("mrope_rotary_cos_sin", &GenLlmReq::getMropeRotaryCosSin) + .def_prop_ro("bad_words_list", &GenLlmReq::getBadWordsList) + .def_prop_rw("draft_logits", &GenLlmReq::getDraftLogits, &GenLlmReq::setDraftLogits) + .def_prop_ro("embedding_bias", &GenLlmReq::getEmbeddingBias) + .def_prop_rw("lora_config", &GenLlmReq::getLoraConfig, &GenLlmReq::setLoraConfig) + .def_prop_rw("lora_weights", &GenLlmReq::getLoraWeights, &GenLlmReq::setLoraWeights) + .def_prop_ro("stop_words_list", &GenLlmReq::getStopWordsList) + .def_prop_ro("context_logits", &GenLlmReq::getContextLogitsHost) + .def_prop_ro("generation_logits", &GenLlmReq::getGenerationLogitsHost) + .def_prop_ro("prompt_vocab_size", &GenLlmReq::getPromptVocabSize) + .def_prop_ro("mrope_position_deltas", &GenLlmReq::getMropePositionDeltas) + .def_prop_ro("lora_task_id", &GenLlmReq::getLoraTaskId) + .def_prop_ro("lookahead_config", &GenLlmReq::getLookaheadConfig) + .def_prop_rw("context_chunk_size", &GenLlmReq::getContextChunkSize, &GenLlmReq::setContextChunkSize) + .def_prop_rw("decoding_iter", &GenLlmReq::getDecodingIter, &GenLlmReq::setDecodingIter) + .def_rw("request_id", &GenLlmReq::mRequestId) + .def_rw("prompt_len", &GenLlmReq::mPromptLen) + .def_rw("max_new_tokens", &GenLlmReq::mMaxNewTokens) + .def_rw("sampling_config", &GenLlmReq::mSamplingConfig) + .def_prop_rw("state", &GenLlmReq::getState, &GenLlmReq::setState) + .def_prop_rw("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming) + .def_rw("end_id", &GenLlmReq::mEndId) + .def_rw("pad_id", &GenLlmReq::mPadId) + .def_rw("seq_slot", &GenLlmReq::mSeqSlot) + .def_prop_ro("return_log_probs", &GenLlmReq::returnLogProbs) + .def_prop_ro("return_context_logits", &GenLlmReq::getReturnContextLogits) + .def_prop_ro("return_generation_logits", &GenLlmReq::getReturnGenerationLogits) + .def_prop_ro("log_probs", nb::overload_cast<>(&GenLlmReq::getLogProbs, nb::const_)) + .def("get_log_probs", nb::overload_cast(&GenLlmReq::getLogProbs, nb::const_)) + .def("set_log_probs", &GenLlmReq::setLogProbs, nb::arg("log_probs"), nb::arg("beam")) + .def("set_return_encoder_output", &GenLlmReq::setReturnEncoderOutput, nb::arg("return_encoder_output")) + .def("get_return_encoder_output", &GenLlmReq::getReturnEncoderOutput) + .def("priority", nb::overload_cast<>(&GenLlmReq::priority, nb::const_)) + .def("set_priority", nb::overload_cast(&GenLlmReq::setPriority)) + .def_prop_ro("cum_log_probs", &GenLlmReq::getCumLogProbs) + .def("set_cum_log_prob", &GenLlmReq::setCumLogProb, nb::arg("cum_log_prob"), nb::arg("beam")) + .def("update_num_tokens_per_iteration", &GenLlmReq::updateNumTokensPerIteration, + nb::arg("num_tokens_per_iteration"), nb::arg("model_config")) + .def_prop_ro("orig_prompt_len", &GenLlmReq::getOrigPromptLen) + .def("has_draft_tokens", &GenLlmReq::hasDraftTokens) + .def("move_to_next_context_chunk", &GenLlmReq::moveToNextContextChunk) + .def_prop_ro("is_last_context_chunk", &GenLlmReq::isLastContextChunk) + .def_prop_ro("is_first_context_chunk", &GenLlmReq::isFirstContextChunk) + .def_prop_ro("context_remaining_length", &GenLlmReq::getContextRemainingLength) + .def_prop_ro("context_logits", &GenLlmReq::getContextLogitsHost) + .def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens) + .def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam")) + .def_prop_ro("is_finished", &GenLlmReq::isFinished) + .def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) + .def_prop_rw( + "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) + .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) + .def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) + .def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams) + .def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest) + .def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) + .def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState) + .def_prop_ro("is_context_finished", &GenLlmReq::isContextFinished) + .def_prop_ro("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState) + .def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete) + .def_prop_ro( + "is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress) + .def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState) + .def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState) + .def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState) + .def_prop_ro("is_disagg_context_complete_state", &GenLlmReq::isDisaggContextCompleteState) + .def_prop_ro("stage", &GenLlmReq::getRequestStage) + .def_prop_ro("kv_cache_transfer_time_ms", &GenLlmReq::getKvCacheTransferTimeMS) + .def_prop_ro("kv_cache_size", &GenLlmReq::getKvCacheSize) + .def_prop_ro("avg_decoded_tokens_per_iter", &GenLlmReq::getAvgDecodedTokensPerIter) + .def_prop_ro("alloc_total_blocks", &GenLlmReq::getAllocTotalBlocksPerRequest) + .def_prop_ro("alloc_new_blocks", &GenLlmReq::getAllocNewBlocksPerRequest) + .def("alloc_context_logits", &GenLlmReq::allocContextLogitsHost, nb::arg("vocab_size"), nb::arg("logit_dtype")) + .def_prop_ro("reused_blocks", &GenLlmReq::getReusedBlocksPerRequest) + .def_prop_ro("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest) + .def_prop_ro("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest) + .def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType) + .def_prop_ro("multimodal_hashes", + [](GenLlmReq& self) + { + std::optional>> hashes = std::nullopt; + if (self.getMultimodalHashes()) + { + hashes = *self.getMultimodalHashes().value(); + } + return hashes; + }) + .def_prop_ro("multimodal_positions", + [](GenLlmReq& self) + { + std::optional> positions = std::nullopt; + if (self.getMultimodalPositions()) + { + positions = *self.getMultimodalPositions().value(); + } + return positions; + }) + .def_prop_ro("multimodal_lengths", + [](GenLlmReq& self) + { + std::optional> lengths = std::nullopt; + if (self.getMultimodalLengths()) + { + lengths = *self.getMultimodalLengths().value(); + } + return lengths; + }) + .def_prop_ro("position_ids", + [](GenLlmReq& self) + { + std::optional> positionIds = std::nullopt; + if (self.getPositionIds()) + { + positionIds = *self.getPositionIds().value(); + } + return positionIds; + }) + .def_prop_rw( + "draft_tokens", + [](GenLlmReq& self) + { + std::optional draftTokens = std::nullopt; + if (self.hasDraftTokens()) + { + draftTokens = *self.getDraftTokens(); + } + return draftTokens; + }, + [](GenLlmReq& self, std::optional const& draftTokens) + { + if (draftTokens) + { + self.setDraftTokens(std::make_shared(draftTokens.value())); + } + }) + .def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) + .def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); + + nb::class_(m, "LlmRequest", nb::dynamic_attr()) + .def( + "__init__", + [](tb::LlmRequest* self, tb::LlmRequest::RequestIdType request_id, + tb::LlmRequest::SizeType32 max_new_tokens, std::vector input_tokens, + runtime::SamplingConfig sampling_config, bool is_streaming, + std::optional end_id, std::optional pad_id, + std::optional embedding_bias, std::optional bad_words_list, + std::optional stop_words_list, + std::optional> position_ids, + std::optional prompt_embedding_table, + std::optional prompt_vocab_size, + std::optional>> multimodal_hashes, + std::optional> multimodal_positions, + std::optional> multimodal_lengths, + std::optional multimodal_embedding, std::optional mrope_rotary_cos_sin, + std::optional mrope_position_deltas, + std::optional lora_task_id, std::optional lora_weights, + std::optional lora_config, + std::optional lookahead_config, + std::optional kv_cache_retention_config, bool return_log_probs, + bool return_context_logits, bool return_generation_logits, + std::optional draft_tokens, std::optional draft_logits, + bool exclude_input_from_output, + std::optional logits_post_processor, + bool apply_logits_post_processor_batched, std::optional encoder_input_tokens, + bool return_encoder_output, std::optional client_id, + executor::PriorityType priority, std::optional encoder_input_features, + std::optional encoder_output_length, + std::optional cross_attention_mask, tb::LlmRequestType llm_request_type, + std::optional input_token_extra_ids, + tb::LlmRequest::SizeType32 num_return_sequences, std::optional eagle_config, + std::optional skip_cross_attn_blocks, bool return_perf_metrics, + std::optional guided_decoding_params, + std::optional language_adapter_uid, + std::optional allotted_time_ms, + std::optional context_phase_params) + { + auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) + { + std::optional tensorPtr = std::nullopt; + if (atTensor) + { + tensorPtr = tr::TorchView::of(atTensor.value()); + if (unsqueeze) + { + (*tensorPtr)->unsqueeze(0); + } + } + return tensorPtr; + }; + + auto embedding_bias_tensor_ptr = makeOptionalTensor(embedding_bias, true); + auto bad_words_list_tensor_ptr = makeOptionalTensor(bad_words_list, true); + auto stop_words_list_tensor_ptr = makeOptionalTensor(stop_words_list, true); + auto prompt_embedding_table_tensor_ptr = makeOptionalTensor(prompt_embedding_table); + auto multimodal_embedding_tensor_ptr = makeOptionalTensor(multimodal_embedding); + auto lora_weights_tensor_ptr = makeOptionalTensor(lora_weights); + auto mrope_rotary_cos_sin_tensor_ptr = makeOptionalTensor(mrope_rotary_cos_sin); + auto lora_config_tensor_ptr = makeOptionalTensor(lora_config); + auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits); + auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features); + auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask); + auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks); + + // 49 parameters + new (self) tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming, + end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr, + position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, multimodal_hashes, + multimodal_positions, multimodal_lengths, multimodal_embedding_tensor_ptr, + mrope_rotary_cos_sin_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr, + lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs, + return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr, + exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched, + encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr, + encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, + num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, + guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params}; + }, + nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), + nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, + nb::arg("embedding_bias") = std::nullopt, nb::arg("bad_words_list") = std::nullopt, + nb::arg("stop_words_list") = std::nullopt, nb::arg("position_ids") = std::nullopt, + nb::arg("prompt_embedding_table") = std::nullopt, nb::arg("prompt_vocab_size") = std::nullopt, + nb::arg("multimodal_hashes") = std::nullopt, nb::arg("multimodal_positions") = std::nullopt, + nb::arg("multimodal_lengths") = std::nullopt, nb::arg("multimodal_embedding") = std::nullopt, + nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt, + nb::arg("lora_task_id") = std::nullopt, nb::arg("lora_weights") = std::nullopt, + nb::arg("lora_config") = std::nullopt, nb::arg("lookahead_config") = std::nullopt, + nb::arg("kv_cache_retention_config") = std::nullopt, nb::arg("return_log_probs") = false, + nb::arg("return_context_logits") = false, nb::arg("return_generation_logits") = false, + nb::arg("draft_tokens") = std::nullopt, nb::arg("draft_logits") = std::nullopt, + nb::arg("exclude_input_from_output") = false, nb::arg("logits_post_processor") = std::nullopt, + nb::arg("apply_logits_post_processor_batched") = false, nb::arg("encoder_input_tokens") = std::nullopt, + nb::arg("return_encoder_output") = false, nb::arg("client_id") = std::nullopt, + nb::arg("priority") = executor::Request::kDefaultPriority, nb::arg("encoder_input_features") = std::nullopt, + nb::arg("encoder_output_len") = std::nullopt, nb::arg("cross_attention_mask") = std::nullopt, + nb::arg("llm_request_type") = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, + nb::arg("input_token_extra_ids") = std::nullopt, nb::arg("num_return_sequences") = 1, + nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt, + nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, + nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, + nb::arg("context_phase_params") = std::nullopt) + .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), + nb::arg("max_draft_len"), nb::arg("vocab_size_padded"), nb::arg("max_endocer_input_len") = std::nullopt, + nb::arg("enable_kv_cache_reuse") = false) + .def("create_response", &tb::LlmRequest::createResponse, nb::arg("use_fast_logits") = false, + nb::arg("mpi_world_rank") = 0) + .def("create_result", &tb::LlmRequest::createResult, nb::arg("use_fast_logits") = false, + nb::arg("mpi_world_rank") = 0) + .def("create_serialized_result", + [](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0) + { + std::vector serialized_result; + bool is_final = false; + self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank); + return std::make_tuple(nb::bytes(serialized_result.data(), serialized_result.size()), is_final); + }) + .def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, nb::arg("manager")) + .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, nb::arg("manager")) + .def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason")) + .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) + .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")); + + nb::class_(m, "SequenceSlotManager") + .def(nb::init(), nb::arg("max_num_slots"), + nb::arg("max_sequence_idle_microseconds")) + .def("get_sequence_slot", &tb::SequenceSlotManager::getSequenceSlot, nb::arg("start_flag"), + nb::arg("sequence_id")) + .def("free_sequence_slot", &tb::SequenceSlotManager::freeSequenceSlot, nb::arg("sequence_id")) + .def("free_idle_sequence_slots", &tb::SequenceSlotManager::freeIdleSequenceSlots); + + nb::class_(m, "RnnStateManager") + .def(nb::init(), + nb::arg("max_num_sequences"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")); + + nb::class_(m, "DecoderInputBuffers") + .def(nb::init(), + nb::arg("max_num_sequences"), nb::arg("max_batch_size"), nb::arg("max_tokens_per_engine_step"), + nb::arg("manager")) + .def_rw("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots) + .def_rw("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice) + .def_rw("fill_values", &tb::DecoderInputBuffers::fillValues) + .def_rw("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice) + .def_rw("inputs_ids", &tb::DecoderInputBuffers::inputsIds) + .def_rw("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots) + .def_rw("logits", &tb::DecoderInputBuffers::logits); + + nb::class_(m, "DecoderOutputBuffers") + .def_rw("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost) + .def_rw("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost) + .def_prop_ro("new_output_tokens_host", + [](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); }) + .def_rw("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost) + .def_rw("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost) + .def_rw("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost); + + nb::class_(m, "SlotDecoderBuffers") + .def(nb::init(), + nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager")) + .def_rw("output_ids", &tb::SlotDecoderBuffers::outputIds) + .def_rw("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost) + .def_rw("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost) + .def_rw("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs) + .def_rw("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost) + .def_rw("log_probs", &tb::SlotDecoderBuffers::logProbs) + .def_rw("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost) + .def_rw("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost); + + nb::class_(m, "MedusaBuffers") + .def(nb::init(), + nb::arg("max_beam_width"), nb::arg("max_seq_len"), nb::arg("buffer_manager"), nb::arg("model_config"), + nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("runtime")); + + m.def( + "add_new_tokens_to_requests", + [](std::vector>& requests, + std::vector const& tokens, int beam_idx) + { + TLLM_CHECK_WITH_INFO(requests.size() == tokens.size(), "Expected the same number of requests and tokens."); + + for (int i = 0; i < requests.size(); ++i) + { + requests[i]->addNewToken(tokens[i], beam_idx); + } + }, + nb::arg("requests"), nb::arg("tokens"), nb::arg("beam_idx"), + "Add new tokens to multiple LLM requests. The tokens vector should contain tokens for beam beam_idx of all " + "requests in order."); + + m.def( + "make_decoding_batch_input", + [](std::vector>& contextRequests, + std::vector>& genRequests, tr::ITensor::SharedPtr logits, int beamWidth, + std::vector const& numContextLogitsPrefixSum, tb::DecoderInputBuffers const& decoderInputBuffers, + runtime::decoder::DecoderState& decoderState, tr::BufferManager const& manager) + { + std::vector activeSlots; + std::vector generationSteps; + std::vector> logitsVec = {{}}; + + for (int i = 0; i < contextRequests.size(); ++i) + { + if (contextRequests[i]->isLastContextChunk()) + { + activeSlots.push_back(*contextRequests[i]->mSeqSlot); + generationSteps.push_back(contextRequests[i]->getDecodingIter()); + auto contextLogitsOffset = numContextLogitsPrefixSum[i + 1] - 1; + tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, contextLogitsOffset, 1); + + if (beamWidth > 1) + { + // Tile logits of context requests + auto const logitsShape = logitsView->getShape(); + auto const logitsType = logitsView->getDataType(); + auto decoderLogits = manager.gpu(ITensor::makeShape({beamWidth, logitsShape.d[1]}), logitsType); + tensorrt_llm::runtime::kernels::tileTensor( + *decoderLogits, *logitsView, beamWidth, manager.getStream()); + decoderLogits->unsqueeze(0); + logitsVec[0].push_back(std::move(decoderLogits)); + } + else + { + logitsView->unsqueeze(1); + logitsVec[0].push_back(std::move(logitsView)); + } + } + } + + auto genLogitsOffset = numContextLogitsPrefixSum.back(); + for (int i = 0; i < genRequests.size(); ++i) + { + if (genRequests[i]->isGenerationInProgressState()) + { + activeSlots.push_back(*genRequests[i]->mSeqSlot); + generationSteps.push_back(genRequests[i]->getDecodingIter()); + + auto logitsOffset = genLogitsOffset + i * beamWidth; + auto numberOfLogits = beamWidth; + tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, logitsOffset, numberOfLogits); + logitsView->unsqueeze(0); + logitsVec[0].push_back(std::move(logitsView)); + } + } + + auto& batchSlots = decoderInputBuffers.forwardBatchSlots; + batchSlots[0]->resize(activeSlots.size()); + auto batchSlotsRange = tr::BufferRange(*batchSlots[0]); + for (int i = 0; i < activeSlots.size(); ++i) + { + batchSlotsRange[i] = activeSlots[i]; + } + + auto decodingInput = std::make_unique(logitsVec, 1); + decodingInput->batchSlots = batchSlots; + + auto const maxBeamWidth = decoderState.getMaxBeamWidth(); + if (maxBeamWidth > 1) + { + // For Variable-Beam-Width-Search + decoderState.getJointDecodingInput().generationSteps = generationSteps; + } + + return decodingInput; + }, + nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("logits"), nb::arg("beam_width"), + nb::arg("num_context_logits_prefix_sum"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"), + nb::arg("buffer_manager"), "Make decoding batch input."); +} + +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.h b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.h new file mode 100644 index 00000000000..3d5a0f5d5b2 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.h @@ -0,0 +1,28 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +void initBindings(nb::module_& m); + +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp new file mode 100644 index 00000000000..b6edcca1c24 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.cpp @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "buffers.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/runtimeBuffers.h" +#include "tensorrt_llm/batch_manager/transformerBuffers.h" + +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tb = tensorrt_llm::batch_manager; +namespace tr = tensorrt_llm::runtime; + +using tr::SizeType32; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +void Buffers::initBindings(nb::module_& m) +{ + nb::class_(m, "TransformerBuffers") + .def(nb::init const&, SizeType32, SizeType32, + runtime::TllmRuntime const&, runtime::ModelConfig const&, runtime::WorldConfig const&>(), + nb::arg("max_batch_size"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), + nb::arg("max_attention_window"), nb::arg("sink_token_len"), nb::arg("runtime"), nb::arg("model_config"), + nb::arg("world_config")) + .def("reshape", &tb::TransformerBuffers::reshape, nb::arg("num_sequences"), nb::arg("num_input_tokens")) + .def("reshape_kv_tensors", &tb::TransformerBuffers::reshapeKvTensors, nb::arg("max_batch_size"), + nb::arg("max_beam_width"), nb::arg("max_blocks_per_seq"), nb::arg("kv_cache_type"), nb::arg("num_pools"), + nb::arg("buffer_manager")) + .def("get_buffers", &tb::TransformerBuffers::getBuffers, nb::arg("input_buffers"), nb::arg("output_buffers"), + nb::arg("model_config")) + .def("copy_position_ids", &tb::TransformerBuffers::copyPositionIds, nb::arg("runtime"), + nb::arg("position_ids_host"), nb::arg("is_chat_glm"), nb::arg("decoder_position_ids")) + .def("copy_kv_block_offsets", &tb::TransformerBuffers::copyKvBlockOffsets, nb::arg("context_requests"), + nb::arg("gen_requests"), nb::arg("kv_cache_manager"), nb::arg("cross_kv_cache_manager"), + nb::arg("buffer_manager")) + .def("copy_cache_indirection", &tb::TransformerBuffers::copyCacheIndirection, nb::arg("gen_requests"), + nb::arg("decoder_cache_indirection_output"), nb::arg("runtime")) + .def_rw("past_key_value_lengths", &tb::TransformerBuffers::pastKeyValueLengths) + .def_rw("position_ids", &tb::TransformerBuffers::positionIds) + .def_rw("max_attention_windows", &tb::TransformerBuffers::maxAttentionWindows) + .def_rw("sink_token_lengths", &tb::TransformerBuffers::sinkTokenLengths) + .def_rw("cache_indirection", &tb::TransformerBuffers::cacheIndirection) + .def_rw("kv_cache_block_offsets_host", &tb::TransformerBuffers::kvCacheBlockOffsetsHost) + .def_rw("kv_cache_block_offsets_device", &tb::TransformerBuffers::kvCacheBlockOffsetsDevice) + .def_rw("cross_kv_cache_block_pool_pointers", &tb::TransformerBuffers::crossKvCacheBlockPoolPointers) + .def_rw("cross_kv_cache_block_offsets_host", &tb::TransformerBuffers::crossKvCacheBlockOffsetsHost) + .def_rw("cross_kv_cache_block_offsets_device", &tb::TransformerBuffers::crossKvCacheBlockOffsetsDevice) + .def_rw("cache_indir_batched_copy_src_offsets", &tb::TransformerBuffers::cacheIndirBatchedCopySrcOffsets) + .def_rw("cache_indir_batched_copy_dst_offsets", &tb::TransformerBuffers::cacheIndirBatchedCopyDstOffsets) + .def_rw("cache_indir_batched_copy_sizes", &tb::TransformerBuffers::cacheIndirBatchedCopySizes) + .def_rw("fill_values_alt", &tb::TransformerBuffers::fillValuesAlt) + .def_rw("fill_values_alt_device", &tb::TransformerBuffers::fillValuesAltDevice) + .def_rw("seq_slots_alt", &tb::TransformerBuffers::seqSlotsAlt) + .def_rw("seq_slots_alt_device", &tb::TransformerBuffers::seqSlotsAltDevice); + + nb::class_(m, "RuntimeBuffers") + .def(nb::init const&, SizeType32, SizeType32, + runtime::TllmRuntime const&, runtime::ModelConfig const&, runtime::WorldConfig const&, + executor::DecodingConfig const&, bool, std::optional>(), + nb::arg("max_batch_size"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), + nb::arg("max_attention_window"), nb::arg("sink_token_len"), nb::arg("runtime"), nb::arg("model_config"), + nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("gather_generation_logits"), + nb::arg("max_num_tokens") = std::nullopt) + .def_prop_rw( + "transformer_buffers", [](tb::RuntimeBuffers& self) { return self.transformerBuffers; }, + [](tb::RuntimeBuffers& self, std::shared_ptr val) + { self.transformerBuffers = val; }) + .def_rw("num_context_logits", &tb::RuntimeBuffers::numContextLogits) + .def_rw("cache_indir_decoder_io_batched_copy_src_offsets", + &tb::RuntimeBuffers::cacheIndirDecoderIOBatchedCopySrcOffsets) + .def_rw("cache_indir_decoder_io_batched_copy_dst_offsets", + &tb::RuntimeBuffers::cacheIndirDecoderIOBatchedCopyDstOffsets) + .def_rw("cache_indir_decoder_io_batched_copy_sizes", &tb::RuntimeBuffers::cacheIndirDecoderIOBatchedCopySizes) + .def_rw("logits", &tb::RuntimeBuffers::logits) + .def_rw("seq_slots", &tb::RuntimeBuffers::seqSlots) + .def_rw("seq_slots_device", &tb::RuntimeBuffers::seqSlotsDevice) + .def_rw("cache_indir_decoder_io_batched_copy_src_offsets_slice_device", + &tb::RuntimeBuffers::mCacheIndirDecoderIOBatchedCopySrcOffsetsSliceDevice) + .def_rw("cache_indir_decoder_io_batched_copy_dst_offsets_slice_device", + &tb::RuntimeBuffers::mCacheIndirDecoderIOBatchedCopyDstOffsetsSliceDevice) + .def_rw("cache_indir_decoder_io_batched_copy_copy_sizes_device", + &tb::RuntimeBuffers::mCacheIndirDecoderIOBatchedCopyCopySizesDevice); +} +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h new file mode 100644 index 00000000000..34df07e4073 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/buffers.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager +{ +class Buffers +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp new file mode 100644 index 00000000000..abac6d17ed8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -0,0 +1,110 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cacheTransceiver.h" +#include "tensorrt_llm/batch_manager/cacheTransceiver.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include +#include +#include +#include +#include +#include +#include + +using SizeType32 = tensorrt_llm::runtime::SizeType32; + +namespace tb = tensorrt_llm::batch_manager; +namespace nb = nanobind; + +namespace +{ + +class PyCacheTransceiver : public tb::BaseCacheTransceiver +{ +public: + // using BaseCacheTransceiver::BaseCacheTransceiver; // Inherit constructors + NB_TRAMPOLINE(tb::BaseCacheTransceiver, 6); + + void respondAndSendAsync(tb::LlmRequest* llmRequest) override + { + NB_OVERRIDE_PURE(respondAndSendAsync, llmRequest); + } + + void requestAndReceiveSync(tb::LlmRequest* llmRequest) override + { + NB_OVERRIDE_PURE(requestAndReceiveSync, llmRequest); + } + + void requestAndReceiveAsync(tb::LlmRequest* llmRequest) override + { + NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest); + } + + void checkContextTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + { + NB_OVERRIDE_PURE(checkContextTransferStatus, atLeastRequestNum); + } + + void checkGenTransferStatus(std::optional const& atLeastRequestNum = std::nullopt) override + { + NB_OVERRIDE_PURE(checkGenTransferStatus, atLeastRequestNum); + } + + bool checkGenTransferComplete() const override + { + NB_OVERRIDE_PURE(checkGenTransferComplete); + } +}; +} // namespace + +void tb::CacheTransceiverBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "BaseCacheTransceiver") + .def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync) + .def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync) + .def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync) + .def("check_context_transfer_status", &BaseCacheTransceiver::checkContextTransferStatus) + .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus) + .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete); + + nb::enum_(m, "CommType") + .value("UNKNOWN", tb::CacheTransceiver::CommType::UNKNOWN) + .value("MPI", tb::CacheTransceiver::CommType::MPI) + .value("UCX", tb::CacheTransceiver::CommType::UCX) + .value("NIXL", tb::CacheTransceiver::CommType::NIXL); + + nb::enum_(m, "AttentionType") + .value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT) + .value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA); + + nb::class_(m, "CacheTransceiver") + .def(nb::init, SizeType32, SizeType32, runtime::WorldConfig, nvinfer1::DataType, + executor::kv_cache::CacheState::AttentionType, std::optional>(), + nb::arg("cache_manager"), nb::arg("comm_type"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), + nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("dtype"), nb::arg("attention_type"), + nb::arg("cache_transceiver_config") = std::nullopt); + + nb::class_(m, "CacheTransBufferManager") + .def(nb::init>(), nb::arg("cache_manager"), + nb::arg("max_num_tokens") = std::nullopt) + .def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize, + nb::arg("max_num_tokens") = std::nullopt); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h new file mode 100644 index 00000000000..90fc63d4fde --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +namespace nb = nanobind; + +namespace tensorrt_llm::batch_manager +{ +class CacheTransceiverBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp new file mode 100644 index 00000000000..f1c398d31f0 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -0,0 +1,478 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kvCacheManager.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/batch_manager/peftCacheManager.h" +#include "tensorrt_llm/nanobind/common/bindTypes.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tb = tensorrt_llm::batch_manager; +namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; +namespace tr = tensorrt_llm::runtime; +namespace nb = nanobind; +using BlockKey = tbk::BlockKey; +using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; +using SizeType32 = tensorrt_llm::runtime::SizeType32; +using TokenIdType = tensorrt_llm::runtime::TokenIdType; +using VecTokens = std::vector; +using CudaStreamPtr = std::shared_ptr; + +namespace +{ +std::optional from_torch(std::optional torchPtr) +{ + if (torchPtr) + { + return tr::TorchView::of(torchPtr.value()); + } + return std::nullopt; +} + +class PyKvCacheManager : public tbk::BaseKVCacheManager +{ +public: + NB_TRAMPOLINE(tbk::BaseKVCacheManager, 28); + + // using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors + void allocatePools(bool useUvm = false) override + { + NB_OVERRIDE_PURE(allocatePools, useUvm); + } + + void releasePools() override + { + NB_OVERRIDE_PURE(releasePools); + } + + void startScheduling() override + { + NB_OVERRIDE_PURE(startScheduling); + } + + SizeType32 getTokensPerBlock() const override + { + NB_OVERRIDE_PURE(getTokensPerBlock); + } + + SizeType32 getMaxNumBlocks() const override + { + NB_OVERRIDE_PURE(getMaxNumBlocks); + } + + SizeType32 getNumPools() const override + { + NB_OVERRIDE_PURE(getNumPools); + } + + tbk::KvCacheStats getKvCacheStats() const override + { + NB_OVERRIDE_PURE(getKvCacheStats); + } + + void addToken(tb::LlmRequest::RequestIdType requestId) override + { + NB_OVERRIDE_PURE(addToken, requestId); + } + + void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + { + NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); + } + + void removeSequence(tb::LlmRequest::RequestIdType requestId, + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + { + NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest); + } + + tbk::GenerationRequest const& getSequence(tb::LlmRequest::RequestIdType requestId) const override + { + NB_OVERRIDE_PURE(getSequence, requestId); + } + + void schedulingRemoveSequence(tb::LlmRequest::RequestIdType requestId) override + { + NB_OVERRIDE_PURE(schedulingRemoveSequence, requestId); + } + + tensorrt_llm::runtime::ITensor::SharedPtr getBlockPoolPointers() const override + { + NB_OVERRIDE_PURE(getBlockPoolPointers); + } + + tensorrt_llm::runtime::ITensor::SharedPtr getLayerToPoolMapping() const override + { + NB_OVERRIDE_PURE(getLayerToPoolMapping); + } + + void getBlockOffsetsOfBatch(tensorrt_llm::runtime::ITensor& output, SizeType32 firstBatchSlotIdx, + SizeType32 batchSize, SizeType32 beamWidth) const override + { + NB_OVERRIDE_PURE(getBlockOffsetsOfBatch, output, firstBatchSlotIdx, batchSize, beamWidth); + } + + SizeType32 copyBlockOffsets(tensorrt_llm::runtime::ITensor& output, SizeType32 outputSlotOffset, + tb::LlmRequest::RequestIdType requestId) const override + { + NB_OVERRIDE_PURE(copyBlockOffsets, output, outputSlotOffset, requestId); + } + + bool isEnableBlockReuse() const override + { + NB_OVERRIDE_PURE(isEnableBlockReuse); + } + + void rewindKVCache(tb::LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override + { + NB_OVERRIDE_PURE(rewindKVCache, requestId, rewindLengths); + } + + bool isCrossKv() const override + { + NB_OVERRIDE_PURE(isCrossKv); + } + + std::optional findNewContextBlock( + VecUniqueTokens const& uniqueTokens, tb::LlmRequest const& llmRequest) const override + { + NB_OVERRIDE_PURE(findNewContextBlock, uniqueTokens, llmRequest); + } + + void storeContextBlocks(tb::LlmRequest const& llmRequest) override + { + NB_OVERRIDE_PURE(storeContextBlocks, llmRequest); + } + + std::vector> const& getCacheBlockIds( + tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override + { + NB_OVERRIDE_PURE(getCacheBlockIds, requestId, windowSize); + } + + std::vector>> getBatchCacheBlockIds( + std::vector const& requestIds, SizeType32 windowSize) const override + { + NB_OVERRIDE_PURE(getBatchCacheBlockIds, requestIds, windowSize); + } + + std::vector getNewlyAllocatedBlockIds( + tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override + { + NB_OVERRIDE_PURE(getNewlyAllocatedBlockIds, requestId, windowSize); + } + + SizeType32 getUsedNumBlocks() const override + { + NB_OVERRIDE_PURE(getUsedNumBlocks); + } + + SizeType32 getNumFreeBlocks() const override + { + NB_OVERRIDE_PURE(getNumFreeBlocks); + } + + tbk::BlockManager const& getBlockManager() const override + { + NB_OVERRIDE_PURE(getBlockManager); + } + + std::deque getLatestEvents( + std::optional timeout = std::nullopt) const override + { + NB_OVERRIDE_PURE(getLatestEvents, timeout); + } + + tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override + { + NB_OVERRIDE_PURE(getPrimaryPool, layer_idx); + } + + SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override + { + NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx); + } + + void refreshBlocks() override + { + NB_OVERRIDE_PURE(refreshBlocks); + } + + void flushIterationEvents() override + { + NB_OVERRIDE_PURE(flushIterationEvents); + } +}; + +// TODO: Deduplicate executor bindings KvCacheStats +class PyBasePeftCacheManager : public tb::BasePeftCacheManager +{ +public: + ~PyBasePeftCacheManager() override = default; + + NB_TRAMPOLINE(tb::BasePeftCacheManager, 8); + + void addRequestPeft(tb::BasePeftCacheManager::LlmRequestPtr llmRequest, bool tryGpuCache = true) override + { + NB_OVERRIDE_PURE(addRequestPeft, llmRequest, tryGpuCache); + } + + tb::BasePeftCacheManager::PeftTable ensureBatch(tb::RequestVector const& contextRequests, + tb::RequestVector const& generationRequests, bool resetGpuCache = false) override + { + NB_OVERRIDE_PURE(ensureBatch, contextRequests, generationRequests, resetGpuCache); + } + + void resetDeviceCache() override + { + NB_OVERRIDE_PURE(resetDeviceCache); + } + + void markRequestDone(tb::LlmRequest const& llmReq, bool pause = false) override + { + NB_OVERRIDE_PURE(markRequestDone, llmReq, pause); + } + + tr::SizeType32 getMaxDevicePages() const override + { + NB_OVERRIDE_PURE(getMaxDevicePages); + } + + tr::SizeType32 getMaxHostPages() const override + { + NB_OVERRIDE_PURE(getMaxHostPages); + } + + tr::SizeType32 determineNumPages(std::shared_ptr llmRequest) const override + { + NB_OVERRIDE_PURE(determineNumPages, llmRequest); + } + + bool enabled() const override + { + NB_OVERRIDE_PURE(enabled); + } +}; +} // namespace + +void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "KvCacheStats") + .def(nb::init<>()) + .def_rw("max_num_blocks", &tbk::KvCacheStats::maxNumBlocks) + .def_rw("free_num_blocks", &tbk::KvCacheStats::freeNumBlocks) + .def_rw("used_num_blocks", &tbk::KvCacheStats::usedNumBlocks) + .def_rw("tokens_per_block", &tbk::KvCacheStats::toksPerBlock) + .def_rw("alloc_total_blocks", &tbk::KvCacheStats::allocTotalBlocks) + .def_rw("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks) + .def_rw("reused_blocks", &tbk::KvCacheStats::reusedBlocks) + .def_rw("missed_blocks", &tbk::KvCacheStats::missedBlocks) + .def_rw("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate) + .def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize); + + nb::class_(m, "TempAttentionWindowInputs") + .def(nb::init<>()) + .def_rw("paged_context_fmha", &tbk::TempAttentionWindowInputs::pagedContextFMHA) + .def_rw("max_input_len", &tbk::TempAttentionWindowInputs::maxInputLen) + .def_rw("max_num_tokens", &tbk::TempAttentionWindowInputs::maxNumTokens); + + nb::class_(m, "BlockKey") + .def(nb::init<>()) + .def(nb::init>(), nb::arg("tokens"), + nb::arg("lora_task_id") = std::nullopt) + .def(nb::init, VecUniqueTokens const&>(), nb::arg("uses_extra_ids"), + nb::arg("lora_task_id"), nb::arg("unique_tokens")) + .def_ro("uses_extra_ids", &tbk::BlockKey::usesExtraIds) + .def_ro("lora_task_id", &tbk::BlockKey::loraTaskId) + .def_ro("unique_tokens", &tbk::BlockKey::uniqueTokens); + + nb::class_(m, "BlockKeyHasher") + .def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0); + + nb::class_(m, "KVCacheEventManager") + .def(nb::init(), nb::arg("max_kv_event_entries")); + + nb::class_(m, "BaseKVCacheManager") + .def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, nb::arg("config"), + nb::arg("is_cross_attention"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), + nb::arg("window_size_to_layers"), nb::arg("allotted_primary_mem_bytes"), + nb::arg("allotted_secondary_mem_bytes"), nb::arg("extra_cost_memory"), nb::arg("kv_factor")) + .def("allocate_pools", &BaseKVCacheManager::allocatePools) + .def("release_pools", &BaseKVCacheManager::releasePools) + .def("start_scheduling", &BaseKVCacheManager::startScheduling) + .def_prop_ro("tokens_per_block", &BaseKVCacheManager::getTokensPerBlock) + .def_prop_ro("max_num_blocks", &BaseKVCacheManager::getMaxNumBlocks) + .def_prop_ro("num_pools", &BaseKVCacheManager::getNumPools) + .def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats) + .def_prop_ro("max_blocks_per_seq", + [](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; }) + .def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep) + .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion) + .def("add_token", &BaseKVCacheManager::addToken) + .def("add_sequence", &BaseKVCacheManager::addSequence) + .def("remove_sequence", &BaseKVCacheManager::removeSequence) + .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence) + .def("get_block_pool_pointers", + [](tbk::BaseKVCacheManager& self) + { + std::optional block_pool_pointers{std::nullopt}; + auto tensor = self.getBlockPoolPointers(); + if (tensor) + { + std::shared_ptr _tensor = std::move(tensor); + block_pool_pointers = tr::Torch::tensor(_tensor); + } + return block_pool_pointers; + }) + .def("get_layer_to_pool_mapping", + [](tbk::BaseKVCacheManager& self) + { + std::optional layer_to_pool_mapping{std::nullopt}; + auto tensor = self.getLayerToPoolMapping(); + if (tensor) + { + std::shared_ptr _tensor = std::move(tensor); + layer_to_pool_mapping = tr::Torch::tensor(_tensor); + } + return layer_to_pool_mapping; + }) + .def("get_primary_pool_data", + [](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor + { + auto pool = tr::Torch::tensor(self.getPrimaryPool(layer_idx)); + auto pool_layer_idx = self.getPoolLayerIdx(layer_idx); + return pool.index({torch::indexing::Slice(), pool_layer_idx}); + }) + .def("get_block_offsets_of_batch", + [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, + SizeType32 beamWidth) + { + auto _output = from_torch(output); + TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); + self.getBlockOffsetsOfBatch(*(_output.value()), firstBatchSlotIdx, batchSize, beamWidth); + }) + .def("copy_block_offsets", + [](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 outputSlotOffset, + tb::LlmRequest::RequestIdType requestId) + { + auto _output = from_torch(output); + TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); + auto maxBlockCount = self.copyBlockOffsets(*(_output.value()), outputSlotOffset, requestId); + return maxBlockCount; + }) + .def("copy_batch_block_offsets", + [](tbk::BaseKVCacheManager& self, at::Tensor output, + std::vector const& requestIds, SizeType32 const beamWidth, + SizeType32 const offset) + { + auto _output = from_torch(output); + TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); + for (size_t i = 0; i < requestIds.size(); ++i) + { + self.copyBlockOffsets(*(_output.value()), i * beamWidth + offset, requestIds[i]); + } + }) + .def( + "get_latest_events", + [](tbk::BaseKVCacheManager& self, std::optional timeout_ms = std::nullopt) + { + if (timeout_ms) + { + return self.getLatestEvents(std::chrono::milliseconds(static_cast(*timeout_ms))); + } + return self.getLatestEvents(std::nullopt); + }, + nb::arg("timeout_ms") = std::nullopt) + .def_prop_ro("enable_block_reuse", &BaseKVCacheManager::isEnableBlockReuse) + .def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache) + .def_prop_ro("cross_kv", &BaseKVCacheManager::isCrossKv) + .def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks) + .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds) + .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds) + .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds) + .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents); + + nb::bind_vector>>(m, "CacheBlockIds"); + + nb::enum_(m, "CacheType") + .value("SELF", tbk::CacheType::kSELF) + .value("CROSS", tbk::CacheType::kCROSS) + .value("SELFKONLY", tbk::CacheType::kSELFKONLY); + + nb::class_(m, "KVCacheManager") + .def(nb::init const&, SizeType32, SizeType32, + std::map> const&, SizeType32, SizeType32, + std::vector const&, std::optional const&, + nvinfer1::DataType, SizeType32, int64_t, std::optional, bool, bool, + tbk::CacheType, std::optional, + std::shared_ptr, bool, bool>(), + nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"), + nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), + nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"), + nb::arg("sink_token_length"), nb::arg("stream"), nb::arg("max_sequence_length").none(), + nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true, + nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt, + nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true, + nb::arg("copy_on_partial_reuse") = true); +} + +void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "BasePeftCacheManager") + .def("add_request_peft", &tb::BasePeftCacheManager::addRequestPeft, nb::arg("request"), + nb::arg("try_gpu_cache") = true) + .def( + "ensure_batch", + [](tb::BasePeftCacheManager& self, tb::RequestVector const& contextRequests, + tb::RequestVector const& generationRequests, bool resetGpuCache) + { + nb::gil_scoped_release release; + return self.ensureBatch(contextRequests, generationRequests, resetGpuCache); + }, + nb::arg("context_requests"), nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false) + .def("reset_device_cache", &tb::BasePeftCacheManager::resetDeviceCache) + .def("mark_request_done", &tb::BasePeftCacheManager::markRequestDone, nb::arg("request"), + nb::arg("pause") = false) + .def_prop_ro("max_device_pages", &tb::BasePeftCacheManager::getMaxDevicePages) + .def_prop_ro("max_host_pages", &tb::BasePeftCacheManager::getMaxHostPages) + .def("determine_num_pages", &tb::BasePeftCacheManager::determineNumPages, nb::arg("request")) + .def_prop_ro("enabled", &tb::BasePeftCacheManager::enabled); + + nb::class_(m, "PeftCacheManager") + .def(nb::init(), + nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")); + + nb::class_(m, "NoOpPeftCacheManager").def(nb::init<>()); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h new file mode 100644 index 00000000000..786c0d391df --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.h @@ -0,0 +1,39 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::batch_manager::kv_cache_manager +{ +class KVCacheManagerBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager + +namespace tensorrt_llm::batch_manager +{ +class BasePeftCacheManagerBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp new file mode 100644 index 00000000000..d8f45cb865f --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -0,0 +1,131 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "llmRequest.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/nanobind/common/bindTypes.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchUtils.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include + +#include + +namespace tb = tensorrt_llm::batch_manager; +namespace tr = tensorrt_llm::runtime; +namespace tle = tensorrt_llm::executor; + +using namespace tensorrt_llm::nanobind::batch_manager; + +using LlmRequestPtr = std::shared_ptr; +using RequestList = std::list; + +namespace +{ + +std::optional from_torch(std::optional torchPtr) +{ + if (torchPtr) + { + return tr::TorchView::of(torchPtr.value()); + } + return std::nullopt; +} + +} // namespace + +std::optional LlmRequest::callbackAdapter( + std::optional callback) +{ + if (!callback) + { + return std::nullopt; + } + + return [callback](RequestIdType reqId, tr::ITensor::SharedPtr& tensor, tb::LlmRequest::BeamTokens const& tokens, + tr::BufferManager::CudaStreamPtr stream, std::optional clientId) + { + at::Tensor atTensor = tr::Torch::tensor(tensor); + callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId); + }; +} + +std::shared_ptr LlmRequest::toTrtLlm() const +{ + + auto const draftTokens = std::make_shared>(*mDraftTokens.get()); + auto const optDraftTokens = std::optional>>(draftTokens); + auto const encoderInputTokens = mEncoderTokens.has_value() + ? std::make_shared>(*mEncoderTokens.value().get()) + : nullptr; + auto const optEncoderInputTokens = std::optional>>(encoderInputTokens); + // 49 parameters + return std::make_shared( // + mRequestId, // + mMaxNewTokens, // + std::make_shared>(mTokens.at(0)), // + mSamplingConfig, // + mIsStreaming, // + mEndId, // + mPadId, // + from_torch(mEmbeddingBias), // + from_torch(mBadWordsList), // + from_torch(mStopWordsList), // + mPositionIds, // + from_torch(mPromptEmbeddingTable), // + mPromptVocabSize, // + mMultimodalHashes, // + mMultimodalPositions, // + mMultimodalLengths, // + from_torch(mMultimodalEmbedding), // + from_torch(mMropeRotaryCosSin), // + mMropePositionDeltas, // + mLoraTaskId, // + from_torch(mLoraWeights), // + from_torch(mLoraConfig), // + mLookaheadConfig, // + mKvCacheRetentionConfig, // + mReturnLogProbs, // + mReturnContextLogits, // + mReturnGenerationLogits, // + optDraftTokens, // + from_torch(mDraftLogits), // + mExcludeInputFromOutput, // + callbackAdapter(mLogitsPostProcessor), // + mApplyLogitsPostProcessorBatched, // + optEncoderInputTokens, // + mReturnEncoderOutput, // + mClientId, // + mPriority, // + from_torch(mEncoderInputFeatures), // + mEncoderOutputLength, // + from_torch(mCrossAttentionMask), // + getLlmRequestType(), // + std::nullopt, // inputTokenExtraIds + mNumReturnSequences, // + mEagleConfig, // + from_torch(mSkipCrossAttnBlocks), // + false, // returnPerfMetrics + mGuidedDecodingParams, // + mLanguageAdapterUid, // + mAllottedTimeMs, // + mContextPhaseParams // + ); +} diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h new file mode 100644 index 00000000000..624dc55112d --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -0,0 +1,160 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/batch_manager/llmRequest.h" + +#include +#include +#include +#include +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::batch_manager +{ + +namespace tb = tensorrt_llm::batch_manager; + +/* Unfortunately, torch's default nanobind bindings don't know about c10::cuda::CUDAStream, + * so we have to pass the more generic c10::Stream, and convert it back to a full-fledged + * torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py + */ +class LlmRequest : public tb::GenericLlmRequest +{ +public: + using Base = GenericLlmRequest; + using TensorPtr = Base::TensorPtr; + using SizeType32 = Base::SizeType32; + using TokenIdType = Base::TokenIdType; + using RequestIdType = Base::RequestIdType; + using LoraTaskIdType = Base::LoraTaskIdType; + using VecLogProbs = Base::VecLogProbs; + using BeamTokens = Base::BeamTokens; + using VecTokens = Base::VecTokens; + using VecTokenExtraIds = Base::VecTokenExtraIds; + using LogitsPostProcessor = Base::LogitsPostProcessor; + + // 49 parameters + LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector inputTokens, + runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, + std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, + std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, + std::optional> positionIds = std::nullopt, + std::optional promptEmbeddingTable = std::nullopt, + std::optional promptVocabSize = std::nullopt, + std::optional>> multimodalHashes = std::nullopt, + std::optional> multimodalPositions = std::nullopt, + std::optional> multimodalLengths = std::nullopt, + std::optional multimodalEmbedding = std::nullopt, + std::optional mropeRotaryCosSin = std::nullopt, + std::optional mropePositionDeltas = std::nullopt, + std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, + std::optional loraConfig = std::nullopt, + std::optional lookaheadConfig = std::nullopt, + std::optional kvCacheRetentionConfig = std::nullopt, + bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, + std::optional draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, + bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt, + bool applyLogitsPostProcessorBatched = false, std::optional encoderInputTokens = std::nullopt, + bool returnEncoderOutput = false, std::optional clientId = std::nullopt, + executor::PriorityType priority = executor::Request::kDefaultPriority, + std::optional encoderInputFeatures = std::nullopt, + std::optional encoderOutputLength = std::nullopt, + std::optional crossAttentionMask = std::nullopt, + tb::LlmRequestType llmRequestType = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, + std::optional inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1, + std::optional eagleConfig = std::nullopt, + std::optional skipCrossAttnBlocks = std::nullopt, bool returnPerfMetrics = false, + std::optional guidedDecodingParams = std::nullopt, + std::optional languageAdapterUid = std::nullopt, + std::optional allottedTimeMs = std::nullopt, + std::optional const& contextPhaseParams = std::nullopt) + : Base(requestId, // + maxNewTokens, // + std::make_shared>(std::move(inputTokens)), // + samplingConfig, // + isStreaming, // + endId, // + padId, // + embeddingBias, // + badWordsList, // + stopWordsList, // + positionIds.has_value() ? std::make_shared>(std::move(positionIds.value())) // + : std::optional>>(std::nullopt), // + promptEmbeddingTable, // + promptVocabSize, // + multimodalHashes.has_value() + ? std::make_optional( + std::make_shared>>(std::move(multimodalHashes.value()))) // + : std::optional>>>(std::nullopt), // + multimodalPositions.has_value() + ? std::make_shared>(std::move(multimodalPositions.value())) // + : std::optional>>(std::nullopt), // + multimodalLengths.has_value() + ? std::make_shared>(std::move(multimodalLengths.value())) // + : std::optional>>(std::nullopt), // + multimodalEmbedding, // + mropeRotaryCosSin, // + mropePositionDeltas, // + loraTaskId, // + loraWeights, // + loraConfig, // + lookaheadConfig, // + kvCacheRetentionConfig, // + returnLogProbs, // + returnContextLogits, // + returnGenerationLogits, // + draftTokens.has_value() ? std::make_shared(std::move(draftTokens.value())) // + : std::make_shared(), // + draftLogits, // + excludeInputFromOutput, // + logitsPostProcessor, // + applyLogitsPostProcessorBatched, // + encoderInputTokens ? std::make_optional(std::make_shared(std::move(*encoderInputTokens))) // + : std::optional>(std::nullopt), // + returnEncoderOutput, // + clientId, // + priority, // + encoderInputFeatures, // + encoderOutputLength, // + crossAttentionMask, // + llmRequestType, // + inputTokenExtraIds // + ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) // + : std::optional>(std::nullopt), // + numReturnSequences, // + eagleConfig, // + skipCrossAttnBlocks, // + returnPerfMetrics, // + guidedDecodingParams, // + languageAdapterUid, // + allottedTimeMs, // + contextPhaseParams // + ) + { + } + + static std::optional callbackAdapter( + std::optional callback); + + [[nodiscard]] std::shared_ptr toTrtLlm() const; +}; + +} // namespace tensorrt_llm::nanobind::batch_manager diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index adc82587433..dd01d21cced 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,14 +15,483 @@ * limitations under the License. */ +#include "tensorrt_llm/nanobind/common/customCasters.h" #include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h" +#include "tensorrt_llm/common/quantization.h" +#include "tensorrt_llm/nanobind/batch_manager/algorithms.h" +#include "tensorrt_llm/nanobind/batch_manager/bindings.h" +#include "tensorrt_llm/nanobind/batch_manager/buffers.h" +#include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h" +#include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h" +#include "tensorrt_llm/nanobind/executor/bindings.h" +#include "tensorrt_llm/nanobind/runtime/bindings.h" +#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h" +#include "tensorrt_llm/nanobind/userbuffers/bindings.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/gptJsonConfig.h" +#include "tensorrt_llm/runtime/ipcNvlsMemory.h" +#include "tensorrt_llm/runtime/memoryCounters.h" +#include "tensorrt_llm/runtime/samplingConfig.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" + +namespace nb = nanobind; +namespace tb = tensorrt_llm::batch_manager; +namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; +namespace tpb = tensorrt_llm::nanobind::batch_manager; +namespace tc = tensorrt_llm::common; +namespace tr = tensorrt_llm::runtime; +namespace tle = tensorrt_llm::executor; +using SizeType32 = tr::SizeType32; +using TokenIdType = tr::TokenIdType; +template +using OptVec = std::optional>; #if not defined(TRTLLM_NB_MODULE) #error "TRTLLM_NB_MODULE must be defined" #endif +namespace +{ +tr::SamplingConfig makeSamplingConfig(std::vector const& configs) +{ + return tr::SamplingConfig(configs); +} +} // namespace + NB_MODULE(TRTLLM_NB_MODULE, m) { m.doc() = "TensorRT-LLM Python bindings for C++ runtime"; m.attr("binding_type") = "nanobind"; + nb::set_leak_warnings(false); + + // Create MpiComm binding first since it's used in the executor bindings + nb::class_(m, "MpiComm") + .def_static("rank", + []() + { + auto& session = tensorrt_llm::mpi::MpiComm::session(); + return session.tensorrt_llm::mpi::MpiComm::getRank(); + }) + .def_static("size", + []() + { + auto& session = tensorrt_llm::mpi::MpiComm::session(); + return session.tensorrt_llm::mpi::MpiComm::getSize(); + }) + .def_static("local_size", + []() + { + auto& session = tensorrt_llm::mpi::MpiComm::localSession(); + return session.tensorrt_llm::mpi::MpiComm::getSize(); + }) + .def_static("local_init", []() { tensorrt_llm::mpi::MpiComm::localSession(); }) + .def_static("set_raw_mpi_session_by_fortran_handle", + [](int64_t fortran_handle) { tensorrt_llm::mpi::MpiComm::setRawSessionByFortran(fortran_handle); }) + .def_static("split", + [](size_t color, size_t rank) + { + auto& world = tensorrt_llm::mpi::MpiComm::world(); + tensorrt_llm::mpi::MpiComm::setSession(world.split(color, rank)); + }); + + nb::class_(m, "CudaStream") + .def( + "__init__", + [](tr::CudaStream* self, nb::object py_stream) + { + cudaStream_t stream = reinterpret_cast(nb::cast(py_stream)); + new (self) tr::CudaStream{stream}; + }, + nb::arg("stream_ptr")) + .def("get_device", &tr::CudaStream::getDevice); + + // Create submodule for executor bindings. + auto mExecutor = m.def_submodule("executor", "Executor bindings"); + auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime"); + auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings"); + auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings"); + auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings"); + + tensorrt_llm::nanobind::executor::initBindings(mExecutor); + tensorrt_llm::nanobind::runtime::initBindingsEarly(mInternalRuntime); + + auto buildInfo = m.def_submodule("BuildInfo"); + buildInfo.attr("ENABLE_MULTI_DEVICE") = nb::int_(ENABLE_MULTI_DEVICE); + + nb::class_(m, "PeftCacheManagerConfig") + .def(nb::init, std::optional, std::optional>(), + nb::arg("num_host_module_layer") = 0, nb::arg("num_device_module_layer") = 0, + nb::arg("optimal_adapter_size") = 8, nb::arg("max_adapter_size") = 64, nb::arg("num_put_workers") = 1, + nb::arg("num_ensure_workers") = 1, nb::arg("num_copy_streams") = 1, + nb::arg("max_pages_per_block_host") = 24, nb::arg("max_pages_per_block_device") = 8, + nb::arg("device_cache_percent") = std::nullopt, nb::arg("host_cache_size") = std::nullopt, + nb::arg("lora_prefetch_dir") = std::nullopt) + .def_rw("num_host_module_layer", &tb::PeftCacheManagerConfig::numHostModuleLayer) + .def_rw("num_device_module_layer", &tb::PeftCacheManagerConfig::numDeviceModuleLayer) + .def_rw("optimal_adapter_size", &tb::PeftCacheManagerConfig::optimalAdapterSize) + .def_rw("max_adapter_size", &tb::PeftCacheManagerConfig::maxAdapterSize) + .def_rw("num_put_workers", &tb::PeftCacheManagerConfig::numPutWorkers) + .def_rw("num_ensure_workers", &tb::PeftCacheManagerConfig::numEnsureWorkers) + .def_rw("num_copy_streams", &tb::PeftCacheManagerConfig::numCopyStreams) + .def_rw("max_pages_per_block_host", &tb::PeftCacheManagerConfig::maxPagesPerBlockHost) + .def_rw("max_pages_per_block_device", &tb::PeftCacheManagerConfig::maxPagesPerBlockDevice) + .def_rw("device_cache_percent", &tb::PeftCacheManagerConfig::deviceCachePercent) + .def_rw("host_cache_size", &tb::PeftCacheManagerConfig::hostCacheSize) + .def_rw("lora_prefetch_dir", &tb::PeftCacheManagerConfig::loraPrefetchDir); + + nb::enum_(m, "DataType") + .value("FLOAT", nvinfer1::DataType::kFLOAT) + .value("HALF", nvinfer1::DataType::kHALF) + .value("INT8", nvinfer1::DataType::kINT8) + .value("INT32", nvinfer1::DataType::kINT32) + .value("BOOL", nvinfer1::DataType::kBOOL) + .value("UINT8", nvinfer1::DataType::kUINT8) + .value("FP8", nvinfer1::DataType::kFP8) + .value("BF16", nvinfer1::DataType::kBF16) + .value("INT64", nvinfer1::DataType::kINT64) + .export_values(); + + nb::enum_(m, "GptModelVariant") + .value("GPT", tr::ModelConfig::ModelVariant::kGpt) + .value("GLM", tr::ModelConfig::ModelVariant::kGlm) + .value("CHATGLM", tr::ModelConfig::ModelVariant::kChatGlm) + .value("MAMBA", tr::ModelConfig::ModelVariant::kMamba) + .value("RECURRENTGEMMA", tr::ModelConfig::ModelVariant::kRecurrentGemma); + + nb::enum_(m, "KVCacheType") + .value("CONTINUOUS", tr::ModelConfig::KVCacheType::kCONTINUOUS) + .value("PAGED", tr::ModelConfig::KVCacheType::kPAGED) + .value("DISABLED", tr::ModelConfig::KVCacheType::kDISABLED) + .def("from_string", tr::ModelConfig::KVCacheTypeFromString); + + nb::enum_(m, "LayerType") + .value("ATTENTION", tr::ModelConfig::LayerType::kATTENTION) + .value("RECURRENT", tr::ModelConfig::LayerType::kRECURRENT); + + nb::enum_(m, "LoraModuleType") + .value("INVALID", tr::LoraModule::ModuleType::kINVALID) + .value("ATTN_QKV", tr::LoraModule::ModuleType::kATTN_QKV) + .value("ATTN_Q", tr::LoraModule::ModuleType::kATTN_Q) + .value("ATTN_K", tr::LoraModule::ModuleType::kATTN_K) + .value("ATTN_V", tr::LoraModule::ModuleType::kATTN_V) + .value("ATTN_DENSE", tr::LoraModule::ModuleType::kATTN_DENSE) + .value("MLP_H_TO_4H", tr::LoraModule::ModuleType::kMLP_H_TO_4H) + .value("MLP_4H_TO_H", tr::LoraModule::ModuleType::kMLP_4H_TO_H) + .value("MLP_GATE", tr::LoraModule::ModuleType::kMLP_GATE) + .value("CROSS_ATTN_QKV", tr::LoraModule::ModuleType::kCROSS_ATTN_QKV) + .value("CROSS_ATTN_Q", tr::LoraModule::ModuleType::kCROSS_ATTN_Q) + .value("CROSS_ATTN_K", tr::LoraModule::ModuleType::kCROSS_ATTN_K) + .value("CROSS_ATTN_V", tr::LoraModule::ModuleType::kCROSS_ATTN_V) + .value("CROSS_ATTN_DENSE", tr::LoraModule::ModuleType::kCROSS_ATTN_DENSE) + .value("MOE_H_TO_4H", tr::LoraModule::ModuleType::kMOE_H_TO_4H) + .value("MOE_4H_TO_H", tr::LoraModule::ModuleType::kMOE_4H_TO_H) + .value("MOE_GATE", tr::LoraModule::ModuleType::kMOE_GATE) + .value("MOE_ROUTER", tr::LoraModule::ModuleType::kMOE_ROUTER) + .value("MLP_ROUTER", tr::LoraModule::ModuleType::kMLP_ROUTER) + .value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP); + + nb::class_(m, "LoraModule") + .def(nb::init(), + nb::arg("module_type"), nb::arg("in_dim"), nb::arg("out_dim"), nb::arg("in_dim_first"), + nb::arg("out_dim_first"), nb::arg("in_tp_split_dim"), nb::arg("out_tp_split_dim")) + .def_prop_ro("module_type", &tr::LoraModule::name) + .def_prop_ro("in_dim", &tr::LoraModule::inDim) + .def_prop_ro("out_dim", &tr::LoraModule::outDim) + .def_prop_ro("in_dim_first", &tr::LoraModule::inDimFirst) + .def_prop_ro("out_dim_first", &tr::LoraModule::outDimFirst) + .def_prop_ro("in_tp_split_dim", &tr::LoraModule::inTpSplitDim) + .def_prop_ro("out_tp_split_dim", &tr::LoraModule::outTpSplitDim) + .def_static("create_lora_modules", &tr::LoraModule::createLoraModules, nb::arg("lora_module_names"), + nb::arg("hidden_size"), nb::arg("mlp_hidden_size"), nb::arg("num_attention_heads"), + nb::arg("num_kv_attention_heads"), nb::arg("attention_head_size"), nb::arg("tp_size") = 1, + nb::arg("num_experts") = 0); + + nb::class_(m, "QuantMode") + .def_static("none", &tc::QuantMode::none) + .def_static("int4_weights", &tc::QuantMode::int4Weights) + .def_static("int8_weights", &tc::QuantMode::int8Weights) + .def_static("activations", &tc::QuantMode::activations) + .def_static("per_channel_scaling", &tc::QuantMode::perChannelScaling) + .def_static("per_token_scaling", &tc::QuantMode::perTokenScaling) + .def_static("per_group_scaling", &tc::QuantMode::perGroupScaling) + .def_static("int8_kv_cache", &tc::QuantMode::int8KvCache) + .def_static("fp8_kv_cache", &tc::QuantMode::fp8KvCache) + .def_static("fp8_qdq", &tc::QuantMode::fp8Qdq) + .def_prop_ro("value", &tc::QuantMode::value) + .def("is_set", &tc::QuantMode::isSet, nb::arg("mode")) + .def_prop_ro("has_int4_weights", &tc::QuantMode::hasInt4Weights) + .def_prop_ro("has_int8_weights", &tc::QuantMode::hasInt8Weights) + .def_prop_ro("has_activations", &tc::QuantMode::hasActivations) + .def_prop_ro("has_per_channel_scaling", &tc::QuantMode::hasPerChannelScaling) + .def_prop_ro("has_per_token_scaling", &tc::QuantMode::hasPerTokenScaling) + .def_prop_ro("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling) + .def_prop_ro("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling) + .def_prop_ro("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache) + .def_prop_ro("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache) + .def_prop_ro("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq) + .def_prop_ro("has_nvfp4", &tc::QuantMode::hasNvfp4) + .def_prop_ro("has_w4a8_mxfp4_fp8", &tc::QuantMode::hasW4a8Mxfp4Fp8) + .def_prop_ro("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant) + .def_static("from_description", &tc::QuantMode::fromDescription, nb::arg("quantize_weights"), + nb::arg("quantize_activations"), nb::arg("per_token"), nb::arg("per_channel"), nb::arg("per_group"), + nb::arg("use_int4_weights"), nb::arg("use_int8_kv_cache"), nb::arg("use_fp8_kv_kache"), + nb::arg("use_fp8_qdq"), nb::arg("use_fp8_rowwise"), nb::arg("use_w4a8_qserve"), nb::arg("use_nvfp4"), + nb::arg("use_fp8_block_scales"), nb::arg("use_w4a8_mxfp4_fp8")) + .def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, nb::arg("per_token") = false, + nb::arg("per_channel") = false) + .def_static("use_weight_only", &tc::QuantMode::useWeightOnly, nb::arg("use_int4_weights") = false, + nb::arg("per_group") = false) + .def_static("from_quant_algo", &tc::QuantMode::fromQuantAlgo, nb::arg("quant_algo") = nb::none(), + nb::arg("kv_cache_quant_algo") = nb::none()) + .def(nb::self + nb::self) + .def(nb::self += nb::self) + .def(nb::self - nb::self) + .def(nb::self -= nb::self) + .def(nb::self == nb::self) + .def(nb::self != nb::self); + + nb::class_(m, "ModelConfig") + .def(nb::init(), + nb::arg("vocab_size"), nb::arg("num_layers"), nb::arg("num_attention_layers"), nb::arg("num_rnn_layers"), + nb::arg("num_heads"), nb::arg("hidden_size"), nb::arg("data_type")) + .def_prop_ro("vocab_size", &tr::ModelConfig::getVocabSize) + .def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, nb::arg("world_size")) + .def("num_layers", &tr::ModelConfig::getNbLayers, nb::arg("pipeline_parallelism") = 1, + nb::arg("pipeline_parallelism_rank") = 0) + .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, nb::arg("pipeline_parallelism") = 1, + nb::arg("pipeline_parallelism_rank") = 0) + .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, nb::arg("pipeline_parallelism") = 1, + nb::arg("pipeline_parallelism_rank") = 0) + .def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, nb::arg("layer_idx")) + .def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, nb::arg("num_kv_heads")) + .def_prop_ro("num_heads", &tr::ModelConfig::getNbHeads) + .def_prop_ro("hidden_size", &tr::ModelConfig::getHiddenSize) + .def_prop_ro("size_per_head", &tr::ModelConfig::getSizePerHead) + .def_prop_ro("data_type", &tr::ModelConfig::getDataType) + .def_prop_ro("speculative_decoding_mode", &tr::ModelConfig::getSpeculativeDecodingMode) + .def_prop_rw("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead) + .def_prop_rw( + "num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer) + .def_prop_rw("use_gpt_attention_plugin", + nb::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, nb::const_), + nb::overload_cast(&tr::ModelConfig::useGptAttentionPlugin)) + .def_prop_rw("use_packed_input", nb::overload_cast<>(&tr::ModelConfig::usePackedInput, nb::const_), + nb::overload_cast(&tr::ModelConfig::usePackedInput)) + .def_prop_rw("kv_cache_type", nb::overload_cast<>(&tr::ModelConfig::getKVCacheType, nb::const_), + nb::overload_cast(&tr::ModelConfig::setKVCacheType)) + .def_prop_rw("tokens_per_block", &tr::ModelConfig::getTokensPerBlock, &tr::ModelConfig::setTokensPerBlock) + .def_prop_rw("quant_mode", &tr::ModelConfig::getQuantMode, &tr::ModelConfig::setQuantMode) + .def_prop_ro("supports_inflight_batching", &tr::ModelConfig::supportsInflightBatching) + .def_prop_rw("max_batch_size", &tr::ModelConfig::getMaxBatchSize, &tr::ModelConfig::setMaxBatchSize) + .def_prop_rw("max_beam_width", &tr::ModelConfig::getMaxBeamWidth, &tr::ModelConfig::setMaxBeamWidth) + .def_prop_rw("max_input_len", &tr::ModelConfig::getMaxInputLen, &tr::ModelConfig::setMaxInputLen) + .def_prop_rw("max_seq_len", &tr::ModelConfig::getMaxSequenceLen, &tr::ModelConfig::setMaxSequenceLen) + .def_prop_rw("max_num_tokens", &tr::ModelConfig::getMaxNumTokens, &tr::ModelConfig::setMaxNumTokens) + .def_prop_rw("max_prompt_embedding_table_size", &tr::ModelConfig::getMaxPromptEmbeddingTableSize, + &tr::ModelConfig::setMaxPromptEmbeddingTableSize) + .def_prop_ro("use_prompt_tuning", &tr::ModelConfig::usePromptTuning) + .def_prop_ro("use_mrope", &tr::ModelConfig::useMrope) + .def_prop_rw("use_lora_plugin", nb::overload_cast<>(&tr::ModelConfig::useLoraPlugin, nb::const_), + nb::overload_cast(&tr::ModelConfig::useLoraPlugin)) + .def_prop_rw("layer_types", &tr::ModelConfig::getLayerTypes, &tr::ModelConfig::setLayerTypes) + .def_prop_rw("compute_context_logits", nb::overload_cast<>(&tr::ModelConfig::computeContextLogits, nb::const_), + nb::overload_cast(&tr::ModelConfig::computeContextLogits)) + .def_prop_rw("compute_generation_logits", + nb::overload_cast<>(&tr::ModelConfig::computeGenerationLogits, nb::const_), + nb::overload_cast(&tr::ModelConfig::computeGenerationLogits)) + .def_prop_rw("model_variant", &tr::ModelConfig::getModelVariant, &tr::ModelConfig::setModelVariant) + .def_prop_rw("use_cross_attention", &tr::ModelConfig::useCrossAttention, &tr::ModelConfig::setUseCrossAttention) + .def_prop_rw("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules) + .def_prop_rw("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank) + .def_prop_rw("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize) + .def_prop_rw("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead); + + nb::class_(m, "WorldConfig") + .def(nb::init> const&, bool>(), + nb::arg("tensor_parallelism") = 1, nb::arg("pipeline_parallelism") = 1, nb::arg("context_parallelism") = 1, + nb::arg("rank") = 0, nb::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, + nb::arg("device_ids") = nb::none(), nb::arg("enable_attention_dp") = false) + .def_prop_ro("size", &tr::WorldConfig::getSize) + .def_prop_ro("tensor_parallelism", &tr::WorldConfig::getTensorParallelism) + .def_prop_ro("pipeline_parallelism", &tr::WorldConfig::getPipelineParallelism) + .def_prop_ro("context_parallelism", &tr::WorldConfig::getContextParallelism) + .def_prop_ro("is_tensor_parallel", &tr::WorldConfig::isTensorParallel) + .def_prop_ro("is_pipeline_parallel", &tr::WorldConfig::isPipelineParallel) + .def_prop_ro("is_context_parallel", &tr::WorldConfig::isContextParallel) + .def_prop_ro("rank", &tr::WorldConfig::getRank) + .def_prop_ro("local_rank", &tr::WorldConfig::getLocalRank) + .def_prop_ro("node_rank", &tr::WorldConfig::getNodeRank) + .def_prop_ro("gpus_per_node", &tr::WorldConfig::getGpusPerNode) + .def_prop_ro("gpus_per_group", &tr::WorldConfig::getGpusPerGroup) + .def_prop_ro("device", &tr::WorldConfig::getDevice) + .def_prop_ro("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank) + .def_prop_ro("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank) + .def_prop_ro("context_parallel_rank", &tr::WorldConfig::getContextParallelRank) + .def_prop_ro("enable_attention_dp", &tr::WorldConfig::enableAttentionDP) + .def_static("mpi", + nb::overload_cast, std::optional, + std::optional, std::optional> const&, bool>(&tr::WorldConfig::mpi), + nb::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, nb::arg("tensor_parallelism") = nb::none(), + nb::arg("pipeline_parallelism") = nb::none(), nb::arg("context_parallelism") = nb::none(), + nb::arg("device_ids") = nb::none(), nb::arg("enable_attention_dp") = false); + + auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> nb::tuple + { + return nb::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty, + config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed, + config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty, + config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP, + config.beamWidthArray); + }; + auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) -> tr::SamplingConfig + { + assert(t.size() == 19); + + tr::SamplingConfig config; + config.beamWidth = nb::cast(t[0]); + config.temperature = nb::cast>(t[1]); + config.minLength = nb::cast>(t[2]); + config.repetitionPenalty = nb::cast>(t[3]); + config.presencePenalty = nb::cast>(t[4]); + config.frequencyPenalty = nb::cast>(t[5]); + config.topK = nb::cast>(t[6]); + config.topP = nb::cast>(t[7]); + config.randomSeed = nb::cast>(t[8]); + config.topPDecay = nb::cast>(t[9]); + config.topPMin = nb::cast>(t[10]); + config.topPResetIds = nb::cast>(t[11]); + config.beamSearchDiversityRate = nb::cast>(t[12]); + config.lengthPenalty = nb::cast>(t[13]); + config.earlyStopping = nb::cast>(t[14]); + config.noRepeatNgramSize = nb::cast>(t[15]); + config.numReturnSequences = nb::cast(t[16]); + config.minP = nb::cast>(t[17]); + config.beamWidthArray = nb::cast>>(t[18]); + + return config; + }; + + nb::class_(m, "SamplingConfig") + .def(nb::init(), nb::arg("beam_width") = 1) + .def(nb::init>(), + nb::arg("executor_sample_config"), nb::arg("external_draft_tokens_config") = std::nullopt) + .def_rw("beam_width", &tr::SamplingConfig::beamWidth) + .def_rw("temperature", &tr::SamplingConfig::temperature) + .def_rw("min_length", &tr::SamplingConfig::minLength) + .def_rw("repetition_penalty", &tr::SamplingConfig::repetitionPenalty) + .def_rw("presence_penalty", &tr::SamplingConfig::presencePenalty) + .def_rw("frequency_penalty", &tr::SamplingConfig::frequencyPenalty) + .def_rw("top_k", &tr::SamplingConfig::topK) + .def_rw("top_p", &tr::SamplingConfig::topP) + .def_rw("random_seed", &tr::SamplingConfig::randomSeed) + .def_rw("top_p_decay", &tr::SamplingConfig::topPDecay) + .def_rw("top_p_min", &tr::SamplingConfig::topPMin) + .def_rw("top_p_reset_ids", &tr::SamplingConfig::topPResetIds) + .def_rw("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate) + .def_rw("length_penalty", &tr::SamplingConfig::lengthPenalty) + .def_rw("early_stopping", &tr::SamplingConfig::earlyStopping) + .def_rw("no_repeat_ngram_size", &tr::SamplingConfig::noRepeatNgramSize) + .def_rw("num_return_sequences", &tr::SamplingConfig::numReturnSequences) + .def_rw("min_p", &tr::SamplingConfig::minP) + .def_rw("beam_width_array", &tr::SamplingConfig::beamWidthArray) + .def_rw("normalize_log_probs", &tr::SamplingConfig::normalizeLogProbs) + .def("__getstate__", SamplingConfigGetState) + .def("__setstate__", SamplingConfigSetState) + .def("__eq__", &tr::SamplingConfig::operator==); + + nb::bind_vector>(m, "SamplingConfigVector"); + + m.def("make_sampling_config", &makeSamplingConfig, nb::arg("configs")); + + nb::class_(m, "GptJsonConfig") + .def(nb::init>(), + nb::arg("name"), nb::arg("version"), nb::arg("precision"), nb::arg("tensor_parallelism"), + nb::arg("pipeline_parallelism"), nb::arg("context_parallelism"), nb::arg("gpus_per_node"), + nb::arg("model_config"), nb::arg("runtime_defaults") = nb::none()) + .def_static("parse", nb::overload_cast(&tr::GptJsonConfig::parse), nb::arg("json")) + .def_static( + "parse_file", nb::overload_cast(&tr::GptJsonConfig::parse), nb::arg("path")) + .def_prop_ro("model_config", &tr::GptJsonConfig::getModelConfig) + .def_prop_ro("name", &tr::GptJsonConfig::getName) + .def_prop_ro("version", &tr::GptJsonConfig::getVersion) + .def_prop_ro("precision", &tr::GptJsonConfig::getPrecision) + .def_prop_ro("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism) + .def_prop_ro("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism) + .def_prop_ro("context_parallelism", &tr::GptJsonConfig::getContextParallelism) + .def_prop_ro("gpus_per_node", &tr::GptJsonConfig::getGpusPerNode) + .def_prop_ro("world_size", &tr::GptJsonConfig::getWorldSize) + .def_prop_ro("runtime_defaults", &tr::GptJsonConfig::getRuntimeDefaults) + .def("engine_filename", + nb::overload_cast( + &tr::GptJsonConfig::engineFilename, nb::const_), + nb::arg("world_config"), nb::arg("model")) + .def("engine_filename", + nb::overload_cast(&tr::GptJsonConfig::engineFilename, nb::const_), + nb::arg("world_config")); + + nb::enum_(m, "LlmRequestState") + .value("UNKNOWN", tb::LlmRequestState::kUNKNOWN) + .value("ENCODER_INIT", tb::LlmRequestState::kENCODER_INIT) + .value("CONTEXT_INIT", tb::LlmRequestState::kCONTEXT_INIT) + .value("GENERATION_IN_PROGRESS", tb::LlmRequestState::kGENERATION_IN_PROGRESS) + .value("GENERATION_TO_COMPLETE", tb::LlmRequestState::kGENERATION_TO_COMPLETE) + .value("GENERATION_COMPLETE", tb::LlmRequestState::kGENERATION_COMPLETE) + .value("DISAGG_GENERATION_INIT", tb::LlmRequestState::kDISAGG_GENERATION_INIT) + .value("DISAGG_CONTEXT_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS) + .value("DISAGG_CONTEXT_COMPLETE", tb::LlmRequestState::kDISAGG_CONTEXT_COMPLETE) + .value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS) + .value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE) + .value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); + + nb::class_(m, "MemoryCounters") + .def_static("instance", &tr::MemoryCounters::getInstance, nb::rv_policy::reference) + .def_prop_ro("gpu", &tr::MemoryCounters::getGpu) + .def_prop_ro("cpu", &tr::MemoryCounters::getCpu) + .def_prop_ro("pinned", &tr::MemoryCounters::getPinned) + .def_prop_ro("uvm", &tr::MemoryCounters::getUVM); + + tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime); + tensorrt_llm::nanobind::testing::initBindings(mInternalTesting); + tpb::initBindings(mInternalBatchManager); + tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager); + tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager); + tb::CacheTransceiverBindings::initBindings(mInternalBatchManager); + tpb::Buffers::initBindings(mInternalBatchManager); + + auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings"); + tpb::algorithms::initBindings(mInternalAlgorithms); + + auto mUserbuffers = mInternal.def_submodule("userbuffers", "User buffers internal bindings"); + tensorrt_llm::kernels::userbuffers::UserBufferBindings::initBindings(mUserbuffers); + + // NVLS allocators + nb::class_(m, "IpcNvlsHandle") + .def(nb::init<>()) + .def_rw("uc_ptr", &tr::IpcNvlsHandle::uc_ptr) + .def_rw("mc_ptr", &tr::IpcNvlsHandle::mc_ptr) + .def_rw("size", &tr::IpcNvlsHandle::size) + .def("get_ipc_ptrs", + [](tr::IpcNvlsHandle& self) { return reinterpret_cast(self.ipc_uc_ptrs.data()); }); + + m.def("ipc_nvls_allocate", &tr::ipcNvlsAllocate, nb::rv_policy::reference); + m.def("ipc_nvls_free", &tr::ipcNvlsFree); + m.def("ipc_nvls_supported", &tr::ipcNvlsSupported); } diff --git a/cpp/tensorrt_llm/nanobind/common/bindTypes.h b/cpp/tensorrt_llm/nanobind/common/bindTypes.h new file mode 100644 index 00000000000..5cd714e458a --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/common/bindTypes.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace PybindUtils +{ + +namespace nb = nanobind; + +template +void bindList(nb::module_& m, std::string const& name) +{ + nb::class_(m, name.c_str()) + .def(nb::init<>()) + .def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); }) + .def("pop_back", [](T& lst) { lst.pop_back(); }) + .def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); }) + .def("pop_front", [](T& lst) { lst.pop_front(); }) + .def("__len__", [](T const& lst) { return lst.size(); }) + .def( + "__iter__", [](T& lst) { return nb::make_iterator(nb::type(), "iterator", lst.begin(), lst.end()); }, + nb::keep_alive<0, 1>()) + .def("__getitem__", + [](T const& lst, size_t index) + { + if (index >= lst.size()) + throw nb::index_error(); + auto it = lst.begin(); + std::advance(it, index); + return *it; + }) + .def("__setitem__", + [](T& lst, size_t index, const typename T::value_type& value) + { + if (index >= lst.size()) + throw nb::index_error(); + auto it = lst.begin(); + std::advance(it, index); + *it = value; + }); +} + +template +void bindSet(nb::module_& m, std::string const& name) +{ + nb::class_(m, name.c_str()) + .def(nb::init<>()) + .def("clear", &T::clear) + .def("size", &T::size) + .def("insert", [](T& s, typename T::value_type const& value) { s.insert(value); }) + .def("erase", nb::overload_cast(&T::erase)) + .def("__len__", [](T const& lst) { return lst.size(); }) + .def("__contains__", [](T const& s, typename T::value_type x) { return s.find(x) != s.end(); }) + .def( + "__iter__", [](T& s) { return nb::make_iterator(nb::type(), "iterator", s.begin(), s.end()); }, + nb::keep_alive<0, 1>()) + .def("__eq__", [](T const& s, T const& other) { return s == other; }) + .def("__getstate__", + [](T const& v) + { + /* Return a tuple that fully encodes the state of the object */ + return nb::make_tuple(std::vector(v.begin(), v.end())); + }) + .def("__setstate__", + [](T& v, nb::tuple const& t) + { + if (t.size() != 1) + throw std::runtime_error("Invalid state!"); + /* Create a new C++ instance */ + T s; + /* Assign any additional state */ + auto state_list = nb::cast>(t[0]); + for (auto& item : state_list) + { + s.insert(item); + } + return s; + }); +} + +} // namespace PybindUtils diff --git a/cpp/tensorrt_llm/nanobind/common/customCasters.h b/cpp/tensorrt_llm/nanobind/common/customCasters.h new file mode 100644 index 00000000000..7cfa07d249a --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/common/customCasters.h @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/batch_manager/common.h" +#include "tensorrt_llm/batch_manager/decoderBuffers.h" +#include "tensorrt_llm/common/optionalRef.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/request.h" +#include "tensorrt_llm/runtime/samplingConfig.h" +#include "tensorrt_llm/runtime/torch.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Pybind requires to have a central include in order for type casters to work. +// Opaque bindings add a type caster, so they have the same requirement. +// See the warning in https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html + +// Opaque bindings +NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet) +NB_MAKE_OPAQUE(std::vector) +NB_MAKE_OPAQUE(std::vector) +NB_MAKE_OPAQUE(std::vector) +NB_MAKE_OPAQUE(std::vector>) + +namespace nb = nanobind; + +// Custom casters +namespace NB_NAMESPACE +{ + +namespace detail +{ + +template +struct type_caster> +{ + using Type = std::deque; + NB_TYPE_CASTER(Type, const_name("List")); + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept + { + sequence seq(src, nanobind::detail::borrow_t{}); + value.clear(); + make_caster caster; + for (auto const& item : seq) + { + if (!caster.from_python(item, flags, cleanup)) + return false; + value.push_back(caster.operator T&()); + } + return true; + } + + static handle from_cpp(Type const& deque, rv_policy policy, cleanup_list* cleanup) noexcept + { + nb::list list; + + for (auto const& item : deque) + { + nb::object py_item = steal(make_caster::from_cpp(item, policy, cleanup)); + if (!py_item) + return {}; + list.append(py_item); + } + return list.release(); + } +}; + +template +struct type_caster> +{ + using value_conv = make_caster; + + NB_TYPE_CASTER(tensorrt_llm::common::OptionalRef, value_conv::Name); + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) + { + if (src.is_none()) + { + // If the Python object is None, create an empty OptionalRef + value = tensorrt_llm::common::OptionalRef(); + return true; + } + + value_conv conv; + if (!conv.from_python(src, flags, cleanup)) + return false; + + // Create an OptionalRef with a reference to the converted value + value = tensorrt_llm::common::OptionalRef(conv); + return true; + } + + static handle from_cpp(tensorrt_llm::common::OptionalRef const& src, rv_policy policy, cleanup_list* cleanup) + { + if (!src.has_value()) + return none().release(); + + return value_conv::from_cpp(*src, policy, cleanup); + } +}; + +template +struct PathCaster +{ + +private: + static PyObject* unicode_from_fs_native(std::string const& w) + { + return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size())); + } + + static PyObject* unicode_from_fs_native(std::wstring const& w) + { + return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size())); + } + +public: + static handle from_cpp(T const& path, rv_policy, cleanup_list* cleanup) + { + if (auto py_str = unicode_from_fs_native(path.native())) + { + return module_::import_("pathlib").attr("Path")(steal(py_str), cleanup).release(); + } + return nullptr; + } + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) + { + PyObject* native = nullptr; + if constexpr (std::is_same_v) + { + if (PyUnicode_FSConverter(src.ptr(), &native) != 0) + { + if (auto* c_str = PyBytes_AsString(native)) + { + // AsString returns a pointer to the internal buffer, which + // must not be free'd. + value = c_str; + } + } + } + else if constexpr (std::is_same_v) + { + if (PyUnicode_FSDecoder(src.ptr(), &native) != 0) + { + if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr)) + { + // AsWideCharString returns a new string that must be free'd. + value = c_str; // Copies the string. + PyMem_Free(c_str); + } + } + } + Py_XDECREF(native); + if (PyErr_Occurred()) + { + PyErr_Clear(); + return false; + } + return true; + } + + NB_TYPE_CASTER(T, const_name("os.PathLike")); +}; + +template <> +class type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::executor::StreamPtr, const_name("int")); + + bool from_python([[maybe_unused]] handle src, uint8_t flags, cleanup_list* cleanup) + { + auto stream_ptr = nanobind::cast(src); + value = std::make_shared(reinterpret_cast(stream_ptr)); + + return true; + } + + static handle from_cpp( + tensorrt_llm::executor::StreamPtr const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + // Return cudaStream_t as integer. + return PyLong_FromVoidPtr(src->get()); + } +}; + +template <> +struct type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::executor::Tensor, const_name("torch.Tensor")); + + // Convert PyObject(torch.Tensor) -> tensorrt_llm::executor::Tensor + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) + { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + at::Tensor const& t = THPVariable_Unpack(obj); + value = tensorrt_llm::executor::detail::ofITensor(tensorrt_llm::runtime::TorchView::of(t)); + return true; + } + return false; + } + + // Convert tensorrt_llm::executor::Tensor -> PyObject(torch.Tensor) + static handle from_cpp( + tensorrt_llm::executor::Tensor const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(tensorrt_llm::executor::detail::toITensor(src))); + } +}; + +template <> +struct type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedPtr, const_name("torch.Tensor")); + + // Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedPtr + bool from_python(handle src, uint8_t, cleanup_list*) + { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + at::Tensor const& t = THPVariable_Unpack(obj); + value = std::move(tensorrt_llm::runtime::TorchView::of(t)); + return true; + } + return false; + } + + // Convert tensorrt_llm::runtime::ITensor::SharedPtr -> PyObject(torch.Tensor) + static handle from_cpp( + tensorrt_llm::runtime::ITensor::SharedPtr const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + if (src == nullptr) + { + return none().release(); + } + return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(src)); + } +}; + +template <> +struct type_caster +{ +public: + NB_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedConstPtr, const_name("torch.Tensor")); + + // Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedConstPtr + bool from_python(handle src, uint8_t, cleanup_list*) + { + PyObject* obj = src.ptr(); + if (THPVariable_Check(obj)) + { + at::Tensor const& t = THPVariable_Unpack(obj); + value = std::move(tensorrt_llm::runtime::TorchView::of(t)); + return true; + } + return false; + } + + // Convert tensorrt_llm::runtime::ITensor::SharedConstPtr -> PyObject(torch.Tensor) + static handle from_cpp( + tensorrt_llm::runtime::ITensor::SharedConstPtr const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + if (src == nullptr) + { + return none().release(); + } + return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor( + reinterpret_cast(src))); + } +}; + +template <> +struct type_caster +{ + NB_TYPE_CASTER(at::Tensor, const_name("torch.Tensor")); + + bool from_python(nb::handle src, uint8_t, cleanup_list*) noexcept + { + nb::object capsule = nb::getattr(src, "__dlpack__")(); + DLManagedTensor* dl_managed = static_cast(PyCapsule_GetPointer(capsule.ptr(), "dltensor")); + PyCapsule_SetDestructor(capsule.ptr(), nullptr); + value = at::fromDLPack(dl_managed).alias(); + return true; + } + + static handle from_cpp(at::Tensor tensor, rv_policy, cleanup_list*) noexcept + { + DLManagedTensor* dl_managed = at::toDLPack(tensor); + if (!dl_managed) + return nullptr; + + nanobind::object capsule = nb::steal(PyCapsule_New(dl_managed, "dltensor", + [](PyObject* obj) + { + DLManagedTensor* dl = static_cast(PyCapsule_GetPointer(obj, "dltensor")); + dl->deleter(dl); + })); + if (!capsule.is_valid()) + { + dl_managed->deleter(dl_managed); + return nullptr; + } + nanobind::module_ torch = nanobind::module_::import_("torch"); + nanobind::object result = torch.attr("from_dlpack")(capsule); + capsule.release(); + return result.release(); + } +}; +} // namespace detail +} // namespace NB_NAMESPACE diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp new file mode 100644 index 00000000000..d3f482df899 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -0,0 +1,263 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "executor.h" +#include "executorConfig.h" +#include "request.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; +using SizeType32 = tle::SizeType32; + +namespace tensorrt_llm::nanobind::executor +{ + +template +void instantiateEventDiff(nb::module_& m, std::string const& name) +{ + nb::class_>(m, ("KVCacheEventDiff" + name).c_str()) + .def_ro("old_value", &tle::KVCacheEventDiff::oldValue) + .def_ro("new_value", &tle::KVCacheEventDiff::newValue); +} + +void initBindings(nb::module_& m) +{ + m.attr("__version__") = tle::version(); + nb::enum_(m, "ModelType") + .value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY) + .value("ENCODER_ONLY", tle::ModelType::kENCODER_ONLY) + .value("ENCODER_DECODER", tle::ModelType::kENCODER_DECODER); + + auto decodingModeGetstate = [](tle::DecodingMode const& self) { return nb::make_tuple(self.getState()); }; + auto decodingModeSetstate = [](tle::DecodingMode& self, nb::tuple const& state) + { + if (state.size() != 1) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DecodingMode(nb::cast(state[0])); + }; + nb::class_(m, "DecodingMode") + .def("Auto", &tle::DecodingMode::Auto) + .def("TopK", &tle::DecodingMode::TopK) + .def("TopP", &tle::DecodingMode::TopP) + .def("TopKTopP", &tle::DecodingMode::TopKTopP) + .def("BeamSearch", &tle::DecodingMode::BeamSearch) + .def("Medusa", &tle::DecodingMode::Medusa) + .def("Lookahead", &tle::DecodingMode::Lookahead) + .def("ExplicitDraftTokens", &tle::DecodingMode::ExplicitDraftTokens) + .def("Eagle", &tle::DecodingMode::Eagle) + .def("isAuto", &tle::DecodingMode::isAuto) + .def("isTopK", &tle::DecodingMode::isTopK) + .def("isTopP", &tle::DecodingMode::isTopP) + .def("isTopKorTopP", &tle::DecodingMode::isTopKorTopP) + .def("isTopKandTopP", &tle::DecodingMode::isTopKandTopP) + .def("isBeamSearch", &tle::DecodingMode::isBeamSearch) + .def("isMedusa", &tle::DecodingMode::isMedusa) + .def("isLookahead", &tle::DecodingMode::isLookahead) + .def("isExplicitDraftTokens", &tle::DecodingMode::isExplicitDraftTokens) + .def("isEagle", &tle::DecodingMode::isEagle) + .def("useVariableBeamWidthSearch", &tle::DecodingMode::useVariableBeamWidthSearch) + .def_prop_ro("name", &tle::DecodingMode::getName) + .def("__getstate__", decodingModeGetstate) + .def("__setstate__", decodingModeSetstate); + + nb::enum_(m, "CapacitySchedulerPolicy") + .value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION) + .value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) + .value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH); + + nb::enum_(m, "ContextChunkingPolicy") + .value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS) + .value("FIRST_COME_FIRST_SERVED", tle::ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED); + + nb::enum_(m, "CommunicationType").value("MPI", tle::CommunicationType::kMPI); + + nb::enum_(m, "CommunicationMode") + .value("LEADER", tle::CommunicationMode::kLEADER) + .value("ORCHESTRATOR", tle::CommunicationMode::kORCHESTRATOR); + + nb::class_(m, "KvCacheStats") + .def(nb::init<>()) + .def_rw("max_num_blocks", &tle::KvCacheStats::maxNumBlocks) + .def_rw("free_num_blocks", &tle::KvCacheStats::freeNumBlocks) + .def_rw("used_num_blocks", &tle::KvCacheStats::usedNumBlocks) + .def_rw("tokens_per_block", &tle::KvCacheStats::tokensPerBlock) + .def_rw("alloc_total_blocks", &tle::KvCacheStats::allocTotalBlocks) + .def_rw("alloc_new_blocks", &tle::KvCacheStats::allocNewBlocks) + .def_rw("reused_blocks", &tle::KvCacheStats::reusedBlocks) + .def_rw("missed_blocks", &tle::KvCacheStats::missedBlocks) + .def_rw("cache_hit_rate", &tle::KvCacheStats::cacheHitRate); + + nb::class_(m, "StaticBatchingStats") + .def(nb::init<>()) + .def_rw("num_scheduled_requests", &tle::StaticBatchingStats::numScheduledRequests) + .def_rw("num_context_requests", &tle::StaticBatchingStats::numContextRequests) + .def_rw("num_ctx_tokens", &tle::StaticBatchingStats::numCtxTokens) + .def_rw("num_gen_tokens", &tle::StaticBatchingStats::numGenTokens) + .def_rw("empty_gen_slots", &tle::StaticBatchingStats::emptyGenSlots); + + nb::class_(m, "InflightBatchingStats") + .def(nb::init<>()) + .def_rw("num_scheduled_requests", &tle::InflightBatchingStats::numScheduledRequests) + .def_rw("num_context_requests", &tle::InflightBatchingStats::numContextRequests) + .def_rw("num_gen_requests", &tle::InflightBatchingStats::numGenRequests) + .def_rw("num_paused_requests", &tle::InflightBatchingStats::numPausedRequests) + .def_rw("num_ctx_tokens", &tle::InflightBatchingStats::numCtxTokens) + .def_rw("micro_batch_id", &tle::InflightBatchingStats::microBatchId) + .def_rw("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter); + + nb::class_(m, "SpecDecodingStats") + .def(nb::init<>()) + .def_rw("num_draft_tokens", &tle::SpecDecodingStats::numDraftTokens) + .def_rw("num_accepted_tokens", &tle::SpecDecodingStats::numAcceptedTokens) + .def_rw("num_requests_with_draft_tokens", &tle::SpecDecodingStats::numRequestsWithDraftTokens) + .def_rw("acceptance_length", &tle::SpecDecodingStats::acceptanceLength) + .def_rw("iter_latency_ms", &tle::SpecDecodingStats::iterLatencyMS) + .def_rw("draft_overhead", &tle::SpecDecodingStats::draftOverhead); + + nb::class_(m, "IterationStats") + .def(nb::init<>()) + .def_rw("timestamp", &tle::IterationStats::timestamp) + .def_rw("iter", &tle::IterationStats::iter) + .def_rw("iter_latency_ms", &tle::IterationStats::iterLatencyMS) + .def_rw("new_active_requests_queue_latency_ms", &tle::IterationStats::newActiveRequestsQueueLatencyMS) + .def_rw("num_new_active_requests", &tle::IterationStats::numNewActiveRequests) + .def_rw("num_active_requests", &tle::IterationStats::numActiveRequests) + .def_rw("num_queued_requests", &tle::IterationStats::numQueuedRequests) + .def_rw("num_completed_requests", &tle::IterationStats::numCompletedRequests) + .def_rw("max_num_active_requests", &tle::IterationStats::maxNumActiveRequests) + .def_rw("gpu_mem_usage", &tle::IterationStats::gpuMemUsage) + .def_rw("cpu_mem_usage", &tle::IterationStats::cpuMemUsage) + .def_rw("pinned_mem_usage", &tle::IterationStats::pinnedMemUsage) + .def_rw("kv_cache_stats", &tle::IterationStats::kvCacheStats) + .def_rw("cross_kv_cache_stats", &tle::IterationStats::crossKvCacheStats) + .def_rw("static_batching_stats", &tle::IterationStats::staticBatchingStats) + .def_rw("inflight_batching_stats", &tle::IterationStats::inflightBatchingStats) + .def_rw("specdec_stats", &tle::IterationStats::specDecodingStats) + .def("to_json_str", + [](tle::IterationStats const& iterationStats) + { return tle::JsonSerialization::toJsonStr(iterationStats); }); + + nb::class_(m, "DebugTensorsPerIteration") + .def(nb::init<>()) + .def_rw("iter", &tle::DebugTensorsPerIteration::iter) + .def_rw("debug_tensors", &tle::DebugTensorsPerIteration::debugTensors); + + nb::enum_(m, "RequestStage") + .value("QUEUED", tle::RequestStage::kQUEUED) + .value("ENCODER_IN_PROGRESS", tle::RequestStage::kENCODER_IN_PROGRESS) + .value("CONTEXT_IN_PROGRESS", tle::RequestStage::kCONTEXT_IN_PROGRESS) + .value("GENERATION_IN_PROGRESS", tle::RequestStage::kGENERATION_IN_PROGRESS) + .value("GENERATION_COMPLETE", tle::RequestStage::kGENERATION_COMPLETE); + + nb::class_(m, "DisServingRequestStats") + .def(nb::init<>()) + .def_rw("kv_cache_transfer_ms", &tle::DisServingRequestStats::kvCacheTransferMS) + .def_rw("kv_cache_size", &tle::DisServingRequestStats::kvCacheSize); + + nb::class_(m, "RequestStats") + .def(nb::init<>()) + .def_rw("id", &tle::RequestStats::id) + .def_rw("stage", &tle::RequestStats::stage) + .def_rw("context_prefill_position", &tle::RequestStats::contextPrefillPosition) + .def_rw("num_generated_tokens", &tle::RequestStats::numGeneratedTokens) + .def_rw("avg_num_decoded_tokens_per_iter", &tle::RequestStats::avgNumDecodedTokensPerIter) + .def_rw("scheduled", &tle::RequestStats::scheduled) + .def_rw("paused", &tle::RequestStats::paused) + .def_rw("dis_serving_stats", &tle::RequestStats::disServingStats) + .def_rw("alloc_total_blocks_per_request", &tle::RequestStats::allocTotalBlocksPerRequest) + .def_rw("alloc_new_blocks_per_request", &tle::RequestStats::allocNewBlocksPerRequest) + .def_rw("reused_blocks_per_request", &tle::RequestStats::reusedBlocksPerRequest) + .def_rw("missed_blocks_per_request", &tle::RequestStats::missedBlocksPerRequest) + .def_rw("kv_cache_hit_rate_per_request", &tle::RequestStats::kvCacheHitRatePerRequest) + .def("to_json_str", + [](tle::RequestStats const& iterationStats) { return tle::JsonSerialization::toJsonStr(iterationStats); }); + + nb::class_(m, "RequestStatsPerIteration") + .def(nb::init<>()) + .def_rw("iter", &tle::RequestStatsPerIteration::iter) + .def_rw("request_stats", &tle::RequestStatsPerIteration::requestStats) + .def("to_json_str", + [](tle::RequestStatsPerIteration const& iterationStats) + { return tle::JsonSerialization::toJsonStr(iterationStats); }); + + nb::module_ executor_kv_cache = m.def_submodule("kv_cache", "Executor KV Cache Manager"); + + nb::class_(executor_kv_cache, "KVCacheCreatedData") + .def_ro("num_blocks_per_cache_level", &tle::KVCacheCreatedData::numBlocksPerCacheLevel); + + nb::class_(executor_kv_cache, "UniqueToken") + .def_ro("token_id", &tensorrt_llm::runtime::UniqueToken::tokenId) + .def_ro("token_extra_id", &tensorrt_llm::runtime::UniqueToken::tokenExtraId); + + nb::class_(executor_kv_cache, "KVCacheStoredBlockData") + .def_ro("block_hash", &tle::KVCacheStoredBlockData::blockHash) + .def_ro("tokens", &tle::KVCacheStoredBlockData::tokens) + .def_ro("lora_id", &tle::KVCacheStoredBlockData::loraId) + .def_ro("cache_level", &tle::KVCacheStoredBlockData::cacheLevel) + .def_ro("priority", &tle::KVCacheStoredBlockData::priority); + + nb::class_(executor_kv_cache, "KVCacheStoredData") + .def_ro("parent_hash", &tle::KVCacheStoredData::parentHash) + .def_ro("blocks", &tle::KVCacheStoredData::blocks); + + nb::class_(executor_kv_cache, "KVCacheRemovedData") + .def_ro("block_hashes", &tle::KVCacheRemovedData::blockHashes); + + instantiateEventDiff(executor_kv_cache, "Int"); + + nb::class_(executor_kv_cache, "KVCacheUpdatedData") + .def_ro("block_hash", &tle::KVCacheUpdatedData::blockHash) + .def_ro("cache_level", &tle::KVCacheUpdatedData::cacheLevel) + .def_ro("priority", &tle::KVCacheUpdatedData::priority); + + nb::class_(executor_kv_cache, "KVCacheEvent") + .def_ro("event_id", &tle::KVCacheEvent::eventId) + .def_ro("data", &tle::KVCacheEvent::data) + .def_ro("window_size", &tle::KVCacheEvent::windowSize); + + nb::class_(executor_kv_cache, "KVCacheEventManager") + .def( + "get_latest_events", + [](tle::KVCacheEventManager& self, std::optional timeout_ms = std::nullopt) + { + if (timeout_ms) + { + return self.getLatestEvents(std::chrono::milliseconds(static_cast(*timeout_ms))); + } + return self.getLatestEvents(std::nullopt); + }, + nb::arg("timeout_ms") = std::nullopt); + + tensorrt_llm::nanobind::executor::initRequestBindings(m); + tensorrt_llm::nanobind::executor::initConfigBindings(m); + tensorrt_llm::nanobind::executor::Executor::initBindings(m); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.h b/cpp/tensorrt_llm/nanobind/executor/bindings.h new file mode 100644 index 00000000000..4df52c2d34e --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::executor +{ + +// Register bindings for executor API. +void initBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executor.cpp b/cpp/tensorrt_llm/nanobind/executor/executor.cpp new file mode 100644 index 00000000000..59c7d2a3dc1 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executor.cpp @@ -0,0 +1,241 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "executor.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; + +namespace nanobind::detail +{ + +template <> +struct dtype_traits +{ + static constexpr dlpack::dtype value{ + (uint8_t) dlpack::dtype_code::Float, // type code + 16, // size in bits + 1 // lanes (simd), usually set to 1 + }; + static constexpr auto name = const_name("float16"); +}; +} // namespace nanobind::detail + +namespace +{ +// todo: Properly support FP8 and BF16 and verify functionality +tle::Tensor numpyToTensor(nb::ndarray const& array) +{ + auto npDtype = array.dtype(); + char kind = '\0'; + switch (npDtype.code) + { + case static_cast(nb::dlpack::dtype_code::Int): + kind = 'i'; // signed integer + break; + case static_cast(nb::dlpack::dtype_code::UInt): + kind = 'u'; // unsigned integer + break; + case static_cast(nb::dlpack::dtype_code::Float): + kind = 'f'; // floating point + break; + case static_cast(nb::dlpack::dtype_code::Bfloat): + kind = 'f'; // brain floating point (treat as float kind) + break; + case static_cast(nb::dlpack::dtype_code::Complex): + kind = 'c'; // complex + break; + default: + kind = 'V'; // void/other + break; + } + tle::DataType dtype; + if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kFP16; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kFP32; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kINT8; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kINT32; + } + else if (npDtype == nb::dtype()) + { + dtype = tle::DataType::kINT64; + } + else if (kind == 'V' && array.itemsize() == 1) + { + dtype = tle::DataType::kFP8; + } + else if (kind == 'V' && array.itemsize() == 2) + { + dtype = tle::DataType::kBF16; + } + else + { + TLLM_THROW("Unsupported numpy dtype."); + } + + // todo: improve the following code + std::vector dims; + dims.reserve(array.ndim()); + for (size_t i = 0; i < array.ndim(); ++i) + { + dims.push_back(static_cast(array.shape(i))); + } + tle::Shape shape(dims.data(), dims.size()); + + return tle::Tensor::of(dtype, const_cast(array.data()), shape); +} + +} // namespace + +namespace tensorrt_llm::nanobind::executor +{ + +Executor::Executor( + std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) +{ + mExecutor = std::make_unique(modelPath, modelType, executorConfig); +} + +Executor::Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, + tle::ModelType modelType, tle::ExecutorConfig const& executorConfig) +{ + mExecutor = std::make_unique(encoderModelPath, decoderModelPath, modelType, executorConfig); +} + +Executor::Executor(nb::bytes const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig, std::optional managedWeights) +{ + uint8_t const* data = static_cast(engineBuffer.data()); + size_t size = engineBuffer.size(); + std::optional> managedWeightsMap = std::nullopt; + if (managedWeights.has_value() && !managedWeights.value().empty()) + { + managedWeightsMap = std::map(); + for (auto const& [rawName, rawArray] : managedWeights.value()) + { + std::string name = nb::cast(rawName); + nb::ndarray array = nb::cast>(rawArray); + managedWeightsMap->emplace(name, numpyToTensor(array)); + } + } + mExecutor = std::make_unique( + tle::BufferView(data, size), jsonConfigStr, modelType, executorConfig, managedWeightsMap); +} + +Executor::Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, + std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig) +{ + uint8_t const* encoderData = reinterpret_cast(encoderEngineBuffer.data()); + size_t encoderSize = encoderEngineBuffer.size(); + uint8_t const* decoderData = reinterpret_cast(decoderEngineBuffer.data()); + size_t decoderSize = decoderEngineBuffer.size(); + mExecutor = std::make_unique(tle::BufferView(encoderData, encoderSize), encoderJsonConfigStr, + tle::BufferView(decoderData, decoderSize), decoderJsonConfigStr, modelType, executorConfig); +} + +nb::object Executor::enter() +{ + TLLM_CHECK(static_cast(mExecutor)); + return nb::cast(this); +} + +void Executor::exit( + [[maybe_unused]] nb::handle type, [[maybe_unused]] nb::handle value, [[maybe_unused]] nb::handle traceback) +{ + shutdown(); + mExecutor = nullptr; +} + +void Executor::shutdown() +{ + // NOTE: we must release the GIL here. Executor has spawned a thread for the execution loop. That thread must be + // able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so + // we release it now. Note that we shouldn't do anything related to python objects after that. + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + nb::gil_scoped_release release; + mExecutor->shutdown(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +void Executor::initBindings(nb::module_& m) +{ + nb::class_(m, "Executor") + .def(nb::init(), + nb::arg("model_path"), nb::arg("model_type"), nb::arg("executor_config")) + .def(nb::init(), + nb::arg("encoder_model_path"), nb::arg("decoder_model_path"), nb::arg("model_type"), + nb::arg("executor_config")) + .def(nb::init(), + nb::arg("engine_buffer"), nb::arg("json_config_str"), nb::arg("model_type"), nb::arg("executor_config"), + nb::arg("managed_weights") = nb::dict()) + .def(nb::init(), + nb::arg("encoder_engine_buffer"), nb::arg("encoder_json_config_str"), nb::arg("decoder_engine_buffer"), + nb::arg("decoder_json_config_str"), nb::arg("model_type"), nb::arg("executor_config")) + .def("shutdown", &Executor::shutdown) + .def("__enter__", &Executor::enter) + .def("__exit__", &Executor::exit) + .def("enqueue_request", &Executor::enqueueRequest, nb::arg("request")) + .def("enqueue_requests", &Executor::enqueueRequests, nb::arg("requests")) + .def("await_responses", + nb::overload_cast const&>(&Executor::awaitResponses), + nb::arg("timeout") = nb::none()) + .def("await_responses", + nb::overload_cast const&>( + &Executor::awaitResponses), + nb::arg("id"), nb::arg("timeout") = nb::none()) + .def("await_responses", + nb::overload_cast const&, std::optional const&>( + &Executor::awaitResponses), + nb::arg("ids"), nb::arg("timeout") = nb::none()) + .def("get_num_responses_ready", &Executor::getNumResponsesReady, nb::arg("id") = nb::none()) + .def("cancel_request", &Executor::cancelRequest, nb::arg("id") = nb::none()) + .def("get_latest_iteration_stats", &Executor::getLatestIterationStats) + .def("get_latest_request_stats", &Executor::getLatestRequestStats) + .def("get_latest_debug_tensors", &Executor::getLatestDebugTensors) + .def("can_enqueue_requests", &Executor::canEnqueueRequests) + .def("get_kv_cache_event_manager", &Executor::getKVCacheEventManager); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executor.h b/cpp/tensorrt_llm/nanobind/executor/executor.h new file mode 100644 index 00000000000..22c24abb4bf --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executor.h @@ -0,0 +1,129 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; + +namespace tensorrt_llm::nanobind::executor +{ + +class Executor +{ +public: + Executor( + std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + + Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath, + tle::ModelType modelType, tle::ExecutorConfig const& executorConfig); + + Executor(nb::bytes const& engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig, std::optional managedWeights); + + Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr, + std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType, + tle::ExecutorConfig const& executorConfig); + + nb::object enter(); + void exit( + [[maybe_unused]] nb::handle type, [[maybe_unused]] nb::handle value, [[maybe_unused]] nb::handle traceback); + void shutdown(); + + [[nodiscard]] tle::IdType enqueueRequest(tle::Request const& request) + { + return mExecutor->enqueueRequest(request); + } + + [[nodiscard]] std::vector enqueueRequests(std::vector const& requests) + { + return mExecutor->enqueueRequests(requests); + } + + [[nodiscard]] std::vector awaitResponses( + std::optional const& timeout = std::nullopt) + { + // Await responses blocks until a response is received. Release GIL so that it can be ran in a background + // thread. + nb::gil_scoped_release release; + return mExecutor->awaitResponses(timeout); + } + + [[nodiscard]] std::vector awaitResponses( + tle::IdType const& requestId, std::optional const& timeout = std::nullopt) + { + // Await responses blocks until a response is received. Release GIL so that it can be ran in a background + // thread. + nb::gil_scoped_release release; + return mExecutor->awaitResponses(requestId, timeout); + } + + [[nodiscard]] std::vector> awaitResponses(std::vector const& requestIds, + std::optional const& timeout = std::nullopt) + { + // Await responses blocks until a response is received. Release GIL so that it can be ran in a background + // thread. + nb::gil_scoped_release release; + return mExecutor->awaitResponses(requestIds, timeout); + } + + [[nodiscard]] tle::SizeType32 getNumResponsesReady(std::optional const& requestId = std::nullopt) const + { + return mExecutor->getNumResponsesReady(requestId); + } + + void cancelRequest(tle::IdType requestId) + { + mExecutor->cancelRequest(requestId); + } + + std::deque getLatestIterationStats() + { + return mExecutor->getLatestIterationStats(); + } + + std::deque getLatestRequestStats() + { + return mExecutor->getLatestRequestStats(); + } + + std::deque getLatestDebugTensors() + { + return mExecutor->getLatestDebugTensors(); + } + + [[nodiscard]] bool canEnqueueRequests() const + { + return mExecutor->canEnqueueRequests(); + } + + [[nodiscard]] std::optional> getKVCacheEventManager() const + { + return mExecutor->getKVCacheEventManager(); + } + + static void initBindings(nb::module_& m); + +private: + std::unique_ptr mExecutor; +}; + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp new file mode 100644 index 00000000000..c2d9fe25dff --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -0,0 +1,616 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "executorConfig.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/types.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; +using SizeType32 = tle::SizeType32; +using RuntimeDefaults = tensorrt_llm::runtime::RuntimeDefaults; + +namespace tensorrt_llm::nanobind::executor +{ + +void initConfigBindings(nb::module_& m) +{ + nb::enum_(m, "BatchingType") + .value("STATIC", tle::BatchingType::kSTATIC) + .value("INFLIGHT", tle::BatchingType::kINFLIGHT); + + auto dynamicBatchConfigGetstate = [](tle::DynamicBatchConfig const& self) + { + return nb::make_tuple(self.getEnableBatchSizeTuning(), self.getEnableMaxNumTokensTuning(), + self.getDynamicBatchMovingAverageWindow(), self.getBatchSizeTable()); + }; + auto dynamicBatchConfigSetstate = [](tle::DynamicBatchConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DynamicBatchConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast>>(state[3])); + }; + nb::class_(m, "DynamicBatchConfig") + .def(nb::init(), nb::arg("enable_batch_size_tuning"), + nb::arg("enable_max_num_tokens_tuning"), nb::arg("dynamic_batch_moving_average_window")) + .def_prop_ro("enable_batch_size_tuning", &tle::DynamicBatchConfig::getEnableBatchSizeTuning) + .def_prop_ro("enable_max_num_tokens_tuning", &tle::DynamicBatchConfig::getEnableMaxNumTokensTuning) + .def_prop_ro( + "dynamic_batch_moving_average_window", &tle::DynamicBatchConfig::getDynamicBatchMovingAverageWindow) + .def("__getstate__", dynamicBatchConfigGetstate) + .def("__setstate__", dynamicBatchConfigSetstate); + + auto schedulerConfigSetstate = [](tle::SchedulerConfig& self, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::SchedulerConfig(nb::cast(state[0]), + nb::cast>(state[1]), + nb::cast>(state[2])); + }; + auto schedulerConfigGetstate = [](tle::SchedulerConfig const& self) + { + return nb::make_tuple( + self.getCapacitySchedulerPolicy(), self.getContextChunkingPolicy(), self.getDynamicBatchConfig()); + }; + nb::class_(m, "SchedulerConfig") + .def(nb::init, + std::optional>(), + nb::arg("capacity_scheduler_policy") = tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, + nb::arg("context_chunking_policy") = nb::none(), nb::arg("dynamic_batch_config") = nb::none()) + .def_prop_ro("capacity_scheduler_policy", &tle::SchedulerConfig::getCapacitySchedulerPolicy) + .def_prop_ro("context_chunking_policy", &tle::SchedulerConfig::getContextChunkingPolicy) + .def_prop_ro("dynamic_batch_config", &tle::SchedulerConfig::getDynamicBatchConfig) + .def("__getstate__", schedulerConfigGetstate) + .def("__setstate__", schedulerConfigSetstate); + + nb::class_(m, "RuntimeDefaults") + .def(nb::init>, std::optional>(), + nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none()) + .def_ro("max_attention_window", &RuntimeDefaults::maxAttentionWindowVec) + .def_ro("sink_token_length", &RuntimeDefaults::sinkTokenLength); + + auto kvCacheConfigGetstate = [](tle::KvCacheConfig const& self) + { + return nb::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), + self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), + self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), + self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm()); + }; + auto kvCacheConfigSetstate = [](tle::KvCacheConfig& self, nb::tuple const& state) + { + if (state.size() != 13) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::KvCacheConfig(nb::cast(state[0]), nb::cast>(state[1]), + nb::cast>>(state[2]), nb::cast>(state[3]), + nb::cast>(state[4]), nb::cast>(state[5]), + nb::cast(state[6]), nb::cast>(state[7]), + nb::cast>(state[8]), nb::cast(state[9]), + nb::cast(state[10]), nb::cast(state[11]), nb::cast(state[12])); + }; + nb::class_(m, "KvCacheConfig") + .def(nb::init const&, std::optional> const&, + std::optional const&, std::optional const&, std::optional const&, bool, + std::optional const&, std::optional, size_t const&, bool, bool, bool, + std::optional const&>(), + nb::arg("enable_block_reuse") = true, nb::arg("max_tokens") = nb::none(), + nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none(), + nb::arg("free_gpu_memory_fraction") = nb::none(), nb::arg("host_cache_size") = nb::none(), + nb::arg("onboard_blocks") = true, nb::arg("cross_kv_cache_fraction") = nb::none(), + nb::arg("secondary_offload_min_priority") = nb::none(), nb::arg("event_buffer_max_size") = 0, nb::kw_only(), + nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("use_uvm") = false, + nb::arg("runtime_defaults") = nb::none()) + .def_prop_rw( + "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) + .def_prop_rw("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) + .def_prop_rw("max_attention_window", &tle::KvCacheConfig::getMaxAttentionWindowVec, + &tle::KvCacheConfig::setMaxAttentionWindowVec) + .def_prop_rw( + "sink_token_length", &tle::KvCacheConfig::getSinkTokenLength, &tle::KvCacheConfig::setSinkTokenLength) + .def_prop_rw("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction, + &tle::KvCacheConfig::setFreeGpuMemoryFraction) + .def_prop_rw("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize) + .def_prop_rw("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks) + .def_prop_rw("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction, + &tle::KvCacheConfig::setCrossKvCacheFraction) + .def_prop_rw("secondary_offload_min_priority", &tle::KvCacheConfig::getSecondaryOffloadMinPriority, + &tle::KvCacheConfig::setSecondaryOffloadMinPriority) + .def_prop_rw("event_buffer_max_size", &tle::KvCacheConfig::getEventBufferMaxSize, + &tle::KvCacheConfig::setEventBufferMaxSize) + .def_prop_rw("enable_partial_reuse", &tle::KvCacheConfig::getEnablePartialReuse, + &tle::KvCacheConfig::setEnablePartialReuse) + .def_prop_rw("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse, + &tle::KvCacheConfig::setCopyOnPartialReuse) + .def_prop_rw("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm) + .def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults) + .def("__getstate__", kvCacheConfigGetstate) + .def("__setstate__", kvCacheConfigSetstate); + + nb::class_(m, "OrchestratorConfig") + .def(nb::init, bool>(), nb::arg("is_orchestrator") = true, + nb::arg("worker_executable_path") = "", nb::arg("orch_leader_comm").none() = nullptr, + nb::arg("spawn_processes") = true) + .def_prop_rw( + "is_orchestrator", &tle::OrchestratorConfig::getIsOrchestrator, &tle::OrchestratorConfig::setIsOrchestrator) + .def_prop_rw("worker_executable_path", &tle::OrchestratorConfig::getWorkerExecutablePath, + &tle::OrchestratorConfig::setWorkerExecutablePath) + .def_prop_rw("orch_leader_comm", &tle::OrchestratorConfig::getOrchLeaderComm, + &tle::OrchestratorConfig::setOrchLeaderComm) + .def_prop_rw("spawn_processes", &tle::OrchestratorConfig::getSpawnProcesses, + &tle::OrchestratorConfig::setSpawnProcesses); + + auto parallelConfigGetstate = [](tle::ParallelConfig const& self) + { + return nb::make_tuple(self.getCommunicationType(), self.getCommunicationMode(), self.getDeviceIds(), + self.getParticipantIds(), self.getOrchestratorConfig(), self.getNumNodes()); + }; + auto parallelConfigSetstate = [](tle::ParallelConfig& self, nb::tuple const& state) + { + if (state.size() != 6) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::ParallelConfig(nb::cast(state[0]), + nb::cast(state[1]), nb::cast>>(state[2]), + nb::cast>>(state[3]), + nb::cast>(state[4]), nb::cast>(state[5])); + }; + nb::class_(m, "ParallelConfig") + .def(nb::init> const&, + std::optional> const&, std::optional const&, + std::optional const&>(), + nb::arg("communication_type") = tle::CommunicationType::kMPI, + nb::arg("communication_mode") = tle::CommunicationMode::kLEADER, nb::arg("device_ids") = nb::none(), + nb::arg("participant_ids") = nb::none(), nb::arg("orchestrator_config") = nb::none(), + nb::arg("num_nodes") = nb::none()) + .def_prop_rw("communication_type", &tle::ParallelConfig::getCommunicationType, + &tle::ParallelConfig::setCommunicationType) + .def_prop_rw("communication_mode", &tle::ParallelConfig::getCommunicationMode, + &tle::ParallelConfig::setCommunicationMode) + .def_prop_rw("device_ids", &tle::ParallelConfig::getDeviceIds, &tle::ParallelConfig::setDeviceIds) + .def_prop_rw( + "participant_ids", &tle::ParallelConfig::getParticipantIds, &tle::ParallelConfig::setParticipantIds) + .def_prop_rw("orchestrator_config", &tle::ParallelConfig::getOrchestratorConfig, + &tle::ParallelConfig::setOrchestratorConfig) + .def_prop_rw("num_nodes", &tle::ParallelConfig::getNumNodes, &tle::ParallelConfig::setNumNodes) + .def("__getstate__", parallelConfigGetstate) + .def("__setstate__", parallelConfigSetstate); + + auto peftCacheConfigSetstate = [](tle::PeftCacheConfig& self, nb::tuple const& state) + { + if (state.size() != 11) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::PeftCacheConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), + nb::cast(state[5]), nb::cast(state[6]), nb::cast(state[7]), + nb::cast(state[8]), nb::cast>(state[9]), + nb::cast>(state[10])); + }; + auto peftCacheConfigGetstate = [](tle::PeftCacheConfig const& self) + { + return nb::make_tuple(self.getNumHostModuleLayer(), self.getNumDeviceModuleLayer(), + self.getOptimalAdapterSize(), self.getMaxAdapterSize(), self.getNumPutWorkers(), self.getNumEnsureWorkers(), + self.getNumCopyStreams(), self.getMaxPagesPerBlockHost(), self.getMaxPagesPerBlockDevice(), + self.getDeviceCachePercent(), self.getHostCacheSize()); + }; + nb::class_(m, "PeftCacheConfig") + .def(nb::init const&, std::optional const&, + std::optional const&>(), + nb::arg("num_host_module_layer") = 0, nb::arg("num_device_module_layer") = 0, + nb::arg("optimal_adapter_size") = 8, nb::arg("max_adapter_size") = 64, nb::arg("num_put_workers") = 1, + nb::arg("num_ensure_workers") = 1, nb::arg("num_copy_streams") = 1, + nb::arg("max_pages_per_block_host") = 24, nb::arg("max_pages_per_block_device") = 8, + nb::arg("device_cache_percent") = nb::none(), nb::arg("host_cache_size") = nb::none(), + nb::arg("lora_prefetch_dir") = nb::none()) + .def_prop_ro("num_host_module_layer", &tle::PeftCacheConfig::getNumHostModuleLayer) + .def_prop_ro("num_device_module_layer", &tle::PeftCacheConfig::getNumDeviceModuleLayer) + .def_prop_ro("optimal_adapter_size", &tle::PeftCacheConfig::getOptimalAdapterSize) + .def_prop_ro("max_adapter_size", &tle::PeftCacheConfig::getMaxAdapterSize) + .def_prop_ro("num_put_workers", &tle::PeftCacheConfig::getNumPutWorkers) + .def_prop_ro("num_ensure_workers", &tle::PeftCacheConfig::getNumEnsureWorkers) + .def_prop_ro("num_copy_streams", &tle::PeftCacheConfig::getNumCopyStreams) + .def_prop_ro("max_pages_per_block_host", &tle::PeftCacheConfig::getMaxPagesPerBlockHost) + .def_prop_ro("max_pages_per_block_device", &tle::PeftCacheConfig::getMaxPagesPerBlockDevice) + .def_prop_ro("device_cache_percent", &tle::PeftCacheConfig::getDeviceCachePercent) + .def_prop_ro("host_cache_size", &tle::PeftCacheConfig::getHostCacheSize) + .def_prop_ro("lora_prefetch_dir", &tle::PeftCacheConfig::getLoraPrefetchDir) + .def("__getstate__", peftCacheConfigGetstate) + .def("__setstate__", peftCacheConfigSetstate); + + auto decodingConfigGetstate = [](tle::DecodingConfig const& self) + { + return nb::make_tuple( + self.getDecodingMode(), self.getLookaheadDecodingConfig(), self.getMedusaChoices(), self.getEagleConfig()); + }; + auto decodingConfigSetstate = [](tle::DecodingConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DecodingConfig(nb::cast>(state[0]), // DecodingMode + nb::cast>(state[1]), // LookaheadDecodingConfig + nb::cast>(state[2]), // MedusaChoices + nb::cast>(state[3]) // EagleConfig + ); + }; + nb::class_(m, "DecodingConfig") + .def(nb::init, std::optional, + std::optional, std::optional>(), + nb::arg("decoding_mode") = nb::none(), nb::arg("lookahead_decoding_config") = nb::none(), + nb::arg("medusa_choices") = nb::none(), nb::arg("eagle_config") = nb::none()) + .def_prop_rw("decoding_mode", &tle::DecodingConfig::getDecodingMode, &tle::DecodingConfig::setDecodingMode) + .def_prop_rw("lookahead_decoding_config", &tle::DecodingConfig::getLookaheadDecodingConfig, + &tle::DecodingConfig::setLookaheadDecodingConfig) + .def_prop_rw("medusa_choices", &tle::DecodingConfig::getMedusaChoices, &tle::DecodingConfig::setMedusaChoices) + .def_prop_rw("eagle_config", &tle::DecodingConfig::getEagleConfig, &tle::DecodingConfig::setEagleConfig) + .def("__getstate__", decodingConfigGetstate) + .def("__setstate__", decodingConfigSetstate); + + auto debugConfigGetstate = [](tle::DebugConfig const& self) + { + return nb::make_tuple(self.getDebugInputTensors(), self.getDebugOutputTensors(), self.getDebugTensorNames(), + self.getDebugTensorsMaxIterations()); + }; + auto debugConfigSetstate = [](tle::DebugConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&self) tle::DebugConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast>(state[2]), nb::cast(state[3])); + }; + nb::class_(m, "DebugConfig") + .def(nb::init, SizeType32>(), nb::arg("debug_input_tensors") = false, + nb::arg("debug_output_tensors") = false, nb::arg("debug_tensor_names") = nb::none(), + nb::arg("debug_tensors_max_iterations") = false) + .def_prop_rw( + "debug_input_tensors", &tle::DebugConfig::getDebugInputTensors, &tle::DebugConfig::setDebugInputTensors) + .def_prop_rw( + "debug_output_tensors", &tle::DebugConfig::getDebugOutputTensors, &tle::DebugConfig::setDebugOutputTensors) + .def_prop_rw( + "debug_tensor_names", &tle::DebugConfig::getDebugTensorNames, &tle::DebugConfig::setDebugTensorNames) + .def_prop_rw("debug_tensors_max_iterations", &tle::DebugConfig::getDebugTensorsMaxIterations, + &tle::DebugConfig::setDebugTensorsMaxIterations) + .def("__getstate__", debugConfigGetstate) + .def("__setstate__", debugConfigSetstate); + + auto logitsPostProcessorConfigGetstate = [](tle::LogitsPostProcessorConfig const& self) + { return nb::make_tuple(self.getProcessorMap(), self.getProcessorBatched(), self.getReplicate()); }; + + auto logitsPostProcessorConfigSetstate = [](tle::LogitsPostProcessorConfig& self, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid LogitsPostProcessorConfig state!"); + } + new (&self) tle::LogitsPostProcessorConfig(nb::cast>(state[0]), + nb::cast>(state[1]), nb::cast(state[2])); + }; + + nb::class_(m, "LogitsPostProcessorConfig") + .def(nb::init, std::optional, + bool>(), + nb::arg("processor_map") = nb::none(), nb::arg("processor_batched") = nb::none(), + nb::arg("replicate") = true) + .def_prop_rw("processor_map", &tle::LogitsPostProcessorConfig::getProcessorMap, + &tle::LogitsPostProcessorConfig::setProcessorMap) + .def_prop_rw("processor_batched", &tle::LogitsPostProcessorConfig::getProcessorBatched, + &tle::LogitsPostProcessorConfig::setProcessorBatched) + .def_prop_rw( + "replicate", &tle::LogitsPostProcessorConfig::getReplicate, &tle::LogitsPostProcessorConfig::setReplicate) + .def("__getstate__", logitsPostProcessorConfigGetstate) + .def("__setstate__", logitsPostProcessorConfigSetstate); + + auto extendedRuntimePerfKnobConfigSetstate = [](tle::ExtendedRuntimePerfKnobConfig& self, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!"); + } + new (&self) tle::ExtendedRuntimePerfKnobConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[2])); + }; + auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self) + { + return nb::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc(), self.getCudaGraphMode(), + self.getCudaGraphCacheSize()); + }; + nb::class_(m, "ExtendedRuntimePerfKnobConfig") + .def( + nb::init(), nb::arg("multi_block_mode") = true, nb::arg("enable_context_fmha_fp32_acc") = false) + .def_prop_rw("multi_block_mode", &tle::ExtendedRuntimePerfKnobConfig::getMultiBlockMode, + &tle::ExtendedRuntimePerfKnobConfig::setMultiBlockMode) + .def_prop_rw("enable_context_fmha_fp32_acc", &tle::ExtendedRuntimePerfKnobConfig::getEnableContextFMHAFP32Acc, + &tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc) + .def_prop_rw("cuda_graph_mode", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphMode, + &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphMode) + .def_prop_rw("cuda_graph_cache_size", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphCacheSize, + &tle::ExtendedRuntimePerfKnobConfig::setCudaGraphCacheSize) + .def("__getstate__", extendedRuntimePerfKnobConfigGetstate) + .def("__setstate__", extendedRuntimePerfKnobConfigSetstate); + + auto SpeculativeDecodingConfigGetState + = [](tle::SpeculativeDecodingConfig const& self) { return nb::make_tuple(self.fastLogits); }; + auto SpeculativeDecodingConfigSetState = [](tle::SpeculativeDecodingConfig& self, nb::tuple const& state) + { + if (state.size() != 1) + { + throw std::runtime_error("Invalid SpeculativeDecodingConfig state!"); + } + new (&self) tle::SpeculativeDecodingConfig(nb::cast(state[0])); + }; + nb::class_(m, "SpeculativeDecodingConfig") + .def(nb::init(), nb::arg("fast_logits") = false) + .def_rw("fast_logits", &tle::SpeculativeDecodingConfig::fastLogits) + .def("__getstate__", SpeculativeDecodingConfigGetState) + .def("__setstate__", SpeculativeDecodingConfigSetState); + + // Guided decoding config + auto pyGuidedDecodingConfig = nb::class_(m, "GuidedDecodingConfig"); + + nb::enum_(pyGuidedDecodingConfig, "GuidedDecodingBackend") + .value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR) + .value("LLGUIDANCE", tle::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE); + + auto guidedDecodingConfigGetstate = [](tle::GuidedDecodingConfig const& self) { + return nb::make_tuple( + self.getBackend(), self.getEncodedVocab(), self.getTokenizerStr(), self.getStopTokenIds()); + }; + auto guidedDecodingConfigSetstate = [](tle::GuidedDecodingConfig& self, nb::tuple state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid GuidedDecodingConfig state!"); + } + new (&self) tle::GuidedDecodingConfig(nb::cast(state[0]), + nb::cast>>(state[1]), nb::cast>(state[2]), + nb::cast>>(state[3])); + }; + + pyGuidedDecodingConfig + .def(nb::init>, + std::optional, std::optional>>(), + nb::arg("backend"), nb::arg("encoded_vocab") = nb::none(), nb::arg("tokenizer_str") = nb::none(), + nb::arg("stop_token_ids") = nb::none()) + .def_prop_rw("backend", &tle::GuidedDecodingConfig::getBackend, &tle::GuidedDecodingConfig::setBackend) + .def_prop_rw( + "encoded_vocab", &tle::GuidedDecodingConfig::getEncodedVocab, &tle::GuidedDecodingConfig::setEncodedVocab) + .def_prop_rw( + "tokenizer_str", &tle::GuidedDecodingConfig::getTokenizerStr, &tle::GuidedDecodingConfig::setTokenizerStr) + .def_prop_rw( + "stop_token_ids", &tle::GuidedDecodingConfig::getStopTokenIds, &tle::GuidedDecodingConfig::setStopTokenIds) + .def("__getstate__", guidedDecodingConfigGetstate) + .def("__setstate__", guidedDecodingConfigSetstate); + + auto cacheTransceiverConfigGetstate + = [](tle::CacheTransceiverConfig const& self) { return nb::make_tuple(self.getMaxNumTokens()); }; + auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state) + { + if (state.size() != 1) + { + throw std::runtime_error("Invalid CacheTransceiverConfig state!"); + } + new (&self) tle::CacheTransceiverConfig(nb::cast>(state[0])); + }; + + nb::class_(m, "CacheTransceiverConfig") + .def(nb::init>(), nb::arg("max_num_tokens") = nb::none()) + .def_prop_rw("max_num_tokens", &tle::CacheTransceiverConfig::getMaxNumTokens, + &tle::CacheTransceiverConfig::setMaxNumTokens) + .def("__getstate__", cacheTransceiverConfigGetstate) + .def("__setstate__", cacheTransceiverConfigSetstate); + + auto executorConfigGetState = [](nb::object const& self) + { + auto& c = nb::cast(self); + // Return a tuple containing C++ data and the Python __dict__ + auto cpp_states = nb::make_tuple(c.getMaxBeamWidth(), c.getSchedulerConfig(), c.getKvCacheConfig(), + c.getEnableChunkedContext(), c.getNormalizeLogProbs(), c.getIterStatsMaxIterations(), + c.getRequestStatsMaxIterations(), c.getBatchingType(), c.getMaxBatchSize(), c.getMaxNumTokens(), + c.getParallelConfig(), c.getPeftCacheConfig(), c.getLogitsPostProcessorConfig(), c.getDecodingConfig(), + c.getUseGpuDirectStorage(), c.getGpuWeightsPercent(), c.getMaxQueueSize(), + c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(), + c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(), + c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(), + c.getPromptTableOffloading(), c.getEnableTrtOverlap()); + auto pickle_tuple = nb::make_tuple(cpp_states, nb::getattr(self, "__dict__")); + return pickle_tuple; + }; + + auto executorConfigSetState = [](nb::object self, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid state!"); + } + + auto cpp_states = nb::cast(state[0]); + if (cpp_states.size() != 28) + { + throw std::runtime_error("Invalid cpp_states!"); + } + + // Restore C++ data + tle::ExecutorConfig* cpp_self = nb::inst_ptr(self); + new (cpp_self) tle::ExecutorConfig( // + nb::cast(cpp_states[0]), // MaxBeamWidth + nb::cast(cpp_states[1]), // SchedulerConfig + nb::cast(cpp_states[2]), // KvCacheConfig + nb::cast(cpp_states[3]), // EnableChunkedContext + nb::cast(cpp_states[4]), // NormalizeLogProbs + nb::cast(cpp_states[5]), // IterStatsMaxIterations + nb::cast(cpp_states[6]), // RequestStatsMaxIterations + nb::cast(cpp_states[7]), // BatchingType + nb::cast>(cpp_states[8]), // MaxBatchSize + nb::cast>(cpp_states[9]), // MaxNumTokens + nb::cast>(cpp_states[10]), // ParallelConfig + nb::cast>(cpp_states[11]), // PeftCacheConfig + nb::cast>(cpp_states[12]), // LogitsPostProcessorConfig + nb::cast>(cpp_states[13]), // DecodingConfig + nb::cast(cpp_states[14]), // UseGpuDirectStorage + nb::cast(cpp_states[15]), // GpuWeightsPercent + nb::cast>(cpp_states[16]), // MaxQueueSize + nb::cast(cpp_states[17]), // ExtendedRuntimePerfKnobConfig + nb::cast>(cpp_states[18]), // DebugConfig + nb::cast(cpp_states[19]), // RecvPollPeriodMs + nb::cast(cpp_states[20]), // MaxSeqIdleMicroseconds + nb::cast>(cpp_states[21]), // SpecDecConfig + nb::cast>(cpp_states[22]), // GuidedDecodingConfig + nb::cast>>(cpp_states[23]), // AdditionalModelOutputs + nb::cast>(cpp_states[24]), // CacheTransceiverConfig + nb::cast(cpp_states[25]), // GatherGenerationLogits + nb::cast(cpp_states[26]), // PromptTableOffloading + nb::cast(cpp_states[27]) // EnableTrtOverlap + ); + + // Restore Python data + auto py_state = nb::cast(state[1]); + self.attr("__dict__").attr("update")(py_state); + + nb::inst_mark_ready(self); + }; + + nb::class_(m, "ExecutorConfig", nb::dynamic_attr()) + .def(nb::init< // + SizeType32, // MaxBeamWidth + tle::SchedulerConfig const&, // SchedulerConfig + tle::KvCacheConfig const&, // KvCacheConfig + bool, // EnableChunkedContext + bool, // NormalizeLogProbs + SizeType32, // IterStatsMaxIterations + SizeType32, // RequestStatsMaxIterations + tle::BatchingType, // BatchingType + std::optional, // MaxBatchSize + std::optional, // MaxNumTokens + std::optional, // ParallelConfig + tle::PeftCacheConfig const&, // PeftCacheConfig + std::optional, // LogitsPostProcessorConfig + std::optional, // DecodingConfig + bool, // UseGpuDirectStorage + float, // GpuWeightsPercent + std::optional, // MaxQueueSize + tle::ExtendedRuntimePerfKnobConfig const&, // ExtendedRuntimePerfKnobConfig + std::optional, // DebugConfig + SizeType32, // RecvPollPeriodMs + uint64_t, // MaxSeqIdleMicroseconds + std::optional, // SpecDecConfig + std::optional, // GuidedDecodingConfig + std::optional>, // AdditionalModelOutputs + std::optional, // CacheTransceiverConfig + bool, // GatherGenerationLogits + bool, // PromptTableOffloading + bool // EnableTrtOverlap + >(), + nb::arg("max_beam_width") = 1, nb::arg("scheduler_config") = tle::SchedulerConfig(), + nb::arg("kv_cache_config") = tle::KvCacheConfig(), nb::arg("enable_chunked_context") = false, + nb::arg("normalize_log_probs") = true, + nb::arg("iter_stats_max_iterations") = tle::ExecutorConfig::kDefaultIterStatsMaxIterations, + nb::arg("request_stats_max_iterations") = tle::ExecutorConfig::kDefaultRequestStatsMaxIterations, + nb::arg("batching_type") = tle::BatchingType::kINFLIGHT, nb::arg("max_batch_size") = nb::none(), + nb::arg("max_num_tokens") = nb::none(), nb::arg("parallel_config") = nb::none(), + nb::arg("peft_cache_config") = tle::PeftCacheConfig(), nb::arg("logits_post_processor_config") = nb::none(), + nb::arg("decoding_config") = nb::none(), nb::arg("use_gpu_direct_storage") = false, + nb::arg("gpu_weights_percent") = 1.0, nb::arg("max_queue_size") = nb::none(), + nb::arg("extended_runtime_perf_knob_config") = tle::ExtendedRuntimePerfKnobConfig(), + nb::arg("debug_config") = nb::none(), nb::arg("recv_poll_period_ms") = 0, + nb::arg("max_seq_idle_microseconds") = tle::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds, + nb::arg("spec_dec_config") = nb::none(), nb::arg("guided_decoding_config") = nb::none(), + nb::arg("additional_model_outputs") = nb::none(), nb::arg("cache_transceiver_config") = nb::none(), + nb::arg("gather_generation_logits") = false, nb::arg("mm_embedding_offloading") = false, + nb::arg("enable_trt_overlap") = false) + .def_prop_rw("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth) + .def_prop_rw("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize) + .def_prop_rw("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens) + .def_prop_rw( + "scheduler_config", &tle::ExecutorConfig::getSchedulerConfigRef, &tle::ExecutorConfig::setSchedulerConfig) + .def_prop_rw( + "kv_cache_config", &tle::ExecutorConfig::getKvCacheConfigRef, &tle::ExecutorConfig::setKvCacheConfig) + .def_prop_rw("enable_chunked_context", &tle::ExecutorConfig::getEnableChunkedContext, + &tle::ExecutorConfig::setEnableChunkedContext) + .def_prop_rw("normalize_log_probs", &tle::ExecutorConfig::getNormalizeLogProbs, + &tle::ExecutorConfig::setNormalizeLogProbs) + .def_prop_rw("iter_stats_max_iterations", &tle::ExecutorConfig::getIterStatsMaxIterations, + &tle::ExecutorConfig::setIterStatsMaxIterations) + .def_prop_rw("request_stats_max_iterations", &tle::ExecutorConfig::getRequestStatsMaxIterations, + &tle::ExecutorConfig::setRequestStatsMaxIterations) + .def_prop_rw("batching_type", &tle::ExecutorConfig::getBatchingType, &tle::ExecutorConfig::setBatchingType) + .def_prop_rw( + "parallel_config", &tle::ExecutorConfig::getParallelConfig, &tle::ExecutorConfig::setParallelConfig) + .def_prop_rw( + "peft_cache_config", &tle::ExecutorConfig::getPeftCacheConfig, &tle::ExecutorConfig::setPeftCacheConfig) + .def_prop_rw("logits_post_processor_config", &tle::ExecutorConfig::getLogitsPostProcessorConfig, + &tle::ExecutorConfig::setLogitsPostProcessorConfig) + .def_prop_rw( + "decoding_config", &tle::ExecutorConfig::getDecodingConfig, &tle::ExecutorConfig::setDecodingConfig) + .def_prop_rw("use_gpu_direct_storage", &tle::ExecutorConfig::getUseGpuDirectStorage, + &tle::ExecutorConfig::setUseGpuDirectStorage) + .def_prop_rw("gpu_weights_percent", &tle::ExecutorConfig::getGpuWeightsPercent, + &tle::ExecutorConfig::setGpuWeightsPercent) + .def_prop_rw("max_queue_size", &tle::ExecutorConfig::getMaxQueueSize, &tle::ExecutorConfig::setMaxQueueSize) + .def_prop_rw("extended_runtime_perf_knob_config", &tle::ExecutorConfig::getExtendedRuntimePerfKnobConfig, + &tle::ExecutorConfig::setExtendedRuntimePerfKnobConfig) + .def_prop_rw("debug_config", &tle::ExecutorConfig::getDebugConfig, &tle::ExecutorConfig::setDebugConfig) + .def_prop_rw( + "recv_poll_period_ms", &tle::ExecutorConfig::getRecvPollPeriodMs, &tle::ExecutorConfig::setRecvPollPeriodMs) + .def_prop_rw("max_seq_idle_microseconds", &tle::ExecutorConfig::getMaxSeqIdleMicroseconds, + &tle::ExecutorConfig::setMaxSeqIdleMicroseconds) + .def_prop_rw("spec_dec_config", &tle::ExecutorConfig::getSpecDecConfig, &tle::ExecutorConfig::setSpecDecConfig) + .def_prop_rw("guided_decoding_config", &tle::ExecutorConfig::getGuidedDecodingConfig, + &tle::ExecutorConfig::setGuidedDecodingConfig) + .def_prop_rw("additional_model_outputs", &tle::ExecutorConfig::getAdditionalModelOutputs, + &tle::ExecutorConfig::setAdditionalModelOutputs) + .def_prop_rw("cache_transceiver_config", &tle::ExecutorConfig::getCacheTransceiverConfig, + &tle::ExecutorConfig::setCacheTransceiverConfig) + .def_prop_rw("gather_generation_logits", &tle::ExecutorConfig::getGatherGenerationLogits, + &tle::ExecutorConfig::setGatherGenerationLogits) + .def_prop_rw("mm_embedding_offloading", &tle::ExecutorConfig::getPromptTableOffloading, + &tle::ExecutorConfig::setPromptTableOffloading) + .def_prop_rw( + "enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap) + .def("__getstate__", executorConfigGetState) + .def("__setstate__", executorConfigSetState); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.h b/cpp/tensorrt_llm/nanobind/executor/executorConfig.h new file mode 100644 index 00000000000..5b63e7c5a3e --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::executor +{ + +// Register bindings for executor API. +void initConfigBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp new file mode 100644 index 00000000000..9c3d34aa8fd --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -0,0 +1,935 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "request.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/serializeUtils.h" +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/executor/types.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/cudaStream.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace nb = nanobind; +namespace tle = tensorrt_llm::executor; +using Tensor = tle::Tensor; +using SizeType32 = tle::SizeType32; +using FloatType = tle::FloatType; +using VecTokens = tle::VecTokens; +using IdType = tle::IdType; +using VecTokenExtraIds = tle::VecTokenExtraIds; + +namespace tensorrt_llm::nanobind::executor +{ + +void initRequestBindings(nb::module_& m) +{ + nb::enum_(m, "RequestType") + .value("REQUEST_TYPE_CONTEXT_AND_GENERATION", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION) + .value("REQUEST_TYPE_CONTEXT_ONLY", tle::RequestType::REQUEST_TYPE_CONTEXT_ONLY) + .value("REQUEST_TYPE_GENERATION_ONLY", tle::RequestType::REQUEST_TYPE_GENERATION_ONLY); + + nb::enum_(m, "FinishReason") + .value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED) + .value("END_ID", tle::FinishReason::kEND_ID) + .value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS) + .value("LENGTH", tle::FinishReason::kLENGTH) + .value("TIMED_OUT", tle::FinishReason::kTIMED_OUT) + .value("CANCELLED", tle::FinishReason::kCANCELLED); + + nb::enum_(m, "KvCacheTransferMode") + .value("DRAM", tle::KvCacheTransferMode::DRAM) + .value("GDS", tle::KvCacheTransferMode::GDS) + .value("POSIX_DEBUG_FALLBACK", tle::KvCacheTransferMode::POSIX_DEBUG_FALLBACK); + + auto samplingConfigGetstate = [](tle::SamplingConfig const& self) + { + return nb::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(), + self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(), + self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(), + self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(), + self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); + }; + auto samplingConfigSetstate = [](tle::SamplingConfig& samplingConfig, nb::tuple const& state) + { + if (state.size() != 19) + { + throw std::runtime_error("Invalid SamplingConfig state!"); + } + new (&samplingConfig) tle::SamplingConfig(nb::cast(state[0]), // BeamWidth + nb::cast>(state[1]), // TopK + nb::cast>(state[2]), // TopP + nb::cast>(state[3]), // TopPMin + nb::cast>(state[4]), // TopPResetIds + nb::cast>(state[5]), // TopPDecay + nb::cast>(state[6]), // Seed + nb::cast>(state[7]), // Temperature + nb::cast>(state[8]), // MinTokens + nb::cast>(state[9]), // BeamSearchDiversityRate + nb::cast>(state[10]), // RepetitionPenalty + nb::cast>(state[11]), // PresencePenalty + nb::cast>(state[12]), // FrequencyPenalty + nb::cast>(state[13]), // LengthPenalty + nb::cast>(state[14]), // EarlyStopping + nb::cast>(state[15]), // NoRepeatNgramSize + nb::cast>(state[16]), // NumReturnSequences + nb::cast>(state[17]), // MinP + nb::cast>>(state[18]) // BeamWidthArray + ); + }; + nb::class_(m, "SamplingConfig") + .def(nb::init const&, // beamWidth + std::optional const&, // topP + std::optional const&, // topPMin + std::optional const&, // topPResetIds + std::optional const&, // topPDecay + std::optional const&, // seed + std::optional const&, // temperature + std::optional const&, // minTokens + std::optional const&, // beamSearchDiversityRate + std::optional const&, // repetitionPenalty + std::optional const&, // presencePenalty + std::optional const&, // frequencyPenalty + std::optional const&, // lengthPenalty + std::optional const&, // earlyStopping + std::optional const&, // noRepeatNgramSize + std::optional const&, // numReturnSequences + std::optional const&, // minP + std::optional> const& // beamWidthArray + >(), + // clang-format off + nb::arg("beam_width") = 1, + nb::kw_only(), + nb::arg("top_k") = nb::none(), + nb::arg("top_p") = nb::none(), + nb::arg("top_p_min") = nb::none(), + nb::arg("top_p_reset_ids") = nb::none(), + nb::arg("top_p_decay") = nb::none(), + nb::arg("seed") = nb::none(), + nb::arg("temperature") = nb::none(), + nb::arg("min_tokens") = nb::none(), + nb::arg("beam_search_diversity_rate") = nb::none(), + nb::arg("repetition_penalty") = nb::none(), + nb::arg("presence_penalty") = nb::none(), + nb::arg("frequency_penalty") = nb::none(), + nb::arg("length_penalty") = nb::none(), + nb::arg("early_stopping") = nb::none(), + nb::arg("no_repeat_ngram_size") = nb::none(), + nb::arg("num_return_sequences") = nb::none(), + nb::arg("min_p") = nb::none(), + nb::arg("beam_width_array") = nb::none()) // clang-format on + .def_prop_rw("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) + .def_prop_rw("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) + .def_prop_rw("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) + .def_prop_rw("top_p_min", &tle::SamplingConfig::getTopPMin, &tle::SamplingConfig::setTopPMin) + .def_prop_rw("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds, &tle::SamplingConfig::setTopPResetIds) + .def_prop_rw("top_p_decay", &tle::SamplingConfig::getTopPDecay, &tle::SamplingConfig::setTopPDecay) + .def_prop_rw("seed", &tle::SamplingConfig::getSeed, &tle::SamplingConfig::setSeed) + .def_prop_rw("temperature", &tle::SamplingConfig::getTemperature, &tle::SamplingConfig::setTemperature) + .def_prop_rw("min_tokens", &tle::SamplingConfig::getMinTokens, &tle::SamplingConfig::setMinTokens) + .def_prop_rw("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate, + &tle::SamplingConfig::setBeamSearchDiversityRate) + .def_prop_rw("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty, + &tle::SamplingConfig::setRepetitionPenalty) + .def_prop_rw("presence_penalty", &tle::SamplingConfig::getPresencePenalty, + [](tle::SamplingConfig& self, std::optional v) { self.setPresencePenalty(v); }) + .def_prop_rw( + "frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty) + .def_prop_rw("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty) + .def_prop_rw("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping) + .def_prop_rw("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize, + &tle::SamplingConfig::setNoRepeatNgramSize) + .def_prop_rw("num_return_sequences", &tle::SamplingConfig::getNumReturnSequences, + &tle::SamplingConfig::setNumReturnSequences) + .def_prop_rw("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP) + .def_prop_rw( + "beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray) + .def("__getstate__", samplingConfigGetstate) + .def("__setstate__", samplingConfigSetstate); + + auto additionalModelOutputGetstate + = [](tle::AdditionalModelOutput const& self) { return nb::make_tuple(self.name, self.gatherContext); }; + auto additionalModelOutputSetstate = [](tle::AdditionalModelOutput& additionalModelOutput, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid AdditionalModelOutput state!"); + } + new (&additionalModelOutput) + tle::AdditionalModelOutput(nb::cast(state[0]), nb::cast(state[1])); + }; + nb::class_(m, "AdditionalModelOutput") + .def(nb::init(), nb::arg("name"), nb::arg("gather_context") = false) + .def_rw("name", &tle::AdditionalModelOutput::name) + .def_rw("gather_context", &tle::AdditionalModelOutput::gatherContext) + .def("__getstate__", additionalModelOutputGetstate) + .def("__setstate__", additionalModelOutputSetstate); + + auto outputConfigGetstate = [](tle::OutputConfig const& self) + { + return nb::make_tuple(self.returnLogProbs, self.returnContextLogits, self.returnGenerationLogits, + self.excludeInputFromOutput, self.returnEncoderOutput, self.returnPerfMetrics, self.additionalModelOutputs); + }; + auto outputConfigSetstate = [](tle::OutputConfig& outputConfig, nb::tuple const& state) + { + if (state.size() != 7) + { + throw std::runtime_error("Invalid OutputConfig state!"); + } + new (&outputConfig) tle::OutputConfig(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), nb::cast(state[5]), + nb::cast>>(state[6])); + }; + nb::class_(m, "OutputConfig") + .def(nb::init>>(), + nb::arg("return_log_probs").none() = false, nb::arg("return_context_logits") = false, + nb::arg("return_generation_logits") = false, nb::arg("exclude_input_from_output") = false, + nb::arg("return_encoder_output") = false, nb::arg("return_perf_metrics") = false, + nb::arg("additional_model_outputs") = nb::none()) + .def_rw("return_log_probs", &tle::OutputConfig::returnLogProbs) + .def_rw("return_context_logits", &tle::OutputConfig::returnContextLogits) + .def_rw("return_generation_logits", &tle::OutputConfig::returnGenerationLogits) + .def_rw("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput) + .def_rw("return_encoder_output", &tle::OutputConfig::returnEncoderOutput) + .def_rw("return_perf_metrics", &tle::OutputConfig::returnPerfMetrics) + .def_rw("additional_model_outputs", &tle::OutputConfig::additionalModelOutputs) + .def("__getstate__", outputConfigGetstate) + .def("__setstate__", outputConfigSetstate); + + auto externalDraftTokensConfigGetstate = [](tle::ExternalDraftTokensConfig const& self) + { return nb::make_tuple(self.getTokens(), self.getLogits(), self.getAcceptanceThreshold()); }; + auto externalDraftTokensConfigSetstate + = [](tle::ExternalDraftTokensConfig& externalDraftTokensConfig, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid ExternalDraftTokensConfig state!"); + } + new (&externalDraftTokensConfig) tle::ExternalDraftTokensConfig(nb::cast(state[0]), + nb::cast>(state[1]), nb::cast>(state[2])); + }; + nb::class_(m, "ExternalDraftTokensConfig") + .def(nb::init, std::optional const&, std::optional>(), + nb::arg("tokens"), nb::arg("logits") = nb::none(), nb::arg("acceptance_threshold") = nb::none(), + nb::arg("fast_logits") = nb::none()) + .def_prop_ro("tokens", &tle::ExternalDraftTokensConfig::getTokens) + .def_prop_ro("logits", &tle::ExternalDraftTokensConfig::getLogits) + .def_prop_ro("acceptance_threshold", &tle::ExternalDraftTokensConfig::getAcceptanceThreshold) + .def("__getstate__", externalDraftTokensConfigGetstate) + .def("__setstate__", externalDraftTokensConfigSetstate) + .def_prop_ro("fast_logits", &tle::ExternalDraftTokensConfig::getFastLogits); + + auto promptTuningConfigGetstate = [](tle::PromptTuningConfig const& self) + { return nb::make_tuple(self.getEmbeddingTable(), self.getInputTokenExtraIds()); }; + auto promptTuningConfigSetstate = [](tle::PromptTuningConfig& promptTuningConfig, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid PromptTuningConfig state!"); + } + new (&promptTuningConfig) + tle::PromptTuningConfig(nb::cast(state[0]), nb::cast>(state[1])); + }; + nb::class_(m, "PromptTuningConfig") + .def(nb::init>(), nb::arg("embedding_table"), + nb::arg("input_token_extra_ids") = nb::none()) + .def_prop_ro("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable) + .def_prop_ro("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds) + .def("__getstate__", promptTuningConfigGetstate) + .def("__setstate__", promptTuningConfigSetstate); + + auto loraConfigGetstate = [](tle::LoraConfig const& self) + { return nb::make_tuple(self.getTaskId(), self.getWeights(), self.getConfig()); }; + auto loraConfigSetstate = [](tle::LoraConfig& loraConfig, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid LoraConfig state!"); + } + new (&loraConfig) tle::LoraConfig(nb::cast(state[0]), nb::cast>(state[1]), + nb::cast>(state[2])); + }; + nb::class_(m, "LoraConfig") + .def(nb::init, std::optional>(), nb::arg("task_id"), + nb::arg("weights") = nb::none(), nb::arg("config") = nb::none()) + .def_prop_ro("task_id", &tle::LoraConfig::getTaskId) + .def_prop_ro("weights", &tle::LoraConfig::getWeights) + .def_prop_ro("config", &tle::LoraConfig::getConfig) + .def("__getstate__", loraConfigGetstate) + .def("__setstate__", loraConfigSetstate); + + auto multimodalInputGetstate = [](tle::MultimodalInput const& self) + { return nb::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); }; + auto multimodalInputSetstate = [](tle::MultimodalInput& multimodalInput, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid MultimodalInput state!"); + } + new (&multimodalInput) tle::MultimodalInput(nb::cast>>(state[0]), + nb::cast>(state[1]), nb::cast>(state[2])); + }; + nb::class_(m, "MultimodalInput") + .def(nb::init>, std::vector, std::vector>(), + nb::arg("multimodal_hashes"), nb::arg("multimodal_positions"), nb::arg("multimodal_lengths")) + .def_prop_ro("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes) + .def_prop_ro("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions) + .def_prop_ro("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths) + .def("__getstate__", multimodalInputGetstate) + .def("__setstate__", multimodalInputSetstate); + + auto MropeConfigGetstate = [](tle::MropeConfig const& self) + { return nb::make_tuple(self.getMRopeRotaryCosSin(), self.getMRopePositionDeltas()); }; + auto MropeConfigSetstate = [](tle::MropeConfig& mropeConfig, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid MropeConfig state!"); + } + new (&mropeConfig) tle::MropeConfig(nb::cast(state[0]), nb::cast(state[1])); + }; + nb::class_(m, "MropeConfig") + .def(nb::init(), nb::arg("mrope_rotary_cos_sin"), nb::arg("mrope_position_deltas")) + .def_prop_ro("mrope_rotary_cos_sin", &tle::MropeConfig::getMRopeRotaryCosSin) + .def_prop_ro("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas) + .def("__getstate__", MropeConfigGetstate) + .def("__setstate__", MropeConfigSetstate); + + auto lookaheadDecodingConfigGetstate = [](tle::LookaheadDecodingConfig const& self) + { return nb::make_tuple(self.getWindowSize(), self.getNgramSize(), self.getVerificationSetSize()); }; + auto lookaheadDecodingConfigSetstate + = [](tle::LookaheadDecodingConfig& lookaheadDecodingConfig, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid LookaheadDecodingConfig state!"); + } + new (&lookaheadDecodingConfig) tle::LookaheadDecodingConfig( + nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); + }; + nb::class_(m, "LookaheadDecodingConfig") + .def(nb::init(), nb::arg("max_window_size"), nb::arg("max_ngram_size"), + nb::arg("max_verification_set_size")) + .def_prop_ro("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize) + .def_prop_ro("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize) + .def_prop_ro("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize) + .def("calculate_speculative_resource", &tle::LookaheadDecodingConfig::calculateSpeculativeResource) + .def_static( + "calculate_speculative_resource_tuple", &tle::LookaheadDecodingConfig::calculateSpeculativeResourceTuple) + .def("__getstate__", lookaheadDecodingConfigGetstate) + .def("__setstate__", lookaheadDecodingConfigSetstate) + .def_static("get_default_lookahead_decoding_window", + []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingWindow; }) + .def_static("get_default_lookahead_decoding_ngram", + []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingNgram; }) + .def_static("get_default_lookahead_decoding_verification_set", + []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingVerificationSet; }); + + auto TokenRangeRetentionConfigGetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig const& self) + { return nb::make_tuple(self.tokenStart, self.tokenEnd, self.priority, self.durationMs); }; + auto TokenRangeRetentionConfigSetstate + = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig& tokenRangeRetentionConfig, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid state!"); + } + new (&tokenRangeRetentionConfig) tle::KvCacheRetentionConfig::TokenRangeRetentionConfig( + nb::cast(state[0]), nb::cast>(state[1]), + nb::cast(state[2]), nb::cast>(state[3])); + }; + auto kvCacheRetentionConfigGetstate = [](tle::KvCacheRetentionConfig const& self) + { + return nb::make_tuple(self.getTokenRangeRetentionConfigs(), self.getDecodeRetentionPriority(), + self.getDecodeDurationMs(), self.getTransferMode(), self.getDirectory()); + }; + auto kvCacheRetentionConfigSetstate + = [](tle::KvCacheRetentionConfig& kvCacheRetentionConfig, nb::tuple const& state) + { + if (state.size() != 5) + { + throw std::runtime_error("Invalid state!"); + } + new (&kvCacheRetentionConfig) tle::KvCacheRetentionConfig( + nb::cast>(state[0]), + nb::cast(state[1]), nb::cast>(state[2]), + nb::cast(state[3]), nb::cast>(state[4])); + }; + + auto kvCacheRetentionConfig = nb::class_(m, "KvCacheRetentionConfig"); + + nb::class_( + kvCacheRetentionConfig, "TokenRangeRetentionConfig") + .def(nb::init, tle::RetentionPriority, + std::optional>(), + nb::arg("token_start"), nb::arg("token_end"), nb::arg("priority"), nb::arg("duration_ms") = nb::none()) + .def_rw("token_start", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenStart) + .def_rw("token_end", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenEnd) + .def_rw("priority", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::priority) + .def_rw("duration_ms", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::durationMs) + .def("__getstate__", TokenRangeRetentionConfigGetstate) + .def("__setstate__", TokenRangeRetentionConfigSetstate) + .def("__eq__", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::operator==); + + // There's a circular dependency between the declaration of the TokenRangeRetentionPriority and + // KvCacheRetentionConfig bindings. Defer definition of the KvCacheRetentionConfig bindings until the + // TokenRangeRetentionPriority bindings have been defined. + kvCacheRetentionConfig + .def(nb::init, tle::RetentionPriority, + std::optional, tle::KvCacheTransferMode, std::optional>(), + nb::arg("token_range_retention_configs"), + nb::arg("decode_retention_priority") = tle::KvCacheRetentionConfig::kDefaultRetentionPriority, + nb::arg("decode_duration_ms") = nb::none(), nb::arg("transfer_mode") = tle::KvCacheTransferMode::DRAM, + nb::arg("directory") = nb::none()) + .def_prop_ro("token_range_retention_configs", &tle::KvCacheRetentionConfig::getTokenRangeRetentionConfigs) + .def_prop_ro("decode_retention_priority", &tle::KvCacheRetentionConfig::getDecodeRetentionPriority) + .def_prop_ro("decode_duration_ms", &tle::KvCacheRetentionConfig::getDecodeDurationMs) + .def_prop_ro("transfer_mode", &tle::KvCacheRetentionConfig::getTransferMode) + .def_prop_ro("directory", &tle::KvCacheRetentionConfig::getDirectory) + .def("__getstate__", kvCacheRetentionConfigGetstate) + .def("__setstate__", kvCacheRetentionConfigSetstate) + .def("__eq__", &tle::KvCacheRetentionConfig::operator==); + + auto ContextPhaseParamsGetState = [](tle::ContextPhaseParams const& self) + { + if (self.getState() != nullptr) + { + auto serializedState = self.getSerializedState(); + return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), + nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); + } + return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens()); + }; + + auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state) + { + if (state.size() != 4) + { + throw std::runtime_error("Invalid ContextPhaseParams state!"); + } + if (!state[2].is_none()) + { + auto opaque_state = nb::cast(state[2]); + auto opaque_state_str_view = std::string_view(opaque_state.c_str(), opaque_state.size()); + new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), + nb::cast(state[1]), + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), + nb::cast>(state[3])); + } + new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), + nb::cast(state[1]), nb::cast>(state[3])); + }; + + nb::class_(m, "ContextPhaseParams") + .def("__init__", + [](tle::ContextPhaseParams const& self, VecTokens const& first_gen_tokens, + tle::ContextPhaseParams::RequestIdType req_id, std::optional const& opaque_state, + std::optional const& draft_tokens) + { + if (opaque_state) + { + auto opaque_state_str_view + = std::string_view(opaque_state.value().c_str(), opaque_state.value().size()); + return std::make_unique(first_gen_tokens, req_id, + std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); + } + return std::make_unique(first_gen_tokens, req_id, draft_tokens); + }) + .def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }) + .def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }) + .def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId) + .def_prop_ro("opaque_state", + [](tle::ContextPhaseParams const& self) + { + std::optional opaque_state{std::nullopt}; + if (self.getState() != nullptr) + { + auto serializedState = self.getSerializedState(); + opaque_state = nb::bytes(serializedState.data(), serializedState.size()); + } + return opaque_state; + }) + .def("__getstate__", ContextPhaseParamsGetState) + .def("__setstate__", ContextPhaseParamsSetState); + + auto EagleDecodingConfigGetstate = [](tle::EagleConfig const& self) + { + return nb::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(), + self.useDynamicTree(), self.getDynamicTreeMaxTopK()); + }; + auto EagleDecodingConfigSetstate = [](tle::EagleConfig& eagleConfig, nb::tuple const& state) + { + if (state.size() != 5) + { + throw std::runtime_error("Invalid EagleConfig state!"); + } + new (&eagleConfig) tle::EagleConfig(nb::cast>(state[0]), + nb::cast(state[1]), nb::cast>(state[2]), nb::cast(state[3]), + nb::cast>(state[4])); + }; + nb::class_(m, "EagleConfig") + .def(nb::init, bool, std::optional, bool, std::optional>(), + nb::arg("eagle_choices") = nb::none(), nb::arg("greedy_sampling") = true, + nb::arg("posterior_threshold") = nb::none(), nb::arg("use_dynamic_tree") = false, + nb::arg("dynamic_tree_max_topK") = nb::none()) + .def_prop_ro("eagle_choices", &tle::EagleConfig::getEagleChoices) + .def_prop_ro("greedy_sampling", &tle::EagleConfig::isGreedySampling) + .def_prop_ro("posterior_threshold", &tle::EagleConfig::getPosteriorThreshold) + .def_prop_ro("use_dynamic_tree", &tle::EagleConfig::useDynamicTree) + .def_prop_ro("dynamic_tree_max_topK", &tle::EagleConfig::getDynamicTreeMaxTopK) + .def("__getstate__", EagleDecodingConfigGetstate) + .def("__setstate__", EagleDecodingConfigSetstate); + + // Guided decoding params + auto pyGuidedDecodingParams = nb::class_(m, "GuidedDecodingParams"); + + nb::enum_(pyGuidedDecodingParams, "GuideType") + .value("JSON", tle::GuidedDecodingParams::GuideType::kJSON) + .value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA) + .value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX) + .value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR) + .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); + + auto guidedDecodingParamsGetstate + = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide()); }; + + auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& guidedDecodingParams, nb::tuple const& state) + { + if (state.size() != 2) + { + throw std::runtime_error("Invalid GuidedDecodingParams state!"); + } + new (&guidedDecodingParams) tle::GuidedDecodingParams( + nb::cast(state[0]), nb::cast>(state[1])); + }; + + pyGuidedDecodingParams + .def(nb::init>(), nb::arg("guide_type"), + nb::arg("guide") = nb::none()) + .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) + .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) + .def("__getstate__", guidedDecodingParamsGetstate) + .def("__setstate__", guidedDecodingParamsSetstate); + + auto requestGetstate = [](tle::Request const& self) + { + return nb::make_tuple(self.getInputTokenIds(), self.getMaxTokens(), self.getStreaming(), + self.getSamplingConfig(), self.getOutputConfig(), self.getEndId(), self.getPadId(), self.getPositionIds(), + self.getBadWords(), self.getStopWords(), self.getEmbeddingBias(), self.getExternalDraftTokensConfig(), + self.getPromptTuningConfig(), self.getMultimodalInput(), self.getMultimodalEmbedding(), + self.getMropeConfig(), self.getLoraConfig(), self.getLookaheadConfig(), self.getKvCacheRetentionConfig(), + self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(), + self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), + self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), + self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), + self.getGuidedDecodingParams()); + }; + auto requestSetstate = [](tle::Request& request, nb::tuple const& state) + { + if (state.size() != 33) + { + throw std::runtime_error("Invalid Request state!"); + } + new (&request) tle::Request(nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), + nb::cast>(state[5]), nb::cast>(state[6]), + nb::cast>>(state[7]), + nb::cast>>(state[8]), + nb::cast>>(state[9]), nb::cast>(state[10]), + nb::cast>(state[11]), + nb::cast>(state[12]), + nb::cast>(state[13]), nb::cast>(state[14]), + nb::cast>(state[15]), nb::cast>(state[16]), + nb::cast>(state[17]), + nb::cast>(state[18]), + nb::cast>(state[19]), + nb::cast>(state[20]), nb::cast>(state[21]), + nb::cast>(state[22]), nb::cast(state[23]), + nb::cast(state[24]), nb::cast(state[25]), + nb::cast>(state[26]), + nb::cast>(state[27]), nb::cast>(state[28]), + nb::cast>(state[29]), 1, nb::cast>(state[30]), + nb::cast>(state[31]), + nb::cast>(state[32])); + }; + + nb::class_ request(m, "Request", nb::dynamic_attr()); + request + .def(nb::init const&, // endId + std::optional const&, // padId + std::optional>, // positionIds + std::optional>, // badWords + std::optional>, // stopWords + std::optional, // embeddingBias + std::optional, // externalDraftTokensConfig + std::optional, // pTuningConfig + std::optional, // multimodalInput + std::optional, // multimodalEmbedding + std::optional, // mRopeConfig + std::optional, // loraConfig + std::optional, // lookaheadConfig + std::optional, // kvCacheRetentionConfig + std::optional, // logitsPostProcessorName + std::optional, // logitsPostProcessor + std::optional, // encoderInputTokenIds + std::optional, // clientId + bool, // returnAllGeneratedTokens + tle::PriorityType, // priority + tle::RequestType, // type + std::optional, // contextPhaseParams + std::optional, // encoderInputFeatures + std::optional, // encoderOutputLength + std::optional, // crossAttentionMask + SizeType32, // numReturnSequences + std::optional, // eagleConfig + std::optional, // skipCrossAttnBlocks + std::optional, // guidedDecodingParams + std::optional, // languageAdapterUid + std::optional // allottedTimeMs + >(), + // clang-format off + nb::arg("input_token_ids"), + nb::arg("max_tokens"), + nb::kw_only(), + nb::arg("streaming") = false, + nb::arg("sampling_config") = tle::SamplingConfig(), + nb::arg("output_config") = tle::OutputConfig(), + nb::arg("end_id") = nb::none(), + nb::arg("pad_id") = nb::none(), + nb::arg("position_ids") = nb::none(), + nb::arg("bad_words") = nb::none(), + nb::arg("stop_words") = nb::none(), + nb::arg("embedding_bias") = nb::none(), + nb::arg("external_draft_tokens_config") = nb::none(), + nb::arg("prompt_tuning_config") = nb::none(), + nb::arg("multimodal_input") = nb::none(), + nb::arg("multimodal_embedding") = nb::none(), + nb::arg("mrope_config") = nb::none(), + nb::arg("lora_config") = nb::none(), + nb::arg("lookahead_config") = nb::none(), + nb::arg("kv_cache_retention_config") = nb::none(), + nb::arg("logits_post_processor_name") = nb::none(), + nb::arg("logits_post_processor") = nb::none(), + nb::arg("encoder_input_token_ids") = nb::none(), + nb::arg("client_id") = nb::none(), + nb::arg("return_all_generated_tokens") = false, + nb::arg("priority") = tle::Request::kDefaultPriority, + nb::arg("type") = tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION, + nb::arg("context_phase_params") = nb::none(), + nb::arg("encoder_input_features") = nb::none(), + nb::arg("encoder_output_length") = nb::none(), + nb::arg("cross_attention_mask") = nb::none(), + nb::arg("num_return_sequences") = 1, + nb::arg("eagle_config") = nb::none(), + nb::arg("skip_cross_attn_blocks") = nb::none(), + nb::arg("guided_decoding_params") = nb::none(), + nb::arg("language_adapter_uid") = nb::none(), + nb::arg("allotted_time_ms") = nb::none() + ) // clang-format on + .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) + .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) + .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) + .def_prop_rw("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig) + .def_prop_rw("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig) + .def_prop_rw("end_id", &tle::Request::getEndId, &tle::Request::setEndId) + .def_prop_rw("pad_id", &tle::Request::getPadId, &tle::Request::setPadId) + .def_prop_rw("position_ids", &tle::Request::getPositionIds, &tle::Request::setPositionIds) + .def_prop_rw("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords) + .def_prop_rw("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords) + .def_prop_rw("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias) + .def_prop_rw("external_draft_tokens_config", &tle::Request::getExternalDraftTokensConfig, + &tle::Request::setExternalDraftTokensConfig) + .def_prop_rw("prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig) + .def_prop_rw("multimodal_input", &tle::Request::getMultimodalInput, &tle::Request::setMultimodalInput) + .def_prop_rw( + "multimodal_embedding", &tle::Request::getMultimodalEmbedding, &tle::Request::setMultimodalEmbedding) + .def_prop_rw("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig) + .def_prop_rw("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig) + .def_prop_rw("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig) + .def_prop_rw("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig, + &tle::Request::setKvCacheRetentionConfig) + .def_prop_rw("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName, + &tle::Request::setLogitsPostProcessorName) + .def_prop_rw( + "logits_post_processor", &tle::Request::getLogitsPostProcessor, &tle::Request::setLogitsPostProcessor) + .def_prop_rw( + "encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds) + .def_prop_rw("client_id", &tle::Request::getClientId, &tle::Request::setClientId) + .def_prop_rw("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens, + &tle::Request::setReturnAllGeneratedTokens) + .def_prop_rw("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType) + .def_prop_rw( + "encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures) + .def_prop_rw("cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask) + .def_prop_rw("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig) + .def_prop_rw( + "skip_cross_attn_blocks", &tle::Request::getSkipCrossAttnBlocks, &tle::Request::setSkipCrossAttnBlocks) + .def_prop_rw( + "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) + .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) + .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) + .def("__getstate__", requestGetstate) + .def("__setstate__", requestSetstate); + request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; + + nb::class_(m, "SpeculativeDecodingFastLogitsInfo") + .def(nb::init<>()) + .def_rw("draft_request_id", &tle::SpeculativeDecodingFastLogitsInfo::draftRequestId) + .def_rw("draft_participant_id", &tle::SpeculativeDecodingFastLogitsInfo::draftParticipantId) + .def("to_tensor", &tle::SpeculativeDecodingFastLogitsInfo::toTensor); + + auto requestPerfMetrics = nb::class_(m, "RequestPerfMetrics"); + + auto timingMetricsGetstate = [](tle::RequestPerfMetrics::TimingMetrics const& self) + { + return nb::make_tuple(self.arrivalTime, self.firstScheduledTime, self.firstTokenTime, self.lastTokenTime, + self.kvCacheTransferStart, self.kvCacheTransferEnd, self.kvCacheSize); + }; + auto timingMetricsSetstate = [](tle::RequestPerfMetrics::TimingMetrics& timingMetrics, nb::tuple const& state) + { + if (state.size() != 7) + { + throw std::runtime_error("Invalid TimingMetrics state!"); + } + new (&timingMetrics) + tle::RequestPerfMetrics::TimingMetrics{nb::cast(state[0]), + nb::cast(state[1]), + nb::cast(state[2]), + nb::cast(state[3]), + nb::cast(state[4]), + nb::cast(state[5]), nb::cast(state[6])}; + }; + nb::class_(m, "TimingMetrics") + .def(nb::init<>()) + .def_rw("arrival_time", &tle::RequestPerfMetrics::TimingMetrics::arrivalTime) + .def_rw("first_scheduled_time", &tle::RequestPerfMetrics::TimingMetrics::firstScheduledTime) + .def_rw("first_token_time", &tle::RequestPerfMetrics::TimingMetrics::firstTokenTime) + .def_rw("last_token_time", &tle::RequestPerfMetrics::TimingMetrics::lastTokenTime) + .def_rw("kv_cache_transfer_start", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferStart) + .def_rw("kv_cache_transfer_end", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferEnd) + .def_rw("kv_cache_size", &tle::RequestPerfMetrics::TimingMetrics::kvCacheSize) + .def("__getstate__", timingMetricsGetstate) + .def("__setstate__", timingMetricsSetstate); + + auto kvCacheMetricsGetstate = [](tle::RequestPerfMetrics::KvCacheMetrics const& self) + { + return nb::make_tuple(self.numTotalAllocatedBlocks, self.numNewAllocatedBlocks, self.numReusedBlocks, + self.numMissedBlocks, self.kvCacheHitRate); + }; + auto kvCacheMetricsSetstate = [](tle::RequestPerfMetrics::KvCacheMetrics& kvCacheMetrics, nb::tuple const& state) + { + if (state.size() != 5) + { + throw std::runtime_error("Invalid KvCacheMetrics state!"); + } + new (&kvCacheMetrics) + tle::RequestPerfMetrics::KvCacheMetrics{nb::cast(state[0]), nb::cast(state[1]), + nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4])}; + }; + nb::class_(m, "KvCacheMetrics") + .def(nb::init<>()) + .def_rw("num_total_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numTotalAllocatedBlocks) + .def_rw("num_new_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numNewAllocatedBlocks) + .def_rw("num_reused_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numReusedBlocks) + .def_rw("num_missed_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numMissedBlocks) + .def_rw("kv_cache_hit_rate", &tle::RequestPerfMetrics::KvCacheMetrics::kvCacheHitRate) + .def("__getstate__", kvCacheMetricsGetstate) + .def("__setstate__", kvCacheMetricsSetstate); + + auto speculativeDecodingMetricsGetstate = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics const& self) + { return nb::make_tuple(self.acceptanceRate, self.totalAcceptedDraftTokens, self.totalDraftTokens); }; + auto speculativeDecodingMetricsSetstate + = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics& speculativeDecodingMetrics, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid SpeculativeDecodingMetrics state!"); + } + new (&speculativeDecodingMetrics) tle::RequestPerfMetrics::SpeculativeDecodingMetrics{ + nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])}; + }; + + nb::class_(m, "SpeculativeDecodingMetrics") + .def(nb::init<>()) + .def_rw("acceptance_rate", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::acceptanceRate) + .def_rw("total_accepted_draft_tokens", + &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalAcceptedDraftTokens) + .def_rw("total_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalDraftTokens) + .def("__getstate__", speculativeDecodingMetricsGetstate) + .def("__setstate__", speculativeDecodingMetricsSetstate); + + auto requestPerfMetricsGetstate = [](tle::RequestPerfMetrics const& self) + { + return nb::make_tuple(self.timingMetrics, self.kvCacheMetrics, self.speculativeDecoding, self.firstIter, + self.lastIter, self.iter); + }; + auto requestPerfMetricsSetstate = [](tle::RequestPerfMetrics& requestPerfMetrics, nb::tuple const& state) + { + if (state.size() != 6) + { + throw std::runtime_error("Invalid RequestPerfMetrics state!"); + } + new (&requestPerfMetrics) tle::RequestPerfMetrics{nb::cast(state[0]), + nb::cast(state[1]), + nb::cast(state[2]), + nb::cast>(state[3]), + nb::cast>(state[4]), + nb::cast>(state[5])}; + }; + + // There's a circular dependency between the declaration of the TimingMetrics and RequestPerfMetrics bindings. + // Defer definition of the RequestPerfMetrics bindings until the TimingMetrics have been defined. + requestPerfMetrics.def(nb::init<>()) + .def_rw("timing_metrics", &tle::RequestPerfMetrics::timingMetrics) + .def_rw("kv_cache_metrics", &tle::RequestPerfMetrics::kvCacheMetrics) + .def_rw("speculative_decoding", &tle::RequestPerfMetrics::speculativeDecoding) + .def_rw("first_iter", &tle::RequestPerfMetrics::firstIter) + .def_rw("last_iter", &tle::RequestPerfMetrics::lastIter) + .def_rw("iter", &tle::RequestPerfMetrics::iter) + .def("__getstate__", requestPerfMetricsGetstate) + .def("__setstate__", requestPerfMetricsSetstate); + + nb::class_(m, "AdditionalOutput") + .def("__init__ ", + [](tle::AdditionalOutput const& self, std::string const& name, tle::Tensor const& output) + { return std::make_unique(name, output); }) + .def_rw("name", &tle::AdditionalOutput::name) + .def_rw("output", &tle::AdditionalOutput::output); + + auto resultSetstate = [](tle::Result& result, nb::tuple const& state) + { + if (state.size() != 13) + { + throw std::runtime_error("Invalid Request state!"); + } + new (&result) tle::Result(); + result.isFinal = nb::cast(state[0]); + result.outputTokenIds = nb::cast>(state[1]); + result.cumLogProbs = nb::cast>>(state[2]); + result.logProbs = nb::cast>>>(state[3]); + result.contextLogits = nb::cast>(state[4]); + result.generationLogits = nb::cast>(state[5]); + result.encoderOutput = nb::cast>(state[6]); + result.finishReasons = nb::cast>(state[7]); + result.sequenceIndex = nb::cast(state[8]); + result.isSequenceFinal = nb::cast(state[9]); + result.decodingIter = nb::cast(state[10]); + result.contextPhaseParams = nb::cast>(state[11]); + result.requestPerfMetrics = nb::cast>(state[12]); + }; + + auto resultGetstate = [](tle::Result const& self) + { + return nb::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits, + self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal, + self.decodingIter, self.contextPhaseParams, self.requestPerfMetrics); + }; + + nb::class_(m, "Result") + .def(nb::init<>()) + .def_rw("is_final", &tle::Result::isFinal) + .def_rw("output_token_ids", &tle::Result::outputTokenIds) + .def_rw("cum_log_probs", &tle::Result::cumLogProbs) + .def_rw("log_probs", &tle::Result::logProbs) + .def_rw("context_logits", &tle::Result::contextLogits) + .def_rw("generation_logits", &tle::Result::generationLogits) + .def_rw("spec_dec_fast_logits_info", &tle::Result::specDecFastLogitsInfo) + .def_rw("encoder_output", &tle::Result::encoderOutput) + .def_rw("finish_reasons", &tle::Result::finishReasons) + .def_rw("sequence_index", &tle::Result::sequenceIndex) + .def_rw("is_sequence_final", &tle::Result::isSequenceFinal) + .def_rw("decoding_iter", &tle::Result::decodingIter) + .def_rw("context_phase_params", &tle::Result::contextPhaseParams) + .def_rw("request_perf_metrics", &tle::Result::requestPerfMetrics) + .def_rw("additional_outputs", &tle::Result::additionalOutputs) + .def("__getstate__", resultGetstate) + .def("__setstate__", resultSetstate); + + m.def("deserialize_result", + [](nb::bytes& x) + { + std::string str(x.c_str(), x.size()); + std::istringstream is(str); + return tle::serialize_utils::deserialize(is); + }); + + auto responseGetstate = [](tle::Response const& self) + { return nb::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); }; + + auto responseSetstate = [](tle::Response& response, nb::tuple const& state) + { + if (state.size() != 3) + { + throw std::runtime_error("Invalid Request state!"); + } + new (&response) tle::Response( + nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); + }; + + nb::class_(m, "Response") + .def(nb::init>(), nb::arg("request_id"), nb::arg("error_msg"), + nb::arg("client_id") = std::nullopt) + .def(nb::init>(), nb::arg("request_id"), nb::arg("result"), + nb::arg("client_id") = std::nullopt) + .def_prop_ro("request_id", &tle::Response::getRequestId) + .def_prop_ro("client_id", &tle::Response::getClientId) + .def("has_error", &tle::Response::hasError) + .def_prop_ro("error_msg", &tle::Response::getErrorMsg) + .def_prop_ro("result", &tle::Response::getResult) + .def("clear_context_logits", + [](tle::Response& self) + { + if (!self.hasError()) + { + auto& result = const_cast(self.getResult()); + result.contextLogits.reset(); + } + }) + .def("clear_generation_logits", + [](tle::Response& self) + { + if (!self.hasError()) + { + auto& result = const_cast(self.getResult()); + result.generationLogits.reset(); + } + }) + .def("__getstate__", responseGetstate) + .def("__setstate__", responseSetstate); +} + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/request.h b/cpp/tensorrt_llm/nanobind/executor/request.h new file mode 100644 index 00000000000..5a5cf9acbee --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/executor/request.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::executor +{ + +// Register bindings for executor API. +void initRequestBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::executor diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp new file mode 100644 index 00000000000..f3be85bbbf2 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -0,0 +1,388 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "moeBindings.h" +#include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h" +#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" +#include "tensorrt_llm/kernels/customAllReduceKernels.h" +#include "tensorrt_llm/kernels/delayStream.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/cudaEvent.h" +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/decoderState.h" +#include "tensorrt_llm/runtime/decodingInput.h" +#include "tensorrt_llm/runtime/decodingOutput.h" +#include "tensorrt_llm/runtime/gptDecoder.h" +#include "tensorrt_llm/runtime/gptDecoderBatched.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include "tensorrt_llm/runtime/iGptDecoderBatched.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/ipcUtils.h" +#include "tensorrt_llm/runtime/lookaheadBuffers.h" +#include "tensorrt_llm/runtime/loraCache.h" +#include "tensorrt_llm/runtime/mcastGPUBuffer.h" +#include "tensorrt_llm/runtime/request.h" +#include "tensorrt_llm/runtime/speculativeDecodingMode.h" +#include "tensorrt_llm/runtime/tllmRuntime.h" +#include "tensorrt_llm/runtime/torchView.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace tr = tensorrt_llm::runtime; +namespace te = tensorrt_llm::executor; + +class PyIGptDecoder : public tr::IGptDecoder +{ +public: + NB_TRAMPOLINE(tr::IGptDecoder, 5); + + void setup(tr::SamplingConfig const& samplingConfig, size_t batchSize, + tr::DecodingInput::TensorConstPtr const& batchSlots, + std::optional const& output = std::nullopt, + std::optional explicitDraftTokensDType = std::nullopt, + std::optional> const& lookaheadPrompt = std::nullopt, + std::optional> const& lookaheadAlgoConfigs = std::nullopt) override + { + NB_OVERRIDE_PURE(setup, samplingConfig, batchSize, batchSlots, output, explicitDraftTokensDType, + lookaheadPrompt, lookaheadAlgoConfigs); + } + + void forwardAsync(tr::DecodingOutput& output, tr::DecodingInput const& input) override + { + NB_OVERRIDE_PURE(forwardAsync, output, input); + } + + void forwardSync(tr::DecodingOutput& output, tr::DecodingInput const& input) override + { + NB_OVERRIDE_PURE(forwardSync, output, input); + } + + tr::SamplingConfig const& getSamplingConfig() override + { + NB_OVERRIDE_PURE(getSamplingConfig); + } + + void disableLookahead(std::optional const& samplingConfig, tr::SizeType32 batchSize, + tr::DecodingInput::TensorConstPtr batchSlots) override + { + NB_OVERRIDE_PURE(disableLookahead, samplingConfig, batchSize, batchSlots); + } +}; + +namespace tensorrt_llm::nanobind::runtime +{ + +void initBindings(nb::module_& m) +{ + + nb::class_(m, "TaskLayerModuleConfig") + .def(nb::init<>()) + .def_rw("page_id", &tr::LoraCache::TaskLayerModuleConfig::pageId) + .def_rw("slot_idx", &tr::LoraCache::TaskLayerModuleConfig::slotIdx) + .def_rw("in_size", &tr::LoraCache::TaskLayerModuleConfig::inSize) + .def_rw("out_size", &tr::LoraCache::TaskLayerModuleConfig::outSize) + .def_rw("module_id", &tr::LoraCache::TaskLayerModuleConfig::moduleId) + .def_rw("layer_id", &tr::LoraCache::TaskLayerModuleConfig::layerId) + .def_rw("adapter_size", &tr::LoraCache::TaskLayerModuleConfig::adapterSize) + .def_rw("num_slots", &tr::LoraCache::TaskLayerModuleConfig::numSlots) + .def_rw("weights_in_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsInPointer) + .def_rw("weights_out_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsOutPointer) + .def_rw("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer) + .def(nb::self == nb::self); + + nb::class_(m, "BufferManager") + .def(nb::init(), nb::arg("stream"), nb::arg("trim_pool") = false) + .def_prop_ro("stream", &tr::BufferManager::getStream); + + nb::class_(m, "TllmRuntime") + .def( + "__init__", + [](tr::TllmRuntime* self, std::filesystem::path engine_path, float gpu_weights_percent = 1.0f, + bool use_shape_inference = true) + { + // Using default logger by passing nullptr + new (self) + tr::TllmRuntime(tr::RawEngine(engine_path), nullptr, gpu_weights_percent, use_shape_inference); + }, + nb::arg("engine_path"), nb::arg("gpu_weights_percent") = 1.0f, nb::arg("use_shape_inference") = true) + .def( + "__init__", + [](tr::TllmRuntime* self, nb::ndarray engine_buffer, float gpu_weights_percent = 1.0f, + bool use_shape_inference = true) + { + if (engine_buffer.ndim() != 1) + throw std::runtime_error("Expected 1-D array for engine buffer"); + new (self) tr::TllmRuntime(tr::RawEngine(engine_buffer.data(), engine_buffer.size()), nullptr, + gpu_weights_percent, use_shape_inference); + }, + nb::arg("engine_buffer"), nb::arg("gpu_weights_percent") = 1.0f, nb::arg("use_shape_inference") = true) + .def_prop_ro("num_contexts", &tr::TllmRuntime::getNbContexts) + .def_prop_ro("num_profiles", &tr::TllmRuntime::getNbProfiles) + .def("get_opt_profile_id", &tr::TllmRuntime::getOptProfileId, nb::arg("num_tokens"), nb::arg("split_points")) + .def("clear_contexts", &tr::TllmRuntime::clearContexts) + .def("execute_context", &tr::TllmRuntime::executeContext, nb::arg("context_id")) + .def_prop_ro("stream_ptr", &tr::TllmRuntime::getStreamPtr) + .def_prop_ro("buffer_manager", + static_cast(&tr::TllmRuntime::getBufferManager)) + .def("set_layer_profiler", &tr::TllmRuntime::setLayerProfiler) + .def("has_layer_profiler", &tr::TllmRuntime::hasLayerProfiler, nb::arg("context_id")) + .def_prop_ro("layer_profiler_info", &tr::TllmRuntime::getLayerProfileInfo) + .def("report_to_profiler", &tr::TllmRuntime::reportToProfiler, nb::arg("context_id")) + .def_prop_ro("logits_dtype_from_engine", + [](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); }); + + nb::class_(m, "Request") + .def(nb::init, + std::optional>(), + nb::arg("ids"), nb::arg("input_len"), nb::arg("max_new_tokens") = std::nullopt, + nb::arg("end_id") = std::nullopt) + .def_rw("ids", &tr::decoder_batch::Request::ids) + .def_rw("input_len", &tr::decoder_batch::Request::inputLen) + .def_rw("max_new_tokens", &tr::decoder_batch::Request::maxNewTokens) + .def_rw("end_id", &tr::decoder_batch::Request::endId) + .def_rw("draft_logits", &tr::decoder_batch::Request::draftLogits) + .def_rw("embedding_bias", &tr::decoder_batch::Request::embeddingBias) + .def_rw("bad_words_list", &tr::decoder_batch::Request::badWordsList) + .def_rw("stop_words_list", &tr::decoder_batch::Request::stopWordsList) + .def_rw("generated_tokens_per_engine_step", &tr::decoder_batch::Request::generatedTokensPerEngineStep) + .def_rw("medusa_paths", &tr::decoder_batch::Request::medusaPaths) + .def_rw("medusa_tree_ids", &tr::decoder_batch::Request::medusaTreeIds) + .def_rw("lookahead_runtime_config", &tr::decoder_batch::Request::lookaheadRuntimeConfig); + nb::bind_vector>(m, "RequestVector"); + + nb::class_(m, "DecoderBatchInput") + .def(nb::init>, tr::SizeType32>(), nb::arg("logits"), + nb::arg("max_decoding_engine_tokens")) + .def(nb::init>(), nb::arg("logits")) + .def_rw("logits", &tr::decoder_batch::Input::logits) + .def_rw("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps) + .def_rw("batch_slots", &tr::decoder_batch::Input::batchSlots); + + nb::class_(m, "LookaheadDecodingBuffers") + .def(nb::init(), nb::arg("max_num_sequences"), + nb::arg("max_tokens_per_step"), nb::arg("buffer_manager")) + .def_rw("generation_lengths", &tr::LookaheadDecodingBuffers::generationLengths) + .def_rw("position_offsets", &tr::LookaheadDecodingBuffers::positionOffsets) + .def_rw("packed_masks", &tr::LookaheadDecodingBuffers::packedMasks) + .def_rw("position_ids", &tr::LookaheadDecodingBuffers::positionIds); + + nb::class_(m, "ExplicitDraftTokensBuffersInputs") + .def("create", &tr::ExplicitDraftTokensBuffers::Inputs::create, nb::arg("max_num_sequences"), + nb::arg("runtime"), nb::arg("model_config"), nb::arg("world_config")) + .def_rw("temperatures", &tr::ExplicitDraftTokensBuffers::Inputs::temperatures) + .def_rw("position_ids_base", &tr::ExplicitDraftTokensBuffers::Inputs::positionIdsBase) + .def_rw("generation_lengths", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengths) + .def_rw("random_data_sample", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataSample) + .def_rw("random_data_validation", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataValidation) + .def_rw("draft_tokens", &tr::ExplicitDraftTokensBuffers::Inputs::draftTokens) + .def_rw("draft_indices", &tr::ExplicitDraftTokensBuffers::Inputs::draftIndices) + .def_rw("draft_probs", &tr::ExplicitDraftTokensBuffers::Inputs::draftProbs) + .def_rw("packed_masks", &tr::ExplicitDraftTokensBuffers::Inputs::packedMasks) + .def_rw("position_ids", &tr::ExplicitDraftTokensBuffers::Inputs::positionIds) + .def_rw("max_gen_length_host", &tr::ExplicitDraftTokensBuffers::Inputs::maxGenLengthHost) + .def_rw("generation_lengths_host", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengthsHost); + + nb::class_(m, "DecodingInput"); + nb::class_(m, "DecodingOutput"); + + nb::class_(m, "CudaEvent") + .def(nb::init(), nb::arg("flags") = cudaEventDisableTiming) + .def("synchronize", &tr::CudaEvent::synchronize); + + nb::class_(m, "IGptDecoder") + .def( + "setup", + [](tr::IGptDecoder& self, tr::SamplingConfig const& samplingConfig, size_t batchSize, + at::Tensor const& batchSlots, std::optional const& output = std::nullopt, + std::optional explicitDraftTokensDType = std::nullopt, + std::optional> const& lookaheadPrompt = std::nullopt, + std::optional> const& lookaheadAlgoConfigs = std::nullopt) + { + auto tensorPtrBatchSlots = tr::TorchView::of(batchSlots); + self.setup(samplingConfig, batchSize, std::move(tensorPtrBatchSlots), output, explicitDraftTokensDType, + lookaheadPrompt, lookaheadAlgoConfigs); + }, + nb::arg("sampling_config"), nb::arg("batch_size"), nb::arg("batch_slots"), nb::arg("output") = std::nullopt, + nb::arg("explicit_draft_tokens_d_type") = std::nullopt, nb::arg("lookahead_prompt") = std::nullopt, + nb::arg("lookahead_algo_configs") = std::nullopt); + + nb::class_(m, "DecoderState") + .def(nb::init<>()) + .def("setup", &tr::decoder::DecoderState::setup, nb::arg("max_batch_size"), nb::arg("max_beam_width"), + nb::arg("max_attention_window"), nb::arg("sink_token_length"), nb::arg("max_sequence_length"), + nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")) + .def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, nb::arg("max_batch_size"), + nb::arg("max_beam_width"), nb::arg("max_attention_window"), nb::arg("buffer_manager")) + .def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding, + nb::arg("speculative_decoding_mode"), nb::arg("max_tokens_per_engine_step"), nb::arg("dtype"), + nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager")) + .def_prop_ro("joint_decoding_input", &tr::decoder::DecoderState::getJointDecodingInput) + .def_prop_ro("joint_decoding_output", &tr::decoder::DecoderState::getJointDecodingOutput) + .def_prop_ro("cache_indirection_input", &tr::decoder::DecoderState::getCacheIndirectionInput) + .def_prop_ro("cache_indirection_output", &tr::decoder::DecoderState::getCacheIndirectionOutput) + .def_prop_ro( + "sequence_lengths", nb::overload_cast<>(&tr::decoder::DecoderState::getSequenceLengths, nb::const_)) + .def("get_sequence_lengths", + nb::overload_cast(&tr::decoder::DecoderState::getSequenceLengths, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("all_new_tokens", &tr::decoder::DecoderState::getAllNewTokens) + .def_prop_ro("finished_sum", &tr::decoder::DecoderState::getFinishedSum) + .def_prop_ro("finish_reasons", &tr::decoder::DecoderState::getFinishReasons) + .def_prop_ro("ids", nb::overload_cast<>(&tr::decoder::DecoderState::getIds, nb::const_)) + .def("get_ids", nb::overload_cast(&tr::decoder::DecoderState::getIds, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("gathered_ids", nb::overload_cast<>(&tr::decoder::DecoderState::getGatheredIds, nb::const_)) + .def("get_gathered_ids", + nb::overload_cast(&tr::decoder::DecoderState::getGatheredIds, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("parent_ids", &tr::decoder::DecoderState::getParentIds) + .def_prop_ro("cum_log_probs", nb::overload_cast<>(&tr::decoder::DecoderState::getCumLogProbs, nb::const_)) + .def("get_cum_log_probs", + nb::overload_cast(&tr::decoder::DecoderState::getCumLogProbs, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("log_probs", nb::overload_cast<>(&tr::decoder::DecoderState::getLogProbs, nb::const_)) + .def("get_log_probs", nb::overload_cast(&tr::decoder::DecoderState::getLogProbs, nb::const_), + nb::arg("batch_idx")) + .def_prop_ro("next_draft_tokens", &tr::decoder::DecoderState::getNextDraftTokens) + .def_prop_ro("prev_draft_tokens_lengths", &tr::decoder::DecoderState::getPrevDraftTokensLengths) + .def_prop_ro("next_draft_tokens_lengths", &tr::decoder::DecoderState::getNextDraftTokensLengths) + .def_prop_ro("accepted_lengths_cum_sum", &tr::decoder::DecoderState::getAcceptedLengthsCumSum) + .def_prop_ro("accepted_packed_paths", &tr::decoder::DecoderState::getAcceptedPackedPaths) + .def_prop_ro("finished_steps", &tr::decoder::DecoderState::getFinishedSteps) + .def_prop_ro("max_beam_width", &tr::decoder::DecoderState::getMaxBeamWidth) + .def_prop_ro("max_sequence_length", &tr::decoder::DecoderState::getMaxSequenceLength) + .def_prop_ro("max_decoding_decoder_tokens", &tr::decoder::DecoderState::getMaxDecodingDecoderTokens) + .def_prop_ro("max_decoding_engine_tokens", &tr::decoder::DecoderState::getMaxDecodingEngineTokens) + .def_prop_ro("num_decoding_engine_tokens", + nb::overload_cast<>(&tr::decoder::DecoderState::getNumDecodingEngineTokens, nb::const_)) + .def("get_num_decoding_engine_tokens", + nb::overload_cast(&tr::decoder::DecoderState::getNumDecodingEngineTokens, nb::const_), + nb::arg("batch_idx")) + .def("set_num_decoding_engine_tokens", &tr::decoder::DecoderState::setNumDecodingEngineTokens, + nb::arg("batch_idx"), nb::arg("num_tokens")) + .def_prop_ro("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode) + .def_prop_rw("generation_steps", &tr::decoder::DecoderState::getGenerationSteps, + &tr::decoder::DecoderState::setGenerationSteps); + + nb::class_(m, "GptDecoderBatched") + .def(nb::init(), nb::arg("stream")) + .def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_batch_size"), + nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config")) + .def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input")) + .def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference) + .def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"), + nb::arg("sampling_config"), nb::arg("streaming")) + .def_prop_ro( + "decoder_stream", + [](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); }, + nb::rv_policy::reference); + + m.def( + "lamport_initialize_all", + [](intptr_t buffer_0, intptr_t buffer_1, intptr_t buffer_2, size_t size) + { + tr::lamportInitializeAll(reinterpret_cast(buffer_0), reinterpret_cast(buffer_1), + reinterpret_cast(buffer_2), size); + }, + "Lamport initialize all buffers"); + m.def( + "lamport_initialize", + [](intptr_t buffer, size_t size) + { tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast(buffer), size, 0); }, + "Lmaport initialize buffer"); + m.def( + "delay_kernel", + [](int64_t delay_micro_secs, nb::object py_stream) + { + // Get the raw stream handle from PyTorch stream object + auto stream_ptr = nb::cast(py_stream.attr("cuda_stream")); + cudaStream_t stream = reinterpret_cast(stream_ptr); + tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream); + }, + "Delay kernel launch on the default stream"); + m.def( + "max_workspace_size_lowprecision", + [](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); }, + "Calculate the maximum workspace size needed for low precision all-reduce operations"); + + nb::class_(m, "McastGPUBuffer") + .def(nb::init()) + .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer) + .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer); + + nb::enum_(m, "AllReduceFusionOp") + .value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE) + .value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM) + .value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB) + .value("RESIDUAL_RMS_PREPOST_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM) + .value("RESIDUAL_RMS_NORM_QUANT_FP8", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8) + .value("RESIDUAL_RMS_NORM_QUANT_NVFP4", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4) + .value("RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4", + tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4) + .value("RESIDUAL_RMS_NORM_OUT_QUANT_FP8", + tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8); + + nb::enum_(m, "AllReduceStrategy") + .value("NCCL", tensorrt_llm::kernels::AllReduceStrategyType::NCCL) + .value("MIN_LATENCY", tensorrt_llm::kernels::AllReduceStrategyType::MIN_LATENCY) + .value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO) + .value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB) + .value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT) + .value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT); + + // Initialize MoeLoadBalancer bindings + initMoeBindings(m); +} + +void initBindingsEarly(nb::module_& m) +{ + nb::class_(m, "SpeculativeDecodingMode") + .def(nb::init(), nb::arg("state")) + .def_static("NoneType", &tr::SpeculativeDecodingMode::None) + .def_static("DraftTokensExternal", &tr::SpeculativeDecodingMode::DraftTokensExternal) + .def_static("Medusa", &tr::SpeculativeDecodingMode::Medusa) + .def_static("Eagle", &tr::SpeculativeDecodingMode::Eagle) + .def_static("LookaheadDecoding", &tr::SpeculativeDecodingMode::LookaheadDecoding) + .def_static("ExplicitDraftTokens", &tr::SpeculativeDecodingMode::ExplicitDraftTokens) + .def_prop_ro("is_none", &tr::SpeculativeDecodingMode::isNone) + .def_prop_ro("is_draft_tokens_external", &tr::SpeculativeDecodingMode::isDraftTokensExternal) + .def_prop_ro("is_medusa", &tr::SpeculativeDecodingMode::isMedusa) + .def_prop_ro("is_eagle", &tr::SpeculativeDecodingMode::isEagle) + .def_prop_ro("is_lookahead_decoding", &tr::SpeculativeDecodingMode::isLookaheadDecoding) + .def_prop_ro("is_explicit_draft_tokens", &tr::SpeculativeDecodingMode::isExplicitDraftTokens) + .def_prop_ro("updates_position_ids", &tr::SpeculativeDecodingMode::updatesPositionIds) + .def_prop_ro("requires_attention_mask", &tr::SpeculativeDecodingMode::requiresAttentionMask) + .def_prop_ro("predicts_draft_tokens", &tr::SpeculativeDecodingMode::predictsDraftTokens) + .def_prop_ro("needs_kv_cache_rewind", &tr::SpeculativeDecodingMode::needsKVCacheRewind) + .def_prop_ro("variable_draft_length", &tr::SpeculativeDecodingMode::variableDraftLength) + .def_prop_ro("has_draft_logits", &tr::SpeculativeDecodingMode::hasDraftLogits) + .def_prop_ro("needs_decoder_prologue", &tr::SpeculativeDecodingMode::needsDecoderPrologue); +} +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.h b/cpp/tensorrt_llm/nanobind/runtime/bindings.h new file mode 100644 index 00000000000..410dac80b05 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::runtime +{ + +void initBindings(nb::module_& m); +void initBindingsEarly(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/runtime/moeBindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.cpp new file mode 100644 index 00000000000..c26fa84b661 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.cpp @@ -0,0 +1,124 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moeBindings.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.h" +#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h" +#include +#include +#include + +namespace nb = nanobind; +namespace tr = tensorrt_llm::runtime; +namespace tk = tensorrt_llm::kernels; + +namespace tensorrt_llm::nanobind::runtime +{ + +void pyDoReplication(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector& expertLoadFactor, + tr::MoePlacementCpuInfo* cpuPlacement) +{ + TLLM_CHECK_WITH_INFO( + metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch"); + tr::doReplication(metaInfo, expertLoadFactor.data(), cpuPlacement); +}; + +void pyDoPlacement(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector& expertLoadFactor, + tr::MoePlacementCpuInfo* cpuPlacement) +{ + TLLM_CHECK_WITH_INFO( + metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch"); + tr::doPlacement(metaInfo, expertLoadFactor.data(), cpuPlacement); +}; + +void initMoeBindings(nb::module_& m) +{ + // Bind MoeWeight struct + nb::class_(m, "MoeWeight") + .def(nb::init<>()) + .def_prop_rw("weight_ptr", &tr::MoeWeight::getWeightPtr, &tr::MoeWeight::setWeightPtr) + .def_rw("height", &tr::MoeWeight::mHeight) + .def_rw("width", &tr::MoeWeight::mWidth) + .def_rw("pitch", &tr::MoeWeight::mPitch) + .def("__repr__", + [](tr::MoeWeight const& self) + { + return ""; + }); + + // Bind MoeLoadBalanceMetaInfo struct + nb::class_(m, "MoeLoadBalanceMetaInfo") + .def(nb::init(), nb::arg("expert_count"), nb::arg("top_k"), nb::arg("ep_rank"), + nb::arg("ep_size"), nb::arg("slot_count_per_rank")) + .def_rw("expert_count", &tk::MoeLoadBalanceMetaInfo::expertCount) + .def_rw("top_k", &tk::MoeLoadBalanceMetaInfo::topK) + .def_rw("ep_rank", &tk::MoeLoadBalanceMetaInfo::epRank) + .def_rw("ep_size", &tk::MoeLoadBalanceMetaInfo::epSize) + .def_rw("slot_count_per_rank", &tk::MoeLoadBalanceMetaInfo::slotCountPerRank); + + // Bind MoePlacementCpuInfo struct + nb::class_(m, "MoePlacementCpuInfo") + .def(nb::init<>()) + .def_rw("expert_replica_count", &tr::MoePlacementCpuInfo::expertReplicaCount) + .def_rw("rank_expert_ids", &tr::MoePlacementCpuInfo::rankExpertIds); + + // Bind SingleLayerMoeLoadBalancer class + nb::class_(m, "SingleLayerMoeLoadBalancer") + .def("add_single_weight_slot", &tr::SingleLayerMoeLoadBalancer::addSingleWeightSlot, nb::arg("slot_id"), + nb::arg("name"), nb::arg("weight_slot"), "Add a single weight slot for a specific slot ID") + .def("add_single_host_weight", &tr::SingleLayerMoeLoadBalancer::addSingleHostWeight, nb::arg("expert_id"), + nb::arg("name"), nb::arg("host_weight"), "Add a single host weight for a specific expert ID") + .def("set_initial_weight_assignments", &tr::SingleLayerMoeLoadBalancer::setInitialWeightAssignments, + nb::arg("initial_weight_assignments"), "Set initial weight assignments for each slot") + .def("get_pointer", &tr::SingleLayerMoeLoadBalancer::getSelfPtr, + "Get the pointer of the SingleLayerMoeLoadBalancer") + .def("get_layer_id", &tr::SingleLayerMoeLoadBalancer::getLayerId, + "Get the layer id of the SingleLayerMoeLoadBalancer"); + + // Bind MoeLoadBalancer class + nb::class_(m, "MoeLoadBalancer") + .def(nb::init(), nb::arg("ep_rank"), nb::arg("ep_size"), nb::arg("layer_updates_per_iter"), + "Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency") + .def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, nb::arg("use_gpu_memcpy"), + "Set whether to use GPU memcpy for weight updates") + .def("add_layer", &tr::MoeLoadBalancer::AddLayer, nb::arg("expert_count"), nb::arg("top_k"), + nb::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer") + .def("finalize_model", &tr::MoeLoadBalancer::finalizeModel, + "Finalize the model structure, must be called after all layers are added") + .def("set_warm_up_iter_count", &tr::MoeLoadBalancer::setWarmUpIterCount, nb::arg("iter_count"), + "Set the number of warm-up iterations") + .def("start_iter", &tr::MoeLoadBalancer::startIter, nb::arg("iter_id"), nb::arg("enable_statistic"), + nb::arg("enable_update_weights"), "Start a new iteration with the given ID and settings") + .def("end_iter", &tr::MoeLoadBalancer::endIter, nb::arg("iter_id"), "End the iteration with the given ID") + .def("shutdown", &tr::MoeLoadBalancer::shutdown, "Shutdown the load balancer and clean up resources"); + + m.def("is_host_accessible_device_memory_supported", &tr::HostAccessibleDeviceAllocator::isSupported, + "If current system support host accessible device memory"); + + // Bind do_replication function for testing + m.def("do_replication", &pyDoReplication, nb::arg("meta_info"), nb::arg("expert_load_factor"), + nb::arg("cpu_placement"), "Do replication"); + + // Bind do_placement function for testing + m.def("do_placement", &pyDoPlacement, nb::arg("meta_info"), nb::arg("expert_load_factor"), nb::arg("cpu_placement"), + "Do placement"); +} + +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/runtime/moeBindings.h b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.h new file mode 100644 index 00000000000..73b9a3ceec8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/runtime/moeBindings.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::runtime +{ + +void initMoeBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.cpp b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.cpp new file mode 100644 index 00000000000..caef94c5def --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.cpp @@ -0,0 +1,87 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "modelSpecBinding.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/testing/modelSpec.h" + +#include + +namespace nb = nanobind; +using tensorrt_llm::testing::ModelSpec; +using tensorrt_llm::testing::KVCacheType; +using tensorrt_llm::testing::QuantMethod; +using tensorrt_llm::testing::OutputContentType; + +namespace tensorrt_llm::nanobind::testing +{ + +void initBindings(nb::module_& m) +{ + nb::enum_(m, "QuantMethod", nb::is_arithmetic(), "Quantization Method") + .value("NONE", QuantMethod::kNONE, "No Quantization") + .value("SMOOTH_QUANT", QuantMethod::kSMOOTH_QUANT, "Smooth Quantization"); + + nb::enum_(m, "OutputContentType", nb::is_arithmetic(), "Output Content Type") + .value("NONE", OutputContentType::kNONE, "No Output Content") + .value("CONTEXT_LOGITS", OutputContentType::kCONTEXT_LOGITS, "Context Logits") + .value("GENERATION_LOGITS", OutputContentType::kGENERATION_LOGITS, "Generation Logits") + .value("LOG_PROBS", OutputContentType::kLOG_PROBS, "Log Probs") + .value("CUM_LOG_PROBS", OutputContentType::kCUM_LOG_PROBS, "Cumulative Log"); + + nb::class_(m, "ModelSpec") + .def(nb::init()) + .def("use_gpt_plugin", &ModelSpec::useGptAttentionPlugin, nb::rv_policy::reference_internal) + .def("use_packed_input", &ModelSpec::usePackedInput, nb::rv_policy::reference_internal) + .def("set_kv_cache_type", &ModelSpec::setKVCacheType, nb::rv_policy::reference_internal) + .def("use_decoder_per_request", &ModelSpec::useDecoderPerRequest, nb::rv_policy::reference_internal) + .def("use_tensor_parallelism", &ModelSpec::useTensorParallelism, nb::rv_policy::reference_internal) + .def("use_pipeline_parallelism", &ModelSpec::usePipelineParallelism, nb::rv_policy::reference_internal) + .def("use_context_parallelism", &ModelSpec::useContextParallelism, nb::rv_policy::reference_internal) + .def("set_draft_tokens", &ModelSpec::setDraftTokens, nb::rv_policy::reference_internal) + .def("use_accept_by_logits", &ModelSpec::useAcceptByLogits, nb::rv_policy::reference_internal) + .def("use_mamba_plugin", &ModelSpec::useMambaPlugin, nb::rv_policy::reference_internal) + .def("gather_logits", &ModelSpec::gatherLogits, nb::rv_policy::reference_internal) + .def("replace_logits", &ModelSpec::replaceLogits, nb::rv_policy::reference_internal) + .def("return_log_probs", &ModelSpec::returnLogProbs, nb::rv_policy::reference_internal) + .def("smoke_test", &ModelSpec::smokeTest, nb::rv_policy::reference_internal) + .def("use_medusa", &ModelSpec::useMedusa, nb::rv_policy::reference_internal) + .def("use_eagle", &ModelSpec::useEagle, nb::rv_policy::reference_internal) + .def("use_lookahead_decoding", &ModelSpec::useLookaheadDecoding, nb::rv_policy::reference_internal) + .def("use_explicit_draft_tokens_decoding", &ModelSpec::useExplicitDraftTokensDecoding, + nb::rv_policy::reference_internal) + .def("use_draft_tokens_external_decoding", &ModelSpec::useDraftTokensExternalDecoding, + nb::rv_policy::reference_internal) + .def("use_logits", &ModelSpec::useLogits) + .def("use_multiple_profiles", &ModelSpec::useMultipleProfiles, nb::rv_policy::reference_internal) + .def("set_max_input_length", &ModelSpec::setMaxInputLength, nb::rv_policy::reference_internal) + .def("set_max_output_length", &ModelSpec::setMaxOutputLength, nb::rv_policy::reference_internal) + .def("set_quant_method", &ModelSpec::setQuantMethod, nb::rv_policy::reference_internal) + .def("use_lora_plugin", &ModelSpec::useLoraPlugin, nb::rv_policy::reference_internal) + .def("get_input_file", &ModelSpec::getInputFile) + .def("get_model_path", &ModelSpec::getModelPath) + .def("get_results_file", &ModelSpec::getResultsFile) + .def("get_generation_logits_file", &ModelSpec::getGenerationLogitsFile) + .def("get_context_logits_file", &ModelSpec::getContextLogitsFile) + .def("get_cum_log_probs_file", &ModelSpec::getCumLogProbsFile) + .def("get_log_probs_file", &ModelSpec::getLogProbsFile) + .def("enable_context_fmha_fp32_acc", &ModelSpec::enableContextFMHAFp32Acc, nb::rv_policy::reference_internal) + .def("get_enable_context_fmha_fp32_acc", &ModelSpec::getEnableContextFMHAFp32Acc) + .def("__copy__", [](ModelSpec const& self) { return ModelSpec(self); }); +} + +} // namespace tensorrt_llm::nanobind::testing diff --git a/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.h b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.h new file mode 100644 index 00000000000..1aababc6ff8 --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/testing/modelSpecBinding.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::nanobind::testing +{ + +void initBindings(nb::module_& m); + +} // namespace tensorrt_llm::nanobind::testing diff --git a/cpp/tensorrt_llm/nanobind/userbuffers/bindings.cpp b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.cpp new file mode 100644 index 00000000000..82e0d0a1f0c --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.cpp @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bindings.h" +#include "tensorrt_llm/kernels/userbuffers/ub_interface.h" +#include "tensorrt_llm/kernels/userbuffers/userbuffersManager.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include + +namespace nb = nanobind; +namespace tub = tensorrt_llm::runtime::ub; + +namespace tensorrt_llm::kernels::userbuffers +{ + +void UserBufferBindings::initBindings(nb::module_& m) +{ + nb::class_(m, "UBBuffer") + .def_ro("size", &tub::UBBuffer::size) + .def_prop_ro("addr", [](tub::UBBuffer& self) { return reinterpret_cast(self.addr); }) + .def_ro("handle", &tub::UBBuffer::handle) + .def("invalid", &tub::UBBuffer::invalid); + + m.def("ub_initialize", [](int tp_size) { tub::ub_initialize(tp_size); }); + m.def("ub_is_initialized", &tub::ub_is_initialized); + m.def("ub_allocate", [](size_t bytes) { return tub::ub_allocate(bytes); }); + m.def("ub_deallocate", [](intptr_t addr) { return tub::ub_deallocate(reinterpret_cast(addr)); }); + m.def("ub_get", &tub::ub_get); + m.def("ub_supported", &tub::ub_supported); + + m.def("initialize_userbuffers_manager", &tub::initialize_userbuffers_manager); +} +} // namespace tensorrt_llm::kernels::userbuffers diff --git a/cpp/tensorrt_llm/nanobind/userbuffers/bindings.h b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.h new file mode 100644 index 00000000000..15728bf6c1d --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/userbuffers/bindings.h @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace nb = nanobind; + +namespace tensorrt_llm::kernels::userbuffers +{ +class UserBufferBindings +{ +public: + static void initBindings(nb::module_& m); +}; +} // namespace tensorrt_llm::kernels::userbuffers diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 1a5841d4b7a..962071c4857 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -170,7 +170,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .value("CONTINUOUS", tr::ModelConfig::KVCacheType::kCONTINUOUS) .value("PAGED", tr::ModelConfig::KVCacheType::kPAGED) .value("DISABLED", tr::ModelConfig::KVCacheType::kDISABLED) - .def(py::init(&tr::ModelConfig::KVCacheTypeFromString)); + .def("from_string", &tr::ModelConfig::KVCacheTypeFromString); py::enum_(m, "LayerType") .value("ATTENTION", tr::ModelConfig::LayerType::kATTENTION) diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index d09157e1a8b..a8f6aaef73d 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -244,7 +244,17 @@ void initBindings(pybind11::module_& m) py::class_>( executor_kv_cache, "KVCacheEventManager") - .def("get_latest_events", &tle::KVCacheEventManager::getLatestEvents, py::arg("timeout") = std::nullopt); + .def( + "get_latest_events", + [](tle::KVCacheEventManager& self, std::optional timeout_ms = std::nullopt) + { + if (timeout_ms) + { + return self.getLatestEvents(std::chrono::milliseconds(static_cast(*timeout_ms))); + } + return self.getLatestEvents(std::nullopt); + }, + py::arg("timeout_ms") = std::nullopt); tensorrt_llm::pybind::executor::initRequestBindings(m); tensorrt_llm::pybind::executor::initConfigBindings(m); diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index bc0d997e337..1153ca13a8e 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -336,7 +336,7 @@ void initConfigBindings(pybind11::module_& m) throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!"); } return tle::ExtendedRuntimePerfKnobConfig( - state[0].cast(), state[1].cast(), state[2].cast(), state[2].cast()); + state[0].cast(), state[1].cast(), state[2].cast(), state[3].cast()); }; auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self) { diff --git a/examples/models/core/llama/summarize_long.py b/examples/models/core/llama/summarize_long.py index 9f127bc32a6..cee2e07fdd5 100644 --- a/examples/models/core/llama/summarize_long.py +++ b/examples/models/core/llama/summarize_long.py @@ -97,7 +97,7 @@ def TRTLLaMA(args, config): quantization_config = pretrained_config['quantization'] build_config = config['build_config'] - kv_cache_type = KVCacheType(build_config['kv_cache_type']) + kv_cache_type = KVCacheType.from_string(build_config['kv_cache_type']) plugin_config = build_config['plugin_config'] dtype = pretrained_config['dtype'] diff --git a/examples/models/core/qwen2audio/run.py b/examples/models/core/qwen2audio/run.py index e0d495a67f8..93e161c7e08 100644 --- a/examples/models/core/qwen2audio/run.py +++ b/examples/models/core/qwen2audio/run.py @@ -122,7 +122,8 @@ def get_model(self): num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_heads) if "kv_cache_type" in config["build_config"]: - kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"]) + kv_cache_type = KVCacheType.from_string( + config["build_config"]["kv_cache_type"]) else: kv_cache_type = KVCacheType.CONTINUOUS diff --git a/examples/models/core/qwenvl/run.py b/examples/models/core/qwenvl/run.py index a04c2b142e3..06ce341a9a0 100644 --- a/examples/models/core/qwenvl/run.py +++ b/examples/models/core/qwenvl/run.py @@ -118,7 +118,8 @@ def get_model(self): num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_heads) if "kv_cache_type" in config["build_config"]: - kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"]) + kv_cache_type = KVCacheType.from_string( + config["build_config"]["kv_cache_type"]) else: kv_cache_type = KVCacheType.CONTINUOUS diff --git a/jenkins/Build.groovy b/jenkins/Build.groovy index bb8fd7816ce..77e12ee5100 100644 --- a/jenkins/Build.groovy +++ b/jenkins/Build.groovy @@ -47,6 +47,12 @@ CONFIG_LINUX_AARCH64 = "linux_aarch64" @Field def CONFIG_LINUX_AARCH64_LLVM = "linux_aarch64_LLVM" +@Field +def CONFIG_LINUX_X86_64_NANOBIND = "linux_x86_64_Nanobind" + +@Field +def CONFIG_LINUX_AARCH64_NANOBIND = "linux_aarch64_Nanobind" + @Field def BUILD_CONFIGS = [ // Vanilla TARNAME is used for packaging in runLLMPackage @@ -56,6 +62,11 @@ def BUILD_CONFIGS = [ (TARNAME) : "TensorRT-LLM.tar.gz", (WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real", ], + (CONFIG_LINUX_X86_64_NANOBIND) : [ + (WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks", + (TARNAME) : "nanobind-TensorRT-LLM.tar.gz", + (WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real", + ], (CONFIG_LINUX_X86_64_SINGLE_DEVICE) : [ (WHEEL_EXTRA_ARGS) : "--extra-cmake-vars ENABLE_MULTI_DEVICE=0 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars ENABLE_UCX=0 --micro_benchmarks", (TARNAME) : "single-device-TensorRT-LLM.tar.gz", @@ -71,6 +82,11 @@ def BUILD_CONFIGS = [ (TARNAME) : "TensorRT-LLM-GH200.tar.gz", (WHEEL_ARCHS): "90-real;100-real;120-real", ], + (CONFIG_LINUX_AARCH64_NANOBIND): [ + (WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars WARNING_IS_ERROR=ON", + (TARNAME) : "nanobind-TensorRT-LLM-GH200.tar.gz", + (WHEEL_ARCHS): "90-real;100-real;120-real", + ], (CONFIG_LINUX_AARCH64_LLVM) : [ (WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_CUDA_HOST_COMPILER=clang -DCMAKE_LINKER_TYPE=LLD", (TARNAME) : "llvm-TensorRT-LLM-GH200.tar.gz", @@ -523,6 +539,8 @@ def launchStages(pipeline, cpu_arch, enableFailFast, globalVars) pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64 : CONFIG_LINUX_X86_64_VANILLA), "Build TRT-LLM LLVM": [LLM_DOCKER_IMAGE] + prepareLLMBuild( pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_LLVM : CONFIG_LINUX_X86_64_LLVM), + "Build TRT-LLM Nanobind": [LLM_DOCKER_IMAGE] + prepareLLMBuild( + pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_NANOBIND : CONFIG_LINUX_X86_64_NANOBIND), ] if (cpu_arch == X86_64_TRIPLE) { diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 6f6ae7c1186..35e7140ebda 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -64,6 +64,9 @@ def LLVM_CONFIG = "LLVM" @Field LINUX_AARCH64_CONFIG = "linux_aarch64" +@Field +def NANOBIND_CONFIG = "Nanobind" + @Field def BUILD_CONFIGS = [ // Vanilla TARNAME is used for packaging in runLLMPackage @@ -71,6 +74,7 @@ def BUILD_CONFIGS = [ (SINGLE_DEVICE_CONFIG) : [(TARNAME) : "single-device-TensorRT-LLM.tar.gz"], (LLVM_CONFIG) : [(TARNAME) : "llvm-TensorRT-LLM.tar.gz"], (LINUX_AARCH64_CONFIG) : [(TARNAME) : "TensorRT-LLM-GH200.tar.gz"], + (NANOBIND_CONFIG) : [(TARNAME) : "nanobind-TensorRT-LLM.tar.gz"], ] // TODO: Move common variables to an unified location @@ -1724,6 +1728,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) "A10-TensorRT-4": ["a10", "l0_a10", 4, 6], "A10-TensorRT-5": ["a10", "l0_a10", 5, 6], "A10-TensorRT-6": ["a10", "l0_a10", 6, 6], + "A10-Nanobind": ["a10", "l0_a10_nanobind", 1, 1], "A30-Triton-1": ["a30", "l0_a30", 1, 1], "A30-PyTorch-1": ["a30", "l0_a30", 1, 2], "A30-PyTorch-2": ["a30", "l0_a30", 2, 2], @@ -1800,6 +1805,9 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null) if (key.contains("llvm")) { config = LLVM_CONFIG } + if (key.contains("Nanobind")) { + config = NANOBIND_CONFIG + } runLLMTestlistOnPlatform(pipeline, values[0], values[1], config, key.contains("Perf"), key, values[2], values[3]) }]]} fullSet = parallelJobs.keySet() diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index e2dc543ac42..11d528a853d 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -593,7 +593,7 @@ def from_dict(cls, config, plugin_config=None): defaults.get('max_prompt_embedding_table_size')) if "kv_cache_type" in config and config["kv_cache_type"] is not None: - kv_cache_type = KVCacheType(config.pop('kv_cache_type')) + kv_cache_type = KVCacheType.from_string(config.pop('kv_cache_type')) else: kv_cache_type = None gather_context_logits = config.pop( diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index a47e1485b71..e6b55f6e040 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -38,6 +38,23 @@ from tensorrt_llm.quantization.mode import QuantAlgo +def enum_type(enum_class): + + def parse_enum(value): + if isinstance(value, enum_class): + return value + + if isinstance(value, str): + return enum_class.from_string(value) + + valid_values = [e.name for e in enum_class] + raise argparse.ArgumentTypeError( + f"Invalid value '{value}' of type {type(value).__name__}. Expected one of {valid_values}" + ) + + return parse_enum + + def parse_arguments(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -131,7 +148,7 @@ def parse_arguments(): parser.add_argument( '--kv_cache_type', default=argparse.SUPPRESS, - type=KVCacheType, + type=enum_type(KVCacheType), help= "Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed." ) diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index 486c58f6d15..a9f0fe8de40 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -86,7 +86,7 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]: dtype = builder_config['precision'] tp_size = builder_config['tensor_parallel'] pp_size = builder_config.get('pipeline_parallel', 1) - kv_cache_type = KVCacheType(builder_config.get('kv_cache_type')) + kv_cache_type = KVCacheType.from_string(builder_config.get('kv_cache_type')) world_size = tp_size * pp_size assert world_size == mpi_world_size(), \ f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({mpi_world_size()})' diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 2f63ab45f3a..5799ea27945 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -190,3 +190,18 @@ l0_a10: tests: - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] +l0_a10_nanobind: +- condition: + ranges: + system_gpu_count: + gte: 1 + lte: 1 + wildcards: + gpu: + - '*a10*' + linux_distribution_name: ubuntu* + terms: + stage: pre_merge + backend: tensorrt + tests: + - unittest/bindings diff --git a/tests/unittest/bindings/test_bindings_ut.py b/tests/unittest/bindings/test_bindings_ut.py index 774accb080f..6fd46040b66 100644 --- a/tests/unittest/bindings/test_bindings_ut.py +++ b/tests/unittest/bindings/test_bindings_ut.py @@ -5,6 +5,7 @@ from pathlib import Path import numpy as np +import pytest import torch from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly @@ -309,6 +310,8 @@ def parse_runtime_defaults(defaults_dict: dict | None = None): strict_keys=strict_keys) +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_llm_request(): beam_width = 2 sampling_config = _tb.SamplingConfig(beam_width) @@ -418,6 +421,8 @@ def test_Mpicomm(): assert size2 == session_size +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_SamplingConfig_pickle(): config = _tb.SamplingConfig() config.beam_width = 5 @@ -497,6 +502,8 @@ def test_KvCache_events_binding(): torch.cuda.empty_cache() +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_ReqIdsSet_pickle(): ids = _tb.internal.batch_manager.ReqIdsSet() ids1 = pickle.loads(pickle.dumps(ids)) diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 935c4c9bfc3..af72d9ac44b 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -14,6 +14,7 @@ from binding_test_utils import * from pydantic import BaseModel +import tensorrt_llm.bindings as _tb import tensorrt_llm.bindings.executor as trtllm import tensorrt_llm.version as trtllm_version from tensorrt_llm.models.modeling_utils import PretrainedConfig @@ -484,6 +485,8 @@ def test_get_num_responses_ready(streaming: bool, assert executor.get_num_responses_ready() == num_expected_responses +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") @pytest.mark.parametrize("batching_type", [trtllm.BatchingType.INFLIGHT]) @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("beam_width", [1]) @@ -688,6 +691,8 @@ def verify_output(beam_tokens, test_data, given_input_lengths): verify_output(tokens, test_data, given_input_lengths) +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") @pytest.mark.parametrize("streaming", [False, True]) @pytest.mark.parametrize("beam_width", [1]) def test_finish_reason(streaming: bool, beam_width: int, model_files, @@ -1112,6 +1117,8 @@ def test_spec_dec_fast_logits_info(): assert fast_logits_info.draft_participant_id == 5 +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_result(): result = trtllm.Result() result.is_final = True @@ -1149,6 +1156,8 @@ def test_result(): assert (additional_output.output == torch.ones(1, 4, 100)).all() +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_result_pickle(): result = trtllm.Result() result.is_final = True @@ -1495,6 +1504,8 @@ def test_eagle_config(): assert getattr(config, k) == v +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_eagle_config_pickle(): config = trtllm.EagleConfig([[0, 0], [0, 1]], False, 0.5) config_copy = pickle.loads(pickle.dumps(config)) @@ -1867,6 +1878,8 @@ def logits_post_processor(req_id: int, logits: torch.Tensor, assert tokens[-max_tokens:] == [42] * max_tokens +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") def test_logits_post_processor_batched(model_files, model_path): # Define the logits post-processor callback @@ -2141,6 +2154,8 @@ def test_request_perf_metrics_kv_cache(model_path): assert kv_cache_metrics.kv_cache_hit_rate == 1.0 +@pytest.mark.skipif(_tb.binding_type == "nanobind", + reason="Test not supported for nanobind yet") @pytest.mark.parametrize("exclude_input_from_output", [False, True]) def test_request_perf_metrics_draft(model_path_draft_tokens_external, exclude_input_from_output: bool): @@ -2221,7 +2236,7 @@ def test_kv_event_stream_timeout(model_path): assert len(events) == 1 start = datetime.datetime.now() - events = cache_manager.get_latest_events(datetime.timedelta(seconds=1)) + events = cache_manager.get_latest_events(1000) end = datetime.datetime.now() # Make sure that it actually waited assert abs(end - start) > datetime.timedelta(milliseconds=900) From d71c6fe5267f4b61c51cc39d4594cdcb417f0703 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Thu, 17 Jul 2025 17:22:25 +0200 Subject: [PATCH 6/9] [fix] Update jenkins container images (#6094) Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- docker/Makefile | 3 +- docker/README.md | 41 +++++++++++++++++++++++---- jenkins/current_image_tags.properties | 11 ++++--- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/docker/Makefile b/docker/Makefile index 926c8cea1aa..2b5022b1ee8 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -180,7 +180,8 @@ jenkins-aarch64_%: IMAGE_WITH_TAG = $(shell . ../jenkins/current_image_tags.prop jenkins-aarch64_%: STAGE = tritondevel # For x86_64 -jenkins-rockylinux8_%: IMAGE_WITH_TAG = $(shell . ../jenkins/current_image_tags.properties && echo $$LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE) +jenkins-rockylinux8_%: PYTHON_VERSION_TAG_ID = $(if $(findstring 3.12,${PYTHON_VERSION}),PY312,$(if $(findstring 3.10,${PYTHON_VERSION}),PY310,$(error Unknown PYTHON_VERSION specified))) +jenkins-rockylinux8_%: IMAGE_WITH_TAG = $(shell . ../jenkins/current_image_tags.properties && echo $$LLM_ROCKYLINUX8_${PYTHON_VERSION_TAG_ID}_DOCKER_IMAGE) jenkins-rockylinux8_%: STAGE = tritondevel jenkins-rockylinux8_%: BASE_IMAGE = nvidia/cuda jenkins-rockylinux8_%: BASE_TAG = 12.9.0-devel-rockylinux8 diff --git a/docker/README.md b/docker/README.md index 3bfac62a2c4..fa1b80a9fd7 100644 --- a/docker/README.md +++ b/docker/README.md @@ -89,13 +89,10 @@ equivalent containers as [described above](#building-docker-images-with-gnu-make ### Jenkins Integration [`Makefile`](Makefile) has special targets for building, pushing and running the Docker build image used on Jenkins. -The full image name and tag is defined in [`L0_MergeRequest.groovy`](../jenkins/L0_MergeRequest.groovy). The `make` -system will parse this name as the value of `LLM_DOCKER_IMAGE`. To build and push a new Docker image for Jenkins, -define a new image name and tag in [`L0_MergeRequest.groovy`](../jenkins/L0_MergeRequest.groovy) and run +The full image names and tags are defined in [`current_image_tags.properties`](../jenkins/current_image_tags.properties). The `make` +system will parse the names/tags from this file. -```bash -make -C docker jenkins_push -``` +#### Running Start a new container using the same image as Jenkins using your local user account with @@ -134,6 +131,38 @@ make -C docker trtllm_run LOCAL_USER=1 DOCKER_PULL=1 The argument `DOCKER_PULL=1` instructs `make` to pull the latest version of the image before deploying it in the container. By default, the release images built in the above manner are tagged by their `git` branch name and may be frequently updated. +#### Building CI images + +To build and push a new Docker image for Jenkins, define new image names and tags in [`current_image_tags.properties`](../jenkins/current_image_tags.properties) and run + +```bash +# Commands assume an amd64 host +make -C docker jenkins_build +# +docker buildx create --name multi-builder +make -C docker jenkins-aarch64_build \ + DOCKER_BUILD_ARGS="--platform arm64 --builder=multi-builder" +# +# check jenkins/BuildDockerImage.groovy for current Python versions +make -C docker jenkins-rockylinux8_build PYTHON_VERSION=3.12.3 +make -C docker jenkins-rockylinux8_build PYTHON_VERSION=3.10.12 +``` + +The resulting images then need to be pushed: + +```bash +sh -c '. jenkins/current_image_tags.properties && echo $LLM_DOCKER_IMAGE $LLM_SBSA_DOCKER_IMAGE $LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE $LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE' | tr ' ' '\n' | xargs -I{} docker push {} +``` + +Alternatively, it is possible to trigger the image build by opening a new pull request and commenting + +```text +/bot run --stage-list "Build-Docker-Images" +``` + +The resulting images can then be re-tagged using `scripts/rename_docker_images.py` +and the new tags included in [`current_image_tags.properties`](../jenkins/current_image_tags.properties). + ### Docker rootless Some aspects require special treatment when using [Docker rootless mode](https://docs.docker.com/engine/security/rootless/). The `docker/Makefile` contains heuristics to detect Docker rootless mode. When assuming diff --git a/jenkins/current_image_tags.properties b/jenkins/current_image_tags.properties index 5836d212c5e..6e4863a11ed 100644 --- a/jenkins/current_image_tags.properties +++ b/jenkins/current_image_tags.properties @@ -8,7 +8,10 @@ # NB: Although string interpolation is supported, redundant substrings are # kept in the variables below for interoperability with # scripts/rename_docker_images.py -LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.05-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507150652-9504 -LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.05-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507150652-9504 -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202507150652-9504 -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202507150652-9504 +# +# NB: Typically, the suffix indicates the PR whose CI pipeline generated the images. In case that +# images are adopted from PostMerge pipelines, the abbreviated commit hash is used instead. +LLM_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.05-py3-x86_64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507162011-ec3ebae +LLM_SBSA_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.05-py3-aarch64-ubuntu24.04-trt10.11.0.33-skip-tritondevel-202507162011-ec3ebae +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.11.0.33-skip-tritondevel-202507162011-ec3ebae +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE=urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.11.0.33-skip-tritondevel-202507162011-ec3ebae From fc13e00bf325a7ac681afee0c14ef3974e6fd34d Mon Sep 17 00:00:00 2001 From: Stefan Niebler <82932102+stnie@users.noreply.github.com> Date: Fri, 4 Jul 2025 11:49:05 +0000 Subject: [PATCH 7/9] fix: Update disaggregation handling in sampler and executor - Set `is_disagg` in `TRTLLMSampler` to False by default. - Modify sequence length calculation to consider `is_disagg` when determining overlap behavior in the sampling process. - Ensure `kv_cache_transceiver` presence updates `sampler.is_disagg` in the executor instance creation. Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 2 +- tensorrt_llm/_torch/pyexecutor/sampler.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 29f1c5d3ac8..e5a6fac535e 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -520,7 +520,7 @@ def create_py_executor_instance( cache_transceiver_config = executor_config.cache_transceiver_config kv_cache_transceiver = create_kv_cache_transceiver( mapping, kv_cache_manager, attention_type, cache_transceiver_config) - + sampler.is_disagg = kv_cache_transceiver is not None return PyExecutor( resource_manager, scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index b4dfdf25d45..d9b73553086 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -510,6 +510,7 @@ def __init__( self.max_num_sequences = mapping.pp_size * self.executor_config.max_batch_size self.max_seq_idle_microseconds = 180 * 1000 * 1000 self.is_trt_overlap = not disable_overlap_scheduler + self.is_disagg = False self.num_micro_batches = mapping.pp_size if mapping.pp_size > 1 else ( 2 if self.is_trt_overlap else 1) self.micro_batch_idx = 0 @@ -821,7 +822,7 @@ def update_requests_multiple_beams_or_drafting(self, for beam in range(beam_width): seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam] - seq_len = seq_len + 1 if self.is_trt_overlap else seq_len + seq_len = seq_len + 1 if self.is_trt_overlap and self.is_disagg else seq_len num_new_tokens[beam] = min( num_generated_tokens, seq_len - request.get_num_tokens(beam)) From 38fe83b856908db5f5acfc8d6c296aef4ff38ba3 Mon Sep 17 00:00:00 2001 From: Stefan Niebler <82932102+stnie@users.noreply.github.com> Date: Mon, 14 Jul 2025 14:32:46 +0000 Subject: [PATCH 8/9] [fix] Fix bug 5369799: Beam search with disaggregated generation. Fixed Disaggregated Serving + overlap scheduler dropping tokens in the output. CopySequenceLengths now correctly respects tokens generated in the first generation phase. Removed the line in sampler.py that masked this issue and caused bug 5367999 Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp | 5 ++++- tensorrt_llm/_torch/pyexecutor/sampler.py | 4 +--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 1d06ac0e860..baa51f47e73 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -63,7 +63,10 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe SizeType32 batchIdx{0}; for (auto const& llmReq : contextRequests) { - auto const currentSequenceLen = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens(); + auto const disaggFirstGenTokenSize + = llmReq->getContextPhaseParams() ? llmReq->getContextPhaseParams().value().getFirstGenTokens().size() : 0; + auto const currentSequenceLen + = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens() + disaggFirstGenTokenSize; // Get position of the current sequence in the decoder auto const seqSlot = llmReq->mSeqSlot.value(); batchSlotsRange[batchIdx] = seqSlot; diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index d9b73553086..9dd6bd60b8a 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -752,8 +752,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM): reqs_with_new_tokens = [ r for r in reqs - if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0) - or self.is_trt_overlap) + if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0)) ] # Add new tokens @@ -822,7 +821,6 @@ def update_requests_multiple_beams_or_drafting(self, for beam in range(beam_width): seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam] - seq_len = seq_len + 1 if self.is_trt_overlap and self.is_disagg else seq_len num_new_tokens[beam] = min( num_generated_tokens, seq_len - request.get_num_tokens(beam)) From 3d73f2d2895ffbdfb9cfcf4c22c161499289643f Mon Sep 17 00:00:00 2001 From: Stefan Niebler <82932102+stnie@users.noreply.github.com> Date: Thu, 17 Jul 2025 10:01:18 +0000 Subject: [PATCH 9/9] chore: removed unused code lines. Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 1 - tensorrt_llm/_torch/pyexecutor/sampler.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e5a6fac535e..0bfba50a9c9 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -520,7 +520,6 @@ def create_py_executor_instance( cache_transceiver_config = executor_config.cache_transceiver_config kv_cache_transceiver = create_kv_cache_transceiver( mapping, kv_cache_manager, attention_type, cache_transceiver_config) - sampler.is_disagg = kv_cache_transceiver is not None return PyExecutor( resource_manager, scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 9dd6bd60b8a..31e17c1247d 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -510,7 +510,6 @@ def __init__( self.max_num_sequences = mapping.pp_size * self.executor_config.max_batch_size self.max_seq_idle_microseconds = 180 * 1000 * 1000 self.is_trt_overlap = not disable_overlap_scheduler - self.is_disagg = False self.num_micro_batches = mapping.pp_size if mapping.pp_size > 1 else ( 2 if self.is_trt_overlap else 1) self.micro_batch_idx = 0