|
30 | 30 | def get_file(preset, path): |
31 | 31 | """Download a preset file in necessary and return the local path.""" |
32 | 32 | if preset.startswith(KAGGLE_PREFIX): |
33 | | - kaggle_handle = preset.removeprefix(KAGGLE_PREFIX) |
34 | 33 | if kagglehub is None: |
35 | 34 | raise ImportError( |
36 | 35 | "`from_preset()` requires the `kagglehub` package. " |
37 | 36 | "Please install with `pip install kagglehub`." |
38 | 37 | ) |
39 | | - if len(kaggle_handle.split("/")) not in (4, 5): |
| 38 | + segments = preset.removeprefix(KAGGLE_PREFIX).split("/") |
| 39 | + # Insert the kaggle framework into the handle. |
| 40 | + if len(segments) == 3: |
| 41 | + org, model, variant = segments |
| 42 | + kaggle_handle = f"{org}/{model}/keras/{variant}/1" |
| 43 | + elif len(segments) == 4: |
| 44 | + org, model, variant, version = segments |
| 45 | + kaggle_handle = f"{org}/{model}/keras/{variant}/{version}" |
| 46 | + else: |
40 | 47 | raise ValueError( |
41 | | - "Unexpected kaggle preset handle. Kaggle model handles should have " |
42 | | - "the form kaggle://{org}/{model}/keras/{variant}[/{version}]. For " |
43 | | - "example, kaggle://keras-nlp/albert/keras/bert_base_en_uncased." |
| 48 | + "Unexpected kaggle preset handle. Kaggle model handles should " |
| 49 | + "have the form kaggle://{org}/{model}/{variant}[/{version}]. " |
| 50 | + "For example, 'kaggle://keras/bert/bert_base_en'. " |
| 51 | + f"Received: preset={preset}" |
44 | 52 | ) |
45 | 53 | return kagglehub.model_download(kaggle_handle, path) |
46 | 54 | return os.path.join(preset, path) |
|
0 commit comments