Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,11 @@ def disable_optimization(backend: Backend):
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if not self.is_draft_model and self.max_draft_len > 0 and not self.spec_config.spec_dec_mode.use_one_engine(
):
if (not self.is_draft_model and self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)

for bs in cuda_graph_batch_sizes:
Expand Down
9 changes: 7 additions & 2 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, final

from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import ResourceManager
Expand All @@ -26,8 +26,13 @@ def prepare_draft_tokens(
"""
raise NotImplementedError

@final
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
"""Check if spec decode should be used for the current iteration."""
"""
You probably don't want to override this. ModelEngine
assumes that speculation is always on if max_concurrency
is not specified by the user's spec config.
"""
if self.max_concurrency is not None:
return len(requests) <= self.max_concurrency
return True