Skip to content
Prev Previous commit
Next Next commit
fix bug
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 19, 2024
commit b4612910469a60ae0accf13c4677c6c4c2041a55
14 changes: 10 additions & 4 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,16 @@ def local_config(self):
def local_config(self, config):
self._local_config = config

def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name] = config
def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
if hasattr(operator_name_or_list, "__iter__"):
for operator_name in operator_name_or_list:
if operator_name in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name] = config
else:
if operator_name_or_list in self.local_config:
logger.warning("The configuration for %s has already been set, update it.", operator_name)
self.local_config[operator_name_or_list] = config
return self

def to_dict(self):
Expand Down
13 changes: 9 additions & 4 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ def __init__(
self.double_quant_group_size = double_quant_group_size
self.quant_lm_head = quant_lm_head
self._post_init() # initialize global & local configuration
if not self.quant_lm_head:
# use .* for re.match
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32"))

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
Expand Down Expand Up @@ -209,6 +205,15 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators))
cls.supported_configs = supported_configs

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
filter_result = []
Expand Down
1 change: 0 additions & 1 deletion neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def dump_model_op_stats(mode, tune_cfg):
for op, config in tune_cfg.items():
op_type = op[1]
config = config.to_dict()
# import pdb; pdb.set_trace()
if not config["dtype"] == "fp32":
num_bits = config["bits"]
group_size = config["group_size"]
Expand Down