Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address comments
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
  • Loading branch information
meenchen committed Dec 1, 2025
commit 78eeaa172402b28f87375f147cba61aceca0a89d
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,10 @@ def _attn_impl(
# and keeps attention output in BF16 for better precision when applying pre_quant_scale
if self.has_quant_scale and not self.attn_output_gate and not has_awq_pre_quant_scale:
out_scale = self.o_proj.inv_input_scale
if has_awq_pre_quant_scale:
if has_awq_pre_quant_scale and enable_attn_nvfp4_output:
logger.warning_once(
"Disable attn nvfp4 output because o_proj has pre_quant_scale for AWQ."
)
enable_attn_nvfp4_output = False
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output and not self.attn_output_gate:
out_scale_sf = self.o_proj.input_scale
Expand Down
33 changes: 15 additions & 18 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,27 +1918,24 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
w3_reference = all_w3_pre_quant_scales[0]
w1_reference = all_w1_pre_quant_scales[0]

def check_consistency(scale, ref_scale, scale_name, expert_id):
if not torch.allclose(scale, ref_scale, rtol=1e-5, atol=1e-8):
max_diff = (scale - ref_scale).abs().max()
msg = (
f"MoE pre_quant_scale: expert {expert_id} {scale_name} "
f"differs from expert {module.initial_local_expert_ids[0]}! Max diff: {max_diff:.6e}. "
f"All experts should have identical pre_quant_scale since they share the same input."
)
logger.error(msg)
raise ValueError(msg)

for i, (w3_scale, w1_scale) in enumerate(
zip(all_w3_pre_quant_scales[1:],
all_w1_pre_quant_scales[1:]), 1):
if not torch.allclose(
w3_scale, w3_reference, rtol=1e-5, atol=1e-8):
max_diff = (w3_scale - w3_reference).abs().max()
logger.warning(
f"MoE pre_quant_scale: expert {module.initial_local_expert_ids[i]} w3.pre_quant_scale "
f"differs from expert {module.initial_local_expert_ids[0]}! Max diff: {max_diff:.6e}. "
f"All experts should have identical pre_quant_scale since they share the same input. "
f"Using the first expert's value.")
break
if not torch.allclose(
w1_scale, w1_reference, rtol=1e-5, atol=1e-8):
max_diff = (w1_scale - w1_reference).abs().max()
logger.warning(
f"MoE pre_quant_scale: expert {module.initial_local_expert_ids[i]} w1.pre_quant_scale "
f"differs from expert {module.initial_local_expert_ids[0]}! Max diff: {max_diff:.6e}. "
f"All experts should have identical pre_quant_scale since they share the same input. "
f"Using the first expert's value.")
break
check_consistency(w3_scale, w3_reference, "w3.pre_quant_scale",
module.initial_local_expert_ids[i])
check_consistency(w1_scale, w1_reference, "w1.pre_quant_scale",
module.initial_local_expert_ids[i])

# Take the maximum pre_quant_scale between w3 and w1 from the first expert
# (all experts should have the same values)
Expand Down