Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def __init__(
rope=RopeParams.from_config(config),
)

# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712)
# TODO: Consider adding disable_deep_gemm support to QKNormRoPEAttention if accuracy still remains
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
# and https://nvbugspro.nvidia.com/bug/5505402)
disable_deep_gemm = True

super().__init__(
hidden_size=config.hidden_size,
Expand All @@ -63,6 +64,7 @@ def __init__(
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
disable_deep_gemm=disable_deep_gemm,
)


Expand All @@ -81,12 +83,17 @@ def __init__(
layer_idx=layer_idx,
)

# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
# and https://nvbugspro.nvidia.com/bug/5505402)
disable_deep_gemm = True

self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=config.mlp_bias if hasattr(config, "mlp_bias") else False,
dtype=config.torch_dtype,
config=model_config,
disable_deep_gemm=disable_deep_gemm,
)

self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
Expand Down
10 changes: 8 additions & 2 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
config: Optional[ModelConfig] = None,
q_scaling: float = 1.0,
attention_chunk_size: Optional[int] = None,
disable_deep_gemm: bool = False,
):
"""
Initialize the Attention module.
Expand All @@ -134,6 +135,7 @@ def __init__(
config (Optional[ModelConfig]): The model configuration.
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
attention_chunk_size (Optional[int]): See [Chunked Attention] below.
disable_deep_gemm (bool): Whether to disable the use of DeepGEMM in Linear layers (currently only matters on SM100 + FP8).
"""
super().__init__()
self.layer_idx = layer_idx
Expand Down Expand Up @@ -215,7 +217,9 @@ def __init__(
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
)

self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
Expand All @@ -231,7 +235,9 @@ def __init__(
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
)

self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
Expand Down
35 changes: 21 additions & 14 deletions tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@

class GatedMLP(nn.Module):

def __init__(self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = F.silu,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
overridden_tp_size: Optional[int] = None,
reduce_output: bool = True,
layer_idx: Optional[int] = None,
use_cute_dsl_blockscaling_mm: bool = False):
def __init__(
self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = F.silu,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
overridden_tp_size: Optional[int] = None,
reduce_output: bool = True,
layer_idx: Optional[int] = None,
use_cute_dsl_blockscaling_mm: bool = False,
disable_deep_gemm: bool = False,
):

super().__init__()
self.layer_idx = layer_idx
Expand Down Expand Up @@ -68,7 +71,9 @@ def __init__(self,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
disable_deep_gemm=disable_deep_gemm,
)

self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
Expand All @@ -86,7 +91,9 @@ def __init__(self,
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
disable_deep_gemm=disable_deep_gemm,
)

# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def apply(self, module: Linear, input: torch.Tensor,
assert input.dtype == torch.bfloat16

if get_sm_version() == 100:
if module.use_cute_dsl_blockscaling_mm:
if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm:
# TODO (@lmin): replace with cute_dsl gemm
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
Expand Down Expand Up @@ -1789,6 +1789,7 @@ def __init__(
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
force_dynamic_quantization: bool = False,
use_cute_dsl_blockscaling_mm: bool = False,
disable_deep_gemm: bool = False,
):
from ..distributed import AllReduce

Expand All @@ -1806,6 +1807,7 @@ def __init__(
self.gather_output = gather_output
self.force_dynamic_quantization = force_dynamic_quantization
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
self.disable_deep_gemm = disable_deep_gemm

local_in_features = in_features
local_out_features = out_features
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/modules/qk_norm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
dense_bias: Optional[bool] = None,
config: ModelConfig,
q_scaling: float = 1.0,
disable_deep_gemm: bool = False,
):
self.pretrained_config = config.pretrained_config

Expand All @@ -178,6 +179,7 @@ def __init__(
dense_bias=dense_bias,
config=config,
q_scaling=q_scaling,
disable_deep_gemm=disable_deep_gemm,
)

self.q_norm = RMSNorm(hidden_size=self.head_dim,
Expand Down
1 change: 0 additions & 1 deletion tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,6 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-8B"

@skip_pre_hopper
@pytest.mark.skip(reason="https://nvbugs/5505402")
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler",
[(1, 1, 1, False, True, True)],
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_lists/test-db/l0_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # nvbugs 5505402
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # Cover nvbugs 5461712 and 5505402
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
Expand Down
Loading