diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index aa0902484a5..89fdcc6d850 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1100,6 +1100,10 @@ def _executor_loop(self): sample_state = self._sample_async(scheduled_batch, batch_outputs) + if self.drafter is not None: + self.drafter.run_drafter_post(scheduled_batch, + self.resource_manager, + self.is_warmup) self._update_request_states(scheduled_batch) self._update_requests(sample_state, self.resource_manager) diff --git a/tensorrt_llm/_torch/speculative/__init__.py b/tensorrt_llm/_torch/speculative/__init__.py index 8f6e0254faa..31ed71f76f3 100644 --- a/tensorrt_llm/_torch/speculative/__init__.py +++ b/tensorrt_llm/_torch/speculative/__init__.py @@ -3,6 +3,7 @@ from .interface import SpecMetadata from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker from .ngram import NGramDrafter, NGramPoolManager +from .save_hidden_state import SaveHiddenStatesDrafter from .spec_tree_manager import SpecTreeManager from .utils import (get_num_extra_kv_tokens, get_num_spec_layers, get_spec_decoder, get_spec_drafter, get_spec_metadata, @@ -16,6 +17,7 @@ "MTPWorker", "NGramDrafter", "NGramPoolManager", + "SaveHiddenStatesDrafter", "SpecMetadata", "get_num_extra_kv_tokens", "get_num_spec_layers", diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 74384206740..485934f7b5c 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -67,3 +67,15 @@ def pad_draft_tokens_for_cuda_graph( num_draft_tokens = get_draft_token_length(req) req.py_draft_tokens.extend( 0 for _ in range(max_draft_tokens - num_draft_tokens)) + + def run_drafter_post( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + is_warmup: bool = False, + ) -> None: + """ + If draft forward needs to be run directly after the target model forward, + this method can be overridden to do that. + Used in SaveHiddenStatesDrafter (to ensure correct input_ids) + """ diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 571850c82da..42682c07934 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -126,6 +126,10 @@ def __post_init__(self): self.num_layers - 4) else: self.layers_to_capture = sorted(list(self.layers_to_capture)) + if self.layers_to_capture[0] == -1: + self.layers_to_capture = self.layers_to_capture[1:] + [ + self.layers_to_capture.pop(0) + ] self.num_capture_layers = len(self.layers_to_capture) # Initialize to 0 to avoid reading uninitialized memory during warmup diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 16522e98320..191eb92c7eb 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum): NGRAM = auto() DRAFT_TARGET = auto() USER_PROVIDED = auto() + SAVE_HIDDEN_STATES = auto() NONE = auto() AUTO = auto() @@ -55,6 +56,9 @@ def is_none(self): def is_draft_target(self): return self == SpeculativeDecodingMode.DRAFT_TARGET + def is_save_hidden_states(self): + return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES + def without_logits(self): return self.is_mtp_one_model() or self.is_eagle3_one_model() @@ -95,8 +99,9 @@ def has_spec_decoder(self): ) or self.is_eagle3_one_model() def has_spec_drafter(self): - return self.is_eagle3() or self.is_draft_target() or self.is_ngram( - ) or self.is_user_provided() or self.is_mtp_eagle() + return self.is_eagle3( + ) or self.is_draft_target() or self.is_ngram() or self.is_user_provided( + ) or self.is_mtp_eagle() or self.is_save_hidden_states() def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ diff --git a/tensorrt_llm/_torch/speculative/save_hidden_state.py b/tensorrt_llm/_torch/speculative/save_hidden_state.py new file mode 100644 index 00000000000..202088784fe --- /dev/null +++ b/tensorrt_llm/_torch/speculative/save_hidden_state.py @@ -0,0 +1,99 @@ +import os +from typing import Optional + +import torch + +from tensorrt_llm._utils import local_mpi_rank + +from ..pyexecutor.llm_request import LlmRequest +from ..pyexecutor.resource_manager import ResourceManager +from ..pyexecutor.scheduler import ScheduledRequests +from .drafter import Drafter + + +class SaveHiddenStatesDrafter(Drafter): + + def __init__( + self, + spec_config: "SaveHiddenStatesDecodingConfig", + spec_resource_manager, + ): + super().__init__(spec_config.max_concurrency) + self.spec_config = spec_config + self.max_draft_len = spec_config.max_draft_len + self._iter = 1 + self._output_directory = spec_config.output_directory + self._file_prefix = spec_config.file_prefix + self._write_interval = spec_config.write_interval + self._saved_state = [] + self.spec_resource_manager = spec_resource_manager + os.makedirs(self._output_directory, exist_ok=True) + + def _process_request(self, request: LlmRequest, resource_manager) -> None: + out_dict = {} + if local_mpi_rank() == 0: + input_ids = torch.tensor(list(request.get_tokens(0)), + dtype=torch.long, + device='cpu') + hidden_size = resource_manager.hidden_size + num_tokens = input_ids.shape[0] + hidden_states = resource_manager.hidden_states[:num_tokens, + -hidden_size:].cpu( + ).clone() + + out_dict = { + "id": self._iter, + "input_ids": input_ids, + "hidden_state": hidden_states, + } + if len(self.spec_config.eagle3_layers_to_capture) > 1: + if self.spec_config._last_hidden_in_save: + out_dict[ + "aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu( + ).clone() + else: + out_dict[ + "aux_hidden_states"] = resource_manager.hidden_states[: + num_tokens, : + -hidden_size].cpu( + ).clone( + ) + + self._saved_state.append(out_dict) + + def _write_to_file(self) -> None: + if local_mpi_rank() == 0: + output_path = os.path.join(self._output_directory, + f"{self._file_prefix}_{self._iter}.pt") + torch.save(self._saved_state, output_path) + self._saved_state = [] + + def prepare_draft_tokens( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + ) -> None: + for request in sorted( + scheduled_requests.context_requests, + key=lambda r: + (r.py_batch_idx is None, r.py_batch_idx or r.request_id), + ): + request.py_max_new_tokens = 1 + + def run_drafter_post( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + is_warmup: bool = False, + ) -> None: + if is_warmup: + return + for request in sorted( + scheduled_requests.context_requests, + key=lambda r: + (r.py_batch_idx is None, r.py_batch_idx or r.request_id), + ): + self._process_request(request, self.spec_resource_manager) + if self._iter % self._write_interval == 0: + self._write_to_file() + self._iter += 1 diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 56b44704c0e..152cbd1074e 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -11,6 +11,7 @@ from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata, MTPWorker) from .ngram import NGramDrafter, NGramPoolManager +from .save_hidden_state import SaveHiddenStatesDrafter def get_spec_metadata(spec_config, @@ -55,6 +56,25 @@ def get_spec_metadata(spec_config, max_num_tokens=max_num_tokens, layers_to_capture=spec_config.eagle3_layers_to_capture, ) + if spec_config.spec_dec_mode.is_save_hidden_states(): + if spec_config.eagle3_layers_to_capture is None: + spec_config.eagle3_layers_to_capture = { + 1, model_config.num_hidden_layers // 2 - 1, + model_config.num_hidden_layers - 4, -1 + } + return Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, + spec_dec_mode=spec_config.spec_dec_mode, + max_num_requests=max_num_requests, + num_layers=model_config.num_hidden_layers, + hidden_size=model_config.hidden_size, + max_num_tokens=max_num_tokens, + dtype=model_config.torch_dtype, + is_draft_model=is_draft_model, + eagle3_resource_manager=spec_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, + max_total_draft_tokens=1, + ) if spec_config.spec_dec_mode.is_draft_target() or \ spec_config.spec_dec_mode.is_ngram() or \ spec_config.spec_dec_mode.is_user_provided(): @@ -102,6 +122,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): max_seq_len, max_num_tokens, ) + if spec_dec_mode.is_save_hidden_states(): + return Eagle3ResourceManager( + spec_config, + model_engine.model.config.torch_dtype, + model_config.hidden_size, + max_num_requests, + max_seq_len, + max_num_tokens, + ) if spec_dec_mode.is_ngram(): return NGramPoolManager(spec_config, max_num_requests) if spec_dec_mode.is_user_provided(): @@ -151,6 +180,9 @@ def get_spec_drafter(model_engine, if spec_config.spec_dec_mode.is_ngram(): return NGramDrafter(spec_config, spec_resource_manager) + if spec_config.spec_dec_mode.is_save_hidden_states(): + return SaveHiddenStatesDrafter(spec_config, spec_resource_manager) + return None diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index 1c3ebd6e2b9..adc0e7e35c3 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -11,7 +11,8 @@ DynamicBatchConfig, EagleDecodingConfig, ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, - MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig, + MTPDecodingConfig, NGramDecodingConfig, + SaveHiddenStatesDecodingConfig, SchedulerConfig, TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, @@ -59,4 +60,5 @@ 'AutoDecodingConfig', 'AttentionDpConfig', 'LoRARequest', + 'SaveHiddenStatesDecodingConfig', ] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 5a05ee741f3..4bad387a1b0 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -374,6 +374,7 @@ def from_dict(cls, data: dict): "Lookahead": LookaheadDecodingConfig, "NGram": NGramDecodingConfig, "DraftTarget": DraftTargetDecodingConfig, + "SaveState": SaveHiddenStatesDecodingConfig, "UserProvided": UserProvidedDecodingConfig, "AUTO": AutoDecodingConfig, } @@ -556,6 +557,52 @@ def num_capture_layers(self) -> int: return 3 +class SaveHiddenStatesDecodingConfig(DecodingBaseConfig): + output_directory: str + write_interval: int = 20 + file_prefix: str = "data" + eagle3_layers_to_capture: Optional[Set[int]] = None + + max_total_draft_tokens: Optional[int] = Field(default=1, init=False) + eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False) + + def model_post_init(self, __context): + self._last_hidden_in_save = True + if self.eagle3_layers_to_capture is None: + self._last_hidden_in_save = False + elif -1 not in self.eagle3_layers_to_capture: + self._last_hidden_in_save = False + self.eagle3_layers_to_capture.add(-1) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + decoding_type: ClassVar[str] = "SaveState" + + def validate(self) -> None: + if self.output_directory is None or not self.eagle3_layers_to_capture: + raise ValueError( + "Save directory and layers to capture must be provided") + + @functools.cached_property + def spec_dec_mode(self): + from tensorrt_llm._torch.speculative.interface import \ + SpeculativeDecodingMode as TorchSpeculativeDecodingMode + return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES + + @functools.cached_property + def num_capture_layers(self): + """ + Returns the number of layers to capture of the target model. + If eagle3_layers_to_capture is not None, return the length of the set. + Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state). + """ + if self.eagle3_layers_to_capture is None: + return 4 + return len(self.eagle3_layers_to_capture) + + class UserProvidedDecodingConfig(DecodingBaseConfig): # Cannot use real type annotations due to circular imports drafter: object # Type is Drafter @@ -1044,6 +1091,7 @@ def supports_backend(self, backend: str) -> bool: MTPDecodingConfig, NGramDecodingConfig, UserProvidedDecodingConfig, + SaveHiddenStatesDecodingConfig, AutoDecodingConfig, ]] @@ -1863,6 +1911,20 @@ def validate_speculative_config(self): self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO self.build_config.max_draft_len = self.speculative_config.max_draft_len + elif isinstance(self.speculative_config, + SaveHiddenStatesDecodingConfig): + assert self.backend in ['pytorch'] + logger.warning( + "SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None" + ) + self.build_config.max_batch_size = 1 + self.max_batch_size = 1 + self.disable_overlap_scheduler = True + self.cuda_graph_config = None + self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES + self.build_config.max_draft_len = 1 + self.speculative_config.max_draft_len = 1 + else: raise ValueError( f"Unrecognized speculative config type {type(self.speculative_config)}" diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 6f4bfdf0bb0..b2ad8f82dfc 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag): EAGLE = auto() NGRAM = auto() USER_PROVIDED = auto() + SAVE_HIDDEN_STATES = auto() AUTO = auto() @staticmethod @@ -120,6 +121,8 @@ def from_arguments(args: argparse.Namespace): return SpeculativeDecodingMode.USER_PROVIDED elif args.speculative_decoding_mode == "auto": return SpeculativeDecodingMode.AUTO + elif args.speculative_decoding_mode == "save_hidden_states": + return SpeculativeDecodingMode.SAVE_HIDDEN_STATES else: assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode diff --git a/tests/unittest/_torch/speculative/test_save_state.py b/tests/unittest/_torch/speculative/test_save_state.py new file mode 100644 index 00000000000..406dd4f8cf0 --- /dev/null +++ b/tests/unittest/_torch/speculative/test_save_state.py @@ -0,0 +1,138 @@ +import os +import sys +import tempfile +import unittest + +import pytest +import torch +from utils.llm_data import llm_models_root + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig, + SaveHiddenStatesDecodingConfig) + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + + +def test_multi_save_state(): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + enable_chunked_prefill = False + layers_to_capture = {10, 11, 12} + + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 80: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + with tempfile.TemporaryDirectory() as temp_dir: + + target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct" + + max_batch_size = 16 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + free_gpu_memory_fraction=0.5) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1, 2, 4]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + ) + spec_config = SaveHiddenStatesDecodingConfig( + output_directory=temp_dir, + write_interval=1, + file_prefix="data", + eagle3_layers_to_capture=layers_to_capture) + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + llm_spec.shutdown() + assert os.path.exists(os.path.join(temp_dir, "data_1.pt")) + # Read in .pt file + saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0] + + assert saved_data["aux_hidden_states"].shape == (len(tok_ids), 2048 * + len(layers_to_capture)) + assert saved_data["hidden_state"].shape == (len(tok_ids), 2048) + assert saved_data["input_ids"].tolist() == tok_ids + + +@pytest.mark.parametrize("layers_to_capture", [{-1}, None]) +def test_save_state(layers_to_capture): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + enable_chunked_prefill = False + + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 80: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + with tempfile.TemporaryDirectory() as temp_dir: + + target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct" + + max_batch_size = 16 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + free_gpu_memory_fraction=0.5) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1, 2, 4]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + ) + spec_config = SaveHiddenStatesDecodingConfig( + output_directory=temp_dir, + write_interval=1, + file_prefix="data", + eagle3_layers_to_capture=layers_to_capture) + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + llm_spec.shutdown() + assert os.path.exists(os.path.join(temp_dir, "data_1.pt")) + # Read in .pt file + saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0] + if layers_to_capture is None: + assert saved_data["aux_hidden_states"].shape == (len(tok_ids), + 2048 * 3) + assert saved_data["hidden_state"].shape == (len(tok_ids), 2048) + assert saved_data["input_ids"].tolist() == tok_ids + else: + assert "aux_hidden_states" not in saved_data + assert saved_data["hidden_state"].shape == (len(tok_ids), 2048) + assert saved_data["input_ids"].tolist() == tok_ids + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/api_stability/references_committed/llm.yaml b/tests/unittest/api_stability/references_committed/llm.yaml index a722da54958..36e2ff28ea9 100644 --- a/tests/unittest/api_stability/references_committed/llm.yaml +++ b/tests/unittest/api_stability/references_committed/llm.yaml @@ -59,7 +59,7 @@ methods: default: null # Speculative decoding speculative_config: - annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType] + annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, tensorrt_llm.llmapi.llm_args.SaveHiddenStatesDecodingConfig, NoneType] default: null # generation constraints max_batch_size: