Skip to content

Commit 67c87d2

Browse files
committed
clean rebundant code
1 parent ac3cc5a commit 67c87d2

File tree

10 files changed

+331
-807
lines changed

10 files changed

+331
-807
lines changed

verl/utils/qat/__init__.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -37,30 +37,10 @@
3737
invalidate_all_scales,
3838
load_quantization_config,
3939
)
40-
from verl.utils.qat.linear import (
41-
TRITON_AVAILABLE,
42-
QATLinear,
43-
QATMode,
44-
STEFP4QuantTriton,
45-
fp4_fake_quant_weight,
46-
)
47-
from verl.utils.qat.quantizer import (
48-
FUSE_PATTERNS,
49-
QATQuantizer,
50-
ScaleInfo,
51-
WeightScaleResult,
52-
compute_blockwise_scale_only,
53-
compute_global_amax,
54-
compute_global_scale_from_amax,
55-
compute_weight_scales,
56-
fuse_global_scales,
57-
)
5840
from verl.utils.qat.vllm_patch import (
59-
LazyParamsDict,
6041
apply_qat_patches,
6142
manual_process_weights_after_loading,
6243
prepare_qat_for_load_weights,
63-
remove_qat_patches,
6444
)
6545

6646
__all__ = [
@@ -70,26 +50,8 @@
7050
"load_quantization_config",
7151
"enable_qat_fuse",
7252
"invalidate_all_scales",
73-
# Linear (includes Triton kernels)
74-
"QATLinear",
75-
"QATMode",
76-
"STEFP4QuantTriton",
77-
"TRITON_AVAILABLE",
78-
"fp4_fake_quant_weight",
79-
# Quantizer (includes scale computation utilities)
80-
"QATQuantizer",
81-
"ScaleInfo",
82-
"WeightScaleResult",
83-
"compute_weight_scales",
84-
"compute_blockwise_scale_only",
85-
"compute_global_amax",
86-
"compute_global_scale_from_amax",
87-
"fuse_global_scales",
88-
"FUSE_PATTERNS",
8953
# vLLM Patch
9054
"apply_qat_patches",
91-
"remove_qat_patches",
9255
"manual_process_weights_after_loading",
9356
"prepare_qat_for_load_weights",
94-
"LazyParamsDict",
9557
]

verl/utils/qat/core.py

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from dataclasses import dataclass, field
2121
from typing import Any, Optional
2222

23-
import torch
2423
import torch.nn as nn
2524

2625
from verl.base_config import BaseConfig
@@ -109,6 +108,7 @@ def apply_qat(
109108

110109
logger.info(f"Found {len(modules_to_replace)} Linear layers to convert to QAT")
111110

111+
converted_count = 0
112112
for name, module in modules_to_replace:
113113
if isinstance(module, QATLinear):
114114
continue
@@ -121,11 +121,8 @@ def apply_qat(
121121
)
122122

123123
_set_module(model, name, fake_quant_module)
124-
logger.debug(f"Converted {name} to QATLinear")
124+
converted_count += 1
125125

126-
model._qat_config = config
127-
128-
converted_count = sum(1 for name, m in model.named_modules() if isinstance(m, QATLinear))
129126
logger.info(f"Successfully applied QAT to {converted_count} layers")
130127

131128
return model
@@ -140,58 +137,41 @@ def _set_module(model: nn.Module, name: str, new_module: nn.Module):
140137
setattr(parent, parts[-1], new_module)
141138

142139

140+
FUSION_PATTERNS = {
141+
"qkv": ["q_proj", "k_proj", "v_proj"],
142+
"gate_up": ["gate_proj", "up_proj"],
143+
}
144+
145+
143146
def setup_fusion_siblings(model: nn.Module):
144147
"""Setup fusion siblings for QKV and GateUp layers."""
145148
import weakref
146149

147150
from verl.utils.qat.linear import QATLinear
148151

149-
qat_modules = {}
150-
for name, module in model.named_modules():
151-
if isinstance(module, QATLinear):
152-
qat_modules[name] = module
153-
154-
# Setup QKV fusion siblings
155-
qkv_groups = {}
156-
for name, module in qat_modules.items():
157-
for proj in ["q_proj", "k_proj", "v_proj"]:
158-
if name.endswith(proj):
159-
parent = name.rsplit(".", 1)[0]
160-
if parent not in qkv_groups:
161-
qkv_groups[parent] = {}
162-
qkv_groups[parent][proj] = module
163-
164-
qkv_count = 0
165-
for parent, projs in qkv_groups.items():
166-
if len(projs) >= 2:
167-
modules = list(projs.values())
168-
for i, m in enumerate(modules):
169-
siblings = [modules[j] for j in range(len(modules)) if j != i]
170-
m._fusion_siblings_ref = [weakref.ref(s) for s in siblings]
171-
qkv_count += 1
172-
173-
# Setup GateUp fusion siblings
174-
gate_up_groups = {}
175-
for name, module in qat_modules.items():
176-
if name.endswith("gate_proj") or name.endswith("up_proj"):
177-
parent = name.rsplit(".", 1)[0]
178-
proj_type = name.rsplit(".", 1)[1]
179-
if parent not in gate_up_groups:
180-
gate_up_groups[parent] = {}
181-
gate_up_groups[parent][proj_type] = module
182-
183-
gate_up_count = 0
184-
for parent, projs in gate_up_groups.items():
185-
if "gate_proj" in projs and "up_proj" in projs:
186-
gate = projs["gate_proj"]
187-
up = projs["up_proj"]
188-
gate._fusion_siblings_ref = [weakref.ref(up)]
189-
up._fusion_siblings_ref = [weakref.ref(gate)]
190-
gate_up_count += 1
191-
192-
logger.info(f"[QAT Fuse] Setup fusion siblings: {qkv_count} QKV groups, {gate_up_count} GateUp pairs")
193-
194-
return qkv_count, gate_up_count
152+
qat_modules = {name: m for name, m in model.named_modules() if isinstance(m, QATLinear)}
153+
154+
counts = {}
155+
for group_name, suffixes in FUSION_PATTERNS.items():
156+
groups: dict[str, dict[str, nn.Module]] = {}
157+
for name, module in qat_modules.items():
158+
for suffix in suffixes:
159+
if name.endswith(suffix):
160+
parent = name.rsplit(".", 1)[0]
161+
groups.setdefault(parent, {})[suffix] = module
162+
163+
count = 0
164+
for parent, projs in groups.items():
165+
if len(projs) >= 2:
166+
modules = list(projs.values())
167+
for i, m in enumerate(modules):
168+
siblings = modules[:i] + modules[i + 1 :]
169+
m._fusion_siblings_ref = [weakref.ref(s) for s in siblings]
170+
count += 1
171+
counts[group_name] = count
172+
173+
logger.info(f"[QAT Fuse] Setup fusion siblings: {counts}")
174+
return counts
195175

196176

197177
def enable_qat_fuse(model: nn.Module):

0 commit comments

Comments
 (0)