diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 5e447e6a0e4..de4c9b23a87 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -1,10 +1,10 @@ import argparse from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi import (CudaGraphConfig, DraftTargetDecodingConfig, - EagleDecodingConfig, KvCacheConfig, MoeConfig, - MTPDecodingConfig, NGramDecodingConfig, - TorchCompileConfig) +from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig, + DraftTargetDecodingConfig, EagleDecodingConfig, + KvCacheConfig, MoeConfig, MTPDecodingConfig, + NGramDecodingConfig, TorchCompileConfig) example_prompts = [ "Hello, my name is", @@ -181,6 +181,8 @@ def setup_llm(args, **kwargs): is_use_oldest=True, is_public_pool=True, ) + elif spec_decode_algo == "AUTO": + spec_config = AutoDecodingConfig() else: spec_config = None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index cf466700d54..014e0880347 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -864,7 +864,7 @@ def _prepare_and_schedule_batch(self): self.use_spec_decode = self.drafter.should_use_spec_decode( self.active_requests) self.model_engine.enable_spec_decode = self.use_spec_decode - self._prepare_draft_requests(self.active_requests) + self._prepare_draft_requests() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) @@ -965,14 +965,15 @@ def _executor_loop(self): iter_stats=iter_stats, iter_start_time=iter_start_time)) - def _prepare_draft_requests(self, requests): + def _prepare_draft_requests(self): try: # Set draft tokens here to make the KV cache manager # and scheduler aware of them. - for req in requests: + for req in self.active_requests: if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS, LlmRequestState.DISAGG_GENERATION_INIT): continue + req.py_last_draft_tokens = req.py_draft_tokens max_draft_len = self.model_engine.spec_config.max_draft_len diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 46fe18e0584..d606073f000 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -18,6 +18,7 @@ class SpeculativeDecodingMode(IntEnum): DRAFT_TARGET = auto() USER_PROVIDED = auto() NONE = auto() + AUTO = auto() def is_mtp(self): return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 9113900ef94..39267f5da26 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -2,6 +2,7 @@ from ordered_set import OrderedSet +from tensorrt_llm.llmapi import NGramDecodingConfig from tensorrt_llm.logger import logger from ..pyexecutor.llm_request import * @@ -163,10 +164,11 @@ class NGramDrafter(Drafter): def __init__( self, - spec_config: "NGramDecodingConfig", + spec_config: NGramDecodingConfig, ngram_pool_manager: NGramPoolManager = None, ): assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." + self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len self.spec_resource_manager = ngram_pool_manager @@ -175,6 +177,10 @@ def prepare_draft_tokens( scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: + # Disable NGram speculative decoding auto heuristic for batch size > 32. + if self.spec_config.is_auto_heuristic and len( + scheduled_requests.all_requests()) > 32: + return # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft # before forward_step, so py_batch_idx is not assigned. diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index 24f7ad00e75..bef7ded9948 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -4,15 +4,15 @@ from .build_cache import BuildCacheConfig from .llm import LLM, RequestOutput # yapf: disable -from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig, - CapacitySchedulerPolicy, ContextChunkingPolicy, - CudaGraphConfig, DraftTargetDecodingConfig, - DynamicBatchConfig, EagleDecodingConfig, - ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, - LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, - MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig, - TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, - UserProvidedDecodingConfig) +from .llm_args import (AutoDecodingConfig, BatchingType, CacheTransceiverConfig, + CalibConfig, CapacitySchedulerPolicy, + ContextChunkingPolicy, CudaGraphConfig, + DraftTargetDecodingConfig, DynamicBatchConfig, + EagleDecodingConfig, ExtendedRuntimePerfKnobConfig, + KvCacheConfig, LlmArgs, LookaheadDecodingConfig, + MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, + NGramDecodingConfig, SchedulerConfig, TorchCompileConfig, + TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, QuantConfig) from .mpi_session import MpiCommSession @@ -53,4 +53,5 @@ 'LlmArgs', 'TorchLlmArgs', 'TrtLlmArgs', + 'AutoDecodingConfig', ] diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 2ddd9beaf1a..bd24bfc2a5d 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -31,8 +31,8 @@ from ..logger import logger from ..sampling_params import SamplingParams from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING, - TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig, - PybindMirror, TorchLlmArgs, TrtLlmArgs) + TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig, + PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs) from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available @@ -995,13 +995,43 @@ def _build_model(self): self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( self.args.cache_transceiver_config) from tensorrt_llm._torch.pyexecutor.config import update_executor_config + + spec_config = self.args.speculative_config + max_batch_size = self._executor_config.max_batch_size + # Apply default heuristic to AutoDecodingConfig based on benchmark results + # With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 + # With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 + # With concurrency > 32, speculative decoding is disabled. + if spec_config is not None and spec_config.decoding_type == "AUTO": + if not self.args.disable_overlap_scheduler: + logger.info( + "Disable overlap scheduler to enable Auto speculative decoding with Ngram." + ) + # From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32. + # Therefore, we disable overlap scheduler to enable NGram speculative decoding. + self.args.disable_overlap_scheduler = True + + spec_config = NGramDecodingConfig( + max_draft_len=5 if max_batch_size <= 4 else 3, + max_matching_ngram_size=3 if max_batch_size <= 4 else 5, + is_keep_all=True, + is_use_oldest=True, + is_public_pool=True, + # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. + is_auto_heuristic=True, + ) + + logger.info( + f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}" + ) + update_executor_config( self._executor_config, backend=self.args.backend, pytorch_backend_config=self.args.get_pytorch_backend_config() if self.args.backend in ["pytorch", "_autodeploy"] else None, mapping=self.args.parallel_config.to_mapping(), - speculative_config=self.args.speculative_config, + speculative_config=spec_config, hf_model_dir=self._hf_model_dir, max_input_len=self.args.max_input_len, max_seq_len=max_seq_len, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1c5885fb604..44016980fc4 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -336,6 +336,7 @@ def from_dict(cls, data: dict): "NGram": NGramDecodingConfig, "DraftTarget": DraftTargetDecodingConfig, "UserProvided": UserProvidedDecodingConfig, + "AUTO": AutoDecodingConfig, } config_class = config_classes.get(decoding_type) @@ -446,11 +447,13 @@ class NGramDecodingConfig(DecodingBaseConfig): is_public_pool: bool = True Whether to use a common pool for all requests, or the pool is private for each request if False. """ - - max_matching_ngram_size: int = 4 + max_matching_ngram_size: int = 0 is_keep_all: bool = True is_use_oldest: bool = True is_public_pool: bool = True + # Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic. + # User should not set this flag. Use AutoDecodingConfig instead. + is_auto_heuristic: bool = False @classmethod def from_dict(cls, data: dict): @@ -510,6 +513,29 @@ def spec_dec_mode(self): return TorchSpeculativeDecodingMode.MTP +class AutoDecodingConfig(DecodingBaseConfig): + """ + Configuration for auto speculative decoding. + + This config is used to automatically select the best speculative decoding algorithm. + + According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32. + Default heuristic: + With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3 + With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5 + With concurrency > 32, speculative decoding is disabled. + """ + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + decoding_type: ClassVar[str] = "AUTO" + + def supports_backend(self, backend: str) -> bool: + return backend == "pytorch" + + class PybindMirror(ABC): ''' A class containing the utilities for mirroring Python classes to pybinding classes. @@ -872,6 +898,7 @@ def supports_backend(self, backend: str) -> bool: MTPDecodingConfig, NGramDecodingConfig, UserProvidedDecodingConfig, + AutoDecodingConfig, ]] @@ -1292,7 +1319,6 @@ def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs": tensorrt_llm.llmapi.llm_utils.BaseLlmArgs: The `BaseLlmArgs` instance. """ kwargs = BaseLlmArgs._check_consistency(dict(kwargs)) - ret = cls(**kwargs) return ret @@ -1621,6 +1647,11 @@ def validate_speculative_config(self): self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED self.build_config.max_draft_len = self.speculative_config.max_draft_len + elif isinstance(self.speculative_config, AutoDecodingConfig): + assert self.backend in ['pytorch', '_autodeploy'] + self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO + self.build_config.max_draft_len = self.speculative_config.max_draft_len + else: raise ValueError( f"Unrecognized speculative config type {type(self.speculative_config)}" diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 07d930064c6..7b2af7af15e 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag): EAGLE = auto() NGRAM = auto() USER_PROVIDED = auto() + AUTO = auto() @staticmethod def from_arguments(args: argparse.Namespace): @@ -117,6 +118,8 @@ def from_arguments(args: argparse.Namespace): return SpeculativeDecodingMode.NGRAM elif args.speculative_decoding_mode == "user_provided": return SpeculativeDecodingMode.USER_PROVIDED + elif args.speculative_decoding_mode == "auto": + return SpeculativeDecodingMode.AUTO else: assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 0e4a1619cfe..ded14aadd82 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1775,6 +1775,31 @@ def test_ptp_quickstart_advanced_ngram(llm_root, llm_venv, model_name, _check_mem_usage(running_log, [27.0, 0, 0, 0]) +@pytest.mark.parametrize("model_name,model_path", [ + ("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"), +]) +def test_ptp_quickstart_advanced_auto(llm_root, llm_venv, model_name, + model_path): + print(f"Testing {model_name}.") + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + with tempfile.NamedTemporaryFile(mode='w+t', + suffix=f".{model_name}.log", + dir="./", + delete=True, + delete_on_close=True) as running_log: + llm_venv.run_cmd([ + str(example_root / "quickstart_advanced.py"), + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--spec_decode_algo", + "AUTO", + "--use_cuda_graph", + "--max_batch_size=4", + ], + stdout=running_log) + _check_mem_usage(running_log, [27.0, 0, 0, 0]) + + @skip_post_blackwell @pytest.mark.skip_less_device_memory(110000) @pytest.mark.skip_less_device(8) diff --git a/tests/unittest/api_stability/references_committed/llm.yaml b/tests/unittest/api_stability/references_committed/llm.yaml index d0d6c8ce0bf..a722da54958 100644 --- a/tests/unittest/api_stability/references_committed/llm.yaml +++ b/tests/unittest/api_stability/references_committed/llm.yaml @@ -59,7 +59,7 @@ methods: default: null # Speculative decoding speculative_config: - annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, NoneType] + annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType] default: null # generation constraints max_batch_size: