Skip to content

Commit aa68bd4

Browse files
kohya-ssDisty0
authored andcommitted
Update dependencies ref kohya-ss#1024
1 parent b9995d4 commit aa68bd4

4 files changed

Lines changed: 19 additions & 40 deletions

File tree

library/lpw_stable_diffusion.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import PIL.Image
1010
import torch
1111
from packaging import version
12-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
12+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
1313

1414
import diffusers
1515
from diffusers import SchedulerMixin, StableDiffusionPipeline
@@ -520,6 +520,7 @@ def __init__(
520520
safety_checker: StableDiffusionSafetyChecker,
521521
feature_extractor: CLIPFeatureExtractor,
522522
requires_safety_checker: bool = True,
523+
image_encoder: CLIPVisionModelWithProjection = None,
523524
clip_skip: int = 1,
524525
):
525526
super().__init__(
@@ -531,32 +532,11 @@ def __init__(
531532
safety_checker=safety_checker,
532533
feature_extractor=feature_extractor,
533534
requires_safety_checker=requires_safety_checker,
535+
image_encoder=image_encoder,
534536
)
535-
self.clip_skip = clip_skip
537+
self.custom_clip_skip = clip_skip
536538
self.__init__additional__()
537539

538-
# else:
539-
# def __init__(
540-
# self,
541-
# vae: AutoencoderKL,
542-
# text_encoder: CLIPTextModel,
543-
# tokenizer: CLIPTokenizer,
544-
# unet: UNet2DConditionModel,
545-
# scheduler: SchedulerMixin,
546-
# safety_checker: StableDiffusionSafetyChecker,
547-
# feature_extractor: CLIPFeatureExtractor,
548-
# ):
549-
# super().__init__(
550-
# vae=vae,
551-
# text_encoder=text_encoder,
552-
# tokenizer=tokenizer,
553-
# unet=unet,
554-
# scheduler=scheduler,
555-
# safety_checker=safety_checker,
556-
# feature_extractor=feature_extractor,
557-
# )
558-
# self.__init__additional__()
559-
560540
def __init__additional__(self):
561541
if not hasattr(self, "vae_scale_factor"):
562542
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
@@ -624,7 +604,7 @@ def _encode_prompt(
624604
prompt=prompt,
625605
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
626606
max_embeddings_multiples=max_embeddings_multiples,
627-
clip_skip=self.clip_skip,
607+
clip_skip=self.custom_clip_skip,
628608
)
629609
bs_embed, seq_len, _ = text_embeddings.shape
630610
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)

library/model_util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import math
55
import os
66
import torch
7+
78
try:
89
import intel_extension_for_pytorch as ipex
10+
911
if torch.xpu.is_available():
1012
from library.ipex import ipex_init
13+
1114
ipex_init()
1215
except Exception:
1316
pass
@@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
571574
if key.startswith("cond_stage_model.transformer"):
572575
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
573576

574-
# support checkpoint without position_ids (invalid checkpoint)
575-
if "text_model.embeddings.position_ids" not in text_model_dict:
576-
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
577+
# remove position_ids for newer transformer, which causes error :(
578+
if "text_model.embeddings.position_ids" in text_model_dict:
579+
text_model_dict.pop("text_model.embeddings.position_ids")
577580

578581
return text_model_dict
579582

library/sdxl_model_util.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def convert_key(key):
100100
key = key.replace(".ln_final", ".final_layer_norm")
101101
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
102102
elif ".embeddings.position_ids" in key:
103-
key = None # remove this key: make position_ids by ourselves
103+
key = None # remove this key: position_ids is not used in newer transformers
104104
return key
105105

106106
keys = list(checkpoint.keys())
@@ -126,10 +126,6 @@ def convert_key(key):
126126
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
127127
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
128128

129-
# original SD にはないので、position_idsを追加
130-
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
131-
new_sd["text_model.embeddings.position_ids"] = position_ids
132-
133129
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
134130
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
135131

@@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
265261
elif k.startswith("conditioner.embedders.1.model."):
266262
te2_sd[k] = state_dict.pop(k)
267263

268-
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
269-
if "text_model.embeddings.position_ids" not in te1_sd:
270-
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
264+
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
265+
if "text_model.embeddings.position_ids" in te1_sd:
266+
te1_sd.pop("text_model.embeddings.position_ids")
271267

272268
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
273269
print("text encoder 1:", info1)

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
accelerate==0.23.0
2-
transformers==4.30.2
3-
diffusers[torch]==0.21.2
1+
accelerate==0.25.0
2+
transformers==4.36.2
3+
diffusers[torch]==0.25.0
44
ftfy==6.1.1
55
# albumentations==1.3.0
66
opencv-python==4.7.0.68
@@ -14,7 +14,7 @@ altair==4.2.2
1414
easygui==0.98.3
1515
toml==0.10.2
1616
voluptuous==0.13.1
17-
huggingface-hub==0.15.1
17+
huggingface-hub==0.20.1
1818
# for BLIP captioning
1919
# requests==2.28.2
2020
# timm==0.6.12

0 commit comments

Comments
 (0)