-
Notifications
You must be signed in to change notification settings - Fork 2k
Deepseek R1 FP8 Support on Blackwell #6486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
09b012e
[Draft] DeepGEMM Blackwell integration
Barry-Delaney ec400ab
Clean up fused_moe_deepgemm.py
Barry-Delaney d9a85ac
Moving permute space allocation to GPU
Barry-Delaney 7c4045c
optimize padding in deepgemm moe.
lfr-0531 20b2592
add torch compile to per_token_cast_to_fp8_e8m0 and rm the two sync.
lfr-0531 c74a31a
Improve bmm.
yuxianq d3e1797
Online resmooth for fp8 checkpoint on Blackwell. (#2)
yuxianq d83cc25
Fix OOM issue for fp8 resmooth. (#4)
yuxianq e1e96fd
Enbale masked grouped GEMM (#5)
Barry-Delaney 09b0465
Pin DeepGEMM's version to commit cc416ee. (#6)
yuxianq 35b4e23
Improve resmooth. (#7)
yuxianq dce291f
Add compile for quantization kernels (#8)
Barry-Delaney b3ab47d
Move SF transform to TRTLLM (#11)
Barry-Delaney 65d05d6
Use local barrier to avoid multi-node hang issue. (#12)
yuxianq d65bdac
optimize the masked index copy and index gather (#13)
lfr-0531 0af69ac
Fix adp for deepgemm moe backend (#10)
zongfeijing 6f431f6
Use DeepGEMM main branch instead.
yuxianq 481fd50
Revert "Use DeepGEMM main branch instead."
Barry-Delaney ab7175f
Use DeepGEMM main branch and disable ue8m0 cast. (#16)
yuxianq 97a21fd
fuse maskec index_copy and grouped fp8 quantization.
lfr-0531 f668fa7
fix quantization accuracy issue.
lfr-0531 c6b8985
Fuse swiglu and quant 2 (#18)
Barry-Delaney 11053b7
Opt gather kernel (#19)
zongfeijing 0173836
optimize the perf of masked_index_copy_group_quant_fp8.
lfr-0531 bd94e37
fix duplicate load.
lfr-0531 f1d3115
fuse scaling factor transform to _masked_index_copy_group_quant_fp8.
lfr-0531 acd4381
fix.
lfr-0531 2d5beab
add another for loop on the group dim.
lfr-0531 5653eea
Remove SFB transform from forward process (#23)
Barry-Delaney 49dcb98
change deepgeem to a new commit that with torch dependency. (#24)
lfr-0531 9997006
fix format and rebase bug.
lfr-0531 d8ae02c
fix dummy requests when estimate kv cache with attention DP enabled t…
lfr-0531 fb3e467
Fuse quantize and transform e8m0 scales (#26)
Barry-Delaney 9107cfa
Revert "Fuse quantize and transform e8m0 scales (#26)" (#27)
Barry-Delaney 59b3957
Fix CI install error for DeepGEMM. (#28)
yuxianq 3c413be
Reapply "Fuse quantize and transform e8m0 scales (#26)" (#27) (#29)
Barry-Delaney 2e9fcbe
Fix UT for DeepGEMM
zongfeijing 10bfbb5
Fix sanity check for deepgemm
zongfeijing a7e54c9
Merge branch 'main' into user/zongfeij/ci-clean
zongfeijing dfd021c
Merge branch 'main' into user/zongfeij/ci-clean
zongfeijing File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use model dtype instead of hardcoded bfloat16.
The assertion at line 711 enforces bfloat16 dtype, which may be too restrictive. The dequantized tensors should use the model's configured dtype.
Apply this diff to use the model's dtype:
if get_sm_version() == 100: - assert self.dtype == torch.bfloat16 self.k_b_proj_trans_dequant = nn.Parameter( torch.empty( (self.num_heads, self.kv_lora_rank, self.qk_nope_head_dim), dtype=self.dtype, ), requires_grad=False, ) self.v_b_proj_dequant = nn.Parameter( torch.empty( (self.num_heads, self.v_head_dim, self.kv_lora_rank), dtype=self.dtype, ), requires_grad=False, )📝 Committable suggestion
🤖 Prompt for AI Agents