Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,16 @@ venv.bak/

# mypy
.mypy_cache/

# Claude settings
.claude/*

# IDE files
.idea/
.vscode/
*.swp
*.swo
*~

# Poetry
poetry.lock is NOT ignored - this should be committed
282 changes: 282 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
[tool.poetry]
name = "warmup-scheduler"
version = "0.3.2"
description = "Gradually Warm-up LR Scheduler for Pytorch"
authors = ["Your Name <[email protected]>"]
license = "MIT"
readme = "README.md"
homepage = "https://github.com/ildoonet/pytorch-gradual-warmup-lr"
repository = "https://github.com/ildoonet/pytorch-gradual-warmup-lr"
packages = [{include = "warmup_scheduler"}]

[tool.poetry.dependencies]
python = "^3.8"

[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
pytest-cov = "^5.0.0"
pytest-mock = "^3.14.0"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[tool.pytest.ini_options]
minversion = "8.0"
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"-ra",
"--strict-markers",
"--cov=warmup_scheduler",
"--cov-report=term-missing",
"--cov-report=html",
"--cov-report=xml",
"-vv"
]
markers = [
"unit: Unit tests",
"integration: Integration tests",
"slow: Slow tests"
]

[tool.coverage.run]
source = ["warmup_scheduler"]
omit = [
"*/tests/*",
"*/__pycache__/*",
"*/venv/*",
"*/virtualenv/*",
"*/.venv/*",
"*/site-packages/*"
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod"
]
skip_empty = true
precision = 2
fail_under = 80

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
130 changes: 130 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import Mock, patch
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))


@pytest.fixture
def temp_dir():
"""Create a temporary directory that is cleaned up after the test."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)


@pytest.fixture
def mock_config():
"""Provide a mock configuration object."""
config = Mock()
config.warmup_epochs = 5
config.base_lr = 0.001
config.warmup_factor = 0.1
config.total_epochs = 100
return config


@pytest.fixture
def mock_optimizer():
"""Provide a mock optimizer for testing schedulers."""
optimizer = Mock()
optimizer.param_groups = [{'lr': 0.001}, {'lr': 0.01}]
return optimizer


@pytest.fixture
def mock_scheduler():
"""Provide a mock base scheduler."""
scheduler = Mock()
scheduler.base_lrs = [0.001, 0.01]
scheduler.last_epoch = -1
scheduler.get_lr = Mock(return_value=[0.001, 0.01])
scheduler.get_last_lr = Mock(return_value=[0.001, 0.01])
scheduler.step = Mock()
return scheduler


@pytest.fixture
def capture_output():
"""Capture stdout and stderr for testing print statements."""
import io
from contextlib import redirect_stdout, redirect_stderr

class OutputCapture:
def __init__(self):
self.stdout = io.StringIO()
self.stderr = io.StringIO()

def __enter__(self):
self._stdout_context = redirect_stdout(self.stdout)
self._stderr_context = redirect_stderr(self.stderr)
self._stdout_context.__enter__()
self._stderr_context.__enter__()
return self

def __exit__(self, *args):
self._stdout_context.__exit__(*args)
self._stderr_context.__exit__(*args)

def get_stdout(self):
return self.stdout.getvalue()

def get_stderr(self):
return self.stderr.getvalue()

return OutputCapture()


@pytest.fixture
def mock_torch_module():
"""Mock torch module for testing without PyTorch dependency."""
with patch.dict('sys.modules', {'torch': Mock(), 'torch.optim': Mock(), 'torch.optim.lr_scheduler': Mock()}):
yield


@pytest.fixture(autouse=True)
def reset_imports():
"""Reset imports between tests to ensure clean state."""
modules_to_remove = [m for m in sys.modules if m.startswith('warmup_scheduler')]
for module in modules_to_remove:
del sys.modules[module]
yield


@pytest.fixture
def sample_data():
"""Provide sample data for testing."""
return {
'learning_rates': [0.001, 0.01, 0.1],
'epochs': list(range(100)),
'warmup_epochs': 5,
'multiplier': 8,
'total_epoch': 100
}


@pytest.fixture
def environment_vars():
"""Temporarily set environment variables for testing."""
original_env = os.environ.copy()

def _set_env(**kwargs):
os.environ.update(kwargs)
return os.environ

yield _set_env

os.environ.clear()
os.environ.update(original_env)


@pytest.fixture
def mock_time():
"""Mock time-related functions for deterministic testing."""
with patch('time.time', return_value=1234567890):
with patch('time.sleep'):
yield
Empty file added tests/integration/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions tests/test_setup_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pytest
import sys
import os
from pathlib import Path


class TestSetupValidation:
"""Validation tests to ensure the testing infrastructure is properly configured."""

def test_pytest_import(self):
"""Test that pytest can be imported."""
import pytest
assert pytest is not None

def test_project_importable(self):
"""Test that the main package can be imported."""
try:
import warmup_scheduler
assert warmup_scheduler.__version__ == '0.3.2'
except ModuleNotFoundError as e:
if "torch" in str(e):
pytest.skip("PyTorch not installed - this is expected for infrastructure setup")
else:
raise

def test_gradual_warmup_scheduler_importable(self):
"""Test that GradualWarmupScheduler can be imported."""
try:
from warmup_scheduler import GradualWarmupScheduler
assert GradualWarmupScheduler is not None
except ModuleNotFoundError as e:
if "torch" in str(e):
pytest.skip("PyTorch not installed - this is expected for infrastructure setup")
else:
raise

@pytest.mark.unit
def test_unit_marker(self):
"""Test that unit test marker works."""
assert True

@pytest.mark.integration
def test_integration_marker(self):
"""Test that integration test marker works."""
assert True

@pytest.mark.slow
def test_slow_marker(self):
"""Test that slow test marker works."""
assert True

def test_temp_dir_fixture(self, temp_dir):
"""Test that temp_dir fixture works correctly."""
assert temp_dir.exists()
assert temp_dir.is_dir()

test_file = temp_dir / "test.txt"
test_file.write_text("Hello, World!")
assert test_file.exists()
assert test_file.read_text() == "Hello, World!"

def test_mock_config_fixture(self, mock_config):
"""Test that mock_config fixture provides expected attributes."""
assert mock_config.warmup_epochs == 5
assert mock_config.base_lr == 0.001
assert mock_config.warmup_factor == 0.1
assert mock_config.total_epochs == 100

def test_mock_optimizer_fixture(self, mock_optimizer):
"""Test that mock_optimizer fixture works correctly."""
assert len(mock_optimizer.param_groups) == 2
assert mock_optimizer.param_groups[0]['lr'] == 0.001
assert mock_optimizer.param_groups[1]['lr'] == 0.01

def test_capture_output_fixture(self, capture_output):
"""Test that capture_output fixture works correctly."""
with capture_output as capture:
print("Hello stdout")
print("Hello stderr", file=sys.stderr)

assert "Hello stdout" in capture.get_stdout()
assert "Hello stderr" in capture.get_stderr()

def test_sample_data_fixture(self, sample_data):
"""Test that sample_data fixture provides expected data."""
assert 'learning_rates' in sample_data
assert 'epochs' in sample_data
assert len(sample_data['learning_rates']) == 3
assert len(sample_data['epochs']) == 100

def test_environment_vars_fixture(self, environment_vars):
"""Test that environment_vars fixture works correctly."""
original_value = os.environ.get('TEST_VAR', None)

environment_vars(TEST_VAR='test_value')
assert os.environ['TEST_VAR'] == 'test_value'

# After test, environment should be restored
# (this will be checked after the fixture cleanup)

def test_coverage_configured(self):
"""Test that coverage is properly configured."""
try:
import coverage
assert coverage is not None
except ImportError:
pytest.skip("Coverage not yet installed")

def test_test_directories_exist(self):
"""Test that test directory structure exists."""
test_root = Path(__file__).parent
assert test_root.exists()
assert (test_root / "unit").exists()
assert (test_root / "integration").exists()
assert (test_root / "__init__.py").exists()
assert (test_root / "unit" / "__init__.py").exists()
assert (test_root / "integration" / "__init__.py").exists()

def test_conftest_loaded(self):
"""Test that conftest.py is properly loaded."""
# If we can use the fixtures, conftest is loaded
assert True

@pytest.mark.parametrize("value,expected", [
(1, 1),
(2, 2),
(3, 3),
])
def test_parametrize_works(self, value, expected):
"""Test that pytest parametrize decorator works."""
assert value == expected
Empty file added tests/unit/__init__.py
Empty file.