@@ -531,8 +531,7 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,
531531 return shared_tp_size , shared_output_scale
532532
533533 def compute_routed_output (self , hidden_states , hidden_states_fp4 ,
534- all_rank_num_tokens , all_rank_max_num_tokens ,
535- do_finalize ):
534+ all_rank_num_tokens , do_finalize ):
536535 # max-throughput
537536 use_dp_padding = False
538537 if self .use_dp and self .mapping .tp_size > 1 :
@@ -551,7 +550,6 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
551550 do_finalize = do_finalize ,
552551 output_dtype = hidden_states .dtype ,
553552 all_rank_num_tokens = all_rank_num_tokens ,
554- all_rank_max_num_tokens = all_rank_max_num_tokens ,
555553 use_dp_padding = use_dp_padding ,
556554 )
557555
@@ -562,7 +560,6 @@ def forward(
562560 hidden_states : torch .Tensor ,
563561 hidden_states_fp4 : Optional [Fp4QuantizedTensor ] = None ,
564562 all_rank_num_tokens : Optional [list [int ]] = None ,
565- all_rank_max_num_tokens : Optional [int ] = None ,
566563 final_all_reduce_params : Optional [AllReduceParams ] = None ,
567564 do_finalize : Optional [bool ] = True ,
568565 ) -> torch .Tensor :
@@ -581,7 +578,6 @@ def _compute_routed_output():
581578 routed_output = self .compute_routed_output (hidden_states ,
582579 hidden_states_fp4 ,
583580 all_rank_num_tokens ,
584- all_rank_max_num_tokens ,
585581 do_finalize )
586582 return routed_output
587583
@@ -804,7 +800,6 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
804800 hidden_states ,
805801 hidden_states_fp4 ,
806802 all_rank_num_tokens = attn_metadata .all_rank_num_tokens ,
807- all_rank_max_num_tokens = attn_metadata .all_rank_max_num_tokens ,
808803 final_all_reduce_params = AllReduceParams (
809804 enable_allreduce = not (self .fusion_config .POST_MOE_FUSION
810805 or self .mapping .tp_size == 1 )),
@@ -992,7 +987,6 @@ def forward(
992987 embed_tokens : Embedding ,
993988 attn_metadata : AttentionMetadata ,
994989 all_rank_num_tokens : Optional [List [int ]] = None ,
995- all_rank_max_num_tokens : Optional [int ] = None ,
996990 ** kwargs ,
997991 ) -> torch .Tensor :
998992
@@ -1051,7 +1045,6 @@ def norm_hidden():
10511045 hidden_states = self .mlp (
10521046 hidden_states ,
10531047 all_rank_num_tokens = all_rank_num_tokens ,
1054- all_rank_max_num_tokens = all_rank_max_num_tokens ,
10551048 final_all_reduce_params = AllReduceParams (
10561049 enable_allreduce = not (self .fusion_config .POST_MOE_FUSION
10571050 or self .mapping .tp_size == 1 )),
0 commit comments