Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: Test

on:
push:
branches: [main]
pull_request:
branches: ["*"]
workflow_dispatch: # allows you to trigger manually

# When this workflow is queued, automatically cancel any previous running
# or pending jobs from the same branch
concurrency:
group: pytest-${{ github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -l {0}

jobs:
test:
name: ${{ matrix.os }} Python ${{ matrix.python-version }} NumPy ${{ matrix.numpy-version}}
runs-on: ${{ matrix.os}}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.11", "3.14"]
numpy-version: ["2.3.0", latest]
include:
# Test oldest supported Python and NumPy versions
- os: ubuntu-latest
python-version: "3.8"
numpy-version: "1.18.0"
# Test vs. NumPy nightly wheels
- os: ubuntu-latest
python-version: "3.14"
numpy-version: "nightly"
# Test issues re. preinstalled SSL certificates on different OSes
- os: windows-latest
python-version: "3.14"
numpy-version: latest
- os: macos-latest
python-version: "3.14"
numpy-version: latest

steps:
- name: Checkout
uses: actions/checkout@v6

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install pinned NumPy
if: matrix.numpy-version != 'latest' && matrix.numpy-version != 'nightly'
run: python -m pip install numpy==${{ matrix.numpy-version }}

- name: Install nightly NumPy wheels
if: matrix.numpy-version == 'nightly'
run: pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/ numpy

- name: Install package
run: pip install .

- name: Smoke test
run: python -c "import ml_datasets"

- name: Install test dependencies
run: pip install pytest

- name: Run tests
run: pytest
2 changes: 2 additions & 0 deletions ml_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
from .loaders.universal_dependencies import ud_ancora_pos_tags, ud_ewtb_pos_tags
from .loaders.dbpedia import dbpedia
from .loaders.cmu import cmu
from .loaders.cifar import cifar
from .loaders.wikiner import wikiner
1 change: 1 addition & 0 deletions ml_datasets/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cifar import cifar
from .cmu import cmu
from .dbpedia import dbpedia
from .imdb import imdb
Expand Down
8 changes: 4 additions & 4 deletions ml_datasets/loaders/reuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from ..util import get_file
from .._registry import register_loader


URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl"
URL = "https://s3.amazonaws.com/text-datasets/reuters.pkl"
WORD_INDEX_URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl"


@register_loader("reuters")
Expand All @@ -15,7 +15,7 @@ def reuters():


def get_word_index(path="reuters_word_index.pkl"):
path = get_file(path, origin=URL)
path = get_file(path, origin=WORD_INDEX_URL)
f = open(path, "rb")
data = pickle.load(f, encoding="latin1")
f.close()
Expand Down Expand Up @@ -60,7 +60,7 @@ def load_reuters(
# https://raw.githubusercontent.com/fchollet/keras/master/keras/datasets/mnist.py
# Copyright Francois Chollet, Google, others (2015)
# Under MIT license
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl")
path = get_file(path, origin=URL)
f = open(path, "rb")
X, labels = pickle.load(f)
f.close()
Expand Down
Empty file added ml_datasets/test/__init__.py
Empty file.
81 changes: 81 additions & 0 deletions ml_datasets/test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# TODO the tests below only verify that the various functions don't crash.
# Expand them to test the actual output contents.

import platform

import pytest
import numpy as np

import ml_datasets

NP_VERSION = tuple(int(x) for x in np.__version__.split(".")[:2])

# FIXME warning on NumPy 2.4 when downloading pre-computed pickles:
# Python or NumPy boolean but got `align=0`.
# Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just re-saving the pickle files with numpy 2.4 will fix this. But then you'd need to replace the files in the AWS bucket...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. I'd open a follow-up ticket but the issues tab has been disabled for this repo.

if NP_VERSION >= (2, 4):
np_24_deprecation = pytest.mark.filterwarnings(
"ignore::numpy.exceptions.VisibleDeprecationWarning",

)
else:
# Note: can't use `condition=NP_VERSION >= (2, 4)` on the decorator directly
# as numpy.exceptions did not exist in old NumPy versions.
np_24_deprecation = lambda x: x


@np_24_deprecation
def test_cifar():
(X_train, y_train), (X_test, y_test) = ml_datasets.cifar()


@pytest.mark.skip(reason="very slow download")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's always going to be skipped, why add the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To inform whoever reads it next that today the whole feature is functionally unusable.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make that clear in the message or with a FIXME comment?

def test_cmu():
train, dev = ml_datasets.cmu()


def test_dbpedia():
train, dev = ml_datasets.dbpedia()


def test_imdb():
train, dev = ml_datasets.imdb()


@np_24_deprecation
def test_mnist():
(X_train, y_train), (X_test, y_test) = ml_datasets.mnist()


@pytest.mark.xfail(reason="403 Forbidden")
def test_quora_questions():
train, dev = ml_datasets.quora_questions()


@np_24_deprecation
def test_reuters():
(X_train, y_train), (X_test, y_test) = ml_datasets.reuters()


@pytest.mark.xfail(platform.system() == "Windows", reason="path issues")
def test_snli():
train, dev = ml_datasets.snli()


@pytest.mark.xfail(reason="no default path")
def test_stack_exchange():
train, dev = ml_datasets.stack_exchange()


def test_ud_ancora_pos_tags():
(train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ancora_pos_tags()


@pytest.mark.xfail(reason="str column where int expected")
def test_ud_ewtb_pos_tags():
(train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ewtb_pos_tags()


@pytest.mark.xfail(reason="no default path")
def test_wikiner():
train, dev = ml_datasets.wikiner()
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[tool.pytest.ini_options]
addopts = "--strict-markers --strict-config -v -r sxfE --color=yes --durations=10"
xfail_strict = true
filterwarnings = [
"error",
# FIXME spurious random download warnings; will cause trouble in downstream CI
"ignore:Implicitly cleaning up <HTTPError 403:ResourceWarning",
"ignore:unclosed <socket.socket:ResourceWarning",
]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cloudpickle>=2.2
numpy>=1.7.0
numpy>=1.18
scipy>=1.7.0
tqdm>=4.10.0,<5.0.0
# Our libraries
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ long_description_content_type = text/markdown
[options]
zip_safe = true
include_package_data = true
python_requires = >=3.6
python_requires = >=3.8
install_requires =
cloudpickle>=2.2
numpy>=1.7.0
numpy>=1.18
tqdm>=4.10.0,<5.0.0
# Our libraries
srsly>=1.0.1,<4.0.0
Expand Down