Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
MoEWeightLoadingMode, TRTLLMGenFusedMoE,
create_moe)
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
Expand Down Expand Up @@ -559,6 +560,9 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
**({
"alltoall_result_do_sum": False
} if isinstance(self.experts, WideEPMoE) else {}),
)

return routed_output
Expand Down
43 changes: 27 additions & 16 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,16 @@ def reducescatter_or_allreduce(
return outputs

def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
use_all_to_all: bool,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
use_all_to_all: bool,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
alltoall_result_do_sum: bool = True,
) -> torch.Tensor:
if isinstance(x, Fp4QuantizedTensor):
assert output_dtype is not None
Expand All @@ -389,6 +390,9 @@ def forward_chunk(
if self.layer_load_balancer and is_first_call:
self.layer_load_balancer.start_wait_gpu_stage()

if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
alltoall_result_do_sum = True

use_deepseek_fp8_block_scale = False
use_w4_group_scaling = False
weight_dtype = self.w3_w1_weight.dtype
Expand Down Expand Up @@ -679,7 +683,8 @@ def forward_chunk(
if self.enable_dummy_allreduce:
self.dummy_allreduce()
final_hidden_states = self.alltoall_combine(
final_hidden_states, alltoall_info, token_count)
final_hidden_states, alltoall_info, token_count,
alltoall_result_do_sum)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
final_hidden_states = self.unpad_tensors(
padded, final_hidden_states)
Expand Down Expand Up @@ -719,6 +724,7 @@ def forward(
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
alltoall_result_do_sum: bool = True,
) -> torch.Tensor:
assert all_rank_num_tokens is not None
assert use_dp_padding is not None
Expand All @@ -744,7 +750,8 @@ def forward(
all_rank_num_tokens=all_rank_num_tokens_padded,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
outputs = self.reducescatter_or_allreduce(
outputs,
use_all_to_all,
Expand Down Expand Up @@ -804,7 +811,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
all_rank_max_num_tokens=
all_rank_max_num_tokens_list[idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
if idx_chunk > 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
Expand All @@ -822,7 +830,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
Expand All @@ -838,7 +847,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
idx_chunk],
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)

outputs_list.append(outputs)
if not use_all_to_all:
Expand Down Expand Up @@ -894,7 +904,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
return x, x_sf, token_selected_slots, token_final_scales

def alltoall_combine(self, final_hidden_states: torch.Tensor,
alltoall_info: MoEAlltoallInfo, token_count: int):
alltoall_info: MoEAlltoallInfo, token_count: int,
alltoall_result_do_sum: bool):
top_k = self.routing_method.experts_per_token
if isinstance(final_hidden_states, list):
final_hidden_states = final_hidden_states[0]
Expand All @@ -907,7 +918,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
top_k=top_k,
token_count=token_count,
use_low_precision_combine=self.use_low_precision_combine,
do_reduce=False)
do_reduce=alltoall_result_do_sum)

return final_hidden_states

Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ l0_dgx_b200:
tests:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ l0_dgx_h100:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]
Expand Down
13 changes: 9 additions & 4 deletions tests/unittest/_torch/modules/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,14 @@ def per_rank_test_fused_moe_alltoall(job_id):
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
torch.nn.init.xavier_uniform_(w1_weight)
torch.nn.init.xavier_uniform_(w2_weight)
torch.nn.init.xavier_uniform_(w3_weight)
Expand Down Expand Up @@ -289,7 +292,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
assert r is None


@pytest.mark.skip(reason="https://nvbugs/5467531")
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("alltoall_method_type", [
Expand All @@ -299,6 +301,9 @@ def per_rank_test_fused_moe_alltoall(job_id):
ids=lambda s: s.name)
def test_fused_moe_alltoall_fp4(alltoall_method_type):

if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
pytest.skip("Skipped due to https://nvbugs/5467531")

world_size = 4
dtype = torch.bfloat16
HIDDEN_SIZE = 2560
Expand Down