-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[rollout, vllm] feat: support blockwise FP8 rollout for vLLM v0.11 MoE RL #4222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -325,7 +325,8 @@ def _create_param_from_subclass_attributes(custom_param): | |||||
| maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) | ||||||
|
|
||||||
|
|
||||||
| def process_weights_after_loading_moe(self, layer) -> None: | ||||||
| def process_weights_after_loading_moe_for_vllm10(self, layer) -> None: | ||||||
| """This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm v0.10""" | ||||||
| from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled | ||||||
| from vllm.model_executor.layers.quantization.fp8 import _is_col_major, _swap_w13_to_w31 | ||||||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||||||
|
|
@@ -400,6 +401,58 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): | |||||
| layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() | ||||||
|
|
||||||
|
|
||||||
| def process_weights_after_loading_moe_for_vllm11(self, layer) -> None: | ||||||
| """This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm 0.11""" | ||||||
| from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled | ||||||
| from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( | ||||||
| swap_w13_to_w31, | ||||||
| ) | ||||||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||||||
| expert_weight_is_col_major, | ||||||
| requant_weight_ue8m0_inplace, | ||||||
| ) | ||||||
| from vllm.utils.deep_gemm import ( | ||||||
| get_col_major_tma_aligned_tensor, | ||||||
| is_deep_gemm_e8m0_used, | ||||||
| ) | ||||||
|
|
||||||
| self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() | ||||||
|
|
||||||
| assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized | ||||||
| assert self.quant_config.activation_scheme == "dynamic" | ||||||
|
|
||||||
| if self.flashinfer_moe_backend is not None: | ||||||
| layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) | ||||||
| layer.w13_weight_scale_inv.data = swap_w13_to_w31(layer.w13_weight_scale_inv.data) | ||||||
|
|
||||||
| if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): | ||||||
| if expert_weight_is_col_major(layer.w13_weight_scale_inv): | ||||||
| layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) | ||||||
| if expert_weight_is_col_major(layer.w2_weight_scale_inv): | ||||||
| layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) | ||||||
|
|
||||||
| if is_deep_gemm_e8m0_used(): | ||||||
| assert layer.weight_block_size is not None | ||||||
| # Re-quantise the expert weights so their scales are UE8M0. | ||||||
| block_sz = tuple(layer.weight_block_size) | ||||||
| requant_weight_ue8m0_inplace( | ||||||
| layer.w13_weight.data, | ||||||
| layer.w13_weight_scale_inv.data, | ||||||
| block_sz, | ||||||
| ) | ||||||
| requant_weight_ue8m0_inplace( | ||||||
| layer.w2_weight.data, | ||||||
| layer.w2_weight_scale_inv.data, | ||||||
| block_sz, | ||||||
| ) | ||||||
|
|
||||||
| # Ensure column-major TMA alignment expected by DeepGEMM. | ||||||
| if expert_weight_is_col_major(layer.w13_weight_scale_inv): | ||||||
| layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) | ||||||
| if expert_weight_is_col_major(layer.w2_weight_scale_inv): | ||||||
| layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) | ||||||
|
|
||||||
|
|
||||||
| def apply_vllm_fp8_patches(): | ||||||
| logger.info("Applying vllm fp8 patches for blockwise quantization") | ||||||
| func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" | ||||||
|
|
@@ -411,5 +464,10 @@ def apply_vllm_fp8_patches(): | |||||
| ) | ||||||
| patcher1.start() | ||||||
| func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" | ||||||
| patcher2 = patch(func2_path, process_weights_after_loading_moe) | ||||||
| patcher2 = patch( | ||||||
| func2_path, | ||||||
| process_weights_after_loading_moe_for_vllm11 | ||||||
| if vllm.__version__ >= "0.11.0" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comparing version strings directly using operators like
Suggested change
|
||||||
| else process_weights_after_loading_moe_for_vllm10, | ||||||
| ) | ||||||
| patcher2.start() | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's duplicated code for aligning DeepGEMM scales. This logic appears in two separate
ifblocks. To improve maintainability and reduce redundancy, you can extract this logic into a nested helper function.