diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..2c86851 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,74 @@ +name: Test + +on: + push: + branches: [main] + pull_request: + branches: ["*"] + workflow_dispatch: # allows you to trigger manually + +# When this workflow is queued, automatically cancel any previous running +# or pending jobs from the same branch +concurrency: + group: pytest-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l {0} + +jobs: + test: + name: ${{ matrix.os }} Python ${{ matrix.python-version }} NumPy ${{ matrix.numpy-version}} + runs-on: ${{ matrix.os}} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.11", "3.14"] + numpy-version: ["2.3.0", latest] + include: + # Test oldest supported Python and NumPy versions + - os: ubuntu-latest + python-version: "3.8" + numpy-version: "1.18.0" + # Test vs. NumPy nightly wheels + - os: ubuntu-latest + python-version: "3.14" + numpy-version: "nightly" + # Test issues re. preinstalled SSL certificates on different OSes + - os: windows-latest + python-version: "3.14" + numpy-version: latest + - os: macos-latest + python-version: "3.14" + numpy-version: latest + + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install pinned NumPy + if: matrix.numpy-version != 'latest' && matrix.numpy-version != 'nightly' + run: python -m pip install numpy==${{ matrix.numpy-version }} + + - name: Install nightly NumPy wheels + if: matrix.numpy-version == 'nightly' + run: pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/ numpy + + - name: Install package + run: pip install . + + - name: Smoke test + run: python -c "import ml_datasets" + + - name: Install test dependencies + run: pip install pytest + + - name: Run tests + run: pytest diff --git a/ml_datasets/__init__.py b/ml_datasets/__init__.py index b4d6d89..56fe965 100644 --- a/ml_datasets/__init__.py +++ b/ml_datasets/__init__.py @@ -8,3 +8,5 @@ from .loaders.universal_dependencies import ud_ancora_pos_tags, ud_ewtb_pos_tags from .loaders.dbpedia import dbpedia from .loaders.cmu import cmu +from .loaders.cifar import cifar +from .loaders.wikiner import wikiner diff --git a/ml_datasets/loaders/__init__.py b/ml_datasets/loaders/__init__.py index 747d385..07add6b 100644 --- a/ml_datasets/loaders/__init__.py +++ b/ml_datasets/loaders/__init__.py @@ -1,3 +1,4 @@ +from .cifar import cifar from .cmu import cmu from .dbpedia import dbpedia from .imdb import imdb diff --git a/ml_datasets/loaders/reuters.py b/ml_datasets/loaders/reuters.py index 367abfc..20428ef 100644 --- a/ml_datasets/loaders/reuters.py +++ b/ml_datasets/loaders/reuters.py @@ -4,8 +4,8 @@ from ..util import get_file from .._registry import register_loader - -URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl" +URL = "https://s3.amazonaws.com/text-datasets/reuters.pkl" +WORD_INDEX_URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl" @register_loader("reuters") @@ -15,7 +15,7 @@ def reuters(): def get_word_index(path="reuters_word_index.pkl"): - path = get_file(path, origin=URL) + path = get_file(path, origin=WORD_INDEX_URL) f = open(path, "rb") data = pickle.load(f, encoding="latin1") f.close() @@ -60,7 +60,7 @@ def load_reuters( # https://raw.githubusercontent.com/fchollet/keras/master/keras/datasets/mnist.py # Copyright Francois Chollet, Google, others (2015) # Under MIT license - path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl") + path = get_file(path, origin=URL) f = open(path, "rb") X, labels = pickle.load(f) f.close() diff --git a/ml_datasets/test/__init__.py b/ml_datasets/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ml_datasets/test/test_datasets.py b/ml_datasets/test/test_datasets.py new file mode 100644 index 0000000..a27dbc0 --- /dev/null +++ b/ml_datasets/test/test_datasets.py @@ -0,0 +1,81 @@ +# TODO the tests below only verify that the various functions don't crash. +# Expand them to test the actual output contents. + +import platform + +import pytest +import numpy as np + +import ml_datasets + +NP_VERSION = tuple(int(x) for x in np.__version__.split(".")[:2]) + +# FIXME warning on NumPy 2.4 when downloading pre-computed pickles: +# Python or NumPy boolean but got `align=0`. +# Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4) +if NP_VERSION >= (2, 4): + np_24_deprecation = pytest.mark.filterwarnings( + "ignore::numpy.exceptions.VisibleDeprecationWarning", + + ) +else: + # Note: can't use `condition=NP_VERSION >= (2, 4)` on the decorator directly + # as numpy.exceptions did not exist in old NumPy versions. + np_24_deprecation = lambda x: x + + +@np_24_deprecation +def test_cifar(): + (X_train, y_train), (X_test, y_test) = ml_datasets.cifar() + + +@pytest.mark.skip(reason="very slow download") +def test_cmu(): + train, dev = ml_datasets.cmu() + + +def test_dbpedia(): + train, dev = ml_datasets.dbpedia() + + +def test_imdb(): + train, dev = ml_datasets.imdb() + + +@np_24_deprecation +def test_mnist(): + (X_train, y_train), (X_test, y_test) = ml_datasets.mnist() + + +@pytest.mark.xfail(reason="403 Forbidden") +def test_quora_questions(): + train, dev = ml_datasets.quora_questions() + + +@np_24_deprecation +def test_reuters(): + (X_train, y_train), (X_test, y_test) = ml_datasets.reuters() + + +@pytest.mark.xfail(platform.system() == "Windows", reason="path issues") +def test_snli(): + train, dev = ml_datasets.snli() + + +@pytest.mark.xfail(reason="no default path") +def test_stack_exchange(): + train, dev = ml_datasets.stack_exchange() + + +def test_ud_ancora_pos_tags(): + (train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ancora_pos_tags() + + +@pytest.mark.xfail(reason="str column where int expected") +def test_ud_ewtb_pos_tags(): + (train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ewtb_pos_tags() + + +@pytest.mark.xfail(reason="no default path") +def test_wikiner(): + train, dev = ml_datasets.wikiner() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..12e2298 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[tool.pytest.ini_options] +addopts = "--strict-markers --strict-config -v -r sxfE --color=yes --durations=10" +xfail_strict = true +filterwarnings = [ + "error", + # FIXME spurious random download warnings; will cause trouble in downstream CI + "ignore:Implicitly cleaning up =2.2 -numpy>=1.7.0 +numpy>=1.18 scipy>=1.7.0 tqdm>=4.10.0,<5.0.0 # Our libraries diff --git a/setup.cfg b/setup.cfg index 9d40de5..30ea6a7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,10 +11,10 @@ long_description_content_type = text/markdown [options] zip_safe = true include_package_data = true -python_requires = >=3.6 +python_requires = >=3.8 install_requires = cloudpickle>=2.2 - numpy>=1.7.0 + numpy>=1.18 tqdm>=4.10.0,<5.0.0 # Our libraries srsly>=1.0.1,<4.0.0