Skip to content
Prev Previous commit
Next Next commit
fix save
Signed-off-by: gta <[email protected]>
  • Loading branch information
gta authored and gta committed Jul 15, 2024
commit facde1875297643b8d6cae0d5e7ba75d3390c728
6 changes: 2 additions & 4 deletions neural_compressor/torch/algorithms/static_quant/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,15 @@ def save(model, output_dir="./saved_results"):

qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
if next(model.parameters()).device.type == "cpu":
if next(model.parameters()).device.type == "cpu": # pragma: no cover
model.ori_save(qmodel_file_path)
with open(qconfig_file_path, "w") as f:
json.dump(model.tune_cfg, f, indent=4)
else:
from neural_compressor.common.utils import save_config_mapping

model.ori_save(qmodel_file_path)
save_config_mapping(model.qconfig, qconfig_file_path)
# MethodType 'save' not in state_dict
del model.save
torch.save(model.state_dict(), qmodel_file_path)

logger.info("Save quantized model to {}.".format(qmodel_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):

dump_model_op_stats(self.user_cfg)

model.ori_save = model.save
model.save = MethodType(save, model)
model.ori_save = model.save
model.save = MethodType(save, model)

logger.info("Static quantization done.")
return model
Expand Down