@@ -463,15 +463,14 @@ def forward_chunk(
463463 if not use_postquant_alltoall :
464464 deep_ep_topk_idx = token_selected_slots
465465 deep_ep_topk_weights = token_final_scales
466+ assert all_rank_max_num_tokens <= self .deep_ep_max_num_tokens
466467 x , recv_expert_count , deep_ep_handle = \
467- self .deep_ep_buffer .low_latency_dispatch (x , deep_ep_topk_idx , self . deep_ep_max_num_tokens , self .num_slots )
468- # x shape: [#local experts, EP size * deep_ep_max_num_tokens , hidden_size]
468+ self .deep_ep_buffer .low_latency_dispatch (x , deep_ep_topk_idx , all_rank_max_num_tokens , self .num_slots )
469+ # x shape: [#local experts, EP size * all_rank_max_num_tokens , hidden_size]
469470 # recv_expert_count shape: [#local experts]
470471
471472 # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
472473 # TODO: remove the adapter by changing `torch.ops.trtllm.fused_moe` API
473- x = x [:, :self .mapping .moe_ep_size *
474- all_rank_max_num_tokens ]
475474 mask = torch .arange (
476475 x .shape [1 ], dtype = torch .int32 , device = x .device ).expand (
477476 x .shape [0 ],
@@ -615,26 +614,14 @@ def forward_chunk(
615614
616615 deep_ep_topk_idx = token_selected_slots
617616 deep_ep_topk_weights = token_final_scales
618- # Each LL combine/dispatch kernel call requires that the `dispatch_rdma_recv_count_buffer` be properly cleaned.
619- # However, the offset of this buffer within the entire RDMA buffer changes according to the hidden size.
620- # Therefore, if the hidden size for the next LL dispatch/combine call is different from the current kernel call, manual cleaning is necessary.
621- if packed_hidden_size != hidden_size :
622- self .deep_ep_buffer .clean_low_latency_buffer (
623- self .deep_ep_max_num_tokens , packed_hidden_size ,
624- self .num_slots )
617+
618+ assert all_rank_max_num_tokens <= self .deep_ep_max_num_tokens
625619 fp4_packed_tensor , recv_expert_count , deep_ep_handle = \
626- self .deep_ep_buffer .low_latency_dispatch (fp4_packed_tensor , deep_ep_topk_idx , self .deep_ep_max_num_tokens , self .num_slots )
627- if packed_hidden_size != hidden_size :
628- self .deep_ep_buffer .clean_low_latency_buffer (
629- self .deep_ep_max_num_tokens , hidden_size ,
630- self .num_slots )
620+ self .deep_ep_buffer .low_latency_dispatch (fp4_packed_tensor , deep_ep_topk_idx , all_rank_max_num_tokens , self .num_slots )
631621 deep_ep_handle = list (deep_ep_handle )
632622 deep_ep_handle [3 ] = hidden_size
633623 deep_ep_handle = tuple (deep_ep_handle )
634624
635- fp4_packed_tensor = fp4_packed_tensor [:, :self .mapping .
636- moe_ep_size *
637- all_rank_max_num_tokens ]
638625 assert fp4_packed_tensor .ndim == 3 and fp4_packed_tensor .shape [
639626 2 ] == packed_hidden_size
640627 x_sf = fp4_packed_tensor [:, :, x .shape [1 ]:x .shape [1 ] +
@@ -707,23 +694,9 @@ def forward_chunk(
707694 final_hidden_states , deep_ep_handle )
708695 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
709696 num_tokens_per_expert_for_fused_moe = self .mapping .moe_ep_size * all_rank_max_num_tokens
710- num_tokens_per_expert_for_deep_ep = self .deep_ep_max_num_tokens * self .mapping .moe_ep_size
711697 final_hidden_states = final_hidden_states .view (
712698 self .expert_size_per_partition ,
713699 num_tokens_per_expert_for_fused_moe , self .hidden_size )
714- if num_tokens_per_expert_for_deep_ep != num_tokens_per_expert_for_fused_moe :
715- # Adapter between fused_moe num_tokens and DeepEP num_tokens
716- # This adapter can be removed if fused_moe accepts DeepEP num_tokens without overhead
717- final_hidden_states_for_fused_moe = final_hidden_states
718- final_hidden_states = torch .empty (
719- self .expert_size_per_partition ,
720- self .deep_ep_max_num_tokens * self .mapping .moe_ep_size ,
721- self .hidden_size ,
722- dtype = final_hidden_states .dtype ,
723- device = final_hidden_states .device )
724- final_hidden_states [:, :
725- num_tokens_per_expert_for_fused_moe ] = final_hidden_states_for_fused_moe
726- del final_hidden_states_for_fused_moe # Release memory
727700 final_hidden_states = self .deep_ep_buffer .low_latency_combine (
728701 final_hidden_states , deep_ep_topk_idx , deep_ep_topk_weights ,
729702 deep_ep_handle )
0 commit comments