Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
db16753
enable auto_round format export
WeiweiZhang1 Sep 12, 2024
1eceb6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
26fe175
Update auto_round dependency to commit 5dd16fc34a974a8c2f5a4288ce72e6…
XuehaoSun Sep 12, 2024
2e67cd5
fix docscan issues
WeiweiZhang1 Sep 12, 2024
b99140c
Merge branch 'enable_autoround_format_quantization' of https://github…
WeiweiZhang1 Sep 12, 2024
a7d1431
fixtypos
WeiweiZhang1 Sep 12, 2024
8e78efc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
0adc4ef
fix self.quantization_config
Kaihui-intel Sep 12, 2024
73d8c2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
dc49120
Merge branch 'master' into enable_autoround_format_quantization
xin3he Sep 13, 2024
27b4f43
rm ar ut
Kaihui-intel Sep 13, 2024
46f3c76
fixtypos
WeiweiZhang1 Sep 13, 2024
28e4878
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
8bb25c9
Merge branch 'enable_autoround_format_quantization' of https://github…
Kaihui-intel Sep 13, 2024
c744130
revert ar ut
WeiweiZhang1 Sep 14, 2024
39d66e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
79f44f4
refine UT
WeiweiZhang1 Sep 14, 2024
16a296e
refine UT
WeiweiZhang1 Sep 14, 2024
91f7985
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
01136d7
fix unit test
XuehaoSun Sep 14, 2024
07ae762
against code coverage issue
WeiweiZhang1 Sep 14, 2024
d3c3f39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
461379a
fixtypo
WeiweiZhang1 Sep 14, 2024
7fbf186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
41bfca5
fixtypo
WeiweiZhang1 Sep 14, 2024
7a72f52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
f3bf7fb
fixtypo
WeiweiZhang1 Sep 14, 2024
a280b10
fixtypo
WeiweiZhang1 Sep 14, 2024
7f41ff0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
enable auto_round format export
Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 committed Sep 12, 2024
commit db16753cc651507553103cbd9f9be12764bd5241
11 changes: 9 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
act_sym: bool = None,
act_dynamic: bool = True,
low_cpu_mem_usage: bool = False,
export_format: str = "itrex",
**kwargs,
):
"""Init a AutQRoundQuantizer object.
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
self.act_sym = act_sym
self.act_dynamic = act_dynamic
self.low_cpu_mem_usage = low_cpu_mem_usage

self.export_format = export_format
def prepare(self, model: torch.nn.Module, *args, **kwargs):
"""Prepares a given model for quantization.

Expand Down Expand Up @@ -211,7 +212,11 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
)
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
model = pack_model(model, weight_config, device=self.device, inplace=True)
if 'itrex' in self.export_format:
model = pack_model(model, weight_config, device=self.device, inplace=True)
else:
model = rounder.save_quantized(output_dir=None, format=self.export_format, device=self.device, inplace=True)

return model


Expand All @@ -238,3 +243,5 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, nsamples=nsamples
)
return dataloader


30 changes: 24 additions & 6 deletions neural_compressor/torch/algorithms/weight_only/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,27 @@
device_woqlinear_mapping = {"cpu": INCWeightOnlyLinear, "hpu": HPUWeightOnlyLinear}


def save(model, output_dir="./saved_results"):
def save(model, output_dir="./saved_results", format="default", **kwargs):
"""Save the quantized model and config to the output path.

Args:
model (torch.nn.module): raw fp32 model or prepared model.
output_dir (str, optional): output path to save.
"""
os.makedirs(output_dir, exist_ok=True)
if format == "huggingface":
config = model.config
quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None
if "backend" in quantization_config and 'auto_round' in quantization_config['backend']:
safe_serialization = kwargs.get("safe_serialization", True)
tokenizer = kwargs.get("tokenizer", None)
max_shard_size = kwargs.get("max_shard_size", "5GB")
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)
del model.save
model.save_pretrained(output_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
return

qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
# saving process
Expand Down Expand Up @@ -122,7 +135,7 @@ def load_woq_model(self):
"""
if self.format == LoadFormat.HUGGINGFACE:
assert self.model_name_or_path is not None, "'model_name_or_path' can't be None."

model = self.load_hf_format_woq_model()
logger.info("Loading HuggingFace weight-only quantization model successfully.")
elif self.format == LoadFormat.DEFAULT:
Expand Down Expand Up @@ -195,16 +208,20 @@ def load_hf_format_woq_model(self):
"""
# check required package
from neural_compressor.torch.utils import is_package_available

if not is_package_available("transformers"):
raise ImportError("Loading huggingface model requires transformers: `pip install transformers`")
if not is_package_available("accelerate"):
raise ImportError("Loading huggingface model requires accelerate: `pip install accelerate`")

# get model class and config
model_class, config = self._get_model_class_and_config()
self.quantization_config = config.quantization_config

quantization_config = config.quantization_config if hasattr(config, "quantization_config") else None
if "backend" in quantization_config and 'auto_round' in quantization_config['backend']:
# load autoround format quantized model
from auto_round import AutoRoundConfig
model = model_class.from_pretrained(self.model_name_or_path)
return model
# get loaded state_dict
self.loaded_state_dict = self._get_loaded_state_dict(config)
self.loaded_state_dict_keys = list(set(self.loaded_state_dict.keys()))
Expand Down Expand Up @@ -400,7 +417,7 @@ def _get_model_class_and_config(self):
trust_remote_code = self.kwargs.pop("trust_remote_code", None)
kwarg_attn_imp = self.kwargs.pop("attn_implementation", None)

config = AutoConfig.from_pretrained(self.model_name_or_path)
config = AutoConfig.from_pretrained(self.model_name_or_path, trust_remote_code=trust_remote_code)
# quantization_config = config.quantization_config

if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover
Expand Down Expand Up @@ -866,3 +883,4 @@ def _use_hpu_module(self): # pragma: no cover
if os.path.exists(os.path.join(self._model_local_dir, HPU_WEIGHT_NAME)):
return True
return False

3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def autoround_quantize_entry(
scale_dtype = quant_config.scale_dtype
quant_block_list = quant_config.quant_block_list
low_cpu_mem_usage = quant_config.use_layer_wise
export_format = quant_config.export_format

kwargs.pop("example_inputs")

Expand Down Expand Up @@ -636,6 +637,7 @@ def autoround_quantize_entry(
scale_dtype=scale_dtype,
quant_block_list=quant_block_list,
low_cpu_mem_usage=low_cpu_mem_usage,
export_format=export_format,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
model.qconfig = configs_mapping
Expand Down Expand Up @@ -752,3 +754,4 @@ def mixed_precision_entry(
mixed_precision_model = half_precision_converter.convert(model)

return mixed_precision_model

3 changes: 3 additions & 0 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def __init__(
scale_dtype: str = "fp16",
use_layer_wise: bool = False,
quant_block_list: list = None,
export_format: str = "itrex",
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init AUTOROUND weight-only quantization config.
Expand Down Expand Up @@ -1005,6 +1006,7 @@ def __init__(
self.scale_dtype = scale_dtype
self.use_layer_wise = use_layer_wise
self.quant_block_list = quant_block_list
self.export_format = export_format
self._post_init()

@classmethod
Expand Down Expand Up @@ -2058,3 +2060,4 @@ def get_woq_tuning_config() -> list:
GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32)
AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32)
return [RTN_G32ASYM, AUTO_ROUND_CONFIG, GPTQ_G32ASYM, AWQ_G32ASYM]

27 changes: 24 additions & 3 deletions test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import copy
import shutil

import pytest
import torch
import transformers
from packaging.version import Version

from neural_compressor.torch.quantization import (
AutoRoundConfig,
convert,
Expand All @@ -21,8 +19,11 @@
try:
import auto_round
from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear
from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear


auto_round_installed = True
auto_gptq_installed = True
except ImportError:
auto_round_installed = False

Expand All @@ -40,6 +41,7 @@ def run_fn(model, dataloader):

@pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed")
class TestAutoRound:
@classmethod
def setup_class(self):
self.gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
Expand All @@ -51,7 +53,8 @@ def setup_class(self):
)
self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10)
self.label = self.gptj(self.inp)[0]


@classmethod
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)

Expand Down Expand Up @@ -143,6 +146,7 @@ def test_save_and_load(self):
loaded_model.transformer.h[0].attn.k_proj, INCWeightOnlyLinear
), "loading compressed model failed."


def test_conv1d(self):
input = torch.randn(1, 32)
from transformers import GPT2Model, GPT2Tokenizer
Expand All @@ -159,3 +163,20 @@ def test_conv1d(self):
out2 = q_model(**encoded_input)[0]
assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected."
assert isinstance(q_model.h[0].attn.c_attn, WeightOnlyLinear), "loading compressed model failed."


@pytest.mark.skipif(not auto_gptq_installed, reason="auto_gptq module is not installed")
def test_autoround_format_export(self):
from neural_compressor.torch.quantization import load
gpt_j_model = copy.deepcopy(self.gptj)
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32", export_format="auto_round:gptq")
logger.info(f"Test AutoRound with config {quant_config}")
model = prepare(model=gpt_j_model, quant_config=quant_config)
run_fn(model, self.dataloader)
q_model = convert(model)
out = q_model(self.inp)[0]
assert torch.allclose(out, self.label, atol=1e-1)
assert isinstance(q_model.transformer.h[0].attn.k_proj, QuantLinear), "packing model failed."
q_model.save(output_dir="saved_results_tiny-random-GPTJForCausalLM", format="huggingface")
loaded_model = load("saved_results_tiny-random-GPTJForCausalLM", format="huggingface", trust_remote_code=True)