Skip to content

Commit d7da19c

Browse files
committed
Automatically add the keras framework to kaggle handles (#1331)
1 parent da0842c commit d7da19c

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

keras_nlp/utils/preset_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,25 @@
3030
def get_file(preset, path):
3131
"""Download a preset file in necessary and return the local path."""
3232
if preset.startswith(KAGGLE_PREFIX):
33-
kaggle_handle = preset.removeprefix(KAGGLE_PREFIX)
3433
if kagglehub is None:
3534
raise ImportError(
3635
"`from_preset()` requires the `kagglehub` package. "
3736
"Please install with `pip install kagglehub`."
3837
)
39-
if len(kaggle_handle.split("/")) not in (4, 5):
38+
segments = preset.removeprefix(KAGGLE_PREFIX).split("/")
39+
# Insert the kaggle framework into the handle.
40+
if len(segments) == 3:
41+
org, model, variant = segments
42+
kaggle_handle = f"{org}/{model}/keras/{variant}/1"
43+
elif len(segments) == 4:
44+
org, model, variant, version = segments
45+
kaggle_handle = f"{org}/{model}/keras/{variant}/{version}"
46+
else:
4047
raise ValueError(
41-
"Unexpected kaggle preset handle. Kaggle model handles should have "
42-
"the form kaggle://{org}/{model}/keras/{variant}[/{version}]. For "
43-
"example, kaggle://keras-nlp/albert/keras/bert_base_en_uncased."
48+
"Unexpected kaggle preset handle. Kaggle model handles should "
49+
"have the form kaggle://{org}/{model}/{variant}[/{version}]. "
50+
"For example, 'kaggle://keras/bert/bert_base_en'. "
51+
f"Received: preset={preset}"
4452
)
4553
return kagglehub.model_download(kaggle_handle, path)
4654
return os.path.join(preset, path)

0 commit comments

Comments
 (0)