Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat(models): add output-dir option to model_download
  • Loading branch information
KeijiBranshi committed Nov 4, 2024
commit 546959f2d8627cb78189b5e1d5c8e6afd9f5369b
29 changes: 26 additions & 3 deletions src/kagglehub/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from shutil import copytree
from typing import Optional, Union

from kagglehub import registry
Expand All @@ -13,22 +14,44 @@
DEFAULT_IGNORE_PATTERNS = [".git/", "*/.git/", ".cache/", ".huggingface/"]


def model_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
def model_download(
handle: str,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
output_dir: Optional[str] = None) -> str:
"""Download model files.
Args:
handle: (string) the model handle.
path: (string) Optional path to a file within the model bundle.
force_download: (bool) Optional flag to force download a model, even if it's cached.
output_dir: (str) Optional path to copy model files to after successful download.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should integrate the default path here instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rosbo - The parameters for this fn is starting to grow. Any thoughts on managing this? Should we use wrap this in a settings/configuration object instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just my nitpick opinion: I don't think the explicit default path here is necessary since it circumvents the general cache behavior, which is encapsulated in the http resolver.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want this method to be easy to use and I feel like adding a settings object would complicate it more than anything. Users will use keyword argument for these and likely set no parameters or at most 1-2.
If we compare to huggingface_hub.snapshot_download(...), we don't have many arguments... https://github.com/huggingface/huggingface_hub/blob/4011b5a2836d7bb036d8da54ed656f88bc0d2f7f/src/huggingface_hub/_snapshot_download.py#L21-L45

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg.

Returns:
A string representing the path to the requested model files.
"""
h = parse_model_handle(handle)
logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
return registry.model_resolver(h, path, force_download=force_download)
cached_dir = registry.model_resolver(h, path, force_download=force_download)

if output_dir is None:
return cached_dir

try:
# only copying so that we can maintain the cached files
logger.info(
f"Copying model files to requested directory: {output_dir} ...",
extra={**EXTRA_CONSOLE_BLOCK}
)
true_output_dir = copytree(cached_dir, output_dir, dirs_exist_ok=True)
return true_output_dir
except Exception as e:
logger.warn(
f"Successfully downloaded {handle}, but failed to copy from {cached_dir} "
f"to requested output directory {output_dir}. Encountered error: {e}"
)
return cached_dir

def model_upload(
handle: str,
Expand Down
24 changes: 23 additions & 1 deletion tests/test_http_model_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kagglehub.cache import MODELS_CACHE_SUBFOLDER, get_cached_archive_path
from kagglehub.handle import parse_model_handle
from tests.fixtures import BaseTestCase
from unittest import mock

from .server_stubs import model_download_stub as stub
from .server_stubs import serv
Expand Down Expand Up @@ -147,6 +148,28 @@ def test_versioned_model_download_with_path_with_force_download(self) -> None:
with create_test_cache() as d:
self._download_test_file_and_assert_downloaded(d, VERSIONED_MODEL_HANDLE, force_download=True)

def test_versioned_model_download_with_output_dir(self) -> None:
with create_test_cache() as d:
expected_ouput_dir = "/tmp/downloaded_model"
self._download_model_and_assert_downloaded(
d,
VERSIONED_MODEL_HANDLE,
expected_ouput_dir,
output_dir=expected_ouput_dir
)

def test_versioned_model_download_with_bad_output_dir(self) -> None:
with create_test_cache() as d:
mock.patch("kagglehub.models.copytree", side_effect=Exception())
bad_output_dir = "/bad/path/that/fails"
expected_output_dir = EXPECTED_MODEL_SUBDIR # falls back to default
self._download_model_and_assert_downloaded(
d,
VERSIONED_MODEL_HANDLE,
expected_output_dir,
output_dir=bad_output_dir
)

def test_unversioned_model_download_with_path_with_force_download(self) -> None:
with create_test_cache() as d:
self._download_test_file_and_assert_downloaded(d, UNVERSIONED_MODEL_HANDLE, force_download=True)
Expand Down Expand Up @@ -188,7 +211,6 @@ def test_versioned_model_download_with_path_already_cached_with_force_download_e

self.assertEqual(os.path.join(d, EXPECTED_MODEL_SUBPATH), model_path)


class TestHttpNoInternet(BaseTestCase):
def test_versioned_model_download_already_cached_with_force_download(self) -> None:
with create_test_cache():
Expand Down