Skip to content

Commit 1a17045

Browse files
Merge pull request #24 from SumanthRH/skyrl-changes
Upstreaming changes from SkyRL integration
2 parents 9d61357 + 398dfcd commit 1a17045

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

flash_rl/vllm_patch.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import vllm
44
import torch
5+
from torch import nn
56
import types
67
import gc
78
from packaging.version import parse
@@ -35,6 +36,56 @@ def bond_method_to_cls(func, obj):
3536
'_assert_and_load',
3637
]
3738

39+
# fine-grained process weights after loading implementation where we skip operations for certain modules
40+
def process_weights_after_loading_skippable(model: nn.Module, model_config,
41+
target_device: torch.device) -> None:
42+
from vllm.attention import Attention
43+
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
44+
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
45+
from vllm.model_executor.model_loader.utils import device_loading_context
46+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod, CompressedTensorsW8A8Int8
47+
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
48+
49+
for _, module in model.named_modules():
50+
if isinstance(module, QKVCrossParallelLinear):
51+
# NOTE(Isotr0py): special case for cross QKV layer because
52+
# q and kv proj aren't registered as submodules intentionally
53+
module.process_weights_after_loading()
54+
continue
55+
quant_method = getattr(module, "quant_method", None)
56+
quant_method = getattr(module, "quant_method", None)
57+
58+
# hack for skpping postprocess for default cutlass kernel
59+
scheme = getattr(module, "scheme", None)
60+
is_cutlass = isinstance(quant_method, CompressedTensorsLinearMethod) and isinstance(scheme, CompressedTensorsW8A8Int8) and isinstance(scheme.kernel, CutlassScaledMMLinearKernel)
61+
can_skip_postprocess = False
62+
if is_cutlass:
63+
assert isinstance(scheme, CompressedTensorsW8A8Int8)
64+
kernel = scheme.kernel
65+
can_skip_postprocess = kernel.config.is_channelwise and not kernel.config.is_static_input_scheme and kernel.config.input_symmetric
66+
67+
if can_skip_postprocess:
68+
continue
69+
70+
if isinstance(quant_method, QuantizeMethodBase):
71+
# When quant methods need to process weights after loading
72+
# (for repacking, quantizing, etc), they expect parameters
73+
# to be on the global target device. This scope is for the
74+
# case where cpu offloading is used, where we will move the
75+
# parameters onto device for processing and back off after.
76+
with device_loading_context(module, target_device):
77+
quant_method.process_weights_after_loading(module)
78+
79+
# Currently only used by MLA.
80+
# NOTE: This intentionally happens after other modules so we can easily
81+
# decompress the weights for MLA.
82+
for _, module in model.named_modules():
83+
if isinstance(module, Attention) and \
84+
hasattr(module, "process_weights_after_loading"):
85+
# TODO(lucas): see if there is a way to unify the signatures
86+
# of process_weights_after_loading
87+
module.process_weights_after_loading(model_config.dtype)
88+
3889
def hacked_process_weights_after_loading(
3990
original_process_weights_after_loading,
4091
model,
@@ -90,6 +141,7 @@ def patch_vllm_process_weights_after_loading():
90141
loader.beforeflashrl_process_weights_after_loading = original_process_weights_after_loading
91142

92143
from functools import partial
144+
# don't skip for vllm < 0.9.1 for now
93145
loader._process_weights_after_loading = partial(hacked_process_weights_after_loading, original_process_weights_after_loading)
94146

95147
logger.debug("Successfully patched the _process_weights_after_loading function of vllm")
@@ -104,8 +156,13 @@ def patch_vllm_process_weights_after_loading():
104156
from vllm.model_executor.model_loader import utils
105157

106158
if not hasattr(utils, 'beforeflashrl_process_weights_after_loading'):
159+
160+
# use the skippable version for 0.9.1 and 0.9.2
161+
if parse(vllm.__version__) in [parse("0.9.1"), parse("0.9.2")]:
162+
original_process_weights_after_loading = process_weights_after_loading_skippable
163+
else:
164+
original_process_weights_after_loading = utils.process_weights_after_loading
107165

108-
original_process_weights_after_loading = utils.process_weights_after_loading
109166
utils.beforeflashrl_process_weights_after_loading = original_process_weights_after_loading
110167

111168
from functools import partial

0 commit comments

Comments
 (0)