|
51 | 51 | )
|
52 | 52 | from diffusers.optimization import get_scheduler
|
53 | 53 | from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
| 54 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
54 | 55 | from diffusers.utils.import_utils import is_xformers_available
|
55 | 56 | from diffusers.utils.torch_utils import is_compiled_module
|
56 | 57 |
|
@@ -199,28 +200,32 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
199 | 200 | make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
200 | 201 | img_str += f"\n"
|
201 | 202 |
|
202 |
| - yaml = f""" |
203 |
| ---- |
204 |
| -license: openrail++ |
205 |
| -base_model: {base_model} |
206 |
| -tags: |
207 |
| -- stable-diffusion-xl |
208 |
| -- stable-diffusion-xl-diffusers |
209 |
| -- text-to-image |
210 |
| -- diffusers |
211 |
| -- controlnet |
212 |
| -inference: true |
213 |
| ---- |
214 |
| - """ |
215 |
| - model_card = f""" |
| 203 | + model_description = f""" |
216 | 204 | # controlnet-{repo_id}
|
217 | 205 |
|
218 | 206 | These are controlnet weights trained on {base_model} with new type of conditioning.
|
219 | 207 | {img_str}
|
220 | 208 | """
|
221 | 209 |
|
222 |
| - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
223 |
| - f.write(yaml + model_card) |
| 210 | + model_card = load_or_create_model_card( |
| 211 | + repo_id_or_path=repo_id, |
| 212 | + from_training=True, |
| 213 | + license="openrail++", |
| 214 | + base_model=base_model, |
| 215 | + model_description=model_description, |
| 216 | + inference=True, |
| 217 | + ) |
| 218 | + |
| 219 | + tags = [ |
| 220 | + "stable-diffusion-xl", |
| 221 | + "stable-diffusion-xl-diffusers", |
| 222 | + "text-to-image", |
| 223 | + "diffusers", |
| 224 | + "controlnet", |
| 225 | + ] |
| 226 | + model_card = populate_model_card(model_card, tags=tags) |
| 227 | + |
| 228 | + model_card.save(os.path.join(repo_folder, "README.md")) |
224 | 229 |
|
225 | 230 |
|
226 | 231 | def parse_args(input_args=None):
|
|
0 commit comments