Skip to content
Prev Previous commit
Next Next commit
update per review
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 25, 2024
commit 2351abbb3f9fab541a84c7aec569d50096a58852
16 changes: 6 additions & 10 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, is_transformers_imported, logger
from neural_compressor.torch.utils.constants import (
LM_HEAD_NAMES,
PRIORITY_AUTOROUND,
PRIORITY_AWQ,
PRIORITY_GPTQ,
Expand Down Expand Up @@ -198,8 +199,7 @@ def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32"))
self.set_local(LM_HEAD_NAMES, RTNConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

Expand Down Expand Up @@ -359,8 +359,7 @@ def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, GPTQConfig(dtype="fp32"))
self.set_local(LM_HEAD_NAMES, GPTQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

Expand Down Expand Up @@ -502,8 +501,7 @@ def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, AWQConfig(dtype="fp32"))
self.set_local(LM_HEAD_NAMES, AWQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

Expand Down Expand Up @@ -641,8 +639,7 @@ def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, TEQConfig(dtype="fp32"))
self.set_local(LM_HEAD_NAMES, TEQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

Expand Down Expand Up @@ -1269,8 +1266,7 @@ def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, HQQConfig(dtype="fp32"))
self.set_local(LM_HEAD_NAMES, HQQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@
class LoadFormat(Enum):
DEFAULT = "default"
HUGGINGFACE = "huggingface"


LM_HEAD_NAMES = [".*lm_head", ".*output_layer", ".*embed_out"]