Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add unit tests
  • Loading branch information
crusaderky committed Dec 11, 2025
commit f7aebf654b8756e391f93d2a6a815ef156073f98
62 changes: 62 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: Test

on:
push:
branches: [main]
pull_request:
branches: ["*"]
workflow_dispatch: # allows you to trigger manually

# When this workflow is queued, automatically cancel any previous running
# or pending jobs from the same branch
concurrency:
group: pytest-${{ github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -l {0}

jobs:
test:
name: Python ${{ matrix.python-version }} NumPy ${{ matrix.numpy-version}}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.14"]
numpy-version: ["2.3.0", latest]
include:
- python-version: "3.8"
numpy-version: "1.18.0"
- python-version: "3.14"
numpy-version: "nightly"

steps:
- name: Checkout
uses: actions/checkout@v6

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install pinned NumPy
if: matrix.numpy-version != 'latest' && matrix.numpy-version != 'nightly'
run: python -m pip install numpy==${{ matrix.numpy-version }}

- name: Install nightly NumPy wheels
if: matrix.numpy-version == 'nightly'
run: pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/ numpy

- name: Install package
run: pip install .

- name: Smoke test
run: python -c "import ml_datasets"

- name: Install test dependencies
run: pip install pytest

- name: Run tests
run: pytest
2 changes: 2 additions & 0 deletions ml_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
from .loaders.universal_dependencies import ud_ancora_pos_tags, ud_ewtb_pos_tags
from .loaders.dbpedia import dbpedia
from .loaders.cmu import cmu
from .loaders.cifar import cifar
from .loaders.wikiner import wikiner
1 change: 1 addition & 0 deletions ml_datasets/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cifar import cifar
from .cmu import cmu
from .dbpedia import dbpedia
from .imdb import imdb
Expand Down
8 changes: 4 additions & 4 deletions ml_datasets/loaders/reuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from ..util import get_file
from .._registry import register_loader


URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl"
URL = "https://s3.amazonaws.com/text-datasets/reuters.pkl"
WORD_INDEX_URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl"


@register_loader("reuters")
Expand All @@ -15,7 +15,7 @@ def reuters():


def get_word_index(path="reuters_word_index.pkl"):
path = get_file(path, origin=URL)
path = get_file(path, origin=WORD_INDEX_URL)
f = open(path, "rb")
data = pickle.load(f, encoding="latin1")
f.close()
Expand Down Expand Up @@ -60,7 +60,7 @@ def load_reuters(
# https://raw.githubusercontent.com/fchollet/keras/master/keras/datasets/mnist.py
# Copyright Francois Chollet, Google, others (2015)
# Under MIT license
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl")
path = get_file(path, origin=URL)
f = open(path, "rb")
X, labels = pickle.load(f)
f.close()
Expand Down
Empty file added ml_datasets/test/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions ml_datasets/test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
import numpy as np

import ml_datasets

NP_VERSION = tuple(int(x) for x in np.__version__.split(".")[:2])

# FIXME warning on NumPy 2.4 when downloading pre-computed pickles:
# Python or NumPy boolean but got `align=0`.
# Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just re-saving the pickle files with numpy 2.4 will fix this. But then you'd need to replace the files in the AWS bucket...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. I'd open a follow-up ticket but the issues tab has been disabled for this repo.

if NP_VERSION >= (2, 4):
np_24_deprecation = pytest.mark.filterwarnings(
"ignore::numpy.exceptions.VisibleDeprecationWarning",

)
else:
# Note: can't use `condition=NP_VERSION >= (2, 4)` on the decorator directly
# as numpy.exceptions did not exist in old NumPy versions.
np_24_deprecation = lambda x: x


@np_24_deprecation
def test_cifar():
(X_train, y_train), (X_test, y_test) = ml_datasets.cifar()
# TODO test output contents


@pytest.mark.skip(reason="very slow download")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's always going to be skipped, why add the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To inform whoever reads it next that today the whole feature is functionally unusable.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make that clear in the message or with a FIXME comment?

def test_cmu():
train, dev = ml_datasets.cmu()
# TODO test output contents


def test_dbpedia():
train, dev = ml_datasets.dbpedia()
# TODO test output contents


def test_imdb():
train, dev = ml_datasets.imdb()
# TODO test output contents


@np_24_deprecation
def test_mnist():
(X_train, y_train), (X_test, y_test) = ml_datasets.mnist()
# TODO test output contents


@pytest.mark.xfail(reason="403 Forbidden")
def test_quora_questions():
train, dev = ml_datasets.quora_questions()
# TODO test output contents


@np_24_deprecation
def test_reuters():
(X_train, y_train), (X_test, y_test) = ml_datasets.reuters()
# TODO test output contents


def test_snli():
train, dev = ml_datasets.snli()
# TODO test output contents


@pytest.mark.xfail(reason="no default path")
def test_stack_exchange():
train, dev = ml_datasets.stack_exchange()
# TODO test output contents


def test_ud_ancora_pos_tags():
(train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ancora_pos_tags()
# TODO test output contents


@pytest.mark.xfail(reason="str column where int expected")
def test_ud_ewtb_pos_tags():
(train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ewtb_pos_tags()
# TODO test output contents


@pytest.mark.xfail(reason="no default path")
def test_wikiner():
train, dev = ml_datasets.wikiner()
# TODO test output contents
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[tool.pytest.ini_options]
addopts = "--strict-markers --strict-config -v -r sxfE --color=yes --durations=10"
xfail_strict = true
filterwarnings = [
"error",
# FIXME spurious random download warnings; will cause trouble in downstream CI
"ignore:Implicitly cleaning up <HTTPError 403:ResourceWarning",
"ignore:unclosed <socket.socket:ResourceWarning",
]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy>=1.7.0
numpy>=1.18
scipy>=1.7.0
tqdm>=4.10.0,<5.0.0
# Our libraries
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ long_description_content_type = text/markdown
[options]
zip_safe = true
include_package_data = true
python_requires = >=3.6
python_requires = >=3.8
install_requires =
numpy>=1.7.0
numpy>=1.18
Comment on lines 14 to 17
Copy link
Contributor Author

@crusaderky crusaderky Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump to lowest tested versions
python 3.8 is the lowest available for actions/setup-python
numpy 1.18 is the lowest available for python 3.8

tqdm>=4.10.0,<5.0.0
# Our libraries
srsly>=1.0.1,<3.0.0
Expand Down