Skip to content

[vllm_fp8_utils] fix: preserve vLLM>=0.12 FP8 Marlin workspace init#5153

Open
JohnConnor123 wants to merge 6 commits intoverl-project:mainfrom
JohnConnor123:bugfix/vllm_fp8_workspace
Open

[vllm_fp8_utils] fix: preserve vLLM>=0.12 FP8 Marlin workspace init#5153
JohnConnor123 wants to merge 6 commits intoverl-project:mainfrom
JohnConnor123:bugfix/vllm_fp8_workspace

Conversation

@JohnConnor123
Copy link
Contributor

Summary

  • Fix vLLM FP8 (Marlin) crash in verl: missing workspace attribute (seen as AttributeError during engine init / profiling).
  • Root cause: verl monkey-patched vLLM FP8 process_weights_after_loading and (for vLLM>=0.12) bypassed vLLM-native Marlin preparation, where vLLM allocates/sets layer.workspace.
  • Fix: for vLLM>=0.12 wrap and call the original vLLM post-load method (preserving Marlin workspace init) and restore verl-specific custom Parameter attrs; additionally ensure QKVParallelLinear instances always have a workspace attribute (lazy + idempotent) to avoid early AttributeError.

Reproduction

The bug shows up on the Marlin FP8 path (typically GPUs without native FP8 support). If you are on a native-FP8 GPU, you can force Marlin for testing (see below).

Minimal repro script (crashes on main, fixed on this PR)

set -euo pipefail

# Force Marlin even on native-FP8 GPUs (H100/SM90+) so the bug reproduces deterministically.
export VLLM_TEST_FORCE_FP8_MARLIN=1

python3 - <<'PY'
from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches
apply_vllm_fp8_patches()

from vllm import LLM, SamplingParams

llm = LLM(
    model="Qwen/Qwen3-1.7B",
    tensor_parallel_size=1,
    dtype="bfloat16",
    quantization="fp8",
    trust_remote_code=True,
)
out = llm.generate(["Hello"], SamplingParams(max_tokens=8))
print(out[0].outputs[0].text)
PY

Expected error (before)

AttributeError: 'QKVParallelLinear' object has no attribute 'workspace'

(Triggered from the Marlin FP8 execution path where vLLM uses workspace=layer.workspace; see vLLM v0.12.0 Fp8LinearMethod.apply:
https://raw.githubusercontent.com/vllm-project/vllm/v0.12.0/vllm/model_executor/layers/quantization/fp8.py.)

Notes on GPU support (native FP8 vs Marlin)

  • On GPUs without native FP8, vLLM commonly falls back to Marlin weight-only FP8, which requires a workspace buffer.
  • On GPUs with native FP8, vLLM usually uses a non-Marlin FP8 path and does not rely on layer.workspace, so this crash often does not appear.
  • You can force the Marlin path for testing on native-FP8 GPUs with VLLM_TEST_FORCE_FP8_MARLIN=1 (as in the repro script above).

Test plan

  • Run the script above on main → observe the workspace AttributeError (or force Marlin).
  • Run the same script on this PR branch → it should complete and print a short generation.

@CLAassistant
Copy link

CLAassistant commented Jan 31, 2026

CLA assistant check
All committers have signed the CLA.

@wuxibin89 wuxibin89 requested a review from Agoniii February 2, 2026 05:40
@wuxibin89
Copy link
Collaborator

wuxibin89 commented Feb 2, 2026

@JohnConnor123 Can you clarify what GPU you're using? I have reservations about supporting fp8 rollout on GPUs without native fp8 capability.

@JohnConnor123
Copy link
Contributor Author

@wuxibin89 I use 2 rtx3090.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants