22import os
33import vllm
44import torch
5+ from torch import nn
56import types
67import gc
78from packaging .version import parse
@@ -35,6 +36,56 @@ def bond_method_to_cls(func, obj):
3536 '_assert_and_load' ,
3637]
3738
39+ # fine-grained process weights after loading implementation where we skip operations for certain modules
40+ def process_weights_after_loading_skippable (model : nn .Module , model_config ,
41+ target_device : torch .device ) -> None :
42+ from vllm .attention import Attention
43+ from vllm .model_executor .layers .linear import QKVCrossParallelLinear
44+ from vllm .model_executor .layers .quantization .base_config import QuantizeMethodBase
45+ from vllm .model_executor .model_loader .utils import device_loading_context
46+ from vllm .model_executor .layers .quantization .compressed_tensors .compressed_tensors import CompressedTensorsLinearMethod , CompressedTensorsW8A8Int8
47+ from vllm .model_executor .layers .quantization .kernels .scaled_mm .cutlass import CutlassScaledMMLinearKernel
48+
49+ for _ , module in model .named_modules ():
50+ if isinstance (module , QKVCrossParallelLinear ):
51+ # NOTE(Isotr0py): special case for cross QKV layer because
52+ # q and kv proj aren't registered as submodules intentionally
53+ module .process_weights_after_loading ()
54+ continue
55+ quant_method = getattr (module , "quant_method" , None )
56+ quant_method = getattr (module , "quant_method" , None )
57+
58+ # hack for skpping postprocess for default cutlass kernel
59+ scheme = getattr (module , "scheme" , None )
60+ is_cutlass = isinstance (quant_method , CompressedTensorsLinearMethod ) and isinstance (scheme , CompressedTensorsW8A8Int8 ) and isinstance (scheme .kernel , CutlassScaledMMLinearKernel )
61+ can_skip_postprocess = False
62+ if is_cutlass :
63+ assert isinstance (scheme , CompressedTensorsW8A8Int8 )
64+ kernel = scheme .kernel
65+ can_skip_postprocess = kernel .config .is_channelwise and not kernel .config .is_static_input_scheme and kernel .config .input_symmetric
66+
67+ if can_skip_postprocess :
68+ continue
69+
70+ if isinstance (quant_method , QuantizeMethodBase ):
71+ # When quant methods need to process weights after loading
72+ # (for repacking, quantizing, etc), they expect parameters
73+ # to be on the global target device. This scope is for the
74+ # case where cpu offloading is used, where we will move the
75+ # parameters onto device for processing and back off after.
76+ with device_loading_context (module , target_device ):
77+ quant_method .process_weights_after_loading (module )
78+
79+ # Currently only used by MLA.
80+ # NOTE: This intentionally happens after other modules so we can easily
81+ # decompress the weights for MLA.
82+ for _ , module in model .named_modules ():
83+ if isinstance (module , Attention ) and \
84+ hasattr (module , "process_weights_after_loading" ):
85+ # TODO(lucas): see if there is a way to unify the signatures
86+ # of process_weights_after_loading
87+ module .process_weights_after_loading (model_config .dtype )
88+
3889def hacked_process_weights_after_loading (
3990 original_process_weights_after_loading ,
4091 model ,
@@ -90,6 +141,7 @@ def patch_vllm_process_weights_after_loading():
90141 loader .beforeflashrl_process_weights_after_loading = original_process_weights_after_loading
91142
92143 from functools import partial
144+ # don't skip for vllm < 0.9.1 for now
93145 loader ._process_weights_after_loading = partial (hacked_process_weights_after_loading , original_process_weights_after_loading )
94146
95147 logger .debug ("Successfully patched the _process_weights_after_loading function of vllm" )
@@ -104,8 +156,13 @@ def patch_vllm_process_weights_after_loading():
104156 from vllm .model_executor .model_loader import utils
105157
106158 if not hasattr (utils , 'beforeflashrl_process_weights_after_loading' ):
159+
160+ # use the skippable version for 0.9.1 and 0.9.2
161+ if parse (vllm .__version__ ) in [parse ("0.9.1" ), parse ("0.9.2" )]:
162+ original_process_weights_after_loading = process_weights_after_loading_skippable
163+ else :
164+ original_process_weights_after_loading = utils .process_weights_after_loading
107165
108- original_process_weights_after_loading = utils .process_weights_after_loading
109166 utils .beforeflashrl_process_weights_after_loading = original_process_weights_after_loading
110167
111168 from functools import partial
0 commit comments