-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Not CUDA graph captured eagle3 one-model draft loop #10251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -796,6 +796,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]): | |
| assert key in model_config.extra_attrs | ||
| model_config.extra_attrs[key].update(value) | ||
| self.layer_idx = -1 | ||
| self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -823,33 +824,15 @@ def forward( | |
| if attn_metadata.padded_num_tokens is not None: | ||
| hidden_states = hidden_states[:attn_metadata.num_tokens] | ||
|
|
||
| is_capturing = torch.cuda.is_current_stream_capturing() | ||
|
|
||
| if self.draft_model is not None: | ||
| # get logits | ||
| logits = self.logits_processor.forward( | ||
| hidden_states[spec_metadata.gather_ids], | ||
| self.lm_head, | ||
| attn_metadata, | ||
| True, | ||
| ) | ||
| mtp_input_ids = input_ids | ||
| mtp_position_ids = position_ids | ||
| if attn_metadata.padded_num_tokens is not None: | ||
| if input_ids is not None: | ||
| # Slice along the first dimension | ||
| mtp_input_ids = input_ids[:attn_metadata.num_tokens] | ||
| if position_ids is not None: | ||
| # Slice along the last dimension | ||
| mtp_position_ids = position_ids[:, :attn_metadata. | ||
| num_tokens] | ||
|
|
||
| # get accepted tokens and next draft tokens | ||
| return self.spec_worker(input_ids=mtp_input_ids, | ||
| position_ids=mtp_position_ids, | ||
| hidden_states=hidden_states, | ||
| logits=logits, | ||
| attn_metadata=attn_metadata, | ||
| spec_metadata=spec_metadata, | ||
| draft_model=self.draft_model) | ||
| if is_capturing and not self.enable_cuda_graph_for_draft_model: | ||
| return hidden_states | ||
| else: | ||
| return self.forward_draft(hidden_states, input_ids, | ||
| position_ids, attn_metadata, | ||
| spec_metadata) | ||
|
Comment on lines
+827
to
+835
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# First, find the method that contains lines 827-835
head -835 tensorrt_llm/_torch/models/modeling_speculative.py | tail -50Repository: NVIDIA/TensorRT-LLM Length of output: 2350 🏁 Script executed: #!/bin/bash
# Also check for forward_draft method and its return type
rg -n "def forward_draft" tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 134 🏁 Script executed: #!/bin/bash
# Find where this code is called in model_engine.py
rg -n "self.draft_model" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 5 -A 15Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Get the forward_draft method signature and implementation
sed -n '846,900p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 2456 🏁 Script executed: #!/bin/bash
# Find where SpeculativeHeadWrapper.forward is called
rg -n "\.forward(" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 2 -A 5Repository: NVIDIA/TensorRT-LLM Length of output: 133 🏁 Script executed: #!/bin/bash
# Search for usage of the return value from the model's forward/forward_draft
rg -n "forward_draft\|SpeculativeHeadWrapper" tensorrt_llm/_torch/pyexecutor/Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Fix the regex and search for forward calls
rg -n "\.forward\(" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 2 -A 5Repository: NVIDIA/TensorRT-LLM Length of output: 2751 🏁 Script executed: #!/bin
# Search for SpeculativeHeadWrapper usage
rg -n "SpeculativeHeadWrapper" tensorrt_llm/_torch/Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Look at the full forward method to understand return type
sed -n '800,850p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 1982 🏁 Script executed: #!/bin/bash
# Search more broadly for how forward is used and what happens to its return value
rg -n "model.forward\|\.forward\(\)" tensorrt_llm/_torch/pyexecutor/ -A 3 | head -100Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Check what spec_worker returns
rg -n "class.*SpecWorker\|def __call__" tensorrt_llm/_torch/ | grep -i specRepository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Look for spec_worker definition and return type
fd "spec_worker" tensorrt_llm/_torch/ --type fRepository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Let's check what self.spec_worker is - search for its definition/assignment
rg -n "self\.spec_worker\s*=" tensorrt_llm/_torch/models/modeling_speculative.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM Length of output: 438 🏁 Script executed: #!/bin/bash
# Search for get_spec_worker function
rg -n "def get_spec_worker\|from.*spec_worker" tensorrt_llm/_torch/Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Look at the broader file structure to understand the class
head -50 tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 2088 🏁 Script executed: #!/bin/bash
# Find get_spec_worker definition
rg -n "def get_spec_worker" tensorrt_llm/ -A 30Repository: NVIDIA/TensorRT-LLM Length of output: 2713 🏁 Script executed: #!/bin/bash
# Look for speculative module files
fd "speculative" tensorrt_llm/_torch/ --type fRepository: NVIDIA/TensorRT-LLM Length of output: 113 🏁 Script executed: #!/bin/bash
# Check the return type annotation in forward method more carefully
sed -n '814,846p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 1323 🏁 Script executed: #!/bin/bash
# Check what the worker classes return by looking at their __call__ methods
rg -n "class MTPWorker\|class MTPEagleWorker\|class Eagle3OneModelWorker" tensorrt_llm/_torch/speculative/ -A 50 | head -150Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Search for __call__ method in speculative worker classes
rg -n "def __call__" tensorrt_llm/_torch/speculative/ -A 15Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Look at full forward method and its return type annotation
sed -n '807,851p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 1774 🏁 Script executed: #!/bin/bash
# Check class definition and where forward is called
rg -n "class SpeculativeHeadWrapper" tensorrt_llm/_torch/models/modeling_speculative.py -A 5Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Look at how the model output is used after calling forward
rg -n "SpeculativeHeadWrapper\|speculative_head_wrapper" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 5 -A 10Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Check what modules/layers use SpeculativeHeadWrapper
rg -n "SpeculativeHeadWrapper" tensorrt_llm/_torch/ -B 2 -A 2Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Search for the class name from the file context
rg -n "class.*Wrapper" tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Look at imports and overall file structure
head -100 tensorrt_llm/_torch/models/modeling_speculative.py | tail -50Repository: NVIDIA/TensorRT-LLM Length of output: 2201 🏁 Script executed: #!/bin/bash
# Search for where speculative models are instantiated
rg -n "modeling_speculative\|from.*modeling_speculative" tensorrt_llm/_torch/ -lRepository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Find the class that contains this forward method - search backwards from line 807
sed -n '700,810p' tensorrt_llm/_torch/models/modeling_speculative.py | grep "^class"Repository: NVIDIA/TensorRT-LLM Length of output: 138 🏁 Script executed: #!/bin/bash
# Let's get a broader view of the file structure
rg -n "^class " tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 611 🏁 Script executed: #!/bin/bash
# Look at the entire context around the forward method to find the class
sed -n '750,815p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 3735 🏁 Script executed: #!/bin/bash
# Now search for where SpecDecOneEngineForCausalLM is used and how the forward result is handled
rg -n "SpecDecOneEngineForCausalLM" tensorrt_llm/_torch/ -B 2 -A 5Repository: NVIDIA/TensorRT-LLM Length of output: 14296 🏁 Script executed: #!/bin/bash
# Look for the model_engine forward method to see how it uses the model output
rg -n "def forward\(" tensorrt_llm/_torch/pyexecutor/model_engine.py | head -20Repository: NVIDIA/TensorRT-LLM Length of output: 114 🏁 Script executed: #!/bin/bash
# Check line around 3297-3300 that we saw earlier which calls model.forward
sed -n '3290,3310p' tensorrt_llm/_torch/pyexecutor/model_engine.pyRepository: NVIDIA/TensorRT-LLM Length of output: 1102 🏁 Script executed: #!/bin/bash
# Check what spec_worker types return
rg -n "class.*Worker.*:" tensorrt_llm/_torch/speculative/ -A 30 | grep -A 25 "class MTPWorker\|class Eagle3OneModelWorker"Repository: NVIDIA/TensorRT-LLM Length of output: 4316 🏁 Script executed: #!/bin/bash
# Look for the __call__ method in worker classes
fd "speculative" tensorrt_llm/_torch/ --type f -name "*.py"Repository: NVIDIA/TensorRT-LLM Length of output: 293 🏁 Script executed: #!/bin/bash
# Check worker implementations
head -300 tensorrt_llm/_torch/speculative/worker.py | tail -200Repository: NVIDIA/TensorRT-LLM Length of output: 163 🏁 Script executed: #!/bin/bash
# Find the return statement in MTPWorker.forward
rg -n "return" tensorrt_llm/_torch/speculative/mtp.py | grep -A 2 "MTPWorker\|def forward" | head -30Repository: NVIDIA/TensorRT-LLM Length of output: 45 🏁 Script executed: #!/bin/bash
# Get the full MTPWorker.forward method
sed -n '362,550p' tensorrt_llm/_torch/speculative/mtp.py | tail -100Repository: NVIDIA/TensorRT-LLM Length of output: 4632 🏁 Script executed: #!/bin/bash
# Get the full Eagle3OneModelWorker.forward method
sed -n '373,600p' tensorrt_llm/_torch/speculative/eagle3.py | tail -150Repository: NVIDIA/TensorRT-LLM Length of output: 6246 🏁 Script executed: #!/bin/bash
# Check if there's any type casting or handling of the return value
sed -n '3295,3320p' tensorrt_llm/_torch/pyexecutor/model_engine.pyRepository: NVIDIA/TensorRT-LLM Length of output: 992 🏁 Script executed: #!/bin/bash
# Look at MTPWorker.forward return statement more carefully
rg -n "return {" tensorrt_llm/_torch/speculative/mtp.py | head -5Repository: NVIDIA/TensorRT-LLM Length of output: 142 🏁 Script executed: #!/bin/bash
# Get exact return statement from MTPWorker
sed -n '500,530p' tensorrt_llm/_torch/speculative/mtp.pyRepository: NVIDIA/TensorRT-LLM Length of output: 1629 🏁 Script executed: #!/bin/bash
# Find the actual return statement in MTPWorker.forward
rg -n "return \{" tensorrt_llm/_torch/speculative/mtp.py -A 10Repository: NVIDIA/TensorRT-LLM Length of output: 1963 🏁 Script executed: #!/bin/bash
# Look at the end of MTPWorker.forward method
tail -200 tensorrt_llm/_torch/speculative/mtp.py | head -100Repository: NVIDIA/TensorRT-LLM Length of output: 5501 🏁 Script executed: #!/antml/bin/bash
# Check what model_forward returns in model_engine
sed -n '3310,3350p' tensorrt_llm/_torch/pyexecutor/model_engine.pyRepository: NVIDIA/TensorRT-LLM Length of output: 1816 🏁 Script executed: #!/bin/bash
# Get the exact method signature
sed -n '807,820p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM Length of output: 539 Fix return type annotation. The method's return type is annotated as 🤖 Prompt for AI Agents |
||
| else: | ||
| logits = self.logits_processor.forward( | ||
| hidden_states, | ||
|
|
@@ -860,6 +843,34 @@ def forward( | |
|
|
||
| return logits | ||
|
|
||
| def forward_draft(self, hidden_states, input_ids, position_ids, | ||
| attn_metadata, spec_metadata): | ||
| # get logits | ||
| logits = self.logits_processor.forward( | ||
| hidden_states[spec_metadata.gather_ids], | ||
| self.lm_head, | ||
| attn_metadata, | ||
| True, | ||
| ) | ||
| mtp_input_ids = input_ids | ||
| mtp_position_ids = position_ids | ||
| if attn_metadata.padded_num_tokens is not None: | ||
| if input_ids is not None: | ||
| # Slice along the first dimension | ||
| mtp_input_ids = input_ids[:attn_metadata.num_tokens] | ||
| if position_ids is not None: | ||
| # Slice along the last dimension | ||
| mtp_position_ids = position_ids[:, :attn_metadata.num_tokens] | ||
|
|
||
| # get accepted tokens and next draft tokens | ||
| return self.spec_worker(input_ids=mtp_input_ids, | ||
| position_ids=mtp_position_ids, | ||
| hidden_states=hidden_states, | ||
| logits=logits, | ||
| attn_metadata=attn_metadata, | ||
| spec_metadata=spec_metadata, | ||
| draft_model=self.draft_model) | ||
|
|
||
| def load_weights(self, | ||
| weights: Dict, | ||
| weight_mapper: Optional[BaseWeightMapper] = None, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -338,6 +338,7 @@ def __init__( | |
| ) or self.model_is_wrapped | ||
| self.max_draft_len = spec_config.max_draft_len | ||
| self.max_total_draft_tokens = spec_config.max_total_draft_tokens | ||
| self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard Right now
You likely only intend to run Proposed fix: initialize flag safely and gate the `forward_draft` call@@
- self.llm_args = llm_args
- self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
- self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
+ self.llm_args = llm_args
+ self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
+ self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
@@
- self.spec_config = spec_config
- self.is_spec_decode = spec_config is not None
+ self.spec_config = spec_config
+ self.is_spec_decode = spec_config is not None
+ # Default to True so non-speculative executors never take the draft-only path.
+ self.enable_cuda_graph_for_draft_model = (
+ spec_config.enable_cuda_graph_for_draft_model
+ if spec_config is not None else True
+ )
self.sparse_attention_config = None if is_draft_model else llm_args.sparse_attention_config
self.enable_spec_decode = self.is_spec_decode
self.is_draft_model = is_draft_model
@@
- else:
- with MoeLoadBalancerIterContext(moe_load_balancer):
- outputs = self.cuda_graph_runner.replay(key, inputs)
- if not self.enable_cuda_graph_for_draft_model:
- outputs = self.model.forward_draft(
- outputs, inputs['input_ids'],
- inputs['position_ids'],
- inputs['attn_metadata'],
- inputs['spec_metadata'])
+ else:
+ with MoeLoadBalancerIterContext(moe_load_balancer):
+ outputs = self.cuda_graph_runner.replay(key, inputs)
+ # When speculative decoding is enabled but we opted out of
+ # capturing the draft loop in the CUDA graph, run the
+ # draft-only pass after replay.
+ if (self.enable_spec_decode
+ and not self.enable_cuda_graph_for_draft_model):
+ outputs = self.model.forward_draft(
+ outputs,
+ inputs['input_ids'],
+ inputs['position_ids'],
+ inputs['attn_metadata'],
+ inputs['spec_metadata'],
+ )This keeps non-speculative flows and non-Eagle3 models on the existing path while enabling the new “draft outside CUDA graph” behavior only where Also applies to: 3269-3274 |
||
| else: | ||
| self.without_logits = False | ||
| self.max_draft_len = 0 | ||
|
|
@@ -3265,6 +3266,12 @@ def capture_postprocess_fn(inputs: Dict[str, Any]): | |
| else: | ||
| with MoeLoadBalancerIterContext(moe_load_balancer): | ||
| outputs = self.cuda_graph_runner.replay(key, inputs) | ||
| if not self.enable_cuda_graph_for_draft_model: | ||
| outputs = self.model.forward_draft( | ||
| outputs, inputs['input_ids'], | ||
| inputs['position_ids'], | ||
| inputs['attn_metadata'], | ||
| inputs['spec_metadata']) | ||
|
|
||
| if self.forward_pass_callable is not None: | ||
| self.forward_pass_callable() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential
AttributeErrorwhenspec_configisNone.The
spec_configvariable can beNone(assigned viagetattr(model_config, 'spec_config', None)on line 741). Accessingspec_config.enable_cuda_graph_for_draft_modeldirectly without a null check will raise anAttributeError.🔎 Proposed fix
Note: Defaulting to
Truepreserves backward-compatible behavior (CUDA graph capture enabled by default).📝 Committable suggestion
🤖 Prompt for AI Agents