Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add target_path variable to resolver classes
  • Loading branch information
BuechlerA committed Jun 15, 2025
commit 9bd7d0988a97c45d77098478717c1aa43f929133
19 changes: 17 additions & 2 deletions src/kagglehub/colab_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,23 @@ def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002,
return True

def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: ModelHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Colab notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this ignored inside Colab, since it optional and not set by default is puts not really nice restriction on a user

logger.info(
"Ignoring `target_path` argument when running inside the Colab notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)

api_client = ColabClient()
data = {
Expand Down Expand Up @@ -118,7 +128,12 @@ def is_supported(self, handle: DatasetHandle, *_, **__) -> bool: # noqa: ANN002
return True

def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False, target_path: Optional[str] = None
self,
h: DatasetHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
Expand Down
44 changes: 40 additions & 4 deletions src/kagglehub/http_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from kagglehub.clients import KaggleApiV1Client
from kagglehub.exceptions import UnauthenticatedError
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle
from kagglehub.logger import EXTRA_CONSOLE_BLOCK
from kagglehub.packages import PackageScope
from kagglehub.resolver import Resolver

Expand All @@ -36,8 +37,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: CompetitionHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if target_path:
logger.info(
"Ignoring `target_path` argument for competition downloads.",
extra={**EXTRA_CONSOLE_BLOCK},
)
api_client = KaggleApiV1Client()

cached_path = load_from_cache(h, path)
Expand Down Expand Up @@ -100,7 +111,12 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False, target_path: Optional[str] = None
self,
h: DatasetHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
api_client = KaggleApiV1Client()

Expand Down Expand Up @@ -154,8 +170,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: ModelHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if target_path:
logger.info(
"Ignoring `target_path` argument for model downloads.",
extra={**EXTRA_CONSOLE_BLOCK},
)
api_client = KaggleApiV1Client()

if not h.is_versioned():
Expand Down Expand Up @@ -216,8 +242,18 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return True

def _resolve(
self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: NotebookHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if target_path:
logger.info(
"Ignoring `target_path` argument for notebook output downloads.",
extra={**EXTRA_CONSOLE_BLOCK},
)
api_client = KaggleApiV1Client()

if not h.is_versioned():
Expand Down
43 changes: 39 additions & 4 deletions src/kagglehub/kaggle_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,24 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: CompetitionHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
client = KaggleJwtClient()
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)

competition_ref = {
"CompetitionSlug": h.competition,
Expand Down Expand Up @@ -102,7 +112,12 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False, target_path: Optional[str] = None
self,
h: DatasetHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
Expand Down Expand Up @@ -182,13 +197,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: ModelHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
client = KaggleJwtClient()
model_ref = {
"OwnerSlug": h.owner,
Expand Down Expand Up @@ -259,13 +284,23 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
return False

def _resolve(
self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
self,
h: NotebookHandle,
path: Optional[str] = None,
*,
force_download: Optional[bool] = False,
target_path: Optional[str] = None,
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
if target_path:
logger.info(
"Ignoring `target_path` argument when running inside the Kaggle notebook environment.",
extra={**EXTRA_CONSOLE_BLOCK},
)
client = KaggleJwtClient()
kernel_ref = {
"OwnerSlug": h.owner,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_http_dataset_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _download_test_file_and_assert_downloaded(
**kwargs, # noqa: ANN003
) -> None:
dataset_path = kagglehub.dataset_download(dataset_handle, path=TEST_FILEPATH, **kwargs)

if expected_target_path:
self.assertEqual(expected_target_path, dataset_path)
with open(dataset_path) as dataset_file:
Expand Down Expand Up @@ -163,7 +163,7 @@ def test_versioned_dataset_download_with_target_path(self) -> None:
VERSIONED_DATASET_HANDLE,
EXPECTED_DATASET_SUBDIR,
target_path=target_dir,
expected_target_path=os.path.join(target_dir, os.path.basename(EXPECTED_DATASET_SUBPATH))
expected_target_path=os.path.join(target_dir, os.path.basename(EXPECTED_DATASET_SUBPATH)),
)

def test_versioned_dataset_download_with_path_and_target_path(self) -> None:
Expand All @@ -174,5 +174,5 @@ def test_versioned_dataset_download_with_path_and_target_path(self) -> None:
d,
VERSIONED_DATASET_HANDLE,
target_path=target_dir,
expected_target_path=os.path.join(target_dir, TEST_FILEPATH)
expected_target_path=os.path.join(target_dir, TEST_FILEPATH),
)