Skip to content

Commit 8543afe

Browse files
Merge branch 'main' into lucas/fast_fp8_v2
2 parents b3599c0 + 1a17045 commit 8543afe

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ Below are the combinations of the environments that we have tested on.
116116
|--|--|--|--|--|--|--|--|
117117
| `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | 12.6 | 2.43.0 | 0.8.3 | [flash-rl](https://github.com/yaof20/verl/tree/flash-rl/recipe/flash_rl) | 1.0.1 | ✅ Tested | ✅ Tested |
118118
| `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0` | 12.6 | 2.43.0 | 0.8.4 | [flash-rl](https://github.com/yaof20/verl/tree/flash-rl/recipe/flash_rl) | 1.0.1 | ✅ Tested | |
119-
| `hiyouga/verl:ngc-th2.7.0-cu12.6-vllm0.9.1` | 12.6 | 2.43.0 | 0.8.4 | [flash-rl-vllm0.9.1](https://github.com/yaof20/verl/tree/flash-rl-vllm0.9.1/recipe/flash_rl) | 1.0.2| ✅ Tested | |
119+
| `hiyouga/verl:ngc-th2.7.0-cu12.6-vllm0.9.1` | 12.6 | 2.43.0 | 0.9.1 | [flash-rl-vllm0.9.1](https://github.com/yaof20/verl/tree/flash-rl-vllm0.9.1/recipe/flash_rl) | 1.0.2| ✅ Tested | |
120+
| `hiyouga/verl:ngc-th2.7.1-cu12.6-vllm0.10.0` | 12.6 | 2.48.0 | 0.10.0 | [flash-rl-vllm0.9.1](https://github.com/yaof20/verl/tree/flash-rl-vllm0.9.1/recipe/flash_rl) | 1.0.2| ✅ Tested | |
120121

121122
## 🚧 Roadmap & Future Improvements
122123

flash_rl/vllm_patch.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import vllm
44
import torch
5+
from torch import nn
56
import types
67
import gc
78
from packaging.version import parse
@@ -35,6 +36,56 @@ def bond_method_to_cls(func, obj):
3536
'_assert_and_load',
3637
]
3738

39+
# fine-grained process weights after loading implementation where we skip operations for certain modules
40+
def process_weights_after_loading_skippable(model: nn.Module, model_config,
41+
target_device: torch.device) -> None:
42+
from vllm.attention import Attention
43+
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
44+
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
45+
from vllm.model_executor.model_loader.utils import device_loading_context
46+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod, CompressedTensorsW8A8Int8
47+
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
48+
49+
for _, module in model.named_modules():
50+
if isinstance(module, QKVCrossParallelLinear):
51+
# NOTE(Isotr0py): special case for cross QKV layer because
52+
# q and kv proj aren't registered as submodules intentionally
53+
module.process_weights_after_loading()
54+
continue
55+
quant_method = getattr(module, "quant_method", None)
56+
quant_method = getattr(module, "quant_method", None)
57+
58+
# hack for skpping postprocess for default cutlass kernel
59+
scheme = getattr(module, "scheme", None)
60+
is_cutlass = isinstance(quant_method, CompressedTensorsLinearMethod) and isinstance(scheme, CompressedTensorsW8A8Int8) and isinstance(scheme.kernel, CutlassScaledMMLinearKernel)
61+
can_skip_postprocess = False
62+
if is_cutlass:
63+
assert isinstance(scheme, CompressedTensorsW8A8Int8)
64+
kernel = scheme.kernel
65+
can_skip_postprocess = kernel.config.is_channelwise and not kernel.config.is_static_input_scheme and kernel.config.input_symmetric
66+
67+
if can_skip_postprocess:
68+
continue
69+
70+
if isinstance(quant_method, QuantizeMethodBase):
71+
# When quant methods need to process weights after loading
72+
# (for repacking, quantizing, etc), they expect parameters
73+
# to be on the global target device. This scope is for the
74+
# case where cpu offloading is used, where we will move the
75+
# parameters onto device for processing and back off after.
76+
with device_loading_context(module, target_device):
77+
quant_method.process_weights_after_loading(module)
78+
79+
# Currently only used by MLA.
80+
# NOTE: This intentionally happens after other modules so we can easily
81+
# decompress the weights for MLA.
82+
for _, module in model.named_modules():
83+
if isinstance(module, Attention) and \
84+
hasattr(module, "process_weights_after_loading"):
85+
# TODO(lucas): see if there is a way to unify the signatures
86+
# of process_weights_after_loading
87+
module.process_weights_after_loading(model_config.dtype)
88+
3889
def hacked_process_weights_after_loading(
3990
original_process_weights_after_loading,
4091
model,
@@ -184,6 +235,7 @@ def patch_vllm_process_weights_after_loading():
184235
loader.beforeflashrl_process_weights_after_loading = original_process_weights_after_loading
185236

186237
from functools import partial
238+
# don't skip for vllm < 0.9.1 for now
187239
loader._process_weights_after_loading = partial(hacked_process_weights_after_loading, original_process_weights_after_loading)
188240

189241
logger.debug("Successfully patched the _process_weights_after_loading function of vllm")
@@ -198,8 +250,13 @@ def patch_vllm_process_weights_after_loading():
198250
from vllm.model_executor.model_loader import utils
199251

200252
if not hasattr(utils, 'beforeflashrl_process_weights_after_loading'):
253+
254+
# use the skippable version for 0.9.1 and 0.9.2
255+
if parse(vllm.__version__) in [parse("0.9.1"), parse("0.9.2")]:
256+
original_process_weights_after_loading = process_weights_after_loading_skippable
257+
else:
258+
original_process_weights_after_loading = utils.process_weights_after_loading
201259

202-
original_process_weights_after_loading = utils.process_weights_after_loading
203260
utils.beforeflashrl_process_weights_after_loading = original_process_weights_after_loading
204261

205262
from functools import partial

0 commit comments

Comments
 (0)