Skip to content

Commit 26b643d

Browse files
committed
enhance gptq forward as awq
Signed-off-by: xin3he <[email protected]>
1 parent 39f649c commit 26b643d

File tree

3 files changed

+21
-19
lines changed
  • examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only
  • neural_compressor/torch/algorithms/weight_only
  • test/3x/torch/quantization/weight_only

3 files changed

+21
-19
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,12 @@ def get_user_model():
272272
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
273273
for batch in tqdm(dataloader_for_calibration):
274274
batch = move_input_to_device(batch, device=None)
275-
try:
276-
if isinstance(batch, tuple) or isinstance(batch, list):
277-
model(batch[0])
278-
elif isinstance(batch, dict):
279-
model(**batch)
280-
else:
281-
model(batch)
282-
except ValueError:
283-
pass
275+
if isinstance(batch, tuple) or isinstance(batch, list):
276+
model(batch[0])
277+
elif isinstance(batch, dict):
278+
model(**batch)
279+
else:
280+
model(batch)
284281
return
285282
if args.double_quant_type is not None:
286283
double_quant_config_dict.update(

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,18 @@ def forward(layer, *args, **kwargs):
345345
self.gptq_related_blocks["transformers"][0].forward = partial(
346346
forward, self.gptq_related_blocks["transformers"][0]
347347
)
348+
# Step 3: replace model_forward to avoid ValueError
349+
self.orig_model_forward_cache = self.model.forward
350+
model_forward_cache = self.model.forward
351+
352+
def model_forward(model, *args, **kwargs):
353+
nonlocal model_forward_cache
354+
try:
355+
model_forward_cache(*args, **kwargs)
356+
except ValueError:
357+
pass
358+
359+
self.model.forward = partial(model_forward, self.model)
348360

349361
@torch.no_grad()
350362
def remove_prepare_for_calibration(self):
@@ -359,6 +371,7 @@ def remove_prepare_for_calibration(self):
359371
logger.info("Done.")
360372

361373
# Step 4: restore original forward function, relocate layers back to cpu.
374+
self.model.forward = self.orig_model_forward_cache
362375
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
363376
if not self.use_layer_wise: # pragma: no cover
364377
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()

test/3x/torch/quantization/weight_only/test_gptq.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,11 @@
1919
device = accelerator.current_device_name()
2020

2121

22-
def run_fn_for_rtn(model):
22+
def run_fn(model):
2323
model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device))
2424
model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device))
2525

2626

27-
def run_fn(model):
28-
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
29-
# It's special for UTs, no need to add this wrapper in examples.
30-
with pytest.raises(ValueError):
31-
model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device))
32-
model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device))
33-
34-
3527
class TestGPTQQuant:
3628
def setup_class(self):
3729
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
@@ -50,7 +42,7 @@ def test_accuracy_improvement(self):
5042
model = copy.deepcopy(self.tiny_gptj)
5143
quant_config = get_default_rtn_config()
5244
model = prepare(model, quant_config)
53-
run_fn_for_rtn(model)
45+
run_fn(model)
5446
model = convert(model)
5547
rtn_label = model(self.example_inputs)[0]
5648
rtn_atol = (rtn_label - self.label).amax()

0 commit comments

Comments
 (0)