Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix param
Signed-off-by: zehao-intel <[email protected]>
  • Loading branch information
zehao-intel committed Jul 17, 2024
commit 29404385714713e67438937f0ea491d8432b4f5c
13 changes: 7 additions & 6 deletions docs/3x/TF_Quant.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
TensorFlow Quantization
===============

1. [Introduction](#introduction)
2. [Usage](#usage)
2.1 [Without Accuracy Aware Tuning](#without-accuracy-aware-tuning)
2.2 [With Accuracy Aware Tuning](#with-accuracy-aware-tuning)
2.3 [Specify Quantization Rules](#specify-quantization-rules)
3. [Examples](#examples)
- [TensorFlow Quantization](#tensorflow-quantization)
- [Introduction](#introduction)
- [Get Started](#get-started)
- [Without Accuracy Aware Tuning](#without-accuracy-aware-tuning)
- [With Accuracy Aware Tuning](#with-accuracy-aware-tuning)
- [Specify Quantization Rules](#specify-quantization-rules)
- [Examples](#examples)

## Introduction

Expand Down
12 changes: 8 additions & 4 deletions neural_compressor/tensorflow/algorithms/smoother/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,23 @@ class SmoothQuant:
def __init__(
self,
config: SmoothQuantConfig,
calib_dataloader: Callable,
calib_dataloader: Callable=None,
calib_iteration: int = 1,
calib_func: Callable=None,
):
"""Convert the model by smooth quant.

Args:
config: the SmoothQuantConfig class used to set this class
calibdataloader: the calibration dataloader
calib_iteration: how many steps of iterations on the dataloader to move forward
config: the SmoothQuantConfig class used to set this class.
calibdataloader: the calibration dataloader.
calib_iteration: how many steps of iterations on the dataloader to move forward.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
model: A smoothed Tensorflow model
"""
assert calib_func is None, "calibration function is not supported for smooth quant."
self.config = config
self.calib_dataloader = calib_dataloader
self.calib_iteration = calib_iteration
Expand Down
5 changes: 4 additions & 1 deletion neural_compressor/tensorflow/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def smooth_quant_entry(
smooth_quant_config: SmoothQuantConfig,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
"""The main entry to apply smooth quantization.

Expand All @@ -72,6 +73,8 @@ def smooth_quant_entry(
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
q_model: the quantized model.
Expand All @@ -80,7 +83,7 @@ def smooth_quant_entry(

from neural_compressor.tensorflow.algorithms import SmoothQuant

converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration)
converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration, calib_func)
sq_model = converter(model)

return sq_model
12 changes: 9 additions & 3 deletions neural_compressor/tensorflow/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def quantize_model(
quant_config: Union[BaseConfig, list],
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
"""The main entry to quantize model.

Expand All @@ -40,16 +41,18 @@ def quantize_model(
quant_config: single or lists of quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
q_model: the quantized model.
"""
q_model = Model(model)
if isinstance(quant_config, list):
for config in quant_config:
q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration)
q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration, calib_func)
else:
q_model = quantize_model_with_single_config(q_model, quant_config, calib_dataloader, calib_iteration)
q_model = quantize_model_with_single_config(q_model, quant_config, calib_dataloader, calib_iteration, calib_func)

return q_model

Expand All @@ -59,6 +62,7 @@ def quantize_model_with_single_config(
quant_config: BaseConfig,
calib_dataloader: Callable = None,
calib_iteration: int = 100,
calib_func: Callable = None,
):
"""Quantize model using single config.

Expand All @@ -67,6 +71,8 @@ def quantize_model_with_single_config(
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
q_model: the quantized model.
Expand All @@ -89,5 +95,5 @@ def quantize_model_with_single_config(
for algo_name, algo_func in algos_mapping.items():
if need_apply(configs_mapping, algo_name):
logger.info(f"Start to apply {algo_name} on the model.")
q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration)
q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration, calib_func)
return q_model