Skip to content

Commit aec862a

Browse files
Add Model Card for Hugging Face Upload (#1578)
* Add model card for Hugging Face upload. * Add task type. * Add README_FILE constant. * Improve error handling for delete_model_card. * Address reviews. * Address reviews. * Add pipeline_tag for tasks. * Fix model name to model link conversion. * Don't create model card if it already exists.
1 parent 16d3ebb commit aec862a

File tree

1 file changed

+89
-2
lines changed

1 file changed

+89
-2
lines changed

keras_nlp/utils/preset_utils.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import inspect
1818
import json
1919
import os
20+
import re
2021

2122
from absl import logging
2223

@@ -50,6 +51,8 @@
5051
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
5152
METADATA_FILE = "metadata.json"
5253

54+
README_FILE = "README.md"
55+
5356
# Weight file names.
5457
MODEL_WEIGHTS_FILE = "model.weights.h5"
5558
TASK_WEIGHTS_FILE = "task.weights.h5"
@@ -333,6 +336,78 @@ def _validate_backbone(preset):
333336
)
334337

335338

339+
def get_snake_case(name):
340+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
341+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
342+
343+
344+
def create_model_card(preset):
345+
model_card_path = os.path.join(preset, README_FILE)
346+
markdown_content = ""
347+
348+
config = load_config(preset, CONFIG_FILE)
349+
model_name = (
350+
config["class_name"].replace("Backbone", "")
351+
if config["class_name"].endswith("Backbone")
352+
else config["class_name"]
353+
)
354+
355+
task_type = None
356+
if check_file_exists(preset, TASK_CONFIG_FILE):
357+
task_config = load_config(preset, TASK_CONFIG_FILE)
358+
task_type = (
359+
task_config["class_name"].replace(model_name, "")
360+
if task_config["class_name"].startswith(model_name)
361+
else task_config["class_name"]
362+
)
363+
364+
# YAML
365+
markdown_content += "---\n"
366+
markdown_content += "library_name: keras-nlp\n"
367+
if task_type == "CausalLM":
368+
markdown_content += "pipeline_tag: text-generation\n"
369+
elif task_type == "Classifier":
370+
markdown_content += "pipeline_tag: text-classification\n"
371+
markdown_content += "---\n"
372+
373+
model_link = (
374+
f"https://keras.io/api/keras_nlp/models/{get_snake_case(model_name)}"
375+
)
376+
markdown_content += (
377+
f"This is a [`{model_name}` model]({model_link}) "
378+
"uploaded using the KerasNLP library.\n"
379+
)
380+
if task_type:
381+
markdown_content += (
382+
f"This model is related to a `{task_type}` task.\n\n"
383+
)
384+
385+
backbone_config = config["config"]
386+
markdown_content += "Model config:\n"
387+
for k, v in backbone_config.items():
388+
markdown_content += f"* **{k}:** {v}\n"
389+
markdown_content += "\n"
390+
markdown_content += (
391+
"This model card has been generated automatically and should be completed "
392+
"by the model author. See [Model Cards documentation]"
393+
"(https://huggingface.co/docs/hub/model-cards) for more information.\n"
394+
)
395+
396+
with open(model_card_path, "w") as md_file:
397+
md_file.write(markdown_content)
398+
399+
400+
def delete_model_card(preset):
401+
model_card_path = os.path.join(preset, README_FILE)
402+
try:
403+
os.remove(model_card_path)
404+
except FileNotFoundError:
405+
logging.warning(
406+
f"There was an attempt to delete file `{model_card_path}` but this"
407+
" file doesn't exist."
408+
)
409+
410+
336411
@keras_nlp_export("keras_nlp.upload_preset")
337412
def upload_preset(
338413
uri,
@@ -382,9 +457,21 @@ def upload_preset(
382457
"'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
383458
f"upload to your user account. Received: URI={uri}."
384459
) from e
385-
huggingface_hub.upload_folder(
386-
repo_id=repo_url.repo_id, folder_path=preset
460+
has_model_card = huggingface_hub.file_exists(
461+
repo_id=repo_url.repo_id, filename=README_FILE
387462
)
463+
if not has_model_card:
464+
# Remote repo doesn't have a model card so a basic model card is automatically generated.
465+
create_model_card(preset)
466+
try:
467+
huggingface_hub.upload_folder(
468+
repo_id=repo_url.repo_id, folder_path=preset
469+
)
470+
finally:
471+
if not has_model_card:
472+
# Clean up the preset directory in case user attempts to upload the
473+
# preset directory into Kaggle hub as well.
474+
delete_model_card(preset)
388475
else:
389476
raise ValueError(
390477
"Unknown URI. An URI must be a one of:\n"

0 commit comments

Comments
 (0)