Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0658a83
add inc woq and remove itrex dependency
changwangss Aug 27, 2024
4955b8a
Update optimum/intel/neural_compressor/modeling_base.py
changwangss Aug 29, 2024
7fe5ac5
Update optimum/intel/neural_compressor/modeling_base.py
changwangss Aug 29, 2024
1d6797c
Update optimum/intel/neural_compressor/modeling_base.py
changwangss Aug 29, 2024
ab178e9
Update optimum/intel/neural_compressor/modeling_base.py
changwangss Aug 29, 2024
c078ca2
fix code according comment
changwangss Aug 29, 2024
c257101
add logger setting
changwangss Aug 29, 2024
d55004b
improve ut
changwangss Aug 29, 2024
fcadbac
move woq quantization to quantization.py
changwangss Sep 5, 2024
8cf22de
Update examples/neural_compressor/language-modeling/run_clm.py
changwangss Sep 5, 2024
a31fc6a
Update examples/neural_compressor/language-modeling/run_clm.py
changwangss Sep 5, 2024
3b5f228
remove dependency
changwangss Sep 5, 2024
7f8c2a2
Update examples/neural_compressor/language-modeling/run_clm.py
IlyasMoutawwakil Sep 5, 2024
6eba7c4
add woq saving and loading ut and logger info
changwangss Sep 5, 2024
2683608
Merge branch 'main' into wangchang/inc_woq
changwangss Sep 5, 2024
1401c89
set transformers version limit
changwangss Sep 5, 2024
bc3b95a
fix installation neural_compressor[pt]
changwangss Sep 6, 2024
99f797d
improve ut
changwangss Sep 6, 2024
8321a24
refactoring
echarlaix Sep 6, 2024
08091bc
Refactor
echarlaix Sep 6, 2024
09acbd9
revert
echarlaix Sep 6, 2024
28a10d9
fix datasets loading issue
changwangss Sep 9, 2024
1ad67f1
fix
echarlaix Sep 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add woq saving and loading ut and logger info
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Sep 5, 2024
commit 6eba7c4ef94b8d930a4d17cbf38a1ea1442a047d
6 changes: 6 additions & 0 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def _from_pretrained(
"Weight only quantization model loading provided by intel_extension_for_transformers is deprecated and it is provided by INC now.",
DeprecationWarning,
)
logger.info(
"The weight only quantized model loading only supports the same format as GPTQ, such as https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ/tree/main."
)
_BaseINCAutoModelClass.ORIG_MODEL = cls.auto_model_class
model = _BaseINCAutoModelClass.load_low_bit(
model_id,
Expand All @@ -165,6 +168,9 @@ def _from_pretrained(
"Weight only quantization provided by intel_extension_for_transformers is deprecated and it is provided by INC now.",
DeprecationWarning,
)
logger.info(
"The quantized model parameters will be saved in the same format as GPTQ, here is the sample model https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ/tree/main for details."
)
model = weight_only_quantization(
cls.auto_model_class,
model_id,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
QUALITY_REQUIRE = ["black~=23.1", "ruff==0.4.4"]

EXTRAS_REQUIRE = {
"neural-compressor": ["neural-compressor>3.0", "accelerate", "transformers"],
"neural-compressor": ["neural-compressor>3.0", "accelerate", "transformers<4.43"],
"openvino": ["openvino>=2023.3", "nncf>=2.11.0", "openvino-tokenizers[transformers]"],
"nncf": ["nncf>=2.11.0"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.39.0,<4.44.0"],
Expand Down
29 changes: 28 additions & 1 deletion tests/neural_compressor/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, pipeline, set_seed
from transformers.utils import SAFE_WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.intel import ( # noqa
Expand All @@ -38,7 +39,7 @@
INCStableDiffusionPipeline,
INCTrainer,
)
from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, WEIGHTS_NAME
from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, QUANTIZATION_CONFIG_NAME, WEIGHTS_NAME


os.environ["CUDA_VISIBLE_DEVICES"] = ""
Expand Down Expand Up @@ -145,6 +146,32 @@ def test_compare_with_and_without_past_key_values(self):
self.assertEqual(outputs_without_pkv.shape[1], self.GENERATION_LENGTH)
self.assertTrue(torch.equal(outputs_with_pkv, outputs_without_pkv))

def test_saving_loading_inc_woq_model(self):
model_name = "TheBlokeAI/Mixtral-tiny-GPTQ"
subfolder = "inc"
model = INCModelForCausalLM.from_pretrained(model_name, revision="inc", subfolder=subfolder)
tokenizer = AutoTokenizer.from_pretrained(model_name, revision="inc")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokens = tokenizer("This is a sample output", return_tensors="pt")

with tempfile.TemporaryDirectory() as tmp_dir:
model_save_dir = Path(tmp_dir) / subfolder
model.save_pretrained(model_save_dir)
folder_contents = os.listdir(model_save_dir)
self.assertIn(SAFE_WEIGHTS_NAME, folder_contents)
self.assertIn(QUANTIZATION_CONFIG_NAME, folder_contents)
loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder)

with torch.no_grad():
outputs = model(**tokens)
loaded_outputs = loaded_model(**tokens)

self.assertTrue("logits" in loaded_outputs)
self.assertIsInstance(loaded_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in loaded_outputs)
self.assertIsInstance(loaded_outputs.past_key_values, tuple)
self.assertTrue(torch.allclose(outputs.logits, loaded_outputs.logits, atol=1e-5))

def test_saving_loading_inc_model(self):
model_name = "echarlaix/tiny-random-PhiForCausalLM"
subfolder = "inc"
Expand Down