Skip to content

Conversation

@lfr-0531
Copy link
Collaborator

@lfr-0531 lfr-0531 commented May 30, 2025

Description

  • Enable the overlap scheduler between different draft forwards.
  • Add a new Eagle3ResourceManager to manage the hidden states, and remove the extra model input.
  • Move the eagle3 fc into the model forward.
  • Move the h2d copy to the end of the _prepare_tp_inputs to hide the CPU time.
  • Disable CUDA graph for the 1st draft forward.

I'll collect more performance and accuracy data and update the results here.

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/reduce_eagle3_cpu_overhead branch 4 times, most recently from 652be47 to 275a6c3 Compare June 6, 2025 10:13
@lfr-0531 lfr-0531 marked this pull request as ready for review June 6, 2025 10:15
@lfr-0531 lfr-0531 requested review from a team as code owners June 6, 2025 10:15
@lfr-0531 lfr-0531 requested review from hyukn, juney-nvidia and mikeiovine and removed request for hyukn and juney-nvidia June 6, 2025 10:15
@lfr-0531 lfr-0531 changed the title draft: enable overlap scheduler between draft forwards [TRTLLM-4983] feat: enable overlap scheduler between draft forwards Jun 6, 2025
@lfr-0531
Copy link
Collaborator Author

lfr-0531 commented Jun 6, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #7897 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #7897 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #5706 completed with status: 'FAILURE'

Copy link
Collaborator

@mikeiovine mikeiovine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we measure the performance gain? Since our initial goal is to hit parity with vLLM, I think using Llama 3.3 70B on Hopper makes sense. I got these numbers on 8xH200 today.

Framework Max Batch Size OSL Eagle (draft len = 3)? Output tok/sec
TRTLLM 8 256 No 876
TRTLLM 8 256 Yes 1154
vLLM 8 256 No 715
vLLM 8 256 Yes 1430

A note about the dataset: I used gsm8k from here. For the llama 3.3 eagle drafters we have from the paper author's, you have to apply the tokenizer's chat template or the AR will drop significantly and performance will be regressed. I will send you a preprocessed dataset that is compatible with trtllm-bench.

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/reduce_eagle3_cpu_overhead branch 3 times, most recently from affc695 to 095d2c2 Compare June 10, 2025 11:56
@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8288 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8288 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #6000 completed with status: 'FAILURE'

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/reduce_eagle3_cpu_overhead branch from 095d2c2 to 52328f9 Compare June 10, 2025 15:42
@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8319 [ run ] triggered by Bot

@lfr-0531
Copy link
Collaborator Author

/bot kill

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/reduce_eagle3_cpu_overhead branch 2 times, most recently from 775cd4c to 5267115 Compare June 11, 2025 02:38
@tensorrt-cicd
Copy link
Collaborator

PR_Github #8377 [ kill ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8802 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8802 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #6390 completed with status: 'FAILURE'

@mikeiovine
Copy link
Collaborator

@lfr-0531 That plan sounds good to me. It can still be useful for other use case (@IzzyPutterman's draft/target speculative decoding should land soon, it'll be useful there since the draft models are a bit bigger than EAGLE).

@IzzyPutterman
Copy link
Collaborator

Do we already have logic to make sure that extra pass of overlap doesnt run (and infringe on the verify pass)? I think I have an internal MR that does this from a while ago.

@IzzyPutterman
Copy link
Collaborator

Something like this: #5211

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/reduce_eagle3_cpu_overhead branch from 9ec2aa9 to be5b542 Compare June 15, 2025 00:35
@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8907 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8907 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #6492 completed with status: 'FAILURE'

@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8925 [ run ] triggered by Bot

@lfr-0531
Copy link
Collaborator Author

Do we already have logic to make sure that extra pass of overlap doesnt run (and infringe on the verify pass)? I think I have an internal MR that does this from a while ago.

We didn't add it in this PR. This PR is only used to overlap the different forwards in the same iteration. After we have PR-5211, the last iteration, including the _prepare_draft_tokens, can be skipped.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #8925 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #6509 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@lfr-0531 lfr-0531 merged commit 39bba63 into NVIDIA:main Jun 15, 2025
3 checks passed
@lfr-0531 lfr-0531 deleted the user/fanrongl/reduce_eagle3_cpu_overhead branch June 27, 2025 12:43
@jhaotingc
Copy link
Collaborator

Hi @lfr-0531,
So, is overlap scheduler supported in Eagle3 Two Model yet?
support_overlap_scheduler
cc @mikeiovine

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants