Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: wide_ep support block-wise FP8 on blackwell
Signed-off-by: xxi <xxi@nvidia.com>

	modified:   tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
	new file:   tensorrt_llm/_torch/modules/fused_moe/moe_backend.py
	modified:   tests/unittest/_torch/modules/test_fused_moe.py
  • Loading branch information
xxi-nv committed Sep 1, 2025
commit 77d53e94b4b089db973c8a22ff762ec7e1a4d425
52 changes: 31 additions & 21 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch

from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import logger
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping

from ...distributed import AllReduce, allgather, reducescatter
Expand All @@ -15,8 +16,10 @@
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .deep_ep_utils import buffer_pool, deep_ep_installed
from .interface import MoE
from .moe_backend import MoEBackend, MoEBackendSelection
from .moe_load_balancer import get_moe_load_balancer
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
NVFP4CutlassFusedMoEMethod,
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
Expand Down Expand Up @@ -90,6 +93,9 @@ def __init__(
self.apply_router_weight_on_input = apply_router_weight_on_input
self.layer_idx = layer_idx

# Store original hidden size before any potential padding
self.unpadded_hidden_size = self.hidden_size

moe_load_balancer = get_moe_load_balancer()
self.layer_load_balancer = None
self.repeat_idx = 0
Expand Down Expand Up @@ -227,6 +233,9 @@ def __init__(
self.enable_dummy_allreduce = os.environ.get(
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"

# MoE backend will be lazily initialized when first accessed (see moe_backend property)
self._moe_backend_impl = None

def _check_configs(self):
assert self._weights_created

Expand Down Expand Up @@ -316,7 +325,10 @@ def _get_quant_method(self):
if self.quant_config.layer_quant_mode.has_fp8_qdq():
return FP8QDQFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethod()
if get_sm_version() == 100:
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
return DeepSeekFP8BlockScalesFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_nvfp4():
return NVFP4CutlassFusedMoEMethod()
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
Expand All @@ -339,6 +351,19 @@ def create_weights(self):
self._weights_created = True
self._check_configs()

@property
def moe_backend_impl(self) -> MoEBackend:
"""
Lazily initialize and return the MoE backend.

The backend is selected based on hardware capabilities and quantization
configuration, which are only available after weights are created.
"""
if self._moe_backend_impl is None:
assert self._weights_created, "Weights must be created before accessing moe_backend"
self._moe_backend_impl = MoEBackendSelection.select_backend(self)
return self._moe_backend_impl

def dummy_allreduce(self):
"""
Debug function for eliminating imbalance during performance analysis.
Expand Down Expand Up @@ -389,8 +414,6 @@ def forward_chunk(
if self.layer_load_balancer and is_first_call:
self.layer_load_balancer.start_wait_gpu_stage()

use_deepseek_fp8_block_scale = False
use_w4_group_scaling = False
weight_dtype = self.w3_w1_weight.dtype

token_selected_experts, token_final_scales = self.routing_method.apply(
Expand Down Expand Up @@ -544,9 +567,8 @@ def forward_chunk(
x_sf = x_sf.view((x_row, -1))

elif self.has_deepseek_fp8_block_scales:
use_deepseek_fp8_block_scale = True
pass
elif self.has_w4afp8:
use_w4_group_scaling = True
weight_dtype = torch.quint4x2
else:
raise ValueError(
Expand All @@ -569,12 +591,8 @@ def forward_chunk(
sizes=None if use_dp_padding else all_rank_num_tokens)
x_row = x.shape[0]

ep_size = self.ep_size
ep_rank = self.ep_rank
w3_w1_weight = self.w3_w1_weight
w2_weight = self.w2_weight
cluster_size = self.cluster_size
cluster_rank = self.cluster_rank
quant_scales = self.quant_scales

if self.alltoall_method_type == AlltoallMethodType.MNNVL:
Expand Down Expand Up @@ -640,7 +658,8 @@ def forward_chunk(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
)

final_hidden_states = torch.ops.trtllm.fused_moe(
final_hidden_states = self.moe_backend_impl.run_moe(
self,
x,
token_selected_slots,
token_final_scales,
Expand All @@ -652,17 +671,8 @@ def forward_chunk(
quant_scales=quant_scales,
input_sf=x_sf,
swizzled_input_sf=False,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
cluster_size=cluster_size,
cluster_rank=cluster_rank,
enable_alltoall=use_all_to_all,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4_group_scaling=use_w4_group_scaling,
min_latency_mode=False,
tune_max_num_tokens=self.tune_max_num_tokens,
use_fused_finalize=True,
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,
)
Expand Down
Loading