diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index eab4e4d8d77..c7e587da244 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -325,7 +325,8 @@ def _create_param_from_subclass_attributes(custom_param): maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) -def process_weights_after_loading_moe(self, layer) -> None: +def process_weights_after_loading_moe_for_vllm10(self, layer) -> None: + """This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm v0.10""" from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from vllm.model_executor.layers.quantization.fp8 import _is_col_major, _swap_w13_to_w31 from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -400,6 +401,58 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() +def process_weights_after_loading_moe_for_vllm11(self, layer) -> None: + """This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm 0.11""" + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + swap_w13_to_w31, + ) + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + expert_weight_is_col_major, + requant_weight_ue8m0_inplace, + ) + from vllm.utils.deep_gemm import ( + get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used, + ) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized + assert self.quant_config.activation_scheme == "dynamic" + + if self.flashinfer_moe_backend is not None: + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + layer.w13_weight_scale_inv.data = swap_w13_to_w31(layer.w13_weight_scale_inv.data) + + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) + if expert_weight_is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) + + if is_deep_gemm_e8m0_used(): + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale_inv.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale_inv.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if expert_weight_is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) + if expert_weight_is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) + + def apply_vllm_fp8_patches(): logger.info("Applying vllm fp8 patches for blockwise quantization") func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" @@ -411,5 +464,10 @@ def apply_vllm_fp8_patches(): ) patcher1.start() func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" - patcher2 = patch(func2_path, process_weights_after_loading_moe) + patcher2 = patch( + func2_path, + process_weights_after_loading_moe_for_vllm11 + if vllm.__version__ >= "0.11.0" + else process_weights_after_loading_moe_for_vllm10, + ) patcher2.start()