Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Remove some useless code in ModelDrafter
Signed-off-by: ziyixiong-nv <fxiong@nvidia.com>
  • Loading branch information
ziyixiong-nv committed Jul 17, 2025
commit df74a1a543ea9882ee13bfb062b346b9bfe5258b
47 changes: 15 additions & 32 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,6 @@ def __init__(
# Sampling
self.sampler = sampler

def _should_process_request(self, request: LlmRequest) -> bool:
"""Check if request should be processed for drafting."""
return request.py_draft_pages_allocated > 0 # type: ignore

def _exceeds_max_sequence_length(self, request: LlmRequest) -> bool:
"""Check if the request exceeds maximum sequence length for drafting."""
return request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len

def _create_draft_request(self, request_id: int, max_new_tokens: int,
input_tokens: Optional[List],
sampling_config: SamplingConfig,
Expand All @@ -81,10 +73,6 @@ def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]:

return num_draft_tokens, num_accepted_tokens

def _get_draft_model_input(self, request: LlmRequest) -> Any:
"""Get input tokens for draft model."""
return self.spec_config.get_draft_model_prompt(request.get_tokens()[0])

def _create_context_request(self, request: LlmRequest,
input_tokens: Any) -> LlmRequest:
"""Create a context request for first-time drafting."""
Expand Down Expand Up @@ -116,10 +104,6 @@ def _create_chunked_context_request(self, request: LlmRequest,
request.sampling_config,
request.return_perf_metrics)
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1
# Note: Original code has duplicate assignment (appears to be a bug, but keeping it)
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1
return new_request
Expand All @@ -129,7 +113,8 @@ def _create_draft_request_for_request(
"""Create a draft request based on the original request state."""
num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens(
request)
input_tokens = self._get_draft_model_input(request)
input_tokens = self.spec_config.get_draft_model_prompt(
request.get_tokens()[0])

# First time seeing this request - context request
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
Expand Down Expand Up @@ -184,11 +169,18 @@ def _prepare_draft_batch(
draft_batch = ScheduledRequests()

for request in scheduled_requests.generation_requests:
if not self._should_process_request(request):
if request.py_draft_pages_allocated == 0:
# No space for draft tokens
continue

# Stop drafting when we hit the max seqlen
if self._exceeds_max_sequence_length(request):
# Stop drafting when we hit the max seqlen. We still need dummy draft
# tokens attached to the requests to make sure everything works properly
# with CUDA graph. These dummy tokens are already added by
# _prepare_draft_requests to make the KV cache/scheduler aware of the fact
# that we want to do spec decoding, so no need to do anything else here.
# This makes the perf for this case suboptimal, but that's OK - this is
# a corner case for weird models like the llama 3.1 8b EAGLE3 implementation.
if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len:
continue

draft_request = self._create_draft_request_for_request(request)
Expand Down Expand Up @@ -255,17 +247,8 @@ def _update_request_states(self,

def _update_requests(self, sample_state: SampleState) -> None:
"""Update requests with sample state."""
try:
if self.sampler is not None:
self.sampler.update_requests(sample_state)
except Exception as e:
logger.error(f"Error updating requests: {str(e)}")

def _handle_errors(self, error_msg: str) -> None:
"""Handle errors during draft token generation."""
logger.error(f"Draft token generation error: {error_msg}")
# For now, just log the error. In a full implementation, this could
# clean up resources, notify other components, etc.
if self.sampler is not None:
self.sampler.update_requests(sample_state)

def _process_decoded_tokens(
self, draft_batch: ScheduledRequests,
Expand All @@ -277,7 +260,7 @@ def _process_decoded_tokens(
target_model_req.py_draft_tokens.append(req.get_last_tokens(0))
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
target_model_req.py_draft_tokens
) < target_model_req.py_draft_pages_allocated: # type: ignore
) < target_model_req.py_draft_pages_allocated:
new_requests.append(req)
else:
self.draft_seq_slot_manager.free_resources(req)
Expand Down