Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
09b012e
[Draft] DeepGEMM Blackwell integration
Barry-Delaney Jul 13, 2025
ec400ab
Clean up fused_moe_deepgemm.py
Barry-Delaney Jul 13, 2025
d9a85ac
Moving permute space allocation to GPU
Barry-Delaney Jul 15, 2025
7c4045c
optimize padding in deepgemm moe.
lfr-0531 Jul 16, 2025
20b2592
add torch compile to per_token_cast_to_fp8_e8m0 and rm the two sync.
lfr-0531 Jul 16, 2025
c74a31a
Improve bmm.
yuxianq Jul 16, 2025
d3e1797
Online resmooth for fp8 checkpoint on Blackwell. (#2)
yuxianq Jul 16, 2025
d83cc25
Fix OOM issue for fp8 resmooth. (#4)
yuxianq Jul 17, 2025
e1e96fd
Enbale masked grouped GEMM (#5)
Barry-Delaney Jul 18, 2025
09b0465
Pin DeepGEMM's version to commit cc416ee. (#6)
yuxianq Jul 18, 2025
35b4e23
Improve resmooth. (#7)
yuxianq Jul 18, 2025
dce291f
Add compile for quantization kernels (#8)
Barry-Delaney Jul 18, 2025
b3ab47d
Move SF transform to TRTLLM (#11)
Barry-Delaney Jul 21, 2025
65d05d6
Use local barrier to avoid multi-node hang issue. (#12)
yuxianq Jul 21, 2025
d65bdac
optimize the masked index copy and index gather (#13)
lfr-0531 Jul 21, 2025
0af69ac
Fix adp for deepgemm moe backend (#10)
zongfeijing Jul 21, 2025
6f431f6
Use DeepGEMM main branch instead.
yuxianq Jul 21, 2025
481fd50
Revert "Use DeepGEMM main branch instead."
Barry-Delaney Jul 21, 2025
ab7175f
Use DeepGEMM main branch and disable ue8m0 cast. (#16)
yuxianq Jul 21, 2025
97a21fd
fuse maskec index_copy and grouped fp8 quantization.
lfr-0531 Jul 21, 2025
f668fa7
fix quantization accuracy issue.
lfr-0531 Jul 21, 2025
c6b8985
Fuse swiglu and quant 2 (#18)
Barry-Delaney Jul 21, 2025
11053b7
Opt gather kernel (#19)
zongfeijing Jul 24, 2025
0173836
optimize the perf of masked_index_copy_group_quant_fp8.
lfr-0531 Jul 23, 2025
bd94e37
fix duplicate load.
lfr-0531 Jul 23, 2025
f1d3115
fuse scaling factor transform to _masked_index_copy_group_quant_fp8.
lfr-0531 Jul 24, 2025
acd4381
fix.
lfr-0531 Jul 24, 2025
2d5beab
add another for loop on the group dim.
lfr-0531 Jul 24, 2025
5653eea
Remove SFB transform from forward process (#23)
Barry-Delaney Jul 25, 2025
49dcb98
change deepgeem to a new commit that with torch dependency. (#24)
lfr-0531 Jul 25, 2025
9997006
fix format and rebase bug.
lfr-0531 Jul 25, 2025
d8ae02c
fix dummy requests when estimate kv cache with attention DP enabled t…
lfr-0531 Jul 28, 2025
fb3e467
Fuse quantize and transform e8m0 scales (#26)
Barry-Delaney Jul 28, 2025
9107cfa
Revert "Fuse quantize and transform e8m0 scales (#26)" (#27)
Barry-Delaney Jul 28, 2025
59b3957
Fix CI install error for DeepGEMM. (#28)
yuxianq Jul 28, 2025
3c413be
Reapply "Fuse quantize and transform e8m0 scales (#26)" (#27) (#29)
Barry-Delaney Jul 28, 2025
2e9fcbe
Fix UT for DeepGEMM
zongfeijing Jul 30, 2025
10bfbb5
Fix sanity check for deepgemm
zongfeijing Jul 31, 2025
a7e54c9
Merge branch 'main' into user/zongfeij/ci-clean
zongfeijing Jul 31, 2025
dfd021c
Merge branch 'main' into user/zongfeij/ci-clean
zongfeijing Aug 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def add_llm_args(parser):
parser.add_argument('--moe_backend',
type=str,
default='CUTLASS',
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
choices=[
'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
'DEEPGEMM', 'CUTEDSL'
])
parser.add_argument('--enable_attention_dp',
default=False,
action='store_true')
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ etcd3
blake3
llguidance==0.7.29
soundfile
deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a
triton==3.3.1
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
BaseWeightLoader
from tensorrt_llm._torch.models.modeling_utils import (
register_checkpoint_weight_loader, run_concurrently)
from tensorrt_llm._utils import local_mpi_rank, local_mpi_size
from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank,
local_mpi_size)
from tensorrt_llm.logger import logger


Expand All @@ -38,6 +39,8 @@ def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
)
self.prefetch_files(weight_files)
# Ensure that all local ranks have finished prefetching before loading weights
local_mpi_barrier()

return self._load_weights_in_parallel(
weight_files, self._load_safetensors_file,
Expand Down
59 changes: 52 additions & 7 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@
from tqdm import tqdm
from transformers import PretrainedConfig

from tensorrt_llm import logger
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)

from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
Expand Down Expand Up @@ -1244,7 +1247,7 @@ def load_kv_b_proj_and_k_b_proj_trans(module_name: str,
weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size
local_num_heads = num_heads // weight_divisor

k_nope_weight_trans = k_nope_weight.transpose(2, 1)
k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous()

kv_b_proj = torch.concat([
k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim,
Expand All @@ -1256,11 +1259,6 @@ def load_kv_b_proj_and_k_b_proj_trans(module_name: str,

return kv_b_proj, k_nope_weight_trans

def check_weight_dtype(module_name: str, dtype):
weight_name = "weight"
w_dtype = weights[f"{module_name}.{weight_name}"].dtype
return w_dtype == dtype

def load_kv_b_proj_and_k_b_proj_trans_dequant(
module_name: str) -> torch.Tensor:
weight_name = "weight"
Expand Down Expand Up @@ -1290,7 +1288,7 @@ def load_kv_b_proj_and_k_b_proj_trans_dequant(
weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size
local_num_heads = num_heads // weight_divisor

k_nope_weight_trans = k_nope_weight.transpose(2, 1)
k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous()

kv_b_proj = torch.concat([
k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim,
Expand Down Expand Up @@ -1333,6 +1331,21 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
params_map = {'gate_up_proj': ['gate_proj', 'up_proj']}
all_named_modules = dict(self.named_modules())

if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and get_sm_version() == 100:
for name in list(weights.keys()):
# Use ".experts." to exclude shared_experts.
if name.endswith(
"weight_scale_inv") and ".experts." not in name:
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
weights[weight_name] = weights[weight_name].cpu()
weights[name] = weights[name].cpu()

for name, module in tqdm(all_named_modules.items(),
desc="Loading weights"):
if len(module._parameters) > 0:
Expand Down Expand Up @@ -1384,6 +1397,26 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
attn_module.v_b_proj_scale = nn.Parameter(
v_b_proj_scale, requires_grad=False)

if attn_module.k_b_proj_trans_dequant is not None:
attn_module.k_b_proj_trans_dequant.data.copy_(
weight_dequant(
k_b_proj_trans.view(
-1, k_b_proj_trans.shape[-1]).cuda(),
k_b_proj_trans_scale.view(
-1,
k_b_proj_trans_scale.shape[-1]).cuda(),
).view(
*attn_module.k_b_proj_trans_dequant.shape).
to(attn_module.k_b_proj_trans_dequant.dtype))
if attn_module.v_b_proj_dequant is not None:
attn_module.v_b_proj_dequant.data.copy_(
weight_dequant(
v_b_proj.view(-1,
v_b_proj.shape[-1]).cuda(),
v_b_proj_scale.view(
-1, v_b_proj_scale.shape[-1]).cuda(),
).view(*attn_module.v_b_proj_dequant.shape).to(
attn_module.v_b_proj_dequant.dtype))
elif names[-1] == "kv_a_proj_with_mqa":
fused_a = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
Expand Down Expand Up @@ -1431,6 +1464,18 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
for n, p in module.named_parameters():
p.data.copy_(module_weights[n][:])

if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and get_sm_version() == 100 and hasattr(
module, "weight_scale"):
transfromed_scale = transform_sf_into_required_layout(
module.weight_scale,
mn=module.weight.shape[0],
k=module.weight.shape[1],
recipe=(1, 128, 128),
is_sfa=False)
module.weight_scale = nn.Parameter(transfromed_scale,
requires_grad=False)

for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
90 changes: 61 additions & 29 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def fp8_block_scaling_bmm_out(
mat2_fp8: torch.Tensor,
mat2_scale: torch.Tensor,
out: torch.Tensor,
mat2_dequant: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sm_version = get_sm_version()
if sm_version == 90 or sm_version == 89:
Expand All @@ -365,30 +366,33 @@ def fp8_block_scaling_bmm_out(
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale, out)
elif sm_version == 100:
low_latency = True
use_deep_seek_fp8 = True
tile_size = 8
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
m_size = mat1.shape[0]
if m_size % tile_size != 0:
tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
mat1 = torch.nn.functional.pad(
mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)

mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
mat1_fp8,
mat2_fp8,
tile_size=tile_size,
epilogue_tile_m=epilogue_tile_m,
use_deep_seek_fp8=use_deep_seek_fp8,
low_latency=low_latency,
dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
dq_sfs_b=mat2_scale,
out_dtype=out.dtype,
)
out.copy_(output[:, :m_size])
output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2))
out.copy_(output)

# low_latency = True
# use_deep_seek_fp8 = True
# tile_size = 8
# epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
# m_size = mat1.shape[0]
# if m_size % tile_size != 0:
# tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
# mat1 = torch.nn.functional.pad(
# mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)

# mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
# mat1)
# output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
# mat1_fp8,
# mat2_fp8,
# tile_size=tile_size,
# epilogue_tile_m=epilogue_tile_m,
# use_deep_seek_fp8=use_deep_seek_fp8,
# low_latency=low_latency,
# dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
# dq_sfs_b=mat2_scale,
# out_dtype=out.dtype,
# )
# out.copy_(output[:, :m_size])
else:
raise NotImplementedError(f"SM{sm_version} is not supported")

Expand Down Expand Up @@ -676,6 +680,8 @@ def create_weights(self):
requires_grad=False,
)

self.k_b_proj_trans_dequant = None
self.v_b_proj_dequant = None
if has_fp8_block_scales:
self.k_b_proj_trans_scale = nn.Parameter(
torch.empty(
Expand All @@ -701,6 +707,23 @@ def create_weights(self):
),
requires_grad=False,
)
if get_sm_version() == 100:
assert self.dtype == torch.bfloat16
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.kv_lora_rank,
self.qk_nope_head_dim),
dtype=self.dtype,
),
requires_grad=False,
)
self.v_b_proj_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.v_head_dim, self.kv_lora_rank),
dtype=self.dtype,
),
requires_grad=False,
)
Comment on lines +710 to +726
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use model dtype instead of hardcoded bfloat16.

The assertion at line 711 enforces bfloat16 dtype, which may be too restrictive. The dequantized tensors should use the model's configured dtype.

Apply this diff to use the model's dtype:

             if get_sm_version() == 100:
-                assert self.dtype == torch.bfloat16
                 self.k_b_proj_trans_dequant = nn.Parameter(
                     torch.empty(
                         (self.num_heads, self.kv_lora_rank,
                          self.qk_nope_head_dim),
                         dtype=self.dtype,
                     ),
                     requires_grad=False,
                 )
                 self.v_b_proj_dequant = nn.Parameter(
                     torch.empty(
                         (self.num_heads, self.v_head_dim, self.kv_lora_rank),
                         dtype=self.dtype,
                     ),
                     requires_grad=False,
                 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if get_sm_version() == 100:
assert self.dtype == torch.bfloat16
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.kv_lora_rank,
self.qk_nope_head_dim),
dtype=self.dtype,
),
requires_grad=False,
)
self.v_b_proj_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.v_head_dim, self.kv_lora_rank),
dtype=self.dtype,
),
requires_grad=False,
)
if get_sm_version() == 100:
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.kv_lora_rank,
self.qk_nope_head_dim),
dtype=self.dtype,
),
requires_grad=False,
)
self.v_b_proj_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.v_head_dim, self.kv_lora_rank),
dtype=self.dtype,
),
requires_grad=False,
)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py around lines 710 to 726, the code
asserts that the dtype must be torch.bfloat16, which is too restrictive. Remove
the assertion enforcing bfloat16 and ensure that the dequantized tensors are
created using the model's configured dtype (self.dtype) instead of hardcoding
bfloat16. This will make the code compatible with different dtypes as configured
in the model.

else:
self.k_b_proj_trans_scale = None
self.v_b_proj_scale = None
Expand Down Expand Up @@ -1197,8 +1220,13 @@ def forward_generation(
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)

fp8_block_scaling_bmm_out(q_nope, self.k_b_proj_trans,
self.k_b_proj_trans_scale, q_nope_out)
fp8_block_scaling_bmm_out(
q_nope,
self.k_b_proj_trans,
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
Expand Down Expand Up @@ -1247,9 +1275,13 @@ def forward_generation(
self.v_b_proj.transpose(1, 2),
attn_output.transpose(0, 1))
elif self.v_b_proj.dtype == torch.float8_e4m3fn:
fp8_block_scaling_bmm_out(attn_out_latent, self.v_b_proj,
self.v_b_proj_scale,
attn_output.transpose(0, 1))
fp8_block_scaling_bmm_out(
attn_out_latent,
self.v_b_proj,
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
Expand Down
17 changes: 17 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...model_config import ModelConfig
from .fused_moe_cute_dsl import CuteDslFusedMoE
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_deepgemm import DeepGemmFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from .fused_moe_vanilla import VanillaMoE
from .fused_moe_wide_ep import WideEPMoE
Expand All @@ -31,6 +32,8 @@ def get_moe_cls(
return VanillaMoE
elif moe_backend.upper() == "CUTEDSL":
return CuteDslFusedMoE
elif moe_backend.upper() == "DEEPGEMM":
return DeepGemmFusedMoE
elif moe_backend.upper() == "TRTLLM":
if quant_config is not None and (
quant_config.quant_mode.has_fp8_block_scales()
Expand Down Expand Up @@ -139,5 +142,19 @@ def create_moe(
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
elif moe_cls == DeepGemmFusedMoE:
return moe_cls(
routing_method=routing_method,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
else:
raise ValueError(f"Unsupported moe backend: {moe_cls}")
Loading