Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from .save_load import save, load
from .save_load import save, load
2 changes: 1 addition & 1 deletion neural_compressor/torch/quantization/load_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
from neural_compressor.torch.algorithms import static_quant

return static_quant.load(model_name_or_path)
elif "static_quant" in per_op_qconfig.keys() or "pt2e_dynamic_quant" in per_op_qconfig.keys(): # PT2E
elif "static_quant" in per_op_qconfig.keys() or "pt2e_dynamic_quant" in per_op_qconfig.keys(): # PT2E
from neural_compressor.torch.algorithms import pt2e_quant

return pt2e_quant.load(model_name_or_path)
Expand Down
13 changes: 9 additions & 4 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import shutil

import pytest
import torch
import torch.testing._internal.common_quantization as torch_test_quant_common

Expand Down Expand Up @@ -113,14 +114,18 @@ def calib_fn(model):
config.freezing = True
q_model_out = q_model(*example_inputs)
assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!"

# test save and load
q_model.save(example_inputs=example_inputs, output_dir="./saved_results",)
q_model.save(
example_inputs=example_inputs,
output_dir="./saved_results",
)
from neural_compressor.torch.quantization import load

loaded_quantized_model = load("./saved_results")
loaded_q_model_out = loaded_quantized_model(*example_inputs)
assert torch.allclose(loaded_q_model_out, q_model_out)

opt_model = torch.compile(q_model)
out = opt_model(*example_inputs)
logger.warning("out shape is %s", out.shape)
Expand Down