diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 80a4475e3c6..c1a74bb9afe 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -142,34 +142,29 @@ def __init__( if self.enable_alltoall: self.use_low_precision_combine = model_config.use_low_precision_moe_combine - if self.alltoall_method_type == AlltoallMethodType.MNNVL: - if self.moe_alltoall_backend == "NVLINK_TWO_SIDED": - MnnvlMemory.initialize() - self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( - model_config.mapping) - self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( - model_config.mapping) - elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED": - workspace_mb = int( - os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048")) - self.moe_a2a = MoeAlltoAll( - mapping=self.mapping, - max_num_tokens=model_config.max_num_tokens, - top_k=self.routing_method.experts_per_token, - num_experts=self.num_slots, - workspace_size_per_rank=workspace_mb * 1024 * 1024, - ) - else: - raise ValueError( - f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" - ) + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: + workspace_mb = int( + os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048")) + self.moe_a2a = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=model_config.max_num_tokens, + top_k=self.routing_method.experts_per_token, + num_experts=self.num_slots, + workspace_size_per_rank=workspace_mb * 1024 * 1024, + ) elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: raise NotImplementedError( "DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet" ) else: raise NotImplementedError( - f"Not available alltoall method type: {self.alltoall_method_type!r}" + f"Unsupported alltoall method type: {self.alltoall_method_type!r}" ) # If True, the router weight will be multiplied on the input rather than at the end of FC2 @@ -235,15 +230,11 @@ def select_alltoall_method_type(self) -> AlltoallMethodType: ) return AlltoallMethodType[all2all_method_type] - if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1": - return AlltoallMethodType.NotEnabled - - # TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter, - # regardless of the relationship between EP size and topK. We favor AlltoAll for now. + # TODO: We found that NVLinkOneSided performs better than NCCL AllGather/ReduceScatter, + # regardless of the relationship between EP size and topK. We favor NVLinkOneSided for now. # if not self.mapping.moe_ep_size > self.routing_method.experts_per_token: # return AlltoallMethodType.NotEnabled - - return AlltoallMethodType.MNNVL + return AlltoallMethodType.NVLinkOneSided @cached_property def enable_alltoall(self): @@ -251,12 +242,6 @@ def enable_alltoall(self): """ return self.alltoall_method_type != AlltoallMethodType.NotEnabled - @cached_property - def moe_alltoall_backend(self): - # "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED" - return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND", - "NVLINK_ONE_SIDED").strip().upper() - def _supports_load_balancer(self) -> bool: """CutlassFusedMoE supports load balancer.""" return True @@ -328,7 +313,7 @@ def forward_chunk( if self.layer_load_balancer: self._load_balancer_done_wait_gpu_stage(is_first_call) - ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED" + ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided self._load_balancer_update_statistic( token_selected_experts, is_first_call, @@ -439,7 +424,7 @@ def forward_chunk( token_final_scales = torch.ones_like(token_selected_slots, dtype=torch.float32) - if self.moe_alltoall_backend == "NVLINK_TWO_SIDED": + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" if is_last_call: loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( @@ -472,7 +457,7 @@ def forward_chunk( token_selected_slots, alltoall_info.recv_rank_count_cumsum, runtime_max_tokens_per_rank, top_k, self.num_slots, self.ep_size) - elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED": + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: # Python MoeAlltoAll path if x_sf is not None: x_sf = x_sf.view(x_row, @@ -510,7 +495,7 @@ def forward_chunk( -1, token_final_scales_recv.shape[-1]) else: raise ValueError( - f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" + f"Unsupported moe alltoall method type: {self.alltoall_method_type}" ) elif run_post_quant_allgather: @@ -532,7 +517,7 @@ def forward_chunk( # Optionally provide an output tensor to fused_moe so it writes directly to our buffer moe_output: Optional[torch.Tensor] = None - if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED": + if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: # Retrieve a workspace-backed output tensor sized by runtime tokens runtime_max_tokens_per_rank = max( all_rank_num_tokens) if all_rank_num_tokens else x.shape[0] @@ -583,7 +568,7 @@ def forward_chunk( # Combine results if using alltoall if self.enable_alltoall: - if self.moe_alltoall_backend == "NVLINK_TWO_SIDED": + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: if alltoall_info is not None: top_k = self.routing_method.experts_per_token final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine( @@ -596,7 +581,7 @@ def forward_chunk( use_low_precision_combine=self. use_low_precision_combine, token_count=token_count) - elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED": + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: output_hidden_size = final_hidden_states.shape[-1] runtime_max_tokens_per_rank = max( all_rank_num_tokens) if all_rank_num_tokens else token_count @@ -608,7 +593,7 @@ def forward_chunk( payload_in_workspace=True) else: raise ValueError( - f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" + f"Unsupported moe alltoall method type: {self.alltoall_method_type}" ) self._load_balancer_done_set_cpu_stage(is_last_call) @@ -708,7 +693,10 @@ def _reducescatter_or_allreduce(x_, idx): # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap for idx_chunk, (x, router_logits) in enumerate( zip(x_list, router_logits_list)): - if not (self.alltoall_method_type == AlltoallMethodType.MNNVL): + if not (self.alltoall_method_type + == AlltoallMethodType.NVLinkOneSided + or self.alltoall_method_type + == AlltoallMethodType.NVLinkTwoSided): if idx_chunk % 2 == 0: with torch.cuda.stream(self.aux_stream): outputs = _forward_chunk(x, router_logits, @@ -726,7 +714,10 @@ def _reducescatter_or_allreduce(x_, idx): outputs_list.append(outputs) - if not (self.alltoall_method_type == AlltoallMethodType.MNNVL): + if not (self.alltoall_method_type + == AlltoallMethodType.NVLinkOneSided + or self.alltoall_method_type + == AlltoallMethodType.NVLinkTwoSided): if num_chunks % 2 == 0: outputs_list[-1] = _reducescatter_or_allreduce( outputs_list[-1], -1) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index fbfd7808e3c..e080486c8b7 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -119,34 +119,29 @@ def __init__( if self.enable_alltoall: self.use_low_precision_combine = model_config.use_low_precision_moe_combine - if self.alltoall_method_type == AlltoallMethodType.MNNVL: - if self.moe_alltoall_backend == "NVLINK_TWO_SIDED": - MnnvlMemory.initialize() - self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( - model_config.mapping) - self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( - model_config.mapping) - elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED": - workspace_mb = int( - os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048")) - self.moe_a2a = MoeAlltoAll( - mapping=self.mapping, - max_num_tokens=model_config.max_num_tokens, - top_k=self.routing_method.experts_per_token, - num_experts=self.num_slots, - workspace_size_per_rank=workspace_mb * 1024 * 1024, - ) - else: - raise ValueError( - f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" - ) + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: + workspace_mb = int( + os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048")) + self.moe_a2a = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=model_config.max_num_tokens, + top_k=self.routing_method.experts_per_token, + num_experts=self.num_slots, + workspace_size_per_rank=workspace_mb * 1024 * 1024, + ) elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: raise NotImplementedError( "DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet" ) else: raise NotImplementedError( - f"Not available alltoall method type: {self.alltoall_method_type!r}" + f"Unsupported alltoall method type: {self.alltoall_method_type!r}" ) self._weights_created = False @@ -176,15 +171,11 @@ def select_alltoall_method_type(self) -> AlltoallMethodType: ) return AlltoallMethodType[all2all_method_type] - if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1": - return AlltoallMethodType.NotEnabled - - # TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter, - # regardless of the relationship between EP size and topK. We favor AlltoAll for now. + # TODO: We found that NVLinkOneSided performs better than NCCL AllGather/ReduceScatter, + # regardless of the relationship between EP size and topK. We favor NVLinkOneSided for now. # if not self.mapping.moe_ep_size > self.routing_method.experts_per_token: # return AlltoallMethodType.NotEnabled - - return AlltoallMethodType.MNNVL + return AlltoallMethodType.NVLinkOneSided def _supports_load_balancer(self) -> bool: """TRTLLMGenFusedMoE supports load balancer.""" @@ -196,12 +187,6 @@ def enable_alltoall(self): """ return self.alltoall_method_type != AlltoallMethodType.NotEnabled - @cached_property - def moe_alltoall_backend(self): - # "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED" - return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND", - "NVLINK_ONE_SIDED").strip().upper() - def _check_configs(self): assert self.has_deepseek_fp8_block_scales \ or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \ @@ -362,7 +347,7 @@ def forward_impl( self._load_balancer_done_wait_gpu_stage(is_first_call) - ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED" + ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided self._load_balancer_update_statistic( token_selected_experts, is_first_call, @@ -394,7 +379,7 @@ def forward_impl( else: token_final_scales = token_final_scales.to(torch.float32) - if self.moe_alltoall_backend == "NVLINK_TWO_SIDED": + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" if is_last_call: loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( @@ -444,7 +429,7 @@ def forward_impl( if token_final_scales is not None: token_final_scales = token_final_scales.to(torch.bfloat16) - elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED": + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: if x_sf is not None: x_sf = x_sf.view(x_row, ceil_div(x_col, self.scaling_vector_size)) @@ -486,7 +471,7 @@ def forward_impl( token_final_scales = token_final_scales.to(torch.bfloat16) else: raise ValueError( - f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" + f"Unsupported moe alltoall method type: {self.alltoall_method_type}" ) elif run_post_quant_allgather: @@ -510,7 +495,7 @@ def forward_impl( moe_output: Optional[torch.Tensor] = None use_workspace_output = False # TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now - if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED" and self.has_w4a8_mxfp4_mxfp8: + if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided and self.has_w4a8_mxfp4_mxfp8: moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace( runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16) use_workspace_output = True @@ -774,7 +759,7 @@ def forward_impl( # Combine results if using alltoall if self.enable_alltoall: - if self.moe_alltoall_backend == "NVLINK_TWO_SIDED": + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: if alltoall_info is not None: final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine( final_hidden_states, @@ -787,7 +772,7 @@ def forward_impl( use_low_precision_combine, token_count=token_count, ) - elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED": + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: # If use_workspace_output=True, the MoE result is already in workspace # Otherwise, we need to reshape and pass it if use_workspace_output: @@ -810,7 +795,7 @@ def forward_impl( payload_in_workspace=False) else: raise ValueError( - f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}" + f"Unsupported moe alltoall method type: {self.alltoall_method_type}" ) final_hidden_states = self.reducescatter_or_allreduce( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 7f1e819484a..4e164ecafb6 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -1,5 +1,4 @@ import os -from functools import cached_property from typing import Dict, List, Optional, Tuple, Union import torch @@ -124,7 +123,7 @@ def __init__( "TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1") self.use_low_precision_combine = model_config.use_low_precision_moe_combine - if self.alltoall_method_type == AlltoallMethodType.MNNVL: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: MnnvlMemory.initialize() self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( model_config.mapping) @@ -150,7 +149,7 @@ def __init__( hidden_size, self.num_slots) else: raise NotImplementedError( - f"Not available alltoall method type: {self.alltoall_method_type!r}" + f"Unsupported alltoall method type: {self.alltoall_method_type!r}" ) self.use_fused_finalize = not model_config.moe_disable_finalize_fusion @@ -206,9 +205,14 @@ def is_deepep_feasible(num_ranks: int) -> bool: num_rdma_nodes = num_ranks // mpi_size return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS - all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD") - if all2all_method_type is not None: - return AlltoallMethodType[all2all_method_type] + all2all_method_type_env = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD") + if all2all_method_type_env is not None: + alltoall_method_type = AlltoallMethodType[all2all_method_type_env] + if alltoall_method_type == AlltoallMethodType.NVLinkOneSided: + raise NotImplementedError( + "NVLinkOneSided is not supported for WideEPMoE. Please use NVLinkTwoSided or switch to CutlassFusedMoE." + ) + return alltoall_method_type if not mapping.enable_attention_dp: return AlltoallMethodType.NotEnabled @@ -216,14 +220,11 @@ def is_deepep_feasible(num_ranks: int) -> bool: if mapping.tp_size == 1: return AlltoallMethodType.NotEnabled - if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1": - return AlltoallMethodType.NotEnabled - if mapping.moe_ep_size <= top_k: return AlltoallMethodType.NotEnabled if MnnvlMemory.supports_mnnvl(): - return AlltoallMethodType.MNNVL + return AlltoallMethodType.NVLinkTwoSided if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1": if deep_ep_installed and dtype == torch.bfloat16: @@ -246,19 +247,13 @@ def enable_alltoall(self): """ return self.alltoall_method_type != AlltoallMethodType.NotEnabled - @cached_property - def moe_alltoall_backend(self): - # "NVLINK_TWO_SIDED" (default) or "NVLINK_ONE_SIDED" - return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND", - "NVLINK_TWO_SIDED").strip().upper() - def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int: num_rows = sum(all_rank_num_tokens) return (num_rows + self.moe_max_num_tokens - 1) // self.moe_max_num_tokens def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens): - if self.alltoall_method_type == AlltoallMethodType.MNNVL: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: return True # Disable alltoall when chunking is used @@ -377,7 +372,7 @@ def reducescatter_or_allreduce( def is_post_quant_all2all_supported(self): if not self.use_postquant_alltoall: return False - if self.alltoall_method_type == AlltoallMethodType.MNNVL: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: return True elif self.alltoall_method_type == AlltoallMethodType.DeepEP: return self.has_nvfp4 @@ -407,7 +402,7 @@ def forward_chunk( self._load_balancer_start_wait_gpu_stage(is_first_call) - if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL: + if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.NVLinkTwoSided: alltoall_result_do_sum = True weight_dtype = self.w3_w1_weight.dtype @@ -436,7 +431,7 @@ def forward_chunk( if self.layer_load_balancer: self._load_balancer_done_wait_gpu_stage(is_first_call) - ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED" + ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided self._load_balancer_update_statistic(token_selected_experts, is_first_call, is_last_call, ignore_allreduce) @@ -470,7 +465,7 @@ def forward_chunk( tuner_top_k = None alltoall_info = None if use_all_to_all: - if self.alltoall_method_type == AlltoallMethodType.MNNVL: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: if self.enable_dummy_allreduce: self.dummy_allreduce() token_count = x.shape[0] @@ -561,14 +556,14 @@ def forward_chunk( w2_weight = self.w2_weight quant_scales = self.quant_scales - if self.alltoall_method_type == AlltoallMethodType.MNNVL: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: top_k = self.routing_method.experts_per_token x, x_sf, token_selected_slots, token_final_scales = self.alltoall_dispatch( x, x_sf, token_selected_slots, token_final_scales, all_rank_max_num_tokens, top_k, alltoall_info) if use_postquant_alltoall: - if self.alltoall_method_type == AlltoallMethodType.MNNVL: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: pass elif self.alltoall_method_type == AlltoallMethodType.DeepEP: assert self.has_nvfp4, "DeepEP postquant alltoall should have nvfp4" @@ -629,7 +624,7 @@ def forward_chunk( x, x_sf, recv_expert_count, token_final_scales.dtype) else: raise NotImplementedError( - f"Not available alltoall method type: {self.alltoall_method_type!r}" + f"Unsupported alltoall method type: {self.alltoall_method_type!r}" ) final_hidden_states = self.moe_op_impl.run_moe( @@ -659,7 +654,7 @@ def forward_chunk( final_hidden_states = final_hidden_states[0] if use_all_to_all: - if self.alltoall_method_type == AlltoallMethodType.MNNVL: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: if self.enable_dummy_allreduce: self.dummy_allreduce() final_hidden_states = self.alltoall_combine( @@ -692,7 +687,7 @@ def forward_chunk( deep_ep_topk_weights, deep_ep_handle) else: raise NotImplementedError( - f"Not available alltoall method type: {self.alltoall_method_type!r}" + f"Unsupported alltoall method type: {self.alltoall_method_type!r}" ) self._load_balancer_done_set_cpu_stage(is_last_call) diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index d8690e5cc88..1856c287264 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -26,12 +26,14 @@ class MoEWeightLoadingMode(Enum): class AlltoallMethodType(IntEnum): # Not available NotEnabled = 0 - # MNNVL - MNNVL = 1 + # NVLink One-Sided + NVLinkOneSided = 1 + # NVLink Two-Sided + NVLinkTwoSided = 2 # DeepEP intranode or internode: CUDA Graphs are supported, IBGDA is required by internode - DeepEP = 2 + DeepEP = 3 # DeepEP low latency: CUDA Graphs are supported, IBGDA is required - DeepEPLowLatency = 3 + DeepEPLowLatency = 4 def extract_extra_attrs(layer_idx: str): diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 65fb89cf608..c2e3594e2fb 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -210,7 +210,7 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") @pytest.mark.parametrize("alltoall_method_type", [ - AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP, + AlltoallMethodType.NVLinkTwoSided, AlltoallMethodType.DeepEP, AlltoallMethodType.DeepEPLowLatency ], ids=lambda s: s.name) @@ -302,7 +302,7 @@ def per_rank_test_fused_moe_alltoall(job_id): all_rank_num_tokens=all_rank_num_tokens, use_dp_padding=False) - if alltoall_method_type == AlltoallMethodType.MNNVL and output.ndim == 3: + if output.ndim == 3: output = output.sum(dim=1) print(f"output: {output.shape}") print(f"ref_output: {ref_output.shape}") @@ -323,7 +323,7 @@ def per_rank_test_fused_moe_alltoall(job_id): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") @pytest.mark.parametrize("alltoall_method_type", [ - AlltoallMethodType.MNNVL, AlltoallMethodType.DeepEP, + AlltoallMethodType.NVLinkTwoSided, AlltoallMethodType.DeepEP, AlltoallMethodType.DeepEPLowLatency ], ids=lambda s: s.name) @@ -684,7 +684,7 @@ def set_tensor_value_4(x, num_row, num_cols): reason="needs 4 GPUs to run this test") @pytest.mark.parametrize( "alltoall_method_type", - [AlltoallMethodType.MNNVL, AlltoallMethodType.NotEnabled], + [AlltoallMethodType.NVLinkTwoSided, AlltoallMethodType.NotEnabled], ids=lambda s: s.name) def test_fused_moe_fp8_blockwise_wide_ep(alltoall_method_type): """Test WideEPMoE with FP8 block-wise quantization using DeepGemmFusedMoE as reference."""