Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support save and load
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Aug 27, 2024
commit b7da9199a26f8aed46a55104f8a43b1f13e70201
85 changes: 8 additions & 77 deletions neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@
# limitations under the License.

import copy
import json
import os
import re
import types
from threading import Thread
from typing import Union

import transformers
from accelerate import init_empty_weights
Expand All @@ -52,8 +48,8 @@
is_safetensors_available,
)

from neural_compressor.adaptor.torch_utils.util import set_module
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.utils import is_ipex_available
from neural_compressor.transformers import GPTQConfig, RtnConfig
from neural_compressor.transformers.quantization.utils import convert_dtype_torch2str, replace_linear, save_low_bit
from neural_compressor.utils import logger
Expand All @@ -63,8 +59,6 @@


def build_woq_model(model, quantization_config):
from neural_compressor.adaptor.torch_utils.util import set_module

bits = quantization_config.bits
for n, m in model.named_modules():
if n in quantization_config.modules_to_not_convert:
Expand Down Expand Up @@ -560,70 +554,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

if is_ipex_available():
model = replace_linear(
model,
quantization_config=quantization_config,
device="cpu" if device_map == "auto" else device_map,
empty_weights=True,
)
# if (device_map == "cpu" or device_map == torch.device("cpu")):
# import intel_extension_for_pytorch as ipex
# from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear

# def replace_ipex_cpu_woq_linear(model, current_name=[]):
# for name, module in model.named_children():
# current_name.append(name)
# if isinstance(module, INCWeightOnlyLinear):
# weight_dtype = {
# 4: ipex.quantization.WoqWeightDtype.INT4,
# 8: ipex.quantization.WoqWeightDtype.INT8,
# }
# compute_dtype = {
# "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype.
# "bf16": ipex.quantization.WoqLowpMode.BF16,
# "fp16": ipex.quantization.WoqLowpMode.FP16,
# "int8": ipex.quantization.WoqLowpMode.INT8,
# }

# ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping(
# weight_dtype=weight_dtype[quantization_config.bits],
# lowp_mode=compute_dtype[quantization_config.compute_dtype],
# act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
# group_size=quantization_config.group_size,
# )
# tmp_linear = torch.nn.Linear(
# module.in_features,
# module.out_features,
# True if hasattr(module, "bias") else False,
# )
# tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig
# target_linear = ipex_linear.from_float_and_int4_weight(
# mod=tmp_linear,
# qweight=state_dict.pop(".".join(current_name) + ".qweight"),
# scales=state_dict.pop(".".join(current_name) + ".scales"),
# zero_points=state_dict.pop(".".join(current_name) + ".qzeros"),
# bias=(
# state_dict.pop(".".join(current_name) + ".bias")
# if ".".join(current_name) + ".bias" in state_dict
# else None
# ),
# group_size=quantization_config.group_size,
# g_idx=(
# state_dict.pop(".".join(current_name) + ".g_idx")
# if ".".join(current_name) + ".g_idx" in state_dict
# else None
# ),
# )
# setattr(model, name, target_linear)
# else:
# replace_ipex_cpu_woq_linear(module, current_name)
# current_name.pop()

# replace_ipex_cpu_woq_linear(model)
# model.load_state_dict(state_dict, strict=False, assign=True)
# else:
(
model,
missing_keys,
Expand All @@ -645,6 +575,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

else:
raise AssertionError("Please install intel_extension_for_pytorch.")

Expand All @@ -654,12 +585,12 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()

# model = replace_linear(
# model,
# quantization_config=quantization_config,
# device="cpu" if device_map == "auto" else device_map,
# empty_weights=True,
# )
model = replace_linear(
model,
quantization_config=quantization_config,
device="cpu" if device_map == "auto" else device_map,
empty_weights=True,
)

if (not use_xpu and torch_dtype == torch.float16) or (
not use_xpu and not CpuInfo().bf16 and torch_dtype == torch.bfloat16
Expand Down
123 changes: 84 additions & 39 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,13 @@ def _replace_linear(
qweight=qweight,
scales=scales,
zero_points=qzeros,
bias=(module.bias if hasattr(module, "bias") else None),
# bias=(module.bias if (hasattr(module, "bias") and not torch.all(module.bias.eq(0))) else None),
bias=(module.bias.float() if hasattr(module, "bias") else None),
group_size=quantization_config.group_size,
g_idx=(module.g_idx if hasattr(module, "g_idx") else None),
)
# print(current_key_name)
# print(module.bias.float())

elif device == "xpu" or device == torch.device("xpu"):
from intel_extension_for_pytorch.nn.utils._quantize_convert import (
Expand Down Expand Up @@ -403,41 +406,84 @@ def convert_to_quantized_model(model, config, device="cpu"):
return q_model.to(device)


# def save_linear_parameters(model, save_directory):

# weights_file = os.path.join(
# os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME
# )
# linear_parameters = {}
# from intel_extension_for_pytorch.nn.modules import (
# WeightOnlyQuantizedLinear as ipex_cpu_linear,
# )

# for name, module in model.named_modules():
# if isinstance(module, ipex_cpu_linear):
# linear_parameters[name + ".qweight"] = (
# module._op_context.to_public(
# module._op_context.get_weight()
# ).contiguous()
# )
# linear_parameters[name + ".scales"] = (
# module._op_context.get_scales().contiguous()
# )
# linear_parameters[name + ".qzeros"] = (
# module._op_context.get_zero_points().contiguous()
# )
# if module._op_context.get_bias() is not None:
# linear_parameters[name + ".bias"] = (
# module._op_context.get_bias().contiguous()
# )
# if module._op_context.get_g_idx() is not None:
# linear_parameters[name + ".g_idx"] = (
# module._op_context.get_g_idx().contiguous()
# )

# others_parameters = model.state_dict()
# linear_parameters.update(others_parameters)
# torch.save(linear_parameters, weights_file)
def pack_tensor_with_torch(raw_tensor, bits, compression_dtype=torch.int32):
"""Pack the tensor with torch.

Args:
raw_tensor (tensor): raw tensor.

Returns:
tensor: packed tensor.
"""
n_pack = 32 // bits
target_len = math.ceil(raw_tensor.shape[1] / n_pack)
packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=compression_dtype).to(raw_tensor.device)
mask = torch.tensor(2**bits - 1, dtype=compression_dtype).to(raw_tensor.device)
for j in range(packed_tensor.shape[1]):
start = n_pack * j
end = n_pack * (j + 1)
tmp = raw_tensor[:, start:end].type(compression_dtype)
tmp &= mask
for e in range(tmp.shape[1]):
tmp[:, e] = tmp[:, e] << (bits * e)
packed_tensor[:, j] |= tmp[:, e]

return packed_tensor


def convert_to_GPTQ_checkpoints(model, quantization_config):
from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_cpu_linear

from neural_compressor.adaptor.torch_utils.util import set_module
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear

dtype = "int4" if quantization_config.bits == 4 else "int8"
bits = quantization_config.bits
group_size = quantization_config.group_size
zp = False if quantization_config.sym else True
scale_dtype = quantization_config.scale_dtype
desc_act = (True if hasattr(quantization_config, "desc_act") else False,)

for name, module in model.named_modules():
if isinstance(module, ipex_cpu_linear):
in_features = module.in_features
out_features = module.out_features
new_module = INCWeightOnlyLinear(
in_features,
out_features,
dtype=dtype,
bits=bits,
group_size=group_size,
zp=zp,
bias=True if hasattr(module, "bias") else False,
scale_dtype=scale_dtype,
g_idx=desc_act,
use_optimum_format=True,
)

new_module.bits = 8
new_module.n_pack = 32 // 8
qweight = (
new_module.pack_tensor_with_numpy(module._op_context.to_public(module._op_context.get_weight()))
.t()
.contiguous()
)
new_module.bits = bits
new_module.n_pack = 32 // bits
scales = module._op_context.get_scales().t().contiguous()
bias = module._op_context.get_bias()
qzeros = new_module.pack_tensor_with_numpy(module._op_context.get_zero_points().t()).contiguous()
g_idx = module._op_context.get_g_idx()

new_module.qweight = qweight
new_module.scales = scales
new_module.qzeros = qzeros
if g_idx is not None:
new_module.g_idx = g_idx.contiguous()
if bias is not None:
new_module.bias = bias.contiguous()
set_module(model, name, new_module)
return model


def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
Expand All @@ -452,10 +498,9 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo
# use transformers original `save_pretrained` function
del self.save_pretrained

if self.device == "cpu" or self.device == torch.device("cpu") or self.device == "auto":
convert_to_GPTQ_checkpoints(self, self.quantization_config)
self.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)

# if self.device == "cpu" or self.device == torch.device("cpu") or self.device == "auto":
# save_linear_parameters(self, save_directory)
self.save_pretrained = types.MethodType(save_low_bit, self)
# We conveniently save all the keys of the model to have them on hand,
# so that when using 'low_cpumem load',
Expand Down
38 changes: 38 additions & 0 deletions test/3x/torch/test_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest
import torch
import pytest
import shutil
from transformers import AutoTokenizer
from optimum.intel import INCModelForCausalLM
from neural_compressor.transformers import GPTQConfig, RtnConfig
class TestQuantizationConfig(unittest.TestCase):
@classmethod
def setUpClass(self):
self.model_name = "TheBlokeAI/Mixtral-tiny-GPTQ"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.prompt = "One day, the little girl"
self.input_ids = self.tokenizer(self.prompt, return_tensors="pt")["input_ids"]

@classmethod
def tearDownClass(self):
shutil.rmtree("tmp_gptq")
shutil.rmtree("tmp_rtn")
def test_gptq(self):
quantization_config = GPTQConfig(
bits=4, sym=True, damp_percent=0.01,desc_act=True
)
user_model = INCModelForCausalLM.from_pretrained(self.model_name, quantization_config=quantization_config)
output = user_model(self.input_ids)
user_model.save_pretrained("tmp_gptq")
loaded_model = INCModelForCausalLM.from_pretrained("tmp_gptq")
loaded_output = loaded_model(self.input_ids)
assert torch.allclose(output, loaded_output, atol=1e-2), "Compare failed!"

def test_rtn(self):
quantization_config = RtnConfig(bits=4)
user_model = INCModelForCausalLM.from_pretrained(self.model_name, quantization_config=quantization_config)
output = user_model(self.input_ids)
user_model.save_pretrained("tmp_rtn")
loaded_model = INCModelForCausalLM.from_pretrained("tmp_rtn")
loaded_output = loaded_model(self.input_ids)
assert torch.allclose(output, loaded_output, atol=1e-2), "Compare failed!"