Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 36 additions & 45 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,29 @@ def __init__(
if self.enable_alltoall:
self.use_low_precision_combine = model_config.use_low_precision_moe_combine

if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
workspace_mb = int(
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_mb * 1024 * 1024,
)
else:
raise ValueError(
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
)
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
workspace_mb = int(
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_mb * 1024 * 1024,
)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
)
else:
raise NotImplementedError(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
)

# If True, the router weight will be multiplied on the input rather than at the end of FC2
Expand Down Expand Up @@ -235,28 +230,18 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
)
return AlltoallMethodType[all2all_method_type]

if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return AlltoallMethodType.NotEnabled

# TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter,
# regardless of the relationship between EP size and topK. We favor AlltoAll for now.
# TODO: We found that NVLinkOneSided performs better than NCCL AllGather/ReduceScatter,
# regardless of the relationship between EP size and topK. We favor NVLinkOneSided for now.
# if not self.mapping.moe_ep_size > self.routing_method.experts_per_token:
# return AlltoallMethodType.NotEnabled

return AlltoallMethodType.MNNVL
return AlltoallMethodType.NVLinkOneSided

@cached_property
def enable_alltoall(self):
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
"""
return self.alltoall_method_type != AlltoallMethodType.NotEnabled

@cached_property
def moe_alltoall_backend(self):
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
"NVLINK_ONE_SIDED").strip().upper()

def _supports_load_balancer(self) -> bool:
"""CutlassFusedMoE supports load balancer."""
return True
Expand Down Expand Up @@ -328,7 +313,7 @@ def forward_chunk(

if self.layer_load_balancer:
self._load_balancer_done_wait_gpu_stage(is_first_call)
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
Expand Down Expand Up @@ -439,7 +424,7 @@ def forward_chunk(
token_final_scales = torch.ones_like(token_selected_slots,
dtype=torch.float32)

if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
if is_last_call:
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
Expand Down Expand Up @@ -472,7 +457,7 @@ def forward_chunk(
token_selected_slots, alltoall_info.recv_rank_count_cumsum,
runtime_max_tokens_per_rank, top_k, self.num_slots,
self.ep_size)
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
# Python MoeAlltoAll path
if x_sf is not None:
x_sf = x_sf.view(x_row,
Expand Down Expand Up @@ -510,7 +495,7 @@ def forward_chunk(
-1, token_final_scales_recv.shape[-1])
else:
raise ValueError(
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
)

elif run_post_quant_allgather:
Expand All @@ -532,7 +517,7 @@ def forward_chunk(

# Optionally provide an output tensor to fused_moe so it writes directly to our buffer
moe_output: Optional[torch.Tensor] = None
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
# Retrieve a workspace-backed output tensor sized by runtime tokens
runtime_max_tokens_per_rank = max(
all_rank_num_tokens) if all_rank_num_tokens else x.shape[0]
Expand Down Expand Up @@ -583,7 +568,7 @@ def forward_chunk(

# Combine results if using alltoall
if self.enable_alltoall:
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
if alltoall_info is not None:
top_k = self.routing_method.experts_per_token
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
Expand All @@ -596,7 +581,7 @@ def forward_chunk(
use_low_precision_combine=self.
use_low_precision_combine,
token_count=token_count)
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
output_hidden_size = final_hidden_states.shape[-1]
runtime_max_tokens_per_rank = max(
all_rank_num_tokens) if all_rank_num_tokens else token_count
Expand All @@ -608,7 +593,7 @@ def forward_chunk(
payload_in_workspace=True)
else:
raise ValueError(
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
)

self._load_balancer_done_set_cpu_stage(is_last_call)
Expand Down Expand Up @@ -708,7 +693,10 @@ def _reducescatter_or_allreduce(x_, idx):
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
for idx_chunk, (x, router_logits) in enumerate(
zip(x_list, router_logits_list)):
if not (self.alltoall_method_type == AlltoallMethodType.MNNVL):
if not (self.alltoall_method_type
== AlltoallMethodType.NVLinkOneSided
or self.alltoall_method_type
== AlltoallMethodType.NVLinkTwoSided):
if idx_chunk % 2 == 0:
with torch.cuda.stream(self.aux_stream):
outputs = _forward_chunk(x, router_logits,
Expand All @@ -726,7 +714,10 @@ def _reducescatter_or_allreduce(x_, idx):

outputs_list.append(outputs)

if not (self.alltoall_method_type == AlltoallMethodType.MNNVL):
if not (self.alltoall_method_type
== AlltoallMethodType.NVLinkOneSided
or self.alltoall_method_type
== AlltoallMethodType.NVLinkTwoSided):
if num_chunks % 2 == 0:
outputs_list[-1] = _reducescatter_or_allreduce(
outputs_list[-1], -1)
Expand Down
71 changes: 28 additions & 43 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,34 +119,29 @@ def __init__(
if self.enable_alltoall:
self.use_low_precision_combine = model_config.use_low_precision_moe_combine

if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
workspace_mb = int(
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_mb * 1024 * 1024,
)
else:
raise ValueError(
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
)
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
workspace_mb = int(
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_mb * 1024 * 1024,
)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
"DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet"
)
else:
raise NotImplementedError(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
)

self._weights_created = False
Expand Down Expand Up @@ -176,15 +171,11 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
)
return AlltoallMethodType[all2all_method_type]

if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
return AlltoallMethodType.NotEnabled

# TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter,
# regardless of the relationship between EP size and topK. We favor AlltoAll for now.
# TODO: We found that NVLinkOneSided performs better than NCCL AllGather/ReduceScatter,
# regardless of the relationship between EP size and topK. We favor NVLinkOneSided for now.
# if not self.mapping.moe_ep_size > self.routing_method.experts_per_token:
# return AlltoallMethodType.NotEnabled

return AlltoallMethodType.MNNVL
return AlltoallMethodType.NVLinkOneSided

def _supports_load_balancer(self) -> bool:
"""TRTLLMGenFusedMoE supports load balancer."""
Expand All @@ -196,12 +187,6 @@ def enable_alltoall(self):
"""
return self.alltoall_method_type != AlltoallMethodType.NotEnabled

@cached_property
def moe_alltoall_backend(self):
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
"NVLINK_ONE_SIDED").strip().upper()

def _check_configs(self):
assert self.has_deepseek_fp8_block_scales \
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
Expand Down Expand Up @@ -362,7 +347,7 @@ def forward_impl(

self._load_balancer_done_wait_gpu_stage(is_first_call)

ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
self._load_balancer_update_statistic(
token_selected_experts,
is_first_call,
Expand Down Expand Up @@ -394,7 +379,7 @@ def forward_impl(
else:
token_final_scales = token_final_scales.to(torch.float32)

if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
if is_last_call:
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
Expand Down Expand Up @@ -444,7 +429,7 @@ def forward_impl(

if token_final_scales is not None:
token_final_scales = token_final_scales.to(torch.bfloat16)
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
if x_sf is not None:
x_sf = x_sf.view(x_row,
ceil_div(x_col, self.scaling_vector_size))
Expand Down Expand Up @@ -486,7 +471,7 @@ def forward_impl(
token_final_scales = token_final_scales.to(torch.bfloat16)
else:
raise ValueError(
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
)

elif run_post_quant_allgather:
Expand All @@ -510,7 +495,7 @@ def forward_impl(
moe_output: Optional[torch.Tensor] = None
use_workspace_output = False
# TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED" and self.has_w4a8_mxfp4_mxfp8:
if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided and self.has_w4a8_mxfp4_mxfp8:
moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace(
runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16)
use_workspace_output = True
Expand Down Expand Up @@ -774,7 +759,7 @@ def forward_impl(

# Combine results if using alltoall
if self.enable_alltoall:
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
if alltoall_info is not None:
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
final_hidden_states,
Expand All @@ -787,7 +772,7 @@ def forward_impl(
use_low_precision_combine,
token_count=token_count,
)
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
# If use_workspace_output=True, the MoE result is already in workspace
# Otherwise, we need to reshape and pass it
if use_workspace_output:
Expand All @@ -810,7 +795,7 @@ def forward_impl(
payload_in_workspace=False)
else:
raise ValueError(
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
)

final_hidden_states = self.reducescatter_or_allreduce(
Expand Down
Loading
Loading