diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index 8087ef30dff..892c021b631 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -48,8 +48,9 @@ def __init__( rope=RopeParams.from_config(config), ) - # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712) - # TODO: Consider adding disable_deep_gemm support to QKNormRoPEAttention if accuracy still remains + # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712 + # and https://nvbugspro.nvidia.com/bug/5505402) + disable_deep_gemm = True super().__init__( hidden_size=config.hidden_size, @@ -63,6 +64,7 @@ def __init__( dtype=config.torch_dtype, dense_bias=config.attention_bias, config=model_config, + disable_deep_gemm=disable_deep_gemm, ) @@ -81,12 +83,17 @@ def __init__( layer_idx=layer_idx, ) + # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712 + # and https://nvbugspro.nvidia.com/bug/5505402) + disable_deep_gemm = True + self.mlp = GatedMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, bias=config.mlp_bias if hasattr(config, "mlp_bias") else False, dtype=config.torch_dtype, config=model_config, + disable_deep_gemm=disable_deep_gemm, ) self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index c488e2cd3f9..81cc480dd11 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -116,6 +116,7 @@ def __init__( config: Optional[ModelConfig] = None, q_scaling: float = 1.0, attention_chunk_size: Optional[int] = None, + disable_deep_gemm: bool = False, ): """ Initialize the Attention module. @@ -134,6 +135,7 @@ def __init__( config (Optional[ModelConfig]): The model configuration. q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0. attention_chunk_size (Optional[int]): See [Chunked Attention] below. + disable_deep_gemm (bool): Whether to disable the use of DeepGEMM in Linear layers (currently only matters on SM100 + FP8). """ super().__init__() self.layer_idx = layer_idx @@ -215,7 +217,9 @@ def __init__( quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + disable_deep_gemm=disable_deep_gemm, + ) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) @@ -231,7 +235,9 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.o_lora, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + disable_deep_gemm=disable_deep_gemm, + ) self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index cf381ea2c27..90af4440c36 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -18,18 +18,21 @@ class GatedMLP(nn.Module): - def __init__(self, - *, - hidden_size: int, - intermediate_size: int, - bias: bool, - activation: Callable[[torch.Tensor], torch.Tensor] = F.silu, - dtype: Optional[torch.dtype] = None, - config: Optional[ModelConfig] = None, - overridden_tp_size: Optional[int] = None, - reduce_output: bool = True, - layer_idx: Optional[int] = None, - use_cute_dsl_blockscaling_mm: bool = False): + def __init__( + self, + *, + hidden_size: int, + intermediate_size: int, + bias: bool, + activation: Callable[[torch.Tensor], torch.Tensor] = F.silu, + dtype: Optional[torch.dtype] = None, + config: Optional[ModelConfig] = None, + overridden_tp_size: Optional[int] = None, + reduce_output: bool = True, + layer_idx: Optional[int] = None, + use_cute_dsl_blockscaling_mm: bool = False, + disable_deep_gemm: bool = False, + ): super().__init__() self.layer_idx = layer_idx @@ -68,7 +71,9 @@ def __init__(self, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, + disable_deep_gemm=disable_deep_gemm, + ) self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], [self.hidden_size]) @@ -86,7 +91,9 @@ def __init__(self, lora=self.down_lora, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, + disable_deep_gemm=disable_deep_gemm, + ) # These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used, # but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index c91e4532ab4..f797995dc0f 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -614,7 +614,7 @@ def apply(self, module: Linear, input: torch.Tensor, assert input.dtype == torch.bfloat16 if get_sm_version() == 100: - if module.use_cute_dsl_blockscaling_mm: + if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm: # TODO (@lmin): replace with cute_dsl gemm act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( input) @@ -1789,6 +1789,7 @@ def __init__( allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, force_dynamic_quantization: bool = False, use_cute_dsl_blockscaling_mm: bool = False, + disable_deep_gemm: bool = False, ): from ..distributed import AllReduce @@ -1806,6 +1807,7 @@ def __init__( self.gather_output = gather_output self.force_dynamic_quantization = force_dynamic_quantization self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm + self.disable_deep_gemm = disable_deep_gemm local_in_features = in_features local_out_features = out_features diff --git a/tensorrt_llm/_torch/modules/qk_norm_attention.py b/tensorrt_llm/_torch/modules/qk_norm_attention.py index 6c146e4bfcd..c64b24ca693 100644 --- a/tensorrt_llm/_torch/modules/qk_norm_attention.py +++ b/tensorrt_llm/_torch/modules/qk_norm_attention.py @@ -155,6 +155,7 @@ def __init__( dense_bias: Optional[bool] = None, config: ModelConfig, q_scaling: float = 1.0, + disable_deep_gemm: bool = False, ): self.pretrained_config = config.pretrained_config @@ -178,6 +179,7 @@ def __init__( dense_bias=dense_bias, config=config, q_scaling=q_scaling, + disable_deep_gemm=disable_deep_gemm, ) self.q_norm = RMSNorm(hidden_size=self.head_dim, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ad9a781c1ec..c73eb62eebe 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2329,7 +2329,6 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-8B" @skip_pre_hopper - @pytest.mark.skip(reason="https://nvbugs/5505402") @pytest.mark.parametrize( "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [(1, 1, 1, False, True, True)], diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 031548955b4..0b78108bfe8 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -37,7 +37,7 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton] - - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # nvbugs 5505402 + - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # Cover nvbugs 5461712 and 5505402 - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]