diff --git a/docs/3x/TensorFlow.md b/docs/3x/TensorFlow.md index 5634a524f14..dd58c389699 100644 --- a/docs/3x/TensorFlow.md +++ b/docs/3x/TensorFlow.md @@ -23,7 +23,7 @@ Intel(R) Neural Compressor provides `quantize_model` and `autotune` as main inte **quantize_model** -The design philosophy of the `quantize_model` interface is easy-of-use. With minimal parameters requirement, including `model`, `quant_config`, `calib_dataloader` and `calib_iteration`, it offers a straightforward choice of quantizing TF model in one-shot. +The design philosophy of the `quantize_model` interface is easy-of-use. With minimal parameters requirement, including `model`, `quant_config`, `calib_dataloader`, `calib_iteration`, it offers a straightforward choice of quantizing TF model in one-shot. ```python def quantize_model( @@ -31,6 +31,7 @@ def quantize_model( quant_config: Union[BaseConfig, list], calib_dataloader: Callable = None, calib_iteration: int = 100, + calib_func: Callable = None, ): ``` `model` should be a string of the model's location, the object of Keras model or INC TF model wrapper class. @@ -41,6 +42,9 @@ def quantize_model( `calib_iteration` is used to decide how many iterations the calibration process will be run. +`calib_func` is a substitution for `calib_dataloader` when the built-in calibration function of INC does not work for model inference. + + Here is a simple example of using `quantize_model` interface with a dummy calibration dataloader and the default `StaticQuantConfig`: ```python from neural_compressor.tensorflow import StaticQuantConfig, quantize_model @@ -68,6 +72,7 @@ def autotune( eval_args: Optional[Tuple[Any]] = None, calib_dataloader: Callable = None, calib_iteration: int = 100, + calib_func: Callable = None, ) -> Optional[BaseModel]: ``` `model` should be a string of the model's location, the object of Keras model or INC TF model wrapper class. @@ -82,6 +87,8 @@ def autotune( `calib_iteration` is used to decide how many iterations the calibration process will be run. +`calib_func` is a substitution for `calib_dataloader` when the built-in calibration function of INC does not work for model inference. + Here is a simple example of using `autotune` interface with different quantization rules defined by a list of `StaticQuantConfig`: ```python from neural_compressor.common.base_tuning import TuningConfig diff --git a/neural_compressor/tensorflow/algorithms/smoother/core.py b/neural_compressor/tensorflow/algorithms/smoother/core.py index d8c3af164f5..187539ee6eb 100644 --- a/neural_compressor/tensorflow/algorithms/smoother/core.py +++ b/neural_compressor/tensorflow/algorithms/smoother/core.py @@ -37,19 +37,23 @@ class SmoothQuant: def __init__( self, config: SmoothQuantConfig, - calib_dataloader: Callable, + calib_dataloader: Callable = None, calib_iteration: int = 1, + calib_func: Callable = None, ): """Convert the model by smooth quant. Args: - config: the SmoothQuantConfig class used to set this class - calibdataloader: the calibration dataloader - calib_iteration: how many steps of iterations on the dataloader to move forward + config: the SmoothQuantConfig class used to set this class. + calibdataloader: the calibration dataloader. + calib_iteration: how many steps of iterations on the dataloader to move forward. + calib_func: the function used for calibration, should be a substitution for calib_dataloader + when the built-in calibration function of INC does not work for model inference. Returns: model: A smoothed Tensorflow model """ + assert calib_func is None, "calibration function is not supported for smooth quant." self.config = config self.calib_dataloader = calib_dataloader self.calib_iteration = calib_iteration diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index 004393c8c27..83d9a54609d 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -314,16 +314,18 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): return bn_fused_model @dump_elapsed_time("Pass quantize model") - def quantize(self, quant_config, model, dataloader, iteration, q_func=None): + def quantize(self, quant_config, model, dataloader, iteration, calib_func=None): """Execute the quantize process on the specified model. Args: - tune_cfg(dict): The user defined 'StaticQuantConfig' class. + quant_config(dict): The user defined 'StaticQuantConfig' class. model (object): The model to do quantization. dataloader(object): The calibration dataloader used to load quantization dataset. iteration(int): The iteration of calibration. - q_func (optional): training function for quantization aware training mode. + calib_func (optional): the function used for calibration, should be a substitution for calibration + dataloader when the built-in calibration function of INC does not work for model inference. """ + assert calib_func is None, "The calibration function is not supported on Keras backend yet" self.query_fw_capability(model) converter = KerasConfigConverter(quant_config, iteration) tune_cfg = converter.parse_to_tune_cfg() @@ -367,15 +369,13 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): return quantized_model - def _calibrate(self, model, dataloader, calib_interation): + def _calibrate(self, model, dataloader=None, calib_interation=None): """Apply calibration. Args: model (tf.keras.Model): The model inserted with FakeQuant layers for calibration. dataloader(object): The calibration dataloader used to load quantization dataset. iteration(int): The iteration of calibration. - fq_output_layers (dict): A dict mapping from names of FakeQuant layers to - names of their output layers. """ # run eagerly to fetch the numpy min/max results = {} diff --git a/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py b/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py index 160cdb01e44..3bf9cff80af 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/tensorflow.py @@ -172,7 +172,7 @@ def quantize( model: BaseModel, calib_dataloader: Callable = None, calib_iteration: int = 100, - q_func=None, + calib_func: Callable = None, ): """Execute the quantize process on the specified model. @@ -181,11 +181,11 @@ def quantize( model: the fp32 model to be quantized. calib_dataloader: a data loader for calibration. calib_iteration: the iteration of calibration. - q_func: training function for quantization aware training mode, - which not enabled for tensorflow yet. + calib_func: the function used for calibration, should be a substitution for calib_dataloader + when the built-in calibration function of INC does not work for model inference. Returns: - tf.compat.v1.GraphDef: the quantized model + converted_model: the quantized INC model wrapper. """ assert ( self.approach != "post_training_dynamic_quant" @@ -195,7 +195,7 @@ def quantize( self.approach != "quant_aware_training" ), "Quantize Aware Training is not supported on TensorFlow framework now!" - self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration + self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration if calib_dataloader else 100 tune_cfg = self.parse_quant_config(quant_config, model, calib_iteration) self._tuning_cfg_to_fw(tune_cfg) self.bf16_ops.extend(self.smooth_quant_mul_ops) @@ -228,7 +228,7 @@ def quantize( fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=calib_dataloader, - calib_func=q_func, + calib_func=calib_func, qdq_enabled=self.qdq_enabled, new_api=self.new_api, performance_only=self.performance_only, @@ -251,7 +251,7 @@ def quantize( fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=calib_dataloader, - calib_func=q_func, + calib_func=calib_func, qdq_enabled=self.qdq_enabled, new_api=self.new_api, performance_only=self.performance_only, @@ -275,7 +275,7 @@ def quantize( fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=calib_dataloader, - calib_func=q_func, + calib_func=calib_func, qdq_enabled=self.qdq_enabled, new_api=self.new_api, performance_only=self.performance_only, @@ -750,21 +750,21 @@ def quantize( model: BaseModel, calib_dataloader: Callable = None, calib_iteration: int = 100, - q_func=None, + calib_func: Callable = None, ): """Execute the quantize process on the specified model. Args: - tune_cfg (dict): quantization configuration - model (tf.compat.v1.GraphDef): fp32 model - data_loader (generator): generator the data and labels - q_func (optional): training function for quantization aware training mode, - which not enabled for tensorflow yet. + quant_config: a quantization configuration. + model: the fp32 model to be quantized. + calib_dataloader: a data loader for calibration. + calib_iteration: the iteration of calibration. + calib_func: the function used for calibration, should be a substitution for calib_dataloader + when the built-in calibration function of INC does not work for model inference. Returns: - tf.compat.v1.GraphDef: the quantized model + converted_model: the quantized INC model wrapper. """ - assert q_func is None, "quantization aware training mode is not support on tensorflow" self.calib_sampling_size = calib_dataloader.batch_size * calib_iteration tune_cfg = self.parse_quant_config(quant_config, model, calib_iteration) self._tuning_cfg_to_fw(tune_cfg) @@ -798,7 +798,7 @@ def quantize( fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=calib_dataloader, - calib_func=q_func, + calib_func=calib_func, itex_mode=self.itex_mode, qdq_enabled=self.qdq_enabled, new_api=self.new_api, @@ -846,7 +846,7 @@ def quantize( fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=calib_dataloader, - calib_func=q_func, + calib_func=calib_func, itex_mode=self.itex_mode, qdq_enabled=self.qdq_enabled, new_api=self.new_api, diff --git a/neural_compressor/tensorflow/quantization/algorithm_entry.py b/neural_compressor/tensorflow/quantization/algorithm_entry.py index 4b40a2f39a1..e3530bc5e28 100644 --- a/neural_compressor/tensorflow/quantization/algorithm_entry.py +++ b/neural_compressor/tensorflow/quantization/algorithm_entry.py @@ -28,6 +28,7 @@ def static_quant_entry( quant_config: BaseConfig, calib_dataloader: Callable = None, calib_iteration: int = 100, + calib_func: Callable = None, ): """The main entry to apply static quantization. @@ -36,6 +37,8 @@ def static_quant_entry( quant_config: a quantization configuration. calib_dataloader: a data loader for calibration. calib_iteration: the iteration of calibration. + calib_func: the function used for calibration, should be a substitution for calib_dataloader + when the built-in calibration function of INC does not work for model inference. Returns: q_model: the quantized model. @@ -49,7 +52,7 @@ def static_quant_entry( framework = TensorFlowAdaptor quantizer = framework(TFConfig.global_config) - q_model = quantizer.quantize(quant_config, model, calib_dataloader, calib_iteration) + q_model = quantizer.quantize(quant_config, model, calib_dataloader, calib_iteration, calib_func) TFConfig.reset_global_config() return q_model @@ -61,12 +64,26 @@ def smooth_quant_entry( smooth_quant_config: SmoothQuantConfig, calib_dataloader: Callable = None, calib_iteration: int = 100, + calib_func: Callable = None, ): + """The main entry to apply smooth quantization. + + Args: + model: a fp32 model to be quantized. + quant_config: a quantization configuration. + calib_dataloader: a data loader for calibration. + calib_iteration: the iteration of calibration. + calib_func: the function used for calibration, should be a substitution for calib_dataloader + when the built-in calibration function of INC does not work for model inference. + + Returns: + q_model: the quantized model. + """ assert not isinstance(model, KerasModel), "INC don't support smooth quantization for Keras models now." from neural_compressor.tensorflow.algorithms import SmoothQuant - converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration) + converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration, calib_func) sq_model = converter(model) return sq_model diff --git a/neural_compressor/tensorflow/quantization/autotune.py b/neural_compressor/tensorflow/quantization/autotune.py index 55b089b923c..847557b0b8a 100644 --- a/neural_compressor/tensorflow/quantization/autotune.py +++ b/neural_compressor/tensorflow/quantization/autotune.py @@ -44,6 +44,7 @@ def autotune( eval_args: Optional[Tuple[Any]] = None, calib_dataloader: Callable = None, calib_iteration: int = 100, + calib_func: Callable = None, ) -> Optional[BaseModel]: """The main entry of auto-tune.""" model = Model(model) @@ -57,7 +58,7 @@ def autotune( tuning_logger.trial_start(trial_index=trial_index) tuning_logger.execution_start() logger.info(quant_config.to_dict()) - q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration) + q_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration, calib_func) tuning_logger.execution_end() tuning_logger.evaluation_start() eval_result: float = eval_func_wrapper.evaluate(q_model) @@ -71,7 +72,9 @@ def autotune( logger.info("Re-quantizing with best quantization config...") del q_model best_quant_config: BaseConfig = best_trial_record.quant_config - best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration) + best_quant_model = quantize_model( + model, best_quant_config, calib_dataloader, calib_iteration, calib_func + ) else: best_quant_model = q_model break diff --git a/neural_compressor/tensorflow/quantization/quantize.py b/neural_compressor/tensorflow/quantization/quantize.py index fa613759515..6cfd24225b7 100644 --- a/neural_compressor/tensorflow/quantization/quantize.py +++ b/neural_compressor/tensorflow/quantization/quantize.py @@ -32,6 +32,7 @@ def quantize_model( quant_config: Union[BaseConfig, list], calib_dataloader: Callable = None, calib_iteration: int = 100, + calib_func: Callable = None, ): """The main entry to quantize model. @@ -40,6 +41,8 @@ def quantize_model( quant_config: single or lists of quantization configuration. calib_dataloader: a data loader for calibration. calib_iteration: the iteration of calibration. + calib_func: the function used for calibration, should be a substitution for calib_dataloader + when the built-in calibration function of INC does not work for model inference. Returns: q_model: the quantized model. @@ -47,9 +50,11 @@ def quantize_model( q_model = Model(model) if isinstance(quant_config, list): for config in quant_config: - q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration) + q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration, calib_func) else: - q_model = quantize_model_with_single_config(q_model, quant_config, calib_dataloader, calib_iteration) + q_model = quantize_model_with_single_config( + q_model, quant_config, calib_dataloader, calib_iteration, calib_func + ) return q_model @@ -59,6 +64,7 @@ def quantize_model_with_single_config( quant_config: BaseConfig, calib_dataloader: Callable = None, calib_iteration: int = 100, + calib_func: Callable = None, ): """Quantize model using single config. @@ -67,6 +73,8 @@ def quantize_model_with_single_config( quant_config: a quantization configuration. calib_dataloader: a data loader for calibration. calib_iteration: the iteration of calibration. + calib_func: the function used for calibration, should be a substitution for calib_dataloader + when the built-in calibration function of INC does not work for model inference. Returns: q_model: the quantized model. @@ -89,5 +97,5 @@ def quantize_model_with_single_config( for algo_name, algo_func in algos_mapping.items(): if need_apply(configs_mapping, algo_name): logger.info(f"Start to apply {algo_name} on the model.") - q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration) + q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration, calib_func) return q_model diff --git a/neural_compressor/tensorflow/quantization/utils/graph_converter.py b/neural_compressor/tensorflow/quantization/utils/graph_converter.py index 302bfe13717..e3c1c640c86 100644 --- a/neural_compressor/tensorflow/quantization/utils/graph_converter.py +++ b/neural_compressor/tensorflow/quantization/utils/graph_converter.py @@ -231,6 +231,10 @@ def _inference(self, model): Args: model(TensorflowBaseModel): input TensorflowBaseModel """ + if self.calib_func: + self.calib_func(model) + return + if model.model_type == "llm_saved_model": self._inference_llm(model) return