Skip to content

Commit b461291

Browse files
committed
fix bug
Signed-off-by: xin3he <[email protected]>
1 parent 131f2cb commit b461291

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

neural_compressor/common/base_config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,16 @@ def local_config(self):
198198
def local_config(self, config):
199199
self._local_config = config
200200

201-
def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig:
202-
if operator_name in self.local_config:
203-
logger.warning("The configuration for %s has already been set, update it.", operator_name)
204-
self.local_config[operator_name] = config
201+
def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
202+
if hasattr(operator_name_or_list, "__iter__"):
203+
for operator_name in operator_name_or_list:
204+
if operator_name in self.local_config:
205+
logger.warning("The configuration for %s has already been set, update it.", operator_name)
206+
self.local_config[operator_name] = config
207+
else:
208+
if operator_name_or_list in self.local_config:
209+
logger.warning("The configuration for %s has already been set, update it.", operator_name)
210+
self.local_config[operator_name_or_list] = config
205211
return self
206212

207213
def to_dict(self):

neural_compressor/torch/quantization/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ def __init__(
169169
self.double_quant_group_size = double_quant_group_size
170170
self.quant_lm_head = quant_lm_head
171171
self._post_init() # initialize global & local configuration
172-
if not self.quant_lm_head:
173-
# use .* for re.match
174-
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
175-
self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32"))
176172

177173
@classmethod
178174
def register_supported_configs(cls) -> List[OperatorConfig]:
@@ -209,6 +205,15 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
209205
supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators))
210206
cls.supported_configs = supported_configs
211207

208+
def to_config_mapping(
209+
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
210+
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
211+
if not self.quant_lm_head:
212+
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
213+
self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32"))
214+
config_mapping = super().to_config_mapping(config_list, model_info)
215+
return config_mapping
216+
212217
@staticmethod
213218
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
214219
filter_result = []

neural_compressor/torch/utils/utility.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def dump_model_op_stats(mode, tune_cfg):
225225
for op, config in tune_cfg.items():
226226
op_type = op[1]
227227
config = config.to_dict()
228-
# import pdb; pdb.set_trace()
229228
if not config["dtype"] == "fp32":
230229
num_bits = config["bits"]
231230
group_size = config["group_size"]

0 commit comments

Comments
 (0)