Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add common ut
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 committed Jun 18, 2024
commit 6b316e6cb7a531430355a8e2d07a94a29968d9e6
3 changes: 2 additions & 1 deletion neural_compressor/tensorflow/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from neural_compressor.common import logger
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
from neural_compressor.common.utils import dump_elapsed_time
from neural_compressor.common.utils import call_counter, dump_elapsed_time
from neural_compressor.tensorflow.quantization import quantize_model
from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME, StaticQuantConfig
from neural_compressor.tensorflow.utils import BaseModel, Model, constants
Expand All @@ -36,6 +36,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:


@dump_elapsed_time("Pass auto-tune")
@call_counter
def autotune(
model: Union[str, tf.keras.Model, BaseModel],
tune_config: TuningConfig,
Expand Down
19 changes: 19 additions & 0 deletions test/3x/common/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import unittest
from unittest.mock import MagicMock, patch

import neural_compressor.common.utils.utility as inc_utils
from neural_compressor.common import options
from neural_compressor.common.utils import (
CpuInfo,
Expand Down Expand Up @@ -166,5 +167,23 @@ def __init__(self):
assert instance2.value == 1, "Singleton should return the same instance"


class TestCallCounter(unittest.TestCase):
def test_call_counter(self):
@inc_utils.call_counter
def add(a, b):
return a + b

# Initial count should be 0
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 0)

# Call the function multiple times
add(1, 2)
add(3, 4)
add(5, 6)

# Count should be incremented accordingly
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 3)


if __name__ == "__main__":
unittest.main()