Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import json
import os
import re

from absl import logging

Expand Down Expand Up @@ -50,6 +51,8 @@
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
METADATA_FILE = "metadata.json"

README_FILE = "README.md"

# Weight file names.
MODEL_WEIGHTS_FILE = "model.weights.h5"
TASK_WEIGHTS_FILE = "task.weights.h5"
Expand Down Expand Up @@ -333,6 +336,78 @@ def _validate_backbone(preset):
)


def get_snake_case(name):
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()


def create_model_card(preset):
model_card_path = os.path.join(preset, README_FILE)
markdown_content = ""

config = load_config(preset, CONFIG_FILE)
model_name = (
config["class_name"].replace("Backbone", "")
if config["class_name"].endswith("Backbone")
else config["class_name"]
)

task_type = None
if check_file_exists(preset, TASK_CONFIG_FILE):
task_config = load_config(preset, TASK_CONFIG_FILE)
task_type = (
task_config["class_name"].replace(model_name, "")
if task_config["class_name"].startswith(model_name)
else task_config["class_name"]
)

# YAML
markdown_content += "---\n"
markdown_content += "library_name: keras-nlp\n"
if task_type == "CausalLM":
markdown_content += "pipeline_tag: text-generation\n"
elif task_type == "Classifier":
markdown_content += "pipeline_tag: text-classification\n"
markdown_content += "---\n"

model_link = (
f"https://keras.io/api/keras_nlp/models/{get_snake_case(model_name)}"
)
markdown_content += (
f"This is a [`{model_name}` model]({model_link}) "
"uploaded using the KerasNLP library.\n"
)
if task_type:
markdown_content += (
f"This model is related to a `{task_type}` task.\n\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance we can bind this task_type to a task from https://huggingface.co/tasks?

If yes, that would be awesome to also set it as a pipeline_tag in the yaml part of the model card. This way the models would be recognized as such and therefore searchable on the Hub (for instance on https://huggingface.co/models?pipeline_tag=text-classification). Even if we can't assign a pipeline_tag in every cases (because of uncertainty), having it for a subset of the models would already be nice. If some models support multiple tasks, you can set the main one as pipeline_tag and then secondary ones listed as tags.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can list all supported tasks like this:

curl -s https://huggingface.co/api/tasks | jq -r 'keys[]'

which outputs:

audio-classification
audio-to-audio
automatic-speech-recognition
depth-estimation
document-question-answering
feature-extraction
fill-mask
image-classification
image-feature-extraction
image-segmentation
image-to-3d
image-to-image
image-to-text
mask-generation
object-detection
question-answering
reinforcement-learning
sentence-similarity
summarization
table-question-answering
tabular-classification
tabular-regression
text-classification
text-generation
text-to-3d
text-to-image
text-to-speech
text-to-video
token-classification
translation
unconditional-image-generation
video-classification
visual-question-answering
zero-shot-classification
zero-shot-image-classification
zero-shot-object-detection

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not possible, then let's keep it as it is now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only capture the high-level task type in our configs, i.e. classification vs. generation. If a user picks a text generation model and trains it on a text summarization dataset to make it a text summarization model, I'm not sure if they can record that anywhere. If the high-level task type sounds good to you, I can add that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added high-level task as pipeline_tag.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Having text-generation or text-classification is already a very nice thing. If the user retrains on a summarization dataset, then it is their responsibility to update the model card correctly (in my opinion).

)

backbone_config = config["config"]
markdown_content += "Model config:\n"
for k, v in backbone_config.items():
markdown_content += f"* **{k}:** {v}\n"
markdown_content += "\n"
markdown_content += (
"This model card has been generated automatically and should be completed "
"by the model author. See [Model Cards documentation]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one nudging the author to complete the card :)

"(https://huggingface.co/docs/hub/model-cards) for more information.\n"
)

with open(model_card_path, "w") as md_file:
md_file.write(markdown_content)


def delete_model_card(preset):
model_card_path = os.path.join(preset, README_FILE)
try:
os.remove(model_card_path)
except FileNotFoundError:
logging.warning(
f"There was an attempt to delete file `{model_card_path}` but this"
" file doesn't exist."
)


@keras_nlp_export("keras_nlp.upload_preset")
def upload_preset(
uri,
Expand Down Expand Up @@ -382,9 +457,21 @@ def upload_preset(
"'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
f"upload to your user account. Received: URI={uri}."
) from e
huggingface_hub.upload_folder(
repo_id=repo_url.repo_id, folder_path=preset
has_model_card = huggingface_hub.file_exists(
repo_id=repo_url.repo_id, filename=README_FILE
)
if not has_model_card:
# Remote repo doesn't have a model card so a basic model card is automatically generated.
create_model_card(preset)
try:
huggingface_hub.upload_folder(
repo_id=repo_url.repo_id, folder_path=preset
)
finally:
if not has_model_card:
# Clean up the preset directory in case user attempts to upload the
# preset directory into Kaggle hub as well.
delete_model_card(preset)
else:
raise ValueError(
"Unknown URI. An URI must be a one of:\n"
Expand Down