Skip to content

Commit 777063e

Browse files
Update textual_inversion.py (#6952)
* Update textual_inversion.py * Apply suggestions from code review * Update textual_inversion.py * Update textual_inversion.py * Update textual_inversion.py * Update textual_inversion.py * Update examples/textual_inversion/textual_inversion.py Co-authored-by: Sayak Paul <[email protected]> * Update textual_inversion.py * styling --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 104afbc commit 777063e

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from diffusers.optimization import get_scheduler
5555
from diffusers.utils import check_min_version, is_wandb_available
56+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5657
from diffusers.utils.import_utils import is_xformers_available
5758

5859

@@ -84,32 +85,30 @@
8485
logger = get_logger(__name__)
8586

8687

87-
def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
88+
def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None):
8889
img_str = ""
89-
for i, image in enumerate(images):
90-
image.save(os.path.join(repo_folder, f"image_{i}.png"))
91-
img_str += f"![img_{i}](./image_{i}.png)\n"
92-
93-
yaml = f"""
94-
---
95-
license: creativeml-openrail-m
96-
base_model: {base_model}
97-
tags:
98-
- stable-diffusion
99-
- stable-diffusion-diffusers
100-
- text-to-image
101-
- diffusers
102-
- textual_inversion
103-
inference: true
104-
---
105-
"""
106-
model_card = f"""
90+
if images is not None:
91+
for i, image in enumerate(images):
92+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
93+
img_str += f"![img_{i}](./image_{i}.png)\n"
94+
model_description = f"""
10795
# Textual inversion text2image fine-tuning - {repo_id}
10896
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
10997
{img_str}
11098
"""
111-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
112-
f.write(yaml + model_card)
99+
model_card = load_or_create_model_card(
100+
repo_id_or_path=repo_id,
101+
from_training=True,
102+
license="creativeml-openrail-m",
103+
base_model=base_model,
104+
model_description=model_description,
105+
inference=True,
106+
)
107+
108+
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "textual_inversion"]
109+
model_card = populate_model_card(model_card, tags=tags)
110+
111+
model_card.save(os.path.join(repo_folder, "README.md"))
113112

114113

115114
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):

0 commit comments

Comments
 (0)