diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_next.py b/tensorrt_llm/_torch/models/modeling_qwen3_next.py index e8b2021fb6f..8061be539e9 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_next.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_next.py @@ -647,11 +647,10 @@ def fused_gdn_gating( class Qwen3NextGatedDeltaNet(nn.Module): - def __init__( - self, - model_config: ModelConfig[Qwen3NextConfig], - layer_idx: Optional[int] = None, - ): + def __init__(self, + model_config: ModelConfig[Qwen3NextConfig], + aux_stream: torch.cuda.Stream, + layer_idx: Optional[int] = None): super().__init__() config = model_config.pretrained_config self.model_config = model_config @@ -778,6 +777,12 @@ def __init__( force_dynamic_quantization=model_config.force_dynamic_quantization, use_cute_dsl_blockscaling_mm=False) + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.Attention] + } + self.aux_stream = aux_stream + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. @@ -1032,8 +1037,19 @@ def forward( ssm_states[state_indices_p] = 0 # conv_states[state_indices_p] = 0 # not necessary - projected_states_qkvz = self.in_proj_qkvz(hidden_states) - projected_states_ba = self.in_proj_ba(hidden_states) + def _compute_projected_states_qkvz(): + return self.in_proj_qkvz(hidden_states) + + def _compute_projected_states_ba(): + return self.in_proj_ba(hidden_states) + + projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel( + _compute_projected_states_qkvz, + _compute_projected_states_ba, + self.event_dict[EventType.Main], + self.event_dict[EventType.Attention], + self.aux_stream, + ) # Use fused kernel when possible to avoid elementwise ops if self.num_v_heads // self.num_k_heads in [1, 2, @@ -1098,7 +1114,8 @@ def __init__( super().__init__() self.model_config = model_config config = model_config.pretrained_config - self.linear_attn = Qwen3NextGatedDeltaNet(model_config, layer_idx) + self.linear_attn = Qwen3NextGatedDeltaNet(model_config, aux_stream, + layer_idx) self.mapping = model_config.mapping self.enable_attention_dp = self.mapping.enable_attention_dp