2020from dataclasses import dataclass , field
2121from typing import Any , Optional
2222
23- import torch
2423import torch .nn as nn
2524
2625from verl .base_config import BaseConfig
@@ -109,6 +108,7 @@ def apply_qat(
109108
110109 logger .info (f"Found { len (modules_to_replace )} Linear layers to convert to QAT" )
111110
111+ converted_count = 0
112112 for name , module in modules_to_replace :
113113 if isinstance (module , QATLinear ):
114114 continue
@@ -121,11 +121,8 @@ def apply_qat(
121121 )
122122
123123 _set_module (model , name , fake_quant_module )
124- logger . debug ( f"Converted { name } to QATLinear" )
124+ converted_count += 1
125125
126- model ._qat_config = config
127-
128- converted_count = sum (1 for name , m in model .named_modules () if isinstance (m , QATLinear ))
129126 logger .info (f"Successfully applied QAT to { converted_count } layers" )
130127
131128 return model
@@ -140,58 +137,41 @@ def _set_module(model: nn.Module, name: str, new_module: nn.Module):
140137 setattr (parent , parts [- 1 ], new_module )
141138
142139
140+ FUSION_PATTERNS = {
141+ "qkv" : ["q_proj" , "k_proj" , "v_proj" ],
142+ "gate_up" : ["gate_proj" , "up_proj" ],
143+ }
144+
145+
143146def setup_fusion_siblings (model : nn .Module ):
144147 """Setup fusion siblings for QKV and GateUp layers."""
145148 import weakref
146149
147150 from verl .utils .qat .linear import QATLinear
148151
149- qat_modules = {}
150- for name , module in model .named_modules ():
151- if isinstance (module , QATLinear ):
152- qat_modules [name ] = module
153-
154- # Setup QKV fusion siblings
155- qkv_groups = {}
156- for name , module in qat_modules .items ():
157- for proj in ["q_proj" , "k_proj" , "v_proj" ]:
158- if name .endswith (proj ):
159- parent = name .rsplit ("." , 1 )[0 ]
160- if parent not in qkv_groups :
161- qkv_groups [parent ] = {}
162- qkv_groups [parent ][proj ] = module
163-
164- qkv_count = 0
165- for parent , projs in qkv_groups .items ():
166- if len (projs ) >= 2 :
167- modules = list (projs .values ())
168- for i , m in enumerate (modules ):
169- siblings = [modules [j ] for j in range (len (modules )) if j != i ]
170- m ._fusion_siblings_ref = [weakref .ref (s ) for s in siblings ]
171- qkv_count += 1
172-
173- # Setup GateUp fusion siblings
174- gate_up_groups = {}
175- for name , module in qat_modules .items ():
176- if name .endswith ("gate_proj" ) or name .endswith ("up_proj" ):
177- parent = name .rsplit ("." , 1 )[0 ]
178- proj_type = name .rsplit ("." , 1 )[1 ]
179- if parent not in gate_up_groups :
180- gate_up_groups [parent ] = {}
181- gate_up_groups [parent ][proj_type ] = module
182-
183- gate_up_count = 0
184- for parent , projs in gate_up_groups .items ():
185- if "gate_proj" in projs and "up_proj" in projs :
186- gate = projs ["gate_proj" ]
187- up = projs ["up_proj" ]
188- gate ._fusion_siblings_ref = [weakref .ref (up )]
189- up ._fusion_siblings_ref = [weakref .ref (gate )]
190- gate_up_count += 1
191-
192- logger .info (f"[QAT Fuse] Setup fusion siblings: { qkv_count } QKV groups, { gate_up_count } GateUp pairs" )
193-
194- return qkv_count , gate_up_count
152+ qat_modules = {name : m for name , m in model .named_modules () if isinstance (m , QATLinear )}
153+
154+ counts = {}
155+ for group_name , suffixes in FUSION_PATTERNS .items ():
156+ groups : dict [str , dict [str , nn .Module ]] = {}
157+ for name , module in qat_modules .items ():
158+ for suffix in suffixes :
159+ if name .endswith (suffix ):
160+ parent = name .rsplit ("." , 1 )[0 ]
161+ groups .setdefault (parent , {})[suffix ] = module
162+
163+ count = 0
164+ for parent , projs in groups .items ():
165+ if len (projs ) >= 2 :
166+ modules = list (projs .values ())
167+ for i , m in enumerate (modules ):
168+ siblings = modules [:i ] + modules [i + 1 :]
169+ m ._fusion_siblings_ref = [weakref .ref (s ) for s in siblings ]
170+ count += 1
171+ counts [group_name ] = count
172+
173+ logger .info (f"[QAT Fuse] Setup fusion siblings: { counts } " )
174+ return counts
195175
196176
197177def enable_qat_fuse (model : nn .Module ):
0 commit comments