From 360f36dc3356294d2ce2fe3597f63b3ea8881182 Mon Sep 17 00:00:00 2001 From: xxi Date: Mon, 25 Aug 2025 08:15:34 +0000 Subject: [PATCH] debug Signed-off-by: xxi modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py modified: tensorrt_llm/_torch/distributed/ops.py modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py --- tensorrt_llm/_torch/distributed/ops.py | 5 ++ .../modules/fused_moe/fused_moe_wide_ep.py | 89 ++++++++++++------- .../_torch/modules/fused_moe/moe_backend.py | 35 ++++++-- .../modules/fused_moe/moe_load_balancer.py | 3 + 4 files changed, 97 insertions(+), 35 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 6bd0fcd6ebc..1b915b3ddbb 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -240,6 +240,11 @@ def reducescatter( if isinstance(input, torch.Tensor): assert input.shape[dim] == sum_split_size else: + for val in input: + if val is not None and val.shape[dim] != sum_split_size: + print( + f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}" + ) assert all([ val.shape[dim] == sum_split_size for val in input if val is not None diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index c43f0491a6d..fb245764282 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -308,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int: def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens): # Disable alltoall when chunking is used if self.calculate_num_chunks(all_rank_num_tokens) > 1: + print( + f"can not use alltoall due to chunking {self.calculate_num_chunks(all_rank_num_tokens)}" + ) return False # For DeepEPLowLatency, check if tokens exceed the threshold if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency and all_rank_max_num_tokens > self.deep_ep_max_num_tokens): + print( + f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}" + ) return False + print(f"all to all type {self.alltoall_method_type}") return self.enable_alltoall def _get_quant_method(self): @@ -323,9 +330,18 @@ def _get_quant_method(self): if self.quant_config.layer_quant_mode.has_fp8_qdq(): return FP8QDQFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_fp8_block_scales(): + print( + f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}" + ) if get_sm_version() == 100: + print( + f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm" + ) return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm() else: + print( + f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod" + ) return DeepSeekFP8BlockScalesFusedMoEMethod() elif self.quant_config.layer_quant_mode.has_nvfp4(): return NVFP4CutlassFusedMoEMethod() @@ -399,6 +415,10 @@ def forward_chunk( is_first_call, is_last_call = repeating_info + print( + f"wide_ep forward_chunk: layer_load_balancer={self.layer_load_balancer}, is_first_call={is_first_call}, is_last_call={is_last_call}, x shape: {getattr(x, 'shape', None)}, router_logits shape: {getattr(router_logits, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, repeating_info: {repeating_info}" + ) + if self.layer_load_balancer and is_first_call: self.layer_load_balancer.start_wait_gpu_stage() @@ -475,7 +495,7 @@ def forward_chunk( self.dummy_allreduce() token_count = x.shape[0] alltoall_info = None - if is_last_call: + if self.layer_load_balancer and is_last_call: loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( ) else: @@ -650,7 +670,35 @@ def forward_chunk( ) # Original fused_moe call (preserved as reference) - final_hidden_states = torch.ops.trtllm.fused_moe( + # final_hidden_states = torch.ops.trtllm.fused_moe( + # x, + # token_selected_slots, + # token_final_scales, + # w3_w1_weight.view(weight_dtype), + # None, # w3_w1_bias + # w2_weight.view(weight_dtype), + # None, # w2_bias + # output_dtype, + # quant_scales=quant_scales, + # input_sf=x_sf, + # swizzled_input_sf=False, + # tp_size=self.tp_size, + # tp_rank=self.tp_rank, + # ep_size=ep_size, + # ep_rank=ep_rank, + # cluster_size=cluster_size, + # cluster_rank=cluster_rank, + # enable_alltoall=use_all_to_all, + # use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, + # use_w4_group_scaling=use_w4_group_scaling, + # min_latency_mode=False, + # tune_max_num_tokens=self.tune_max_num_tokens, + # tuner_num_tokens=tuner_num_tokens, + # tuner_top_k=tuner_top_k, + # ) + + # Use the selected backend to compute MoE with the same parameters as fused_moe + final_hidden_states = self.moe_backend.run_moe( x, token_selected_slots, token_final_scales, @@ -675,36 +723,13 @@ def forward_chunk( tune_max_num_tokens=self.tune_max_num_tokens, tuner_num_tokens=tuner_num_tokens, tuner_top_k=tuner_top_k, + module= + self, # Additional parameter for backend to access module properties ) - # Use the selected backend to compute MoE with the same parameters as fused_moe - # final_hidden_states = self.moe_backend.run_moe( - # x, - # token_selected_slots, - # token_final_scales, - # w3_w1_weight.view(weight_dtype), - # None, # w3_w1_bias - # w2_weight.view(weight_dtype), - # None, # w2_bias - # output_dtype, - # quant_scales=quant_scales, - # input_sf=x_sf, - # swizzled_input_sf=False, - # tp_size=self.tp_size, - # tp_rank=self.tp_rank, - # ep_size=ep_size, - # ep_rank=ep_rank, - # cluster_size=cluster_size, - # cluster_rank=cluster_rank, - # enable_alltoall=use_all_to_all, - # use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - # use_w4_group_scaling=use_w4_group_scaling, - # min_latency_mode=False, - # tune_max_num_tokens=self.tune_max_num_tokens, - # tuner_num_tokens=tuner_num_tokens, - # tuner_top_k=tuner_top_k, - # module=self, # Additional parameter for backend to access module properties - # ) + print( + f"xxi x.shape: {getattr(x, 'shape', None)}, final_hidden_states shape: {getattr(final_hidden_states[0], 'shape', None)}, token_selected_slots shape: {getattr(token_selected_slots, 'shape', None)}, token_final_scales shape: {getattr(token_final_scales, 'shape', None)}, w3_w1_weight shape: {getattr(w3_w1_weight, 'shape', None)}, w2_weight shape: {getattr(w2_weight, 'shape', None)}, quant_scales: {getattr(quant_scales, 'shape', None)}, input_sf: {getattr(x_sf, 'shape', None)}, swizzled_input_sf: False, tp_size: {self.tp_size}, tp_rank: {self.tp_rank}, ep_size: {ep_size}, ep_rank: {ep_rank}, cluster_size: {cluster_size}, cluster_rank: {cluster_rank}, enable_alltoall: {use_all_to_all}, use_deepseek_fp8_block_scale: {use_deepseek_fp8_block_scale}, use_w4_group_scaling: {use_w4_group_scaling}, min_latency_mode: False, tune_max_num_tokens: {self.tune_max_num_tokens}, tuner_num_tokens: {tuner_num_tokens}, tuner_top_k: {tuner_top_k}" + ) if self.layer_load_balancer and is_last_call: self.layer_load_balancer.start_set_cpu_stage() @@ -784,6 +809,10 @@ def forward( all_rank_max_num_tokens=all_rank_max_num_tokens, use_dp_padding=use_dp_padding, repeating_info=(is_first_call, is_last_call)) + # 一行打印所有信息 + print( + f"xxi x.shape: {getattr(x, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_num_tokens_padded: {all_rank_num_tokens_padded}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, outputs.shape: {getattr(outputs, 'shape', None)}, use_dp_padding(again): {use_dp_padding}" + ) outputs = self.reducescatter_or_allreduce( outputs, use_all_to_all, diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py b/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py index c204489e6fe..eb54659a4d0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_backend.py @@ -98,7 +98,6 @@ def compute_moe( Computed MoE output tensor """ - @abstractmethod def run_moe( self, # Positional arguments (same order as torch.ops.trtllm.fused_moe) @@ -542,10 +541,11 @@ def __init__(self): super().__init__() # Import DeepGemm specific functions import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils - from tensorrt_llm import deep_gemm - self.deep_gemm = deep_gemm self.fp8_utils = fp8_utils + from .fused_moe_deepgemm import deepgemm_fp8_group_blockwise_gemm + self.deepgemm_fp8_group_blockwise_gemm = deepgemm_fp8_group_blockwise_gemm + def finalize_tactic( self, module: Any, @@ -664,6 +664,7 @@ def compute_moe( Note: This assumes the data has already been gathered/alltoall'd by the WideEP forward_chunk method. """ + # Import necessary functions for DeepGemm from .fused_moe_deepgemm import (masked_index_copy_group_quant_fp8, preprocess_after_permute, set_strides, @@ -711,6 +712,20 @@ def compute_moe( use_fp8_block_scaling=True, # Always use block scaling for DeepGemm ) + print( + "enter deepgemm backend compute_moe \n" + f"x.shape: {getattr(x, 'shape', None)}, \n" + f"input_sf.shape: {getattr(input_sf, 'shape', None)}, \n" + f"token_selected_slots.shape: {getattr(token_selected_slots, 'shape', None)}, \n" + f"token_final_scales.shape: {getattr(token_final_scales, 'shape', None)}, \n" + f"permuted_row_to_unpermuted_row_tensor.shape: {getattr(permuted_row_to_unpermuted_row_tensor, 'shape', None)}, \n" + f"permuted_token_selected_experts_tensor.shape: {getattr(permuted_token_selected_experts_tensor, 'shape', None)}, \n" + f"permuted_data_tensor.shape: {getattr(permuted_data_tensor, 'shape', None)}, \n" + f"expert_first_token_offset_tensor.shape: {getattr(expert_first_token_offset_tensor, 'shape', None)}, \n" + f"permuted_token_final_scales_tensor.shape: {getattr(permuted_token_final_scales_tensor, 'shape', None)}, \n" + f"unpermuted_row_to_permuted_row_tensor.shape: {getattr(unpermuted_row_to_permuted_row_tensor, 'shape', None)}\n" + ) + if permuted_data_tensor.numel() == 0: return torch.zeros_like(x) @@ -750,7 +765,7 @@ def compute_moe( h1 = set_strides(workspace["workspace_1"], expert_size_per_partition, m_max, intermediate_size * 2) - self.deep_gemm.deepgemm_fp8_group_blockwise_gemm( + self.deepgemm_fp8_group_blockwise_gemm( d=h1, a=act_input_fp8, b=w3_w1_weight, @@ -783,7 +798,7 @@ def compute_moe( h3 = set_strides(workspace["workspace_1"], expert_size_per_partition, m_max, hidden_size) - self.deep_gemm.deepgemm_fp8_group_blockwise_gemm( + self.deepgemm_fp8_group_blockwise_gemm( d=h3, a=act_input_fp8, b=w2_weight, @@ -817,6 +832,16 @@ def compute_moe( ep_size, ep_rank, ) + print( + "exit deepgemm backend compute_moe \n" + f"permuted_data_tensor.shape: {getattr(permuted_data_tensor, 'shape', None)}, " + f"token_final_scales.shape: {getattr(token_final_scales, 'shape', None)}, " + f"unpermuted_row_to_permuted_row_tensor.shape: {getattr(unpermuted_row_to_permuted_row_tensor, 'shape', None)}, " + f"permuted_row_to_unpermuted_row_tensor.shape: {getattr(permuted_row_to_unpermuted_row_tensor, 'shape', None)}, " + f"expert_first_token_offset_tensor.shape: {getattr(expert_first_token_offset_tensor, 'shape', None)}, " + f"x.shape: {getattr(x, 'shape', None)}, " + f"final_hidden_states.shape: {getattr(final_hidden_states, 'shape', None)}" + ) return final_hidden_states diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index ff26c87687a..8cb51ea55b2 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -960,6 +960,9 @@ def maybe_create_moe_load_balancer( in_supported_model_arch = model_arch in moe_model_arch_list using_smart_router = mapping and mapping.moe_cluster_size > 1 moe_load_balancer = nullcontext() + print( + f"maybe_create_moe_load_balancer: in_supported_model_arch={in_supported_model_arch}, using_ep={using_ep}, using_smart_router={using_smart_router}, model_config.moe_load_balancer={model_config.moe_load_balancer}" + ) if in_supported_model_arch and using_ep and not using_smart_router and model_config.moe_load_balancer is not None: model_config.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size) if model_config.moe_load_balancer.layer_updates_per_iter > 0: