Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix bias issue
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Aug 28, 2024
commit c02db4683ee05e78971953b0c503c1f691d4a431
20 changes: 10 additions & 10 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@
import types

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME

from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear as WeightOnlyLinear
from neural_compressor.torch.quantization import GPTQConfig, RTNConfig, convert, prepare
from neural_compressor.torch.utils import is_ipex_available
from neural_compressor.utils.utility import CpuInfo, LazyImport
Expand Down Expand Up @@ -162,21 +159,21 @@ def _replace_linear(
tmp_linear = torch.nn.Linear(
in_features,
out_features,
True if hasattr(module, "bias") else False,
True if hasattr(module, "bias") and module.bias is not None else False,
)
if tmp_linear.bias is not None and module.bias is not None:
tmp_linear.bias = torch.nn.Parameter(module.bias.float())

tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig
model._modules[name] = ipex_linear.from_float_and_int4_weight(
mod=tmp_linear,
qweight=qweight,
scales=scales,
zero_points=qzeros,
# bias=(module.bias if (hasattr(module, "bias") and not torch.all(module.bias.eq(0))) else None),
bias=(module.bias.float() if hasattr(module, "bias") else None),
bias=(module.bias.float() if hasattr(module, "bias") and module.bias is not None else None),
group_size=quantization_config.group_size,
g_idx=(module.g_idx if hasattr(module, "g_idx") else None),
)
# print(current_key_name)
# print(module.bias.float())

elif device == "xpu" or device == torch.device("xpu"):
from intel_extension_for_pytorch.nn.utils._quantize_convert import (
Expand Down Expand Up @@ -363,7 +360,7 @@ def convert_to_quantized_model(model, config, device="cpu"):
bits=config.bits,
use_sym=config.sym,
group_size=config.group_size,
use_layer_wise=config.layer_wise,
use_layer_wise=config.use_layer_wise,
act_order=config.desc_act,
percdamp=config.damp_percent,
block_size=config.blocksize,
Expand Down Expand Up @@ -472,7 +469,9 @@ def convert_to_GPTQ_checkpoints(model, quantization_config):
new_module.n_pack = 32 // bits
scales = module._op_context.get_scales().t().contiguous()
bias = module._op_context.get_bias()
qzeros = new_module.pack_tensor_with_numpy(module._op_context.get_zero_points().t()).contiguous()
qzeros = new_module.pack_tensor_with_numpy(
module._op_context.get_zero_points().t().to(torch.uint8) - 1
).contiguous()
g_idx = module._op_context.get_g_idx()

new_module.qweight = qweight
Expand All @@ -482,6 +481,7 @@ def convert_to_GPTQ_checkpoints(model, quantization_config):
new_module.g_idx = g_idx.contiguous()
if bias is not None:
new_module.bias = bias.contiguous()

set_module(model, name, new_module)
return model

Expand Down
32 changes: 25 additions & 7 deletions test/3x/torch/test_transformers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import unittest

Expand All @@ -12,30 +13,47 @@
class TestQuantizationConfig(unittest.TestCase):
@classmethod
def setUpClass(self):
self.model_name = "TheBlokeAI/Mixtral-tiny-GPTQ"
self.model_name = "hf-internal-testing/tiny-random-gptj"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.prompt = "One day, the little girl"
self.input_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"]

@classmethod
def tearDownClass(self):
shutil.rmtree("tmp_gptq")
shutil.rmtree("tmp_rtn")
if os.path.exists("tmp_gptq") and os.path.isdir("tmp_gptq"):
shutil.rmtree("tmp_gptq")
if os.path.exists("tmp_rtn") and os.path.isdir("tmp_rtn"):
shutil.rmtree("tmp_rtn")

def test_gptq(self):
quantization_config = GPTQConfig(bits=4, sym=True, damp_percent=0.01, desc_act=True)
quantization_config = GPTQConfig(
bits=4,
sym=True,
damp_percent=0.01,
desc_act=True,
tokenizer=self.tokenizer,
n_samples=20,
group_size=8,
batch_size=5,
seq_len=32,
block_size=16,
)
user_model = INCModelForCausalLM.from_pretrained(self.model_name, quantization_config=quantization_config)
output = user_model(self.input_ids)
user_model.save_pretrained("tmp_gptq")
loaded_model = INCModelForCausalLM.from_pretrained("tmp_gptq")
loaded_output = loaded_model(self.input_ids)
assert torch.allclose(output, loaded_output, atol=1e-2), "Compare failed!"
assert torch.allclose(output.logits, loaded_output.logits, atol=1e-2), "Compare failed!"

def test_rtn(self):
quantization_config = RtnConfig(bits=4)
quantization_config = RtnConfig(bits=4, group_size=8, sym=False)
user_model = INCModelForCausalLM.from_pretrained(self.model_name, quantization_config=quantization_config)
output = user_model(self.input_ids)
user_model.save_pretrained("tmp_rtn")
loaded_model = INCModelForCausalLM.from_pretrained("tmp_rtn")
loaded_output = loaded_model(self.input_ids)
assert torch.allclose(output, loaded_output, atol=1e-2), "Compare failed!"
assert torch.allclose(output.logits, loaded_output.logits, atol=1e-2), "Compare failed!"


if __name__ == "__main__":
unittest.main()