Skip to content
Prev Previous commit
Next Next commit
add save and load
Signed-off-by: gta <[email protected]>
  • Loading branch information
gta authored and gta committed Jul 12, 2024
commit 66506623934ab6349729a99a3ede664aad2a7061
14 changes: 11 additions & 3 deletions neural_compressor/torch/algorithms/static_quant/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ def save(model, output_dir="./saved_results"):

qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)
if next(model.parameters()).device.type == "cpu":
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)
else:
from neural_compressor.common.utils import save_config_mapping

save_config_mapping(model.qconfig, qconfig_file_path)
# MethodType 'save' not in state_dict
del model.save
torch.save(model.state_dict(), qmodel_file_path)

logger.info("Save quantized model to {}.".format(qmodel_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))
Expand Down
11 changes: 5 additions & 6 deletions neural_compressor/torch/algorithms/static_quant/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}):
"""
super().__init__(quant_config)
self.user_cfg = OrderedDict()
self.device = auto_detect_accelerator().current_device()

def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""Prepares a given model for quantization.
Expand All @@ -70,10 +71,9 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
Returns:
A prepared model.
"""
device = auto_detect_accelerator().current_device()
assert example_inputs is not None, "Please provide example_inputs for static quantization."

if device == "cpu":
if self.device == "cpu":
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
model, example_inputs
)
Expand All @@ -89,7 +89,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig

if device != "cpu": # pragma: no cover
if self.device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import prepare_jit

with torch.no_grad():
Expand Down Expand Up @@ -125,7 +125,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
model, static_qconfig, example_inputs=example_inputs, inplace=inplace
)

if device == "cpu":
if self.device == "cpu":
model.load_qconf_summary(qconf_summary=ipex_config_path)

return model
Expand All @@ -141,12 +141,11 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Returns:
A quantized model.
"""
device = auto_detect_accelerator().current_device()
use_bf16 = self.quant_config.get("use_bf16", None)

from neural_compressor.torch.algorithms.static_quant import save

if device != "cpu": # pragma: no cover
if self.device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import convert_jit

model = convert_jit(model, inplace)
Expand Down