Skip to content

Commit f07899a

Browse files
authored
Standardize model card for Controlnet SDXL (huggingface#6908)
controlnet-sdxl
1 parent a83cc0c commit f07899a

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from diffusers.optimization import get_scheduler
5353
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
54+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5455
from diffusers.utils.import_utils import is_xformers_available
5556
from diffusers.utils.torch_utils import is_compiled_module
5657

@@ -199,28 +200,32 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
199200
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
200201
img_str += f"![images_{i})](./images_{i}.png)\n"
201202

202-
yaml = f"""
203-
---
204-
license: openrail++
205-
base_model: {base_model}
206-
tags:
207-
- stable-diffusion-xl
208-
- stable-diffusion-xl-diffusers
209-
- text-to-image
210-
- diffusers
211-
- controlnet
212-
inference: true
213-
---
214-
"""
215-
model_card = f"""
203+
model_description = f"""
216204
# controlnet-{repo_id}
217205
218206
These are controlnet weights trained on {base_model} with new type of conditioning.
219207
{img_str}
220208
"""
221209

222-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
223-
f.write(yaml + model_card)
210+
model_card = load_or_create_model_card(
211+
repo_id_or_path=repo_id,
212+
from_training=True,
213+
license="openrail++",
214+
base_model=base_model,
215+
model_description=model_description,
216+
inference=True,
217+
)
218+
219+
tags = [
220+
"stable-diffusion-xl",
221+
"stable-diffusion-xl-diffusers",
222+
"text-to-image",
223+
"diffusers",
224+
"controlnet",
225+
]
226+
model_card = populate_model_card(model_card, tags=tags)
227+
228+
model_card.save(os.path.join(repo_folder, "README.md"))
224229

225230

226231
def parse_args(input_args=None):

0 commit comments

Comments
 (0)