From e704cbc0fef140061cd8876bc04a135278d29c43 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 28 Nov 2023 16:35:01 -0800 Subject: [PATCH] Automatically add the keras framework to kaggle handles --- keras_nlp/utils/preset_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 1657d8e746..04ca3a39cd 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -30,17 +30,25 @@ def get_file(preset, path): """Download a preset file in necessary and return the local path.""" if preset.startswith(KAGGLE_PREFIX): - kaggle_handle = preset.removeprefix(KAGGLE_PREFIX) if kagglehub is None: raise ImportError( "`from_preset()` requires the `kagglehub` package. " "Please install with `pip install kagglehub`." ) - if len(kaggle_handle.split("/")) not in (4, 5): + segments = preset.removeprefix(KAGGLE_PREFIX).split("/") + # Insert the kaggle framework into the handle. + if len(segments) == 3: + org, model, variant = segments + kaggle_handle = f"{org}/{model}/keras/{variant}/1" + elif len(segments) == 4: + org, model, variant, version = segments + kaggle_handle = f"{org}/{model}/keras/{variant}/{version}" + else: raise ValueError( - "Unexpected kaggle preset handle. Kaggle model handles should have " - "the form kaggle://{org}/{model}/keras/{variant}[/{version}]. For " - "example, kaggle://keras-nlp/albert/keras/bert_base_en_uncased." + "Unexpected kaggle preset handle. Kaggle model handles should " + "have the form kaggle://{org}/{model}/{variant}[/{version}]. " + "For example, 'kaggle://keras/bert/bert_base_en'. " + f"Received: preset={preset}" ) return kagglehub.model_download(kaggle_handle, path) return os.path.join(preset, path)