diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 29829f7f644..a9ea768e27f 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -9,7 +9,6 @@ from tensorrt_llm._utils import nvtx_range from ...._utils import mpi_rank, mpi_world_size -from ....bindings.executor import ExecutorConfig from ....bindings.internal.batch_manager import CacheType from ....mapping import Mapping from ...distributed import MPIDist @@ -259,7 +258,7 @@ def forward( return {"logits": logits_flat} -def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: str = None): +def create_autodeploy_executor(ad_config: LlmArgs): """Create an AutoDeploy executor from the given configuration and checkpoint directory. This is the entrypoint API to the _autodeploy backend. @@ -276,8 +275,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: # some config msg = "pytorch_backend_config must be an AD LlmArgs object" - assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg - ad_config: LlmArgs = executor_config.pytorch_backend_config + assert isinstance(ad_config, LlmArgs), msg assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported" max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index ac3bb7a9f53..3c37650f4fb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -10,8 +10,13 @@ import tensorrt_llm from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType from tensorrt_llm._utils import get_sm_version -from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig +from tensorrt_llm.bindings.executor import (ContextChunkingPolicy, + ExecutorConfig, + LogitsPostProcessorConfig, + ParallelConfig) from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig +from tensorrt_llm.llmapi.llm_args import TorchLlmArgs +from tensorrt_llm.llmapi.tokenizer import TokenizerBase from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.mapping import Mapping @@ -203,10 +208,21 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping: def create_py_executor( - executor_config: ExecutorConfig, - checkpoint_dir: str = None, - lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor: + llm_args: TorchLlmArgs, + checkpoint_dir: str = None, + tokenizer: Optional[TokenizerBase] = None, + lora_config: Optional[LoraConfig] = None, + logits_post_processor_config: Optional[LogitsPostProcessorConfig] = None, + parallel_config: Optional[ParallelConfig] = None, + kwargs_py_executor: Optional[dict] = None, +) -> PyExecutor: + + executor_config = llm_args.get_executor_config(checkpoint_dir, tokenizer) + executor_config.logits_post_processor_config = logits_post_processor_config + executor_config.parallel_config = parallel_config + + garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold + _mangle_executor_config(executor_config) pytorch_backend_config = executor_config.pytorch_backend_config @@ -294,6 +310,8 @@ def create_py_executor( max_seq_len += spec_config.max_draft_len executor_config.max_seq_len = max_seq_len + if kwargs_py_executor and "max_seq_len" in kwargs_py_executor: + kwargs_py_executor["max_seq_len"] = max_seq_len executor_config.max_num_tokens = model_engine.max_num_tokens config = model_engine.model.model_config.pretrained_config @@ -441,6 +459,9 @@ def create_py_executor( # create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring # the original value before creating the final KV cache. executor_config.max_seq_len = max_seq_len + if kwargs_py_executor and "max_seq_len" in kwargs_py_executor: + kwargs_py_executor["max_seq_len"] = max_seq_len + kv_cache_creator.build_managers(resources) for eng in [model_engine, draft_model_engine]: diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 14c8eeb3894..41a46014238 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -21,9 +21,11 @@ from ..bindings import executor as tllm from ..builder import Engine from ..disaggregated_params import DisaggregatedParams +from ..llmapi.llm_args import BaseLlmArgs from ..llmapi.llm_utils import KvCacheRetentionConfig from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available, need_spawn_mpi_workers) +from ..llmapi.tokenizer import TokenizerBase from ..llmapi.utils import (AsyncQueue, enable_llm_debug, enable_worker_single_process_for_tp1, print_colored, print_colored_debug) @@ -354,7 +356,9 @@ def create( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, ) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]: # local imports to avoid cyclic importing from .proxy import GenerationExecutorProxy @@ -381,6 +385,9 @@ def create( "engine": engine, "executor_config": executor_config, "batched_logits_processor": batched_logits_processor, + "hf_model_dir": hf_model_dir, + "tokenizer": tokenizer, + "llm_args": llm_args, } if lora_config: @@ -398,9 +405,7 @@ def create( model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) # WAR: For the performance of gathering logits, we use single process worker # for TP1 to avoid the large overhead of IPC. @@ -411,9 +416,7 @@ def create( "Using single process worker for TP1, this may hurt streaming generation performance." ) return GenerationExecutorWorker(**worker_kwargs, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) # For single-gpu case: # Partition the workload to multiple process for streaming performance. @@ -425,9 +428,7 @@ def create( model_world_size=model_world_size, mpi_session=None, # use mpi4py postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) else: ctx = multiprocessing.get_context("spawn") # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot. @@ -438,9 +439,7 @@ def create( model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) def wait_first_completed( self, futures: List[GenerationResult] diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 78a0d076200..4026697e072 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -45,7 +45,6 @@ def __init__( worker_cls: type = GenerationExecutorWorker, postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, - garbage_collection_gen0_threshold: Optional[int] = None, ) -> None: postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( ) @@ -87,14 +86,14 @@ def __init__( self.model_world_size = model_world_size - self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold + self.garbage_collection_gen0_threshold = worker_kwargs[ + "llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get( + "llm_args", None) is not None else None worker_kwargs = dict(**worker_kwargs, worker_queues=self._setup_queues(), postproc_worker_config=postproc_worker_config, - is_llm_executor=False, - garbage_collection_gen0_threshold=self. - garbage_collection_gen0_threshold) + is_llm_executor=False) if "log_level" not in worker_kwargs: worker_kwargs["log_level"] = logger.level diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 8a1dab6a237..7eb3dde26c1 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -18,8 +18,9 @@ mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import PybindMirror +from ..llmapi.llm_args import BaseLlmArgs, PybindMirror, TorchLlmArgs from ..llmapi.mpi_session import set_mpi_session_cpp +from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, clear_sched_affinity, print_colored_debug, @@ -60,7 +61,9 @@ def __init__( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, ) -> None: postproc_config = postproc_worker_config or PostprocWorkerConfig() super().__init__( @@ -81,8 +84,8 @@ def __init__( self._await_response_helper = AwaitResponseHelper( self) # TODO: make it weakref self._executor_config = executor_config - self._is_pytorch_backend = getattr(self._executor_config, "backend", - None) == "pytorch" + self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch" + self.llm_args = llm_args if global_mpi_size() > 1: logger.set_rank(self.global_rank) @@ -90,20 +93,72 @@ def __init__( if isinstance(engine, list): engine = engine[self.rank] - if executor_config is None: - executor_config = tllm.ExecutorConfig(1) - - executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( - processor_batched=batched_logits_processor, replicate=False) - - def _create_engine(): + def _get_comm_ranks_device_id(): device_id = self.global_rank % torch.cuda.device_count() torch.cuda.set_device(device_id) - # Make sure C++ executor would use same devices/ranks as py_executor global_rank = global_mpi_rank() comm_ranks = mpi_comm().allgather(global_rank) device_ids = mpi_comm().allgather(device_id) + return comm_ranks, device_ids + + def _create_py_executor(): + args = {} + assert hasattr( + self.llm_args, "backend" + ), "llm_args should be with backend in _create_py_executor" + # Some variables like executor_config.max_seq_len might be updated in + # create_py_executor and further used in the worker like _deduce_max_tokens + kwargs_py_executor = {"max_seq_len": self.llm_args.max_seq_len} + if self.llm_args.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_py_executor + create_executor = create_py_executor + args["llm_args"] = self.llm_args + args["checkpoint_dir"] = hf_model_dir + args["tokenizer"] = tokenizer + args["lora_config"] = lora_config + args[ + "logits_post_processor_config"] = tllm.LogitsPostProcessorConfig( + processor_batched=batched_logits_processor, + replicate=False) + comm_ranks, device_ids = _get_comm_ranks_device_id() + args["parallel_config"] = tllm.ParallelConfig( + participant_ids=comm_ranks, device_ids=device_ids) + args["kwargs_py_executor"] = kwargs_py_executor + elif self.llm_args.backend == "_autodeploy": + from tensorrt_llm._torch.auto_deploy.llm_args import \ + LlmArgs as ADLlmArgs + from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ + create_autodeploy_executor + create_executor = create_autodeploy_executor + assert isinstance(self.llm_args, ADLlmArgs) + args["ad_config"] = self.llm_args.get_pytorch_backend_config() + else: + raise ValueError( + f"Unsupported backend config: {self.llm_args.backend}") + + # Define additional attributes that can be used later, such as in _deduce_max_tokens + self.mapping = self.llm_args.parallel_config.to_mapping() + self.checkpoint_loader = None + if self.llm_args.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.config import \ + _construct_checkpoint_loader + self.checkpoint_loader = _construct_checkpoint_loader( + self.llm_args.backend, self.llm_args.checkpoint_loader, + self.llm_args.checkpoint_format) + + _executor = create_executor(**args) + # Define it after create_executor, since the kwargs_py_executor["max_seq_len"] might be updated inside. + self.max_seq_len = kwargs_py_executor["max_seq_len"] + return _executor + + def _create_engine(executor_config): + if executor_config is None: + executor_config = tllm.ExecutorConfig(1) + executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_batched=batched_logits_processor, replicate=False) + comm_ranks, device_ids = _get_comm_ranks_device_id() executor_config.parallel_config = tllm.ParallelConfig( participant_ids=comm_ranks, device_ids=device_ids) @@ -115,30 +170,12 @@ def _create_engine(): executor_config=executor_config, managed_weights=engine.managed_weights) - if not hasattr(executor_config, "backend"): - return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, - executor_config) - args = { - "executor_config": executor_config, - "checkpoint_dir": executor_config.hf_model_dir, - } - if executor_config.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ - create_py_executor - create_executor = create_py_executor - args["lora_config"] = lora_config - args[ - "garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold - elif executor_config.backend == "_autodeploy": - from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ - create_autodeploy_executor - create_executor = create_autodeploy_executor - else: - raise ValueError( - f"Unsupported backend config: {executor_config.backend}") - return create_executor(**args) + assert not hasattr(executor_config, "backend") + return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, + executor_config) - self.engine = _create_engine() + self.engine = _create_py_executor( + ) if self.llm_args is not None else _create_engine(executor_config) self._lora_manager: Optional[LoraManager] = None self._prompt_adapter_manager: Optional[PromptAdapterManager] = None @@ -161,8 +198,9 @@ def _create_engine(): if engine_config.build_config.max_prompt_embedding_table_size > 0: self._prompt_adapter_manager = PromptAdapterManager() - if getattr(executor_config, "backend", - "") == "pytorch" and lora_config is not None: + if self.llm_args and getattr( + self.llm_args, "backend", + "") == "pytorch" and lora_config is not None: from tensorrt_llm._torch.pyexecutor.resource_manager import \ ResourceManagerType peft_cache_manager = self.engine.resource_manager.resource_managers.get( @@ -430,44 +468,65 @@ def _enqueue_request(self, request: GenerationRequest) -> int: context_phase_params = request.disaggregated_params.get_context_phase_params( ) - is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler - if is_overlap_enabled: - is_disaggregated = self.engine.kv_cache_transceiver is not None - if is_disaggregated and ( - request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): - raise ValueError( - "Context only requests are not supported in pytorch backend when overlap is enabled." - ) + if self._is_pytorch_backend: + assert isinstance(self.llm_args, TorchLlmArgs) + if not self.llm_args.disable_overlap_scheduler: + is_disaggregated = self.engine.kv_cache_transceiver is not None + if is_disaggregated and ( + request_type + == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): + raise ValueError( + "Context only requests are not supported in pytorch backend when overlap is enabled." + ) assert request.id is not None def _deduce_max_tokens(request: GenerationRequest, - executor_config: tllm.ExecutorConfig) -> int: + executor_config: Optional[tllm.ExecutorConfig], + llm_args: Optional[BaseLlmArgs] = None) -> int: if request.sampling_params.max_tokens: return request.sampling_params.max_tokens # deduce max_tokens when it's not set by user query_token_len = len( request.query_token_ids) if request.query_token_ids else 0 - cp_size = 1 if (not hasattr(executor_config, "mapping") - or executor_config.mapping.cp_size - is None) else executor_config.mapping.cp_size - if not hasattr(executor_config, "max_seq_len"): - raise RuntimeError( - "max_tokens for sampling is not set and cannot be deduced") + cp_size = 1 + max_seq_len = None + if llm_args is not None: + # deduce max_tokens by llm args + assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." + if hasattr(self, + "mapping") and self.mapping.cp_size is not None: + cp_size = self.mapping.cp_size + if not hasattr(self, "max_seq_len"): + raise RuntimeError( + "max_tokens for sampling is not set and cannot be deduced by llm args" + ) + max_seq_len = self.max_seq_len + else: + # deduce max_tokens by executor config + if hasattr(executor_config, "mapping" + ) and executor_config.mapping.cp_size is not None: + cp_size = executor_config.mapping.cp_size + if not hasattr(executor_config, "max_seq_len"): + raise RuntimeError( + "max_tokens for sampling is not set and cannot be deduced" + ) + max_seq_len = executor_config.max_seq_len splited_prompt_len = int(len(prompt_token_ids) / cp_size) - default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len + default_max_tokens = max_seq_len - splited_prompt_len - query_token_len if default_max_tokens < 0: raise ValueError( f"Deduced max_tokens {default_max_tokens} is less than 0, because" f"prompt length {splited_prompt_len} plus query length {query_token_len} " - f"is larger than max_seq_len {executor_config.max_seq_len}") + f"is larger than max_seq_len {max_seq_len}") return default_max_tokens try: executor_request = tllm.Request( client_id=request.id, input_token_ids=prompt_token_ids, - max_tokens=_deduce_max_tokens(request, self._executor_config), + max_tokens=_deduce_max_tokens(request, self._executor_config, + self.llm_args), streaming=request.streaming, sampling_config=request.sampling_params._get_sampling_config(), end_id=-1 if request.sampling_params.ignore_eos else @@ -593,11 +652,19 @@ def shutdown(self): self.engine.shutdown() self.engine = None - if hasattr( - self._executor_config, "checkpoint_loader" - ) and self._executor_config.checkpoint_loader is not None: - self._executor_config.checkpoint_loader.cleanup() - self._executor_config.checkpoint_loader = None + if self.llm_args is not None: + assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined." + if (self.llm_args.backend == "pytorch" + and hasattr(self, "checkpoint_loader") + and self.checkpoint_loader is not None): + self.checkpoint_loader.cleanup() + self.checkpoint_loader = None + else: + if hasattr( + self._executor_config, "checkpoint_loader" + ) and self._executor_config.checkpoint_loader is not None: + self._executor_config.checkpoint_loader.cleanup() + self._executor_config.checkpoint_loader = None # Check if there are any errors from the threads before shutdown. self._handle_background_error() @@ -641,7 +708,9 @@ def worker_main( is_llm_executor: Optional[ bool] = True, # whether it's the main executor instance lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, ) -> None: mpi_comm().barrier() print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n", @@ -768,7 +837,9 @@ def notify_proxy_threads_to_quit(): postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + hf_model_dir=hf_model_dir, + tokenizer=tokenizer, + llm_args=llm_args) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 43edb6b62cb..b95e41f57a2 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -37,8 +37,7 @@ from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available -from .tokenizer import (TokenizerBase, _llguidance_tokenizer_info, - _xgrammar_tokenizer_info) +from .tokenizer import TokenizerBase, _xgrammar_tokenizer_info # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import from .utils import (append_docstring, exception_handler, get_device_count, print_colored_debug, set_api_status) @@ -967,90 +966,11 @@ def _build_model(self): self.tokenizer) self._tokenizer = self.input_processor.tokenizer - max_batch_size = self.args.max_batch_size - max_num_tokens = self.args.max_num_tokens - max_seq_len = self.args.max_seq_len - - kwargs = {} - if self._on_trt_backend: - kwargs[ - "batching_type"] = self.args.batching_type or tllm.BatchingType.INFLIGHT - - self._executor_config = tllm.ExecutorConfig( - max_beam_width=self.args.max_beam_width, - scheduler_config=PybindMirror.maybe_to_pybind( - self.args.scheduler_config), - max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - gather_generation_logits=self.args.gather_generation_logits, - fail_fast_on_attention_window_too_large=getattr( - self.args, 'fail_fast_on_attention_window_too_large', False), - **kwargs) - - if self.args.kv_cache_config is not None: - self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( - self.args.kv_cache_config) - if os.getenv("FORCE_DETERMINISTIC", "0") == "1": - # Disable KV cache reuse for deterministic mode - self._executor_config.kv_cache_config.enable_block_reuse = False - self._executor_config.kv_cache_config.enable_partial_reuse = False - if self.args.peft_cache_config is not None: - self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( - self.args.peft_cache_config) - if self.args.decoding_config is not None: - self._executor_config.decoding_config = self.args.decoding_config - if self.args.guided_decoding_backend == 'xgrammar': - self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig( - backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. - XGRAMMAR, - **_xgrammar_tokenizer_info(self.tokenizer)) - elif self.args.guided_decoding_backend == 'llguidance': - self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig( - backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. - LLGUIDANCE, - **_llguidance_tokenizer_info(self.tokenizer)) - elif self.args.guided_decoding_backend is not None: - raise ValueError( - f"Unsupported guided decoding backend {self.args.guided_decoding_backend}" - ) - - if self._on_trt_backend: - self._executor_config.normalize_log_probs = self.args.normalize_log_probs - self._executor_config.enable_chunked_context = self.args.enable_chunked_prefill - self._executor_config.max_beam_width = self.args.max_beam_width - if self.args.cache_transceiver_config is not None: - self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( - self.args.cache_transceiver_config) - from tensorrt_llm._torch.pyexecutor.config import update_executor_config - - spec_config = self.args.speculative_config - max_batch_size = self._executor_config.max_batch_size - - if spec_config is not None and spec_config.decoding_type == "AUTO": - from tensorrt_llm._torch.speculative import suggest_spec_config - spec_config = suggest_spec_config(max_batch_size) - - update_executor_config( - self._executor_config, - backend=self.args.backend, - pytorch_backend_config=self.args.get_pytorch_backend_config() - if self.args.backend in ["pytorch", "_autodeploy"] else None, - mapping=self.args.parallel_config.to_mapping(), - speculative_config=spec_config, - hf_model_dir=self._hf_model_dir, - max_input_len=self.args.max_input_len, - max_seq_len=max_seq_len, - checkpoint_format=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_format, - checkpoint_loader=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_loader) - # TODO: revisit gather_context_logits return_logits = self.args.gather_generation_logits - self._executor = self._executor_cls.create( self._engine_dir, - executor_config=self._executor_config, + executor_config=None, batched_logits_processor=self.args.batched_logits_processor, model_world_size=self.args.parallel_config.world_size, mpi_session=self.mpi_session, @@ -1063,8 +983,9 @@ def _build_model(self): ), is_llm_executor=True, lora_config=self.args.lora_config, - garbage_collection_gen0_threshold=self.args. - garbage_collection_gen0_threshold) + hf_model_dir=self._hf_model_dir, + tokenizer=self.tokenizer, + llm_args=self.args) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: """Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend. diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6ed4dea76c7..ff9c8c40af1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -44,7 +44,8 @@ KvCacheConfig as _KvCacheConfig, LookaheadDecodingConfig as _LookaheadDecodingConfig, PeftCacheConfig as _PeftCacheConfig, - SchedulerConfig as _SchedulerConfig) # isort: skip + SchedulerConfig as _SchedulerConfig, + GuidedDecodingConfig as _GuidedDecodingConfig) # isort: skip # isort: on # yapf: enable @@ -56,7 +57,8 @@ SpeculativeDecodingMode) from ..sampling_params import BatchedLogitsProcessor from .build_cache import BuildCacheConfig -from .tokenizer import TokenizerBase, tokenizer_factory +from .tokenizer import (TokenizerBase, _llguidance_tokenizer_info, + _xgrammar_tokenizer_info, tokenizer_factory) from .utils import generate_api_docs_as_docstring, get_type_repr # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import @@ -1835,6 +1837,81 @@ def _load_config_from_ckpt(self, ckpt_dir: Path): moe_tp_size=moe_tp_size, moe_ep_size=moe_ep_size) + def get_executor_config( + self, + _hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + ) -> _ExecutorConfig: + executor_config = _ExecutorConfig( + max_beam_width=self.max_beam_width, + scheduler_config=PybindMirror.maybe_to_pybind( + self.scheduler_config), + max_batch_size=self.max_batch_size, + max_num_tokens=self.max_num_tokens, + gather_generation_logits=self.gather_generation_logits, + fail_fast_on_attention_window_too_large=getattr( + self, 'fail_fast_on_attention_window_too_large', False), + ) + + if self.kv_cache_config is not None: + executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( + self.kv_cache_config) + if os.getenv("FORCE_DETERMINISTIC", "0") == "1": + # Disable KV cache reuse for deterministic mode + executor_config.kv_cache_config.enable_block_reuse = False + executor_config.kv_cache_config.enable_partial_reuse = False + if self.peft_cache_config is not None: + executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( + self.peft_cache_config) + if self.decoding_config is not None: + executor_config.decoding_config = self.decoding_config + if self.guided_decoding_backend == 'xgrammar': + assert tokenizer is not None + executor_config.guided_decoding_config = _GuidedDecodingConfig( + backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR, + **_xgrammar_tokenizer_info(tokenizer)) + elif self.guided_decoding_backend == 'llguidance': + assert tokenizer is not None + executor_config.guided_decoding_config = _GuidedDecodingConfig( + backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE, + **_llguidance_tokenizer_info(tokenizer)) + elif self.guided_decoding_backend is not None: + raise ValueError( + f"Unsupported guided decoding backend {self.guided_decoding_backend}" + ) + + executor_config.enable_chunked_context = self.enable_chunked_prefill + executor_config.max_beam_width = self.max_beam_width + if self.cache_transceiver_config is not None: + executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( + self.cache_transceiver_config) + + from tensorrt_llm._torch.pyexecutor.config import update_executor_config + + spec_config = self.speculative_config + max_batch_size = executor_config.max_batch_size + + if spec_config is not None and spec_config.decoding_type == "AUTO": + from tensorrt_llm._torch.speculative import suggest_spec_config + spec_config = suggest_spec_config(max_batch_size) + + update_executor_config( + executor_config, + backend=self.backend, + pytorch_backend_config=self.get_pytorch_backend_config() + if self.backend in ["pytorch", "_autodeploy"] else None, + mapping=self.parallel_config.to_mapping(), + speculative_config=spec_config, + hf_model_dir=_hf_model_dir, + max_input_len=self.max_input_len, + max_seq_len=self.max_seq_len, + checkpoint_format=None + if self.backend == "_autodeploy" else self.checkpoint_format, + checkpoint_loader=None + if self.backend == "_autodeploy" else self.checkpoint_loader) + + return executor_config + class TrtLlmArgs(BaseLlmArgs): @@ -2183,6 +2260,13 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) + mm_encoder_only: bool = Field( + default=False, + description= + "Only load/execute the vision encoder part of the full model. Defaults to False.", + status="prototype", + ) + # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) @@ -2374,6 +2458,15 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': raise ValueError("batch_wait_timeout_ms must be greater than 0") return self + def get_executor_config( + self, + _hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + ) -> _ExecutorConfig: + executor_config = super().get_executor_config(_hf_model_dir, tokenizer) + executor_config.mm_encoder_only = self.mm_encoder_only + return executor_config + # TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig def get_pytorch_backend_config(self) -> "PyTorchConfig": from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig diff --git a/tensorrt_llm/llmapi/mm_encoder.py b/tensorrt_llm/llmapi/mm_encoder.py index 541068a9a6d..af0f031fc02 100644 --- a/tensorrt_llm/llmapi/mm_encoder.py +++ b/tensorrt_llm/llmapi/mm_encoder.py @@ -4,13 +4,12 @@ from tqdm import tqdm from tensorrt_llm._utils import nvtx_range_debug -from tensorrt_llm.bindings import executor as tllm from tensorrt_llm.inputs import create_input_processor, prompt_inputs from tensorrt_llm.inputs.data import PromptInputs from tensorrt_llm.sampling_params import SamplingParams from .llm import BaseLLM, RequestOutput, _TorchLLM -from .llm_args import PybindMirror +from .llm_args import TorchLlmArgs from .mpi_session import external_mpi_comm_available @@ -56,48 +55,20 @@ def _build_model(self): self.tokenizer) self._tokenizer = self.input_processor.tokenizer - max_batch_size = self.args.max_batch_size - max_num_tokens = self.args.max_num_tokens - max_seq_len = self.args.max_seq_len - - kwargs = {} - if self._on_trt_backend: - kwargs[ - "batching_type"] = self.args.batching_type or tllm.BatchingType.INFLIGHT - - self._executor_config = tllm.ExecutorConfig( - scheduler_config=PybindMirror.maybe_to_pybind( - self.args.scheduler_config), - max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - **kwargs) - from tensorrt_llm._torch.pyexecutor.config import update_executor_config - max_batch_size = self._executor_config.max_batch_size - update_executor_config( - self._executor_config, - backend=self.args.backend, - pytorch_backend_config=self.args.get_pytorch_backend_config() - if self.args.backend in ["pytorch", "_autodeploy"] else None, - mapping=self.args.parallel_config.to_mapping(), - hf_model_dir=self._hf_model_dir, - max_input_len=self.args.max_input_len, - max_seq_len=max_seq_len, - checkpoint_format=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_format, - checkpoint_loader=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_loader, - mm_encoder_only=True) + assert isinstance(self.args, TorchLlmArgs) + self.args.mm_encoder_only = True self._executor = self._executor_cls.create( self._engine_dir, - executor_config=self._executor_config, + executor_config=None, model_world_size=self.args.parallel_config.world_size, mpi_session=self.mpi_session, reuse_mpi_comm=external_mpi_comm_available( self.args.parallel_config.world_size), is_llm_executor=True, # TODO: check if this is correct or needed - garbage_collection_gen0_threshold=self.args. - garbage_collection_gen0_threshold) + hf_model_dir=self._hf_model_dir, + tokenizer=self.tokenizer, + llm_args=self.args) def _validate_mm_args_for_torch_backend(self, kwargs: dict) -> None: """Validate that users don't pass LLM-specific arguments when using MultimodalEncoder (PyTorch). diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index d9dcd0f83d2..91511263db1 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -95,6 +95,10 @@ methods: annotation: Optional[str] default: null status: prototype + mm_encoder_only: + annotation: bool + default: False + status: prototype disable_overlap_scheduler: annotation: bool default: False diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 64282a00229..a01d7f591f3 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -438,10 +438,12 @@ def test_runtime_sizes(self): assert llm.args.max_seq_len == 128 assert llm.args.max_batch_size == 8 - assert llm._executor_config.max_beam_width == 1 - assert llm._executor_config.max_num_tokens == 256 - assert llm._executor_config.max_seq_len == 128 - assert llm._executor_config.max_batch_size == 8 + executor_config = llm.args.get_executor_config( + llm._hf_model_dir, llm.tokenizer) + assert executor_config.max_beam_width == 1 + assert executor_config.max_num_tokens == 256 + assert executor_config.max_seq_len == 128 + assert executor_config.max_batch_size == 8 def test_dynamic_setattr(self): with pytest.raises(pydantic_core._pydantic_core.ValidationError):