Skip to content
Prev Previous commit
Next Next commit
enhance gptq forward as awq
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 20, 2024
commit 26b643d9362a98d0312567268e865b1e8c7aa757
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,12 @@ def get_user_model():
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
for batch in tqdm(dataloader_for_calibration):
batch = move_input_to_device(batch, device=None)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
model(**batch)
else:
model(batch)
except ValueError:
pass
if isinstance(batch, tuple) or isinstance(batch, list):
model(batch[0])
elif isinstance(batch, dict):
model(**batch)
else:
model(batch)
return
if args.double_quant_type is not None:
double_quant_config_dict.update(
Expand Down
13 changes: 13 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,18 @@ def forward(layer, *args, **kwargs):
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
)
# Step 3: replace model_forward to avoid ValueError
self.orig_model_forward_cache = self.model.forward
model_forward_cache = self.model.forward

def model_forward(model, *args, **kwargs):
nonlocal model_forward_cache
try:
model_forward_cache(*args, **kwargs)
except ValueError:
pass

self.model.forward = partial(model_forward, self.model)

@torch.no_grad()
def remove_prepare_for_calibration(self):
Expand All @@ -359,6 +371,7 @@ def remove_prepare_for_calibration(self):
logger.info("Done.")

# Step 4: restore original forward function, relocate layers back to cpu.
self.model.forward = self.orig_model_forward_cache
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
if not self.use_layer_wise: # pragma: no cover
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
Expand Down
12 changes: 2 additions & 10 deletions test/3x/torch/quantization/weight_only/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,11 @@
device = accelerator.current_device_name()


def run_fn_for_rtn(model):
def run_fn(model):
model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device))
model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device))


def run_fn(model):
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
# It's special for UTs, no need to add this wrapper in examples.
with pytest.raises(ValueError):
model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device))
model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device))


class TestGPTQQuant:
def setup_class(self):
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
Expand All @@ -50,7 +42,7 @@ def test_accuracy_improvement(self):
model = copy.deepcopy(self.tiny_gptj)
quant_config = get_default_rtn_config()
model = prepare(model, quant_config)
run_fn_for_rtn(model)
run_fn(model)
model = convert(model)
rtn_label = model(self.example_inputs)[0]
rtn_atol = (rtn_label - self.label).amax()
Expand Down