From 58c105e34c120649d8fd7ba5681e638249095b23 Mon Sep 17 00:00:00 2001 From: larkzhang-nv Date: Thu, 20 Nov 2025 20:05:31 -0800 Subject: [PATCH 1/2] update patch for vllm0.11 moe fp8 rollout --- verl/utils/vllm/vllm_fp8_utils.py | 73 ++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index eab4e4d8d77..a35779a9477 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -325,7 +325,8 @@ def _create_param_from_subclass_attributes(custom_param): maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) -def process_weights_after_loading_moe(self, layer) -> None: +def process_weights_after_loading_moe_for_vllm10(self, layer) -> None: + """This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm v0.10""" from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from vllm.model_executor.layers.quantization.fp8 import _is_col_major, _swap_w13_to_w31 from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -400,6 +401,69 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() +def process_weights_after_loading_moe_for_vllm11(self, layer) -> None: + """This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm 0.11""" + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled, shuffle_weights) + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + swap_w13_to_w31, + ) + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + expert_weight_is_col_major, + requant_weight_ue8m0_inplace, + ) + from vllm.utils.deep_gemm import ( + get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used, + ) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized + assert self.quant_config.activation_scheme == "dynamic" + + if self.flashinfer_moe_backend is not None: + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + layer.w13_weight_scale_inv.data = swap_w13_to_w31( + layer.w13_weight_scale_inv.data + ) + + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale_inv + ) + if expert_weight_is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale_inv + ) + + if is_deep_gemm_e8m0_used(): + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale_inv.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale_inv.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if expert_weight_is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale_inv + ) + if expert_weight_is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale_inv + ) + + def apply_vllm_fp8_patches(): logger.info("Applying vllm fp8 patches for blockwise quantization") func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" @@ -411,5 +475,10 @@ def apply_vllm_fp8_patches(): ) patcher1.start() func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" - patcher2 = patch(func2_path, process_weights_after_loading_moe) + patcher2 = patch( + func2_path, + process_weights_after_loading_moe_for_vllm11 + if vllm.__version__ >= "0.11.0" + else process_weights_after_loading_moe_for_vllm10, + ) patcher2.start() From 918d2e46e164297a6eeffeb17194796cf9f8297f Mon Sep 17 00:00:00 2001 From: larkz-nv Date: Thu, 20 Nov 2025 21:18:29 -0800 Subject: [PATCH 2/2] pass the pre-commit --- verl/utils/vllm/vllm_fp8_utils.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index a35779a9477..c7e587da244 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -403,8 +403,7 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): def process_weights_after_loading_moe_for_vllm11(self, layer) -> None: """This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm 0.11""" - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( swap_w13_to_w31, ) @@ -424,19 +423,13 @@ def process_weights_after_loading_moe_for_vllm11(self, layer) -> None: if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - layer.w13_weight_scale_inv.data = swap_w13_to_w31( - layer.w13_weight_scale_inv.data - ) + layer.w13_weight_scale_inv.data = swap_w13_to_w31(layer.w13_weight_scale_inv.data) if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): if expert_weight_is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv - ) + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) if expert_weight_is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv - ) + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None @@ -455,13 +448,9 @@ def process_weights_after_loading_moe_for_vllm11(self, layer) -> None: # Ensure column-major TMA alignment expected by DeepGEMM. if expert_weight_is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv - ) + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) if expert_weight_is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv - ) + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) def apply_vllm_fp8_patches(): @@ -476,9 +465,9 @@ def apply_vllm_fp8_patches(): patcher1.start() func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" patcher2 = patch( - func2_path, - process_weights_after_loading_moe_for_vllm11 - if vllm.__version__ >= "0.11.0" + func2_path, + process_weights_after_loading_moe_for_vllm11 + if vllm.__version__ >= "0.11.0" else process_weights_after_loading_moe_for_vllm10, ) patcher2.start()