Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests for downloading versioned data sets with target path
  • Loading branch information
BuechlerA committed Jun 15, 2025
commit dd8face33086945374104c1fa491bada39c40508
16 changes: 16 additions & 0 deletions tests/test_colab_cache_dataset_download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from unittest import mock

import requests
Expand Down Expand Up @@ -73,6 +74,21 @@ def test_versioned_dataset_download_bad_handle_raises(self) -> None:
with self.assertRaises(ValueError):
kagglehub.dataset_download("bad handle")

def test_versioned_dataset_download_with_target_path(self) -> None:
with stub.create_env():
target_dir = os.path.join(os.getcwd(), "custom_target")
os.makedirs(target_dir, exist_ok=True)
try:
dataset_path = kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, target_path=target_dir)
# Colab cache resolver ignores target_path, so it should return the original path
self.assertNotEqual(target_dir, os.path.dirname(dataset_path))
# Check that original dataset path has expected ending
self.assertTrue(dataset_path.endswith("/1"))
finally:
# Clean up
if os.path.exists(target_dir):
shutil.rmtree(target_dir)


class TestNoInternetColabCacheModelDownload(BaseTestCase):
def test_colab_resolver_skipped_when_dataset_not_present(self) -> None:
Expand Down
55 changes: 49 additions & 6 deletions tests/test_http_dataset_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,46 @@ def _download_dataset_and_assert_downloaded(
dataset_handle: str,
expected_subdir_or_subpath: str,
expected_files: Optional[list[str]] = None,
expected_target_path: Optional[str] = None,
**kwargs, # noqa: ANN003
) -> None:
# Download the full datasets and ensure all files are there.
dataset_path = kagglehub.dataset_download(dataset_handle, **kwargs)

self.assertEqual(os.path.join(d, expected_subdir_or_subpath), dataset_path)
# If target_path was specified, check that the file was copied there
if expected_target_path:
self.assertEqual(expected_target_path, dataset_path)
self.assertTrue(os.path.exists(expected_target_path))
else:
self.assertEqual(os.path.join(d, expected_subdir_or_subpath), dataset_path)

path_to_check = dataset_path

if not expected_files:
expected_files = ["foo.txt"]
self.assertEqual(sorted(expected_files), sorted(os.listdir(dataset_path)))
self.assertEqual(sorted(expected_files), sorted(os.listdir(path_to_check)))

# Assert that the archive file has been deleted
archive_path = get_cached_archive_path(parse_dataset_handle(dataset_handle))
self.assertFalse(os.path.exists(archive_path))

def _download_test_file_and_assert_downloaded(self, d: str, dataset_handle: str, **kwargs) -> None: # noqa: ANN003
def _download_test_file_and_assert_downloaded(
self,
d: str,
dataset_handle: str,
expected_target_path: Optional[str] = None,
**kwargs, # noqa: ANN003
) -> None:
dataset_path = kagglehub.dataset_download(dataset_handle, path=TEST_FILEPATH, **kwargs)
self.assertEqual(os.path.join(d, EXPECTED_DATASET_SUBPATH, TEST_FILEPATH), dataset_path)
with open(dataset_path) as dataset_file:
self.assertEqual(TEST_CONTENTS, dataset_file.read())

if expected_target_path:
self.assertEqual(expected_target_path, dataset_path)
with open(dataset_path) as dataset_file:
self.assertEqual(TEST_CONTENTS, dataset_file.read())
else:
self.assertEqual(os.path.join(d, EXPECTED_DATASET_SUBPATH, TEST_FILEPATH), dataset_path)
with open(dataset_path) as dataset_file:
self.assertEqual(TEST_CONTENTS, dataset_file.read())

def _download_test_file_and_assert_downloaded_auto_compressed(
self,
Expand Down Expand Up @@ -133,3 +153,26 @@ def test_unversioned_dataset_full_download_with_file_already_cached(self) -> Non
# Download a single file first
kagglehub.dataset_download(UNVERSIONED_DATASET_HANDLE, path=TEST_FILEPATH)
self._download_dataset_and_assert_downloaded(d, UNVERSIONED_DATASET_HANDLE, EXPECTED_DATASET_SUBDIR)

def test_versioned_dataset_download_with_target_path(self) -> None:
with create_test_cache() as d:
target_dir = os.path.join(d, "custom_target")
os.makedirs(target_dir, exist_ok=True)
self._download_dataset_and_assert_downloaded(
d,
VERSIONED_DATASET_HANDLE,
EXPECTED_DATASET_SUBDIR,
target_path=target_dir,
expected_target_path=os.path.join(target_dir, os.path.basename(EXPECTED_DATASET_SUBPATH))
)

def test_versioned_dataset_download_with_path_and_target_path(self) -> None:
with create_test_cache() as d:
target_dir = os.path.join(d, "custom_target")
os.makedirs(target_dir, exist_ok=True)
self._download_test_file_and_assert_downloaded(
d,
VERSIONED_DATASET_HANDLE,
target_path=target_dir,
expected_target_path=os.path.join(target_dir, TEST_FILEPATH)
)
16 changes: 16 additions & 0 deletions tests/test_kaggle_cache_dataset_download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from unittest import mock

import requests
Expand Down Expand Up @@ -84,3 +85,18 @@ def test_versioned_dataset_download_with_force_download_explicitly_false(self) -
with stub.create_env():
dataset_path = kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, force_download=False)
self.assertEqual(["foo.txt"], sorted(os.listdir(dataset_path)))

def test_versioned_dataset_download_with_target_path(self) -> None:
with stub.create_env():
target_dir = os.path.join(os.getcwd(), "custom_target")
os.makedirs(target_dir, exist_ok=True)
try:
dataset_path = kagglehub.dataset_download(VERSIONED_DATASET_HANDLE, target_path=target_dir)
# Kaggle cache resolver ignores target_path, so it should return the original path
self.assertNotEqual(target_dir, os.path.dirname(dataset_path))
# Check that original dataset path contains expected files
self.assertEqual(["foo.txt"], sorted(os.listdir(dataset_path)))
finally:
# Clean up
if os.path.exists(target_dir):
shutil.rmtree(target_dir)