Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix(pr): lint + extra equality checks
  • Loading branch information
KeijiBranshi committed Nov 4, 2024
commit 0997d406b84c90d0ebad80ac10d44c99358cfc6f
4 changes: 2 additions & 2 deletions src/kagglehub/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def model_download(
handle: (string) the model handle.
path: (string) Optional path to a file within the model bundle.
force_download: (bool) Optional flag to force download a model, even if it's cached.
output_dir: (str) Optional path to copy model files to after successful download.
output_dir: (string) Optional path to copy model files to after successful download.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the user's expectation would be that files are downloaded directly to this output_dir and the cache folder is skipped entirely.


Returns:
A string representing the path to the requested model files.
Expand All @@ -35,7 +35,7 @@ def model_download(
logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
cached_dir = registry.model_resolver(h, path, force_download=force_download)

if output_dir is None:
if output_dir is None or output_dir == cached_dir:
return cached_dir

try:
Expand Down
29 changes: 17 additions & 12 deletions tests/test_http_model_download.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from tempfile import TemporaryDirectory
from typing import Optional
from unittest import mock

import requests

import kagglehub
from kagglehub.cache import MODELS_CACHE_SUBFOLDER, get_cached_archive_path
from kagglehub.handle import parse_model_handle
from tests.fixtures import BaseTestCase
from unittest import mock

from .server_stubs import model_download_stub as stub
from .server_stubs import serv
Expand Down Expand Up @@ -150,24 +151,28 @@ def test_versioned_model_download_with_path_with_force_download(self) -> None:

def test_versioned_model_download_with_output_dir(self) -> None:
with create_test_cache() as d:
expected_ouput_dir = "/tmp/downloaded_model"
self._download_model_and_assert_downloaded(
d,
VERSIONED_MODEL_HANDLE,
expected_ouput_dir,
output_dir=expected_ouput_dir
)
with TemporaryDirectory() as expected_output_dir:
self._download_model_and_assert_downloaded(
d,
VERSIONED_MODEL_HANDLE,
expected_output_dir,
output_dir=expected_output_dir
)

def test_versioned_model_download_with_bad_output_dir(self) -> None:
with create_test_cache() as d:
mock.patch("kagglehub.models.copytree", side_effect=Exception())
bad_output_dir = "/bad/path/that/fails"
with (
create_test_cache() as d,
TemporaryDirectory() as placeholder_dir,
mock.patch("kagglehub.models.copytree") as mock_copytree
):
mock_copytree.side_effect = Exception("Mock exception")
expected_output_dir = EXPECTED_MODEL_SUBDIR # falls back to default
self._download_model_and_assert_downloaded(
d,
VERSIONED_MODEL_HANDLE,
expected_output_dir,
output_dir=bad_output_dir
# note: placeholder name is irrelevant since copytree is mocked to throw
output_dir=placeholder_dir
)

def test_unversioned_model_download_with_path_with_force_download(self) -> None:
Expand Down