|
53 | 53 | )
|
54 | 54 | from diffusers.optimization import get_scheduler
|
55 | 55 | from diffusers.utils import check_min_version, is_wandb_available
|
| 56 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
56 | 57 | from diffusers.utils.import_utils import is_xformers_available
|
57 | 58 |
|
58 | 59 |
|
|
84 | 85 | logger = get_logger(__name__)
|
85 | 86 |
|
86 | 87 |
|
87 |
| -def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): |
| 88 | +def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None): |
88 | 89 | img_str = ""
|
89 |
| - for i, image in enumerate(images): |
90 |
| - image.save(os.path.join(repo_folder, f"image_{i}.png")) |
91 |
| - img_str += f"\n" |
92 |
| - |
93 |
| - yaml = f""" |
94 |
| ---- |
95 |
| -license: creativeml-openrail-m |
96 |
| -base_model: {base_model} |
97 |
| -tags: |
98 |
| -- stable-diffusion |
99 |
| -- stable-diffusion-diffusers |
100 |
| -- text-to-image |
101 |
| -- diffusers |
102 |
| -- textual_inversion |
103 |
| -inference: true |
104 |
| ---- |
105 |
| - """ |
106 |
| - model_card = f""" |
| 90 | + if images is not None: |
| 91 | + for i, image in enumerate(images): |
| 92 | + image.save(os.path.join(repo_folder, f"image_{i}.png")) |
| 93 | + img_str += f"\n" |
| 94 | + model_description = f""" |
107 | 95 | # Textual inversion text2image fine-tuning - {repo_id}
|
108 | 96 | These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
|
109 | 97 | {img_str}
|
110 | 98 | """
|
111 |
| - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
112 |
| - f.write(yaml + model_card) |
| 99 | + model_card = load_or_create_model_card( |
| 100 | + repo_id_or_path=repo_id, |
| 101 | + from_training=True, |
| 102 | + license="creativeml-openrail-m", |
| 103 | + base_model=base_model, |
| 104 | + model_description=model_description, |
| 105 | + inference=True, |
| 106 | + ) |
| 107 | + |
| 108 | + tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "textual_inversion"] |
| 109 | + model_card = populate_model_card(model_card, tags=tags) |
| 110 | + |
| 111 | + model_card.save(os.path.join(repo_folder, "README.md")) |
113 | 112 |
|
114 | 113 |
|
115 | 114 | def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
|
|
0 commit comments