Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
adapt fx
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Aug 30, 2024
commit 20f5152066f09a274e0acd8c3f01840b44598160
105 changes: 42 additions & 63 deletions neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import types

from accelerate import init_empty_weights
from accelerate.utils import is_xpu_available

from neural_compressor.adaptor.torch_utils.util import set_module
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
Expand Down Expand Up @@ -168,7 +169,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
revision = kwargs.pop("revision", "main")
commit_hash = kwargs.pop("_commit_hash", None)
_fast_init = kwargs.pop("_fast_init", True)
device_map = kwargs.pop("device_map", "auto")
device_map = kwargs.pop("device_map", "xpu" if is_xpu_available() else "cpu")
use_safetensors = kwargs.pop("use_safetensors", None)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)

Expand Down Expand Up @@ -210,6 +211,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config = RtnConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "gptq":
quantization_config = GPTQConfig.from_dict(quantization_config)

assert quantization_config is not None, "Detect this model is not a low-bit model."

if commit_hash is None:
Expand Down Expand Up @@ -501,47 +503,27 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
logger.warning("fp32 scale_dtype is used, please change the config.json if you don't want to use it.")

# weight dtype is higher priority than bits in config.json when both existed.
if quantization_config.weight_dtype is None:
if quantization_config.bits == 4:
if use_xpu:
quantization_config.weight_dtype = "int4_fullrange"
else:
quantization_config.weight_dtype = "int4"
logger.info(
"{} quantization weight_dtype is used due to bits is 4 in config.json.".format(
quantization_config.weight_dtype
)
)
elif quantization_config.bits == 8:
quantization_config.weight_dtype = "int8"
logger.info(
"{} quantization weight_dtype is used due to bits is 8 in config.json.".format(
quantization_config.weight_dtype
)
)
if quantization_config.bits == 4:
if use_xpu:
quantization_config.weight_dtype = "int4_fullrange"
else:
logger.warning("bits number only supports 4, 8.")
quantization_config.weight_dtype = "int4"
logger.warning("int4 weight_dtype is used, please change the config.json if you don't want to use it.")
else:
if quantization_config.weight_dtype not in [
"int4_fullrange",
"int4",
"int8",
"fp8_e5m2",
"fp8_e4m3",
"nf4",
"fp4_e2m1_bnb",
"fp4_e2m1",
]:
logger.warning("Please provide the correct bits number or weight_dtype in config.json.")
raise ValueError(
"weight_dtype must be a string in "
"'int8', 'int4', 'int4_fullrange', 'int4', 'nf4', "
"'fp4', 'fp4_e2m1', 'fp8', 'fp8_e5m2, fp8_e4m3'"
logger.info(
"{} quantization weight_dtype is used due to bits is 4 in config.json.".format(
quantization_config.weight_dtype
)
else:
logger.info("{} quantization weight_dtype is used.".format(quantization_config.weight_dtype))
)
elif quantization_config.bits == 8:
quantization_config.weight_dtype = "int8"
logger.info(
"{} quantization weight_dtype is used due to bits is 8 in config.json.".format(
quantization_config.weight_dtype
)
)
else:
logger.warning("bits number only supports 4, 8.")
quantization_config.weight_dtype = "int4"
logger.warning("int4 weight_dtype is used, please change the config.json if you don't want to use it.")

init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts.append(init_empty_weights())
Expand All @@ -561,30 +543,27 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

else:
raise AssertionError("Please install intel_extension_for_pytorch.")
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
21 changes: 15 additions & 6 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def _replace_linear(
model._modules[name].requires_grad_(False)

if device == "xpu" or device == torch.device("xpu"):

if not hasattr(module, "qweight"):
n_pack = 32 // quantization_config.bits

Expand Down Expand Up @@ -343,10 +342,9 @@ def convert_to_quantized_model(model, config, device="cpu"):
break

# mapping to INC config
dtype = "int4" if config.weight_dtype == "int4_fullrange" else config.weight_dtype
if config.quant_method.value == "rtn":
quant_config = RTNConfig(
dtype=config.weight_dtype, bits=config.bits, use_sym=config.sym, group_size=config.group_size
)
quant_config = RTNConfig(dtype=dtype, bits=config.bits, use_sym=config.sym, group_size=config.group_size)
if config.use_layer_wise:
quant_config.user_layer_wise = config.use_layer_wise
quant_config.model_path = config.model_path
Expand All @@ -360,7 +358,7 @@ def convert_to_quantized_model(model, config, device="cpu"):
elif config.quant_method.value == "gptq":
model.seqlen = config.seq_len
quant_config = GPTQConfig(
dtype=config.weight_dtype,
dtype=dtype,
bits=config.bits,
use_sym=config.sym,
group_size=config.group_size,
Expand Down Expand Up @@ -399,6 +397,7 @@ def convert_to_quantized_model(model, config, device="cpu"):
logger.warning("The recommended ipex version is higher than 2.3.10 for xpu device.")

model.eval()

q_model = replace_linear(model, None, None, config, device=device)

if orig_dtype != torch.float32:
Expand Down Expand Up @@ -477,8 +476,18 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo
# use transformers original `save_pretrained` function
del self.save_pretrained

if self.device == "cpu" or self.device == torch.device("cpu") or self.device == "auto":
if self.device == "cpu" or self.device == torch.device("cpu"):
convert_to_GPTQ_checkpoints(self, self.quantization_config)
if self.device == "xpu" or (isinstance(self.device, torch.device) and self.device.type == "xpu"):
from intel_extension_for_pytorch.nn.utils._quantize_convert import WeightOnlyQuantizedLinear

for name, module in self.named_modules():
if isinstance(module, WeightOnlyQuantizedLinear):
if module.weight_transposed:
module.qweight.data = module.qweight.t_().contiguous()
module.scales.data = module.scales.t_().contiguous()
module.weight_transposed = False

self.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
self.save_pretrained = types.MethodType(save_low_bit, self)
# We conveniently save all the keys of the model to have them on hand,
Expand Down