diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index aedac8c2ac7..3320c6b0929 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -2027,7 +2027,7 @@ class GenericLlmRequest // Scatter the input tokens to other beam mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens); - mLastTokens = VecTokens(mSamplingConfig.beamWidth); + mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back()); // Init mUniqueTokens VecUniqueTokens uniqueTokens{inputTokens.size()}; diff --git a/docs/source/torch/features/feature_combination_matrix.md b/docs/source/torch/features/feature_combination_matrix.md index 35a10a49596..6990c61e182 100644 --- a/docs/source/torch/features/feature_combination_matrix.md +++ b/docs/source/torch/features/feature_combination_matrix.md @@ -8,11 +8,11 @@ | Disaggregated Serving | Yes | Yes | Yes | --- | | | | | | | | | | | | Chunked Prefill | Yes | Yes | Yes | Untested | --- | | | | | | | | | | | MTP | Yes | Yes | Yes | Yes | Untested | --- | | | | | | | | | -| EAGLE-3(One Model Engine) | Yes | Yes | Yes | No | Yes | No | --- | | | | | | | | -| EAGLE-3(Two Model Engine) | NO | Yes | Yes | No | Yes | No | No | --- | | | | | | | +| EAGLE-3(One Model Engine) | Yes | Yes | Yes | No | Yes | No | --- | | | | | | | | +| EAGLE-3(Two Model Engine) | NO | Yes | Yes | No | Yes | No | No | --- | | | | | | | | Torch Sampler | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | --- | | | | | | | TLLM C++ Sampler | Yes | Yes | Yes | Yes | Yes | No | No | No | No | --- | | | | | | KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | | -| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | | -| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | | -| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- | +| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | | +| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | | +| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- | diff --git a/requirements.txt b/requirements.txt index baf4b5341cf..591cfc7a055 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,6 +52,8 @@ einops flashinfer-python==0.2.5 opencv-python-headless xgrammar==0.1.21 +llguidance==0.7.29 +jsonschema backoff nvtx matplotlib # FIXME: this is added to make nvtx happy @@ -59,7 +61,6 @@ meson ninja etcd3 blake3 -llguidance==0.7.29 soundfile triton==3.3.1; platform_machine == "x86_64" tiktoken diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 7d51af7ae19..22f24752c77 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -16,11 +16,19 @@ class GrammarMatcher(ABC): def accept_token(self, token_id: int) -> bool: pass + @abstractmethod + def rollback(self, num_tokens: int) -> None: + pass + @abstractmethod def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: pass + @abstractmethod + def is_terminated(self) -> bool: + pass + class GrammarMatcherFactory(ABC): @@ -39,15 +47,23 @@ def __init__(self, matcher: xgrammar.GrammarMatcher): def accept_token(self, token_id: int) -> bool: return self._matcher.accept_token(token_id) + def rollback(self, num_tokens: int) -> None: + self._matcher.rollback(num_tokens) + def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: self._matcher.fill_next_token_bitmask(next_token_bitmask, index) + def is_terminated(self) -> bool: + return self._matcher.is_terminated() + class XGrammarMatcherFactory(GrammarMatcherFactory): - def __init__(self, guided_decoding_config: GuidedDecodingConfig, - vocab_size_padded: int): + def __init__(self, + guided_decoding_config: GuidedDecodingConfig, + vocab_size_padded: int, + max_num_draft_tokens: int = 0): super().__init__() vocab_type = xgrammar.VocabType.RAW add_prefix_space = False @@ -72,6 +88,7 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig, cache_enabled=True, cache_limit_bytes=cache_limit_bytes, ) + self.max_num_draft_tokens = max_num_draft_tokens def create(self, guided_decoding_params: GuidedDecodingParams) -> XGrammarMatcher: @@ -106,20 +123,38 @@ def create(self, case _: raise ValueError(f"Unsupported guide type: {guide_type}.") - matcher = xgrammar.GrammarMatcher(compiled_grammar) + matcher = xgrammar.GrammarMatcher( + compiled_grammar, max_rollback_tokens=self.max_num_draft_tokens) return XGrammarMatcher(matcher) class LLGuidanceMatcher(GrammarMatcher): - def __init__(self, matcher: llguidance.LLMatcher): + def __init__(self, matcher: llguidance.LLMatcher, eos_token: int): super().__init__() self._matcher = matcher + self._eos_token = eos_token + self._is_terminated = False def accept_token(self, token_id: int) -> bool: - result = self._matcher.consume_token(token_id) + if self._matcher.is_stopped(): + # Accept EOS token only if the matcher is stopped. + if token_id == self._eos_token: + self._is_terminated = True + return True + else: + return False + + num_accepted = self._matcher.try_consume_tokens([token_id]) + self._check_err() + return num_accepted > 0 + + def rollback(self, num_tokens: int) -> None: + if self._is_terminated: + self._is_terminated = False + num_tokens -= 1 + self._matcher.rollback(num_tokens) self._check_err() - return result def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: @@ -127,6 +162,9 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, next_token_bitmask, index) self._check_err() + def is_terminated(self) -> bool: + return self._is_terminated + def _check_err(self) -> None: if self._matcher.is_error(): raise ValueError( @@ -181,4 +219,4 @@ def create( if matcher.is_error(): raise ValueError(f"LLGuidance matcher error: {matcher.get_error()}") - return LLGuidanceMatcher(matcher) + return LLGuidanceMatcher(matcher, self._tokenizer.eos_token) diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index f1b21339b9a..fa95a0a7a15 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -5,19 +5,25 @@ from ..._utils import nvtx_range from ...bindings.executor import GuidedDecodingConfig +from ...logger import logger from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory, LLGuidanceMatcherFactory, XGrammarMatcherFactory) +from .llm_request import LlmRequest from .scheduler import ScheduledRequests class GuidedDecoder: bitmask_dtype = torch.int32 - def __init__(self, guided_decoding_config: GuidedDecodingConfig, - max_num_sequences: int, vocab_size_padded: int): + def __init__(self, + guided_decoding_config: GuidedDecodingConfig, + max_num_sequences: int, + vocab_size_padded: int, + max_num_draft_tokens: int = 0): self.guided_decoding_backend = guided_decoding_config.backend self.max_num_sequences = max_num_sequences self.vocab_size_padded = vocab_size_padded + self.max_num_draft_tokens = max_num_draft_tokens self.grammar_matcher_factory: Optional[GrammarMatcherFactory] = None self.grammar_matchers: List[ @@ -25,71 +31,216 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig, if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR: self.grammar_matcher_factory = XGrammarMatcherFactory( - guided_decoding_config, vocab_size_padded) + guided_decoding_config, + vocab_size_padded, + max_num_draft_tokens=max_num_draft_tokens) elif self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE: self.grammar_matcher_factory = LLGuidanceMatcherFactory( guided_decoding_config, vocab_size_padded) else: raise ValueError( - f"invalid guided_decoding_backend: {self.guided_decoding_backend}" + f"invalid guided decoding backend: {self.guided_decoding_backend}" ) + logger.info( + f"Guided decoder initialized with backend: {self.guided_decoding_backend}" + ) self.bitmask = torch.empty(self.max_num_sequences, + self.max_num_draft_tokens + 1, self.bitmask_size, dtype=self.bitmask_dtype, device='cuda') self.bitmask_host = torch.empty(self.max_num_sequences, + self.max_num_draft_tokens + 1, self.bitmask_size, dtype=self.bitmask_dtype, pin_memory=True) + # The number of tokens accepted by the grammar matcher in a build step. + self.num_advanced_tokens: List[int] = [0] * self.max_num_sequences + # The number of tokens with filled bitmask in a build step. + self.num_guided_tokens: List[int] = [0] * self.max_num_sequences + # The accumulated number of tokens accepted by the grammar matcher in a drafting loop. + self.num_advanced_draft_tokens: List[int] = [0] * self.max_num_sequences + # Whether is guided drafting is terminated because of unacceptable drafted tokens. + self.is_draft_terminated: List[bool] = [False] * self.max_num_sequences + self._stream = torch.cuda.Stream() @property def bitmask_size(self) -> int: return math.ceil(self.vocab_size_padded / 32) + def _is_matcher_init(self, llm_req: LlmRequest) -> bool: + if llm_req.guided_decoding_params is None: + return False + if llm_req.py_is_draft: + return False + # The request is in the last chunk of a context forward step. + return llm_req.is_context_init_state and llm_req.is_last_context_chunk + + def _is_matcher_in_progress(self, llm_req: LlmRequest) -> bool: + if llm_req.guided_decoding_params is None: + return False + if llm_req.py_is_draft: + return True + # The request is in a generation forward step. + return llm_req.is_generation_in_progress_state + + @torch.inference_mode() @nvtx_range("GuidedDecoder.build") def build(self, scheduled_requests: ScheduledRequests) -> None: + """Build the bitmask for requests with guided decoding enabled. + + Specifically, this method: + - build and advance the grammar matcher for context and generation requests, respectively; + - call the grammar matcher to fill the bitmask on CPU; + - asynchronously copy the bitmask to GPU. + """ for llm_req in scheduled_requests.all_requests(): - if llm_req.guided_decoding_params is None: - continue - slot = llm_req.py_seq_slot - if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len: - self.grammar_matchers[ - slot] = self.grammar_matcher_factory.create( - llm_req.guided_decoding_params) - - elif llm_req.is_generation_in_progress_state: - # The request is in a generation forward step. - # Currently, guided decoding does not support with beam search. - self.grammar_matchers[slot].accept_token( - llm_req.get_last_tokens(0)) + slot: int = llm_req.py_target_seq_slot if llm_req.py_is_draft else llm_req.py_seq_slot + self.num_advanced_tokens[slot] = 0 + self.num_guided_tokens[slot] = 0 + + if self._is_matcher_init(llm_req): + matcher = self.grammar_matcher_factory.create( + llm_req.guided_decoding_params) + self.grammar_matchers[slot] = matcher + + elif self._is_matcher_in_progress(llm_req): + matcher = self.grammar_matchers[slot] + # The last new token must be acceptable unless the matcher is terminated in a drafting loop. + if llm_req.py_is_draft and (matcher.is_terminated() + or self.is_draft_terminated[slot]): + continue + last_new_token = llm_req.get_last_tokens(0) + accepted = matcher.accept_token(last_new_token) + if not accepted: + if llm_req.py_is_draft: + self.is_draft_terminated[slot] = True + logger.debug( + f"Draft request {llm_req.py_request_id} failed to accept last new token: {last_new_token}." + ) + continue + # TODO: Make this an error response. + raise ValueError( + f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}." + ) + else: continue - # Fill the bitmask on host and asynchorously copy to device. - self.grammar_matchers[slot].fill_next_token_bitmask( - self.bitmask_host, slot) - with torch.cuda.stream(self._stream): - self.bitmask[slot].copy_(self.bitmask_host[slot], - non_blocking=True) - + self.num_advanced_tokens[slot] += 1 + if not matcher.is_terminated(): + matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0) + self.num_guided_tokens[slot] += 1 + # Process draft tokens + for i, tid in enumerate(llm_req.py_draft_tokens, 1): + accepted = matcher.accept_token(tid) + if not accepted: + break + self.num_advanced_tokens[slot] += 1 + if matcher.is_terminated(): + break + matcher.fill_next_token_bitmask(self.bitmask_host[slot], i) + self.num_guided_tokens[slot] += 1 + + if llm_req.py_is_draft: + assert len(llm_req.py_draft_tokens) == 0 + self.num_advanced_draft_tokens[ + slot] += self.num_advanced_tokens[slot] + + if (num_guided_tokens := self.num_guided_tokens[slot]) > 0: + with torch.cuda.stream(self._stream): + self.bitmask[slot, :num_guided_tokens].copy_( + self.bitmask_host[slot, :num_guided_tokens], + non_blocking=True) + + @torch.inference_mode() @nvtx_range("GuidedDecoder.execute") - def execute(self, scheduled_requests: ScheduledRequests, - logits: torch.Tensor) -> None: - assert logits.size(0) == len(scheduled_requests.context_requests) + len( - scheduled_requests.generation_requests) + def execute(self, + scheduled_requests: ScheduledRequests, + logits: torch.Tensor, + d2t: Optional[torch.Tensor] = None) -> None: + """Apply the bitmask to the corresponding logits for requests with guided decoding enabled. + + This method inplace modifies the logits tensor so that any tokens that violate the grammar constraints are masked out. + """ torch.cuda.current_stream().wait_stream(self._stream) + # TODO: Fuse index_copy and index_select to logits_bitmask. + if d2t is not None: + draft_logits = logits + d2t_mapping = d2t + torch.arange(d2t.size(0), device=d2t.device) + logits = torch.empty(draft_logits.size(0), + self.vocab_size_padded, + dtype=draft_logits.dtype, + device=draft_logits.device) + logits.index_copy_(-1, d2t_mapping, draft_logits) + batched_logits, batched_bitmask = [], [] - for i, llm_req in enumerate(scheduled_requests.all_requests()): - if llm_req.guided_decoding_params is None: - continue - if llm_req.is_context_init_state and not llm_req.is_last_context_chunk: - continue - batched_logits.append(logits[i]) - batched_bitmask.append(self.bitmask[llm_req.py_seq_slot]) + offset = 0 + for llm_req in scheduled_requests.all_requests(): + slot: int = llm_req.py_target_seq_slot if llm_req.py_is_draft else llm_req.py_seq_slot + for i in range(self.num_guided_tokens[slot]): + batched_logits.append(logits[offset + i]) + batched_bitmask.append(self.bitmask[slot, i]) + offset += len(llm_req.py_draft_tokens) + 1 + + assert offset == logits.size(0) if len(batched_logits) > 0: torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask) + + if d2t is not None: + torch.index_select(logits, -1, d2t_mapping, out=draft_logits) + + @nvtx_range("GuidedDecoder.rollback_rejected_tokens") + def rollback_rejected_tokens(self, + scheduled_requests: ScheduledRequests) -> None: + """Rollback the grammar matcher for rejected tokens. + + This method should be called: + - after the verification (so that the accepted tokens are ready) and + - before the first guided decoding build of the next drafting loop. + """ + if self.max_num_draft_tokens <= 0: + return + + for llm_req in scheduled_requests.all_requests(): + assert not llm_req.py_is_draft + slot: int = llm_req.py_seq_slot + if self.num_advanced_tokens[slot] <= 0: + continue + # Rollback the grammar matcher to the last accepted token. + num_rollback_tokens = self.num_advanced_tokens[slot] - ( + 1 + llm_req.py_num_accepted_draft_tokens) + # TODO: Make this an error response. + if num_rollback_tokens < 0: + raise ValueError( + f"Failed to rollback: num_advanced_tokens={self.num_advanced_tokens[slot]}, num_accepted_draft_tokens={llm_req.py_num_accepted_draft_tokens}, num_rollback_tokens={num_rollback_tokens}" + ) + self.grammar_matchers[slot].rollback(num_rollback_tokens) + + @nvtx_range("GuidedDecoder.rollback_draft_tokens") + def rollback_draft_tokens(self, + scheduled_requests: ScheduledRequests) -> None: + """Rollback the grammar matcher for draft tokens. + + This method should be called: + - after the the drafting loop and + - before the guided decoding build of the target model. + """ + if self.max_num_draft_tokens <= 0: + return + + for llm_req in scheduled_requests.all_requests(): + assert not llm_req.py_is_draft + slot: int = llm_req.py_seq_slot + if self.num_advanced_draft_tokens[slot] <= 0: + continue + self.grammar_matchers[slot].rollback( + self.num_advanced_draft_tokens[slot]) + # Reset the drafting states. + self.num_advanced_draft_tokens[slot] = 0 + self.is_draft_terminated[slot] = False diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 72793a8a2e4..bd72fcc53e5 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -281,6 +281,8 @@ def __init__( llm_request: Optional[ tensorrt_llm.bindings.internal.batch_manager.LlmRequest] = None, is_draft: bool = False, + seq_slot: Optional[int] = None, + target_seq_slot: Optional[int] = None, **kwargs): self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", @@ -309,6 +311,7 @@ def __init__( self.py_orig_prompt_len = self.orig_prompt_len self.py_max_new_tokens = self.max_new_tokens self.py_batch_idx = None + self.py_draft_pages_allocated = 0 self.py_rewind_len = 0 self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens self.py_last_context_chunk = (None, None) @@ -326,7 +329,8 @@ def __init__( self.py_return_generation_logits = return_generation_logits self.py_return_logits_device_memory = return_logits_device_memory self.py_is_draft = is_draft - self.py_seq_slot = None + self.py_seq_slot = seq_slot + self.py_target_seq_slot = target_seq_slot # TODO: remove this when use DynamicDecodeOp in pytorch flow. # currently, keep py_stop_words_list as python list, rather than tensor. diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2d00cee05f0..247f6da1754 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -453,7 +453,7 @@ def use_mrope(self): 'type'] == 'mrope' except Exception: pass - logger.info(f"Detected use_mrope: {use_mrope}") + logger.debug(f"Detected use_mrope: {use_mrope}") return use_mrope @property diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index d87dbef4e7d..1b028d097e1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -893,7 +893,8 @@ def _prepare_and_schedule_batch(self): f'{len(scheduled_batch.generation_requests)} generation requests') return scheduled_batch, iter_stats - def _execute_guided_decoder(self, scheduled_batch, logits): + def _execute_guided_decoder(self, scheduled_batch: ScheduledRequests, + logits: torch.Tensor): if self.guided_decoder is not None: self.guided_decoder.build(scheduled_batch) self.guided_decoder.execute(scheduled_batch, logits) @@ -931,6 +932,9 @@ def _executor_loop(self): self.resource_manager.prepare_resources(scheduled_batch) if self.drafter is not None and self.use_spec_decode: + if self.guided_decoder is not None: + self.guided_decoder.rollback_rejected_tokens( + scheduled_batch) self.drafter.prepare_draft_tokens( scheduled_batch, self.resource_manager) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index bcd006be71e..af3ee4040a5 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -33,6 +33,7 @@ class _ExecutorCreationStage(enum.Enum): SAMPLER = "Sampler" DRAFTER = "Drafter" + GUIDED_DECODER = "Guided decoder" INIT_KV_CACHE = "Initial KV cache (temporary for KV cache size estimation)" INIT_EXTRA_RESOURCES = "Additional executor resources (temporary for KV cache size estimation)" MODEL_EXTRA = "Model resources created during usage" @@ -326,21 +327,28 @@ def create_py_executor( else: ctx_chunk_config = None + with mem_monitor.observe_creation_stage( + _ExecutorCreationStage.GUIDED_DECODER): + guided_decoder: Optional[GuidedDecoder] = None + if executor_config.guided_decoding_config is not None: + if spec_config is not None and not has_spec_drafter: + raise ValueError( + "Guided decoding is only supported with speculative decoding that has a dedicated drafter (two-model engine)." + ) + if mapping.is_last_pp_rank(): + max_num_draft_tokens = 0 + if spec_config is not None: + max_num_draft_tokens = spec_config.max_draft_len + guided_decoder = GuidedDecoder( + executor_config.guided_decoding_config, + executor_config.max_batch_size, + model_engine.model.vocab_size_padded, + max_num_draft_tokens=max_num_draft_tokens) + with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER): sampler = instantiate_sampler(model_engine, executor_config, pytorch_backend_config, mapping) - guided_decoder: Optional[GuidedDecoder] = None - if executor_config.guided_decoding_config is not None: - if spec_config is not None: - raise ValueError( - "Guided decoding is not supported with speculative decoding.") - if mapping.is_last_pp_rank(): - guided_decoder = GuidedDecoder( - executor_config.guided_decoding_config, - executor_config.max_batch_size, - model_engine.model.vocab_size_padded) - resources = {} estimating_kv_cache = False kv_cache_creator = None @@ -368,8 +376,11 @@ def create_py_executor( # Drafter for speculative decoding with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER): - drafter = get_spec_drafter(model_engine, draft_model_engine, sampler, - spec_resource_manager) + drafter = get_spec_drafter(model_engine, + draft_model_engine, + sampler, + spec_resource_manager=spec_resource_manager, + guided_decoder=guided_decoder) with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_EXTRA_RESOURCES diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 06406016804..fba50db175f 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -8,8 +8,9 @@ from tensorrt_llm._utils import nvtx_range from tensorrt_llm.logger import logger +from ..pyexecutor.guided_decoder import GuidedDecoder from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState, - SamplingConfig, get_draft_token_length) + get_draft_token_length) from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager from ..pyexecutor.sampler import Sampler, SampleState, TorchSampler from ..pyexecutor.scheduler import ScheduledRequests @@ -45,6 +46,7 @@ def __init__( draft_seq_slot_manager: SeqSlotManager, sampler: Sampler, spec_resource_manager: Optional[BaseResourceManager] = None, + guided_decoder: Optional[GuidedDecoder] = None, ): # Validate required parameters if draft_model_engine is None: @@ -65,17 +67,18 @@ def __init__( self._request_draft_logits = False if isinstance(sampler, TorchSampler): self._request_draft_logits = sampler.enable_mixed_sampler + self.guided_decoder = guided_decoder - def _create_draft_request(self, request_id: int, max_new_tokens: int, - input_tokens: Optional[List], - sampling_config: SamplingConfig, - return_perf_metrics: bool) -> LlmRequest: + def _create_draft_request(self, request: LlmRequest, + input_tokens: Optional[List]) -> LlmRequest: """Create a draft request with common parameters.""" - return LlmRequest(request_id=request_id, - max_new_tokens=max_new_tokens, - input_tokens=input_tokens, - sampling_config=sampling_config, - return_perf_metrics=return_perf_metrics, + return LlmRequest(input_tokens=input_tokens, + request_id=request.py_request_id, + max_new_tokens=request.py_max_new_tokens, + sampling_config=request.sampling_config, + guided_decoding_params=request.guided_decoding_params, + target_seq_slot=request.py_seq_slot, + return_perf_metrics=request.return_perf_metrics, is_streaming=False, is_draft=True, return_generation_logits=self._request_draft_logits) @@ -96,11 +99,7 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]: def _create_context_request(self, request: LlmRequest, input_tokens: Any) -> LlmRequest: """Create a context request for first-time drafting.""" - new_request = self._create_draft_request(request.py_request_id, - request.py_max_new_tokens, - input_tokens, - request.sampling_config, - request.return_perf_metrics) + new_request = self._create_draft_request(request, input_tokens) begin_compute, end_compute = request.py_last_context_chunk if begin_compute is not None: @@ -111,13 +110,7 @@ def _create_context_request(self, request: LlmRequest, def _create_generation_request(self, request: LlmRequest, input_tokens: Any) -> LlmRequest: """Create a generation request when no tokens were accepted.""" - new_request = self._create_draft_request(request.py_request_id, - request.py_max_new_tokens, - input_tokens[:-1], - request.sampling_config, - request.return_perf_metrics) - # Explicitly add the last token so get_last_tokens() returns the right value - new_request.add_new_token(input_tokens[-1], 0) + new_request = self._create_draft_request(request, input_tokens) new_request.state = LlmRequestState.GENERATION_IN_PROGRESS return new_request @@ -128,11 +121,7 @@ def _create_accepted_tokens_request(self, request: LlmRequest, Create a chunked context request for accepted tokens. Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3) """ - new_request = self._create_draft_request(request.py_request_id, - request.py_max_new_tokens, - input_tokens, - request.sampling_config, - request.return_perf_metrics) + new_request = self._create_draft_request(request, input_tokens) new_request.context_chunk_size = num_accepted_tokens + 1 new_request.context_current_position = len( input_tokens) - num_accepted_tokens - 1 @@ -144,7 +133,7 @@ def _create_draft_request_for_request( num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens( request) input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode, - request.get_tokens()[0]) + request.get_tokens(0)) # First time seeing this request - context request if request.max_beam_num_tokens - 1 == request.py_prompt_len: @@ -206,7 +195,7 @@ def _prepare_draft_batch( # We hit this path if we're doing chunked prefill. The target model processed # a prefill chunk on the last iteration. Now, we need to fill in the KV cache # for the draft model too. - all_tokens = request.get_tokens()[0] + all_tokens = request.get_tokens(0) input_tokens = get_draft_model_prompt( self.spec_config.spec_dec_mode, all_tokens) @@ -329,6 +318,14 @@ def _pad_to_max_draft_tokens(self, req.py_draft_tokens.extend( 0 for _ in range(max_draft_tokens - num_draft_tokens)) + def _execute_guided_decoder(self, + scheduled_batch: ScheduledRequests, + logits: torch.Tensor, + d2t: Optional[torch.Tensor] = None): + if self.guided_decoder is not None: + self.guided_decoder.build(scheduled_batch) + self.guided_decoder.execute(scheduled_batch, logits, d2t=d2t) + @nvtx_range("prepare_draft_tokens") def prepare_draft_tokens( self, @@ -363,6 +360,9 @@ def prepare_draft_tokens( # Initial forward pass outputs = self._forward_draft_model(draft_batch, resource_manager) + self._execute_guided_decoder(draft_batch, + outputs['logits'], + d2t=outputs.get('d2t')) sample_state = self._sample_async(draft_batch, outputs) previous_batch = sample_state @@ -380,10 +380,14 @@ def prepare_draft_tokens( outputs = self._forward_draft_model(draft_batch, resource_manager, previous_batch) + if previous_batch is not None: + self._update_requests(previous_batch) + self._execute_guided_decoder(draft_batch, + outputs['logits'], + d2t=outputs.get('d2t')) sample_state = self._sample_async(draft_batch, outputs) self._update_request_states(draft_batch) if previous_batch is not None: - self._update_requests(previous_batch) new_requests = self._process_decoded_tokens( previous_batch.scheduled_requests, req_id_to_old_request) @@ -399,6 +403,9 @@ def prepare_draft_tokens( req_id_to_old_request) self._pad_to_max_draft_tokens(scheduled_requests) + if self.guided_decoder is not None: + self.guided_decoder.rollback_draft_tokens(scheduled_requests) + except Exception as e: traceback.print_exc() error_msg = str(e) diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 39267f5da26..6ca615de34b 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -1,11 +1,12 @@ from itertools import chain +from typing import Optional from ordered_set import OrderedSet from tensorrt_llm.llmapi import NGramDecodingConfig from tensorrt_llm.logger import logger -from ..pyexecutor.llm_request import * +from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager from ..pyexecutor.scheduler import ScheduledRequests from .drafter import Drafter @@ -181,6 +182,7 @@ def prepare_draft_tokens( if self.spec_config.is_auto_heuristic and len( scheduled_requests.all_requests()) > 32: return + # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft # before forward_step, so py_batch_idx is not assigned. @@ -190,7 +192,7 @@ def prepare_draft_tokens( (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ): # Add new token to a copy of the generated tokens to find new draft tokens - prefix = list(request.get_tokens()[0]) # Get a copy + prefix = list(request.get_tokens(0)) # Get a copy # Generate draft tokens draft_tokens = self.spec_resource_manager.get_draft_tokens( diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index e8db9d1f561..ad7fbf8fd56 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -1,7 +1,9 @@ -from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler -from tensorrt_llm._torch.speculative.interface import SpecMetadata +from typing import Optional +from ..pyexecutor.guided_decoder import GuidedDecoder +from ..pyexecutor.sampler import TorchSampler from ..pyexecutor.seq_slot_manager import SeqSlotManager +from ..speculative.interface import SpecMetadata from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata, Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3SpecMetadata) @@ -114,8 +116,11 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}") -def get_spec_drafter(model_engine, draft_model_engine, sampler, - spec_resource_manager): +def get_spec_drafter(model_engine, + draft_model_engine, + sampler, + spec_resource_manager, + guided_decoder: Optional[GuidedDecoder] = None): spec_config = model_engine.spec_config if spec_config is None: return None @@ -126,10 +131,13 @@ def get_spec_drafter(model_engine, draft_model_engine, sampler, max_num_requests = model_engine.batch_size if spec_config.spec_dec_mode.is_draft_target( ) or spec_config.spec_dec_mode.is_eagle3(): - return ModelDrafter(spec_config, draft_model_engine, + return ModelDrafter(spec_config, + draft_model_engine, spec_config.max_draft_len, - SeqSlotManager(max_num_requests), sampler, - spec_resource_manager) + SeqSlotManager(max_num_requests), + sampler, + spec_resource_manager=spec_resource_manager, + guided_decoder=guided_decoder) if spec_config.spec_dec_mode.is_ngram(): return NGramDrafter(spec_config, spec_resource_manager) diff --git a/tensorrt_llm/evaluate/json_mode_eval.py b/tensorrt_llm/evaluate/json_mode_eval.py index 1854488bce3..37360754e50 100644 --- a/tensorrt_llm/evaluate/json_mode_eval.py +++ b/tensorrt_llm/evaluate/json_mode_eval.py @@ -18,6 +18,7 @@ import click import datasets +import jsonschema import numpy as np from .. import LLM as PyTorchLLM @@ -65,23 +66,30 @@ def generate_samples(self) -> Iterable[tuple]: sampling_args = { "guided_decoding": GuidedDecodingParams(json=schema) } - yield sample["prompt"], sampling_args, sample["completion"] + yield sample["prompt"], sampling_args, sample["completion"], sample[ + "schema"] - def compute_score(self, outputs: List[RequestOutput], - references: List[str]) -> float: - all_corrections = [] - for output, ref in zip(outputs, references): + def compute_score(self, outputs: List[RequestOutput], references: List[str], + schemas: List[str]) -> float: + all_corrections, all_grammar_corrections = [], [] + for output, ref, schema in zip(outputs, references, schemas): try: output_json = json.loads(output.outputs[0].text) - except json.JSONDecodeError: + jsonschema.validate(output_json, json.loads(schema)) + except (json.JSONDecodeError, jsonschema.ValidationError): all_corrections.append(False) + all_grammar_corrections.append(False) continue - ref_json = json.loads(ref) - all_corrections.append(output_json == ref_json) + all_corrections.append(output_json == json.loads(ref)) + all_grammar_corrections.append(True) acc = np.mean(all_corrections) * 100 logger.info( f"JSON Mode Eval accuracy: {acc:.2f} ({len(all_corrections)})") + grammar_acc = np.mean(all_grammar_corrections) * 100 + logger.info( + f"JSON Mode Eval grammar accuracy: {grammar_acc:.2f} ({len(all_grammar_corrections)})" + ) return acc @click.command("json_mode_eval") diff --git a/tests/integration/defs/accuracy/references/json_mode_eval.yaml b/tests/integration/defs/accuracy/references/json_mode_eval.yaml index d22461d8aa7..f8b82fef8e0 100644 --- a/tests/integration/defs/accuracy/references/json_mode_eval.yaml +++ b/tests/integration/defs/accuracy/references/json_mode_eval.yaml @@ -1,2 +1,6 @@ meta-llama/Llama-3.1-8B-Instruct: - accuracy: 74.00 + - spec_dec_algo: Eagle + accuracy: 74.00 + - spec_dec_algo: NGram + accuracy: 74.00 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 8e3751f149c..ada8352b9ad 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -304,9 +304,7 @@ def test_ngram(self): @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) def test_guided_decoding(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) - llm = LLM(self.MODEL_PATH, - guided_decoding_backend=backend, - cuda_graph_config=CudaGraphConfig()) + llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend) with llm: task = JsonModeEval(self.MODEL_NAME) task.evaluate(llm) @@ -318,12 +316,46 @@ def test_guided_decoding_4gpus(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) with LLM(self.MODEL_PATH, guided_decoding_backend=backend, - cuda_graph_config=CudaGraphConfig(), tensor_parallel_size=2, pipeline_parallel_size=2) as llm: task = JsonModeEval(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding_with_eagle3(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + spec_config = EagleDecodingConfig( + max_draft_len=3, + speculative_model_dir= + f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B", + eagle3_one_model=False) + llm = LLM(self.MODEL_PATH, + guided_decoding_backend=backend, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + disable_overlap_scheduler=True) + with llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_hopper + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) + def test_guided_decoding_with_ngram(self, backend: str, mocker): + mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + spec_config = NGramDecodingConfig(max_draft_len=3, + max_matching_ngram_size=3) + llm = LLM(self.MODEL_PATH, + guided_decoding_backend=backend, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + disable_overlap_scheduler=True) + with llm: + task = JsonModeEval(self.MODEL_NAME) + task.evaluate(llm) + class TestLlama3_2_1B(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.2-1B" diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index a7ff13bed8e..43ee39de1af 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -31,6 +31,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=TRTLLM-torch_compile=False] @@ -209,6 +210,9 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance] - test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] - test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] - condition: