diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 25af1222aa6..d86a841fab9 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -621,14 +621,12 @@ class AllreduceOp AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size) { - static char* force_nccl_all_reduce_strategy_char = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY"); - bool force_nccl_all_reduce_strategy = (force_nccl_all_reduce_strategy_char != nullptr); AllReduceStrategyType runtime_strategy; if (mStrategy == AllReduceStrategyType::UB) { runtime_strategy = AllReduceStrategyType::UB; } - else if (force_nccl_all_reduce_strategy || mStrategy == AllReduceStrategyType::NCCL) + else if (mStrategy == AllReduceStrategyType::NCCL) { runtime_strategy = AllReduceStrategyType::NCCL; } @@ -936,10 +934,7 @@ class AllreduceOp bool isUsingLowPrecision(size_t message_size) const noexcept { - static char* force_low_precision_allreduce_strategy_char - = std::getenv("FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY"); - bool force_low_precision = (force_low_precision_allreduce_strategy_char != nullptr) - || (mStrategy == AllReduceStrategyType::LOWPRECISION); + bool force_low_precision = mStrategy == AllReduceStrategyType::LOWPRECISION; #ifdef ENABLE_FP8 // Use LowPrecision if PCIe and p2p support and message size is larger than 2MB diff --git a/docs/source/advanced/lowprecision-pcie-allreduce.md b/docs/source/advanced/lowprecision-pcie-allreduce.md index 57ca754c4e1..b7ab5070370 100644 --- a/docs/source/advanced/lowprecision-pcie-allreduce.md +++ b/docs/source/advanced/lowprecision-pcie-allreduce.md @@ -41,12 +41,12 @@ The Low-Precision-AllReduce algorithm can be enabled in two ways: ``` AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.LOWPRECISION); ``` -2. **Environment variable control** with AUTO strategy: + +2. Enable by LlmArgs ``` -// In your code -AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.AUTO); -// Set environment variable before running -export FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY=1 +Set allreduce_strategy field in LlmArgs. +Candidates of strategies are "AUTO", "NCCL", "UB", "MINLATENCY", "ONESHOT", "TWOSHOT", "LOWPRECISION" and "MNNVL". +If no strategy is set, AUTO will be set. ``` ## Performance and Accuracy Considerations @@ -58,8 +58,4 @@ Low-Precision-AllReduce reduces communication volume by using FP8 data format fo Users should evaluate the precision impact on their specific models and workloads. -## Environment Variables - -- `FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY`: When set to `1`, forces the use of low-precision algorithm with AUTO strategy. If the algorithm determines it cannot provide performance benefits, it will automatically fall back to other strategies. - **Note**: When compiling TensorRT-LLM without enabling the `ENABLE_FP8` option, setting Low Precision allreduce will not take effect. diff --git a/examples/pytorch/out_of_tree_example/modeling_opt.py b/examples/pytorch/out_of_tree_example/modeling_opt.py index 11c8b8d6746..320a431bc74 100644 --- a/examples/pytorch/out_of_tree_example/modeling_opt.py +++ b/examples/pytorch/out_of_tree_example/modeling_opt.py @@ -64,24 +64,22 @@ def __init__( config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine, dtype=config.torch_dtype) - self.fc1 = Linear( - config.hidden_size, - config.ffn_dim, - bias=config.enable_bias, - dtype=config.torch_dtype, - mapping=model_config.mapping, - tensor_parallel_mode=TensorParallelMode.COLUMN, - quant_config=model_config.get_quant_config(), - ) - self.fc2 = Linear( - config.ffn_dim, - config.hidden_size, - bias=config.enable_bias, - dtype=config.torch_dtype, - mapping=model_config.mapping, - tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=model_config.get_quant_config(), - ) + self.fc1 = Linear(config.hidden_size, + config.ffn_dim, + bias=config.enable_bias, + dtype=config.torch_dtype, + mapping=model_config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + quant_config=model_config.get_quant_config(), + allreduce_strategy=model_config.allreduce_strategy) + self.fc2 = Linear(config.ffn_dim, + config.hidden_size, + bias=config.enable_bias, + dtype=config.torch_dtype, + mapping=model_config.mapping, + tensor_parallel_mode=TensorParallelMode.ROW, + quant_config=model_config.get_quant_config(), + allreduce_strategy=model_config.allreduce_strategy) self.final_layer_norm = LayerNorm( config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine, diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index e0ac0db1b8e..e42da002f6d 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -6,7 +6,7 @@ try: from ....mapping import Mapping from ...distributed import AllReduce, allgather - from ...modules.linear import AllReduceFusionOp, AllReduceParams + from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy def trtllm_allgather(tensor, dim, sizes=None): rank, world_size = get_rank_world_size() @@ -17,7 +17,7 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): rank, world_size = get_rank_world_size() assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - torch_op = AllReduce(p_config) + torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO) return torch_op(tensor, all_reduce_params=all_reduce_params) @torch.library.custom_op( diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 44ab7b1c8dd..7c188ec38d0 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -307,14 +307,17 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype): super().__init__() self.mapping = mapping self.dtype = dtype - self.enable_mnnvl = (os.environ.get("TRTLLM_MNNVL_AR_ENABLED", - "0") == "1" - and dtype in [torch.bfloat16, torch.float32] - and (not mapping.has_cp())) + assert ( + dtype in MNNVLAllReduce.get_supported_dtypes() + and (not mapping.has_cp()) + ), "MNNVL all reduce only supports dtype {MNNVLAllReduce.get_supported_dtypes()} and without cp." - if self.enable_mnnvl: - self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace( - self.mapping, dtype) + self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace( + self.mapping, dtype) + + @staticmethod + def get_supported_dtypes(): + return (torch.bfloat16, torch.float32) def forward( self, @@ -330,7 +333,7 @@ def forward( Returns: Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Reduced tensor(s) """ - if not self.enable_mnnvl or input.numel() > self.max_num_elements_mnnvl: + if input.numel() > self.max_num_elements_mnnvl: return None fusion_op = all_reduce_params.fusion_op @@ -411,27 +414,27 @@ def __init__(self, For the reference implementation for each pattern, please refer to the following unit test: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py - The LOWPRECISION strategy can be selected either by directly specifying it in the constructor - or by setting the environment variable FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY when using - the AUTO strategy. + The LOWPRECISION strategy can be selected either by directly specifying it in the constructor. """ self.mapping = mapping self.workspace = None self.strategy = strategy + self.mnnvl_allreduce = None - self.force_low_precision_env = os.environ.get( - "FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY") if self.mapping.tp_size > 1: # When Strategy is UB, it is guaranteed that the workspace is not used. if self.strategy != AllReduceStrategy.UB: - if self.strategy == AllReduceStrategy.LOWPRECISION or self.force_low_precision_env is not None: + if self.strategy == AllReduceStrategy.LOWPRECISION: allocate_low_presicion_allreduce_workspace(self.mapping) self.workspace = get_allreduce_workspace(self.mapping) # Initialize MNNVL AllReduce if needed - self.mnnvl_allreduce = MNNVLAllReduce(mapping, - dtype) if dtype else None + if self.strategy == AllReduceStrategy.MNNVL and ( + dtype and dtype in MNNVLAllReduce.get_supported_dtypes() + ) and (not self.mapping.has_cp()): + self.mnnvl_allreduce = MNNVLAllReduce(self.mapping, + dtype) if dtype else None def forward( self, diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index c2f817c25a2..f5a3d5f4199 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -8,6 +8,8 @@ from tensorrt_llm import logger from tensorrt_llm._utils import torch_dtype_to_binding +from tensorrt_llm.functional import AllReduceStrategy +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo @@ -77,6 +79,7 @@ class ModelConfig(Generic[TConfig]): attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM + allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO # If true, enable min-latency mode. Currently only used for Llama4. enable_min_latency: bool = False @@ -106,6 +109,24 @@ def __post_init__(self): self.is_generation = self.is_generation_model( self.pretrained_config.architectures) + def get_all_reduce_strategy(strategy: str = "AUTO"): + maps = { + "AUTO": AllReduceStrategy.AUTO, + "NCCL": AllReduceStrategy.NCCL, + "UB": AllReduceStrategy.UB, + "MINLATENCY": AllReduceStrategy.MIN_LATENCY, + "ONESHOT": AllReduceStrategy.ONESHOT, + "TWOSHOT": AllReduceStrategy.TWOSHOT, + "LOWPRECISION": AllReduceStrategy.LOWPRECISION, + "MNNVL": AllReduceStrategy.MNNVL + } + key = strategy.upper() + return maps[key] if key in maps else AllReduceStrategy.AUTO + + if isinstance(self.allreduce_strategy, str): + self.allreduce_strategy = get_all_reduce_strategy( + self.allreduce_strategy) + @property def fuse_pos_embd(self): if self.attn_backend == 'TRTLLM': diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index ff22d3717ce..f5d3417f88f 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -399,7 +399,8 @@ def __init__(self, overridden_tp_size=shared_tp_size, reduce_output=False) - self.allreduce = AllReduce(self.mapping) + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] self.event_dict = { key: torch.cuda.Event() @@ -628,7 +629,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], eps=config.rms_norm_eps, dtype=config.torch_dtype) self.layer_idx = layer_idx - self.allreduce = AllReduce(self.mapping, dtype=config.torch_dtype) + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + dtype=config.torch_dtype) self.moe_allreduce = MoEAllReduce(self.mapping) self.next_layer_layernorm: RMSNorm = None diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 600808c6b61..d6ffeac2ca1 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -282,7 +282,10 @@ def __init__( quant_config=None) self.mapping = model_config.mapping - self.all_reduce = AllReduce(self.mapping) + self.all_reduce = AllReduce( + mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + ) self.moe_event = [torch.cuda.Event(), torch.cuda.Event()] self.aux_stream = aux_stream @@ -414,7 +417,8 @@ def __init__( dtype=config.torch_dtype) self.mapping = model_config.mapping - self.all_reduce = AllReduce(self.mapping) + self.all_reduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) self.next_layer_layernorm: RMSNorm = None self.next_attn: LlamaAttention = None @@ -625,7 +629,7 @@ def __init__( quant_config=model_config.get_quant_config(), skip_create_weights_in_init=model_config. skip_create_weights_in_init, - ) + allreduce_strategy=model_config.allreduce_strategy) class Eagle3LlamaDecoderLayer(DecoderLayer): diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py index ef562979543..333f52532aa 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @@ -44,7 +44,7 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig], gather_output=True, quant_config=model_config.get_quant_config(), skip_create_weights_in_init=model_config.skip_create_weights_in_init, - ) + allreduce_strategy=model_config.allreduce_strategy) class NemotronNASAttention(Attention): diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index f15a21df31d..5e6f67a8d42 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -89,7 +89,8 @@ def __init__( self.top_k = config.num_experts_per_tok self.enable_attention_dp = model_config.mapping.enable_attention_dp self.mapping = model_config.mapping - self.allreduce = AllReduce(self.mapping) + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) self.enable_alltoall = Qwen3MoE.should_enable_alltoall( model_config, self.top_k) if self.enable_alltoall: @@ -202,7 +203,8 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig], dtype=config.torch_dtype) self.layer_idx = layer_idx - self.allreduce = AllReduce(self.mapping) + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) self.next_layer_layernorm: RMSNorm = None self.fusion_config = EagerFusionConfig() diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index cc9031bc288..94574d3f9d7 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -126,7 +126,7 @@ def __init__( weight_mode=WeightMode.FUSED_QKV_LINEAR), quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, - ) + allreduce_strategy=config.allreduce_strategy) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) @@ -140,7 +140,7 @@ def __init__( quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.o_lora, - ) + allreduce_strategy=config.allreduce_strategy) self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend @@ -481,7 +481,8 @@ def __init__( mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, - skip_create_weights_in_init=config.skip_create_weights_in_init) + skip_create_weights_in_init=config.skip_create_weights_in_init, + allreduce_strategy=config.allreduce_strategy) else: self.fused_a = Linear( hidden_size, @@ -501,7 +502,7 @@ def __init__( tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, - ) + allreduce_strategy=config.allreduce_strategy) self.q_b_proj = self.q_proj self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank, @@ -517,7 +518,8 @@ def __init__( mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config, - skip_create_weights_in_init=config.skip_create_weights_in_init) + skip_create_weights_in_init=config.skip_create_weights_in_init, + allreduce_strategy=config.allreduce_strategy) # This parameter will view into self.kv_b_proj.weight after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. # Used in forward_generation only @@ -538,7 +540,7 @@ def __init__( tensor_parallel_mode=TensorParallelMode.ROW, quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, - ) + allreduce_strategy=config.allreduce_strategy) def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py new file mode 100755 index 00000000000..334919050ec --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -0,0 +1,2515 @@ +import copy +import math +import os +import threading +from enum import Enum, IntEnum +from typing import Dict, List, NamedTuple, Optional, Union + +import torch +from torch import nn + +from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo +from tensorrt_llm._utils import get_sm_version, logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.quantization.utils import fp4_utils +from tensorrt_llm.quantization.utils.fp4_utils import ( + get_reorder_rows_for_gated_act_gemm_row_indices, + get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices, + shuffle_matrix_a, shuffle_matrix_sf_a) + +from ...quantization.utils.fp4_utils import float4_sf_dtype +from ..distributed import allgather, reducescatter +from ..expert_statistic import ExpertStatistic +from ..model_config import ModelConfig, MoeLoadBalancerConfig +from ..utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather, + reswizzle_sf, swizzle_sf, unswizzle_sf) +from .gated_mlp import GatedMLP +from .linear import TensorParallelMode, load_weight_shard +from .moe_load_balancer import MoeLoadBalancer + +# The declarations aligns with moe_kernels.h +# pack inputs into int64, e.g. 4 x bf16 input values +FUSED_MOE_NVFP4_INPUT_DTYPE = torch.int64 +# pack weights into int64, e.g. 16 x nvfp4 weight values +FUSED_MOE_NVFP4_WEIGHT_DTYPE = torch.int64 +# pack weight block scales into int32, e.g. 4 x fp8 weight values +FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE = torch.int32 + + +# The type of method in top-K routing, for use in torch custom op +# Please keep this in sync with the counterpart defined in cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = 0, + # Renormalize: TopK -> Softmax + Renormalize = 1, + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups + DeepSeekV3 = 2, + # Llama4: Top1 -> Sigmoid + Llama4 = 3, + # Qwen3: Softmax -> TopK -> Renormalize + Qwen3 = 4, + # Unspecified + Unspecified = 5. + + +class BaseMoeRoutingMethod(nn.Module): + + def apply(self, _router_logits) -> (torch.Tensor, torch.Tensor): + """ + Applies the routing method to the router logits. + Router logits are usually the output of the router Linear layer, but can be any type for more complex routing methods. + Returns (token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor): + token_selected_experts: shape (num_tokens, experts_per_token). + It is a list of selected expert indices for each token + token_final_scales: shape (num_tokens, experts_per_token). May be None + It contains a final scaling/weighting factor applied to the output of each selected expert before summing the results + """ + raise NotImplementedError("Subclasses must implement this method") + + def get_experts_per_token(self): + return self.top_k + + @property + def experts_per_token(self): + return self.get_experts_per_token() + + @property + def routing_method_type(self): + return RoutingMethodType.Unspecified + + +class DefaultMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__(self, top_k: int): + super().__init__() + self.top_k = top_k + + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + topk_values, topk_indices = torch.topk(torch.nn.functional.softmax( + router_logits.float(), dim=-1), + k=self.top_k, + dim=-1) + return topk_indices.to(torch.int32), topk_values + + @property + def routing_method_type(self): + return RoutingMethodType.Default + + +class DeepSeekV3MoeRoutingMethod(BaseMoeRoutingMethod): + + # Intentionally leave apply() unimplemented. + # See comments in DeepseekV3Gate on why routing is done by DeepseekV3Gate. + def __init__(self, top_k: int): + super().__init__() + self.top_k = top_k + + @property + def routing_method_type(self): + return RoutingMethodType.DeepSeekV3 + + +class RenormalizeMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__( + self, + top_k: int, + ): + super().__init__() + self.top_k = top_k + + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + topk_values, topk_indices = torch.topk(router_logits, + k=self.top_k, + dim=-1) + return topk_indices.to(torch.int32), torch.nn.functional.softmax( + topk_values.float(), dim=-1) + + @property + def routing_method_type(self): + return RoutingMethodType.Renormalize + + +class Llama4RenormalizeMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__(self, top_k: int): + super().__init__() + self.top_k = top_k + + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + topk_values, topk_indices = torch.topk(router_logits, + k=self.top_k, + dim=-1) + return topk_indices.to(torch.int32), torch.sigmoid(topk_values.float()) + + @property + def routing_method_type(self): + return RoutingMethodType.Llama4 + + +# TODO: re-enable this once the custom op is working. +# class Llama4RenormalizeMoeRoutingMethod(BaseMoeRoutingMethod): + +# def __init__(self, top_k: int, num_experts_total: int, ep_size: int, +# ep_rank: int): +# super().__init__() +# self.top_k = top_k +# self.num_experts_total = num_experts_total +# self.num_experts_per_node = self.num_experts_total // ep_size +# self.start_expert = self.num_experts_per_node * ep_rank +# self.end_expert = self.start_expert + self.num_experts_per_node + +# def apply(self, +# router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): +# unpermuted_scales, indices = torch.ops.trtllm.fused_topk_softmax( +# router_logits, self.top_k, self.num_experts_total, +# self.start_expert, self.end_expert) +# return indices, unpermuted_scales + + +# TODO Test this for Phi models +class SparseMixerMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__(self, top_k: int, eps: float): + super().__init__() + self.top_k = top_k + self.eps = eps + + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + router_logits = router_logits.float() + topk_values = torch.empty(router_logits.shape[0], + self.top_k, + device=router_logits.device, + dtype=torch.float32) + topk_indices = torch.empty(router_logits.shape[0], + self.top_k, + device=router_logits.device, + dtype=torch.int32) + for i in range(self.top_k): + if i > 0: + max_elem = torch.argmax(router_logits, dim=-1) + # Mask out the previously selected indices to negative infinity + router_logits.scatter_(-1, max_elem.unsqueeze(-1), + -float('inf')) + # Get the max value of the remaining indices + max_values, max_indices = torch.max(router_logits, + dim=-1, + keepdim=True) + assert torch.all(max_values != -float('inf')) + + topk_indices[:, i] = max_indices.squeeze(-1) + + # Mask out any values that fail the condition '(max - value) / std::max(abs(value), max) > 2 * epsilon' + mask = ( + (max_values - router_logits) / + torch.max(torch.abs(router_logits), max_values)) > 2 * self.eps + masked_logits = torch.where(mask, -float('inf'), router_logits) + softmax_masked_logits = torch.nn.functional.softmax(masked_logits, + dim=-1) + selected_values = torch.gather(softmax_masked_logits, -1, + max_indices) + topk_values[:, i] = selected_values.squeeze(-1) + + return topk_indices.to(torch.int32), topk_values + + +class StaticMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__(self, + routing_tensor: torch.Tensor, + routing_scales: Optional[torch.Tensor] = None): + super().__init__() + assert routing_tensor.dtype == torch.int32 + if routing_scales is not None: + assert routing_tensor.shape[0] == routing_scales.shape[0] + assert routing_tensor.shape[1] == routing_scales.shape[1] + assert routing_scales.dtype == torch.float32 + self.routing_tensor = routing_tensor + self.routing_scales = routing_scales + + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + return self.routing_tensor, self.routing_scales + + def get_experts_per_token(self): + return self.routing_tensor.shape[1] + + +class LoadBalancedMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__(self, top_k: int): + super().__init__() + self.top_k = top_k + + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + balanced_values = torch.ones(router_logits.shape[0], + self.top_k, + device=router_logits.device, + dtype=torch.float32) + balanced_indices = torch.empty(router_logits.shape[0], + self.top_k, + device=router_logits.device, + dtype=torch.int32) + + # Fill the balanced_indices with each expert in round-robin fashion + final_size = router_logits.shape[0] * self.top_k + repeat_count = math.ceil(final_size / router_logits.shape[1]) + indices = torch.arange(router_logits.shape[1], + device=router_logits.device, + dtype=torch.int32) + indices = indices.repeat(repeat_count) + indices = indices[:final_size] + balanced_indices = indices.view(router_logits.shape[0], + self.top_k).contiguous() + + return balanced_indices, balanced_values + + +class Qwen3MoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__(self, top_k: int): + super().__init__() + self.top_k = top_k + + def apply(self, + router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float) + topk_values, topk_indices = torch.topk(routing_weights, + k=self.top_k, + dim=-1) + topk_values /= topk_values.sum(dim=-1, keepdim=True) + return topk_indices.to(torch.int32), topk_values + + @property + def routing_method_type(self) -> RoutingMethodType: + return RoutingMethodType.Qwen3 + + +class MoEWeightLoadingMode(Enum): + VANILLA = 0 + FUSED_GATE_UP_PROJ = 1 + + +class VanillaMoE(nn.ModuleList): + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream: Optional[torch.cuda.Stream] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. + VANILLA, + apply_router_weight_on_input: bool = False, + enable_alltoall: bool = False, + pack_weights: bool = False, + ): + from ..distributed import AllReduce + + super().__init__() + self.routing_method = routing_method + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.weight_loading_mode = weight_loading_mode + self.pack_weights = pack_weights + + self.dtype = dtype + self.reduce_results = reduce_results + self.model_config = model_config + # could be modified later + self.quant_config = model_config.quant_config + + self.cluster_rank = model_config.mapping.moe_cluster_rank + self.cluster_size = model_config.mapping.moe_cluster_size + self.smart_router = True if self.cluster_size > 1 else False + assert not self.smart_router, ( + "Smart router is not supported in vanilla MoE, " + "please set moe_cluster_size to 1.") + + self.rank = model_config.mapping.rank + + self.tp_rank = model_config.mapping.moe_tp_rank + self.tp_size = model_config.mapping.moe_tp_size + + self.ep_size = model_config.mapping.moe_ep_size + self.ep_rank = model_config.mapping.moe_ep_rank + self.moe_backend = model_config.moe_backend + self.use_dp = model_config.mapping.enable_attention_dp + + # All ranks participate in allreduce regardless of EP/TP combination + self.mapping = model_config.mapping + self.parallel_size = self.mapping.tp_size + + self.all_reduce = AllReduce(mapping=self.mapping, + strategy=model_config.allreduce_strategy) + + self.intermediate_size_per_partition = intermediate_size // self.tp_size + + self.expert_size_per_partition = num_experts // self.ep_size + self.expert_start = self.ep_rank * self.expert_size_per_partition + self.expert_end = min( + self.expert_start + self.expert_size_per_partition, + self.num_experts) + self.expert_size_per_partition = self.expert_end - self.expert_start + + max_num_tokens = model_config.max_num_tokens + # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled + if self.use_dp: + max_num_tokens *= model_config.mapping.world_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens + + self.enable_alltoall = False + + self._weights_created = False + if not model_config.skip_create_weights_in_init: + self.create_weights() + + # If True, the router weight will be multiplied on the input rather than at the end of FC2 + self.apply_router_weight_on_input = apply_router_weight_on_input + + def create_experts(self, module_list: nn.ModuleList = None): + if module_list is None: + module_list = self + model_config = copy.copy(self.model_config) + model_config.mapping = Mapping( + world_size=self.mapping.moe_tp_size, + tp_size=self.mapping.moe_tp_size, + rank=self.mapping.moe_tp_rank, + ) + model_config.quant_config = self.quant_config + model_config.skip_create_weights_in_init = False + for expert_idx in range(self.num_experts): + if self.expert_start <= expert_idx < self.expert_end: + module_list[expert_idx] = GatedMLP( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + bias=False, + dtype=self.dtype, + config=model_config, + reduce_output=False, + ) + else: + # use identity as placeholder for unused experts + module_list[expert_idx] = nn.Identity() + + def create_weights(self): + if self._weights_created: + return + self._weights_created = True + + if not self.pack_weights: + self.create_experts() + return + + self.has_any_quant = False + self.has_fp8_qdq = False + self.has_fp8_block_scales = False + self.has_nvfp4 = False + gate_up_proj_shape = ( + self.expert_size_per_partition, + self.intermediate_size_per_partition * 2, + self.hidden_size, + ) + down_proj_shape = ( + self.expert_size_per_partition, + self.hidden_size, + self.intermediate_size_per_partition, + ) + if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True): + self.has_any_quant = True + qc = self.quant_config + if qc.layer_quant_mode.has_fp8_qdq(): + self.has_fp8_qdq = True + + self.gate_up_proj_weight = nn.Parameter( + torch.empty( + gate_up_proj_shape, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + self.gate_up_proj_weight_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.gate_up_proj_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.gate_up_proj_inv_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + + self.down_proj_weight = nn.Parameter( + torch.empty( + down_proj_shape, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + self.down_proj_weight_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.down_proj_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.down_proj_inv_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + elif qc.layer_quant_mode.has_fp8_block_scales(): + self.has_fp8_block_scales = True + + self.gate_up_proj_weight = nn.Parameter( + torch.empty( + gate_up_proj_shape, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + gate_up_proj_scale_shape = ( + self.expert_size_per_partition, + math.ceil(self.intermediate_size_per_partition * 2 / 128), + math.ceil(self.hidden_size / 128), + ) + self.gate_up_proj_weight_scale = nn.Parameter( + torch.empty( + gate_up_proj_scale_shape, + dtype=torch.float32, + ), + requires_grad=False, + ) + # Not really used for Gemm now. + # Only used to quantize output of FP8 attention. + self.gate_up_proj_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.gate_up_proj_inv_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + + self.down_proj_weight = nn.Parameter( + torch.empty( + down_proj_shape, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + down_proj_scale_shape = ( + self.expert_size_per_partition, + math.ceil(self.hidden_size / 128), + math.ceil(self.intermediate_size_per_partition / 128), + ) + self.down_proj_weight_scale = nn.Parameter( + torch.empty( + down_proj_scale_shape, + dtype=torch.float32, + ), + requires_grad=False, + ) + # Not really used for Gemm now. + # Only used to quantize output of FP8 attention. + self.down_proj_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.down_proj_inv_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + elif qc.layer_quant_mode.has_nvfp4(): + self.has_nvfp4 = True + self.scaling_vector_size = 16 + + assert self.hidden_size % self.scaling_vector_size == 0, f"hidden_size {self.hidden_size} must be divisible by scaling_vector_size {self.scaling_vector_size}" + + # Quantized weights + self.gate_up_proj_weight = nn.Parameter( + torch.empty( + [ + self.expert_size_per_partition, + self.intermediate_size_per_partition * 2, + self.hidden_size // 2, + ], + dtype=fp4_utils.float4_e2m1x2, + ), + requires_grad=False, + ) + + # FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE + # Padding is required. See computeSFSize in quantization.h + nrows = fp4_utils.pad_up( + self.intermediate_size_per_partition * 2, 128) + ncols = fp4_utils.pad_up( + self.hidden_size // self.scaling_vector_size, 4) + self.gate_up_proj_weight_scale = nn.Parameter( + torch.empty( + [self.expert_size_per_partition, nrows * ncols], + dtype=fp4_utils.float4_sf_dtype, + ), + requires_grad=False, + ) + + # FP32 per-tensor global scaling factor = 448*6/amax_input + self.gate_up_proj_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.gate_up_proj_inv_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + + # (amax_input*amax_weight) / (448*6*448*6) + self.gate_up_proj_alpha = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + + assert self.intermediate_size_per_partition % self.scaling_vector_size == 0, f"intermediate_size_per_partition {self.intermediate_size_per_partition} must be divisible by scaling_vector_size {self.scaling_vector_size}" + + # Quantized weights + self.down_proj_weight = nn.Parameter( + torch.empty( + [ + self.expert_size_per_partition, + self.hidden_size, + self.intermediate_size_per_partition // 2, + ], + dtype=fp4_utils.float4_e2m1x2, + ), + requires_grad=False, + ) + + # FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE + # Padding is required. See computeSFSize in quantization.h + nrows = fp4_utils.pad_up(self.hidden_size, 128) + ncols = fp4_utils.pad_up( + self.intermediate_size_per_partition // + self.scaling_vector_size, 4) + self.down_proj_weight_scale = nn.Parameter( + torch.empty( + [self.expert_size_per_partition, nrows * ncols], + dtype=fp4_utils.float4_sf_dtype, + ), + requires_grad=False, + ) + + # FP32 per-tensor global scaling factor = 448*6/amax_input + self.down_proj_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.down_proj_inv_input_scale = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + + # (amax_input*amax_weight) / (448*6*448*6) + self.down_proj_alpha = nn.Parameter( + torch.empty( + self.expert_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) + else: + raise ValueError(f'unsupported quant mode: {qc.quant_mode}') + else: + self.gate_up_proj_weight = nn.Parameter( + torch.empty(gate_up_proj_shape, dtype=self.dtype), + requires_grad=False, + ) + self.down_proj_weight = nn.Parameter( + torch.empty(down_proj_shape, dtype=self.dtype), + requires_grad=False, + ) + + def pack_params(self, experts, module_name: str, weight_name: str): + weights = [] + for expert_idx in range(self.expert_start, self.expert_end): + weights.append( + getattr(getattr(experts[expert_idx], module_name), weight_name)) + packed_weight = torch._utils._flatten_dense_tensors(weights) + weights_data = torch._utils._unflatten_dense_tensors( + packed_weight, weights) + for weight, data in zip(weights, weights_data): + weight.data = data + packed_weight = packed_weight.view(len(weights), *weights_data[0].shape) + getattr(self, f"{module_name}_{weight_name}").data = packed_weight + + def load_weights(self, weights: List[Dict]): + from ..models.modeling_utils import filter_weights + + assert self._weights_created + assert len(weights) == 1 + weights = weights[0] + + if self.pack_weights: + experts = nn.ModuleList([None] * self.num_experts) + self.create_experts(experts) + experts.to("cuda") + else: + experts = self + + for expert_idx in range(self.expert_start, self.expert_end): + experts[expert_idx].gate_up_proj.load_weights([ + filter_weights(f"{expert_idx}.w1", weights), + filter_weights(f"{expert_idx}.w3", weights), + ]) + experts[expert_idx].down_proj.load_weights([ + filter_weights(f"{expert_idx}.w2", weights), + ]) + + if self.pack_weights: + for module_name in ["gate_up_proj", "down_proj"]: + for weight_name, _ in getattr(experts[self.expert_start], + module_name).named_parameters(): + self.pack_params(experts, module_name, weight_name) + + def reducescatter_or_allreduce( + self, + inputs, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ): + outputs = inputs + if self.parallel_size > 1 and not self.enable_alltoall: + if self.use_dp: + outputs = reducescatter( + inputs, + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + elif self.reduce_results: + outputs = self.all_reduce(inputs) + return outputs + + def run_experts( + self, + input: torch.Tensor, + expanded_inputs: torch.Tensor, + expanded_scales: torch.Tensor, + sorted_experts: torch.Tensor, + batch_indices: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros( + input.shape, + dtype=input.dtype, + device=input.device, + ) + for expert_idx in range(self.expert_start, self.expert_end): + expert_mask = sorted_experts == expert_idx + if not torch.any(expert_mask): + continue + expanded_input = expanded_inputs[expert_mask] + batch_idx = batch_indices[expert_mask] + expanded_scale = expanded_scales[expert_mask] + + output = self[expert_idx](expanded_input) + final_hidden_states[batch_idx] += output * expanded_scale + return final_hidden_states + + def forward( + self, + x: torch.Tensor, + router_logits: torch.Tensor, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + **kwargs, + ) -> torch.Tensor: + assert x.shape[-1] == self.hidden_size + x = x.view(-1, self.hidden_size) + + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits) + + if self.use_dp and self.parallel_size > 1: + x, token_selected_experts, token_final_scales = allgather( + [x, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + + expert_masks = ((token_selected_experts >= self.expert_start) + & (token_selected_experts < self.expert_end)) + local_selected_experts = token_selected_experts[expert_masks] + sort_indices = torch.argsort(local_selected_experts) + sorted_experts = local_selected_experts[sort_indices] + + batch_indices, nth_experts = torch.where(expert_masks) + batch_indices = batch_indices[sort_indices] + nth_experts = nth_experts[sort_indices] + expanded_inputs = x[batch_indices] + expanded_scales = token_final_scales[batch_indices, nth_experts, None] + + final_hidden_states = self.run_experts( + x, + expanded_inputs, + expanded_scales, + sorted_experts, + batch_indices, + ) + + final_hidden_states = self.reducescatter_or_allreduce( + final_hidden_states, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + ) + return final_hidden_states + + +class FusedMoE(nn.Module): + """ + Fused Mixture of Experts (MoE) Layer with performance tuning. + + Args: + num_experts (int): Number of experts in the MoE layer. + top_k (int): Number of top experts to select for each input token. + hidden_size (int): Size of the hidden state. + intermediate_size (int): Size of the intermediate state. + aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. + dtype (Optional[torch.dtype]): Data type for the weights. + reduce_results (bool): Whether to reduce the results across devices. + model_config (ModelConfig): Configuration object for the model. + enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter + + MoE torch custom op: + cutlass Backend + In min-latency mode: + Quant: + fp8 block scales (SM90 Hopper only): + FusedMoE Op: dynamic quant + gemm1 + swiglu + gemm2 (return tensor list). + fp8 qdq, nvfp4: + FusedMoE Op: gemm1 + swiglu + gemm2 (return tensor list). + + In max-throughput mode: + Quant: + fp8 block scales (SM90 Hopper only): + FusedMoE Op: dynamic quant + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor) + p8 qdq, nvfp4: + FusedMoE Op: scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute (return one tensor) + + trtllm_gen backend: + Only support min-latency mode now (SM100 Blackwell only). + Quant: fp8 block scales quant and nvfp4 quant + FusedMoE Op: routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute + + FusedMoE module: + cutlass Backend (moe_backend="CUTLASS"): + min-latency mode: + routing(topK, etc.) + FusedMoE Op + equals to: routing(topK, etc.) [+ dynamic quant fp8 qdq | optional dynamic quant nvfp4] + gemm1 + swiglu + gemm2 + + max-throughput mode: + routing(topK, etc.) [+ dynamic quant for fp8 qdq and nvfp4 ] [+ fp4_allgather] + FusedMoe Op[no allreduce] + reducescatter, with AttentionDP on + equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter + + trtllm_gen backend (moe_backend="TRTLLM"): + min-latency mode (cutlass_min_latency_mode flag of forward has no effect when trtllm_gen is used): + dynamic quant + FusedMoe Op + equals to: dynamic quant + routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute + + In min-latency mode, setting `reduce_results=False` disables the AllReduce in the FusedMoE module, so any necessary AllReduce operations must be added explicitly in the model definition. + AttentionDP should be turned off for min-latency mode. + + When we have redundant expert, we have more weight slots than `num_experts`, in that case, we separate the concepts of expert and slot. + Expert is the concept from model's perspective while slot is the concept from model engine's perspective. + There should be at lease `num_experts` slots in the model engine. More than that is OK, in that case, some experts may have multiple replicas. + """ + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream: Optional[torch.cuda.Stream] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. + VANILLA, + apply_router_weight_on_input: bool = False, + enable_alltoall: bool = False, + moe_load_balancer: Optional[MoeLoadBalancer] = None, + layer_idx: Optional[int] = None, + ): + from ..distributed import AllReduce + + super().__init__() + self.routing_method = routing_method + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.weight_loading_mode = weight_loading_mode + + self.dtype = dtype + self.reduce_results = reduce_results + # could be modified later + self.quant_config = model_config.quant_config + + self.cluster_rank = model_config.mapping.moe_cluster_rank + self.cluster_size = model_config.mapping.moe_cluster_size + self.smart_router = True if self.cluster_size > 1 else False + + self.rank = model_config.mapping.rank + + self.tp_rank = model_config.mapping.moe_tp_rank + self.tp_size = model_config.mapping.moe_tp_size + + self.ep_size = model_config.mapping.moe_ep_size + self.ep_rank = model_config.mapping.moe_ep_rank + self.moe_backend = model_config.moe_backend + self.use_dp = model_config.mapping.enable_attention_dp + + # All ranks participate in allreduce regardless of EP/TP combination + self.mapping = model_config.mapping + self.parallel_size = self.mapping.tp_size + + self.all_reduce = AllReduce(mapping=self.mapping, + strategy=model_config.allreduce_strategy) + + self.intermediate_size_per_partition = intermediate_size // self.tp_size + + self.layer_idx = layer_idx + moe_load_balancer_config = model_config.moe_load_balancer + if moe_load_balancer_config is None: + assert moe_load_balancer is None + # A dummy MoeLoadBalancerConfig to generate default initial_global_assignments + moe_load_balancer_config = MoeLoadBalancerConfig() + moe_load_balancer_config.setup(num_experts=num_experts, + ep_rank=self.ep_rank, + ep_size=self.ep_size) + else: + assert moe_load_balancer is not None + + self.num_slots = moe_load_balancer_config.num_slots + if self.smart_router: + assert self.num_slots == self.num_experts, "Smart router should not have redundant slots" + + self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments( + layer_idx) + self.expert_size_per_partition = moe_load_balancer_config.num_local_slots + self.slot_start = moe_load_balancer_config.slot_start + self.slot_end = moe_load_balancer_config.slot_end + self.initial_local_expert_ids = self.initial_global_assignments[ + self.slot_start:self.slot_end] + assert len( + self.initial_local_expert_ids) == self.expert_size_per_partition + + self.balancer_layer = None + if moe_load_balancer is not None: + self.balancer_layer = moe_load_balancer.add_layer( + expert_count=num_experts, + top_k=routing_method.experts_per_token, + slot_count_per_rank=self.expert_size_per_partition, + ) + self.balancer_layer.set_initial_weight_assignments( + self.initial_global_assignments) + logger.info( + f"MoE load balancer enabled. num_experts = {num_experts}, num_slots = {self.num_slots}, ep_size = {self.ep_size}" + ) + logger.info( + f"initial_global_assignments (layer {layer_idx}) = {self.initial_global_assignments}" + ) + + max_num_tokens = model_config.max_num_tokens + # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled + if self.use_dp: + max_num_tokens *= model_config.mapping.world_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens + # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied + if self.moe_max_num_tokens < max_num_tokens: + self.aux_stream = aux_stream if aux_stream is not None else torch.cuda.Stream( + ) + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeChunkingOverlap] + } + else: + self.aux_stream = None + self.event_dict = None + + # The profiler converges on the same best tactic when the number of tokens is large enough. + # To avoid long profiling time, the max number of tokens used in the profiling is capped to + # around 16k tokens per expert, which is well into the compute bound domain. + self.tune_max_num_tokens = min( + self.moe_max_num_tokens, + 16384 * self.num_slots // routing_method.get_experts_per_token(), + ) + self.has_been_profiled = False + self.has_been_profiled_min_latency = False + + self.enable_alltoall = enable_alltoall + self.use_postquant_alltoall = False + if self.enable_alltoall: + assert self.use_dp and self.parallel_size > 1,\ + "alltoall should only enabled with attention dp and parallel_size > 1" + qm = self.quant_config.quant_mode + self.use_postquant_alltoall = (os.environ.get( + "TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1") + == "1") and qm.has_nvfp4() + self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( + model_config.mapping) if enable_alltoall else None + + self._weights_created = False + if not model_config.skip_create_weights_in_init: + self.create_weights() + + # If True, the router weight will be multiplied on the input rather than at the end of FC2 + self.apply_router_weight_on_input = apply_router_weight_on_input + self._check_configs() + + @property + def has_any_quant(self): + return self.quant_config and self.quant_config.quant_mode.has_any_quant( + exclude_kv_cache=True) + + def _check_configs(self): + if self.enable_alltoall: + assert self.use_dp and self.parallel_size > 1,\ + "alltoall should only enabled with attention dp and parallel_size > 1" + + if self.is_trtllm(): + # trtllm_gen backend only support min-latency mode now + assert not self.apply_router_weight_on_input, "TRTLLM backend does not support applying router weight on input yet." + assert not self.reduce_results + assert self.quant_config and ( + self.quant_config.quant_mode.has_nvfp4() + | self.quant_config.quant_mode.has_fp8_block_scales() + ), "The TRTLLM backend of FusedMoE only supports fp8_block_scaling and nvfp4 dtypes." + else: + if self.apply_router_weight_on_input: + assert self.routing_method.top_k == 1, "Current walkaround only supports top-1 routing" + if self.quant_config and self.quant_config.quant_mode.has_any_quant( + exclude_kv_cache=True): + if not (self.quant_config.quant_mode.has_nvfp4() + | self.quant_config.quant_mode.has_fp8_block_scales() + | self.quant_config.quant_mode.has_fp8_qdq() + | self.quant_config.quant_mode. + is_int4_weight_only_per_group()): + raise ValueError( + f"unsupported quantization mode: {self.quant_config.quant_mode}" + ) + + def setup_quant_scales(self): + self.quant_scales = None + if not self.has_any_quant: + return + if self.has_fp8_qdq: + self.quant_scales = FusedMoEQuantScalesFP8( + fc1_dequant=self.fc31_dequant, + fc2_quant=self.fc2_quant, + fc2_dequant=self.fc2_dequant, + fc1_input_dequant=self.fc31_input_dequant, + ) + elif self.has_fp8_block_scales: + self.quant_scales = FusedMoEQuantScalesFP8BlockScales( + fc_weight_scales=self.w3_w1_weight_scaling_factor, + proj_weight_scales=self.w2_weight_scaling_factor, + ) + elif self.has_nvfp4: + self.quant_scales = FusedMoEQuantScalesNVFP4( + fc1_act_global=self.fc31_input_scale, + fc1_weight_block=self.w3_w1_weight_scale, + fc1_global=self.fc31_alpha, + fc2_act_global=self.fc2_input_scale, + fc2_weight_block=self.w2_weight_scale, + fc2_global=self.fc2_alpha, + ) + elif self.has_w4afp8: + self.quant_scales = FusedMoEQuantScalesW4A8( + scale_1_interleaved=self.fc31_weight_scale, + scale_2_interleaved=self.fc2_weight_scale, + pre_quant_scale_1=self.fc31_act_scale, + pre_quant_scale_2=self.fc2_act_scale, + zero_1=torch.Tensor(), + zero_2=torch.Tensor(), + alpha_1=self.fc31_alpha, + alpha_2=self.fc2_alpha, + ) + + def is_trtllm(self): + return self.moe_backend == "TRTLLM" and self.has_any_quant + + def is_cutlass(self): + return not self.is_trtllm() + + def get_quant_scales(self, slot_start, slot_end): + assert self.smart_router + + if self.has_fp8_block_scales: + return FusedMoEQuantScalesFP8BlockScales( + fc_weight_scales=self.w3_w1_weight_scaling_factor.narrow( + 0, slot_start, slot_end - slot_start), + proj_weight_scales=self.w2_weight_scaling_factor.narrow( + 0, slot_start, slot_end - slot_start), + ) + elif self.has_nvfp4: + return FusedMoEQuantScalesNVFP4( + fc1_act_global=self.fc31_input_scale, + fc1_weight_block=self.w3_w1_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc1_global=self.fc31_alpha.narrow(0, slot_start, + slot_end - slot_start), + fc2_act_global=self.fc2_input_scale, + fc2_weight_block=self.w2_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + fc2_global=self.fc2_alpha.narrow(0, slot_start, + slot_end - slot_start), + ) + elif self.has_w4afp8: + return FusedMoEQuantScalesW4A8( + scale_1_interleaved=self.fc31_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + scale_2_interleaved=self.fc2_weight_scale.narrow( + 0, slot_start, slot_end - slot_start), + pre_quant_scale_1=self.fc31_act_scale.narrow( + 0, slot_start, slot_end - slot_start), + pre_quant_scale_2=self.fc2_act_scale.narrow( + 0, slot_start, slot_end - slot_start), + zero_1=torch.Tensor(), + zero_2=torch.Tensor(), + alpha_1=self.fc31_alpha.narrow(0, slot_start, + slot_end - slot_start), + alpha_2=self.fc2_alpha.narrow(0, slot_start, + slot_end - slot_start), + ) + else: + return self.quant_scales + + def create_weights(self): + if self._weights_created: + return + weight_dtype = self.dtype + w3_w1_weight_shape = (self.expert_size_per_partition, + self.intermediate_size_per_partition * 2, + self.hidden_size) + w2_weight_shape = ( + self.expert_size_per_partition, + self.hidden_size, + self.intermediate_size_per_partition, + ) + + self.quant_scales = [] + self.has_fp8_qdq = False + self.has_fp8_block_scales = False + self.has_nvfp4 = False + self.has_w4afp8 = False + if self.quant_config and self.quant_config.quant_mode.has_any_quant( + exclude_kv_cache=True): + qc = self.quant_config + if qc.quant_mode.has_fp8_qdq(): + self.has_fp8_qdq = True + weight_dtype = torch.float8_e4m3fn + + fc31_dequant = nn.Parameter(torch.empty( + self.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc31_dequant", fc31_dequant) + + fc2_dequant = nn.Parameter(torch.empty( + self.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc2_dequant", fc2_dequant) + + fc2_quant = nn.Parameter(torch.tensor(1., dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc2_quant", fc2_quant) + + fc31_input_dequant = nn.Parameter(torch.tensor( + 1., dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc31_input_dequant", + fc31_input_dequant) + elif qc.quant_mode.has_fp8_block_scales(): + self.has_fp8_block_scales = True + weight_dtype = torch.float8_e4m3fn + cell_div = lambda x, y: (x + y - 1) // y + w3_w1_weight_scaling_factor = nn.Parameter(torch.empty( + (self.expert_size_per_partition, + cell_div(self.intermediate_size_per_partition, 128) * 2, + cell_div(w3_w1_weight_shape[2], 128)), + dtype=torch.float32), + requires_grad=False) + self.register_parameter("w3_w1_weight_scaling_factor", + w3_w1_weight_scaling_factor) + + w2_weight_scaling_factor = nn.Parameter(torch.empty( + (self.expert_size_per_partition, + cell_div(w2_weight_shape[1], + 128), cell_div(w2_weight_shape[2], 128)), + dtype=torch.float32), + requires_grad=False) + self.register_parameter("w2_weight_scaling_factor", + w2_weight_scaling_factor) + elif qc.quant_mode.is_int4_weight_only_per_group(): + self.has_w4afp8 = True + self.sm_version = get_sm_version() + if self.sm_version == 89: + self.interleave = [1, 1] + elif self.sm_version == 90: + self.interleave = [] + for k_shape in [ + self.hidden_size, + self.intermediate_size_per_partition + ]: + if k_shape % 512 == 0: + self.interleave.append(4) + elif k_shape % 256 == 0: + self.interleave.append(2) + elif k_shape % 128 == 0: + self.interleave.append(1) + else: + raise NotImplementedError( + f"K shape is required to be multiple of 128, received {k_shape}." + ) + else: + raise NotImplementedError( + f"W4AFP8 MoE is unsupported on SM{self.sm_version}.") + weight_dtype = torch.int8 + w3_w1_weight_shape = (self.expert_size_per_partition, + self.intermediate_size_per_partition * 2, + self.hidden_size // 2) + w2_weight_shape = (self.expert_size_per_partition, + self.hidden_size, + self.intermediate_size_per_partition // 2) + + fc31_act_scale = nn.Parameter(torch.empty( + self.expert_size_per_partition, 1, dtype=self.dtype), + requires_grad=False) + self.register_parameter("fc31_act_scale", fc31_act_scale) + + fc2_act_scale = nn.Parameter(torch.empty( + self.expert_size_per_partition, 1, dtype=self.dtype), + requires_grad=False) + self.register_parameter("fc2_act_scale", fc2_act_scale) + + # col parallel + fc31_weight_scale = nn.Parameter( + torch.empty(self.expert_size_per_partition, + self.hidden_size // (128 * self.interleave[0]), + self.intermediate_size_per_partition * 2 * + self.interleave[0], + dtype=self.dtype), + requires_grad=False) + self.register_parameter("fc31_weight_scale", fc31_weight_scale) + + # row parallel + fc2_weight_scale = nn.Parameter( + torch.empty(self.expert_size_per_partition, + self.intermediate_size_per_partition // + (128 * self.interleave[1]), + self.hidden_size * self.interleave[1], + dtype=self.dtype), + requires_grad=False) + self.register_parameter("fc2_weight_scale", fc2_weight_scale) + + fc31_alpha = nn.Parameter(torch.empty( + self.expert_size_per_partition, 1, dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc31_alpha", fc31_alpha) + + fc2_alpha = nn.Parameter(torch.empty( + self.expert_size_per_partition, 1, dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc2_alpha", fc2_alpha) + elif qc.quant_mode.has_nvfp4(): + self.has_nvfp4 = True + if self.is_trtllm(): + weight_dtype = float4_sf_dtype + weight_vec_size = torch.iinfo(weight_dtype).bits // 4 + block_scales_dtype = torch.float8_e4m3fn + block_scales_vec_size = 1 + else: + weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE + weight_vec_size = torch.iinfo(weight_dtype).bits // 4 + block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE + block_scales_vec_size = torch.iinfo( + block_scales_dtype).bits // 8 + + self.scaling_vector_size = 16 + # Divide by 16 because we use int64 to pack 16 fp4 values + w3_w1_weight_shape = (self.expert_size_per_partition, + self.intermediate_size_per_partition * 2, + self.hidden_size // weight_vec_size) + w2_weight_shape = (self.expert_size_per_partition, + self.hidden_size, + self.intermediate_size_per_partition // + weight_vec_size) + + # Divide by 4 because we use int32 to pack 4 fp8 values + # column parallel + w3_w1_weight_scale = nn.Parameter( + torch.ones(self.expert_size_per_partition, + self.intermediate_size_per_partition * 2, + self.hidden_size // self.scaling_vector_size // + block_scales_vec_size, + dtype=block_scales_dtype), + requires_grad=False) + self.register_parameter("w3_w1_weight_scale", + w3_w1_weight_scale) + + # row parallel + w2_weight_scale = nn.Parameter(torch.ones( + self.expert_size_per_partition, + self.hidden_size, + self.intermediate_size_per_partition // + self.scaling_vector_size // block_scales_vec_size, + dtype=block_scales_dtype), + requires_grad=False) + self.register_parameter("w2_weight_scale", w2_weight_scale) + + fc31_input_scale = nn.Parameter(torch.tensor( + 1., dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc31_input_scale", fc31_input_scale) + + fc2_input_scale = nn.Parameter(torch.tensor( + 1., dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc2_input_scale", fc2_input_scale) + + fc31_alpha = nn.Parameter(torch.ones( + self.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc31_alpha", fc31_alpha) + + fc2_alpha = nn.Parameter(torch.ones( + self.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc2_alpha", fc2_alpha) + + if self.is_trtllm(): + fc31_scale_c = nn.Parameter(torch.ones( + self.expert_size_per_partition, dtype=torch.float32), + requires_grad=False) + self.register_parameter("fc31_scale_c", fc31_scale_c) + + else: + # TODO: support other quant mode + raise ValueError( + f"unsupported quantization mode: {qc.quant_mode}") + self.setup_quant_scales() + + # Fused gate_up_proj (column parallel) + w3_w1_weight = nn.Parameter(torch.empty(w3_w1_weight_shape, + dtype=weight_dtype), + requires_grad=False) + self.register_parameter("w3_w1_weight", w3_w1_weight) + + # down_proj (row parallel) + w2_weight = nn.Parameter(torch.empty(w2_weight_shape, + dtype=weight_dtype), + requires_grad=False) + self.register_parameter("w2_weight", w2_weight) + self._weights_created = True + + def reducescatter_or_allreduce( + self, + inputs, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ): + outputs = inputs + if self.parallel_size > 1 and not self.enable_alltoall: + if self.use_dp: + outputs = reducescatter( + inputs, + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + elif self.reduce_results: + outputs = self.all_reduce(inputs) + return outputs + + def forward_chunk( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + cutlass_min_latency_mode: bool = False, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + if isinstance(x, Fp4QuantizedTensor): + assert output_dtype is not None + output_dtype = output_dtype + else: + output_dtype = x.dtype + + use_fp8_block_scaling = False + use_w4a8_group_scaling = False + weight_dtype = self.w3_w1_weight.dtype + + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits) + if self.balancer_layer is None: + token_selected_slots = token_selected_experts + else: + # If attention DP is enabled, token_selected_experts is a local rank tensor, + # so we need to offset the round robin position by ep_rank + token_selected_slots = self.balancer_layer.route( + token_selected_experts, offset_by_ep_rank=self.use_dp) + + # If load balancer is disabled, the statistics are collected from expert IDs. + # If load balancer is enabled, the statistics are collected from expert slot IDs. + ExpertStatistic.set_layer(self.layer_idx) + ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) + + assert token_selected_slots.shape[ + 1] == self.routing_method.experts_per_token + assert token_selected_slots.shape == token_final_scales.shape + assert token_selected_slots.shape[0] == router_logits.shape[0] + assert token_final_scales.dtype == torch.float32 + assert token_selected_slots.dtype == torch.int32 + + if self.apply_router_weight_on_input: + assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing" + assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + x = x * token_final_scales.to(x.dtype) + # TODO: remove this once we have correct fusedmoe kernel ready + token_final_scales = None + + token_count = x.shape[0] + + alltoall_info = None + + if self.enable_alltoall: + x, token_selected_slots, token_final_scales, alltoall_info = \ + self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens, + x, + token_selected_slots, + token_final_scales) + + x_sf = None + if self.has_any_quant: + if self.has_fp8_qdq: + x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( + x, self.fc31_input_dequant) + elif self.has_nvfp4: + if not disable_fp4_allgather() or self.use_postquant_alltoall: + if isinstance(x, Fp4QuantizedTensor): + x, x_sf = x.fp4_tensor, x.scaling_factor + x_row = x.shape[0] + # note: we use uint8 to store 2 fp4 values + x_col = x.shape[1] * 2 + else: + x_row = x.shape[0] + x_col = x.shape[1] + x, x_sf = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, self.scaling_vector_size, + False) + + elif self.has_fp8_block_scales: + use_fp8_block_scaling = True + elif self.has_w4afp8: + use_w4a8_group_scaling = True + weight_dtype = torch.quint4x2 + else: + raise ValueError( + f"unsupported quantization mode: {self.quant_config.quant_mode}" + ) + + if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather( + ) and not self.enable_alltoall: + x, x_sf, token_selected_slots, token_final_scales = allgather( + [x, x_sf, token_selected_slots, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + # Fp4 gemm has extra scaling factor + if x_sf is not None: + x_sf = reswizzle_sf(x_sf, x_row, x_col, + self.scaling_vector_size) + + if self.smart_router and not cutlass_min_latency_mode: + ep_size = self.cluster_size + ep_rank = self.cluster_rank + expert_start = ep_rank * self.num_experts // ep_size + expert_end = min(self.num_experts, + (ep_rank + 1) * self.num_experts // ep_size) + w3_w1_weight = self.w3_w1_weight.narrow(0, expert_start, + expert_end - expert_start) + w2_weight = self.w2_weight.narrow(0, expert_start, + expert_end - expert_start) + cluster_size = self.ep_size + cluster_rank = self.ep_rank + quant_scales = self.get_quant_scales(expert_start, expert_end) + else: + ep_size = self.ep_size + ep_rank = self.ep_rank + w3_w1_weight = self.w3_w1_weight + w2_weight = self.w2_weight + cluster_size = self.cluster_size + cluster_rank = self.cluster_rank + quant_scales = self.quant_scales + + if self.use_postquant_alltoall: + x, x_sf = self.alltoall_postquant_dispatch(x, x_sf, x_row, x_col, + alltoall_info) + + final_hidden_states = torch.ops.trtllm.fused_moe( + x, + token_selected_slots, + token_final_scales, + w3_w1_weight.view(weight_dtype), + w2_weight.view(weight_dtype), + output_dtype, + quant_scales=quant_scales, + input_sf=x_sf, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + cluster_size=cluster_size, + cluster_rank=cluster_rank, + use_fp8_block_scaling=use_fp8_block_scaling, + use_w4a8_group_scaling=use_w4a8_group_scaling, + min_latency_mode=cutlass_min_latency_mode, + tune_max_num_tokens=self.tune_max_num_tokens, + ) + + if cutlass_min_latency_mode: + assert not self.reduce_results + return final_hidden_states + else: + # Custom op requires all inputs are in the same type. + # Only in cutlass_min_latency_mode, the output is a list of tensors. + # Otherwise, the output should be unpacked as a single tensor. + final_hidden_states = final_hidden_states[0] + + if not self.enable_alltoall: + return final_hidden_states + else: + return self.alltoall_combine(final_hidden_states, alltoall_info, + token_count) + + def forward( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + cutlass_min_latency_mode: bool = False, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + """ + cutlass_min_latency_mode has no effect when trtllm_gen backend is enabled. + """ + if self.is_cutlass(): + return self.forward_cutlass(x, router_logits, + cutlass_min_latency_mode, output_dtype, + all_rank_num_tokens, use_dp_padding) + elif self.is_trtllm(): + return self.forward_trtllmgen(x, router_logits) + else: + raise NotImplementedError( + f"FusedMoE only supports CUTLASS or TRTLLM backends, not {self.moe_backend}" + ) + + def forward_cutlass( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + cutlass_min_latency_mode: bool = False, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + assert self.is_cutlass() + + if self.use_dp: + assert all_rank_num_tokens is not None + assert use_dp_padding is not None + num_rows = sum(all_rank_num_tokens) + else: + num_rows = x.shape[0] + + # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks + num_chunks = (num_rows + self.moe_max_num_tokens - + 1) // self.moe_max_num_tokens + + if cutlass_min_latency_mode: + assert num_chunks == 1 and ( + not self.reduce_results + ), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False" + + if use_dp_padding: + all_rank_num_tokens_padded = [max(all_rank_num_tokens) + ] * len(all_rank_num_tokens) + else: + all_rank_num_tokens_padded = all_rank_num_tokens + if num_chunks == 1: + outputs = self.forward_chunk( + x, + router_logits, + cutlass_min_latency_mode, + output_dtype, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding) + outputs = self.reducescatter_or_allreduce( + outputs, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding) + else: + + def split_chunk(split_token_num: int, split_num_chunks: int): + val_div = split_token_num // split_num_chunks + val_mod = split_token_num % split_num_chunks + split_chunk_size_list = [val_div + 1] * val_mod + [val_div] * ( + split_num_chunks - val_mod) + return split_chunk_size_list + + if self.use_dp: + all_rank_chunk_size_list = [ + split_chunk(val, num_chunks) + for val in all_rank_num_tokens_padded + ] + all_rank_num_tokens_list = [[ + val[idx_chunk] for val in all_rank_chunk_size_list + ] for idx_chunk in range(num_chunks)] + chunk_size_list = all_rank_chunk_size_list[self.rank] + if self.enable_alltoall: + all_rank_num_tokens_list = [[ + 1 if val == 0 else val for val in val_list + ] for val_list in all_rank_num_tokens_list] + else: + all_rank_num_tokens_list = [None] * num_chunks + chunk_size_list = split_chunk(x.shape[0], num_chunks) + + x_list = x.split(chunk_size_list) + router_logits_list = router_logits.split(chunk_size_list) + + if not self.enable_alltoall: + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() + + outputs_list = [] + # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap + for idx_chunk, (x, router_logits) in enumerate( + zip(x_list, router_logits_list)): + if not self.enable_alltoall: + if idx_chunk % 2 == 0: + with torch.cuda.stream(self.aux_stream): + outputs = self.forward_chunk( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens_list[ + idx_chunk] if self.use_dp else None, + use_dp_padding=use_dp_padding) + if idx_chunk > 0: + outputs_list[-1] = self.reducescatter_or_allreduce( + outputs_list[-1], + all_rank_num_tokens=all_rank_num_tokens_list[ + idx_chunk - 1], + use_dp_padding=use_dp_padding) + else: + outputs = self.forward_chunk( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens_list[ + idx_chunk] if self.use_dp else None, + use_dp_padding=use_dp_padding) + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = self.reducescatter_or_allreduce( + outputs_list[-1], + all_rank_num_tokens=all_rank_num_tokens_list[ + idx_chunk - 1], + use_dp_padding=use_dp_padding) + else: + outputs = self.forward_chunk( + x, + router_logits, + all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk] + if self.use_dp else None) + + outputs_list.append(outputs) + if not self.enable_alltoall: + if num_chunks % 2 == 0: + outputs_list[-1] = self.reducescatter_or_allreduce( + outputs_list[-1], + all_rank_num_tokens=all_rank_num_tokens_list[-1], + use_dp_padding=use_dp_padding) + else: + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = self.reducescatter_or_allreduce( + outputs_list[-1], + all_rank_num_tokens=all_rank_num_tokens_list[-1], + use_dp_padding=use_dp_padding) + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.MoeChunkingOverlap].record() + self.event_dict[EventType.MoeChunkingOverlap].wait() + outputs = torch.cat(outputs_list) + if self.use_dp: + rank = self.mapping.tp_rank + outputs = outputs[:all_rank_num_tokens[rank]] + return outputs + + def forward_trtllmgen(self, x: torch.Tensor, + router_logits: torch.Tensor) -> torch.Tensor: + assert self.is_trtllm() + assert x.dtype == torch.bfloat16 + + # DeepSeekV3 style routing + if isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod): + top_k = self.routing_method.routing_impl.top_k + routing_bias = self.routing_method.e_score_correction_bias + n_group = self.routing_method.routing_impl.n_group + topk_group = self.routing_method.routing_impl.topk_group + routed_scaling_factor = self.routing_method.routing_impl.routed_scaling_factor + else: + top_k = self.routing_method.top_k + routing_bias = None + n_group = None + topk_group = None + routed_scaling_factor = None + + # TODO: since routing kernel is integrated into moe_runner for fp8, + # here we just route the I/Os for moe_runner + if self.quant_config and self.quant_config.quant_mode.has_fp8_block_scales( + ): + x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x) + + final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner( + router_logits, + routing_bias, + x_val, + x_scale, + self.w3_w1_weight, + self.w3_w1_weight_scaling_factor, + self.w2_weight, + self.w2_weight_scaling_factor, + self.num_slots, + top_k, + n_group, + topk_group, + self.intermediate_size_per_partition, + self. + slot_start, # local_expert_start; use ep_rank if stride!=1 + self.expert_size_per_partition, # local_expert_size + routed_scaling_factor, + self.routing_method.routing_method_type, + ) + elif self.quant_config and self.quant_config.quant_mode.has_nvfp4(): + scale_factor_use_ue8m0 = False + is_scale_factor_swizzled = False # use linear layout here + hidden_states_fp4, hidden_states_scale_linear_fp4 = torch.ops.trtllm.fp4_quantize( + x, self.fc31_input_scale, 16, scale_factor_use_ue8m0, + is_scale_factor_swizzled) + + final_hidden_states = torch.ops.trtllm.fp4_block_scale_moe_runner( + router_logits, + routing_bias, + hidden_states_fp4, + hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn), + self.w3_w1_weight, + self.w3_w1_weight_scale.view(torch.float8_e4m3fn), + self.w2_weight, + self.w2_weight_scale.view(torch.float8_e4m3fn), + self.fc31_scale_c.data, + self.fc31_alpha.data, + self.fc2_alpha.data, + self.num_slots, + top_k, + n_group, + topk_group, + self.intermediate_size_per_partition, + self. + slot_start, # local_expert_start; use ep_rank if stride!=1 + self.expert_size_per_partition, # local_expert_size + routed_scaling_factor, + self.routing_method.routing_method_type, + ) + else: + raise NotImplementedError( + "The TRTLLM backend of FusedMoE only supports fp8_block_scaling and nvfp4 dtypes." + ) + + if self.reduce_results and self.parallel_size > 1: + final_hidden_states = self.all_reduce(final_hidden_states) + + return final_hidden_states + + def alltoall_prepare_maybe_dispatch(self, all_rank_num_tokens: list, + x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor): + top_k = self.routing_method.experts_per_token + expert_count = self.num_experts + # gather router info + max_num_token = max(all_rank_num_tokens) + token_selected_slots = torch.nn.functional.pad( + token_selected_slots, + (0, 0, 0, max_num_token - token_selected_slots.shape[0]), + 'constant', self.num_experts) + token_final_scales = torch.nn.functional.pad( + token_final_scales, + (0, 0, 0, max_num_token - token_final_scales.shape[0])) + gathered_token_selected_slots, gathered_token_final_scales = allgather( + [token_selected_slots, token_final_scales], self.mapping, dim=0) + gathered_token_selected_slots = torch.flatten( + gathered_token_selected_slots.contiguous(), start_dim=0, end_dim=-2) + gathered_token_final_scales = torch.flatten( + gathered_token_final_scales.contiguous(), start_dim=0, end_dim=-2) + gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id( + gathered_token_selected_slots, self.num_experts, self.ep_size) + alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare( + gathered_target_rank_ids, None, gathered_token_selected_slots, + gathered_token_final_scales, max_num_token, expert_count, top_k, + self.ep_rank, self.ep_size) + + if not self.use_postquant_alltoall: + assert not isinstance( + x, Fp4QuantizedTensor + ), "pre-quant alltoall doesn't support fp4 tensor" + x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info, + self.alltoall_workspace, + self.ep_rank, self.ep_size) + + return x, token_selected_slots, token_final_scales, alltoall_info + + def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor, + x_row: int, x_col: int, + alltoall_info: MoEAlltoallInfo): + x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info, + self.alltoall_workspace, self.ep_rank, + self.ep_size) + + if x_sf is not None: + if self.has_nvfp4: + x_sf = unswizzle_sf(x_sf, x_row, x_col, + self.scaling_vector_size) + + x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info, + self.alltoall_workspace, + self.ep_rank, self.ep_size) + + if self.has_nvfp4: + x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, + self.scaling_vector_size) + + return x, x_sf + + def alltoall_combine(self, final_hidden_states: torch.Tensor, + alltoall_info: MoEAlltoallInfo, token_count: int): + top_k = self.routing_method.experts_per_token + if isinstance(final_hidden_states, list): + final_hidden_states = final_hidden_states[0] + final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine( + final_hidden_states, + alltoall_info, + self.alltoall_workspace, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + top_k=top_k, + token_count=token_count) + + return final_hidden_states + + def load_weights(self, weights: List[Dict]): + assert self._weights_created + assert len(weights) == 1 + weights = weights[0] + + def load_expert_w3_w1_weight(w1_weight, + w3_weight, + dst_w3_w1_weight: torch.Tensor, + is_trtllm: bool = False): + w1_weight_shard = load_weight_shard(w1_weight, self.tp_size, + self.tp_rank, + TensorParallelMode.COLUMN) + w3_weight_shard = load_weight_shard(w3_weight, self.tp_size, + self.tp_rank, + TensorParallelMode.COLUMN) + + if is_trtllm: + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Keep weights in device buffer + dst_w3_weight = dst_w3_w1_weight.narrow( + dim=0, start=0, length=self.intermediate_size_per_partition) + dst_w1_weight = dst_w3_w1_weight.narrow( + dim=0, + start=self.intermediate_size_per_partition, + length=self.intermediate_size_per_partition) + dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype)) + dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype)) + + # Get permute indices and chain them together + permute0 = get_reorder_rows_for_gated_act_gemm_row_indices( + dst_w3_w1_weight) + permute1 = get_shuffle_matrix_a_row_indices( + dst_w3_w1_weight, epilogue_tile_m) + permute = permute0[permute1] + + # Shuffle the weight according to permute indices + processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix( + dst_w3_w1_weight, permute.to(dst_w3_w1_weight.device)) + # Copy the result into device buffer + dst_w3_w1_weight.copy_(processed_w31_weight_shard.view( + dst_w3_w1_weight.dtype), + non_blocking=True) + # We are done here so do not continue + return + + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], + dim=0) + + if self.has_w4afp8 and self.sm_version == 89: + import tensorrt_llm.quantization.functional + preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm + packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + w31_weight_shard = packer( + unpacker(w31_weight_shard.cpu()).T.contiguous()).to( + w31_weight_shard.device) + w31_weight_shard = preprocessor(w31_weight_shard, + torch.quint4x2, + torch.float8_e4m3fn, + 89).view(dst_w3_w1_weight.shape) + dst_w3_w1_weight.copy_(w31_weight_shard.view( + dst_w3_w1_weight.dtype), + non_blocking=True) + + def load_expert_w2_weight(w2_weight, + dst_w2_weight: torch.Tensor, + is_trtllm: bool = False): + w2_weight_shard = load_weight_shard(w2_weight, self.tp_size, + self.tp_rank, + TensorParallelMode.ROW) + if is_trtllm: + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Keep weights in device buffer + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), + non_blocking=True) + # Get permuted result + processed_w2_weight = shuffle_matrix_a(dst_w2_weight, + epilogue_tile_m) + # Copy the result into device buffer + dst_w2_weight.copy_(processed_w2_weight.view( + dst_w2_weight.dtype), + non_blocking=True) + # We are done here so do not continue + return + + if self.has_w4afp8 and self.sm_version == 89: + import tensorrt_llm.quantization.functional + preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm + packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4 + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + w2_weight_shard = packer( + unpacker(w2_weight_shard.cpu()).T.contiguous()).to( + w2_weight_shard.device) + w2_weight_shard = preprocessor(w2_weight_shard, torch.quint4x2, + torch.float8_e4m3fn, + 89).view(dst_w2_weight.shape) + + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), + non_blocking=True) + + # Use multi-threading to load expert weights in parallel. + # Even though CPython has global interpreter lock (GIL), + # it's still faster to load weights in parallel because it can utilize + # CPU memory bandwidth better. + threads = [] + + for local_slot_id, expert_id in enumerate( + self.initial_local_expert_ids): + # expert_idx is the local slot index of current rank + expert_idx = local_slot_id + + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight = weights[f"{expert_id}.w1.weight"] + w3_weight = weights[f"{expert_id}.w3.weight"] + w2_weight = weights[f"{expert_id}.w2.weight"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( + 0, 1) + w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0) + w2_weight = weights["down_proj"][expert_id].transpose( + 0, 1).contiguous() + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) + + is_trtllm_nvfp4 = self.is_trtllm( + ) and self.quant_config.quant_mode.has_nvfp4() + + thread = threading.Thread(target=load_expert_w3_w1_weight, + args=(w1_weight, w3_weight, + self.w3_w1_weight.data[expert_idx], + is_trtllm_nvfp4)) + thread.start() + threads.append(thread) + + thread = threading.Thread(target=load_expert_w2_weight, + args=(w2_weight, + self.w2_weight.data[expert_idx], + is_trtllm_nvfp4)) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + + if self.quant_config and self.quant_config.quant_mode.has_any_quant( + exclude_kv_cache=True): + if self.quant_config.quant_mode.has_fp8_qdq(): + self._load_fp8_qdq_scales(weights) + elif self.quant_config.quant_mode.has_nvfp4(): + self._load_nvfp4_scales(weights) + elif self.quant_config.quant_mode.has_fp8_block_scales(): + self._load_fp8_block_scales_scales(weights) + elif self.quant_config.quant_mode.is_int4_weight_only_per_group(): + self._load_int4_groupwise_scales(weights) + else: + raise ValueError( + f"unsupported quantization mode: {self.quant_config.quant_mode}" + ) + # Re-setup quant scales after loading weights as the tensors may have been modified. + self.setup_quant_scales() + + def _load_fp8_block_scales_scales(self, weights: Dict): + all_w2_scales = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"], + self.tp_size, self.tp_rank, + TensorParallelMode.ROW) + for expert_id in self.initial_local_expert_ids + ] + + w2_scales = torch.stack(all_w2_scales) + self.w2_weight_scaling_factor.data.copy_(w2_scales) + + all_w3_scales = [ + load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"], + self.tp_size, self.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in self.initial_local_expert_ids + ] + + all_w1_scales = [ + load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"], + self.tp_size, self.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in self.initial_local_expert_ids + ] + + w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-2) + self.w3_w1_weight_scaling_factor.data.copy_(w3_w1_scales) + + def _load_fp8_qdq_scales(self, weights: Dict): + # Step1: Load input scales. + def load_expert_fc31_input_scale_fp8_qdq( + w1_input_scale, w3_input_scale, + dst_fc31_input_scale: torch.Tensor): + dst_fc31_input_scale.copy_( + max(w1_input_scale[...].reshape([]), + w3_input_scale[...].reshape([]))) + + def load_expert_fc2_input_scale_fp8_qdq( + w2_input_scale, dst_fc2_input_scale: torch.Tensor): + dst_fc2_input_scale.copy_(w2_input_scale[...].reshape([])) + + tmp_fc31_input_scale = torch.empty(self.num_experts, + dtype=torch.float32) + tmp_fc2_input_scale = torch.empty(self.num_experts, dtype=torch.float32) + for expert_id in range(self.num_experts): + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_input_scale = weights[f"{expert_id}.w1.input_scale"] + w3_input_scale = weights[f"{expert_id}.w3.input_scale"] + w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_input_scale = weights[f"gate_up_proj_input_scale"] + w3_input_scale = weights[f"gate_up_proj_input_scale"] + w2_input_scale = weights[f"down_proj_input_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) + + load_expert_fc31_input_scale_fp8_qdq( + w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id]) + + load_expert_fc2_input_scale_fp8_qdq(w2_input_scale, + tmp_fc2_input_scale[expert_id]) + + # max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales. + # It's used to quantize fc31 input inside the MOE op + max_fc31_input_scale = tmp_fc31_input_scale.max() + # max_fc2_input_scale is the maximum of all w2 input scales. + max_fc2_input_scale = tmp_fc2_input_scale.max() + + # Step2: Load weight scales and requantize w3_w1_weight. + tmp_w3_w1_weight_scale = torch.empty(self.expert_size_per_partition, + dtype=torch.float32) + tmp_w2_weight_scale = torch.empty(self.expert_size_per_partition, + dtype=torch.float32) + + def load_expert_w3_w1_weight_scale_fp8_qdq( + w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight_scale: torch.Tensor): + w1_weight_scale = w1_weight_scale[...].reshape([]) + w3_weight_scale = w3_weight_scale[...].reshape([]) + dst_w3_w1_weight_scale.copy_(max(w1_weight_scale, w3_weight_scale)) + + def requantize_expert_w3_w1_weight_fp8_qdq( + w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight: torch.Tensor): + w1_weight_scale = w1_weight_scale[...].reshape([]) + w3_weight_scale = w3_weight_scale[...].reshape([]) + max_w3_w1_weight_scale = max(w1_weight_scale, w3_weight_scale) + + w3_weight = dst_w3_w1_weight.narrow( + dim=0, start=0, length=self.intermediate_size_per_partition).to( + dtype=self.dtype) + w1_weight = dst_w3_w1_weight.narrow( + dim=0, + start=self.intermediate_size_per_partition, + length=self.intermediate_size_per_partition).to( + dtype=self.dtype) + dequant_w3_weight = w3_weight * w3_weight_scale + dequant_w1_weight = w1_weight * w1_weight_scale + requant_w3_weight = (dequant_w3_weight / max_w3_w1_weight_scale).to( + torch.float8_e4m3fn) + requant_w1_weight = (dequant_w1_weight / max_w3_w1_weight_scale).to( + torch.float8_e4m3fn) + + dst_w3_w1_weight.narrow( + dim=0, start=0, + length=self.intermediate_size_per_partition).copy_( + requant_w3_weight) + dst_w3_w1_weight.narrow( + dim=0, + start=self.intermediate_size_per_partition, + length=self.intermediate_size_per_partition).copy_( + requant_w1_weight) + + def load_expert_w2_weight_scale_fp8(w2_weight_scale, + dst_w2_weight_scale: torch.Tensor): + dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([])) + + for local_slot_id, expert_id in enumerate( + self.initial_local_expert_ids): + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_weight_scale = weights[f"gate_up_proj_weight_scale"] + w3_weight_scale = weights[f"gate_up_proj_weight_scale"] + w2_weight_scale = weights[f"down_proj_weight_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) + + expert_idx = local_slot_id + + load_expert_w3_w1_weight_scale_fp8_qdq( + w1_weight_scale, w3_weight_scale, + tmp_w3_w1_weight_scale[expert_idx]) + + requantize_expert_w3_w1_weight_fp8_qdq( + w1_weight_scale, w3_weight_scale, + self.w3_w1_weight.data[expert_idx]) + + load_expert_w2_weight_scale_fp8(w2_weight_scale, + tmp_w2_weight_scale[expert_idx]) + + # Step3: calculate and store final loaded weights + self.fc31_dequant.data.copy_(tmp_w3_w1_weight_scale * + max_fc31_input_scale) + self.fc2_quant.data.copy_(max_fc2_input_scale.reciprocal()) + self.fc2_dequant.data.copy_(tmp_w2_weight_scale * max_fc2_input_scale) + self.fc31_input_dequant.data.copy_(max_fc31_input_scale) + + def _load_nvfp4_scales(self, weights: Dict): + # Step1: Load input scales. + tmp_fc31_input_scale = torch.empty(self.num_experts, + dtype=torch.float32) + tmp_fc2_input_scale = torch.empty(self.num_experts, dtype=torch.float32) + + def load_expert_fc31_input_scale_nvfp4( + w1_input_scale, w3_input_scale, + dst_fc31_input_scale: torch.Tensor): + w1_input_scale = w1_input_scale[...].reshape([]) + w3_input_scale = w3_input_scale[...].reshape([]) + assert torch.allclose( + w1_input_scale, + w3_input_scale), "w1_input_scale != w3_input_scale" + dst_fc31_input_scale.copy_(w1_input_scale) + + def load_expert_fc2_input_scale_nvfp4( + w2_input_scale, dst_fc2_input_scale: torch.Tensor): + dst_fc2_input_scale.copy_(w2_input_scale[...].reshape([])) + + for expert_id in range(self.num_experts): + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_input_scale = weights[f"{expert_id}.w1.input_scale"] + w3_input_scale = weights[f"{expert_id}.w3.input_scale"] + w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_input_scale = weights["gate_up_proj_input_scale"] + w3_input_scale = weights["gate_up_proj_input_scale"] + w2_input_scale = weights["down_proj_input_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) + + load_expert_fc31_input_scale_nvfp4(w1_input_scale, w3_input_scale, + tmp_fc31_input_scale[expert_id]) + load_expert_fc2_input_scale_nvfp4(w2_input_scale, + tmp_fc2_input_scale[expert_id]) + + # fc31_input_scale is the reciprocal of the maximum of all w1 input scales and w3 input scales. + self.fc31_input_scale.data.copy_( + tmp_fc31_input_scale.max().reciprocal()) + # fc2_input_scale is the reciprocal of the maximum of all w2 input scales. + self.fc2_input_scale.data.copy_(tmp_fc2_input_scale.max().reciprocal()) + + if self.is_trtllm(): + block_scales_dtype = torch.float8_e4m3fn + else: + block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE + + # Step2: Load weight block scales and alphas. + def load_expert_w3_w1_weight_scale_nvfp4( + w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight_scale: torch.Tensor, is_trtllm: bool): + w1_weight_scale = load_weight_shard(w1_weight_scale, self.tp_size, + self.tp_rank, + TensorParallelMode.COLUMN) + w3_weight_scale = load_weight_shard(w3_weight_scale, self.tp_size, + self.tp_rank, + TensorParallelMode.COLUMN) + # Keep weights in device buffer + # w3 + dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow( + dim=0, start=0, length=self.intermediate_size_per_partition) + dst_w3_weight_scale.copy_( + w3_weight_scale.view(dst_w3_weight_scale.dtype)) + + # w1 + dst_w1_weight_scale = dst_w3_w1_weight_scale.narrow( + dim=0, + start=self.intermediate_size_per_partition, + length=self.intermediate_size_per_partition) + dst_w1_weight_scale.copy_( + w1_weight_scale.view(dst_w1_weight_scale.dtype)) + + orig_shape = dst_w3_w1_weight_scale.shape + + if is_trtllm: + # FIXME + epilogue_tile_m = 128 + + # Get permute indices and chain them together + permute0 = get_reorder_rows_for_gated_act_gemm_row_indices( + dst_w3_w1_weight_scale) + permute1 = get_shuffle_matrix_sf_a_row_indices( + dst_w3_w1_weight_scale.view(float4_sf_dtype), + epilogue_tile_m, 16) + permute = permute0[permute1] + + # Shuffle the weight according to permute indices + w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix( + dst_w3_w1_weight_scale.view(float4_sf_dtype), + permute.cuda()) + # Assert should only be removed during debugging + assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed" + # Interleave the weight. + processed_w3_w1_weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave( + w3_w1_weight_scale.view(float4_sf_dtype).reshape( + orig_shape)) + # Copy the result into device buffer + dst_w3_w1_weight_scale.copy_( + processed_w3_w1_weight_scale.view( + block_scales_dtype).reshape(orig_shape)) + else: + dst_w3_w1_weight_scale.copy_( + torch.ops.tensorrt_llm.nvfp4_block_scale_interleave( + dst_w3_w1_weight_scale.view(float4_sf_dtype)).view( + block_scales_dtype).reshape(orig_shape)) + + def load_expert_w2_weight_scale_nvfp4(w2_weight_scale, + dst_w2_weight_scale: torch.Tensor, + is_trtllm: bool): + w2_weight_scale = load_weight_shard(w2_weight_scale, self.tp_size, + self.tp_rank, + TensorParallelMode.ROW) + # Keep weights in device buffer + dst_w2_weight_scale.copy_( + w2_weight_scale.view(dst_w2_weight_scale.dtype)) + + orig_shape = dst_w2_weight_scale.shape + if is_trtllm: + epilogue_tile_m = 128 # FIXME: read from kernel + # Assert should only be removed during debugging + assert dst_w2_weight_scale.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed" + # Interleave the weight and copy + dst_w2_weight_scale.copy_( + shuffle_matrix_sf_a( + dst_w2_weight_scale.view(float4_sf_dtype), + epilogue_tile_m, + 16).view(block_scales_dtype).reshape(orig_shape)) + else: + dst_w2_weight_scale.copy_( + torch.ops.tensorrt_llm.nvfp4_block_scale_interleave( + dst_w2_weight_scale.view(float4_sf_dtype)).view( + block_scales_dtype).reshape(orig_shape)) + + def load_expert_fc31_alpha_nvfp4(w1_weight_scale_2, w3_weight_scale_2, + final_fc31_input_scale: torch.Tensor, + dst_fc31_alpha: torch.Tensor): + w1_weight_scale_2 = w1_weight_scale_2[...].reshape([]) + w3_weight_scale_2 = w3_weight_scale_2[...].reshape([]) + assert torch.allclose( + w1_weight_scale_2, + w3_weight_scale_2), "w1_weight_scale_2 != w3_weight_scale_2" + + w3_w1_weight_scale_2 = 1.0 / w1_weight_scale_2 + dst_fc31_alpha.copy_( + 1.0 / (final_fc31_input_scale * w3_w1_weight_scale_2)) + + def load_expert_fc2_alpha_nvfp4(w2_weight_scale_2, + final_fc2_input_scale: torch.Tensor, + dst_w2_alpha: torch.Tensor): + w2_weight_scale_2 = 1.0 / w2_weight_scale_2[...].reshape([]) + dst_w2_alpha.copy_(1.0 / + (final_fc2_input_scale * w2_weight_scale_2)) + + for local_slot_id, expert_id in enumerate( + self.initial_local_expert_ids): + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"] + w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"] + w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() + w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk( + 2, dim=0) + w2_weight_scale = weights["down_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() + w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] + w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] + w2_weight_scale_2 = weights["down_proj_weight_scale_2"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) + + expert_idx = local_slot_id + + load_expert_w3_w1_weight_scale_nvfp4( + w1_weight_scale, w3_weight_scale, + self.w3_w1_weight_scale.data[expert_idx], self.is_trtllm()) + load_expert_w2_weight_scale_nvfp4( + w2_weight_scale, self.w2_weight_scale.data[expert_idx], + self.is_trtllm()) + + load_expert_fc31_alpha_nvfp4(w1_weight_scale_2, w3_weight_scale_2, + self.fc31_input_scale.data, + self.fc31_alpha.data[expert_idx]) + load_expert_fc2_alpha_nvfp4(w2_weight_scale_2, + self.fc2_input_scale.data, + self.fc2_alpha.data[expert_idx]) + if self.is_trtllm(): + self.fc31_scale_c.data.copy_(self.fc2_input_scale.data * + self.fc31_alpha.data, + non_blocking=True) + + def _load_int4_groupwise_scales(self, weights: Dict): + # fc31 scales + assert (len(self.interleave) == 2) + all_w3_input_scales = [ + load_weight_shard(weights[f"{expert_id}.w3.input_scale"]) + for expert_id in self.initial_local_expert_ids + ] + all_w1_input_scales = [ + load_weight_shard(weights[f"{expert_id}.w1.input_scale"]) + for expert_id in self.initial_local_expert_ids + ] + all_w3_w1_input_scales = torch.max(torch.stack(all_w3_input_scales), + torch.stack(all_w1_input_scales)) + all_w3_w1_input_scales = torch.ones_like( + all_w3_w1_input_scales) * all_w3_w1_input_scales.max() + self.fc31_act_scale.data.copy_(1 / all_w3_w1_input_scales) + self.fc31_alpha.data.copy_(all_w3_w1_input_scales.float()) + + all_w3_scales = [ + load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"], + self.tp_size, self.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in self.initial_local_expert_ids + ] + all_w1_scales = [ + load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"], + self.tp_size, self.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in self.initial_local_expert_ids + ] + all_w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-2) + if self.sm_version == 89: + w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(self.dtype) + else: + w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view(self.dtype) + w3_w1_s_shape = w3_w1_scales.shape + w3_w1_scales_interleaved = w3_w1_scales.reshape( + w3_w1_s_shape[0], w3_w1_s_shape[1], + (w3_w1_s_shape[2] // self.interleave[0]), self.interleave[0]) + w3_w1_scales_interleaved = w3_w1_scales_interleaved.permute(0, 2, 1, 3) + w3_w1_scales_interleaved = w3_w1_scales_interleaved.reshape( + w3_w1_s_shape[0], w3_w1_s_shape[2] // self.interleave[0], + w3_w1_s_shape[1] * self.interleave[0]) + self.fc31_weight_scale.data.copy_(w3_w1_scales_interleaved.contiguous()) + + # fc2 scales + all_w2_input_scales = [ + load_weight_shard(weights[f"{expert_id}.w2.input_scale"]) + for expert_id in self.initial_local_expert_ids + ] + all_w2_input_scales = torch.stack(all_w2_input_scales).to(self.dtype) + all_w2_input_scales = torch.ones_like( + all_w2_input_scales) * all_w2_input_scales.max() + self.fc2_act_scale.data.copy_(1 / all_w2_input_scales) + self.fc2_alpha.data.copy_(all_w2_input_scales.float()) + + all_w2_scales = [ + load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"], + self.tp_size, self.tp_rank, + TensorParallelMode.ROW) + for expert_id in self.initial_local_expert_ids + ] + if self.sm_version == 89: + w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( + self.dtype) + else: + w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view( + self.dtype) + w2_s_shape = w2_scales.shape + w2_scales_interleaved = w2_scales.reshape( + w2_s_shape[0], w2_s_shape[1], (w2_s_shape[2] // self.interleave[1]), + self.interleave[1]) + w2_scales_interleaved = w2_scales_interleaved.permute(0, 2, 1, 3) + w2_scales_interleaved = w2_scales_interleaved.reshape( + w2_s_shape[0], w2_s_shape[2] // self.interleave[1], + w2_s_shape[1] * self.interleave[1]) + self.fc2_weight_scale.data.copy_(w2_scales_interleaved.contiguous()) + + +class FusedMoEQuantScalesFP8(NamedTuple): + fc1_dequant: torch.Tensor + fc2_quant: torch.Tensor + fc2_dequant: torch.Tensor + fc1_input_dequant: torch.Tensor + + +class FusedMoEQuantScalesNVFP4(NamedTuple): + fc1_act_global: torch.Tensor + fc1_weight_block: torch.Tensor + # fc1_global_scale = 1.0 / (fc1_weight_global_scale * fc1_act_global_scale) + fc1_global: torch.Tensor + + fc2_act_global: torch.Tensor + fc2_weight_block: torch.Tensor + # fc2_global_scale = 1.0 / (fc2_weight_global_scale * fc2_act_global_scale) + fc2_global: torch.Tensor + + +class FusedMoEQuantScalesFP8BlockScales(NamedTuple): + fc_weight_scales: torch.Tensor + proj_weight_scales: torch.Tensor + + +class FusedMoEQuantScalesW4A8(NamedTuple): + scale_1_interleaved: torch.Tensor + scale_2_interleaved: torch.Tensor + pre_quant_scale_1: torch.Tensor + pre_quant_scale_2: torch.Tensor + zero_1: torch.Tensor + zero_2: torch.Tensor + alpha_1: torch.Tensor + alpha_2: torch.Tensor diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py index e65d96daafb..f87647ce511 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py @@ -69,7 +69,8 @@ def __init__( self.mapping = model_config.mapping self.parallel_size = self.mapping.tp_size - self.all_reduce = AllReduce(self.mapping) + self.all_reduce = AllReduce(mapping=self.mapping, + strategy=model_config.allreduce_strategy) self.intermediate_size_per_partition = intermediate_size // self.tp_size diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index bcf51067a72..d305a3b763e 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -78,7 +78,8 @@ def __init__( self.parallel_size = self.mapping.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size - self.all_reduce = AllReduce(self.mapping) + self.all_reduce = AllReduce(mapping=self.mapping, + strategy=model_config.allreduce_strategy) @abstractmethod def create_weights(self): diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index a727cc93ab9..7fab30e1eee 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -73,7 +73,7 @@ def __init__(self, quant_config=config.get_quant_config(), reduce_output=False, skip_create_weights_in_init=config.skip_create_weights_in_init, - ) + allreduce_strategy=config.allreduce_strategy) self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], [self.hidden_size]) @@ -89,7 +89,7 @@ def __init__(self, reduce_output=reduce_output, skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.down_lora, - ) + allreduce_strategy=config.allreduce_strategy) # These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used, # but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index b0062d043e9..b97f2ea489b 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -13,7 +13,8 @@ import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils from tensorrt_llm._torch.peft.lora.layer import LoraLayer -from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams +from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, + AllReduceStrategy) from tensorrt_llm.mapping import Mapping from ...models.modeling_utils import QuantConfig @@ -658,6 +659,7 @@ def __init__( skip_create_weights_in_init: bool = False, use_custom_cublas_mm: bool = False, lora: Optional[LoraLayer] = None, + allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, ): from ..distributed import AllReduce @@ -694,7 +696,9 @@ def __init__( self.in_features = local_in_features self.out_features = local_out_features - self.all_reduce = AllReduce(self.mapping) if reduce_output else None + self.all_reduce = AllReduce( + mapping=self.mapping, + strategy=allreduce_strategy) if reduce_output else None self._weights_created = False self.reduce_output = reduce_output self.use_custom_cublas_mm = use_custom_cublas_mm diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 2b9019be6eb..55a21dae991 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -88,15 +88,14 @@ def __init__( self.is_paged_state = False # in_proj - self.in_proj = Linear( - d_model, - d_in_proj, - bias=bias, - dtype=dtype, - mapping=self.mapping, - tensor_parallel_mode=TensorParallelMode.COLUMN, - quant_config=config.get_quant_config(), - ) + self.in_proj = Linear(d_model, + d_in_proj, + bias=bias, + dtype=dtype, + mapping=self.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + quant_config=config.get_quant_config(), + allreduce_strategy=config.allreduce_strategy) # conv1d, reuse Linear to store weights since it has support for TP > 1 already self.conv1d = Linear( @@ -108,7 +107,7 @@ def __init__( tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, - ) + allreduce_strategy=config.allreduce_strategy) # A self.A = nn.Parameter( @@ -138,15 +137,14 @@ def __init__( ) # out_proj - self.out_proj = Linear( - d_inner, - d_model, - bias=bias, - dtype=dtype, - mapping=self.mapping, - tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=config.get_quant_config(), - ) + self.out_proj = Linear(d_inner, + d_model, + bias=bias, + dtype=dtype, + mapping=self.mapping, + tensor_parallel_mode=TensorParallelMode.ROW, + quant_config=config.get_quant_config(), + allreduce_strategy=config.allreduce_strategy) def forward( self, diff --git a/tensorrt_llm/_torch/modules/mlp.py b/tensorrt_llm/_torch/modules/mlp.py index 8d026e1fa2f..b38da2177bd 100644 --- a/tensorrt_llm/_torch/modules/mlp.py +++ b/tensorrt_llm/_torch/modules/mlp.py @@ -43,7 +43,8 @@ def __init__(self, weight_mode=WeightMode.VANILLA), quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, - lora=self.up_lora) + lora=self.up_lora, + allreduce_strategy=config.allreduce_strategy) self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], [self.hidden_size]) @@ -56,7 +57,8 @@ def __init__(self, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, - lora=self.down_lora) + lora=self.down_lora, + allreduce_strategy=config.allreduce_strategy) def forward( self, diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 533b21b0502..041ee1f6dad 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -86,6 +86,7 @@ class PyTorchConfig: # If true, enable min-latency mode. Currently only used for Llama4. enable_min_latency: bool = False + allreduce_strategy: str = "AUTO" EXETENDED_EXECUTOR_CONFIG_FIELDS = [ diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index c15e00c8568..e67156ec1ab 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3880,6 +3880,7 @@ class AllReduceStrategy(IntEnum): ONESHOT = 4 TWOSHOT = 5 LOWPRECISION = 6 + MNNVL = 7 class AllReduceFusionOp(IntEnum): diff --git a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py index a2902ede1a8..595ff09d12e 100644 --- a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import pickle import sys import traceback @@ -27,6 +26,7 @@ import tensorrt_llm from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams) +from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.mapping import Mapping cloudpickle.register_pickle_by_value(sys.modules[__name__]) @@ -97,7 +97,6 @@ def row_linear_residual_norm_fusion_forward( reference_output = tuple(t.cuda() for t in reference_output) MPI.COMM_WORLD.barrier() - os.environ["TRTLLM_MNNVL_AR_ENABLED"] = "1" allreduce = AllReduce( mapping=Mapping( @@ -105,6 +104,7 @@ def row_linear_residual_norm_fusion_forward( tp_size=tensor_parallel_size, rank=tensor_parallel_rank, ), + strategy=AllReduceStrategy.MNNVL, dtype=dtype, ) diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index 32b0af5ef8c..66934a7ccc4 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -128,7 +128,7 @@ def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma): tp_size=tensor_parallel_size, rank=rank, ) - ar = AllReduce(mapping, strategy=AllReduceStrategy.UB) + ar = AllReduce(mapping=mapping, strategy=AllReduceStrategy.UB) ar_params = AllReduceParams( strategy=AllReduceStrategy.UB, fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, @@ -220,7 +220,7 @@ def run_single_rank_ar_rms_norm_fp8(tensor_parallel_size, a, b, c, gamma, tp_size=tensor_parallel_size, rank=rank, ) - ar = AllReduce(mapping, strategy=AllReduceStrategy.UB) + ar = AllReduce(mapping=mapping, strategy=AllReduceStrategy.UB) ar_params = AllReduceParams( strategy=AllReduceStrategy.UB, fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8, @@ -605,7 +605,7 @@ def run_single_rank_ar_rms_norm_fp4(tensor_parallel_size, a, b, c, gamma): tp_size=tensor_parallel_size, rank=rank, ) - ar = AllReduce(mapping, strategy=AllReduceStrategy.UB) + ar = AllReduce(mapping=mapping, strategy=AllReduceStrategy.UB) ar_params = AllReduceParams( strategy=AllReduceStrategy.UB, fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, @@ -692,9 +692,9 @@ def __init__(self, tp_size, rank, hidden_size, dtype, eps, norm0_gamma, tp_size=tp_size, rank=rank, ) - self.ar_0 = AllReduce(mapping).cuda() - self.ar_1 = AllReduce(mapping).cuda() - self.ar_2 = AllReduce(mapping).cuda() + self.ar_0 = AllReduce(mapping=mapping).cuda() + self.ar_1 = AllReduce(mapping=mapping).cuda() + self.ar_2 = AllReduce(mapping=mapping).cuda() self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda() self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps, diff --git a/tests/unittest/api_stability/references_committed/llm.yaml b/tests/unittest/api_stability/references_committed/llm.yaml index f2c90635fbe..cbb0f5681e1 100644 --- a/tests/unittest/api_stability/references_committed/llm.yaml +++ b/tests/unittest/api_stability/references_committed/llm.yaml @@ -105,6 +105,9 @@ methods: kv_cache_config: annotation: tensorrt_llm.llmapi.llm_args.KvCacheConfig default: null + allreduce_strategy: + annotation: Optional[Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT', 'LOWPRECISION', 'MNNVL']] + default: AUTO return_annotation: None generate: parameters: