Skip to content

Commit 9cb9d21

Browse files
committed
[TRTLLM-7408][feat] Wrap MOE with custom op.
* Let all moe backend go through the same interface * MOE is wrapped with custom op to improve full graph torch compile compatibility Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 23f72c8 commit 9cb9d21

File tree

16 files changed

+185
-101
lines changed

16 files changed

+185
-101
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,7 @@ class AttentionMetadata:
121121
default_factory=AttentionRuntimeFeatures)
122122

123123
# The number of tokens in each rank.
124-
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
125-
default=None,
126-
repr=False)
127-
all_rank_num_tokens: Optional[List[int]]
128-
# The max number of tokens among all ranks.
129-
all_rank_max_num_tokens: Optional[int] = None
124+
all_rank_num_tokens: Optional[List[int]] = None
130125

131126
# These fields are set when changing seq_lens and _num_contexts to avoid computation
132127
# during execution. If the calculation happens during execution, torch compile treats it
@@ -163,16 +158,6 @@ def on_update(self):
163158
elif self._seq_lens is not None:
164159
self._num_tokens = self._seq_lens.sum().item()
165160

166-
@property
167-
def all_rank_num_tokens(self) -> Optional[List[int]]:
168-
return self._all_rank_num_tokens
169-
170-
@all_rank_num_tokens.setter
171-
def all_rank_num_tokens(self, value: Optional[List[int]]):
172-
value = value if value is not AttentionMetadata.all_rank_num_tokens else None
173-
self._all_rank_num_tokens = value
174-
self.all_rank_max_num_tokens = max(value) if value is not None else None
175-
176161
@property
177162
def seq_lens(self) -> Optional[torch.Tensor]:
178163
return self._seq_lens

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,7 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,
531531
return shared_tp_size, shared_output_scale
532532

533533
def compute_routed_output(self, hidden_states, hidden_states_fp4,
534-
all_rank_num_tokens, all_rank_max_num_tokens,
535-
do_finalize):
534+
all_rank_num_tokens, do_finalize):
536535
# max-throughput
537536
use_dp_padding = False
538537
if self.use_dp and self.mapping.tp_size > 1:
@@ -551,7 +550,6 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
551550
do_finalize=do_finalize,
552551
output_dtype=hidden_states.dtype,
553552
all_rank_num_tokens=all_rank_num_tokens,
554-
all_rank_max_num_tokens=all_rank_max_num_tokens,
555553
use_dp_padding=use_dp_padding,
556554
)
557555

@@ -562,7 +560,6 @@ def forward(
562560
hidden_states: torch.Tensor,
563561
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
564562
all_rank_num_tokens: Optional[list[int]] = None,
565-
all_rank_max_num_tokens: Optional[int] = None,
566563
final_all_reduce_params: Optional[AllReduceParams] = None,
567564
do_finalize: Optional[bool] = True,
568565
) -> torch.Tensor:
@@ -581,7 +578,6 @@ def _compute_routed_output():
581578
routed_output = self.compute_routed_output(hidden_states,
582579
hidden_states_fp4,
583580
all_rank_num_tokens,
584-
all_rank_max_num_tokens,
585581
do_finalize)
586582
return routed_output
587583

@@ -804,7 +800,6 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
804800
hidden_states,
805801
hidden_states_fp4,
806802
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
807-
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
808803
final_all_reduce_params=AllReduceParams(
809804
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
810805
or self.mapping.tp_size == 1)),
@@ -992,7 +987,6 @@ def forward(
992987
embed_tokens: Embedding,
993988
attn_metadata: AttentionMetadata,
994989
all_rank_num_tokens: Optional[List[int]] = None,
995-
all_rank_max_num_tokens: Optional[int] = None,
996990
**kwargs,
997991
) -> torch.Tensor:
998992

@@ -1051,7 +1045,6 @@ def norm_hidden():
10511045
hidden_states = self.mlp(
10521046
hidden_states,
10531047
all_rank_num_tokens=all_rank_num_tokens,
1054-
all_rank_max_num_tokens=all_rank_max_num_tokens,
10551048
final_all_reduce_params=AllReduceParams(
10561049
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
10571050
or self.mapping.tp_size == 1)),

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def forward_attn_dp(
258258

259259
# Get attention_dp parameters
260260
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
261-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
262261

263262
if self.mapping.tp_size > 1 and all_rank_num_tokens is not None:
264263
if (isinstance(self.experts, (TRTLLMGenFusedMoE, TritonFusedMoE))):
@@ -276,12 +275,10 @@ def forward_attn_dp(
276275

277276
# Let CutlassFusedMoE handle allgather internally
278277
# Pass the normalized tensor (t) as input to experts, not x
279-
expert_output = self.experts(
280-
x=t,
281-
router_logits=g,
282-
all_rank_num_tokens=all_rank_num_tokens,
283-
all_rank_max_num_tokens=all_rank_max_num_tokens,
284-
use_dp_padding=False)
278+
expert_output = self.experts(x=t,
279+
router_logits=g,
280+
all_rank_num_tokens=all_rank_num_tokens,
281+
use_dp_padding=False)
285282

286283
expert_output = expert_output.view(orig_shape)
287284
return expert_output, residual

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -309,32 +309,27 @@ def __init__(
309309
self.aux_stream = aux_stream
310310

311311
def compute_routed_output(self, hidden_states, all_rank_num_tokens,
312-
all_rank_max_num_tokens,
313312
cutlass_min_latency_mode):
314313
router_logits = self.router(hidden_states)
315-
routed_output = self.experts(
316-
hidden_states,
317-
router_logits,
318-
do_finalize=not cutlass_min_latency_mode,
319-
all_rank_num_tokens=all_rank_num_tokens,
320-
all_rank_max_num_tokens=all_rank_max_num_tokens,
321-
use_dp_padding=False)
314+
routed_output = self.experts(hidden_states,
315+
router_logits,
316+
do_finalize=not cutlass_min_latency_mode,
317+
all_rank_num_tokens=all_rank_num_tokens,
318+
use_dp_padding=False)
322319
return routed_output
323320

324321
def forward(
325322
self,
326323
hidden_states: torch.Tensor,
327324
all_rank_num_tokens=None,
328-
all_rank_max_num_tokens=None,
329325
final_all_reduce_params: Optional[AllReduceParams] = None,
330326
cutlass_min_latency_mode: Optional[bool] = False,
331327
) -> torch.Tensor:
332328
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
333329
# This design is mainly for low latency use case. Need to improve for max throughput use case.
334330
fn0 = lambda: self.shared_expert(hidden_states)
335331
fn1 = lambda: self.compute_routed_output(
336-
hidden_states, all_rank_num_tokens, all_rank_max_num_tokens,
337-
cutlass_min_latency_mode)
332+
hidden_states, all_rank_num_tokens, cutlass_min_latency_mode)
338333
shared_output, routed_output = maybe_execute_in_parallel(
339334
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
340335
if cutlass_min_latency_mode:
@@ -536,7 +531,6 @@ def forward(
536531
hidden_states = self.feed_forward(
537532
hidden_states,
538533
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
539-
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
540534
final_all_reduce_params=AllReduceParams(
541535
enable_allreduce=not self.disable_feed_forward_allreduce),
542536
cutlass_min_latency_mode=cutlass_min_latency_mode,

tensorrt_llm/_torch/models/modeling_mixtral.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,11 @@ def forward(
6262
attn_metadata: AttentionMetadata,
6363
) -> torch.Tensor:
6464
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
65-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
6665
router_logits = self.gate(hidden_states)
6766
final_hidden_states = self.experts(
6867
hidden_states,
6968
router_logits,
7069
all_rank_num_tokens=all_rank_num_tokens,
71-
all_rank_max_num_tokens=all_rank_max_num_tokens,
7270
use_dp_padding=False)
7371
return final_hidden_states
7472

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def forward(
127127
hidden_states = hidden_states.view(-1, self.hidden_dim)
128128
use_dp_padding = False
129129
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
130-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
131130

132131
if not do_finalize:
133132
assert not self.enable_attention_dp
@@ -144,7 +143,6 @@ def forward(
144143
hidden_states,
145144
router_logits,
146145
all_rank_num_tokens=all_rank_num_tokens,
147-
all_rank_max_num_tokens=all_rank_max_num_tokens,
148146
use_dp_padding=use_dp_padding,
149147
do_finalize=do_finalize,
150148
)

tensorrt_llm/_torch/models/modeling_qwen_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,11 @@ def forward(
8484
hidden_states = hidden_states.view(-1, self.hidden_dim)
8585

8686
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
87-
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
8887
router_logits = self.gate(hidden_states)
8988
final_hidden_states = self.experts(
9089
hidden_states,
9190
router_logits,
9291
all_rank_num_tokens=all_rank_num_tokens,
93-
all_rank_max_num_tokens=all_rank_max_num_tokens,
9492
use_dp_padding=False)
9593

9694
shared_expert_output = self.shared_expert(hidden_states)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
swiglu_alpha=swiglu_alpha,
8686
swiglu_beta=swiglu_beta,
8787
swiglu_limit=swiglu_limit,
88+
layer_idx=layer_idx,
8889
)
8990

9091
# Store original hidden size before any potential padding
@@ -96,8 +97,6 @@ def __init__(
9697
self.intermediate_size_per_partition = (
9798
(self.intermediate_size_per_partition + 127) // 128) * 128
9899

99-
self.layer_idx = layer_idx
100-
101100
self.num_slots = self.num_experts
102101
self.expert_size_per_partition = self.num_experts // self.ep_size
103102
self.initial_global_assignments = [
@@ -449,14 +448,13 @@ def split_chunk(self, split_token_num: int, split_num_chunks: int):
449448
split_num_chunks - val_mod)
450449
return split_chunk_size_list
451450

452-
def forward(
451+
def forward_impl(
453452
self,
454453
x: Union[torch.Tensor, Fp4QuantizedTensor],
455454
router_logits: torch.Tensor,
456455
do_finalize: bool = True, # used by other MoE backends
457456
output_dtype: Optional[torch.dtype] = None,
458457
all_rank_num_tokens: Optional[List[int]] = None,
459-
all_rank_max_num_tokens: Optional[int] = None,
460458
use_dp_padding: Optional[bool] = None,
461459
) -> torch.Tensor:
462460
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
@@ -472,7 +470,7 @@ def forward(
472470
1) // self.moe_max_num_tokens
473471

474472
if use_dp_padding:
475-
all_rank_num_tokens_padded = [all_rank_max_num_tokens
473+
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
476474
] * len(all_rank_num_tokens)
477475
else:
478476
all_rank_num_tokens_padded = all_rank_num_tokens

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,14 +637,13 @@ def forward_chunk(
637637

638638
return final_hidden_states
639639

640-
def forward(
640+
def forward_impl(
641641
self,
642642
x: Union[torch.Tensor, Fp4QuantizedTensor],
643643
router_logits: torch.Tensor,
644644
do_finalize: bool = True, # used by other MoE backends
645645
output_dtype: Optional[torch.dtype] = None,
646646
all_rank_num_tokens: Optional[List[int]] = None,
647-
all_rank_max_num_tokens: Optional[int] = None,
648647
use_dp_padding: Optional[bool] = None,
649648
) -> torch.Tensor:
650649
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
@@ -663,7 +662,7 @@ def forward(
663662
1) // self.moe_max_num_tokens
664663

665664
if use_dp_padding:
666-
all_rank_num_tokens_padded = [all_rank_max_num_tokens
665+
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
667666
] * len(all_rank_num_tokens)
668667
else:
669668
all_rank_num_tokens_padded = all_rank_num_tokens

tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,7 @@ def __init__(
12871287
reduce_results=reduce_results,
12881288
model_config=model_config,
12891289
weight_loading_mode=weight_loading_mode,
1290+
layer_idx=layer_idx,
12901291
)
12911292
if not IS_TRITON_KERNELS_AVAILABLE:
12921293
raise ImportError("Triton kernels are not available.")
@@ -1359,7 +1360,7 @@ def create_weights(self):
13591360

13601361
self._weights_created = True
13611362

1362-
def forward(
1363+
def forward_impl(
13631364
self,
13641365
x: torch.Tensor,
13651366
router_logits: torch.Tensor,

0 commit comments

Comments
 (0)