Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
self.allreduce_strategy = get_all_reduce_strategy(
self.allreduce_strategy)

# Set default moe_max_num_tokens if not specified
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
if self.moe_max_num_tokens is None:
self.moe_max_num_tokens = self.max_num_tokens * self.mapping.dp_size

@property
def torch_dtype(self) -> torch.dtype:
"""Get the torch dtype of the model."""
Expand Down
17 changes: 16 additions & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from ..modules.attention import MLA
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoE,
MoEWeightLoadingMode, create_moe)
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
from ..modules.gated_mlp import GatedMLP
Expand Down Expand Up @@ -382,6 +382,21 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
"gate_proj": "w1",
})
module.load_weights(weights=[module_weights])
elif names[-1] == "backend" and isinstance(module, MoE):
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
# and weights loading is done in the backend, so module name includes '.backend'.
# We need to use parent module name (without .backend) to match saved weight names.
# After MoE refactoring is fully complete, all paths will follow this branch.
parent_name = '.'.join(names[:-1])
module_weights = filter_weights(parent_name, weights)
module_weights = rename_moe_weight(module_weights, {
"down_proj": "w2",
"up_proj": "w3",
"gate_proj": "w1",
})
module.load_weights(weights=[module_weights])
elif names[-1] == "self_attn":
continue
elif names[-1] == "next_layer_layernorm":
Expand Down
12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/models/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,18 @@ def load_hf_weights(self, weights: Dict):
module_weights = {}
for k, v in self.hf_params_map.items():
name = name.replace(k, v)

# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
# and weights loading is done in the backend, so module name includes '.backend'.
# We need to use parent module name (without .backend) to match saved weight names.
# After MoE refactoring is fully complete, all paths will follow this branch.
names = name.split('.')
if names[-1] == "backend" and isinstance(module, MoE):
# Backend is under experts module (ConfigurableMoE wrapper)
name = '.'.join(names[:-1])

module_weights = filter_weights(name, weights)

if isinstance(module, MoE):
Expand Down
16 changes: 14 additions & 2 deletions tensorrt_llm/_torch/models/modeling_hunyuan_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (CutlassFusedMoE, RenormalizeMoeRoutingMethod,
VanillaMoE, create_moe)
from ..modules.fused_moe import (CutlassFusedMoE, MoE,
RenormalizeMoeRoutingMethod, VanillaMoE,
create_moe)
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
Expand Down Expand Up @@ -364,6 +365,17 @@ def filter_weights(prefix, weights: Dict):
"lm_head"):
continue
names = name.split('.')

# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
# and weights loading is done in the backend, so module name includes '.backend'.
# We need to use parent module name (without .backend) to match saved weight names.
# After MoE refactoring is fully complete, all paths will follow this branch.
if names[-1] == "backend" and isinstance(module, MoE):
name = '.'.join(names[:-1])
names = name.split('.')

if names[-1] in params_map:
# model.layers.{idx}.mlp.shared_mlp.gate_up_proj or model.layers.{idx}.self_attn.qkv_proj
module_weights = []
Expand Down
22 changes: 22 additions & 0 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,17 @@ def load_single_module(name, module):
return

names = name.split('.')

# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
# and weights loading is done in the backend, so module name includes '.backend'.
# We need to use parent module name (without .backend) to match saved weight names.
# After MoE refactoring is fully complete, all paths will follow this branch.
if names[-1] == "backend" and isinstance(module, MoE):
name = '.'.join(names[:-1])
names = name.split('.')

# WAR: better solution is that llama has its own load_weights function.
if names[-1] == 'next_layer_layernorm':
return
Expand Down Expand Up @@ -968,6 +979,17 @@ def load_single_module(name, module):
return

names = name.split('.')

# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
# and weights loading is done in the backend, so module name includes '.backend'.
# We need to use parent module name (without .backend) to match saved weight names.
# After MoE refactoring is fully complete, all paths will follow this branch.
if names[-1] == "backend" and isinstance(module, MoE):
name = '.'.join(names[:-1])
names = name.split('.')

module_names_breakdown, module_name = names[:-1], names[-1]

if weight_mapper.does_require_special_handling(module_name):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

Available Communication Methods:
- AllGatherReduceScatter: Default fallback method, always available
- MnnvlLatency: MNNVL-optimized communication for latency
- MNNVLThroughput: MNNVL-optimized communication for throughput
- NVLinkTwoSided: NVLINK-optimized communication for latency (formerly MNNVLLatency)
- NVLinkOneSided: NVLINK-optimized communication for throughput (formerly MNNVLThroughput)
- DeepEP: Deep Expert Parallelism with support for large batches
- DeepEPLowLatency: Deep Expert Parallelism optimized for low latency

Expand All @@ -34,16 +34,16 @@
from .communication_factory import CommunicationFactory
from .deep_ep import DeepEP
from .deep_ep_low_latency import DeepEPLowLatency
from .mnnvl_latency import MnnvlLatency
from .mnnvl_throughput import MNNVLThroughput
from .nvlink_one_sided import NVLinkOneSided
from .nvlink_two_sided import NVLinkTwoSided

__all__ = [
# Base classes and types
"Communication",
# Communication strategies
"AllGatherReduceScatter",
"MnnvlLatency",
"MNNVLThroughput",
"NVLinkTwoSided",
"NVLinkOneSided",
"DeepEP",
"DeepEPLowLatency",
# Factory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def __init__(
# Initialize dispatch state
self._dispatch_state = {}

@staticmethod
def is_platform_supported() -> bool:
"""
AllGather + ReduceScatter is always supported as the fallback strategy
"""
return True

def is_workload_feasible(self, all_rank_num_tokens: List[int], num_chunks: int) -> bool:
"""
Check if AllGather is feasible for the given workload at runtime.
Expand Down
25 changes: 25 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/communication/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,31 @@ def __init__(
self.ep_size = mapping.moe_ep_size
self.ep_rank = mapping.moe_ep_rank

# Check platform support and raise error if not supported
if not self.is_platform_supported():
raise RuntimeError(
f"Communication strategy {self.__class__.__name__} "
f"is not supported on this platform."
)
self._is_platform_supported = True

@staticmethod
@abstractmethod
def is_platform_supported() -> bool:
"""
Check if this communication strategy is supported on the current platform.

This method performs platform/hardware checks to determine if the strategy
can be used on the current system.

Returns:
True if platform is supported, False otherwise

Note: This is a static method that can be called before instantiation
to check compatibility without creating an instance.
"""
raise NotImplementedError

@abstractmethod
def is_workload_feasible(
self,
Expand Down
Loading