Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions verl/utils/vllm/vllm_fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def _create_param_from_subclass_attributes(custom_param):
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)


def process_weights_after_loading_moe(self, layer) -> None:
def process_weights_after_loading_moe_for_vllm10(self, layer) -> None:
"""This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm v0.10"""
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
from vllm.model_executor.layers.quantization.fp8 import _is_col_major, _swap_w13_to_w31
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Expand Down Expand Up @@ -400,6 +401,58 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()


def process_weights_after_loading_moe_for_vllm11(self, layer) -> None:
"""This function is used to process the weights after loading for a FusedMoE layer, it is used for vllm 0.11"""
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
expert_weight_is_col_major,
requant_weight_ue8m0_inplace,
)
from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
)

self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized
assert self.quant_config.activation_scheme == "dynamic"

if self.flashinfer_moe_backend is not None:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
layer.w13_weight_scale_inv.data = swap_w13_to_w31(layer.w13_weight_scale_inv.data)

if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)

if is_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
)

# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
Comment on lines +428 to +453
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's duplicated code for aligning DeepGEMM scales. This logic appears in two separate if blocks. To improve maintainability and reduce redundancy, you can extract this logic into a nested helper function.

    def _align_scales(l):
        if expert_weight_is_col_major(l.w13_weight_scale_inv):
            l.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(l.w13_weight_scale_inv)
        if expert_weight_is_col_major(l.w2_weight_scale_inv):
            l.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(l.w2_weight_scale_inv)

    if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
        _align_scales(layer)

    if is_deep_gemm_e8m0_used():
        assert layer.weight_block_size is not None
        # Re-quantise the expert weights so their scales are UE8M0.
        block_sz = tuple(layer.weight_block_size)
        requant_weight_ue8m0_inplace(
            layer.w13_weight.data,
            layer.w13_weight_scale_inv.data,
            block_sz,
        )
        requant_weight_ue8m0_inplace(
            layer.w2_weight.data,
            layer.w2_weight_scale_inv.data,
            block_sz,
        )

        # Ensure column-major TMA alignment expected by DeepGEMM.
        _align_scales(layer)



def apply_vllm_fp8_patches():
logger.info("Applying vllm fp8 patches for blockwise quantization")
func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading"
Expand All @@ -411,5 +464,10 @@ def apply_vllm_fp8_patches():
)
patcher1.start()
func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading"
patcher2 = patch(func2_path, process_weights_after_loading_moe)
patcher2 = patch(
func2_path,
process_weights_after_loading_moe_for_vllm11
if vllm.__version__ >= "0.11.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Comparing version strings directly using operators like >= can lead to incorrect results for some versioning schemes (e.g., '0.9.0' is lexicographically greater than '0.11.0'). For robust version comparison, it's recommended to use a dedicated library like packaging.version. This will prevent potential bugs if vllm releases versions like 0.12.0 or 1.0.0. Note that a similar issue exists on line 462 for patcher1.

Suggested change
if vllm.__version__ >= "0.11.0"
if __import__("packaging").version.parse(vllm.__version__) >= __import__("packaging").version.parse("0.11.0")

else process_weights_after_loading_moe_for_vllm10,
)
patcher2.start()
Loading