Skip to content

Commit 417e139

Browse files
authored
[rollout, vllm] feat: support blockwise FP8 rollout for vLLM v0.11 MoE RL (#4222)
### What does this PR do? This PR enables support for **blockwise FP8 rollout** for **MoE** models using **vLLM v0.11.0**. **Relationship to previous work:** This is a follow-up to #3519. Please refer to that PR for the full support matrix, detailed usage instructions, experimental results, and other related context. **Implementation Details:** To support FP8 MoE RL with vLLM v0.11.0, this PR applies a monkey patch to the vLLM MoE model method: `vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading`. This modification allows the system to correctly handle model weight loading after quantization. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: #3519 - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 45bff38 commit 417e139

File tree

1 file changed

+60
-2
lines changed

1 file changed

+60
-2
lines changed

verl/utils/vllm/vllm_fp8_utils.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ def _create_param_from_subclass_attributes(custom_param):
325325
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
326326

327327

328-
def process_weights_after_loading_moe(self, layer) -> None:
328+
def process_weights_after_loading_moe_for_vllm10(self, layer) -> None:
329+
"""This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm v0.10"""
329330
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
330331
from vllm.model_executor.layers.quantization.fp8 import _is_col_major, _swap_w13_to_w31
331332
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -400,6 +401,58 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight):
400401
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
401402

402403

404+
def process_weights_after_loading_moe_for_vllm11(self, layer) -> None:
405+
"""This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm 0.11"""
406+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
407+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
408+
swap_w13_to_w31,
409+
)
410+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
411+
expert_weight_is_col_major,
412+
requant_weight_ue8m0_inplace,
413+
)
414+
from vllm.utils.deep_gemm import (
415+
get_col_major_tma_aligned_tensor,
416+
is_deep_gemm_e8m0_used,
417+
)
418+
419+
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
420+
421+
assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized
422+
assert self.quant_config.activation_scheme == "dynamic"
423+
424+
if self.flashinfer_moe_backend is not None:
425+
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
426+
layer.w13_weight_scale_inv.data = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
427+
428+
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
429+
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
430+
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
431+
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
432+
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
433+
434+
if is_deep_gemm_e8m0_used():
435+
assert layer.weight_block_size is not None
436+
# Re-quantise the expert weights so their scales are UE8M0.
437+
block_sz = tuple(layer.weight_block_size)
438+
requant_weight_ue8m0_inplace(
439+
layer.w13_weight.data,
440+
layer.w13_weight_scale_inv.data,
441+
block_sz,
442+
)
443+
requant_weight_ue8m0_inplace(
444+
layer.w2_weight.data,
445+
layer.w2_weight_scale_inv.data,
446+
block_sz,
447+
)
448+
449+
# Ensure column-major TMA alignment expected by DeepGEMM.
450+
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
451+
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
452+
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
453+
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
454+
455+
403456
def apply_vllm_fp8_patches():
404457
logger.info("Applying vllm fp8 patches for blockwise quantization")
405458
func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading"
@@ -411,5 +464,10 @@ def apply_vllm_fp8_patches():
411464
)
412465
patcher1.start()
413466
func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading"
414-
patcher2 = patch(func2_path, process_weights_after_loading_moe)
467+
patcher2 = patch(
468+
func2_path,
469+
process_weights_after_loading_moe_for_vllm11
470+
if vllm.__version__ >= "0.11.0"
471+
else process_weights_after_loading_moe_for_vllm10,
472+
)
415473
patcher2.start()

0 commit comments

Comments
 (0)