|
17 | 17 | import inspect |
18 | 18 | import json |
19 | 19 | import os |
| 20 | +import re |
20 | 21 |
|
21 | 22 | from absl import logging |
22 | 23 |
|
|
50 | 51 | PREPROCESSOR_CONFIG_FILE = "preprocessor.json" |
51 | 52 | METADATA_FILE = "metadata.json" |
52 | 53 |
|
| 54 | +README_FILE = "README.md" |
| 55 | + |
53 | 56 | # Weight file names. |
54 | 57 | MODEL_WEIGHTS_FILE = "model.weights.h5" |
55 | 58 | TASK_WEIGHTS_FILE = "task.weights.h5" |
@@ -333,6 +336,78 @@ def _validate_backbone(preset): |
333 | 336 | ) |
334 | 337 |
|
335 | 338 |
|
| 339 | +def get_snake_case(name): |
| 340 | + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) |
| 341 | + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() |
| 342 | + |
| 343 | + |
| 344 | +def create_model_card(preset): |
| 345 | + model_card_path = os.path.join(preset, README_FILE) |
| 346 | + markdown_content = "" |
| 347 | + |
| 348 | + config = load_config(preset, CONFIG_FILE) |
| 349 | + model_name = ( |
| 350 | + config["class_name"].replace("Backbone", "") |
| 351 | + if config["class_name"].endswith("Backbone") |
| 352 | + else config["class_name"] |
| 353 | + ) |
| 354 | + |
| 355 | + task_type = None |
| 356 | + if check_file_exists(preset, TASK_CONFIG_FILE): |
| 357 | + task_config = load_config(preset, TASK_CONFIG_FILE) |
| 358 | + task_type = ( |
| 359 | + task_config["class_name"].replace(model_name, "") |
| 360 | + if task_config["class_name"].startswith(model_name) |
| 361 | + else task_config["class_name"] |
| 362 | + ) |
| 363 | + |
| 364 | + # YAML |
| 365 | + markdown_content += "---\n" |
| 366 | + markdown_content += "library_name: keras-nlp\n" |
| 367 | + if task_type == "CausalLM": |
| 368 | + markdown_content += "pipeline_tag: text-generation\n" |
| 369 | + elif task_type == "Classifier": |
| 370 | + markdown_content += "pipeline_tag: text-classification\n" |
| 371 | + markdown_content += "---\n" |
| 372 | + |
| 373 | + model_link = ( |
| 374 | + f"https://keras.io/api/keras_nlp/models/{get_snake_case(model_name)}" |
| 375 | + ) |
| 376 | + markdown_content += ( |
| 377 | + f"This is a [`{model_name}` model]({model_link}) " |
| 378 | + "uploaded using the KerasNLP library.\n" |
| 379 | + ) |
| 380 | + if task_type: |
| 381 | + markdown_content += ( |
| 382 | + f"This model is related to a `{task_type}` task.\n\n" |
| 383 | + ) |
| 384 | + |
| 385 | + backbone_config = config["config"] |
| 386 | + markdown_content += "Model config:\n" |
| 387 | + for k, v in backbone_config.items(): |
| 388 | + markdown_content += f"* **{k}:** {v}\n" |
| 389 | + markdown_content += "\n" |
| 390 | + markdown_content += ( |
| 391 | + "This model card has been generated automatically and should be completed " |
| 392 | + "by the model author. See [Model Cards documentation]" |
| 393 | + "(https://huggingface.co/docs/hub/model-cards) for more information.\n" |
| 394 | + ) |
| 395 | + |
| 396 | + with open(model_card_path, "w") as md_file: |
| 397 | + md_file.write(markdown_content) |
| 398 | + |
| 399 | + |
| 400 | +def delete_model_card(preset): |
| 401 | + model_card_path = os.path.join(preset, README_FILE) |
| 402 | + try: |
| 403 | + os.remove(model_card_path) |
| 404 | + except FileNotFoundError: |
| 405 | + logging.warning( |
| 406 | + f"There was an attempt to delete file `{model_card_path}` but this" |
| 407 | + " file doesn't exist." |
| 408 | + ) |
| 409 | + |
| 410 | + |
336 | 411 | @keras_nlp_export("keras_nlp.upload_preset") |
337 | 412 | def upload_preset( |
338 | 413 | uri, |
@@ -382,9 +457,21 @@ def upload_preset( |
382 | 457 | "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly" |
383 | 458 | f"upload to your user account. Received: URI={uri}." |
384 | 459 | ) from e |
385 | | - huggingface_hub.upload_folder( |
386 | | - repo_id=repo_url.repo_id, folder_path=preset |
| 460 | + has_model_card = huggingface_hub.file_exists( |
| 461 | + repo_id=repo_url.repo_id, filename=README_FILE |
387 | 462 | ) |
| 463 | + if not has_model_card: |
| 464 | + # Remote repo doesn't have a model card so a basic model card is automatically generated. |
| 465 | + create_model_card(preset) |
| 466 | + try: |
| 467 | + huggingface_hub.upload_folder( |
| 468 | + repo_id=repo_url.repo_id, folder_path=preset |
| 469 | + ) |
| 470 | + finally: |
| 471 | + if not has_model_card: |
| 472 | + # Clean up the preset directory in case user attempts to upload the |
| 473 | + # preset directory into Kaggle hub as well. |
| 474 | + delete_model_card(preset) |
388 | 475 | else: |
389 | 476 | raise ValueError( |
390 | 477 | "Unknown URI. An URI must be a one of:\n" |
|
0 commit comments