Skip to content
Prev Previous commit
Next Next commit
fix bug
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 20, 2024
commit 7073400a61c430f81697a61bbfd8afa63277142f
2 changes: 1 addition & 1 deletion neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def local_config(self, config):
self._local_config = config

def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
if hasattr(operator_name_or_list, "__iter__"):
if isinstance(operator_name_or_list, list):
for operator_name in operator_name_or_list:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
Expand Down
13 changes: 12 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, quant_config: ConfigMappingType) -> None:
Args:
quant_config (ConfigMappingType): quantization config for ops.
"""
quant_config = self._parse_hqq_configs_mapping(quant_config)
super().__init__(quant_config=quant_config)

@torch.no_grad()
Expand Down Expand Up @@ -118,7 +119,8 @@ def save(self, model, path):
pass

def _convert_hqq_module_config(self, config) -> HQQModuleConfig:
# * 3.x API use `bits` for woq while HQQ internal API use `nbits`
# TODO: (Yi) Please note that the configuration defined by INC should be separated from the algorithm.
# * 3.x API use `bits` for woq while HQQ internal API use `nbits`, we should change it in algorithm_entry.py
nbits = config.bits
group_size = config.group_size
quant_zero = config.quant_zero
Expand All @@ -141,3 +143,12 @@ def _convert_hqq_module_config(self, config) -> HQQModuleConfig:
hqq_module_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig)
logger.debug(hqq_module_config)
return hqq_module_config

def _parse_hqq_configs_mapping(self, configs_mapping):
qconfig_mapping = {}
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config is not None and quant_config.dtype == "fp32":
logger.warning("Fallback %s.", op_name)
continue
qconfig_mapping[op_name] = self._convert_hqq_module_config(quant_config)
return qconfig_mapping