diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index b7e42fc09b0..232683a2765 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -165,6 +165,11 @@ def get_all_reduce_strategy(strategy: str = "AUTO"): self.allreduce_strategy = get_all_reduce_strategy( self.allreduce_strategy) + # Set default moe_max_num_tokens if not specified + # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled + if self.moe_max_num_tokens is None: + self.moe_max_num_tokens = self.max_num_tokens * self.mapping.dp_size + @property def torch_dtype(self) -> torch.dtype: """Get the torch dtype of the model.""" diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index a81a5141fa3..00445a520b7 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -55,7 +55,7 @@ from ..modules.attention import MLA from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, +from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoE, MoEWeightLoadingMode, create_moe) from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE from ..modules.gated_mlp import GatedMLP @@ -382,6 +382,21 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, "gate_proj": "w1", }) module.load_weights(weights=[module_weights]) + elif names[-1] == "backend" and isinstance(module, MoE): + # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) + # Currently saved MoE weights don't include 'backend' in their names. + # After MoE refactoring, ConfigurableMoE now has a backend submodule, + # and weights loading is done in the backend, so module name includes '.backend'. + # We need to use parent module name (without .backend) to match saved weight names. + # After MoE refactoring is fully complete, all paths will follow this branch. + parent_name = '.'.join(names[:-1]) + module_weights = filter_weights(parent_name, weights) + module_weights = rename_moe_weight(module_weights, { + "down_proj": "w2", + "up_proj": "w3", + "gate_proj": "w1", + }) + module.load_weights(weights=[module_weights]) elif names[-1] == "self_attn": continue elif names[-1] == "next_layer_layernorm": diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py index 546d17d5985..ec4621fb08c 100644 --- a/tensorrt_llm/_torch/models/modeling_gpt_oss.py +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -657,6 +657,18 @@ def load_hf_weights(self, weights: Dict): module_weights = {} for k, v in self.hf_params_map.items(): name = name.replace(k, v) + + # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) + # Currently saved MoE weights don't include 'backend' in their names. + # After MoE refactoring, ConfigurableMoE now has a backend submodule, + # and weights loading is done in the backend, so module name includes '.backend'. + # We need to use parent module name (without .backend) to match saved weight names. + # After MoE refactoring is fully complete, all paths will follow this branch. + names = name.split('.') + if names[-1] == "backend" and isinstance(module, MoE): + # Backend is under experts module (ConfigurableMoE wrapper) + name = '.'.join(names[:-1]) + module_weights = filter_weights(name, weights) if isinstance(module, MoE): diff --git a/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py index 93890ad532d..89e43d869b2 100644 --- a/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py +++ b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py @@ -15,8 +15,9 @@ from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import (CutlassFusedMoE, RenormalizeMoeRoutingMethod, - VanillaMoE, create_moe) +from ..modules.fused_moe import (CutlassFusedMoE, MoE, + RenormalizeMoeRoutingMethod, VanillaMoE, + create_moe) from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear, TensorParallelMode from ..modules.multi_stream_utils import maybe_execute_in_parallel @@ -364,6 +365,17 @@ def filter_weights(prefix, weights: Dict): "lm_head"): continue names = name.split('.') + + # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) + # Currently saved MoE weights don't include 'backend' in their names. + # After MoE refactoring, ConfigurableMoE now has a backend submodule, + # and weights loading is done in the backend, so module name includes '.backend'. + # We need to use parent module name (without .backend) to match saved weight names. + # After MoE refactoring is fully complete, all paths will follow this branch. + if names[-1] == "backend" and isinstance(module, MoE): + name = '.'.join(names[:-1]) + names = name.split('.') + if names[-1] in params_map: # model.layers.{idx}.mlp.shared_mlp.gate_up_proj or model.layers.{idx}.self_attn.qkv_proj module_weights = [] diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 67f5f922370..d17bcab2df4 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -868,6 +868,17 @@ def load_single_module(name, module): return names = name.split('.') + + # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) + # Currently saved MoE weights don't include 'backend' in their names. + # After MoE refactoring, ConfigurableMoE now has a backend submodule, + # and weights loading is done in the backend, so module name includes '.backend'. + # We need to use parent module name (without .backend) to match saved weight names. + # After MoE refactoring is fully complete, all paths will follow this branch. + if names[-1] == "backend" and isinstance(module, MoE): + name = '.'.join(names[:-1]) + names = name.split('.') + # WAR: better solution is that llama has its own load_weights function. if names[-1] == 'next_layer_layernorm': return @@ -968,6 +979,17 @@ def load_single_module(name, module): return names = name.split('.') + + # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) + # Currently saved MoE weights don't include 'backend' in their names. + # After MoE refactoring, ConfigurableMoE now has a backend submodule, + # and weights loading is done in the backend, so module name includes '.backend'. + # We need to use parent module name (without .backend) to match saved weight names. + # After MoE refactoring is fully complete, all paths will follow this branch. + if names[-1] == "backend" and isinstance(module, MoE): + name = '.'.join(names[:-1]) + names = name.split('.') + module_names_breakdown, module_name = names[:-1], names[-1] if weight_mapper.does_require_special_handling(module_name): diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py index ece7131ecbc..0d44ecd2df1 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/__init__.py @@ -20,8 +20,8 @@ Available Communication Methods: - AllGatherReduceScatter: Default fallback method, always available -- MnnvlLatency: MNNVL-optimized communication for latency -- MNNVLThroughput: MNNVL-optimized communication for throughput +- NVLinkTwoSided: NVLINK-optimized communication for latency (formerly MNNVLLatency) +- NVLinkOneSided: NVLINK-optimized communication for throughput (formerly MNNVLThroughput) - DeepEP: Deep Expert Parallelism with support for large batches - DeepEPLowLatency: Deep Expert Parallelism optimized for low latency @@ -34,16 +34,16 @@ from .communication_factory import CommunicationFactory from .deep_ep import DeepEP from .deep_ep_low_latency import DeepEPLowLatency -from .mnnvl_latency import MnnvlLatency -from .mnnvl_throughput import MNNVLThroughput +from .nvlink_one_sided import NVLinkOneSided +from .nvlink_two_sided import NVLinkTwoSided __all__ = [ # Base classes and types "Communication", # Communication strategies "AllGatherReduceScatter", - "MnnvlLatency", - "MNNVLThroughput", + "NVLinkTwoSided", + "NVLinkOneSided", "DeepEP", "DeepEPLowLatency", # Factory diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/allgather_reducescatter.py b/tensorrt_llm/_torch/modules/fused_moe/communication/allgather_reducescatter.py index 9e175853d52..706de4247a8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/allgather_reducescatter.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/allgather_reducescatter.py @@ -42,6 +42,13 @@ def __init__( # Initialize dispatch state self._dispatch_state = {} + @staticmethod + def is_platform_supported() -> bool: + """ + AllGather + ReduceScatter is always supported as the fallback strategy + """ + return True + def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool: """ Check if AllGather is feasible for the given workload at runtime. diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/base.py b/tensorrt_llm/_torch/modules/fused_moe/communication/base.py index bfacf9f54db..b98f92830e1 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/base.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/base.py @@ -50,6 +50,31 @@ def __init__( self.ep_size = mapping.moe_ep_size self.ep_rank = mapping.moe_ep_rank + # Check platform support and raise error if not supported + if not self.is_platform_supported(): + raise RuntimeError( + f"Communication strategy {self.__class__.__name__} " + f"is not supported on this platform." + ) + self._is_platform_supported = True + + @staticmethod + @abstractmethod + def is_platform_supported() -> bool: + """ + Check if this communication strategy is supported on the current platform. + + This method performs platform/hardware checks to determine if the strategy + can be used on the current system. + + Returns: + True if platform is supported, False otherwise + + Note: This is a static method that can be called before instantiation + to check compatibility without creating an instance. + """ + raise NotImplementedError + @abstractmethod def is_workload_feasible( self, diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py index 586b1cadd4d..0f5e3124b79 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py @@ -26,45 +26,14 @@ import torch from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._utils import local_mpi_size +from tensorrt_llm.logger import logger from .allgather_reducescatter import AllGatherReduceScatter from .base import Communication from .deep_ep import DeepEP from .deep_ep_low_latency import DeepEPLowLatency -from .mnnvl_latency import MnnvlLatency -from .mnnvl_throughput import MNNVLThroughput - - -def is_high_throughput() -> bool: - """ - Check if high throughput mode is enabled - """ - return True - - -def is_deepep_feasible(num_ranks: int) -> bool: - """ - Check if DeepEP is feasible for the given number of ranks - - DeepEP supports two modes: - 1. Intranode: Single node with 2, 4, or 8 ranks - 2. Internode: 2, 4, 8, or 16 nodes with 8 ranks per node - """ - NUM_INTRANODE_SUPPORTED_RANKS = {2, 4, 8} - REQUIRED_LOCAL_MPI_SIZE = 8 - NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2, 4, 8, 16} - mpi_size = local_mpi_size() - - # Intranode cases - if num_ranks == mpi_size and num_ranks in NUM_INTRANODE_SUPPORTED_RANKS: - return True - - # Internode cases - if mpi_size != REQUIRED_LOCAL_MPI_SIZE: - return False - num_rdma_nodes = num_ranks // mpi_size - return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS +from .nvlink_one_sided import NVLinkOneSided +from .nvlink_two_sided import NVLinkTwoSided class CommunicationFactory: @@ -72,7 +41,7 @@ class CommunicationFactory: Factory for creating MoE communication methods Selects the best communication method based on: - - Hardware support (MNNVL, DeepEP) + - Hardware support (NVLINK, DeepEP) - Configuration settings - Workload characteristics """ @@ -85,18 +54,19 @@ def create_strategy( top_k: int, expert_size_per_partition: int, payload_in_workspace: bool = False, - alltoall_result_do_sum: bool = False, + alltoall_result_do_sum: bool = True, ) -> Optional[Communication]: """ Create the best communication method for the given configuration - Selection priority: - 1. Force method (if specified via TRTLLM_FORCE_ALLTOALL_METHOD env) - 2. MNNVL (if hardware supports) - - Selects latency or throughput backend based on TRTLLM_MOE_ALLTOALL_BACKEND env - - Default: "mnnvllatency", alternative: "mnnvlthroughput" - 3. DeepEP / DeepEPLowLatency (if enabled and hardware supports) - 4. AllGather + ReduceScatter (fallback, always works) + Selection priority (using try-catch mechanism): + 1. Force method (if specified via TRTLLM_FORCE_COMM_METHOD env) + 2. Auto-selection (tries in order): + - NVLinkOneSided (highest priority for throughput) + - NVLinkTwoSided (high priority for latency) + - DeepEP (if enabled via TRTLLM_CAN_USE_DEEP_EP) + - DeepEPLowLatency (if enabled via TRTLLM_CAN_USE_DEEP_EP) + - AllGather + ReduceScatter (fallback, always works) Args: model_config: Model configuration containing mapping, quant_config, max_num_tokens, etc. @@ -104,8 +74,8 @@ def create_strategy( num_slots: Total number of expert slots top_k: Number of experts per token expert_size_per_partition: Number of experts per partition (required for DeepEP) - payload_in_workspace: If True, final_hidden_states is already in workspace (for MNNVLThroughput) - alltoall_result_do_sum: If True, sum the alltoall results (for MnnvlLatency) + payload_in_workspace: If True, final_hidden_states is already in workspace (for NVLinkOneSided) + alltoall_result_do_sum: If True, sum the alltoall results (for NVLinkTwoSided) Returns: The selected communication method, or None if attention does not use DP @@ -134,24 +104,9 @@ def create_strategy( return AllGatherReduceScatter(mapping) # Check if forced method is specified via environment variable - force_method = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD") + force_method = os.environ.get("TRTLLM_FORCE_COMM_METHOD") if force_method is not None: - # Validate platform support for forced method - method_upper = force_method.upper() - if method_upper in ["MNNVLLATENCY", "MNNVLTHROUGHPUT"]: - if not MnnvlLatency.is_platform_supported(): - raise RuntimeError( - f"Forced method '{force_method}' is not supported on this platform. " - "MNNVLLATENCY and MNNVLTHROUGHPUT require compatible hardware." - ) - elif method_upper in ["DEEPEP", "DEEPEPLOWLATENCY"]: - if not DeepEP.is_platform_supported(mapping): - raise RuntimeError( - f"Forced method '{force_method}' is not supported on this platform. " - "DeepEP requires compatible hardware and TRTLLM_CAN_USE_DEEP_EP=1." - ) - return CommunicationFactory._create_forced_method( force_method, model_config, @@ -163,58 +118,75 @@ def create_strategy( alltoall_result_do_sum, ) - # Try MNNVL first (highest priority) - if MnnvlLatency.is_platform_supported(): - if is_high_throughput(): - # Currently, MNNVLThroughput shows better performance at all scenarios - return MNNVLThroughput( + # Auto-selection: Try strategies in priority order using try-catch + # Priority: NVLinkOneSided > NVLinkTwoSided > DeepEP > DeepEPLowLatency > AllGather + + try: + strategy = NVLinkOneSided( + mapping, + num_slots, + top_k, + max_num_tokens_per_rank=max_num_tokens, + payload_in_workspace=payload_in_workspace, + ) + logger.info("Selected communication strategy: NVLinkOneSided") + return strategy + except RuntimeError as e: + logger.debug(f"NVLinkOneSided not available: {e}") + + try: + strategy = NVLinkTwoSided( + mapping, + num_experts, + num_slots, + top_k, + use_low_precision_combine, + alltoall_result_do_sum=alltoall_result_do_sum, + ) + logger.info("Selected communication strategy: NVLinkTwoSided") + return strategy + except RuntimeError as e: + logger.debug(f"NVLinkTwoSided not available: {e}") + + # Try DeepEP (if enabled and weight dtype is bfloat16) + if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1" and weight_dtype == torch.bfloat16: + try: + strategy = DeepEP( mapping, - num_experts, - top_k, - max_num_tokens_per_rank=max_num_tokens, - payload_in_workspace=payload_in_workspace, + num_slots, + hidden_size, + weight_dtype, + quant_config, + expert_size_per_partition, + use_cuda_graph, ) - else: - return MnnvlLatency( + logger.info("Selected communication strategy: DeepEP") + return strategy + except RuntimeError as e: + logger.debug(f"DeepEP not available: {e}") + + # Try DeepEPLowLatency as fallback when DeepEP is not available + try: + strategy = DeepEPLowLatency( mapping, - num_experts, num_slots, - top_k, + hidden_size, + weight_dtype, + quant_config, + expert_size_per_partition, + max_num_tokens, use_low_precision_combine, - alltoall_result_do_sum=alltoall_result_do_sum, + moe_max_num_tokens, ) + logger.info("Selected communication strategy: DeepEPLowLatency") + return strategy + except RuntimeError as e: + logger.debug(f"DeepEPLowLatency not available: {e}") - # Try DeepEP - if os.environ.get("TRTLLM_CAN_USE_DEEP_EP", "0") == "1": - if weight_dtype == torch.bfloat16: - if DeepEP.is_platform_supported(mapping) and is_deepep_feasible( - mapping.moe_ep_size - ): - return DeepEP( - mapping, - num_slots, - hidden_size, - weight_dtype, - quant_config, - expert_size_per_partition, - use_cuda_graph, - ) - else: - # Use DeepEP Low Latency as fallback (when not feasible or not supported) - return DeepEPLowLatency( - mapping, - num_slots, - hidden_size, - weight_dtype, - quant_config, - expert_size_per_partition, - max_num_tokens, - use_low_precision_combine, - moe_max_num_tokens, - ) - - # Fallback to AllGather + ReduceScatter - return AllGatherReduceScatter(mapping) + # Fallback to AllGather + ReduceScatter (always works) + strategy = AllGatherReduceScatter(mapping) + logger.info("Selected communication strategy: AllGatherReduceScatter (fallback)") + return strategy @staticmethod def _create_forced_method( @@ -227,7 +199,13 @@ def _create_forced_method( payload_in_workspace: bool, alltoall_result_do_sum: bool, ) -> Communication: - """Create a specific method (for debugging/testing)""" + """ + Create a specific method (for debugging/testing) + + Raises: + RuntimeError: If the forced method is not supported on this platform + ValueError: If method name is unknown + """ # Extract parameters from model_config mapping = model_config.mapping hidden_size = model_config.pretrained_config.hidden_size @@ -240,8 +218,9 @@ def _create_forced_method( method = method.upper() - if method == "MNNVLLATENCY": - return MnnvlLatency( + # Create strategy - will raise RuntimeError if platform not supported + if method in ["NVLINK_TWO_SIDED"]: + return NVLinkTwoSided( mapping, num_experts, num_slots, @@ -249,12 +228,12 @@ def _create_forced_method( use_low_precision_combine, alltoall_result_do_sum=alltoall_result_do_sum, ) - elif method == "MNNVLTHROUGHPUT": - # MNNVLThroughput requires max_num_tokens_per_rank + elif method in ["NVLINK_ONE_SIDED"]: + # NVLinkOneSided requires max_num_tokens_per_rank # max_num_tokens is per-rank value (as passed from callers like cutlass) - return MNNVLThroughput( + return NVLinkOneSided( mapping, - num_experts, + num_slots, top_k, max_num_tokens_per_rank=max_num_tokens, payload_in_workspace=payload_in_workspace, diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py index e188d4479ac..a8d71a1d6b8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep.py @@ -26,6 +26,7 @@ import torch from tensorrt_llm._torch.modules.fused_moe.deep_ep_utils import buffer_pool, deep_ep_installed +from tensorrt_llm._utils import local_mpi_size from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -50,6 +51,15 @@ def __init__( ): super().__init__(mapping) + # Check if DeepEP is feasible for the given number of ranks + if not self._is_deepep_feasible(mapping.moe_ep_size): + raise RuntimeError( + f"DeepEP is not feasible for {mapping.moe_ep_size} ranks. " + f"DeepEP supports: " + f"1) Intranode: 2, 4, or 8 ranks; " + f"2) Internode: 2, 4, 8, or 16 nodes with 8 ranks per node." + ) + # Store needed parameters self.num_slots = num_slots self.hidden_size = hidden_size @@ -67,7 +77,7 @@ def __init__( self.deep_ep_buffer.reserve(hidden_size, weight_dtype) @staticmethod - def is_platform_supported(mapping: Mapping) -> bool: + def is_platform_supported() -> bool: """ Check if DeepEP is supported on the current platform """ @@ -75,6 +85,30 @@ def is_platform_supported(mapping: Mapping) -> bool: return False return deep_ep_installed + @staticmethod + def _is_deepep_feasible(num_ranks: int) -> bool: + """ + Check if DeepEP is feasible for the given number of ranks + + DeepEP supports two modes: + 1. Intranode: Single node with 2, 4, or 8 ranks + 2. Internode: 2, 4, 8, or 16 nodes with 8 ranks per node + """ + NUM_INTRANODE_SUPPORTED_RANKS = {2, 4, 8} + REQUIRED_LOCAL_MPI_SIZE = 8 + NUM_INTERNODE_SUPPORTED_RDMA_RANKS = {2, 4, 8, 16} + mpi_size = local_mpi_size() + + # Intranode cases + if num_ranks == mpi_size and num_ranks in NUM_INTRANODE_SUPPORTED_RANKS: + return True + + # Internode cases + if mpi_size != REQUIRED_LOCAL_MPI_SIZE: + return False + num_rdma_nodes = num_ranks // mpi_size + return num_rdma_nodes in NUM_INTERNODE_SUPPORTED_RDMA_RANKS + def supports_post_quant_dispatch(self) -> bool: """ DeepEP supports post-quant dispatch only for nvfp4 @@ -94,7 +128,7 @@ def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) return False if self.weight_dtype != torch.bfloat16: return False - return self.is_platform_supported(self.mapping) + return True def dispatch( self, diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py index 9f25956467f..d2c6a8164c6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/deep_ep_low_latency.py @@ -80,7 +80,7 @@ def __init__( self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens, hidden_size, num_slots) @staticmethod - def is_platform_supported(mapping: Mapping) -> bool: + def is_platform_supported() -> bool: """ Check if DeepEP Low Latency is supported on the current platform """ @@ -113,7 +113,7 @@ def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) return False if self.weight_dtype != torch.bfloat16: return False - return self.is_platform_supported(self.mapping) + return True def dispatch( self, diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_throughput.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py similarity index 74% rename from tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_throughput.py rename to tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py index d21b0041aca..1dab6d8bad5 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_throughput.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py @@ -14,12 +14,11 @@ # limitations under the License. """ -MNNVL AllToAll Throughput Communication Strategy +NVLINK One-Sided AllToAll Communication Strategy -This module implements the MNNVL AllToAll throughput communication method for MoE. -MNNVL Throughput uses Python-based AllToAll operations for high throughput scenarios. +This module implements the NVLINK one-sided comm AllToAll throughput communication method for MoE. -MNNVL Throughput supports post-quant dispatch +NVLINK One-Sided supports post-quant dispatch. """ import os @@ -35,12 +34,16 @@ from .base import Communication -class MNNVLThroughput(Communication): +class NVLinkOneSided(Communication): """ - MNNVL AllToAll strategy for throughput scenarios - - This class uses Python-based AllToAll operations for high throughput scenarios. - It manages workspace allocation and synchronization for cross-GPU communication. + NVLINK one-sided comm AllToAll strategy for throughput scenarios. + + This implementation utilizes symmetric memory to enable peer-to-peer access between GPUs over NVLink. + The kernels only take the role as one side of the communication: the dispatch kernel puts the data + into peer ranks' symmetric memory from local buffer, while the combine kernel gets the data from peer + ranks' symmetric memory and reduces the data into local buffer. It is the most efficient implementation + by now, but requires symmetric memory size proportional to `max_num_tokens * n_ranks`, which may not + scale well for very large-scale parallelization. """ # Constants from C++ (must match moeAlltoAllKernels.h) @@ -88,7 +91,7 @@ def __init__( payload_in_workspace: bool = False, ): """ - Initialize MNNVLThroughput with workspace allocation. + Initialize NVLinkOneSided with workspace allocation. Args: mapping: TensorRT-LLM Mapping object containing rank information @@ -109,8 +112,8 @@ def __init__( # Initialize constants from C++ self._init_constants() - # Get workspace size from environment variable (default 512MB) - workspace_mb = int(os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512")) + # Get workspace size from environment variable (default 2048MB to match MoeAlltoAll) + workspace_mb = int(os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048")) self.workspace_size_per_rank = workspace_mb * 1024 * 1024 # Initialize or reuse workspace MnnvlMemory.initialize() @@ -129,7 +132,7 @@ def __init__( self.ep_size, self.max_num_tokens_per_rank, ) - MNNVLThroughput._WORKSPACE = { + NVLinkOneSided._WORKSPACE = { "workspace_size_per_rank": self.workspace_size_per_rank, "max_num_tokens_per_rank": self.max_num_tokens_per_rank, "ep_rank": self.ep_rank, @@ -163,30 +166,30 @@ def __init__( # Internal state self._state: str = "idle" # idle | dispatched - # Invalid token expert ID (default to num_experts) - self.invalid_token_expert_id: int = self.num_experts + # Invalid token expert ID (default to -1), the kernels in TRTLLM-gen is hard-code to support -1 only. + self.invalid_token_expert_id: int = -1 @staticmethod def is_platform_supported() -> bool: """ - Check if MNNVL is supported on current hardware + Check if NVLINK one-sided comm is supported on current hardware. """ return MnnvlMemory.supports_mnnvl() def supports_post_quant_dispatch(self) -> bool: """ - MNNVL Throughput supports post-quant dispatch + NVLINK one-sided comm supports post-quant dispatch. """ return True def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool: """ - Check if MNNVL Throughput is feasible for the given workload at runtime. + Check if NVLINK one-sided comm is feasible for the given workload at runtime. This method performs runtime checks based on workload characteristics such as token counts, number of chunks, and other runtime parameters. """ - return self.is_platform_supported() + return True def dispatch( self, @@ -217,28 +220,27 @@ def dispatch( if self._state == "dispatched": raise RuntimeError("dispatch called twice without an intervening combine") - # Build payloads list - token_selected_slots is always first + # Calculate runtime_max_tokens_per_rank from all_rank_num_tokens + runtime_max_tokens_per_rank = max(all_rank_num_tokens) + + # Build payloads list - match TRTLLMGen baseline order for optimal performance + # Order: [hidden_states, hidden_states_sf (optional), token_selected_slots, token_final_scales (optional)] + payloads = [] - payloads.append(token_selected_slots) payloads.append(hidden_states) if hidden_states_sf is not None: payloads.append(hidden_states_sf) + + payloads.append(token_selected_slots) if token_final_scales is not None: payloads.append(token_final_scales) - # Call AllToAll dispatch - ( - recv_buffers, - send_counters, - recv_counters, - topk_target_ranks, - topk_send_indices, - combine_payload_offset, - ) = torch.ops.trtllm.moe_a2a_dispatch( + recv_buffers, combine_payload_offset = torch.ops.trtllm.moe_a2a_dispatch( token_selected_slots, payloads, self.workspace, - self.max_num_tokens_per_rank, + self.moe_a2a_metainfo, + runtime_max_tokens_per_rank, self.ep_rank, self.ep_size, self.top_k, @@ -247,34 +249,44 @@ def dispatch( self._state = "dispatched" - # Store all dispatch state for combine (no class variables) - self._dispatch_state["topk_target_ranks"] = topk_target_ranks - self._dispatch_state["topk_send_indices"] = topk_send_indices - self._dispatch_state["send_counters"] = send_counters - self._dispatch_state["recv_counters"] = recv_counters self._dispatch_state["combine_payload_offset"] = int(combine_payload_offset) - - # Sanitize expert IDs for invalid tokens if needed - # token_selected_slots is always at index 0 in recv_buffers - recv_token_selected_slots = recv_buffers[0] - torch.ops.trtllm.moe_a2a_sanitize_expert_ids( - recv_token_selected_slots, - recv_counters, - int(self.invalid_token_expert_id), - ) + self._dispatch_state["local_num_tokens"] = token_selected_slots.size(0) + self._dispatch_state["runtime_max_tokens_per_rank"] = runtime_max_tokens_per_rank # Extract results from recv_buffers - # Payload order: [token_selected_slots, hidden_states, hidden_states_sf (optional), - # token_final_scales (optional)] - token_selected_slots_recv = recv_buffers[0] - hidden_states_recv = recv_buffers[1] + # Payload order matches input: + # [hidden_states, hidden_states_sf (optional), token_selected_slots, token_final_scales (optional)] + hidden_states_recv = recv_buffers[0] if hidden_states_sf is not None: - hidden_states_sf_recv = recv_buffers[2] + hidden_states_sf_recv = recv_buffers[1] + token_selected_slots_recv = recv_buffers[2] token_final_scales_recv = recv_buffers[3] if token_final_scales is not None else None else: hidden_states_sf_recv = None + token_selected_slots_recv = recv_buffers[1] token_final_scales_recv = recv_buffers[2] if token_final_scales is not None else None + torch.ops.trtllm.moe_a2a_sanitize_expert_ids( + token_selected_slots_recv, + self.workspace, + self.moe_a2a_metainfo, + self.ep_rank, + int(self.invalid_token_expert_id), + ) + + # Flatten 3D tensors to 2D for compatibility with MoE backends + # recv_buffers have shape [ep_size, max_tokens_per_rank, ...], flatten to [ep_size * max_tokens_per_rank, ...] + hidden_states_recv = hidden_states_recv.view(-1, hidden_states_recv.shape[-1]) + if hidden_states_sf_recv is not None: + hidden_states_sf_recv = hidden_states_sf_recv.view(-1, hidden_states_sf_recv.shape[-1]) + token_selected_slots_recv = token_selected_slots_recv.view( + -1, token_selected_slots_recv.shape[-1] + ) + if token_final_scales_recv is not None: + token_final_scales_recv = token_final_scales_recv.view( + -1, token_final_scales_recv.shape[-1] + ) + return ( hidden_states_recv, hidden_states_sf_recv, @@ -302,13 +314,15 @@ def combine( if self._state != "dispatched": raise RuntimeError("combine called before a successful dispatch") - # Read dispatch state - topk_target_ranks = self._dispatch_state.get("topk_target_ranks") - topk_send_indices = self._dispatch_state.get("topk_send_indices") - recv_counters = self._dispatch_state.get("recv_counters") + local_num_tokens = self._dispatch_state.get("local_num_tokens") combine_payload_offset = self._dispatch_state.get("combine_payload_offset") + runtime_max_tokens_per_rank = self._dispatch_state.get("runtime_max_tokens_per_rank") - if topk_target_ranks is None or topk_send_indices is None or recv_counters is None: + if ( + local_num_tokens is None + or combine_payload_offset is None + or runtime_max_tokens_per_rank is None + ): raise RuntimeError("combine called but dispatch state is missing") # Reshape if needed (handle case where input is flattened) @@ -317,7 +331,7 @@ def combine( # Reshape to: [ep_size, max_tokens_per_rank, hidden_size] hidden_size = final_hidden_states.shape[-1] final_hidden_states = final_hidden_states.view( - self.ep_size, self.max_num_tokens_per_rank, hidden_size + self.ep_size, runtime_max_tokens_per_rank, hidden_size ) elif final_hidden_states.dim() == 3: # Already shaped: [ep_size, max_tokens_per_rank, hidden_size] @@ -326,15 +340,12 @@ def combine( raise ValueError( f"final_hidden_states must be 2D or 3D, got {final_hidden_states.dim()}D" ) - - # Call AllToAll combine output = torch.ops.trtllm.moe_a2a_combine( - topk_target_ranks, - topk_send_indices, - recv_counters, final_hidden_states, + int(local_num_tokens), self.workspace, - self.max_num_tokens_per_rank, + self.moe_a2a_metainfo, + int(runtime_max_tokens_per_rank), self.ep_rank, self.ep_size, self.top_k, @@ -349,7 +360,7 @@ def combine( return output def get_combine_payload_tensor_in_workspace( - self, hidden_size: int, dtype: torch.dtype + self, runtime_max_tokens_per_rank: int, hidden_size: int, dtype: torch.dtype ) -> torch.Tensor: """ Return the combine payload tensor in the workspace, which could be used @@ -357,6 +368,7 @@ def get_combine_payload_tensor_in_workspace( See "payload_in_workspace" in combine method. Args: + runtime_max_tokens_per_rank: Runtime max tokens per rank hidden_size: Hidden dimension size dtype: Data type @@ -372,12 +384,14 @@ def get_combine_payload_tensor_in_workspace( if combine_payload_offset is None: raise RuntimeError("combine_payload_offset not found in dispatch state") - return torch.ops.trtllm.moe_a2a_get_combine_payload_tensor( + result = torch.ops.trtllm.moe_a2a_get_combine_payload_tensor( self.workspace, int(self.ep_rank), int(self.ep_size), - int(self.max_num_tokens_per_rank), + int(runtime_max_tokens_per_rank), int(combine_payload_offset), dtype, int(hidden_size), ) + + return result diff --git a/tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_latency.py b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_two_sided.py similarity index 79% rename from tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_latency.py rename to tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_two_sided.py index 82a162beaf4..c38cf3391e3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/communication/mnnvl_latency.py +++ b/tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_two_sided.py @@ -14,13 +14,11 @@ # limitations under the License. """ -MNNVL AllToAll Communication Strategy +NVLINK Two-Sided AllToAll Communication Strategy -This module implements the MNNVL AllToAll communication method for MoE. -MNNVL is an optimized communication strategy for NVIDIA GPU clusters. - -MNNVL supports post-quant dispatch for all quantization modes +This module implements the NVLINK two-sided comm AllToAll communication method for MoE. +NVLINK Two-Sided supports post-quant dispatch for all quantization modes. """ import os @@ -34,9 +32,14 @@ from .base import Communication -class MnnvlLatency(Communication): +class NVLinkTwoSided(Communication): """ - MNNVL AllToAll strategy for latency scenarios + NVLINK two-sided comm AllToAll strategy. + This implementation utilizes symmetric memory to enable peer-to-peer access between GPUs over NVLink. + The kernel takes the role as both sender and receiver: as the sender, it puts the data into a FIFO + quene in peer ranks' symmetric memory; as the receiver, it gets the data from the FIFO quene to the + local buffer. This communication model is akin to NCCL's collective operations. + The required symmetric memory size is proportional to the communication channels opened. """ def __init__( @@ -62,7 +65,7 @@ def __init__( os.environ.get("TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") == "1" ) - # Initialize MNNVL workspaces + # Initialize NVLINK workspaces MnnvlMemory.initialize() self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(mapping) self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(mapping) @@ -73,24 +76,24 @@ def __init__( @staticmethod def is_platform_supported() -> bool: """ - Check if MNNVL is supported on current hardware + Check if NVLINK two-sided comm is supported on current hardware. """ return MnnvlMemory.supports_mnnvl() def supports_post_quant_dispatch(self) -> bool: """ - MNNVL supports post-quant for all modes + NVLINK two-sided comm supports post-quant for all modes. """ return self.enable_postquant_alltoall def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool: """ - Check if MNNVL is feasible for the given workload at runtime. + Check if NVLINK two-sided comm is feasible for the given workload at runtime. This method performs runtime checks based on workload characteristics such as token counts, number of chunks, and other runtime parameters. """ - return self.is_platform_supported() + return True def prepare_dispatch( self, @@ -99,12 +102,12 @@ def prepare_dispatch( local_statistic_tensor: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: """ - MNNVL prepare dispatch: gather EPLB statistics and prepare alltoall_info + NVLINK two-sided comm prepare dispatch: gather EPLB statistics and prepare alltoall_info. """ all_rank_max_num_tokens = max(all_rank_num_tokens) top_k = token_selected_slots.shape[1] - # Call MNNVL prepare to get alltoall_info and gather EPLB statistics + # Call NVLINK prepare to get alltoall_info and gather EPLB statistics alltoall_info, gathered_local_statistic_tensor = ( MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( token_selected_slots, @@ -135,12 +138,14 @@ def dispatch( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: """ - MNNVL dispatch (post-quant, uses alltoall_info from prepare_dispatch) + NVLINK two-sided comm dispatch (post-quant, uses alltoall_info from prepare_dispatch). """ # Read alltoall_info from dispatch_state (set by prepare_dispatch) alltoall_info = self._dispatch_state.get("alltoall_info") if alltoall_info is None: - raise ValueError("MNNVL dispatch requires prepare_dispatch() to be called first") + raise ValueError( + "NVLinkTwoSided dispatch requires prepare_dispatch() to be called first" + ) all_rank_max_num_tokens = max(all_rank_num_tokens) original_token_count = hidden_states.shape[0] # Store for combine @@ -178,8 +183,7 @@ def combine( **kwargs, ) -> torch.Tensor: """ - MNNVL combine - reads from self._dispatch_state - + NVLINK two-sided comm combine - reads from self._dispatch_state. """ if isinstance(final_hidden_states, list): final_hidden_states = final_hidden_states[0] diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py new file mode 100644 index 00000000000..dfe1f091857 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -0,0 +1,1098 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ConfigurableMoE: Composition-based Configurable MoE Module + +This module provides a universal MoE execution flow using composition pattern: +- MoE Backend: Pluggable computation backend (Cutlass, TRTLLMGen, etc.) +- Communication Strategy: Pluggable communication (AllGather, AllToAll, etc.) +- EPLB: Optional load balancing (can be toggled on/off) + +Design Principles: +1. Use composition instead of inheritance for flexibility +2. Backend declares its capabilities (separated vs fused routing) +3. ConfigurableMoE adapts flow based on backend capabilities +4. Unified EPLB integration for backends that support it +""" + +from typing import Dict, List, Optional, Union + +import torch + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.modules.fused_moe.interface import MoE +from tensorrt_llm._torch.modules.fused_moe.routing import BaseMoeRoutingMethod +from tensorrt_llm._torch.utils import AuxStreamType, EventType, Fp4QuantizedTensor +from tensorrt_llm.logger import logger + +from .communication import ( + AllGatherReduceScatter, + Communication, + CommunicationFactory, + DeepEP, + DeepEPLowLatency, + NVLinkOneSided, + NVLinkTwoSided, +) +from .fused_moe_cutlass import CutlassFusedMoE +from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE + + +class ConfigurableMoE(MoE): + """ + Configurable MoE layer using composition pattern with automatic configuration + + This class orchestrates the MoE execution flow by composing: + - moe_backend: Existing FusedMoE implementation (CutlassFusedMoE, WideEPMoE, etc.) + Note: Current FusedMoE implementations are used as backends (transitional). + Future will have dedicated MoEBackend interface. + - Communication: Handles distributed communication (auto-selected) + - EPLB (optional): Handles expert parallel load balancing (auto-detected) + + Args: + routing_method: Routing method for token-to-expert assignment + num_experts: Total number of experts + hidden_size: Hidden dimension size + intermediate_size: Intermediate dimension size + dtype: Data type for weight + reduce_results: Whether to reduce results + model_config: Model configuration + aux_stream_dict: Auxiliary CUDA streams for overlap + weight_loading_mode: Weight loading mode + layer_idx: Layer index + **kwargs: Additional arguments + - backend_type: Backend type ('cutlass', 'trtllm_gen_min_latency', etc.) + Default: 'cutlass' + - tune_max_num_tokens: Max tokens for profiling (passed to backend) + - Other backend-specific arguments + + Key Attributes: + - backend: MoE computation backend (auto-created attribute) + - comm: Communication strategy (auto-created attribute, can be None) + - layer_load_balancer: EPLB instance (auto-detected, optional) + + Auto-Detection: + - EPLB: Enabled if get_moe_load_balancer() is not None + - Backend: Defaults to CutlassMoEBackend, override via backend_type + - Communication: Auto-selected based on hardware (NVLINK > DeepEP > AllGather) + """ + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, + weight_loading_mode=None, + apply_router_weight_on_input: bool = False, + layer_idx: Optional[int] = None, + **kwargs, + ): + super().__init__( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + weight_loading_mode=weight_loading_mode, + layer_idx=layer_idx, # ConfigurableMoE needs correct layer_idx for EPLB initialization + **kwargs, + ) + + # Store model_config and aux_stream_dict for later use (e.g., backend setter) + self.model_config = model_config + self.aux_stream_dict = aux_stream_dict + + # If True, the router weight will be multiplied on the input rather than at the end of FC2 + self.apply_router_weight_on_input = apply_router_weight_on_input + + # ========== Create MoE Backend (Default: Cutlass) ========== + from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend, get_moe_cls + + # Get MoE backend class based on model_config + moe_cls = get_moe_cls(model_config, override_quant_config=None) + + # Call create_moe_backend with all necessary parameters + # init_load_balancer=False: Prevents backend from registering itself with load balancer + # without_comm=True: Prevents backend from initializing communication (ConfigurableMoE handles it) + # skip_create_weights_in_init=True: Prevents backend from creating weights in __init__ + # because backend uses layer_idx=None and may have different expert assignments + # We will create weights after syncing attributes from ConfigurableMoE + tmp_skip_create_weights_in_init = model_config.skip_create_weights_in_init + model_config._frozen = False + model_config.skip_create_weights_in_init = True + model_config._frozen = True + + self.backend = create_moe_backend( + moe_cls=moe_cls, + routing_method=routing_method, + num_experts=self.num_experts, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + dtype=self.dtype, + reduce_results=self.reduce_results, + model_config=model_config, + aux_stream_dict=self.aux_stream_dict, + weight_loading_mode=self.weight_loading_mode, + bias=kwargs.get("bias", False), + apply_router_weight_on_input=self.apply_router_weight_on_input, + layer_idx=None, + swiglu_alpha=kwargs.get("swiglu_alpha"), + swiglu_beta=kwargs.get("swiglu_beta"), + swiglu_limit=kwargs.get("swiglu_limit"), + init_load_balancer=False, + without_comm=True, + ) + + # Sync critical attributes from ConfigurableMoE to backend + # ConfigurableMoE's super().__init__() was called with real layer_idx and initialized load balancer. + # Backend was created with init_load_balancer=False and without_comm=True to avoid + # duplicate initialization. Now sync all attributes from ConfigurableMoE to backend. + self.backend.aux_stream_dict = self.aux_stream_dict + self.backend.layer_idx = self.layer_idx + self.backend.layer_idx_str = self.layer_idx_str + self.backend.num_slots = self.num_slots + self.backend.layer_load_balancer = self.layer_load_balancer + self.backend.repeat_count = self.repeat_count + self.backend.repeat_idx = self.repeat_idx + self.backend.initial_local_expert_ids = self.initial_local_expert_ids + self.backend.initial_global_assignments = self.initial_global_assignments + self.backend.slot_start = self.slot_start + self.backend.slot_end = self.slot_end + self.backend.expert_size_per_partition = self.expert_size_per_partition + + # Create weights here, because the backend needs the layer_load_balancer info to create weights + model_config._frozen = False + model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init + model_config._frozen = True + if not model_config.skip_create_weights_in_init: + self.backend.create_weights() + + # ========== Create Communication Strategy ========== + self._comm = self._create_comm_strategy_auto() + + # ========== Chunking Configuration ========== + # moe_max_num_tokens is set in ModelConfig.__post_init__ if not specified + # The default value is max_num_tokens * dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens + default_moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + + # Auxiliary stream for chunking overlap + if self.moe_max_num_tokens < default_moe_max_num_tokens: + self.aux_stream = ( + aux_stream_dict[AuxStreamType.MoeChunkingOverlap] + if aux_stream_dict is not None + else torch.cuda.Stream() + ) + self.event_dict = { + key: torch.cuda.Event() for key in [EventType.Main, EventType.MoeChunkingOverlap] + } + else: + self.aux_stream = None + self.event_dict = None + + # Validate configuration + self.validate_config() + + # Mark as _weights_removed to skip ConfigurableMoE's post_load_weights in model_loader + # The backend's post_load_weights will be called directly by model_loader + # This avoids duplicate post_load_weights calls (once for ConfigurableMoE, once for backend) + # TODO: in the future, all the weights related work should be done only in backend. + self._weights_removed = True + + def _supports_load_balancer(self) -> bool: + """Check if this MoE implementation supports load balancer.""" + # During initialization, backend might not be created yet + # Return True by default (most backends support it), backend will validate later + if not hasattr(self, "backend") or self.backend is None: + return self.use_dp and self.parallel_size > 1 + return self.backend._supports_load_balancer() + + def validate_config(self): + """ + Validate configuration parameters + + Validates: + - apply_router_weight_on_input: Only supports top-1 routing + """ + if self.apply_router_weight_on_input: + assert self.routing_method.top_k == 1, ( + "apply_router_weight_on_input only supports top-1 routing" + ) + + def _create_comm_strategy(self, model_config: ModelConfig) -> Optional[Communication]: + """ + Create communication strategy based on configuration + + Default: None (will use factory to auto-select when needed) + Auto-selects best strategy based on hardware and configuration + + """ + # Communication strategy is None by default + # Will be created lazily in determine_communication_method() when first needed + # For now, return None and create on-demand + return None + + def _get_quant_config_dict(self, model_config: ModelConfig) -> Optional[Dict]: + """ + Extract quantization configuration from model_config + + """ + if model_config.quant_config is None: + return None + + quant_mode = model_config.quant_config.layer_quant_mode + return { + "has_fp8_qdq": quant_mode.has_fp8_qdq() + if hasattr(quant_mode, "has_fp8_qdq") + else False, + "has_nvfp4": quant_mode.has_nvfp4() if hasattr(quant_mode, "has_nvfp4") else False, + "has_w4afp8": quant_mode.is_int4_weight_only_per_group() + if hasattr(quant_mode, "is_int4_weight_only_per_group") + else False, + "has_fp8_block_scales": quant_mode.has_fp8_block_scales() + if hasattr(quant_mode, "has_fp8_block_scales") + else False, + } + + def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int: + """ + Calculate how many chunks are needed + + """ + num_rows = sum(all_rank_num_tokens) + return (num_rows + self.moe_max_num_tokens - 1) // self.moe_max_num_tokens + + def split_chunk(self, split_token_num: int, split_num_chunks: int) -> List[int]: + """ + Split token count into multiple chunks as evenly as possible + + """ + val_div = split_token_num // split_num_chunks + val_mod = split_token_num % split_num_chunks + split_chunk_size_list = [val_div + 1] * val_mod + [val_div] * (split_num_chunks - val_mod) + return split_chunk_size_list + + def determine_communication_method( + self, all_rank_num_tokens: List[int], num_chunks: int + ) -> None: + """ + Determine and setup communication method with automatic fallback + + This method: + 1. Returns early if comm is None or already AllGather (nothing to validate) + 2. Validates if current AllToAll strategy can be used for given workload + 3. Falls back to AllGather if current strategy cannot be used (logs info message) + + After calling this method, use _is_using_alltoall() to check which method is active. + + Args: + all_rank_num_tokens: Token counts per rank + num_chunks: Number of chunks + + Side effects: + - May switch self.comm to AllGather if current strategy cannot be used + + Note: This method does NOT create strategy if None (creation happens lazily elsewhere). + It only validates and potentially falls back existing AllToAll strategies. + + """ + + # Early return if nothing to validate: + # - None: Atten is TP or single rank, no communication needed + # - AllGather: Already using fallback strategy, no validation needed + if self.comm is None or isinstance(self.comm, AllGatherReduceScatter): + return + + # Check if current strategy can be used + feasible_workload = self.comm.is_workload_feasible(all_rank_num_tokens, num_chunks) + + if not feasible_workload: + # Current comm cannot be used, fallback to AllGather + all_rank_max_num_tokens = max(all_rank_num_tokens) + logger.info( + f"Communication strategy {self.comm.__class__.__name__} " + f"cannot be used (num_chunks={num_chunks}, max_num_tokens={all_rank_max_num_tokens}). " + f"Falling back to AllGatherReduceScatter." + ) + + # Switch to AllGather (always works) + self.comm = AllGatherReduceScatter(mapping=self.mapping) + + def _is_using_alltoall(self) -> bool: + """ + Check if current communication strategy uses alltoall + + Returns: + True: Strategy uses alltoall (NVLINK, DeepEP, etc.) + False: Strategy uses allgather (AllGatherReduceScatter or None) + + Note: Can be called anytime. If comm is None, returns False (no alltoall). + Typically called after determine_communication_method() to get accurate result. + """ + if self.comm is None: + return False # No strategy means no alltoall + + # AllGather uses allgather, all others use alltoall + return not isinstance(self.comm, AllGatherReduceScatter) + + def _create_comm_strategy_auto(self) -> Communication: + """ + Auto-create the best communication strategy based on hardware and configuration + + Uses factory to select optimal strategy. + + """ + return CommunicationFactory.create_strategy( + model_config=self.model_config, + num_experts=self.num_experts, + num_slots=self.num_slots, + top_k=self.routing_method.experts_per_token, + expert_size_per_partition=self.expert_size_per_partition, + payload_in_workspace=False, # ConfigurableMoE does not use workspace output for now + # Currently the TRTLLMGEN reduce sum internally. + # Keep updated with more supported backends. + alltoall_result_do_sum=True, + ) + + def forward_impl( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + *, + do_finalize: bool = True, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + """ + Universal forward implementation framework + + Flow: + 1. Handle padding + 2. Calculate chunk count and determine communication method + 3. Execute MoE computation (single or multiple chunks) + 4. Handle output truncation and EPLB repeat + """ + # ========== Step 1: Handle padding ========== + if all_rank_num_tokens is None: + all_rank_num_tokens = [x.shape[0]] + + all_rank_max_num_tokens = max(all_rank_num_tokens) + + if use_dp_padding: + all_rank_num_tokens_padded = [all_rank_max_num_tokens] * len(all_rank_num_tokens) + else: + all_rank_num_tokens_padded = all_rank_num_tokens + + # ========== Step 2: Determine communication method ========== + num_chunks = self.calculate_num_chunks(all_rank_num_tokens_padded) + + # Determine and setup communication strategy (may fallback to AllGather) + self.determine_communication_method(all_rank_num_tokens_padded, num_chunks) + + # ========== Step 3: Execute MoE computation ========== + if num_chunks == 1: + # Single chunk case + outputs = self._forward_single_chunk( + x, + router_logits, + output_dtype, + all_rank_num_tokens_padded, + use_dp_padding, + do_finalize, + ) + else: + # Multiple chunks case + outputs = self._forward_multiple_chunks( + x, + router_logits, + num_chunks, + output_dtype, + all_rank_num_tokens_padded, + use_dp_padding, + do_finalize, + ) + + # ========== Step 4: Handle output truncation and EPLB repeat ========== + if self.use_dp and self.parallel_size > 1: + outputs = outputs[: all_rank_num_tokens[self.mapping.tp_rank]] + + # EPLB repeat logic + self.repeat_idx = (self.repeat_idx + 1) % self.repeat_count + + return outputs + + def _forward_single_chunk( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype], + all_rank_num_tokens: List[int], + use_dp_padding: Optional[bool], + do_finalize: bool = True, + ) -> torch.Tensor: + """ + Single chunk execution path + + """ + # Calculate EPLB flags (first call or last call) + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 + + # Execute unified flow (handles both separated and fused routing) + outputs = self._forward_chunk_impl( + x, + router_logits, + output_dtype, + all_rank_num_tokens, + use_dp_padding, + is_first_call, + is_last_call, + do_finalize, + ) + + return outputs + + def _forward_chunk_impl( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype], + all_rank_num_tokens: List[int], + use_dp_padding: bool, + is_first_call: bool, + is_last_call: bool, + do_finalize: bool = True, + ) -> torch.Tensor: + """ + Unified execution flow for all backends + + Flow (based on EPLB_in_MOE[1].html): + 1. [EPLB] Start wait GPU stage (first call only, if enabled) + 2. Apply routing (only if backend supports routing separation) + 3. [EPLB] Update statistics and route (only if EPLB enabled) + 4. Quantization and Communication (adaptive ordering) + 5. MoE computation (backend) + 6. [EPLB] Start CPU stage (last call only, if enabled) + 7. Communication combine + 8. [EPLB] Done CPU stage (last call only, if enabled) + + - Separated routing: fused_moe_wide_ep.py:456-780, fused_moe_cutlass.py:236-443 + - Fused routing: fused_moe_trtllm_gen.py + """ + + # ========== Step 1: EPLB - Start wait GPU stage ========== + self._load_balancer_start_wait_gpu_stage(is_first_call) + + # ========== Step 2: Apply routing (only if backend supports load balancer) ========== + + if self.backend._supports_load_balancer(): + # Separated routing: ConfigurableMoE calls routing_method + token_selected_experts, token_final_scales = self.routing_method.apply(router_logits) + + # Convert to standard dtypes for consistency with other MoE implementations + token_selected_experts = token_selected_experts.to(torch.int32) + + assert token_selected_experts.shape[1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + # CutlassFusedMoE expects float32, while TRTLLMGenFusedMoE uses bfloat16 + if isinstance(self.backend, CutlassFusedMoE): + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + # Convert token_final_scales to bfloat16 if needed (TRTLLMGen backend requires it) + if token_final_scales is not None and isinstance(self.backend, TRTLLMGenFusedMoE): + token_final_scales = token_final_scales.to(torch.bfloat16) + + # Apply router weight on input if enabled + if self.apply_router_weight_on_input: + assert x.dtype != torch.float8_e4m3fn, ( + "Current workaround for apply_router_weight_on_input does not support fp8 input" + ) + x = x * token_final_scales.to(x.dtype) + # TODO: remove this once we have correct fusedmoe kernel ready + # Check if using DeepEP strategies (they don't support token_final_scales=None) + if isinstance(self.comm, (DeepEP, DeepEPLowLatency)): + # DeepEP doesn't support token_final_scales is None + token_final_scales = torch.ones_like(token_final_scales) + else: + token_final_scales = None + + else: + # Fused routing: Backend handles routing internally + # EPLB must NOT be enabled for fused routing backends + assert not self._using_load_balancer(), ( + f"EPLB is enabled but backend {self.backend.__class__.__name__} " + f"has fused routing (does not support routing separation)" + ) + + # For fused routing, we don't have token_selected_experts yet + # Will be handled by backend.run_moe_with_routing() later + token_selected_experts = None + token_final_scales = None + + # ========== Step 3: EPLB - Update statistics and route ========== + # Only executed if backend supports routing separation AND EPLB is enabled + if self.layer_load_balancer and token_selected_experts is not None: + self._load_balancer_done_wait_gpu_stage(is_first_call) + + # Update EPLB statistics (method depends on whether using NVLINK two-sided) + # Use base class method: ignore_allreduce=True for NVLINK two-sided (uses local stats only) + ignore_allreduce = self._is_using_nvlink_two_sided() + self._load_balancer_update_statistic( + token_selected_experts, + is_first_call, + is_last_call, + ignore_allreduce=ignore_allreduce, + ) + + # EPLB routing: expert IDs -> slot IDs + token_selected_slots = self._load_balancer_route(token_selected_experts, self.use_dp) + else: + token_selected_slots = token_selected_experts + + # ========== Step 3.5: Communication Prepare Phase (BEFORE quantization) ========== + # NVLINK two-sided has a prepare phase to gather EPLB statistics + + # Only NVLINK two-sided needs prepare_dispatch + if self._is_using_nvlink_two_sided(): + # Get local statistic info if this is the last call and EPLB is enabled + local_statistic_tensor = None + if is_last_call: + local_statistic_tensor = self._load_balancer_get_local_statistic_tensor() + + # Call prepare_dispatch (gathers statistics for NVLINK two-sided) + # prepare_dispatch stores alltoall_info in _dispatch_state and returns gathered_stats + gathered_stats = self.comm.prepare_dispatch( + token_selected_slots, all_rank_num_tokens, local_statistic_tensor + ) + + # Update EPLB with gathered statistics (if available) + if gathered_stats is not None: + gathered_stats = gathered_stats.view((self.mapping.moe_ep_size, self.num_experts)) + self._load_balancer_update_statistic_with_gathered_statistic(gathered_stats) + + # ========== Step 4 & 5: Quantization and Communication Dispatch ========== + # Order depends on whether strategy supports post-quant dispatch + if self.comm is not None: + # Check if we should use post-quant dispatch + # supports_post_quant_dispatch checks strategy capability for the current quant mode + supports_post_quant = self.comm.supports_post_quant_dispatch() + + if supports_post_quant: + # ===== Post-quant flow: Quantize → Dispatch ===== + + # Step 4a: Quantization FIRST + x, x_sf = self.backend.quantize_input(x) + + # Step 4b: Dispatch AFTER quantization + # Get pre_quant_scale for W4AFP8 if available (only DeepEPLowLatency needs it) + # Other strategies will ignore this via **kwargs, so it's safe to pass unconditionally + dispatch_kwargs = {} + if hasattr(self, "quant_scales") and self.quant_scales is not None: + if hasattr(self.quant_scales, "pre_quant_scale_1"): + dispatch_kwargs["pre_quant_scale"] = self.quant_scales.pre_quant_scale_1 + + x, x_sf, token_selected_slots, token_final_scales = self.comm.dispatch( + hidden_states=x, + hidden_states_sf=x_sf, + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + **dispatch_kwargs, + ) + else: + # ===== Pre-quant flow: Dispatch → Quantize ===== + + # Step 4a: Dispatch FIRST (unquantized data) + x, x_sf, token_selected_slots, token_final_scales = self.comm.dispatch( + hidden_states=x, + hidden_states_sf=None, # Not quantized yet + token_selected_slots=token_selected_slots, + token_final_scales=token_final_scales, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + ) + + # Step 4b: Quantization AFTER dispatch + x, x_sf = self.backend.quantize_input(x) + else: + # No communication, just quantize + # (use non-post-quant-comm path for TRTLLMGenFusedMoE) + x, x_sf = self.backend.quantize_input(x, post_quant_comm=False) + + # ========== Step 6: MoE Computation ========== + + # Call unified run_moe interface with common parameters + # If EPLB is enabled, token_selected_slots represents expert slots + # Otherwise, token_selected_experts represents expert IDs + final_hidden_states = self.backend.run_moe( + x=x, + token_selected_experts=token_selected_slots, + token_final_scales=token_final_scales, + x_sf=x_sf, + **self._get_backend_kwargs( + router_logits, do_finalize, all_rank_num_tokens, output_dtype + ), + ) + + # ========== Step 8: EPLB - Start CPU stage ========== + self._load_balancer_start_set_cpu_stage(is_last_call) + + # ========== Step 9: Communication - Combine ========== + if self.comm is not None: + # Use unified combine interface (reads dispatch state from strategy) + final_hidden_states = self.comm.combine(final_hidden_states) + else: + # For non-comm case, It should be attention TP or single rank. + # only check if allreduce is needed + if self.parallel_size > 1 and self.reduce_results: + final_hidden_states = self.all_reduce(final_hidden_states) + # ========== Step 10: EPLB - Done CPU stage ========== + self._load_balancer_done_set_cpu_stage(is_last_call) + + return final_hidden_states + + def _forward_multiple_chunks( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + num_chunks: int, + output_dtype: Optional[torch.dtype], + all_rank_num_tokens: List[int], + use_dp_padding: Optional[bool], + do_finalize: bool = True, + ) -> torch.Tensor: + """ + Multiple chunks execution path with auxiliary stream for overlapping + + Same as original implementation - chunking logic is backend-agnostic + + Note: use_all_to_all is determined internally via _is_using_alltoall() + + """ + # Determine if using alltoall + use_all_to_all = self._is_using_alltoall() + # ========== Chunk preparation ========== + if self.use_dp: + # When using DP: need all ranks' token counts for reducescatter + all_rank_chunk_size_list = [ + self.split_chunk(val, num_chunks) for val in all_rank_num_tokens + ] + all_rank_num_tokens_list = [ + [val[idx_chunk] for val in all_rank_chunk_size_list] + for idx_chunk in range(num_chunks) + ] + chunk_size_list = all_rank_chunk_size_list[self.rank] + + # For alltoall, replace 0 with 1 (avoid empty tensor) + if use_all_to_all: + all_rank_num_tokens_list = [ + [1 if val == 0 else val for val in val_list] + for val_list in all_rank_num_tokens_list + ] + else: + # When not using DP: only need current rank's input size + all_rank_num_tokens_list = [None] * num_chunks + chunk_size_list = self.split_chunk(x.shape[0], num_chunks) + + x_list = x.split(chunk_size_list) + router_logits_list = router_logits.split(chunk_size_list) + + # ========== Setup auxiliary stream ========== + if not use_all_to_all and self.aux_stream is not None: + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() + + # ========== Execute chunking with overlap ========== + outputs_list = [] + for idx_chunk, (x_chunk, router_logits_chunk) in enumerate(zip(x_list, router_logits_list)): + # Calculate EPLB's first/last call + is_first_call = idx_chunk == 0 and self.repeat_idx == 0 + is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1 + + if not use_all_to_all and self.aux_stream is not None: + # Alternate between main stream and auxiliary stream + # Each stream processes complete chunks (forward + reducescatter) + if idx_chunk % 2 == 0: + # Even chunk: execute on auxiliary stream + with torch.cuda.stream(self.aux_stream): + outputs = self._forward_chunk_impl( + x_chunk, + router_logits_chunk, + output_dtype, + all_rank_num_tokens_list[idx_chunk], + use_dp_padding, + is_first_call, + is_last_call, + do_finalize, + ) + else: + # Odd chunk: execute on main stream + outputs = self._forward_chunk_impl( + x_chunk, + router_logits_chunk, + output_dtype, + all_rank_num_tokens_list[idx_chunk], + use_dp_padding, + is_first_call, + is_last_call, + do_finalize, + ) + else: + # No overlap + outputs = self._forward_chunk_impl( + x_chunk, + router_logits_chunk, + output_dtype, + all_rank_num_tokens_list[idx_chunk], + use_dp_padding, + is_first_call, + is_last_call, + do_finalize, + ) + + outputs_list.append(outputs) + + # ========== Wait for auxiliary stream to complete ========== + if not use_all_to_all and self.aux_stream is not None: + # Wait for auxiliary stream to complete all its chunks + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.MoeChunkingOverlap].record() + self.event_dict[EventType.MoeChunkingOverlap].wait() + + # ========== Concatenate outputs from all chunks ========== + outputs = torch.cat(outputs_list) + + return outputs + + # ========== Backend Property with Validation ========== + + @property + def backend(self) -> MoE: + """ + Get the current MoE backend implementation + + Note: Returns a FusedMoE instance (e.g., CutlassFusedMoE, WideEPMoE) + """ + return self._backend + + @backend.setter + def backend(self, backend: MoE): + """ + Set MoE backend with validation + + This setter validates that: + 1. Backend is not None + 2. If EPLB is enabled, backend must support routing separation + + Args: + backend: MoEBackend instance to set + + Raises: + ValueError: If backend is incompatible with current configuration + + Note: EPLB initialization is done in __init__, not in setter. + Setter only validates compatibility. + """ + if backend is None: + raise ValueError("Backend cannot be None") + + # Validate EPLB compatibility + if self._using_load_balancer() and not backend._supports_load_balancer(): + raise ValueError( + f"EPLB is enabled but backend {backend.__class__.__name__} " + f"does not support load balancer. " + f"Either disable EPLB or use a backend that supports load balancer." + ) + + # Set backend (validation passed) + self._backend = backend + + @property + def comm(self) -> Optional[Communication]: + """Get the current communication strategy""" + return self._comm + + @comm.setter + def comm(self, strategy: Optional[Communication]): + """ + Set communication strategy with validation + + This setter validates that the strategy is compatible with the configuration. + + Args: + strategy: Communication instance to set (can be None for lazy creation) + + Raises: + ValueError: If strategy is incompatible with current configuration + + Note: Unlike backend, comm can be None (will be created lazily). + This allows for automatic strategy selection based on hardware. + """ + # comm can be None (lazy creation) + if strategy is None: + self._comm = None + return + + # Set strategy (validation passed) + self._comm = strategy + + # ========== Helper Methods ========== + + def _is_using_nvlink_two_sided(self) -> bool: + """Check if using NVLinkTwoSided communication strategy""" + return isinstance(self.comm, NVLinkTwoSided) + + def _get_backend_kwargs( + self, + router_logits: Optional[torch.Tensor] = None, + do_finalize: bool = True, + all_rank_num_tokens: Optional[List[int]] = None, + output_dtype: Optional[torch.dtype] = None, + ) -> Dict: + """ + Get backend-specific keyword arguments for run_moe + + Returns backend-specific parameters that are not part of the common run_moe interface. + Different backends need different parameters - this method provides them via kwargs. + + TODO: This is not finalized, will be updated later. + Common kwargs (multiple backends): + - cluster_size, cluster_rank: Cutlass, DeepGemm + - min_latency_mode: Cutlass, WideEP, DeepGemm + - use_fused_finalize: Cutlass, WideEP + - tuner_num_tokens, tuner_top_k: Cutlass, WideEP + + Backend-specific kwargs: + - Cutlass: swizzled_input_sf, enable_alltoall, output_tensor + - WideEP: swizzled_input_sf (fixed False), use_all_to_all + - DeepGemm: workspace, permutation tensors + - TRTLLMGen: router_logits, do_finalize, moe_output + + Args: + router_logits: Router logits tensor (for TRTLLMGen backend) + do_finalize: Whether to finalize output (for TRTLLMGen backend) + all_rank_num_tokens: Token counts per rank (for TRTLLMGen backend moe_output) + + Returns: + Dict: Backend-specific keyword arguments + """ + backend_name = self.backend.__class__.__name__ + kwargs = {} + + # Common parameters for Cutlass and DeepGemm + if backend_name in ["CutlassFusedMoE", "DeepGemmFusedMoE"]: + pass + + # Cutlass-specific parameters + if backend_name == "CutlassFusedMoE": + pass + + # WideEP-specific parameters + elif backend_name == "WideEPMoE": + pass + + # DeepGemm-specific parameters + elif backend_name == "DeepGemmFusedMoE": + pass + + # TRTLLMGen-specific parameters + elif backend_name == "TRTLLMGenFusedMoE": + # Determine router_logits based on whether routing has been done + # If backend doesn't support load balancer, routing is done before communication + # In that case, router_logits should be None (routing already done) + router_logits_arg = None + if not self.backend._supports_load_balancer(): + # For fused routing backends, router_logits is only needed if routing hasn't been done yet + router_logits_arg = router_logits + + kwargs["router_logits"] = router_logits_arg + kwargs["do_finalize"] = do_finalize + + # moe_output: workspace output buffer for NVLINK one-sided backend + # TRTLLMGenFusedMoE only supports workspace output for w4a8_mxfp4_mxfp8 quantization. + moe_output = None + if isinstance(self.comm, NVLinkOneSided): + # Determine dtype for workspace tensor + # TRTLLMGenFusedMoE always uses bfloat16, other backends use output_dtype + workspace_dtype = output_dtype + if isinstance(self.backend, TRTLLMGenFusedMoE): + self.comm.invalid_token_expert_id = -1 + workspace_dtype = torch.bfloat16 + + # Check if backend supports workspace output for current quantization + backend_supports_workspace = ( + isinstance(self.backend, TRTLLMGenFusedMoE) + and self.backend.has_w4a8_mxfp4_mxfp8 + ) + if backend_supports_workspace: + assert all_rank_num_tokens is not None, ( + "all_rank_num_tokens must be provided for NVLinkOneSided backend with workspace output" + ) + runtime_max_tokens_per_rank = max(all_rank_num_tokens) + + moe_output = self.comm.get_combine_payload_tensor_in_workspace( + runtime_max_tokens_per_rank, self.hidden_size, workspace_dtype + ) + # Dynamically enable payload_in_workspace for this forward pass + self.comm.payload_in_workspace = True + else: + # Ensure payload_in_workspace is False for non-workspace output + self.comm.payload_in_workspace = False + kwargs["moe_output"] = moe_output + + return kwargs + + def create_weights(self): + """ + Create weights - delegated to backend + + """ + assert hasattr(self.backend, "create_weights"), ( + f"Backend {self.backend.__class__.__name__} must implement create_weights()" + ) + return self.backend.create_weights() + + def load_weights(self, weights: List[Dict]): + """ + Load weights - delegated to backend + + """ + assert hasattr(self.backend, "load_weights"), ( + f"Backend {self.backend.__class__.__name__} must implement load_weights()" + ) + return self.backend.load_weights(weights) + + def post_load_weights(self): + """ + Post load weights processing - delegated to backend + + """ + assert hasattr(self.backend, "post_load_weights"), ( + f"Backend {self.backend.__class__.__name__} must implement post_load_weights()" + ) + return self.backend.post_load_weights() + + # ========== Communication and Quantization Properties ========== + + @property + def enable_alltoall(self): + """ + Check if alltoall is enabled + + This delegates to the communication strategy to determine if alltoall is available. + + """ + if self.comm is None: + return False + # Simplified check - AllGather strategy means no alltoall + return not isinstance(self.comm, AllGatherReduceScatter) + + @property + def _weights_created(self): + """Check if weights have been created (required for quantization properties)""" + assert hasattr(self.backend, "_weights_created"), ( + f"Backend {self.backend.__class__.__name__} must have _weights_created attribute" + ) + return self.backend._weights_created + + # ========== Explicit Backend Attribute Proxies ========== + # These properties delegate to backend for commonly accessed attributes + # TODO: Unify the property access to backend in ConfigurableMoE. + # At the same time, we need to keep the existing test cases working. + + @property + def quant_method(self): + """Delegate quant_method to backend""" + return getattr(self.backend, "quant_method", None) + + @property + def w3_w1_weight(self): + """Delegate w3_w1_weight to backend""" + return getattr(self.backend, "w3_w1_weight", None) + + @property + def w2_weight(self): + """Delegate w2_weight to backend""" + return getattr(self.backend, "w2_weight", None) + + @property + def has_nvfp4(self): + """Delegate has_nvfp4 to backend""" + return getattr(self.backend, "has_nvfp4", False) + + def forward_fake( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + *, + do_finalize: bool = True, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Fake forward for shape inference during torch.compile + + Delegates to backend's forward_fake if available, otherwise calls parent's forward_fake + + Args: + x: Input tensor + router_logits: Router logits for expert selection + do_finalize: Whether to finalize MoE output + output_dtype: Output data type + all_rank_num_tokens: Token counts per rank + use_dp_padding: Whether to use data parallel padding + **kwargs: Additional arguments + + Returns: + Empty tensor(s) with correct shape for torch.compile + """ + if hasattr(self.backend, "forward_fake"): + # Backend has forward_fake, delegate to it + return self.backend.forward_fake( + x, + router_logits, + do_finalize=do_finalize, + output_dtype=output_dtype, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + **kwargs, + ) + else: + # Backend doesn't have forward_fake, use parent's implementation + return super().forward_fake( + x, + router_logits, + do_finalize=do_finalize, + output_dtype=output_dtype, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + **kwargs, + ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 6a5790bbfca..f8ce1a0c310 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -1,3 +1,4 @@ +import os from typing import Dict, Optional, Type import torch @@ -7,6 +8,7 @@ from ...model_config import ModelConfig from ...utils import AuxStreamType +from .configurable_moe import ConfigurableMoE from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE from .fused_moe_deepgemm import DeepGemmFusedMoE @@ -57,15 +59,17 @@ def get_moe_cls( raise ValueError(f"Unsupported moe backend: {moe_backend}") -def create_moe( +def create_moe_backend( + moe_cls: Type[MoE], routing_method: BaseMoeRoutingMethod, - num_experts: int, - hidden_size: int, - intermediate_size: int, + # TODO: remove num_experts, hidden_size, intermediate_size, dtype parameters + # these parameters will be inferred from model_config.pretrained_config. + num_experts: Optional[int] = None, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, dtype: Optional[torch.dtype] = None, reduce_results: bool = False, model_config: ModelConfig = ModelConfig(), - override_quant_config: Optional[QuantConfig] = None, aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, bias: bool = False, @@ -74,8 +78,51 @@ def create_moe( swiglu_alpha: Optional[torch.Tensor] = None, swiglu_beta: Optional[torch.Tensor] = None, swiglu_limit: Optional[torch.Tensor] = None, + init_load_balancer: bool = True, + without_comm: bool = False, ) -> MoE: - moe_cls = get_moe_cls(model_config, override_quant_config) + """ + Create MoE backend instance with validation. + + Args: + moe_cls: MoE backend class to instantiate + routing_method: Routing method for token-to-expert assignment + num_experts: Total number of experts (if None, get from model_config.pretrained_config) + hidden_size: Hidden dimension size (if None, get from model_config.pretrained_config) + intermediate_size: Intermediate dimension size (if None, get from model_config.pretrained_config) + dtype: Data type for weights (if None, get from model_config.pretrained_config) + reduce_results: Whether to reduce results + model_config: Model configuration + aux_stream_dict: Auxiliary CUDA streams for overlap + weight_loading_mode: Weight loading mode + bias: Whether to use bias + apply_router_weight_on_input: Whether to apply router weight on input + layer_idx: Layer index + swiglu_alpha: SwiGLU alpha parameter + swiglu_beta: SwiGLU beta parameter + swiglu_limit: SwiGLU limit parameter + + Returns: + MoE: MoE backend instance + """ + # Get parameters from pretrained_config if not explicitly provided + pretrained_config = model_config.pretrained_config + if num_experts is None: + assert pretrained_config is not None, "num_experts must be provided or model_config.pretrained_config must be set" + num_experts = pretrained_config.num_experts + if hidden_size is None: + assert pretrained_config is not None, "hidden_size must be provided or model_config.pretrained_config must be set" + hidden_size = pretrained_config.hidden_size + if intermediate_size is None: + assert pretrained_config is not None, "intermediate_size must be provided or model_config.pretrained_config must be set" + # For MoE models, prefer moe_intermediate_size if available + if hasattr(pretrained_config, 'moe_intermediate_size'): + intermediate_size = pretrained_config.moe_intermediate_size + else: + intermediate_size = pretrained_config.intermediate_size + if dtype is None and pretrained_config is not None and hasattr( + pretrained_config, 'torch_dtype'): + dtype = pretrained_config.torch_dtype moe_load_balancer = get_moe_load_balancer() if moe_load_balancer is not None: @@ -115,6 +162,8 @@ def create_moe( swiglu_alpha=swiglu_alpha, swiglu_beta=swiglu_beta, swiglu_limit=swiglu_limit, + init_load_balancer=init_load_balancer, + without_comm=without_comm, ) elif moe_cls == CutlassFusedMoE: return moe_cls( @@ -133,6 +182,7 @@ def create_moe( swiglu_alpha=swiglu_alpha, swiglu_beta=swiglu_beta, swiglu_limit=swiglu_limit, + init_load_balancer=init_load_balancer, ) elif moe_cls == WideEPMoE: return moe_cls( @@ -147,6 +197,7 @@ def create_moe( weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, + init_load_balancer=init_load_balancer, ) elif moe_cls == VanillaMoE: assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE." @@ -210,3 +261,127 @@ def create_moe( ) else: raise ValueError(f"Unsupported moe backend: {moe_cls}") + + +def create_moe( + routing_method: BaseMoeRoutingMethod, + num_experts: Optional[int] = None, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + override_quant_config: Optional[QuantConfig] = None, + aux_stream_dict: Optional[Dict[AuxStreamType, torch.cuda.Stream]] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, + bias: bool = False, + apply_router_weight_on_input: bool = False, + layer_idx: Optional[int] = None, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, +) -> MoE: + """ + Create MoE instance with automatic parameter inference from model_config. + + Args: + routing_method: Routing method for token-to-expert assignment + num_experts: Total number of experts (if None, get from model_config.pretrained_config) + hidden_size: Hidden dimension size (if None, get from model_config.pretrained_config) + intermediate_size: Intermediate dimension size (if None, get from model_config.pretrained_config) + dtype: Data type for weights (if None, get from model_config.pretrained_config) + reduce_results: Whether to reduce results + model_config: Model configuration + override_quant_config: Override quantization config + aux_stream_dict: Auxiliary CUDA streams for overlap + weight_loading_mode: Weight loading mode + bias: Whether to use bias + apply_router_weight_on_input: Whether to apply router weight on input + layer_idx: Layer index + swiglu_alpha: SwiGLU alpha parameter + swiglu_beta: SwiGLU beta parameter + swiglu_limit: SwiGLU limit parameter + + Returns: + MoE: MoE instance + """ + # Get parameters from pretrained_config if not explicitly provided + pretrained_config = model_config.pretrained_config + if num_experts is None: + assert pretrained_config is not None, "num_experts must be provided or model_config.pretrained_config must be set" + num_experts = pretrained_config.num_experts + if hidden_size is None: + assert pretrained_config is not None, "hidden_size must be provided or model_config.pretrained_config must be set" + hidden_size = pretrained_config.hidden_size + if intermediate_size is None: + assert pretrained_config is not None, "intermediate_size must be provided or model_config.pretrained_config must be set" + # For MoE models, prefer moe_intermediate_size if available + if hasattr(pretrained_config, 'moe_intermediate_size'): + intermediate_size = pretrained_config.moe_intermediate_size + else: + intermediate_size = pretrained_config.intermediate_size + if dtype is None and pretrained_config is not None and hasattr( + pretrained_config, 'torch_dtype'): + dtype = pretrained_config.torch_dtype + + moe_cls = get_moe_cls(model_config, override_quant_config) + + # Check if ENABLE_CONFIGURABLE_MOE environment variable is set + enable_configurable_moe = os.environ.get('ENABLE_CONFIGURABLE_MOE', + '0') == '1' + + if enable_configurable_moe: + # ConfigurableMoE is only supported for TRTLLMGenFusedMoE backend + if moe_cls == TRTLLMGenFusedMoE: + return ConfigurableMoE( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream_dict=aux_stream_dict, + weight_loading_mode=weight_loading_mode, + apply_router_weight_on_input=apply_router_weight_on_input, + layer_idx=layer_idx, + bias=bias, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, + ) + else: + # Check if this is a TRTLLM backend request that fallback to CutlassFusedMoE + requested_backend = model_config.moe_backend.upper() + if requested_backend == "TRTLLM" and moe_cls == CutlassFusedMoE: + # Workaround for test cases where TRTLLM backend fallbacks to CutlassFusedMoE due to quant_config incompatibility + # Log warning and continue with the fallback backend + logger.warning( + f"ENABLE_CONFIGURABLE_MOE is set but TRTLLM backend fallback to {moe_cls.__name__} due to quant_config. " + f"ConfigurableMoE only supports TRTLLMGenFusedMoE backend. " + f"Continuing with legacy MoE backend {moe_cls.__name__}.") + else: + # For other incompatible backends, raise error + raise ValueError( + f"ENABLE_CONFIGURABLE_MOE is set but backend {moe_cls.__name__} is not supported. " + f"ConfigurableMoE only supports TRTLLMGenFusedMoE backend.") + + # Use legacy create_moe_backend for other backends or when ConfigurableMoE is disabled + return create_moe_backend( + moe_cls=moe_cls, + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream_dict=aux_stream_dict, + weight_loading_mode=weight_loading_mode, + bias=bias, + apply_router_weight_on_input=apply_router_weight_on_input, + layer_idx=layer_idx, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, + ) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index f8e83d2bbe4..8b82a9ea6c9 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -1,6 +1,6 @@ import os from functools import cached_property -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch @@ -74,6 +74,7 @@ def __init__( swiglu_alpha: Optional[torch.Tensor] = None, swiglu_beta: Optional[torch.Tensor] = None, swiglu_limit: Optional[torch.Tensor] = None, + init_load_balancer: bool = True, ): super().__init__( @@ -90,6 +91,7 @@ def __init__( swiglu_beta=swiglu_beta, swiglu_limit=swiglu_limit, layer_idx=layer_idx, + init_load_balancer=init_load_balancer, ) # Store original hidden size before any potential padding @@ -105,11 +107,12 @@ def __init__( # slot_start, slot_end, initial_local_expert_ids are all initialized by # base class's _init_load_balancer() method - # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens + # moe_max_num_tokens is set in ModelConfig.__post_init__ if not specified + # The default value is max_num_tokens * dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < moe_max_num_tokens: + default_moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + if self.moe_max_num_tokens < default_moe_max_num_tokens: self.aux_stream = aux_stream_dict[ AuxStreamType. MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( @@ -243,6 +246,66 @@ def enable_alltoall(self): """ return self.alltoall_method_type != AlltoallMethodType.NotEnabled + def quantize_input( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize input tensor - CutlassFusedMoE implementation + + Handles all quantization cases for Cutlass backend. + """ + # Determine if this is post-quant communication scenario + run_post_quant_allgather = self.use_dp and self.parallel_size > 1 + + x_sf = None + if self.has_any_quant: + if self.has_fp8_qdq or self.has_w4a8_mxfp4_fp8: + x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, self.fc31_input_dequant) + elif self.has_deepseek_fp8_block_scales: + # No quantization needed here, handled in kernel + pass + elif self.has_w4afp8: + # No quantization needed here, handled in kernel + pass + elif self.has_w4a16_mxfp4: + pad_size = self.hidden_size - x.shape[1] + x = torch.nn.functional.pad(x, (0, pad_size)) + elif self.has_int8_woq_per_channel: + # No quantization needed here, handled in kernel + pass + elif self.has_nvfp4: + if run_post_quant_allgather or self.enable_alltoall: + if isinstance(x, Fp4QuantizedTensor): + assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication" + x, x_sf = x.fp4_tensor, x.scaling_factor + else: + x, x_sf = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, self.scaling_vector_size, + False, False) + # Reshape x_sf to 2D + x_sf = x_sf.view((x.shape[0], -1)) + else: + if not isinstance(x, Fp4QuantizedTensor): + x, x_sf = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, self.scaling_vector_size, + False, True) + elif self.has_w4a8_mxfp4_mxfp8: + if run_post_quant_allgather or self.enable_alltoall: + x, x_sf = torch.ops.trtllm.mxfp8_quantize( + x, False, alignment=self.quant_method.weight_alignment) + else: + x, x_sf = torch.ops.trtllm.mxfp8_quantize( + x, True, alignment=self.quant_method.weight_alignment) + else: + raise ValueError( + f"unsupported quantization mode: {self.quant_config.quant_mode}" + ) + + return x, x_sf + def _supports_load_balancer(self) -> bool: """CutlassFusedMoE supports load balancer.""" return True diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 062176a8fbc..292eed4c9ea 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -381,18 +381,19 @@ def __init__( apply_router_weight_on_input: bool = False, layer_idx: Optional[int] = None, ): - if model_config.moe_max_num_tokens is None: - moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size - # The default moe_max_num_tokens is calculated from the following formula: - # max_isl = 8196, max_batch_size = 1024, mtp = 0 - # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344 - # moe_max_num_tokens = max_num_tokens * 2 = 18688 - # It can avoid OOM for 8k/1k cases. - default_moe_max_num_tokens = 18688 - if moe_max_num_tokens > default_moe_max_num_tokens: - model_config._frozen = False - model_config.moe_max_num_tokens = default_moe_max_num_tokens - model_config._frozen = True + # moe_max_num_tokens is set in ModelConfig.__post_init__ if not specified + # The default value is max_num_tokens * dp_size + # For DeepGemm, we need to limit moe_max_num_tokens to avoid OOM + # The default moe_max_num_tokens is calculated from the following formula: + # max_isl = 8196, max_batch_size = 1024, mtp = 0 + # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344 + # moe_max_num_tokens = max_num_tokens * 2 = 18688 + # It can avoid OOM for 8k/1k cases. + default_moe_max_num_tokens = 18688 + if model_config.moe_max_num_tokens > default_moe_max_num_tokens: + model_config._frozen = False + model_config.moe_max_num_tokens = default_moe_max_num_tokens + model_config._frozen = True super().__init__( routing_method=routing_method, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index be52f2b6edd..bcc44479267 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -15,7 +15,7 @@ from ...distributed import allgather from ...expert_statistic import ExpertStatistic from ...model_config import ModelConfig -from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div +from ...utils import AuxStreamType, Fp4QuantizedTensor from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode # isort: off @@ -79,6 +79,8 @@ def __init__( swiglu_alpha: Optional[torch.Tensor] = None, swiglu_beta: Optional[torch.Tensor] = None, swiglu_limit: Optional[torch.Tensor] = None, + init_load_balancer: bool = True, + without_comm: bool = False, ): super().__init__( routing_method=routing_method, @@ -95,6 +97,7 @@ def __init__( swiglu_beta=swiglu_beta, swiglu_limit=swiglu_limit, layer_idx=layer_idx, + init_load_balancer=init_load_balancer, ) sm_version = get_sm_version() @@ -111,40 +114,50 @@ def __init__( # - self.initial_global_assignments, self.slot_start, self.slot_end, etc. # TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future. - self.alltoall_method_type = self.select_alltoall_method_type() - logger.info_once( - f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}", - key="alltoall_method_type") - self.alltoall_workspace = None - self.alltoall_prepare_workspace = None - self.use_low_precision_combine = False - if self.enable_alltoall: - self.use_low_precision_combine = model_config.use_low_precision_moe_combine - - if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: - MnnvlMemory.initialize() - self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( - model_config.mapping) - self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( - model_config.mapping) - elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: - workspace_mb = int( - os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048")) - self.moe_a2a = MoeAlltoAll( - mapping=self.mapping, - max_num_tokens=model_config.max_num_tokens, - top_k=self.routing_method.experts_per_token, - num_experts=self.num_slots, - workspace_size_per_rank=workspace_mb * 1024 * 1024, - ) - elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: - raise NotImplementedError( - "DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet" - ) + # When without_comm=True, skip communication initialization (ConfigurableMoE will handle it) + if not without_comm: + self.alltoall_method_type = self.select_alltoall_method_type() + logger.info_once( + f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}", + key="alltoall_method_type") + self.alltoall_workspace = None + self.alltoall_prepare_workspace = None + self.use_low_precision_combine = False + if self.enable_alltoall: + self.use_low_precision_combine = model_config.use_low_precision_moe_combine + + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: + MnnvlMemory.initialize() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: + workspace_mb = int( + os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048")) + self.moe_a2a = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=model_config.max_num_tokens, + top_k=self.routing_method.experts_per_token, + num_experts=self.num_slots, + workspace_size_per_rank=workspace_mb * 1024 * 1024, + ) + elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: + raise NotImplementedError( + "DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet" + ) + else: + raise NotImplementedError( + f"Unsupported alltoall method type: {self.alltoall_method_type!r}" + ) else: - raise NotImplementedError( - f"Unsupported alltoall method type: {self.alltoall_method_type!r}" - ) + # When without_comm=True, set minimal attributes + # Communication will be handled by parent wrapper (e.g., ConfigurableMoE) + self.alltoall_method_type = AlltoallMethodType.NotEnabled + self.alltoall_workspace = None + self.alltoall_prepare_workspace = None + self.use_low_precision_combine = False + self.moe_a2a = None self._weights_created = False if not model_config.skip_create_weights_in_init: @@ -181,7 +194,7 @@ def select_alltoall_method_type(self) -> AlltoallMethodType: def _supports_load_balancer(self) -> bool: """TRTLLMGenFusedMoE supports load balancer.""" - return True + return self.use_dp and self.parallel_size > 1 @cached_property def enable_alltoall(self): @@ -265,26 +278,39 @@ def load_weights(self, def post_load_weights(self): self.quant_method.post_load_weights(self) - def _quantize_for_post_quant_comm(self, x): - """Quantize inputs prior to post-communication (alltoall/allgather). - Returns: (x, x_sf, x_row, x_col) + def quantize_input(self, x, post_quant_comm: bool = True): + """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation. + + Args: + x: Input tensor to quantize + post_quant_comm: + If True, quantize for post-quant communication path. + If False, quantize for non-communication path + + Returns: (x, x_sf) where x_sf is already reshaped to 2D if needed + + For quantization methods that produce scaling factors: + - x_sf is reshaped from 1D to 2D: [num_elements] -> [batch_size, ceil_div(hidden_size, scaling_vector_size)] + - The 2D shape is required for proper handling in alltoall/allgather operations + - scaling_vector_size is typically the group size for block-wise quantization """ - x_row = x.shape[0] - x_col = x.shape[1] x_sf = None if self.has_w4a8_mxfp4_fp8: - x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( - x, self.fc31_input_dequant[0]) - x_row, x_col = x.shape[0], x.shape[1] + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + x = torch.nn.functional.pad(x, (0, pad_size)) + if post_quant_comm: + x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, self.fc31_input_dequant[0]) + else: + x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, self.fc31_input_gate_dequant[0]) elif self.has_nvfp4: if isinstance(x, Fp4QuantizedTensor): assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication" x_row = x.shape[0] - x_col = x.shape[1] * 2 x, x_sf = x.fp4_tensor, x.scaling_factor else: x_row = x.shape[0] - x_col = x.shape[1] x, x_sf = torch.ops.trtllm.fp4_quantize( x, self.fc31_input_scale, self.scaling_vector_size, False, False) @@ -293,8 +319,8 @@ def _quantize_for_post_quant_comm(self, x): x, False, alignment=self.quant_method.input_hidden_alignment) x_row, x_col = x.shape[0], x.shape[1] elif self.has_deepseek_fp8_block_scales: - # No change required before communication - pass + x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x) + x_row = x.shape[0] elif self.has_w4a16_mxfp4: pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] x = torch.nn.functional.pad(x, (0, pad_size)) @@ -305,22 +331,50 @@ def _quantize_for_post_quant_comm(self, x): raise ValueError( f"unsupported quantization mode for post communication: {self.quant_config.quant_mode}" ) - return x, x_sf, x_row, x_col - def forward_impl( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - *, - do_finalize: bool = True, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, - **kwargs, - ) -> torch.Tensor: + if x_sf is not None: + x_sf = x_sf.view(x_row, -1) - assert x.dtype == torch.bfloat16 + return x, x_sf - # DeepSeekV3 style routing + def run_moe( + self, + x: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + router_logits: Optional[torch.Tensor] = None, + do_finalize: bool = True, + moe_output: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple]: + """ + Run MoE computation with TRTLLMGen backend. + + This method encapsulates the core MoE computation logic, handling different + quantization schemes (fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, + w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8). + + Args: + # Standard MoE interface parameters: + x: Input hidden states (may be pre-quantized) + token_selected_experts: Expert IDs [num_tokens, top_k]. If EPLB is enabled, + this represents expert slots [num_tokens, top_k] instead. + token_final_scales: Final scaling factors for each token + x_sf: Input scale factors (optional, for certain quantization schemes) + + # TRTLLMGen-specific additional parameters: + router_logits: Router logits for integrated routing in some kernels. + Should be None if routing has already been done (e.g., post_quant_comm). + do_finalize: Whether to finalize the output. If False, returns intermediate + results (tuple) for nvfp4 and w4a8_nvfp4_fp8 schemes. + moe_output: Pre-allocated output buffer from workspace (optional). + Used for mnnvlthroughput alltoall backend to avoid extra copies. + + Returns: + If do_finalize=True: final_hidden_states tensor + If do_finalize=False: tuple of intermediate outputs (for nvfp4 and w4a8_nvfp4_fp8) + """ + # Extract routing parameters from routing_method if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod): top_k = self.routing_method.routing_impl.top_k routing_bias = self.routing_method.e_score_correction_bias @@ -334,196 +388,23 @@ def forward_impl( topk_group = None routed_scaling_factor = None - run_post_quant_allgather = (self.use_dp and self.parallel_size > 1 - and not self.enable_alltoall) - post_quant_comm = run_post_quant_allgather or self.enable_alltoall - - x_sf = None - token_selected_experts = None - token_final_scales = None - x_row = x.shape[0] - x_col = x.shape[1] - token_count = x.shape[0] - alltoall_info = None - # Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking) - is_first_call = self.repeat_idx == 0 - is_last_call = self.repeat_idx == self.repeat_count - 1 + routing_bias = routing_bias if router_logits is not None else None - if post_quant_comm: - # Start GPU stage for first call - self._load_balancer_start_wait_gpu_stage(is_first_call) - token_selected_experts, token_final_scales = self.routing_method.apply( - router_logits) - token_selected_experts = token_selected_experts.to(torch.int32) - if token_final_scales is not None: - token_final_scales = token_final_scales.to(torch.bfloat16) - - self._load_balancer_done_wait_gpu_stage(is_first_call) - - ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided - self._load_balancer_update_statistic( - token_selected_experts, - is_first_call, - is_last_call, - ignore_allreduce=ignore_allreduce) - - # Route tokens to slots - token_selected_slots = self._load_balancer_route( - token_selected_experts, self.use_dp) - - # Update expert statistics - ExpertStatistic.set_layer(self.layer_idx) - ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) - - # Use routed slots for subsequent processing - token_selected_experts = token_selected_slots - - x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x) - - if self.enable_alltoall: - assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall" - - runtime_max_tokens_per_rank = max( - all_rank_num_tokens) if all_rank_num_tokens else token_count - - if token_final_scales is None: - token_final_scales = torch.ones_like(token_selected_experts, - dtype=torch.float32) - else: - token_final_scales = token_final_scales.to(torch.float32) - - if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: - assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" - if is_last_call: - loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( - ) - else: - loadbalancer_local_statistic_info = None - alltoall_info, gathered_loadbalancer_local_statistic_info = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( - token_selected_experts, - loadbalancer_local_statistic_info, - self.alltoall_prepare_workspace, - runtime_max_tokens_per_rank, - self.ep_rank, - self.ep_size, - self.num_experts, - self.num_slots, - top_k, - ) - if gathered_loadbalancer_local_statistic_info is not None: - gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( - (self.mapping.moe_ep_size, self.num_experts)) - self._load_balancer_update_statistic_with_gathered_statistic( - gathered_loadbalancer_local_statistic_info) - - if x_sf is not None: - x_sf = x_sf.view(x_row, - ceil_div(x_col, self.scaling_vector_size)) - - x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( - [x, x_sf, token_selected_experts, token_final_scales], - alltoall_info, - self.alltoall_workspace, - self.ep_rank, - self.ep_size, - ) - - torch.ops.trtllm.memset_expert_ids( - token_selected_experts, - alltoall_info.recv_rank_count_cumsum, - runtime_max_tokens_per_rank, - top_k, - -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id - self.ep_size, - ) - - if x_sf is not None: - x_sf = x_sf.flatten() - - if token_final_scales is not None: - token_final_scales = token_final_scales.to(torch.bfloat16) - elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: - if x_sf is not None: - x_sf = x_sf.view(x_row, - ceil_div(x_col, self.scaling_vector_size)) - - payloads = [] - payloads.append(x) - if x_sf is not None: - payloads.append(x_sf) - expert_id_payload_index = 2 - else: - expert_id_payload_index = 1 - payloads.append(token_selected_experts) - payloads.append(token_final_scales) - - recv_tensors = self.moe_a2a.dispatch( - token_selected_experts, - payloads, - runtime_max_tokens_per_rank, - invalid_token_expert_id= - -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id - expert_id_payload_index=expert_id_payload_index, - ) - - if x_sf is not None: - x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors - x_sf = x_sf_recv.view(-1, x_sf_recv.shape[-1]) - else: - x_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors - x = x_recv.view(-1, x_recv.shape[-1]) - token_selected_experts = token_selected_experts_recv.view( - -1, token_selected_experts_recv.shape[-1]) - token_final_scales = token_final_scales_recv.view( - -1, token_final_scales_recv.shape[-1]) - - if x_sf is not None: - x_sf = x_sf.flatten() - - if token_final_scales is not None: - token_final_scales = token_final_scales.to(torch.bfloat16) - else: - raise ValueError( - f"Unsupported moe alltoall method type: {self.alltoall_method_type}" - ) - - elif run_post_quant_allgather: - if x_sf is not None: - x_sf = x_sf.view(x_row, ceil_div(x_col, - self.scaling_vector_size)) - assert len( - x_sf.shape - ) == 2, "The hidden states scaling factor should be 2D tensor before allgather" - x, x_sf, token_selected_experts, token_final_scales = allgather( - [x, x_sf, token_selected_experts, token_final_scales], - self.mapping, - dim=0, - sizes=None if use_dp_padding else all_rank_num_tokens) - if x_sf is not None: - x_sf = x_sf.flatten() - - router_logits_arg = router_logits if not post_quant_comm else None - routing_bias_arg = routing_bias if not post_quant_comm else None - - moe_output: Optional[torch.Tensor] = None - use_workspace_output = False - # TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now - if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided and self.has_w4a8_mxfp4_mxfp8: - moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace( - runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16) - use_workspace_output = True + # Ensure x_sf is 2D before flattening + if x_sf is not None: + assert len( + x_sf.shape + ) == 2, f"x_sf should be 2D tensor, got shape {x_sf.shape}" + x_sf = x_sf.flatten() - # TODO: since routing kernel is integrated into moe_runner for fp8, - # here we just route the I/Os for moe_runner if self.has_deepseek_fp8_block_scales: assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False" - x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x) final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner( - router_logits_arg, - routing_bias_arg, - x_val, - x_scale, + router_logits, + routing_bias, + x, + x_sf, self.w3_w1_weight, self.w3_w1_weight_scaling_factor, self.w2_weight, @@ -533,35 +414,20 @@ def forward_impl( n_group, topk_group, self.intermediate_size_per_partition, - self. - slot_start, # local_expert_start; use ep_rank if stride!=1 - self.expert_size_per_partition, # local_expert_size + self.slot_start, + self.expert_size_per_partition, routed_scaling_factor, self.routing_method.routing_method_type, topk_weights=token_final_scales, topk_ids=token_selected_experts, ) elif self.has_nvfp4: - scale_factor_use_ue8m0 = False - is_scale_factor_swizzled = False # use linear layout here - - if not post_quant_comm: - hidden_states_fp4, hidden_states_scale_linear_fp4 = ( - torch.ops.trtllm.fp4_quantize( - x, - self.fc31_input_scale, - self.scaling_vector_size, - scale_factor_use_ue8m0, - is_scale_factor_swizzled, - )) - else: - hidden_states_fp4, hidden_states_scale_linear_fp4 = x, x_sf outputs = torch.ops.trtllm.fp4_block_scale_moe_runner( - router_logits_arg, - routing_bias_arg, - hidden_states_fp4, - hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn), + router_logits, + routing_bias, + x, + x_sf.view(torch.float8_e4m3fn), self.w3_w1_weight, self.w3_w1_weight_scale.view(torch.float8_e4m3fn), self.w2_weight, @@ -574,9 +440,8 @@ def forward_impl( n_group, topk_group, self.intermediate_size_per_partition, - self. - slot_start, # local_expert_start; use ep_rank if stride!=1 - self.expert_size_per_partition, # local_expert_size + self.slot_start, + self.expert_size_per_partition, routed_scaling_factor, self.routing_method.routing_method_type, do_finalize=do_finalize, @@ -591,17 +456,12 @@ def forward_impl( final_hidden_states = outputs[0] elif self.has_w4a16_mxfp4: assert x.dtype == torch.bfloat16 - if not post_quant_comm: - pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] - x = torch.nn.functional.pad(x, (0, pad_size)) - else: - x = x intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ -2] // 2 final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner( - router_logits_arg, - routing_bias_arg, + router_logits, + routing_bias, x, self.w3_w1_weight, self.w3_w1_weight_scale, @@ -617,12 +477,10 @@ def forward_impl( n_group, topk_group, intermediate_size_per_partition_padded, - self.hidden_size, # valid_hidden_size - self.quant_method. - intermediate_size_per_partition_lean, # valid_intermediate_size - self. - slot_start, # local_expert_start; use ep_rank if stride!=1 - self.expert_size_per_partition, # local_expert_size + self.hidden_size, + self.quant_method.intermediate_size_per_partition_lean, + self.slot_start, + self.expert_size_per_partition, routed_scaling_factor, self.routing_method.routing_method_type, 0, # act_type @@ -633,16 +491,10 @@ def forward_impl( hidden_size].contiguous() elif self.has_w4a8_nvfp4_fp8: - if not post_quant_comm: - hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( - x, 1.0 / self.fc31_input_scale) - else: - hidden_states_fp8 = x - outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner( - router_logits_arg, - routing_bias_arg, - hidden_states_fp8, + router_logits, + routing_bias, + x, self.w3_w1_weight, self.w3_w1_weight_scale.view(torch.float8_e4m3fn), self.w2_weight, @@ -655,9 +507,8 @@ def forward_impl( n_group, topk_group, self.intermediate_size_per_partition, - self. - slot_start, # local_expert_start; use ep_rank if stride!=1 - self.expert_size_per_partition, # local_expert_size + self.slot_start, + self.expert_size_per_partition, routed_scaling_factor, self.routing_method.routing_method_type, do_finalize=do_finalize, @@ -672,19 +523,13 @@ def forward_impl( else: final_hidden_states = outputs[0] elif self.has_w4a8_mxfp4_fp8: - pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] - if not post_quant_comm: - x = torch.nn.functional.pad(x, (0, pad_size)) - x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( - x, self.fc31_input_gate_dequant[0]) - else: - x = x + intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ -2] // 2 final_hidden_states = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner( - router_logits_arg, - routing_bias_arg, + router_logits, + routing_bias, x, self.w3_w1_weight, self.w3_w1_weight_scale, @@ -695,20 +540,18 @@ def forward_impl( self.w2_weight, self.w2_weight_scale, self.w2_bias, - self.fc31_input_dequant, # output1_scales_scalar - self.fc31_input_gate_dequant, # output1_scales_gate_scalar - self.fc2_input_dequant, # output2_scales_scalar + self.fc31_input_dequant, + self.fc31_input_gate_dequant, + self.fc2_input_dequant, self.num_slots, top_k, n_group, topk_group, intermediate_size_per_partition_padded, - self.hidden_size, # valid_hidden_size_per_partition - self.quant_method. - intermediate_size_per_partition_lean, # valid_intermediate_size_per_partition - self. - slot_start, # local_expert_start; use ep_rank if stride!=1 - self.expert_size_per_partition, # local_expert_size + self.hidden_size, + self.quant_method.intermediate_size_per_partition_lean, + self.slot_start, + self.expert_size_per_partition, routed_scaling_factor, self.routing_method.routing_method_type, 0, # act_type @@ -718,23 +561,17 @@ def forward_impl( final_hidden_states = final_hidden_states[:, :self. hidden_size].contiguous() elif self.has_w4a8_mxfp4_mxfp8: - if not post_quant_comm: - # TRTLLM-Gen uses linear SF layout for the mxfp8 input. - mxfp8_x, sf = torch.ops.trtllm.mxfp8_quantize( - x, - False, - alignment=self.quant_method.input_hidden_alignment) - else: - mxfp8_x, sf = x, x_sf + + mxfp8_x, sf = x, x_sf intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ -2] // 2 final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner( - router_logits_arg, - routing_bias_arg, - mxfp8_x, - sf, + router_logits, + routing_bias, + x, + x_sf, self.w3_w1_weight, self.w3_w1_weight_scale, self.w3_w1_bias, @@ -749,12 +586,10 @@ def forward_impl( n_group, topk_group, intermediate_size_per_partition_padded, - self.hidden_size, # valid_hidden_size - self.quant_method. - intermediate_size_per_partition_lean, # valid_intermediate_size - self. - slot_start, # local_expert_start; use ep_rank if stride!=1 - self.expert_size_per_partition, # local_expert_size + self.hidden_size, + self.quant_method.intermediate_size_per_partition_lean, + self.slot_start, + self.expert_size_per_partition, routed_scaling_factor, self.routing_method.routing_method_type, 0, # act_type @@ -767,7 +602,202 @@ def forward_impl( "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes." ) - # Handle load balancer CPU stage if needed + return final_hidden_states + + def forward_impl( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + *, + do_finalize: bool = True, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + + # Get top_k for routing (other routing parameters are extracted inside run_moe) + if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod): + top_k = self.routing_method.routing_impl.top_k + else: + top_k = self.routing_method.top_k + + run_post_quant_allgather = (self.use_dp and self.parallel_size > 1 + and not self.enable_alltoall) + post_quant_comm = run_post_quant_allgather or self.enable_alltoall + + x_sf = None + token_selected_experts = None + token_final_scales = None + token_count = x.shape[0] + alltoall_info = None + # Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking) + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 + + if post_quant_comm: + self._load_balancer_start_wait_gpu_stage(is_first_call) + + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits) + token_selected_experts = token_selected_experts.to(torch.int32) + if token_final_scales is not None: + token_final_scales = token_final_scales.to(torch.bfloat16) + + self._load_balancer_done_wait_gpu_stage(is_first_call) + + ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided + self._load_balancer_update_statistic( + token_selected_experts, + is_first_call, + is_last_call, + ignore_allreduce=ignore_allreduce) + + # Route tokens to slots + token_selected_slots = self._load_balancer_route( + token_selected_experts, self.use_dp) + + # Update expert statistics + ExpertStatistic.set_layer(self.layer_idx) + ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) + + # Use routed slots for subsequent processing + token_selected_experts = token_selected_slots + + x, x_sf = self.quantize_input(x) + + if self.enable_alltoall: + assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall" + + runtime_max_tokens_per_rank = max( + all_rank_num_tokens) if all_rank_num_tokens else token_count + + if token_final_scales is None: + token_final_scales = torch.ones_like(token_selected_experts, + dtype=torch.float32) + else: + token_final_scales = token_final_scales.to(torch.float32) + + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: + assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" + if is_last_call: + loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor( + ) + else: + loadbalancer_local_statistic_info = None + alltoall_info, gathered_loadbalancer_local_statistic_info = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + token_selected_experts, + loadbalancer_local_statistic_info, + self.alltoall_prepare_workspace, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.num_experts, + self.num_slots, + top_k, + ) + if gathered_loadbalancer_local_statistic_info is not None: + gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( + (self.mapping.moe_ep_size, self.num_experts)) + self._load_balancer_update_statistic_with_gathered_statistic( + gathered_loadbalancer_local_statistic_info) + + if self.enable_alltoall: + if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided: + x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( + [x, x_sf, token_selected_experts, token_final_scales], + alltoall_info, + self.alltoall_workspace, + self.ep_rank, + self.ep_size, + ) + + torch.ops.trtllm.memset_expert_ids( + token_selected_experts, + alltoall_info.recv_rank_count_cumsum, + runtime_max_tokens_per_rank, + top_k, + -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id + self.ep_size, + ) + + if token_final_scales is not None: + token_final_scales = token_final_scales.to(torch.bfloat16) + elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided: + payloads = [] + payloads.append(x) + if x_sf is not None: + payloads.append(x_sf) + expert_id_payload_index = 2 + else: + expert_id_payload_index = 1 + payloads.append(token_selected_experts) + payloads.append(token_final_scales) + + recv_tensors = self.moe_a2a.dispatch( + token_selected_experts, + payloads, + runtime_max_tokens_per_rank, + invalid_token_expert_id= + -1, # Caution: TRTLLM-Gen uses -1 as invalid token expert id + expert_id_payload_index=expert_id_payload_index, + ) + + if x_sf is not None: + x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors + x_sf = x_sf_recv.view(-1, x_sf_recv.shape[-1]) + else: + x_recv, token_selected_experts_recv, token_final_scales_recv = recv_tensors + x = x_recv.view(-1, x_recv.shape[-1]) + token_selected_experts = token_selected_experts_recv.view( + -1, token_selected_experts_recv.shape[-1]) + token_final_scales = token_final_scales_recv.view( + -1, token_final_scales_recv.shape[-1]) + + if token_final_scales is not None: + token_final_scales = token_final_scales.to(torch.bfloat16) + else: + raise ValueError( + f"Unsupported moe alltoall method type: {self.alltoall_method_type}" + ) + + elif run_post_quant_allgather: + if x_sf is not None: + assert len( + x_sf.shape + ) == 2, "The hidden states scaling factor should be 2D tensor before allgather" + x, x_sf, token_selected_experts, token_final_scales = allgather( + [x, x_sf, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + else: + # No communication path: use non-post-quant-comm quantization + x, x_sf = self.quantize_input(x, post_quant_comm=False) + + moe_output: Optional[torch.Tensor] = None + use_workspace_output = False + # TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now + if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided and self.has_w4a8_mxfp4_mxfp8: + moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace( + runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16) + use_workspace_output = True + + # Call the extracted run_moe interface + # Determine router_logits based on post_quant_comm + router_logits_arg = None if post_quant_comm else router_logits + + final_hidden_states = self.run_moe( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + # TRTLLMGenFusedMoE extra parameters + router_logits=router_logits_arg, + do_finalize=do_finalize, + moe_output=moe_output, + ) + self._load_balancer_start_set_cpu_stage(is_last_call) # Combine results if using alltoall diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py index a13ff07bad8..3e8cc6ec887 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py @@ -81,9 +81,9 @@ def __init__( self.num_experts) self.expert_size_per_partition = self.expert_end - self.expert_start - # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens + # moe_max_num_tokens is set in ModelConfig.__post_init__ if not specified + # The default value is max_num_tokens * dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens self._weights_created = False if not model_config.skip_create_weights_in_init: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 3b3a58ef288..7574652db82 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -84,11 +84,12 @@ def __init__( self.use_cuda_graph = model_config.use_cuda_graph - # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens + # moe_max_num_tokens is set in ModelConfig.__post_init__ if not specified + # The default value is max_num_tokens * dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < moe_max_num_tokens: + default_moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + if self.moe_max_num_tokens < default_moe_max_num_tokens: self.aux_stream = aux_stream_dict[ AuxStreamType. MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index 1856c287264..cd064d81218 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -1,7 +1,7 @@ import weakref from abc import abstractmethod from enum import Enum, IntEnum -from typing import Dict, List, Optional, Union, final +from typing import Dict, List, Optional, Tuple, Union, final import torch from torch import nn @@ -147,6 +147,7 @@ def __init__( swiglu_limit: Optional[torch.Tensor] = None, layer_idx: Optional[int] = None, activation_type: ActivationType = ActivationType.Swiglu, + init_load_balancer: bool = True, ): from ...distributed import AllReduce @@ -197,7 +198,23 @@ def __init__( dtype=self.dtype) # Initialize load balancer related attributes - self._init_load_balancer(model_config, aux_stream_dict) + if init_load_balancer: + self._init_load_balancer(model_config, aux_stream_dict) + else: + # When init_load_balancer=False, initialize minimal attributes + # These will be synced from the parent wrapper (e.g., ConfigurableMoE) later + self.aux_stream_dict = aux_stream_dict + self.layer_load_balancer = None + self.repeat_idx = 0 + self.repeat_count = 1 + self.expert_size_per_partition = self.num_experts // self.ep_size + self.num_slots = self.num_experts + self.slot_start = self.ep_rank * self.expert_size_per_partition + self.slot_end = self.slot_start + self.expert_size_per_partition + self.initial_local_expert_ids = list( + range(self.slot_start, self.slot_end)) + self.initial_global_assignments = list(range(self.num_experts)) + self.allreduce = None def _init_load_balancer( self, @@ -311,6 +328,10 @@ def _supports_load_balancer(self) -> bool: """ return False + def _using_load_balancer(self) -> bool: + """Check if this MoE is using load balancer.""" + return self.layer_load_balancer is not None + def _using_dynamic_load_balancer(self) -> bool: """Check if this MoE is using dynamic load balancer.""" if self.layer_load_balancer: @@ -348,7 +369,7 @@ def _load_balancer_update_statistic(self, token_selected_experts: The selected experts of all tokens, has shape of [tokenCount * topK] is_first_call: Whether this is the first call for the same weights is_last_call: Whether this is the last call for the same weights - ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. NVLINK_TWO_SIDED supports this. + ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. NVLINKTwoSided supports this. """ if self._using_dynamic_load_balancer(): if ignore_allreduce: @@ -488,6 +509,72 @@ def load_weights(self, weights: List[Dict]): def post_load_weights(self): pass + @abstractmethod + def quantize_input( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + **kwargs, + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], Dict]: + """ + Quantize input tensor - unified interface for all MoE backends + + NOTE: This is a temporary interface. In the future, this method should be moved + to the MoEBackend interface as part of the backend abstraction layer. + + This method handles quantization of input tensors before MoE computation. + All MoE backend implementations must override this method to implement their + specific quantization logic. + + Args: + x: Input tensor [num_tokens, hidden_size] or Fp4QuantizedTensor + **kwargs: Backend-specific arguments (e.g., token_selected_experts, workspace, etc.) + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]] or Dict: + (quantized_x, scaling_factors) + where scaling_factors should be reshaped to 2D if applicable + + Examples: + Simple backends (Cutlass, WideEP, TRTLLMGen): + return x_quantized, x_sf # x_sf is 2D or None + """ + raise NotImplementedError + + @abstractmethod + def run_moe( + self, + # ========== Common parameters (all backends use) ========== + x: torch.Tensor, + token_selected_experts: Optional[torch.Tensor], + token_final_scales: Optional[torch.Tensor], + x_sf: Optional[torch.Tensor] = None, + # ========== Backend-specific parameters (via kwargs) ========== + **kwargs + ) -> torch.Tensor: + """ + Unified MoE computation interface + + NOTE: This is a TEMPORARY interface. In the future, this method should be moved + to the MoEBackend interface as part of the backend abstraction layer. + + This method performs the core MoE computation. Different backends will implement + their specific computation logic while following this unified interface. + + Common parameters (all backends use): + x: Input activations [num_tokens, hidden_size] + token_selected_experts: Expert IDs [num_tokens, top_k] (used by DeepGemm/TRTLLMGen). + If EPLB is enabled, this represents expert slots [num_tokens, top_k]. + token_final_scales: Routing weights [num_tokens, top_k] + x_sf: Input scale factor (for quantization, if applicable) + + Backend-specific parameters (passed via kwargs, obtained from _get_backend_kwargs()): + TODO: This is not finalized, will be updated later. + + Returns: + torch.Tensor: MoE computation result [num_tokens, hidden_size] + """ + raise NotImplementedError + @abstractmethod def forward_impl( self, diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index 91b906241df..21ab127a366 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -51,7 +51,9 @@ def create_nemotron_h_llm(use_cuda_graph, if mamba_ssm_cache_dtype is None else mamba_ssm_cache_dtype), sampler_type="TRTLLMSampler", enable_chunked_prefill=enable_chunked_prefill, - max_num_tokens=max_num_tokens, + **({} if max_num_tokens is None else { + "max_num_tokens": max_num_tokens + }), ) diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 1bff8d83d7c..61c72061773 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -16,6 +16,7 @@ per_token_cast_to_fp8_e8m0) from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor +from transformers.configuration_utils import PretrainedConfig from utils.util import (check_accuracy, skip_blackwell, skip_blackwell_geforce, skip_neither_ada_nor_hopper_unittest, skip_non_hopper_unittest, skip_pre_blackwell, @@ -139,14 +140,20 @@ def test_fused_moe(moe_backend, weights[f"{expert_id}.w1.weight"] = w1_weight weights[f"{expert_id}.w2.weight"] = w2_weight weights[f"{expert_id}.w3.weight"] = w3_weight + + # Create pretrained_config with necessary parameters + pretrained_config = PretrainedConfig() + pretrained_config.num_experts = NUM_EXPERTS + pretrained_config.hidden_size = HIDDEN_SIZE + pretrained_config.intermediate_size = INTERMEDIATE_SIZE + pretrained_config.torch_dtype = dtype + fused_moe = create_moe( - num_experts=NUM_EXPERTS, routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, reduce_results=True, - model_config=ModelConfig(mapping=mapping, moe_backend=moe_backend), + model_config=ModelConfig(pretrained_config=pretrained_config, + mapping=mapping, + moe_backend=moe_backend), bias=bias, ) fused_moe.load_weights([weights]) @@ -589,14 +596,18 @@ def test_fused_moe_fp8(moe_backend, dtype, routing_cls, bias): weights[f"{expert_id}.w2.input_scale"] = w2_input_scale weights[f"{expert_id}.w3.input_scale"] = w3_input_scale + # Create pretrained_config with necessary parameters + pretrained_config = PretrainedConfig() + pretrained_config.num_experts = NUM_EXPERTS + pretrained_config.hidden_size = HIDDEN_SIZE + pretrained_config.intermediate_size = INTERMEDIATE_SIZE + pretrained_config.torch_dtype = dtype + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8) - fused_moe = create_moe(num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, + fused_moe = create_moe(routing_method=routing_method, reduce_results=False, model_config=ModelConfig( + pretrained_config=pretrained_config, quant_config=quant_config, moe_backend=moe_backend), bias=bias) @@ -1446,14 +1457,19 @@ def test_fused_moe_nvfp4(dtype, moe_backend): weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + + # Create pretrained_config with necessary parameters + pretrained_config = PretrainedConfig() + pretrained_config.num_experts = NUM_EXPERTS + pretrained_config.hidden_size = HIDDEN_SIZE + pretrained_config.intermediate_size = INTERMEDIATE_SIZE + pretrained_config.torch_dtype = dtype + fused_moe = create_moe( - num_experts=NUM_EXPERTS, routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, reduce_results=True, - model_config=ModelConfig(quant_config=quant_config, + model_config=ModelConfig(pretrained_config=pretrained_config, + quant_config=quant_config, moe_backend=moe_backend), ) fused_moe.load_weights([weights]) @@ -1935,14 +1951,19 @@ def test_fused_moe_mxfp4_mxfp8(moe_backend, hidden_unpadded, seq_len, bias): router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_MXFP4_MXFP8) + + # Create pretrained_config with necessary parameters + pretrained_config = PretrainedConfig() + pretrained_config.num_experts = NUM_EXPERTS + pretrained_config.hidden_size = HIDDEN_SIZE_UNPADDED + pretrained_config.intermediate_size = INTERMEDIATE_SIZE_UNPADDED + pretrained_config.torch_dtype = dtype + fused_moe = create_moe( - num_experts=NUM_EXPERTS, routing_method=routing_method, - hidden_size=HIDDEN_SIZE_UNPADDED, - intermediate_size=INTERMEDIATE_SIZE_UNPADDED, - dtype=dtype, reduce_results=True, - model_config=ModelConfig(quant_config=quant_config, + model_config=ModelConfig(pretrained_config=pretrained_config, + quant_config=quant_config, moe_backend=moe_backend), bias=bias, ) @@ -2238,13 +2259,18 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend): weights[f"{expert_id}.w3.weight_scale"] = w3_scale quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_MXFP4) - fused_moe = create_moe(num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, + + # Create pretrained_config with necessary parameters + pretrained_config = PretrainedConfig() + pretrained_config.num_experts = NUM_EXPERTS + pretrained_config.hidden_size = HIDDEN_SIZE + pretrained_config.intermediate_size = INTERMEDIATE_SIZE + pretrained_config.torch_dtype = dtype + + fused_moe = create_moe(routing_method=routing_method, reduce_results=False, model_config=ModelConfig( + pretrained_config=pretrained_config, quant_config=quant_config, moe_backend=moe_backend)) fused_moe.load_weights([weights])