Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions neural_compressor/tensorflow/algorithms/smoother/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class SmoothQuant:
def __init__(
self,
config: SmoothQuantConfig,
calib_dataloader: Callable=None,
calib_dataloader: Callable = None,
calib_iteration: int = 1,
calib_func: Callable=None,
calib_func: Callable = None,
):
"""Convert the model by smooth quant.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def quantize(
model: the fp32 model to be quantized.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
Expand Down Expand Up @@ -759,7 +759,7 @@ def quantize(
model: the fp32 model to be quantized.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/tensorflow/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def static_quant_entry(
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
Expand Down Expand Up @@ -73,7 +73,7 @@ def smooth_quant_entry(
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
Expand Down
4 changes: 3 additions & 1 deletion neural_compressor/tensorflow/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def autotune(
logger.info("Re-quantizing with best quantization config...")
del q_model
best_quant_config: BaseConfig = best_trial_record.quant_config
best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration, calib_func)
best_quant_model = quantize_model(
model, best_quant_config, calib_dataloader, calib_iteration, calib_func
)
else:
best_quant_model = q_model
break
Expand Down
8 changes: 5 additions & 3 deletions neural_compressor/tensorflow/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def quantize_model(
quant_config: single or lists of quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
Expand All @@ -52,7 +52,9 @@ def quantize_model(
for config in quant_config:
q_model = quantize_model_with_single_config(q_model, config, calib_dataloader, calib_iteration, calib_func)
else:
q_model = quantize_model_with_single_config(q_model, quant_config, calib_dataloader, calib_iteration, calib_func)
q_model = quantize_model_with_single_config(
q_model, quant_config, calib_dataloader, calib_iteration, calib_func
)

return q_model

Expand All @@ -71,7 +73,7 @@ def quantize_model_with_single_config(
quant_config: a quantization configuration.
calib_dataloader: a data loader for calibration.
calib_iteration: the iteration of calibration.
calib_func: the function used for calibration, should be a substitution for calib_dataloader
calib_func: the function used for calibration, should be a substitution for calib_dataloader
when the built-in calibration function of INC does not work for model inference.

Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _inference(self, model):
if self.calib_func:
self.calib_func(model)
return

if model.model_type == "llm_saved_model":
self._inference_llm(model)
return
Expand Down
8 changes: 4 additions & 4 deletions test/3x/tensorflow/test_quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from neural_compressor.common import logger
from neural_compressor.tensorflow.utils import version1_gte_version2


def build_model():
# Load MNIST dataset
mnist = keras.datasets.mnist
Expand Down Expand Up @@ -110,8 +111,7 @@ def __len__(self):

def evaluate(model):
input_tensor = model.input_tensor
output_tensor = model.output_tensor if len(model.output_tensor)>1 else \
model.output_tensor[0]
output_tensor = model.output_tensor if len(model.output_tensor) > 1 else model.output_tensor[0]

iteration = -1
calib_dataloader = MyDataloader(dataset=Dataset())
Expand Down Expand Up @@ -152,9 +152,9 @@ def test_calib_func(self):
if "Quantized" in node.op:
quantized = True
break

self.assertEqual(quantized, True)


if __name__ == "__main__":
unittest.main()
unittest.main()