@@ -200,7 +200,7 @@ def to_config_mapping(
200200 self , config_list : List [BaseConfig ] = None , model_info : List [Tuple [str , str ]] = None
201201 ) -> OrderedDictType [Union [str , str ], OrderedDictType [str , BaseConfig ]]:
202202 if not self .quant_lm_head :
203- self .set_local (LM_HEAD_NAMES , RTNConfig (dtype = "fp32" ))
203+ self .set_local (LM_HEAD_NAMES , RTNConfig (dtype = "fp32" , use_layer_wise = self . use_layer_wise , model_path = self . model_path ))
204204 config_mapping = super ().to_config_mapping (config_list , model_info )
205205 return config_mapping
206206
@@ -363,7 +363,7 @@ def to_config_mapping(
363363 self , config_list : List [BaseConfig ] = None , model_info : List [Tuple [str , str ]] = None
364364 ) -> OrderedDictType [Union [str , str ], OrderedDictType [str , BaseConfig ]]:
365365 if not self .quant_lm_head :
366- self .set_local (LM_HEAD_NAMES , GPTQConfig (dtype = "fp32" ))
366+ self .set_local (LM_HEAD_NAMES , GPTQConfig (dtype = "fp32" , use_layer_wise = self . use_layer_wise , model_path = self . model_path ))
367367 config_mapping = super ().to_config_mapping (config_list , model_info )
368368 return config_mapping
369369
@@ -385,7 +385,7 @@ def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig
385385 @classmethod
386386 def get_predefined_configs (cls ) -> Dict [torch_utils .ProcessorType , "GPTQConfig" ]:
387387 pre_defined_configs : Dict [torch_utils .ProcessorType , GPTQConfig ] = {}
388- pre_defined_configs [torch_utils .ProcessorType .Client ] = cls (use_layer_wise = True )
388+ pre_defined_configs [torch_utils .ProcessorType .Client ] = cls (use_layer_wise = True )#, model_path=self.model_path)
389389 pre_defined_configs [torch_utils .ProcessorType .Server ] = cls ()
390390 return pre_defined_configs
391391
@@ -508,7 +508,7 @@ def to_config_mapping(
508508 self , config_list : List [BaseConfig ] = None , model_info : List [Tuple [str , str ]] = None
509509 ) -> OrderedDictType [Union [str , str ], OrderedDictType [str , BaseConfig ]]:
510510 if not self .quant_lm_head :
511- self .set_local (LM_HEAD_NAMES , AWQConfig (dtype = "fp32" ))
511+ self .set_local (LM_HEAD_NAMES , AWQConfig (dtype = "fp32" , use_layer_wise = self . use_layer_wise , model_path = self . model_path ))
512512 config_mapping = super ().to_config_mapping (config_list , model_info )
513513 return config_mapping
514514
@@ -815,7 +815,7 @@ def get_config_set_for_tuning(cls) -> Union[None, "AutoRoundConfig", List["AutoR
815815 @classmethod
816816 def get_predefined_configs (cls ) -> Dict [torch_utils .ProcessorType , "AutoRoundConfig" ]:
817817 pre_defined_configs : Dict [torch_utils .ProcessorType , AutoRoundConfig ] = {}
818- pre_defined_configs [torch_utils .ProcessorType .Client ] = cls (use_layer_wise = True )
818+ pre_defined_configs [torch_utils .ProcessorType .Client ] = cls (use_layer_wise = True , model_path = self . model_path )
819819 pre_defined_configs [torch_utils .ProcessorType .Server ] = cls ()
820820 return pre_defined_configs
821821
0 commit comments