From 6561e57662a1ebb1306833fb857d42aca2335962 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 00:36:24 +0900 Subject: [PATCH 001/482] Update unet.py ImageProjectionCustomized in unet --- src/diffusers/loaders/unet.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c68349c36dba..10d436d27526 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -23,6 +23,7 @@ from huggingface_hub.utils import validate_hf_hub_args from ..models.embeddings import ( + ImageProjectionCustomized, ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterFaceIDPlusImageProjection, @@ -564,7 +565,20 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us image_projection = None init_context = init_empty_weights if low_cpu_mem_usage else nullcontext - if "proj.weight" in state_dict: + if state_dict is None: + # IP-Adapter-Customized + num_image_text_embeds = 1 # fixed + clip_embeddings_dim = 1024 # fixed by CLIP model + cross_attention_dim = 2048 # fixed by sdxl two text encoders + + with init_context(): + image_projection = ImageProjectionCustomized( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + elif "proj.weight" in state_dict: # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] From 8df0993bb7403667a3342fbe1032cc5e20645a78 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 00:45:40 +0900 Subject: [PATCH 002/482] customized img proj in unet --- src/diffusers/loaders/unet.py | 374 +++++++++++++++++----------------- 1 file changed, 189 insertions(+), 185 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 10d436d27526..f3676a83b9d7 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -577,195 +577,196 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us image_embed_dim=clip_embeddings_dim, num_image_text_embeds=num_image_text_embeds, ) + else: + if "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 - elif "proj.weight" in state_dict: - # IP-Adapter - num_image_text_embeds = 4 - clip_embeddings_dim = state_dict["proj.weight"].shape[-1] - cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 - - with init_context(): - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim, - num_image_text_embeds=num_image_text_embeds, - ) - - for key, value in state_dict.items(): - diffusers_name = key.replace("proj", "image_embeds") - updated_state_dict[diffusers_name] = value + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) - elif "proj.3.weight" in state_dict: - # IP-Adapter Full - clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] - cross_attention_dim = state_dict["proj.3.weight"].shape[0] + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value - with init_context(): - image_projection = IPAdapterFullImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim - ) + elif "proj.3.weight" in state_dict: + # IP-Adapter Full + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + cross_attention_dim = state_dict["proj.3.weight"].shape[0] - for key, value in state_dict.items(): - diffusers_name = key.replace("proj.0", "ff.net.0.proj") - diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") - diffusers_name = diffusers_name.replace("proj.3", "norm") - updated_state_dict[diffusers_name] = value + with init_context(): + image_projection = IPAdapterFullImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + ) - elif "perceiver_resampler.proj_in.weight" in state_dict: - # IP-Adapter Face ID Plus - id_embeddings_dim = state_dict["proj.0.weight"].shape[1] - embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0] - hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1] - output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0] - heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64 + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + diffusers_name = diffusers_name.replace("proj.3", "norm") + updated_state_dict[diffusers_name] = value - with init_context(): - image_projection = IPAdapterFaceIDPlusImageProjection( - embed_dims=embed_dims, - output_dims=output_dims, - hidden_dims=hidden_dims, - heads=heads, - id_embeddings_dim=id_embeddings_dim, - ) + elif "perceiver_resampler.proj_in.weight" in state_dict: + # IP-Adapter Face ID Plus + id_embeddings_dim = state_dict["proj.0.weight"].shape[1] + embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0] + hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1] + output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0] + heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64 - for key, value in state_dict.items(): - diffusers_name = key.replace("perceiver_resampler.", "") - diffusers_name = diffusers_name.replace("0.to", "attn.to") - diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.") - diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight") - diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.") - diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight") - diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.") - diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight") - diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.") - diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight") - diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0") - diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1") - diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0") - diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1") - diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0") - diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1") - diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0") - diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1") - - if "norm1" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value - elif "norm2" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value - elif "to_kv" in diffusers_name: - v_chunk = value.chunk(2, dim=0) - updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] - updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] - elif "to_out" in diffusers_name: - updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value - elif "proj.0.weight" == diffusers_name: - updated_state_dict["proj.net.0.proj.weight"] = value - elif "proj.0.bias" == diffusers_name: - updated_state_dict["proj.net.0.proj.bias"] = value - elif "proj.2.weight" == diffusers_name: - updated_state_dict["proj.net.2.weight"] = value - elif "proj.2.bias" == diffusers_name: - updated_state_dict["proj.net.2.bias"] = value - else: - updated_state_dict[diffusers_name] = value + with init_context(): + image_projection = IPAdapterFaceIDPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + id_embeddings_dim=id_embeddings_dim, + ) - elif "norm.weight" in state_dict: - # IP-Adapter Face ID - id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] - id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] - multiplier = id_embeddings_dim_out // id_embeddings_dim_in - norm_layer = "norm.weight" - cross_attention_dim = state_dict[norm_layer].shape[0] - num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + for key, value in state_dict.items(): + diffusers_name = key.replace("perceiver_resampler.", "") + diffusers_name = diffusers_name.replace("0.to", "attn.to") + diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.") + diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.") + diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.") + diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.") + diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0") + diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1") + diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0") + diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1") + diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0") + diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1") + diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0") + diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + elif "proj.0.weight" == diffusers_name: + updated_state_dict["proj.net.0.proj.weight"] = value + elif "proj.0.bias" == diffusers_name: + updated_state_dict["proj.net.0.proj.bias"] = value + elif "proj.2.weight" == diffusers_name: + updated_state_dict["proj.net.2.weight"] = value + elif "proj.2.bias" == diffusers_name: + updated_state_dict["proj.net.2.bias"] = value + else: + updated_state_dict[diffusers_name] = value - with init_context(): - image_projection = IPAdapterFaceIDImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=id_embeddings_dim_in, - mult=multiplier, - num_tokens=num_tokens, - ) + elif "norm.weight" in state_dict: + # IP-Adapter Face ID + id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = id_embeddings_dim_out // id_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim - for key, value in state_dict.items(): - diffusers_name = key.replace("proj.0", "ff.net.0.proj") - diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") - updated_state_dict[diffusers_name] = value + with init_context(): + image_projection = IPAdapterFaceIDImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=id_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, + ) - else: - # IP-Adapter Plus - num_image_text_embeds = state_dict["latents"].shape[1] - embed_dims = state_dict["proj_in.weight"].shape[1] - output_dims = state_dict["proj_out.weight"].shape[0] - hidden_dims = state_dict["latents"].shape[2] - attn_key_present = any("attn" in k for k in state_dict) - heads = ( - state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 - if attn_key_present - else state_dict["layers.0.0.to_q.weight"].shape[0] // 64 - ) + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + updated_state_dict[diffusers_name] = value - with init_context(): - image_projection = IPAdapterPlusImageProjection( - embed_dims=embed_dims, - output_dims=output_dims, - hidden_dims=hidden_dims, - heads=heads, - num_queries=num_image_text_embeds, + else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["latents"].shape[1] + embed_dims = state_dict["proj_in.weight"].shape[1] + output_dims = state_dict["proj_out.weight"].shape[0] + hidden_dims = state_dict["latents"].shape[2] + attn_key_present = any("attn" in k for k in state_dict) + heads = ( + state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 + if attn_key_present + else state_dict["layers.0.0.to_q.weight"].shape[0] // 64 ) - for key, value in state_dict.items(): - diffusers_name = key.replace("0.to", "2.to") - - diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0") - diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1") - diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0") - diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1") - diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0") - diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1") - diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0") - diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1") - - if "to_kv" in diffusers_name: - parts = diffusers_name.split(".") - parts[2] = "attn" - diffusers_name = ".".join(parts) - v_chunk = value.chunk(2, dim=0) - updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] - updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] - elif "to_q" in diffusers_name: - parts = diffusers_name.split(".") - parts[2] = "attn" - diffusers_name = ".".join(parts) - updated_state_dict[diffusers_name] = value - elif "to_out" in diffusers_name: - parts = diffusers_name.split(".") - parts[2] = "attn" - diffusers_name = ".".join(parts) - updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value - else: - diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0") - diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj") - diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2") + with init_context(): + image_projection = IPAdapterPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) - diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0") - diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj") - diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2") + for key, value in state_dict.items(): + diffusers_name = key.replace("0.to", "2.to") + + diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0") + diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1") + diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0") + diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1") + diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0") + diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1") + diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0") + diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1") + + if "to_kv" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_q" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + updated_state_dict[diffusers_name] = value + elif "to_out" in diffusers_name: + parts = diffusers_name.split(".") + parts[2] = "attn" + diffusers_name = ".".join(parts) + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + else: + diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0") + diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2") - diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0") - diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj") - diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2") + diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0") + diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2") - diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0") - diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj") - diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2") - updated_state_dict[diffusers_name] = value + diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0") + diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2") + + diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0") + diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj") + diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2") + updated_state_dict[diffusers_name] = value if not low_cpu_mem_usage: - image_projection.load_state_dict(updated_state_dict, strict=True) + if state_dict is not None: + image_projection.load_state_dict(updated_state_dict, strict=True) else: load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) @@ -826,21 +827,24 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F ) num_image_text_embeds = [] for state_dict in state_dicts: - if "proj.weight" in state_dict["image_proj"]: - # IP-Adapter - num_image_text_embeds += [4] - elif "proj.3.weight" in state_dict["image_proj"]: - # IP-Adapter Full Face - num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token - elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]: - # IP-Adapter Face ID Plus - num_image_text_embeds += [4] - elif "norm.weight" in state_dict["image_proj"]: - # IP-Adapter Face ID - num_image_text_embeds += [4] - else: - # IP-Adapter Plus - num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] + if state_dict["image_proj"] is None: + num_image_text_embeds += [1] + else: + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + num_image_text_embeds += [4] + elif "proj.3.weight" in state_dict["image_proj"]: + # IP-Adapter Full Face + num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID Plus + num_image_text_embeds += [4] + elif "norm.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID + num_image_text_embeds += [4] + else: + # IP-Adapter Plus + num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] with init_context(): attn_procs[name] = attn_processor_class( From 671c10ba7c15422f24f178cb43d605d461758c57 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 00:50:51 +0900 Subject: [PATCH 003/482] ImageProjectionCustomized in embeddings --- src/diffusers/models/embeddings.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 390b752abe15..9bbf57e151a3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1519,6 +1519,25 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): return torch.cat([image_text_embeds, text_embeds], dim=1) +class ImageProjectionCustomized(nn.Module): + def __init__( + self, + image_embed_dim: int = 1024, + cross_attention_dim: int = 2048, + num_image_text_embeds: int = 1, + ): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_image_text_embeds = num_image_text_embeds + + def forward(self, image_embeds: torch.Tensor): + image_embeds = image_embeds.repeat(1, 2*self.num_image_text_embeds) + image_embeds = image_embeds.reshape( + -1, self.num_image_text_embeds, self.cross_attention_dim + ) + return image_embeds + class ImageProjection(nn.Module): def __init__( self, From d1daa36a310344d7bcbe43721ae188238ea91d50 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:23:13 +0900 Subject: [PATCH 004/482] print size --- src/diffusers/models/embeddings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 9bbf57e151a3..1ce3db018573 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1532,7 +1532,9 @@ def __init__( self.num_image_text_embeds = num_image_text_embeds def forward(self, image_embeds: torch.Tensor): + print(f'image_embeds={image_embeds.size()}') image_embeds = image_embeds.repeat(1, 2*self.num_image_text_embeds) + print(f'image_embeds after={image_embeds.size()}') image_embeds = image_embeds.reshape( -1, self.num_image_text_embeds, self.cross_attention_dim ) From f0faa06bc7668be210a9ed98080482f4e41a9529 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:30:34 +0900 Subject: [PATCH 005/482] print --- src/diffusers/loaders/unet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index f3676a83b9d7..acd35ba36125 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -582,7 +582,9 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + print('clip_embeddings_dim={clip_embeddings_dim}') cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 + print(f'cross_attention_dim={cross_attention_dim}') with init_context(): image_projection = ImageProjection( From 5c49f75e43ca69177f966cc0fbd333c356fe788f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:32:28 +0900 Subject: [PATCH 006/482] print --- src/diffusers/loaders/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index acd35ba36125..bf3cb2cbbcba 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -582,7 +582,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] - print('clip_embeddings_dim={clip_embeddings_dim}') + print(f'clip_embeddings_dim={clip_embeddings_dim}') cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 print(f'cross_attention_dim={cross_attention_dim}') From a49240d97c01eec50903935d9cff9429ffb5a25f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:37:43 +0900 Subject: [PATCH 007/482] print --- src/diffusers/models/embeddings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1ce3db018573..ee3b826cfd07 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1559,6 +1559,7 @@ def forward(self, image_embeds: torch.Tensor): # image image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype)) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + print(f'image_embeds after reshape={image_embeds.size()}') image_embeds = self.norm(image_embeds) return image_embeds From 245a3fa5765f189d4cb56b026dab6c8e0ef3fb1c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:42:00 +0900 Subject: [PATCH 008/482] print --- src/diffusers/loaders/unet.py | 3 +++ src/diffusers/models/embeddings.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index bf3cb2cbbcba..b437fc181a8a 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -571,6 +571,9 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us clip_embeddings_dim = 1024 # fixed by CLIP model cross_attention_dim = 2048 # fixed by sdxl two text encoders + print(f'clip_embeddings_dim={clip_embeddings_dim}') + print(f'cross_attention_dim={cross_attention_dim}') + with init_context(): image_projection = ImageProjectionCustomized( cross_attention_dim=cross_attention_dim, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index ee3b826cfd07..9a0e5f6881ad 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1534,10 +1534,11 @@ def __init__( def forward(self, image_embeds: torch.Tensor): print(f'image_embeds={image_embeds.size()}') image_embeds = image_embeds.repeat(1, 2*self.num_image_text_embeds) - print(f'image_embeds after={image_embeds.size()}') + print(f'image_embeds after repeat={image_embeds.size()}') image_embeds = image_embeds.reshape( -1, self.num_image_text_embeds, self.cross_attention_dim ) + print(f'image_embeds after reshape={image_embeds.size()}') return image_embeds class ImageProjection(nn.Module): From a8b578c125de7c8c135a91424833077cf5332d06 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:47:08 +0900 Subject: [PATCH 009/482] print image_proj --- src/diffusers/loaders/unet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index b437fc181a8a..5ac9b8a71250 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -851,6 +851,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] + print(f'num_image_text_embeds={num_image_text_embeds}') with init_context(): attn_procs[name] = attn_processor_class( hidden_size=hidden_size, From 8a4f4a04f41400bc024445561f14916785252ff1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:50:31 +0900 Subject: [PATCH 010/482] MultiIPAdapterImageProjection --- src/diffusers/models/embeddings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 9a0e5f6881ad..9c116765ab4f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2628,6 +2628,7 @@ def forward(self, image_embeds: List[torch.Tensor]): for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + print(f'MultiIPAdapterImageProjection image_embed={image_embed}') image_embed = image_projection_layer(image_embed) image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) From f08f179513b71bda718256584264e6706a810f27 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:52:35 +0900 Subject: [PATCH 011/482] print print(f'MultiIPAdapterImageProjection image_embed after={image_embed.size()}') --- src/diffusers/models/embeddings.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 9c116765ab4f..90807b7af185 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2628,9 +2628,11 @@ def forward(self, image_embeds: List[torch.Tensor]): for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) - print(f'MultiIPAdapterImageProjection image_embed={image_embed}') + print(f'MultiIPAdapterImageProjection image_embed before={image_embed.size()}') image_embed = image_projection_layer(image_embed) + print(f'MultiIPAdapterImageProjection image_embed after={image_embed.size()}') image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) + print(f'MultiIPAdapterImageProjection image_embed reshape={image_embed.size()}') projected_image_embeds.append(image_embed) From bc3719479891e4e7dea4df57dc8bcc818241018f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:02:59 +0900 Subject: [PATCH 012/482] print(f'unet 2d image_embeds before={image_embeds.size()}') --- src/diffusers/models/unets/unet_2d_condition.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 5674d8ba26ec..1cf27ba06fe1 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1031,7 +1031,9 @@ def process_encoder_hidden_states( encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") + print(f'unet 2d image_embeds before={image_embeds.size()}') image_embeds = self.encoder_hid_proj(image_embeds) + print(f'unet 2d image_embeds after={image_embeds.size()}') encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states From 70294a68a93de95f77c2ec5e223d6b9c91b30c97 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:05:27 +0900 Subject: [PATCH 013/482] fix bug --- src/diffusers/models/unets/unet_2d_condition.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 1cf27ba06fe1..25fb7c41f840 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1031,9 +1031,11 @@ def process_encoder_hidden_states( encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") - print(f'unet 2d image_embeds before={image_embeds.size()}') + for image_embed in image_embeds: + print(f'unet 2d image_embeds before={image_embed.size()}') image_embeds = self.encoder_hid_proj(image_embeds) - print(f'unet 2d image_embeds after={image_embeds.size()}') + for image_embed in image_embeds: + print(f'unet 2d image_embeds after={image_embed.size()}') encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states From d531aa31d2d053c0c1ded087627462e52bcc2360 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:14:51 +0900 Subject: [PATCH 014/482] Encode ip_adapter_image --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 77d496cf831d..0bbb13173f45 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -583,6 +583,7 @@ def prepare_ip_adapter_image_embeds( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + print(f'output_hidden_state={output_hidden_state}') single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) @@ -1311,6 +1312,7 @@ def __call__( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, @@ -1318,6 +1320,8 @@ def __call__( batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) + for image_embed in image_embeds: + print(f'Encode ip_adapter_image image_embed={image_embed.size()}') # 4. Prepare image if isinstance(controlnet, ControlNetModel): From 8313c908954d5287c98303e010fe246eb84cb967 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:26:19 +0900 Subject: [PATCH 015/482] correct output_hidden_state --- src/diffusers/models/__init__.py | 3 ++- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 853f149fe01c..33f3882450e1 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -114,6 +114,7 @@ ConsistencyDecoderVAE, VQModel, ) + from .cache_utils import CacheMixin from .controlnets import ( ControlNetModel, @@ -130,7 +131,7 @@ SparseControlNetModel, UNetControlNetXSModel, ) - from .embeddings import ImageProjection + from .embeddings import ImageProjection, ImageProjectionCustomized from .modeling_utils import ModelMixin from .transformers import ( AllegroTransformer3DModel, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0bbb13173f45..0b45577509cd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -38,7 +38,9 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel + +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, ImageProjectionCustomized, MultiControlNetModel, UNet2DConditionModel + from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, @@ -582,7 +584,7 @@ def prepare_ip_adapter_image_embeds( for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) and not isinstance(image_proj_layer, ImageProjectionCustomized) print(f'output_hidden_state={output_hidden_state}') single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state From f70680fb07d8795db1893c4bdf439b927645cd01 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:29:12 +0900 Subject: [PATCH 016/482] fix bug --- src/diffusers/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 33f3882450e1..8c2ecfeb0f13 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -52,7 +52,7 @@ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] - _import_structure["embeddings"] = ["ImageProjection"] + _import_structure["embeddings"] = ["ImageProjection", "ImageProjectionCustomized"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] From 294a85caa96405c8fdb77f694a6c8542bb335673 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 9 Aug 2024 14:37:14 +0900 Subject: [PATCH 017/482] remove print --- src/diffusers/loaders/unet.py | 8 +------- src/diffusers/models/embeddings.py | 9 +-------- src/diffusers/models/unets/unet_2d_condition.py | 4 ---- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 5 +---- 4 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 5ac9b8a71250..790416a5a276 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -571,9 +571,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us clip_embeddings_dim = 1024 # fixed by CLIP model cross_attention_dim = 2048 # fixed by sdxl two text encoders - print(f'clip_embeddings_dim={clip_embeddings_dim}') - print(f'cross_attention_dim={cross_attention_dim}') - with init_context(): image_projection = ImageProjectionCustomized( cross_attention_dim=cross_attention_dim, @@ -585,10 +582,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] - print(f'clip_embeddings_dim={clip_embeddings_dim}') cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 - print(f'cross_attention_dim={cross_attention_dim}') - + with init_context(): image_projection = ImageProjection( cross_attention_dim=cross_attention_dim, @@ -851,7 +846,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] - print(f'num_image_text_embeds={num_image_text_embeds}') with init_context(): attn_procs[name] = attn_processor_class( hidden_size=hidden_size, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 90807b7af185..d19b14bbe697 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1532,13 +1532,10 @@ def __init__( self.num_image_text_embeds = num_image_text_embeds def forward(self, image_embeds: torch.Tensor): - print(f'image_embeds={image_embeds.size()}') image_embeds = image_embeds.repeat(1, 2*self.num_image_text_embeds) - print(f'image_embeds after repeat={image_embeds.size()}') image_embeds = image_embeds.reshape( -1, self.num_image_text_embeds, self.cross_attention_dim ) - print(f'image_embeds after reshape={image_embeds.size()}') return image_embeds class ImageProjection(nn.Module): @@ -1560,7 +1557,6 @@ def forward(self, image_embeds: torch.Tensor): # image image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype)) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) - print(f'image_embeds after reshape={image_embeds.size()}') image_embeds = self.norm(image_embeds) return image_embeds @@ -2628,12 +2624,9 @@ def forward(self, image_embeds: List[torch.Tensor]): for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) - print(f'MultiIPAdapterImageProjection image_embed before={image_embed.size()}') image_embed = image_projection_layer(image_embed) - print(f'MultiIPAdapterImageProjection image_embed after={image_embed.size()}') image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) - print(f'MultiIPAdapterImageProjection image_embed reshape={image_embed.size()}') - + projected_image_embeds.append(image_embed) return projected_image_embeds diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 25fb7c41f840..5674d8ba26ec 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1031,11 +1031,7 @@ def process_encoder_hidden_states( encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") - for image_embed in image_embeds: - print(f'unet 2d image_embeds before={image_embed.size()}') image_embeds = self.encoder_hid_proj(image_embeds) - for image_embed in image_embeds: - print(f'unet 2d image_embeds after={image_embed.size()}') encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0b45577509cd..e07c628bf0f1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -585,7 +585,6 @@ def prepare_ip_adapter_image_embeds( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) and not isinstance(image_proj_layer, ImageProjectionCustomized) - print(f'output_hidden_state={output_hidden_state}') single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) @@ -1322,9 +1321,7 @@ def __call__( batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) - for image_embed in image_embeds: - print(f'Encode ip_adapter_image image_embed={image_embed.size()}') - + # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( From 72a86fff1f77e6588e776c5a8465294c6a5bb058 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 22 Oct 2024 13:43:25 +0900 Subject: [PATCH 018/482] safety checker for nsfw content --- .../controlnet/pipeline_controlnet_sd_xl.py | 69 +++++++++- .../pipelines/controlnet/safety_checker.py | 125 ++++++++++++++++++ .../stable_diffusion_xl/pipeline_output.py | 10 +- 3 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 src/diffusers/pipelines/controlnet/safety_checker.py diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index e07c628bf0f1..5a04643674e8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -64,6 +64,7 @@ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + from ...utils import is_torch_xla_available @@ -74,6 +75,10 @@ else: XLA_AVAILABLE = False +from .multicontrolnet import MultiControlNetModel +from .safety_checker import StableDiffusionSafetyChecker + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -225,6 +230,12 @@ class StableDiffusionXLControlNetPipeline( scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings should always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. @@ -243,7 +254,9 @@ class StableDiffusionXLControlNetPipeline( "text_encoder_2", "feature_extractor", "image_encoder", + "safety_checker", ] + _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -265,10 +278,12 @@ def __init__( unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, ): super().__init__() @@ -284,6 +299,7 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) @@ -299,7 +315,23 @@ def __init__( else: self.watermark = None - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( @@ -611,6 +643,20 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1205,6 +1251,13 @@ def __call__( [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) @@ -1592,24 +1645,30 @@ def __call__( latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] - + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents + has_nsfw_concept = None if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) - image = self.image_processor.postprocess(image, output_type=output_type) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (image, has_nsfw_concept) - return StableDiffusionXLPipelineOutput(images=image) + return StableDiffusionXLPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/safety_checker.py b/src/diffusers/pipelines/controlnet/safety_checker.py new file mode 100644 index 000000000000..227101837527 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/safety_checker.py @@ -0,0 +1,125 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py index 0783f44486ee..7c1cab92ee2b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Union +from typing import List, Optional, Union import numpy as np import PIL.Image @@ -16,9 +16,13 @@ class StableDiffusionXLPipelineOutput(BaseOutput): images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] if is_flax_available(): @@ -32,6 +36,10 @@ class FlaxStableDiffusionXLPipelineOutput(BaseOutput): Args: images (`np.ndarray`) Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. + nsfw_content_detected (`List[bool]`): + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content + or `None` if safety checking could not be performed. """ images: np.ndarray + nsfw_content_detected: Optional[List[bool]] From a916e1fea0fe25473f2e259f3865f7973e6aaaa4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 22 Oct 2024 14:10:38 +0900 Subject: [PATCH 019/482] fix bug --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 5a04643674e8..ce297b4a287a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -325,11 +325,11 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) + #if safety_checker is not None and feature_extractor is None: + # raise ValueError( + # "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + # " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + # ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, requires_safety_checker=requires_safety_checker) From dec360464bcad6565435132aabd7db1fdb1cb9ab Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 22 Oct 2024 14:56:40 +0900 Subject: [PATCH 020/482] commit --- src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py index 7c1cab92ee2b..f06a1e222117 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py @@ -6,7 +6,6 @@ from ...utils import BaseOutput, is_flax_available - @dataclass class StableDiffusionXLPipelineOutput(BaseOutput): """ From b7a2b0ef44c242f9e62f9c172c9fd7aaf498d006 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Dec 2024 10:03:24 +0900 Subject: [PATCH 021/482] return images without safe checker --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index ce297b4a287a..280d30cfaa37 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -28,6 +28,8 @@ CLIPVisionModelWithProjection, ) +import copy + from diffusers.utils.import_utils import is_invisible_watermark_available from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -652,10 +654,13 @@ def run_safety_checker(self, image, device, dtype): else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image_original = copy.deepcopy(image) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - return image, has_nsfw_concept + + #return images without safe checker + return image_original, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): From 905ba26118c745aa906e10508058d63f29ae62f7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Dec 2024 17:34:58 +0900 Subject: [PATCH 022/482] fix bug --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 280d30cfaa37..c396e213c52f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -646,6 +646,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds def run_safety_checker(self, image, device, dtype): + image_original = copy.deepcopy(image) if self.safety_checker is None: has_nsfw_concept = None else: @@ -654,7 +655,7 @@ def run_safety_checker(self, image, device, dtype): else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image_original = copy.deepcopy(image) + image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) From 4ff44f27715f14ad50b05aa6e30efcc89368ab74 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 17 Dec 2024 09:27:42 +0900 Subject: [PATCH 023/482] update textual inversion script --- src/diffusers/loaders/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index e756bb5d4956..bd1897e6bd96 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -577,4 +577,4 @@ def unload_textual_inversion( text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0) text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim) text_embeddings_filtered.weight.data = text_embedding_weights - text_encoder.set_input_embeddings(text_embeddings_filtered) + text_encoder.set_input_embeddings(text_embeddings_filtered) \ No newline at end of file From ebee2e4478bf579b32b790d4a74a14b7c61e6d8e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 00:36:24 +0900 Subject: [PATCH 024/482] Update unet.py ImageProjectionCustomized in unet --- src/diffusers/loaders/unet.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 790416a5a276..73d6704bbacc 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -571,6 +571,19 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us clip_embeddings_dim = 1024 # fixed by CLIP model cross_attention_dim = 2048 # fixed by sdxl two text encoders + with init_context(): + image_projection = ImageProjectionCustomized( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + elif "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 + with init_context(): image_projection = ImageProjectionCustomized( cross_attention_dim=cross_attention_dim, From 8d2a7a08d16cbbbf815ef1aa2e9f5e080bfb453e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 00:45:40 +0900 Subject: [PATCH 025/482] customized img proj in unet --- src/diffusers/loaders/unet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 73d6704bbacc..db8b2ec65552 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -590,13 +590,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us image_embed_dim=clip_embeddings_dim, num_image_text_embeds=num_image_text_embeds, ) + else: if "proj.weight" in state_dict: # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 - + with init_context(): image_projection = ImageProjection( cross_attention_dim=cross_attention_dim, From 72a362df9ab98b50b1707d389bed364122f0f34d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:23:13 +0900 Subject: [PATCH 026/482] print size --- src/diffusers/models/embeddings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d19b14bbe697..bad77d7cb398 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1532,7 +1532,9 @@ def __init__( self.num_image_text_embeds = num_image_text_embeds def forward(self, image_embeds: torch.Tensor): + print(f'image_embeds={image_embeds.size()}') image_embeds = image_embeds.repeat(1, 2*self.num_image_text_embeds) + print(f'image_embeds after={image_embeds.size()}') image_embeds = image_embeds.reshape( -1, self.num_image_text_embeds, self.cross_attention_dim ) From 00fda2ff90f9d69b77fe6688ea177db73e172caf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:30:34 +0900 Subject: [PATCH 027/482] print --- src/diffusers/loaders/unet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index db8b2ec65552..a418626122fc 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -596,7 +596,9 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + print('clip_embeddings_dim={clip_embeddings_dim}') cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 + print(f'cross_attention_dim={cross_attention_dim}') with init_context(): image_projection = ImageProjection( From 468b16be9627171b425136deeaadefa3a670fb48 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:32:28 +0900 Subject: [PATCH 028/482] print --- src/diffusers/loaders/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index a418626122fc..012946dd8082 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -596,7 +596,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] - print('clip_embeddings_dim={clip_embeddings_dim}') + print(f'clip_embeddings_dim={clip_embeddings_dim}') cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 print(f'cross_attention_dim={cross_attention_dim}') From ae7fc59b6891c88cde079b4affd38443faf410a8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:37:43 +0900 Subject: [PATCH 029/482] print --- src/diffusers/models/embeddings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bad77d7cb398..32f7406cd21c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1557,8 +1557,10 @@ def forward(self, image_embeds: torch.Tensor): batch_size = image_embeds.shape[0] # image + image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype)) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + print(f'image_embeds after reshape={image_embeds.size()}') image_embeds = self.norm(image_embeds) return image_embeds From 2b6d5fd4fc29a617555c992e5578aa84e7a5b8df Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:42:00 +0900 Subject: [PATCH 030/482] print --- src/diffusers/loaders/unet.py | 3 +++ src/diffusers/models/embeddings.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 012946dd8082..a91b205a2d16 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -571,6 +571,9 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us clip_embeddings_dim = 1024 # fixed by CLIP model cross_attention_dim = 2048 # fixed by sdxl two text encoders + print(f'clip_embeddings_dim={clip_embeddings_dim}') + print(f'cross_attention_dim={cross_attention_dim}') + with init_context(): image_projection = ImageProjectionCustomized( cross_attention_dim=cross_attention_dim, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 32f7406cd21c..73b3be8d67cd 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1534,10 +1534,11 @@ def __init__( def forward(self, image_embeds: torch.Tensor): print(f'image_embeds={image_embeds.size()}') image_embeds = image_embeds.repeat(1, 2*self.num_image_text_embeds) - print(f'image_embeds after={image_embeds.size()}') + print(f'image_embeds after repeat={image_embeds.size()}') image_embeds = image_embeds.reshape( -1, self.num_image_text_embeds, self.cross_attention_dim ) + print(f'image_embeds after reshape={image_embeds.size()}') return image_embeds class ImageProjection(nn.Module): From 605c510569737d9af9318c4fc2bcf4c636a2e6dd Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:47:08 +0900 Subject: [PATCH 031/482] print image_proj --- src/diffusers/loaders/unet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index a91b205a2d16..bfd903ac9989 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -865,6 +865,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] + print(f'num_image_text_embeds={num_image_text_embeds}') with init_context(): attn_procs[name] = attn_processor_class( hidden_size=hidden_size, From f09146a8dee8ae6319e1e920401c47a1e88903da Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:50:31 +0900 Subject: [PATCH 032/482] MultiIPAdapterImageProjection --- src/diffusers/models/embeddings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 73b3be8d67cd..1b271060a1a0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2629,6 +2629,7 @@ def forward(self, image_embeds: List[torch.Tensor]): for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + print(f'MultiIPAdapterImageProjection image_embed={image_embed}') image_embed = image_projection_layer(image_embed) image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) From b07224848b419f98d088ad4e416b566daf16f56c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 01:52:35 +0900 Subject: [PATCH 033/482] print print(f'MultiIPAdapterImageProjection image_embed after={image_embed.size()}') --- src/diffusers/models/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1b271060a1a0..607fb206e661 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2629,10 +2629,10 @@ def forward(self, image_embeds: List[torch.Tensor]): for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) - print(f'MultiIPAdapterImageProjection image_embed={image_embed}') + print(f'MultiIPAdapterImageProjection image_embed before={image_embed.size()}') image_embed = image_projection_layer(image_embed) + print(f'MultiIPAdapterImageProjection image_embed after={image_embed.size()}') image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) - projected_image_embeds.append(image_embed) return projected_image_embeds From a4da08109b4829375be6d9930f896c19bb3c4e8d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:02:59 +0900 Subject: [PATCH 034/482] print(f'unet 2d image_embeds before={image_embeds.size()}') --- src/diffusers/models/unets/unet_2d_condition.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 5674d8ba26ec..1cf27ba06fe1 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1031,7 +1031,9 @@ def process_encoder_hidden_states( encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") + print(f'unet 2d image_embeds before={image_embeds.size()}') image_embeds = self.encoder_hid_proj(image_embeds) + print(f'unet 2d image_embeds after={image_embeds.size()}') encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states From bc420a04fec5300d580077f4e82d099386ac7009 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:05:27 +0900 Subject: [PATCH 035/482] fix bug --- src/diffusers/models/unets/unet_2d_condition.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 1cf27ba06fe1..25fb7c41f840 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1031,9 +1031,11 @@ def process_encoder_hidden_states( encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") - print(f'unet 2d image_embeds before={image_embeds.size()}') + for image_embed in image_embeds: + print(f'unet 2d image_embeds before={image_embed.size()}') image_embeds = self.encoder_hid_proj(image_embeds) - print(f'unet 2d image_embeds after={image_embeds.size()}') + for image_embed in image_embeds: + print(f'unet 2d image_embeds after={image_embed.size()}') encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states From c8cdfb1a22de1d82be1afdcbfbea0f30717d164c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:14:51 +0900 Subject: [PATCH 036/482] Encode ip_adapter_image --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c396e213c52f..02777d1ad3a9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1380,7 +1380,7 @@ def __call__( batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) - + # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( From 42c7679287f34ca8e05d2e292be5f3b44f597358 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 6 Aug 2024 02:26:19 +0900 Subject: [PATCH 037/482] correct output_hidden_state --- src/diffusers/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 8c2ecfeb0f13..0cf108dec292 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -131,6 +131,7 @@ SparseControlNetModel, UNetControlNetXSModel, ) + from .embeddings import ImageProjection, ImageProjectionCustomized from .modeling_utils import ModelMixin from .transformers import ( From 81a16bec06b9fd01e5a3876315accc27a9998596 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 9 Aug 2024 14:37:14 +0900 Subject: [PATCH 038/482] remove print --- src/diffusers/loaders/unet.py | 8 +------- src/diffusers/models/embeddings.py | 8 ++------ src/diffusers/models/unets/unet_2d_condition.py | 4 ---- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index bfd903ac9989..5d0bfa5a47ad 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -571,9 +571,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us clip_embeddings_dim = 1024 # fixed by CLIP model cross_attention_dim = 2048 # fixed by sdxl two text encoders - print(f'clip_embeddings_dim={clip_embeddings_dim}') - print(f'cross_attention_dim={cross_attention_dim}') - with init_context(): image_projection = ImageProjectionCustomized( cross_attention_dim=cross_attention_dim, @@ -599,10 +596,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] - print(f'clip_embeddings_dim={clip_embeddings_dim}') cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 - print(f'cross_attention_dim={cross_attention_dim}') - + with init_context(): image_projection = ImageProjection( cross_attention_dim=cross_attention_dim, @@ -865,7 +860,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] - print(f'num_image_text_embeds={num_image_text_embeds}') with init_context(): attn_procs[name] = attn_processor_class( hidden_size=hidden_size, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 607fb206e661..53c5403fe4dc 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1532,13 +1532,10 @@ def __init__( self.num_image_text_embeds = num_image_text_embeds def forward(self, image_embeds: torch.Tensor): - print(f'image_embeds={image_embeds.size()}') image_embeds = image_embeds.repeat(1, 2*self.num_image_text_embeds) - print(f'image_embeds after repeat={image_embeds.size()}') image_embeds = image_embeds.reshape( -1, self.num_image_text_embeds, self.cross_attention_dim ) - print(f'image_embeds after reshape={image_embeds.size()}') return image_embeds class ImageProjection(nn.Module): @@ -1559,9 +1556,10 @@ def forward(self, image_embeds: torch.Tensor): # image + image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype)) + image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) - print(f'image_embeds after reshape={image_embeds.size()}') image_embeds = self.norm(image_embeds) return image_embeds @@ -2629,9 +2627,7 @@ def forward(self, image_embeds: List[torch.Tensor]): for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) - print(f'MultiIPAdapterImageProjection image_embed before={image_embed.size()}') image_embed = image_projection_layer(image_embed) - print(f'MultiIPAdapterImageProjection image_embed after={image_embed.size()}') image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) projected_image_embeds.append(image_embed) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 25fb7c41f840..5674d8ba26ec 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1031,11 +1031,7 @@ def process_encoder_hidden_states( encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") - for image_embed in image_embeds: - print(f'unet 2d image_embeds before={image_embed.size()}') image_embeds = self.encoder_hid_proj(image_embeds) - for image_embed in image_embeds: - print(f'unet 2d image_embeds after={image_embed.size()}') encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states From aeddea5012bedc405ca013ec483562c5c4b25bf2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 22 Oct 2024 13:43:25 +0900 Subject: [PATCH 039/482] safety checker for nsfw content --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 02777d1ad3a9..f532129acf65 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -332,6 +332,7 @@ def __init__( # "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" # " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." # ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, requires_safety_checker=requires_safety_checker) From 15ec51ee2a781d8449b2bbab2425a274e8f4ce92 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 22 Oct 2024 14:10:38 +0900 Subject: [PATCH 040/482] fix bug --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index f532129acf65..02777d1ad3a9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -332,7 +332,6 @@ def __init__( # "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" # " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." # ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, requires_safety_checker=requires_safety_checker) From da2427aee437fa9b03a1725cf2cdc4068591f385 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Dec 2024 10:03:24 +0900 Subject: [PATCH 041/482] return images without safe checker --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 02777d1ad3a9..b203c79bf130 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -655,7 +655,7 @@ def run_safety_checker(self, image, device, dtype): else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - + image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) From ef3659144f08721ad53f925ae7b3daebb381cb3f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 24 Dec 2024 16:51:28 +0800 Subject: [PATCH 042/482] custom IPAttentionProcessor for both text and ip-adpater masks --- src/diffusers/models/attention_processor.py | 59 +++++++++++++-------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8bba5a82bc2f..32ae324186a2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5035,10 +5035,9 @@ def __call__( return hidden_states - -class IPAdapterAttnProcessor2_0(torch.nn.Module): +class CustomIPAdapterAttnProcessor2_0(torch.nn.Module): r""" - Attention processor for IP-Adapter for PyTorch 2.0. + Custom Attention processor for Both Text and IP-Adapter Masks for PyTorch 2.0. Args: hidden_size (`int`): @@ -5136,26 +5135,6 @@ def __call__( elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - if ip_adapter_masks is not None: if not isinstance(ip_adapter_masks, List): # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] @@ -5189,6 +5168,40 @@ def __call__( else: ip_adapter_masks = [None] * len(self.scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if ip_adapter_masks is not None: + # [scale] * mask.shape[1] for setting strength. disabled by default + scale = ip_adapter_masks[0].shape[1] + print(f'text prompt strength = {scale}') + mask_downsample = IPAdapterMaskProcessor.downsample( + ip_adapter_masks[0][:, i, :, :], + batch_size, + hidden_states.shape[1], + hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = scale * (hidden_states * mask_downsample) + # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks From 32ead1be42da0a4763b3d6ff94a2df566039bb1c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 24 Dec 2024 16:53:59 +0800 Subject: [PATCH 043/482] keep original IPAdapterAttnProcessor2_0 --- src/diffusers/models/attention_processor.py | 229 ++++++++++++++++++++ 1 file changed, 229 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 32ae324186a2..361fa33060f0 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5035,6 +5035,235 @@ def __call__( return hidden_states +class IPAdapterAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapter for PyTorch 2.0. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or `List[float]`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.Tensor] = None, + ): + residual = hidden_states + + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + scale * current_ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + class CustomIPAdapterAttnProcessor2_0(torch.nn.Module): r""" Custom Attention processor for Both Text and IP-Adapter Masks for PyTorch 2.0. From fc08377afe97ba964c88d8aaab557bb8b50d6a53 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 13:36:15 +0900 Subject: [PATCH 044/482] modify pipeline to support text prompt mask --- .../controlnet/pipeline_controlnet_sd_xl.py | 65 ++++++++++++------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index b203c79bf130..304063a3b383 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1329,7 +1329,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + batch_size = 1 #len(prompt), modified else: batch_size = prompt_embeds.shape[0] @@ -1349,27 +1349,44 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) + ## added to store multiple prompt embeds + prompt_embeds_list = [] + negative_prompt_embeds_list = [] + pooled_prompt_embeds_list = [] + negative_pooled_prompt_embeds_list = [] + + for pmt in prompt: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + pmt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + #prompt_embeds=prompt_embeds, + #negative_prompt_embeds=negative_prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + print(f'prompt_embeds shape={prompt_embeds.shape}') + prompt_embeds_list.append(prompt_embeds) + negative_prompt_embeds_list.append(negative_prompt_embeds) + pooled_prompt_embeds_list.append(pooled_prompt_embeds) + negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) + prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) + negative_prompt_embeds_list = torch.stack(negative_prompt_embeds_list,dim=1) + pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) + negative_pooled_prompt_embeds_list = torch.stack(negative_pooled_prompt_embeds_list,dim=1) + print(f'prompt_embeds_list={prompt_embeds_list.shape}') # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: @@ -1491,10 +1508,12 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) + prompt_embeds_list = prompt_embeds_list.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -1547,7 +1566,7 @@ def __call__( } else: control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds + controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): @@ -1583,7 +1602,7 @@ def __call__( noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds_list, #prompt_embeds, modified for multiple text promts timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, From d29bb4f2b75b18003d7e58f1294c8268d33ecfd5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 13:42:17 +0900 Subject: [PATCH 045/482] modify pipeline --- .../controlnet/pipeline_controlnet_sd_xl.py | 61 +++++++++++++------ 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 304063a3b383..33c5ab233044 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1356,37 +1356,58 @@ def __call__( pooled_prompt_embeds_list = [] negative_pooled_prompt_embeds_list = [] - for pmt in prompt: + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + pmt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + #prompt_embeds=prompt_embeds, + #negative_prompt_embeds=negative_prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + prompt_embeds_list.append(prompt_embeds) + negative_prompt_embeds_list.append(negative_prompt_embeds) + pooled_prompt_embeds_list.append(pooled_prompt_embeds) + negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) + prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) + negative_prompt_embeds_list = torch.stack(negative_prompt_embeds_list,dim=1) + pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) + negative_pooled_prompt_embeds_list = torch.stack(negative_pooled_prompt_embeds_list,dim=1) + else: ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( - pmt, + prompt, prompt_2, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, - #prompt_embeds=prompt_embeds, - #negative_prompt_embeds=negative_prompt_embeds, - #pooled_prompt_embeds=pooled_prompt_embeds, - #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) - print(f'prompt_embeds shape={prompt_embeds.shape}') - prompt_embeds_list.append(prompt_embeds) - negative_prompt_embeds_list.append(negative_prompt_embeds) - pooled_prompt_embeds_list.append(pooled_prompt_embeds) - negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) - prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) - negative_prompt_embeds_list = torch.stack(negative_prompt_embeds_list,dim=1) - pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) - negative_pooled_prompt_embeds_list = torch.stack(negative_pooled_prompt_embeds_list,dim=1) - print(f'prompt_embeds_list={prompt_embeds_list.shape}') + # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: @@ -1508,12 +1529,14 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) + if isinstance(prompt, list): # modifed + prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) - prompt_embeds_list = prompt_embeds_list.to(device) + if isinstance(prompt, list): # modifed + prompt_embeds_list = prompt_embeds_list.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -1602,7 +1625,7 @@ def __call__( noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=prompt_embeds_list, #prompt_embeds, modified for multiple text promts + encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, modified for multiple text promts timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, From 3f5f96f57b46acff59b90e6d95511d76240600ee Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 14:26:47 +0900 Subject: [PATCH 046/482] modified attentions --- src/diffusers/models/attention_processor.py | 156 +++++++++++++----- .../models/unets/unet_2d_condition.py | 31 +++- .../controlnet/pipeline_controlnet_sd_xl.py | 24 +-- 3 files changed, 155 insertions(+), 56 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 361fa33060f0..5d10662c0f8e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3238,6 +3238,7 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, + text_masks: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: @@ -3246,60 +3247,137 @@ def __call__( deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) + if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): + hidden_states_list = [] + if not text_masks.shape[0] == encoder_hidden_states.shape[1]: + raise ValueError( + f"Length of text masks ({text_masks.shape[0]}) must match " + f"the number of text prompts " + f"({encoder_hidden_states.shape[1]})") - input_ndim = hidden_states.ndim + for index in range(encoder_hidden_states.shape[1]): + _hidden_states = hidden_states + + if attn.spatial_norm is not None: + _hidden_states = attn.spatial_norm(_hidden_states, temb) - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + input_ndim = _hidden_states.ndim - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + if input_ndim == 4: + batch_size, channel, height, width = _hidden_states.shape + _hidden_states = _hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + batch_size, sequence_length, _ = ( + _hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states[:,index,:,:].shape + ) - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - query = attn.to_q(hidden_states) + if attn.group_norm is not None: + _hidden_states = attn.group_norm(_hidden_states.transpose(1, 2)).transpose(1, 2) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + query = attn.to_q(_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + _encoder_hidden_states = encoder_hidden_states[:,index,:,:] + if attn.norm_cross: + _encoder_hidden_states = attn.norm_encoder_hidden_states(_encoder_hidden_states) - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads + key = attn.to_k(_encoder_hidden_states) + value = attn.to_v(_encoder_hidden_states) - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + _hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + _hidden_states = _hidden_states.to(query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + text_masks[index,:,:,:], + batch_size, + _hidden_states.shape[1], + _hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states_list.append(_hidden_states * mask_downsample) + + hidden_states_list = torch.stack(hidden_states_list) + hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) + else: + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 5674d8ba26ec..e4b8434bd0b6 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1147,9 +1147,21 @@ def forward( else: emb = emb + class_emb - aug_emb = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + # thesea modifed + if len(encoder_hidden_states.shape) == 3: + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + else: + aug_emb_list = [] + for index in range(encoder_hidden_states.shape[1]): + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states[:,index,:,:], added_cond_kwargs=added_cond_kwargs + ) + aug_emb_list.append(aug_emb) + aug_emb_list = torch.stack(aug_emb_list) + aug_emb = torch.mean(aug_emb_list, dim=0, keepdim=False) + if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) @@ -1159,9 +1171,16 @@ def forward( if self.time_embed_act is not None: emb = self.time_embed_act(emb) - encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + # thesea modifed + if len(encoder_hidden_states.shape) == 3: + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + else: + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states[:,index,:,:] = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states[:,index,:,:], added_cond_kwargs=added_cond_kwargs + ) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 33c5ab233044..0bffe8488858 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -763,12 +763,14 @@ def check_inputs( # `prompt` needs more sophisticated handling when there are multiple # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) + + ## thesea modified + #if isinstance(self.controlnet, MultiControlNetModel): + # if isinstance(prompt, list): + # logger.warning( + # f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + # " prompts. The conditionings will be fixed across the prompts." + # ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( @@ -1329,7 +1331,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = 1 #len(prompt), modified + batch_size = 1 #len(prompt), thesea modified else: batch_size = prompt_embeds.shape[0] @@ -1350,7 +1352,7 @@ def __call__( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - ## added to store multiple prompt embeds + ## thesea modified prompt_embeds_list = [] negative_prompt_embeds_list = [] pooled_prompt_embeds_list = [] @@ -1529,13 +1531,13 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - if isinstance(prompt, list): # modifed + if isinstance(prompt, list): # modified prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) - if isinstance(prompt, list): # modifed + if isinstance(prompt, list): # modified prompt_embeds_list = prompt_embeds_list.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -1625,7 +1627,7 @@ def __call__( noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, modified for multiple text promts + encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for multiple text promts timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, From 6f7ef1e5c2e1ba00c37ab6672a8ac775c9b5ec1e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 14:53:34 +0900 Subject: [PATCH 047/482] print mask --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5d10662c0f8e..b92a371f0a13 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3311,6 +3311,7 @@ def __call__( _hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) _hidden_states = _hidden_states.to(query.dtype) + print(f'mask={text_masks[index,:,:,:]}') mask_downsample = IPAdapterMaskProcessor.downsample( text_masks[index,:,:,:], batch_size, @@ -5298,6 +5299,7 @@ def __call__( ) _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) + print(f'mask={mask[:, i, :, :]}') mask_downsample = IPAdapterMaskProcessor.downsample( mask[:, i, :, :], batch_size, From 76a705ca84b29ea62e2c0417404e12549197a866 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 15:23:37 +0900 Subject: [PATCH 048/482] modification to ctln --- src/diffusers/models/attention_processor.py | 56 ++++++++++--------- .../controlnet/pipeline_controlnet_sd_xl.py | 3 +- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b92a371f0a13..ffe840d44e7a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -14,6 +14,7 @@ import inspect import math from typing import Callable, List, Optional, Tuple, Union +import copy import torch import torch.nn.functional as F @@ -582,7 +583,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] @@ -3256,31 +3257,32 @@ def __call__( f"({encoder_hidden_states.shape[1]})") for index in range(encoder_hidden_states.shape[1]): - _hidden_states = hidden_states - - if attn.spatial_norm is not None: - _hidden_states = attn.spatial_norm(_hidden_states, temb) + if index == 0: + _hidden_states = copy.deepcopy(hidden_states) + + if attn.spatial_norm is not None: + _hidden_states = attn.spatial_norm(_hidden_states, temb) - input_ndim = _hidden_states.ndim + input_ndim = _hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = _hidden_states.shape - _hidden_states = _hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + if input_ndim == 4: + batch_size, channel, height, width = _hidden_states.shape + _hidden_states = _hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - batch_size, sequence_length, _ = ( - _hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states[:,index,:,:].shape - ) + batch_size, sequence_length, _ = ( + _hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states[:,index,:,:].shape + ) - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - if attn.group_norm is not None: - _hidden_states = attn.group_norm(_hidden_states.transpose(1, 2)).transpose(1, 2) + if attn.group_norm is not None: + _hidden_states = attn.group_norm(_hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(_hidden_states) + query = attn.to_q(_hidden_states) _encoder_hidden_states = encoder_hidden_states[:,index,:,:] if attn.norm_cross: @@ -3289,16 +3291,18 @@ def __call__( key = attn.to_k(_encoder_hidden_states) value = attn.to_v(_encoder_hidden_states) - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads + if index == 0: + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - if attn.norm_q is not None: - query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) @@ -3311,7 +3315,6 @@ def __call__( _hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) _hidden_states = _hidden_states.to(query.dtype) - print(f'mask={text_masks[index,:,:,:]}') mask_downsample = IPAdapterMaskProcessor.downsample( text_masks[index,:,:,:], batch_size, @@ -5299,7 +5302,6 @@ def __call__( ) _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) - print(f'mask={mask[:, i, :, :]}') mask_downsample = IPAdapterMaskProcessor.downsample( mask[:, i, :, :], batch_size, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0bffe8488858..42d845728d46 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1590,8 +1590,9 @@ def __call__( "time_ids": add_time_ids.chunk(2)[1], } else: + print(f'prompt_embeds_list in ctln') control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds + controlnet_prompt_embeds = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): From f7785e289e83d865745a39eaeae2998c5a526f1f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 15:33:09 +0900 Subject: [PATCH 049/482] fix bug --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 42d845728d46..31e9801541bb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1592,7 +1592,7 @@ def __call__( else: print(f'prompt_embeds_list in ctln') control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified + controlnet_prompt_embeds = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds #prompt_embeds, thesea modified controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): From a9bf1f72253c914eb1473f145881658272377948 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 15:52:00 +0900 Subject: [PATCH 050/482] fix ctln --- src/diffusers/models/controlnets/controlnet.py | 11 ++++++++++- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 7a6ca886caed..2e0bf149b1ba 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -768,7 +768,16 @@ def forward( if self.config.addition_embed_type is not None: if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) + if len(encoder_hidden_states.shape) == 3: + aug_emb = self.add_embedding(encoder_hidden_states) + else: + aug_emb_list = [] + for index in range(encoder_hidden_states.shape[1]): + aug_emb = self.add_embedding(encoder_hidden_states[:,index,:,:]) + aug_emb_list.append(aug_emb) + + aug_emb_list = torch.stack(aug_emb_list) + aug_emb = torch.mean(aug_emb_list, dim=0, keepdim=False) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 31e9801541bb..1884eb97a221 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1590,7 +1590,6 @@ def __call__( "time_ids": add_time_ids.chunk(2)[1], } else: - print(f'prompt_embeds_list in ctln') control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds #prompt_embeds, thesea modified controlnet_added_cond_kwargs = added_cond_kwargs @@ -1608,6 +1607,7 @@ def __call__( t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, + cross_attention_kwargs=self.cross_attention_kwargs, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, From 42619d1a0cba316e6e746b354e5ff0972a39bd56 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 16:16:37 +0900 Subject: [PATCH 051/482] fix bug --- src/diffusers/models/controlnets/controlnet.py | 3 ++- .../controlnet/pipeline_controlnet_sd_xl.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 2e0bf149b1ba..5911e21fbb2b 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -768,6 +768,7 @@ def forward( if self.config.addition_embed_type is not None: if self.config.addition_embed_type == "text": + # thesea modified if len(encoder_hidden_states.shape) == 3: aug_emb = self.add_embedding(encoder_hidden_states) else: @@ -775,7 +776,7 @@ def forward( for index in range(encoder_hidden_states.shape[1]): aug_emb = self.add_embedding(encoder_hidden_states[:,index,:,:]) aug_emb_list.append(aug_emb) - + aug_emb_list = torch.stack(aug_emb_list) aug_emb = torch.mean(aug_emb_list, dim=0, keepdim=False) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 1884eb97a221..ce958169525d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1504,7 +1504,12 @@ def __call__( original_size = original_size or image.shape[-2:] target_size = target_size or (height, width) - add_text_embeds = pooled_prompt_embeds + # thesea modified + if not isinstance(prompt, list): + add_text_embeds = pooled_prompt_embeds + else: + add_text_embeds_list = pooled_prompt_embeds_list + if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: @@ -1534,11 +1539,14 @@ def __call__( if isinstance(prompt, list): # modified prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + if isinstance(prompt, list): # modified + add_text_embeds_list = torch.cat([negative_pooled_prompt_embeds_list, add_text_embeds_list], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) if isinstance(prompt, list): # modified prompt_embeds_list = prompt_embeds_list.to(device) + add_text_embeds_list = add_text_embeds_list.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -1577,7 +1585,8 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + # thesea modified + added_cond_kwargs = {"text_embeds": add_text_embeds_list if isinstance(prompt, list) else add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: @@ -1607,7 +1616,7 @@ def __call__( t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, - cross_attention_kwargs=self.cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, # thesea modified conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, @@ -1646,6 +1655,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: + print(f'callback_on_step_end') callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] From 453f24c285e17acc2e2d481703f6a81100d9bf89 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 16:23:42 +0900 Subject: [PATCH 052/482] fix bug --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index ce958169525d..9a4d3df2fed5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1538,16 +1538,19 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) if isinstance(prompt, list): # modified prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) if isinstance(prompt, list): # modified add_text_embeds_list = torch.cat([negative_pooled_prompt_embeds_list, add_text_embeds_list], dim=0) + else: + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) if isinstance(prompt, list): # modified prompt_embeds_list = prompt_embeds_list.to(device) add_text_embeds_list = add_text_embeds_list.to(device) - add_text_embeds = add_text_embeds.to(device) + else: + add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop From 36e97a726394f75415deb96128cf10fba0500fbe Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 16:44:31 +0900 Subject: [PATCH 053/482] fix bug --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9a4d3df2fed5..32d844b7b535 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1588,8 +1588,7 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # thesea modified - added_cond_kwargs = {"text_embeds": add_text_embeds_list if isinstance(prompt, list) else add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: From db55f1344b3452be884066b5cb7f4b0a88c42d84 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 16:51:50 +0900 Subject: [PATCH 054/482] fix bug --- .../controlnet/pipeline_controlnet_sd_xl.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 32d844b7b535..4db51e3739f0 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1504,12 +1504,8 @@ def __call__( original_size = original_size or image.shape[-2:] target_size = target_size or (height, width) - # thesea modified - if not isinstance(prompt, list): - add_text_embeds = pooled_prompt_embeds - else: - add_text_embeds_list = pooled_prompt_embeds_list - + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: @@ -1538,19 +1534,11 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) if isinstance(prompt, list): # modified prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) - if isinstance(prompt, list): # modified - add_text_embeds_list = torch.cat([negative_pooled_prompt_embeds_list, add_text_embeds_list], dim=0) - else: - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) - if isinstance(prompt, list): # modified - prompt_embeds_list = prompt_embeds_list.to(device) - add_text_embeds_list = add_text_embeds_list.to(device) - else: - add_text_embeds = add_text_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop From fe60b1212db3169272bfc414b21c88d9a10f8ca5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 17:14:22 +0900 Subject: [PATCH 055/482] print info --- src/diffusers/models/attention_processor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ffe840d44e7a..e6be345ceb04 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3322,6 +3322,8 @@ def __call__( _hidden_states.shape[2], ) + print(f'mask_downsample shape={mask_downsample.shape}') + print(f'_hidden_states shape={_hidden_states.shape}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states_list.append(_hidden_states * mask_downsample) @@ -5308,7 +5310,10 @@ def __call__( _current_ip_hidden_states.shape[1], _current_ip_hidden_states.shape[2], ) - + + print(f'mask_downsample shape={mask_downsample.shape}') + print(f'_current_ip_hidden_states={_current_ip_hidden_states.shape}') + print(f'scale[i]={scale[i]}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) else: From eb70a670a6d413b6320af68a4fc499c7b04c2cea Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 12 Jan 2025 17:32:44 +0900 Subject: [PATCH 056/482] remove print info --- src/diffusers/models/attention_processor.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e6be345ceb04..d7d07bd97bef 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3322,8 +3322,6 @@ def __call__( _hidden_states.shape[2], ) - print(f'mask_downsample shape={mask_downsample.shape}') - print(f'_hidden_states shape={_hidden_states.shape}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states_list.append(_hidden_states * mask_downsample) @@ -5311,9 +5309,6 @@ def __call__( _current_ip_hidden_states.shape[2], ) - print(f'mask_downsample shape={mask_downsample.shape}') - print(f'_current_ip_hidden_states={_current_ip_hidden_states.shape}') - print(f'scale[i]={scale[i]}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) else: @@ -5507,7 +5502,6 @@ def __call__( if ip_adapter_masks is not None: # [scale] * mask.shape[1] for setting strength. disabled by default scale = ip_adapter_masks[0].shape[1] - print(f'text prompt strength = {scale}') mask_downsample = IPAdapterMaskProcessor.downsample( ip_adapter_masks[0][:, i, :, :], batch_size, From 1f060925bc50fa7cd80c314efd7a14067f9ff964 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 26 Jan 2025 18:37:03 +0900 Subject: [PATCH 057/482] add safety checker tp sdxl img2img --- .../pipeline_controlnet_sd_xl_img2img.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 86588a5b3851..c597aaac1259 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -61,6 +61,7 @@ if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from .safety_checker import StableDiffusionSafetyChecker # # thesea modified for safty checker from ...utils import is_torch_xla_available @@ -233,7 +234,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( "text_encoder_2", "feature_extractor", "image_encoder", + "safety_checker", ] + _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -254,11 +257,13 @@ def __init__( unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = False, ): super().__init__() @@ -274,6 +279,7 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) @@ -289,7 +295,7 @@ def __init__( else: self.watermark = None - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, requires_safety_checker=requires_safety_checker) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt @@ -1673,7 +1679,9 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + has_nsfw_concept = [False]*len(image) if not return_dict: - return (image,) + return (image,has_nsfw_concept) - return StableDiffusionXLPipelineOutput(images=image) + # thesea modified for safety checker output + return StableDiffusionXLPipelineOutput(images=image, nsfw_content_detected = has_nsfw_concept) From d9b18bfcc41e4d23afaf283ca00bdb180f5ee7bb Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 26 Jan 2025 19:08:58 +0900 Subject: [PATCH 058/482] add support for text prompt mask to sdxl img2img --- .../controlnet/pipeline_controlnet_sd_xl.py | 14 +-- .../pipeline_controlnet_sd_xl_img2img.py | 95 +++++++++++++------ 2 files changed, 71 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4db51e3739f0..d31e4b90d87e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -764,7 +764,7 @@ def check_inputs( # `prompt` needs more sophisticated handling when there are multiple # conditionings. - ## thesea modified + # thesea modified for text prompt mask #if isinstance(self.controlnet, MultiControlNetModel): # if isinstance(prompt, list): # logger.warning( @@ -1331,7 +1331,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = 1 #len(prompt), thesea modified + batch_size = 1 #len(prompt), thesea modified for text prompt mask else: batch_size = prompt_embeds.shape[0] @@ -1352,7 +1352,7 @@ def __call__( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - ## thesea modified + ## thesea modified for text prompt mask prompt_embeds_list = [] negative_prompt_embeds_list = [] pooled_prompt_embeds_list = [] @@ -1373,10 +1373,6 @@ def __call__( self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, - #prompt_embeds=prompt_embeds, - #negative_prompt_embeds=negative_prompt_embeds, - #pooled_prompt_embeds=pooled_prompt_embeds, - #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) @@ -1590,7 +1586,7 @@ def __call__( } else: control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds #prompt_embeds, thesea modified + controlnet_prompt_embeds = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds #prompt_embeds, thesea modified for text prompt mask controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): @@ -1606,7 +1602,7 @@ def __call__( t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, - cross_attention_kwargs=self.cross_attention_kwargs, # thesea modified + cross_attention_kwargs=self.cross_attention_kwargs, # thesea modified for text prompt mask conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index c597aaac1259..a170a0ef059a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -720,12 +720,14 @@ def check_inputs( # `prompt` needs more sophisticated handling when there are multiple # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) + + # thesea modified for text prompt mask + #if isinstance(self.controlnet, MultiControlNetModel): + # if isinstance(prompt, list): + # logger.warning( + # f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + # " prompts. The conditionings will be fixed across the prompts." + # ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( @@ -1366,7 +1368,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + batch_size = 1 #len(prompt), thesea modified for text prompt mask else: batch_size = prompt_embeds.shape[0] @@ -1386,26 +1388,60 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) + + ## thesea modified for text prompt mask + prompt_embeds_list = [] + negative_prompt_embeds_list = [] + pooled_prompt_embeds_list = [] + negative_pooled_prompt_embeds_list = [] + + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + pmt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + prompt_embeds_list.append(prompt_embeds) + negative_prompt_embeds_list.append(negative_prompt_embeds) + pooled_prompt_embeds_list.append(pooled_prompt_embeds) + negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) + prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) + negative_prompt_embeds_list = torch.stack(negative_prompt_embeds_list,dim=1) + pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) + negative_pooled_prompt_embeds_list = torch.stack(negative_pooled_prompt_embeds_list,dim=1) + else: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: @@ -1554,7 +1590,7 @@ def __call__( } else: control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds + controlnet_prompt_embeds = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds #prompt_embeds, thesea modified for text prompt mask controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): @@ -1569,6 +1605,7 @@ def __call__( t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, + cross_attention_kwargs=self.cross_attention_kwargs, # thesea modified for text prompt mask conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, @@ -1589,7 +1626,7 @@ def __call__( noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for multiple text promts cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, From 6600a47e9a1f9510d3653e2e00c811245721fb50 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 26 Jan 2025 21:16:48 +0900 Subject: [PATCH 059/482] add modifications comments --- src/diffusers/models/controlnets/controlnet.py | 2 +- src/diffusers/models/unets/unet_2d_condition.py | 2 +- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py | 2 ++ 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 5911e21fbb2b..06204f61e625 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -768,7 +768,7 @@ def forward( if self.config.addition_embed_type is not None: if self.config.addition_embed_type == "text": - # thesea modified + # thesea modified for text prompt mask if len(encoder_hidden_states.shape) == 3: aug_emb = self.add_embedding(encoder_hidden_states) else: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index e4b8434bd0b6..50bb679c9a16 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1171,7 +1171,7 @@ def forward( if self.time_embed_act is not None: emb = self.time_embed_act(emb) - # thesea modifed + # thesea modifed for text prompt mask if len(encoder_hidden_states.shape) == 3: encoder_hidden_states = self.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index d31e4b90d87e..1367d3155ee9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1528,7 +1528,7 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - if isinstance(prompt, list): # modified + if isinstance(prompt, list): # thesea modified for text prompt mask prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index a170a0ef059a..282c0c32ac29 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -1557,6 +1557,8 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + if isinstance(prompt, list): # thesea modified for text prompt mask + prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) From 51313c13597ace2589162f088582c586e7f14db9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 27 Jan 2025 12:29:58 +0900 Subject: [PATCH 060/482] fix Cannot copy out of meta tensor; no data! --- src/diffusers/loaders/unet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 5d0bfa5a47ad..088b9a60d3e8 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -781,6 +781,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us if not low_cpu_mem_usage: if state_dict is not None: image_projection.load_state_dict(updated_state_dict, strict=True) + else: + image_projection = image_projection.to_empty(device=self.device) else: load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) From f604787447cde4356c0c1e1eef3bc15d0aaea8b9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 16 Jan 2025 13:14:58 +0900 Subject: [PATCH 061/482] modifiedmodified self.encoder_hid_proj --- src/diffusers/loaders/unet.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 088b9a60d3e8..14a3cf3d15f7 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -913,7 +913,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): ) image_projection_layers.append(image_projection_layer) - self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + # diffuser original + #self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + + # thesea modified + self.encoder_hid_proj = MultiIPAdapterImageProjection.float().eval() + self.encoder_hid_proj = self.encoder_hid_proj.to_empty(device=self.device) + #self.encoder_hid_proj.load_state_dict(image_projection_layers, map_location=self.device) + self.config.encoder_hid_dim_type = "ip_image_proj" self.to(dtype=self.dtype, device=self.device) From 3af2a65133ad34cef7c56462ef2405cb647c9f5f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 16 Jan 2025 13:16:01 +0900 Subject: [PATCH 062/482] fix typo --- src/diffusers/loaders/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 14a3cf3d15f7..8506b5519c3e 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -917,7 +917,7 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): #self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) # thesea modified - self.encoder_hid_proj = MultiIPAdapterImageProjection.float().eval() + self.encoder_hid_proj = MultiIPAdapterImageProjection().float().eval() self.encoder_hid_proj = self.encoder_hid_proj.to_empty(device=self.device) #self.encoder_hid_proj.load_state_dict(image_projection_layers, map_location=self.device) From d8e9c83f0cae650bf0d24c7a4aaead27384505bb Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 16 Jan 2025 13:24:10 +0900 Subject: [PATCH 063/482] revery back to original diffuser _load_ip_adapter_weights --- src/diffusers/loaders/unet.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 8506b5519c3e..c73b81697e79 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -907,20 +907,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): # convert IP-Adapter Image Projection layers to diffusers image_projection_layers = [] + for state_dict in state_dicts: image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage ) image_projection_layers.append(image_projection_layer) - # diffuser original - #self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) - - # thesea modified - self.encoder_hid_proj = MultiIPAdapterImageProjection().float().eval() - self.encoder_hid_proj = self.encoder_hid_proj.to_empty(device=self.device) - #self.encoder_hid_proj.load_state_dict(image_projection_layers, map_location=self.device) - + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.config.encoder_hid_dim_type = "ip_image_proj" self.to(dtype=self.dtype, device=self.device) From b539c13b373412738cb0c33299f3e9a3e02493a4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 17 Jan 2025 10:19:24 +0900 Subject: [PATCH 064/482] fix bug --- src/diffusers/loaders/unet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c73b81697e79..6c8808a9e0d6 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -779,6 +779,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us updated_state_dict[diffusers_name] = value if not low_cpu_mem_usage: + print(f'not low_cpu_mem_usage') if state_dict is not None: image_projection.load_state_dict(updated_state_dict, strict=True) else: From 8eaaee18229a764a080289c40209213afba8df72 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 17 Jan 2025 11:51:49 +0900 Subject: [PATCH 065/482] add commnets to modifications --- src/diffusers/loaders/unet.py | 21 +++++---------------- src/diffusers/models/__init__.py | 4 +++- src/diffusers/models/embeddings.py | 2 +- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 6c8808a9e0d6..dc13474a9ab4 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -23,7 +23,7 @@ from huggingface_hub.utils import validate_hf_hub_args from ..models.embeddings import ( - ImageProjectionCustomized, + ImageProjectionCustomized, # thesea modified for human element model ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterFaceIDPlusImageProjection, @@ -565,6 +565,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us image_projection = None init_context = init_empty_weights if low_cpu_mem_usage else nullcontext + # thesea modified for human element model if state_dict is None: # IP-Adapter-Customized num_image_text_embeds = 1 # fixed @@ -577,20 +578,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us image_embed_dim=clip_embeddings_dim, num_image_text_embeds=num_image_text_embeds, ) - - elif "proj.weight" in state_dict: - # IP-Adapter - num_image_text_embeds = 4 - clip_embeddings_dim = state_dict["proj.weight"].shape[-1] - cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 - - with init_context(): - image_projection = ImageProjectionCustomized( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim, - num_image_text_embeds=num_image_text_embeds, - ) - + # thesea modified for human element model else: if "proj.weight" in state_dict: # IP-Adapter @@ -779,7 +767,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us updated_state_dict[diffusers_name] = value if not low_cpu_mem_usage: - print(f'not low_cpu_mem_usage') + # thesea modified for human element model if state_dict is not None: image_projection.load_state_dict(updated_state_dict, strict=True) else: @@ -844,6 +832,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F ) num_image_text_embeds = [] for state_dict in state_dicts: + # thesea modified for human element model if state_dict["image_proj"] is None: num_image_text_embeds += [1] else: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 0cf108dec292..00ec4ad9ac09 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -52,7 +52,7 @@ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] - _import_structure["embeddings"] = ["ImageProjection", "ImageProjectionCustomized"] + _import_structure["embeddings"] = ["ImageProjection", "ImageProjectionCustomized"] # thesea modified for human element model _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] @@ -132,6 +132,8 @@ UNetControlNetXSModel, ) + from .embeddings import ImageProjection, ImageProjectionCustomized # thesea modified for human element model + from .embeddings import ImageProjection, ImageProjectionCustomized from .modeling_utils import ModelMixin from .transformers import ( diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 53c5403fe4dc..c6dd64e767be 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1518,7 +1518,7 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): return torch.cat([image_text_embeds, text_embeds], dim=1) - +# thesea modified for human element model class ImageProjectionCustomized(nn.Module): def __init__( self, From a22dde33c25a456a1340f01aa78fa0c5508bf3d6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 17 Jan 2025 11:59:24 +0900 Subject: [PATCH 066/482] remove CustomIPAdapterAttnProcessor2_0 --- src/diffusers/models/attention_processor.py | 245 -------------------- 1 file changed, 245 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d7d07bd97bef..37d466ec458f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5346,251 +5346,6 @@ def __call__( return hidden_states -class CustomIPAdapterAttnProcessor2_0(torch.nn.Module): - r""" - Custom Attention processor for Both Text and IP-Adapter Masks for PyTorch 2.0. - - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): - The context length of the image features. - scale (`float` or `List[float]`, defaults to 1.0): - the weight scale of image prompt. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - - if not isinstance(num_tokens, (tuple, list)): - num_tokens = [num_tokens] - self.num_tokens = num_tokens - - if not isinstance(scale, list): - scale = [scale] * len(num_tokens) - if len(scale) != len(num_tokens): - raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") - self.scale = scale - - self.to_k_ip = nn.ModuleList( - [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] - ) - self.to_v_ip = nn.ModuleList( - [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - scale: float = 1.0, - ip_adapter_masks: Optional[torch.Tensor] = None, - ): - residual = hidden_states - - # separate ip_hidden_states from encoder_hidden_states - if encoder_hidden_states is not None: - if isinstance(encoder_hidden_states, tuple): - encoder_hidden_states, ip_hidden_states = encoder_hidden_states - else: - deprecation_message = ( - "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." - " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." - ) - deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) - end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - [encoder_hidden_states[:, end_pos:, :]], - ) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - if ip_adapter_masks is not None: - if not isinstance(ip_adapter_masks, List): - # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] - ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) - if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): - raise ValueError( - f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " - f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " - f"({len(ip_hidden_states)})" - ) - else: - for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): - if mask is None: - continue - if not isinstance(mask, torch.Tensor) or mask.ndim != 4: - raise ValueError( - "Each element of the ip_adapter_masks array should be a tensor with shape " - "[1, num_images_for_ip_adapter, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) - if mask.shape[1] != ip_state.shape[1]: - raise ValueError( - f"Number of masks ({mask.shape[1]}) does not match " - f"number of ip images ({ip_state.shape[1]}) at index {index}" - ) - if isinstance(scale, list) and not len(scale) == mask.shape[1]: - raise ValueError( - f"Number of masks ({mask.shape[1]}) does not match " - f"number of scales ({len(scale)}) at index {index}" - ) - else: - ip_adapter_masks = [None] * len(self.scale) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if ip_adapter_masks is not None: - # [scale] * mask.shape[1] for setting strength. disabled by default - scale = ip_adapter_masks[0].shape[1] - mask_downsample = IPAdapterMaskProcessor.downsample( - ip_adapter_masks[0][:, i, :, :], - batch_size, - hidden_states.shape[1], - hidden_states.shape[2], - ) - - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states = scale * (hidden_states * mask_downsample) - - # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks - ): - skip = False - if isinstance(scale, list): - if all(s == 0 for s in scale): - skip = True - elif scale == 0: - skip = True - if not skip: - if mask is not None: - if not isinstance(scale, list): - scale = [scale] * mask.shape[1] - - current_num_images = mask.shape[1] - for i in range(current_num_images): - ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) - ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - _current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) - - mask_downsample = IPAdapterMaskProcessor.downsample( - mask[:, i, :, :], - batch_size, - _current_ip_hidden_states.shape[1], - _current_ip_hidden_states.shape[2], - ) - - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) - else: - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + scale * current_ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - class IPAdapterXFormersAttnProcessor(torch.nn.Module): r""" Attention processor for IP-Adapter using xFormers. From 702fb0ac907743fcfdcbee8012997ae4afb9eb97 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 17 Jan 2025 12:06:10 +0900 Subject: [PATCH 067/482] add comments for modifications --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 5 ++++- .../pipelines/stable_diffusion_xl/pipeline_output.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 1367d3155ee9..350d5b662de1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -78,7 +78,7 @@ XLA_AVAILABLE = False from .multicontrolnet import MultiControlNetModel -from .safety_checker import StableDiffusionSafetyChecker +from .safety_checker import StableDiffusionSafetyChecker # # thesea modified for safty checker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -645,6 +645,7 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + # thesea modified for safty checker def run_safety_checker(self, image, device, dtype): image_original = copy.deepcopy(image) if self.safety_checker is None: @@ -1692,6 +1693,7 @@ def __call__( latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] + # thesea modified for safty checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # cast back to fp16 if needed if needs_upcasting: @@ -1705,6 +1707,7 @@ def __call__( if self.watermark is not None: image = self.watermark.apply_watermark(image) + # thesea modified for safty checker if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py index f06a1e222117..bf9cc68fb6f4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional, Union # thesea modified for safty checker import numpy as np import PIL.Image @@ -21,6 +21,7 @@ class StableDiffusionXLPipelineOutput(BaseOutput): """ images: Union[List[PIL.Image.Image], np.ndarray] + # thesea modified for safty checker nsfw_content_detected: Optional[List[bool]] @@ -41,4 +42,5 @@ class FlaxStableDiffusionXLPipelineOutput(BaseOutput): """ images: np.ndarray + # thesea modified for safty checker nsfw_content_detected: Optional[List[bool]] From 5cb39cd1b111e13600df2532d54d9375efb90e55 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 18 Jan 2025 11:29:20 +0900 Subject: [PATCH 068/482] add safety_checker to img2img pipeline --- .../pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 282c0c32ac29..5b8a6e5c68e1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -61,7 +61,8 @@ if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker -from .safety_checker import StableDiffusionSafetyChecker # # thesea modified for safty checker +from .safety_checker import StableDiffusionSafetyChecker # thesea modified for safty checker + from ...utils import is_torch_xla_available @@ -73,6 +74,7 @@ else: XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 036a97adb90fc02365236de760f18228d1638bca Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 18 Jan 2025 11:43:12 +0900 Subject: [PATCH 069/482] modified for safety checker output --- .../pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 5b8a6e5c68e1..4cf300ff6c75 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -1726,3 +1726,4 @@ def __call__( # thesea modified for safety checker output return StableDiffusionXLPipelineOutput(images=image, nsfw_content_detected = has_nsfw_concept) + From 785b974d956d1760935b246f33215c27dd5b23b1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 21 Jan 2025 15:17:35 +0900 Subject: [PATCH 070/482] add global_step to logger in sdxl ctln pipeline --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 350d5b662de1..d5ec86b5cae3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1560,6 +1560,8 @@ def __call__( is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + + global_step = 0 with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -1662,6 +1664,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + global_step += 1 + logger.info(f"global_step:{global_step}") if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) From 9ac4b003d7bfea75b792d9e88c87c6298a6efbbf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 21 Jan 2025 15:26:58 +0900 Subject: [PATCH 071/482] add comments to modifications --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index d5ec86b5cae3..adc1c3687004 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1560,7 +1560,7 @@ def __call__( is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - + # thesea modified for inference steps info global_step = 0 with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1664,6 +1664,7 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + # thesea modified for inference steps info global_step += 1 logger.info(f"global_step:{global_step}") if callback is not None and i % callback_steps == 0: From 7dc138c849bebe6bc3b4c98cf4670121c233fa44 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Feb 2025 23:24:52 +0900 Subject: [PATCH 072/482] add modification comments --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 37d466ec458f..c254a4808d1f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3248,6 +3248,7 @@ def __call__( deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states + # thesea modified for text prompt mask if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): hidden_states_list = [] if not text_masks.shape[0] == encoder_hidden_states.shape[1]: From 51f52b222b41961ceb0be51ab4b537ba83e5ab05 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 6 Feb 2025 11:15:25 +0900 Subject: [PATCH 073/482] modifications to sd35 ipadapter for color modality in multimodality model --- src/diffusers/loaders/ip_adapter.py | 4 +++- src/diffusers/loaders/transformer_sd3.py | 22 +++++++++++++++++----- src/diffusers/models/embeddings.py | 13 +++++++++++-- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 33144090cbc6..25b86ce45d04 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -676,6 +676,7 @@ def load_ip_adapter( weight_name: str = "ip-adapter.safetensors", subfolder: Optional[str] = None, image_encoder_folder: Optional[str] = "image_encoder", + color_modality: bool = False, # thesea modified for color modality in multimodality model **kwargs, ) -> None: """ @@ -820,7 +821,8 @@ def load_ip_adapter( ) # Load IP-Adapter into transformer - self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + # thesea modified for color modality in multimodality model + self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage, color_modality = color_modality) def set_ip_adapter_scale(self, scale: float) -> None: """ diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index ece17e6728fa..f9a6164fda24 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -25,7 +25,7 @@ class SD3Transformer2DLoadersMixin: """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" - + def _convert_ip_adapter_attn_to_diffusers( self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT ) -> Dict: @@ -124,9 +124,19 @@ def _convert_ip_adapter_image_proj_to_diffusers( updated_state_dict[key] = value # Image projetion parameters - embed_dim = updated_state_dict["proj_in.weight"].shape[1] + # thesea modified for color modality in multimodality model + color_modality = False + if "proj_in.weight" in updated_state_dict: + embed_dim = updated_state_dict["proj_in.weight"].shape[1] + else: + color_modality = True + embed_dim = 0 output_dim = updated_state_dict["proj_out.weight"].shape[0] - hidden_dim = updated_state_dict["proj_in.weight"].shape[0] + # thesea modified for color modality in multimodality model + if "proj_in.weight" in updated_state_dict: + hidden_dim = updated_state_dict["proj_in.weight"].shape[0] + else: + hidden_dim = 0 heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 num_queries = updated_state_dict["latents"].shape[1] timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1] @@ -140,6 +150,7 @@ def _convert_ip_adapter_image_proj_to_diffusers( heads=heads, num_queries=num_queries, timestep_in_dim=timestep_in_dim, + color_modality=color_modality, # thesea modified for color modality in multimodality model ) if not low_cpu_mem_usage: @@ -150,7 +161,8 @@ def _convert_ip_adapter_image_proj_to_diffusers( return image_proj - def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None: + # thesea modified for color modality in multimodality model + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT, color_modality: bool = False) -> None: """Sets IP-Adapter attention processors, image projection, and loads state_dict. Args: @@ -167,4 +179,4 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage) self.set_attn_processor(attn_procs) - self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage) + self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage) \ No newline at end of file diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1e01833ff1b8..4e758773338c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2560,10 +2560,16 @@ def __init__( timestep_in_dim: int = 320, timestep_flip_sin_to_cos: bool = True, timestep_freq_shift: int = 0, + color_modality = False, # thesea modified for color modality in multimodality model ) -> None: super().__init__() self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5) - self.proj_in = nn.Linear(embed_dim, hidden_dim) + + # thesea modified for color modality in multimodality model + self.color_modality = color_modality + if not self.color_modality: + self.proj_in = nn.Linear(embed_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) self.layers = nn.ModuleList( @@ -2588,7 +2594,10 @@ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor latents = self.latents.repeat(x.size(0), 1, 1) - x = self.proj_in(x) + # thesea modified for color modality in multimodality model + if not self.color_modality: + x = self.proj_in(x) + x = x + timestep_emb[:, None] for block in self.layers: From 56df4ff651c8f419ac48c9ef675ab89b47aa2e8b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 26 Feb 2025 11:20:10 +0900 Subject: [PATCH 074/482] add color_modality = False code --- src/diffusers/loaders/transformer_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index f9a6164fda24..ae1ed9661187 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -25,7 +25,7 @@ class SD3Transformer2DLoadersMixin: """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" - + def _convert_ip_adapter_attn_to_diffusers( self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT ) -> Dict: @@ -179,4 +179,4 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage) self.set_attn_processor(attn_procs) - self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage) \ No newline at end of file + self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage) From 5906193919ae6057ce669f29219358a4aa4bd2fa Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 27 Feb 2025 18:03:16 +0900 Subject: [PATCH 075/482] add comments for modification --- src/diffusers/models/unets/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 50bb679c9a16..8e5689cc6c2b 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1147,7 +1147,7 @@ def forward( else: emb = emb + class_emb - # thesea modifed + # thesea modifed for text prompt mask if len(encoder_hidden_states.shape) == 3: aug_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs From 8e6f39f13ca057136fe7fab50f8b2541f9e19d06 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 28 Feb 2025 12:33:39 +0900 Subject: [PATCH 076/482] modifications to sd35 ctln pipeline for text prompt mask --- .../controlnet/pipeline_controlnet_sd_xl.py | 4 +- .../pipeline_stable_diffusion_3_controlnet.py | 97 +++++++++++++------ 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index adc1c3687004..5ed3847a6ef6 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -78,7 +78,7 @@ XLA_AVAILABLE = False from .multicontrolnet import MultiControlNetModel -from .safety_checker import StableDiffusionSafetyChecker # # thesea modified for safty checker +from .safety_checker import StableDiffusionSafetyChecker # thesea modified for safty checker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1605,7 +1605,7 @@ def __call__( t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, - cross_attention_kwargs=self.cross_attention_kwargs, # thesea modified for text prompt mask + cross_attention_kwargs=self.cross_attention_kwargs, # thesea modified for text prompt mask for accepting mask para conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 7f7acd882b59..90fc9fd92e11 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -1016,40 +1016,83 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + batch_size = 1 #len(prompt), thesea modified for text prompt mask else: batch_size = prompt_embeds.shape[0] device = self._execution_device dtype = self.transformer.dtype - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_3=prompt_3, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - do_classifier_free_guidance=self.do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - clip_skip=self.clip_skip, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - ) + # 3.1 Encode input prompt + ## thesea modified for text prompt mask + prompt_embeds_list = [] + negative_prompt_embeds_list = [] + pooled_prompt_embeds_list = [] + negative_pooled_prompt_embeds_list = [] + + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + #prompt_embeds=prompt_embeds, + #negative_prompt_embeds=negative_prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + prompt_embeds_list.append(prompt_embeds) + negative_prompt_embeds_list.append(negative_prompt_embeds) + pooled_prompt_embeds_list.append(pooled_prompt_embeds) + negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) + prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) + negative_prompt_embeds_list = torch.stack(negative_prompt_embeds_list,dim=1) + pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) + negative_pooled_prompt_embeds_list = torch.stack(negative_pooled_prompt_embeds_list,dim=1) + else: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - + if isinstance(prompt, list): # thesea modified for text prompt mask + prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) + #TBD because SDXL does not use this: pooled_prompt_embeds_list = torch.cat([negative_pooled_prompt_embeds_list, pooled_prompt_embeds_list], dim=0) + # 3. Prepare control image if controlnet_config.force_zeros_for_pooled_projection: # instantx sd3 controlnet does not apply shift factor @@ -1131,7 +1174,7 @@ def __call__( controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds if controlnet_config.joint_attention_dim is not None: - controlnet_encoder_hidden_states = prompt_embeds + controlnet_encoder_hidden_states = prompt_embeds_list if isinstance(prompt, list) else prompt_embeds #prompt_embeds, thesea modified for text prompt mask else: # SD35 official 8b controlnet does not use encoder_hidden_states controlnet_encoder_hidden_states = None @@ -1185,8 +1228,8 @@ def __call__( noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for multiple text promts + pooled_projections=pooled_prompt_embeds # TBD: pooled_prompt_embeds_list if isinstance(prompt, list) else pooled_prompt_embeds, block_controlnet_hidden_states=control_block_samples, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, From d824c2a2fb0d92ae20de751aa9d0ef5560cc4390 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 1 Mar 2025 17:13:02 +0900 Subject: [PATCH 077/482] additional modifications for sd35 text prompt mask --- src/diffusers/models/attention.py | 28 ++++- src/diffusers/models/attention_processor.py | 111 ++++++++++++++---- .../models/transformers/transformer_sd3.py | 8 +- .../pipeline_stable_diffusion_3_controlnet.py | 1 + 4 files changed, 118 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 93b11c2b43f0..97a01f50f03d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -202,12 +202,32 @@ def forward( else: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + print(f'attention JointTransformerBlock self.context_pre_only={self.context_pre_only}') + print(f'attention encoder_hidden_states shape={encoder_hidden_states.shape}') if self.context_pre_only: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + # thesea modifed for text prompt mask + if len(encoder_hidden_states.shape) == 3: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states = [] + for index in range(encoder_hidden_states.shape[1]): + norm_encoder_hidden_states.append(self.norm1_context(encoder_hidden_states[:,index,:,:], temb)) + norm_encoder_hidden_states = torch.stack(norm_encoder_hidden_states, dim=1) + else: - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) + if len(encoder_hidden_states.shape) == 3: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + else: + norm_encoder_hidden_states = [] + for index in range(encoder_hidden_states.shape[1]): + tmp_norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states[:,index,:,:], emb=temb + ) + norm_encoder_hidden_states.append(tmp_norm_encoder_hidden_states) + norm_encoder_hidden_states = torch.stack(norm_encoder_hidden_states, dim=1) + # Attention. attn_output, context_attn_output = self.attn( diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1eaf58f22039..ed01bfa21f3c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3242,7 +3242,7 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, - text_masks: Optional[torch.Tensor] = None, + text_masks: Optional[torch.Tensor] = None, # thesea modification for text prompt mask *args, **kwargs, ) -> torch.Tensor: @@ -5629,6 +5629,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, ip_hidden_states: torch.FloatTensor = None, temb: torch.FloatTensor = None, + text_masks: Optional[torch.Tensor] = None, # thesea modification for text prompt mask ) -> torch.FloatTensor: """ Perform the attention computation, integrating image features (if provided) and timestep embeddings. @@ -5677,33 +5678,93 @@ def __call__( key = attn.norm_k(key) # `context` projections. + # thesea modified for text prompt mask + print(f'attention processor encoder_hidden_states.ndim={encoder_hidden_states.ndim}') + print(f'attention processor encoder_hidden_states shape={encoder_hidden_states.shape}') if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + if encoder_hidden_states.ndim == 3: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + elif encoder_hidden_states.ndim == 4: + if not text_masks.shape[0] == encoder_hidden_states.shape[1]: + raise ValueError( + f"Length of text masks ({text_masks.shape[0]}) must match " + f"the number of text prompts " + f"({encoder_hidden_states.shape[1]})") + + queries = [] + keys = [] + values = [] + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,index,:,:]) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + tmp_query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + tmp_key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + tmp_value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + queries.append(tmp_query) + keys.append(tmp_key) + values.append(tmp_value) - query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + # thesea modified for text prompt mask + if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): + hidden_states_list = [] + for mask, query, key, value in zip(text_masks, queries, keys, values): + tmp_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + tmp_hidden_states = tmp_hidden_states.to(query.dtype) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, + batch_size, + tmp_hidden_states.shape[1], + tmp_hidden_states.shape[2], + ) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states_list.append(tmp_hidden_states * mask_downsample) + + hidden_states_list = torch.stack(hidden_states_list) + hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: # Split the attention outputs. diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index e41fad220de6..886be5b24dd7 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -372,7 +372,13 @@ def forward( hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) + # thesea modifed for text prompt mask + print(f'transformer_sd3 encoder_hidden_states.shape={encoder_hidden_states.shape}') + if len(encoder_hidden_states.shape) == 3: + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + else: + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 90fc9fd92e11..82febc7aae3b 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -1225,6 +1225,7 @@ def __call__( return_dict=False, )[0] + print(f'pipeline prompt_embeds_list shape={prompt_embeds_list.shape}') noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, From 2b154a9494bf04bc3ddaff9b03a4fc4740e474a5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 4 Mar 2025 18:48:34 +0900 Subject: [PATCH 078/482] modifications to pipeline_stable_diffusion_3.py for text prompt mask --- .../pipeline_stable_diffusion_3_controlnet.py | 2 +- .../pipeline_stable_diffusion_3.py | 94 ++++++++++++++----- 2 files changed, 70 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 82febc7aae3b..7c7428635dea 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -1038,7 +1038,7 @@ def __call__( pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( - prompt=prompt, + prompt=pmt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 4618d384cbd7..957dff2e004b 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -951,7 +951,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + batch_size = 1 #len(prompt), thesea modified for text prompt mask else: batch_size = prompt_embeds.shape[0] @@ -960,29 +960,69 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_3=prompt_3, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - do_classifier_free_guidance=self.do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - clip_skip=self.clip_skip, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) + + # 3.1 Encode input prompt + ## thesea modified for text prompt mask + prompt_embeds_list = [] + negative_prompt_embeds_list = [] + pooled_prompt_embeds_list = [] + negative_pooled_prompt_embeds_list = [] + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=pmt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + #prompt_embeds=prompt_embeds, + #negative_prompt_embeds=negative_prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + prompt_embeds_list.append(prompt_embeds) + negative_prompt_embeds_list.append(negative_prompt_embeds) + pooled_prompt_embeds_list.append(pooled_prompt_embeds) + negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds) + prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) + negative_prompt_embeds_list = torch.stack(negative_prompt_embeds_list,dim=1) + pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) + negative_pooled_prompt_embeds_list = torch.stack(negative_pooled_prompt_embeds_list,dim=1) + else: + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) if self.do_classifier_free_guidance: if skip_guidance_layers is not None: @@ -990,6 +1030,10 @@ def __call__( original_pooled_prompt_embeds = pooled_prompt_embeds prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + if isinstance(prompt, list): # thesea modified for text prompt mask + prompt_embeds_list = torch.cat([negative_prompt_embeds_list, prompt_embeds_list], dim=0) + #TBD because SDXL does not use this: pooled_prompt_embeds_list = torch.cat([negative_pooled_prompt_embeds_list, pooled_prompt_embeds_list], dim=0) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels @@ -1060,7 +1104,7 @@ def __call__( noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for multiple text promts pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, From fc1461472bbe89b6491843ac206eed10d523c097 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 4 Mar 2025 18:56:12 +0900 Subject: [PATCH 079/482] fix bug --- src/diffusers/models/transformers/transformer_sd3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 886be5b24dd7..87823da1ce31 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -377,8 +377,11 @@ def forward( if len(encoder_hidden_states.shape) == 3: encoder_hidden_states = self.context_embedder(encoder_hidden_states) else: + encoder_hidden_states_list = [] for index in range(encoder_hidden_states.shape[1]): - encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) + tmp_encoder_hidden_states = self.context_embedder(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states_list.append(tmp_encoder_hidden_states) + encoder_hidden_states = torch.stack(encoder_hidden_states_list, dim=1) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") From 21685f1640ef4725ed4898cfe04b406a893ca992 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 4 Mar 2025 19:18:46 +0900 Subject: [PATCH 080/482] print info --- src/diffusers/models/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 97a01f50f03d..99a0d4568c30 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -260,6 +260,8 @@ def forward( if self.context_pre_only: encoder_hidden_states = None else: + print(f'attention encoder_hidden_states shape={encoder_hidden_states.shape}') + print(f'attention context_attn_output shape={context_attn_output.shape}') context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output From 93eaad6b67ca51ff70c5592fb69d45be89540fb7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 4 Mar 2025 19:25:04 +0900 Subject: [PATCH 081/482] print info --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ed01bfa21f3c..3ad09041a100 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5809,6 +5809,8 @@ def __call__( hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: + print(f'attention processor hidden_states shape={hidden_states.shape}') + print(f'attention processor encoder_hidden_states shape={encoder_hidden_states.shape}') return hidden_states, encoder_hidden_states else: return hidden_states From fb0749fb0db26e663552908ed6dc3500ea22781a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 09:55:12 +0900 Subject: [PATCH 082/482] print info --- src/diffusers/models/attention_processor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3ad09041a100..366bfaf05e7f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1419,6 +1419,7 @@ def __call__( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, + text_masks: Optional[torch.Tensor] = None, # thesea modification for text prompt mask *args, **kwargs, ) -> torch.FloatTensor: @@ -1487,6 +1488,8 @@ def __call__( hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: + print(f'attention processor hidden_states shape={hidden_states.shape}') + print(f'attention processor encoder_hidden_states shape={encoder_hidden_states.shape}') return hidden_states, encoder_hidden_states else: return hidden_states From f97efbca86a33e97d24d72877703a86f17984b83 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 12:11:31 +0900 Subject: [PATCH 083/482] modifications to attention processor JointAttnProcessor2_0 for text prompt mask --- src/diffusers/models/attention_processor.py | 101 +++++++++++++++----- 1 file changed, 79 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 366bfaf05e7f..717f39e95b22 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1446,32 +1446,89 @@ def __call__( # `context` projections. if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + if encoder_hidden_states.ndim == 3: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + elif encoder_hidden_states.ndim == 4: + if not text_masks.shape[0] == encoder_hidden_states.shape[1]: + raise ValueError( + f"Length of text masks ({text_masks.shape[0]}) must match " + f"the number of text prompts " + f"({encoder_hidden_states.shape[1]})") - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + queries = [] + keys = [] + values = [] + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,index,:,:]) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + tmp_query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + tmp_key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + tmp_value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + queries.append(tmp_query) + keys.append(tmp_key) + values.append(tmp_value) + + # thesea modified for text prompt mask + if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): + hidden_states_list = [] + for mask, query, key, value in zip(text_masks, queries, keys, values): + tmp_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + tmp_hidden_states = tmp_hidden_states.to(query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, + batch_size, + tmp_hidden_states.shape[1], + tmp_hidden_states.shape[2], + ) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states_list.append(tmp_hidden_states * mask_downsample) + + hidden_states_list = torch.stack(hidden_states_list) + hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: # Split the attention outputs. From 7872f287f9c54c4e4ba23a47b95065781d4b9e07 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 12:46:51 +0900 Subject: [PATCH 084/482] fix bug --- src/diffusers/models/attention.py | 48 +++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 99a0d4568c30..bd545c52c928 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -263,19 +263,51 @@ def forward( print(f'attention encoder_hidden_states shape={encoder_hidden_states.shape}') print(f'attention context_attn_output shape={context_attn_output.shape}') context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output + # thesea modified for text prompt mask + if len(encoder_hidden_states.shape) == 3: + encoder_hidden_states = encoder_hidden_states + context_attn_output + elif len(encoder_hidden_states.shape) == 4: + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states[:,index,:,:] = encoder_hidden_states[:,index,:,:] + context_attn_output + + if len(encoder_hidden_states.shape) == 3: + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + elif len(encoder_hidden_states.shape) == 4: + norm_encoder_hidden_states = [] + for index in range(encoder_hidden_states.shape[1]): + tmp_norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states[:,index,:,:]) + tmp_norm_encoder_hidden_states = tmp_norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_encoder_hidden_states.append(tmp_norm_encoder_hidden_states) + norm_encoder_hidden_states = torch.stack(norm_encoder_hidden_states, dim=1) - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory - context_ff_output = _chunked_feed_forward( - self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size - ) + if len(encoder_hidden_states.shape) == 3: + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + elif len(encoder_hidden_states.shape) == 4: + context_ff_output = [] + for index in range(encoder_hidden_states.shape[1]): + tmp_context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states[:,index,:,:], self._chunk_dim, self._chunk_size + ) + context_ff_output.append(tmp_context_ff_output) + context_ff_output.torch.stack(context_ff_output, dim=1) else: - context_ff_output = self.ff_context(norm_encoder_hidden_states) + if len(encoder_hidden_states.shape) == 3: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + elif len(encoder_hidden_states.shape) == 4: + context_ff_output = [] + for index in range(encoder_hidden_states.shape[1]): + tmp_context_ff_output = self.ff_context(norm_encoder_hidden_states[:,index,:,:]) + context_ff_output.append(tmp_context_ff_output) + context_ff_output.stack(context_ff_output, dim=1) + + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - + return encoder_hidden_states, hidden_states From 138a51ea990cf030cfab3f7e2b702f11618f3f2a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 12:49:11 +0900 Subject: [PATCH 085/482] fix bug --- src/diffusers/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bd545c52c928..1001e08671dd 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -294,7 +294,7 @@ def forward( self.ff_context, norm_encoder_hidden_states[:,index,:,:], self._chunk_dim, self._chunk_size ) context_ff_output.append(tmp_context_ff_output) - context_ff_output.torch.stack(context_ff_output, dim=1) + context_ff_output = torch.stack(context_ff_output, dim=1) else: if len(encoder_hidden_states.shape) == 3: context_ff_output = self.ff_context(norm_encoder_hidden_states) @@ -303,7 +303,7 @@ def forward( for index in range(encoder_hidden_states.shape[1]): tmp_context_ff_output = self.ff_context(norm_encoder_hidden_states[:,index,:,:]) context_ff_output.append(tmp_context_ff_output) - context_ff_output.stack(context_ff_output, dim=1) + context_ff_output = torch.stack(context_ff_output, dim=1) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output From 5b1a8b421cfa96241aaf5c223b27309e6212b7ab Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 12:51:34 +0900 Subject: [PATCH 086/482] fix bug --- src/diffusers/models/attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1001e08671dd..7921e9341905 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -305,8 +305,11 @@ def forward( context_ff_output.append(tmp_context_ff_output) context_ff_output = torch.stack(context_ff_output, dim=1) - - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if len(encoder_hidden_states.shape) == 3: + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + elif len(encoder_hidden_states.shape) == 4: + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states[:,index,:,:] = encoder_hidden_states[:,index,:,:] + c_gate_mlp.unsqueeze(1) * context_ff_output[:,index,:,:] return encoder_hidden_states, hidden_states From e8e63e1e432cd963c50a356381a2e2ea87b7e2d2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 15:13:23 +0900 Subject: [PATCH 087/482] print info --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 717f39e95b22..2621462f9aee 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1509,11 +1509,13 @@ def __call__( # thesea modified for text prompt mask if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): hidden_states_list = [] + print(f'text_masks shape={text_masks.shape}') for mask, query, key, value in zip(text_masks, queries, keys, values): tmp_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) tmp_hidden_states = tmp_hidden_states.to(query.dtype) + print(f'mask shape={mask.shape}') mask_downsample = IPAdapterMaskProcessor.downsample( mask, batch_size, From 395002de28be13bae7b328121adef00f0e887339 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 16:52:58 +0900 Subject: [PATCH 088/482] add context_attn_output only to the last prompt --- src/diffusers/models/attention.py | 3 ++- src/diffusers/models/attention_processor.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7921e9341905..f87cb3ec8d8a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -268,7 +268,8 @@ def forward( encoder_hidden_states = encoder_hidden_states + context_attn_output elif len(encoder_hidden_states.shape) == 4: for index in range(encoder_hidden_states.shape[1]): - encoder_hidden_states[:,index,:,:] = encoder_hidden_states[:,index,:,:] + context_attn_output + if index == encoder_hidden_states.shape[1] - 1: + encoder_hidden_states[:,index,:,:] = encoder_hidden_states[:,index,:,:] + context_attn_output if len(encoder_hidden_states.shape) == 3: norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2621462f9aee..e3008e3137f6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1522,6 +1522,7 @@ def __call__( tmp_hidden_states.shape[1], tmp_hidden_states.shape[2], ) + print(f'mask_downsample shape={mask_downsample.shape}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states_list.append(tmp_hidden_states * mask_downsample) From 0d99cede6c1d42d40852158cfb93d3dfe4b94af7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 17:03:27 +0900 Subject: [PATCH 089/482] fig bug --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e3008e3137f6..0ba27e25bd6a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1519,12 +1519,12 @@ def __call__( mask_downsample = IPAdapterMaskProcessor.downsample( mask, batch_size, - tmp_hidden_states.shape[1], + residual.shape[1], tmp_hidden_states.shape[2], ) print(f'mask_downsample shape={mask_downsample.shape}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_list.append(tmp_hidden_states * mask_downsample) + hidden_states_list.append(tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample) hidden_states_list = torch.stack(hidden_states_list) hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) From edc9f74875c214037f761452aa674d3d1e5ff3b7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 17:13:10 +0900 Subject: [PATCH 090/482] fix bug --- src/diffusers/models/attention.py | 4 ++-- src/diffusers/models/attention_processor.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f87cb3ec8d8a..b409cb8cec50 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -268,8 +268,8 @@ def forward( encoder_hidden_states = encoder_hidden_states + context_attn_output elif len(encoder_hidden_states.shape) == 4: for index in range(encoder_hidden_states.shape[1]): - if index == encoder_hidden_states.shape[1] - 1: - encoder_hidden_states[:,index,:,:] = encoder_hidden_states[:,index,:,:] + context_attn_output + #if index == encoder_hidden_states.shape[1] - 1: + encoder_hidden_states[:,index,:,:] = encoder_hidden_states[:,index,:,:] + context_attn_output if len(encoder_hidden_states.shape) == 3: norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0ba27e25bd6a..07b5670756ea 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1524,7 +1524,8 @@ def __call__( ) print(f'mask_downsample shape={mask_downsample.shape}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_list.append(tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample) + tmp_hidden_states[:,:residual.shape[1],:] = tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample + hidden_states_list.append(tmp_hidden_states) hidden_states_list = torch.stack(hidden_states_list) hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) From 2fc4eedb4963a59b31d94938246ae62c7ba0ee87 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 17:32:24 +0900 Subject: [PATCH 091/482] fix bug --- src/diffusers/models/attention.py | 9 ++----- src/diffusers/models/attention_processor.py | 27 +++++++++++++++------ 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b409cb8cec50..b03d05614e83 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -263,14 +263,9 @@ def forward( print(f'attention encoder_hidden_states shape={encoder_hidden_states.shape}') print(f'attention context_attn_output shape={context_attn_output.shape}') context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + # thesea modified for text prompt mask - if len(encoder_hidden_states.shape) == 3: - encoder_hidden_states = encoder_hidden_states + context_attn_output - elif len(encoder_hidden_states.shape) == 4: - for index in range(encoder_hidden_states.shape[1]): - #if index == encoder_hidden_states.shape[1] - 1: - encoder_hidden_states[:,index,:,:] = encoder_hidden_states[:,index,:,:] + context_attn_output - if len(encoder_hidden_states.shape) == 3: norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 07b5670756ea..c6920f74f00f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1509,6 +1509,7 @@ def __call__( # thesea modified for text prompt mask if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): hidden_states_list = [] + context_states_list = [] print(f'text_masks shape={text_masks.shape}') for mask, query, key, value in zip(text_masks, queries, keys, values): tmp_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) @@ -1524,8 +1525,9 @@ def __call__( ) print(f'mask_downsample shape={mask_downsample.shape}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - tmp_hidden_states[:,:residual.shape[1],:] = tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample - hidden_states_list.append(tmp_hidden_states) + #tmp_hidden_states[:,:residual.shape[1],:] = tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample + hidden_states_list.append(tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample) + context_states_list.append(tmp_hidden_states[:, residual.shape[1] :]) hidden_states_list = torch.stack(hidden_states_list) hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) @@ -1536,12 +1538,21 @@ def __call__( if encoder_hidden_states is not None: # Split the attention outputs. - hidden_states, encoder_hidden_states = ( - hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], - ) - if not attn.context_pre_only: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + # thesea modified for text prompt mask + if encoder_hidden_states.ndim == 3: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + elif encoder_hidden_states.ndim == 4: + if not attn.context_pre_only: + encoder_hidden_states_list = [] + for index in range(encoder_hidden_states.shape[1]): + tmp_encoder_hidden_states = attn.to_add_out(context_states_list[index]) + encoder_hidden_states_list.append(tmp_encoder_hidden_states) + encoder_hidden_states = torch.stack(encoder_hidden_states_list, dim=1) # linear proj hidden_states = attn.to_out[0](hidden_states) From adc76078cb192b1dfc2f427ef52ef09029f33751 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 17:35:22 +0900 Subject: [PATCH 092/482] fix bug --- src/diffusers/models/attention.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b03d05614e83..9f89084c4f9d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -262,9 +262,14 @@ def forward( else: print(f'attention encoder_hidden_states shape={encoder_hidden_states.shape}') print(f'attention context_attn_output shape={context_attn_output.shape}') - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + if len(encoder_hidden_states.shape) == 3: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + elif len(encoder_hidden_states.shape) == 4: + for index in range(encoder_hidden_states.shape[1]): + context_attn_output[:,index,:,:] = c_gate_msa.unsqueeze(1) * context_attn_output[:,index,:,:] + encoder_hidden_states = encoder_hidden_states + context_attn_output - + # thesea modified for text prompt mask if len(encoder_hidden_states.shape) == 3: norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) From ffe4bef40faf23f1566c3f6c65c472474c34bd33 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 5 Mar 2025 18:17:06 +0900 Subject: [PATCH 093/482] remove print info --- src/diffusers/models/attention.py | 9 +++-- src/diffusers/models/attention_processor.py | 40 ++++++++++--------- .../models/transformers/transformer_sd3.py | 1 - .../pipeline_stable_diffusion_3_controlnet.py | 1 - .../pipeline_stable_diffusion_3.py | 4 -- 5 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9f89084c4f9d..bb4a7075f6da 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -202,8 +202,6 @@ def forward( else: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - print(f'attention JointTransformerBlock self.context_pre_only={self.context_pre_only}') - print(f'attention encoder_hidden_states shape={encoder_hidden_states.shape}') if self.context_pre_only: # thesea modifed for text prompt mask if len(encoder_hidden_states.shape) == 3: @@ -215,6 +213,7 @@ def forward( norm_encoder_hidden_states = torch.stack(norm_encoder_hidden_states, dim=1) else: + # thesea modifed for text prompt mask if len(encoder_hidden_states.shape) == 3: norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb @@ -260,8 +259,7 @@ def forward( if self.context_pre_only: encoder_hidden_states = None else: - print(f'attention encoder_hidden_states shape={encoder_hidden_states.shape}') - print(f'attention context_attn_output shape={context_attn_output.shape}') + # thesea modified for text prompt mask if len(encoder_hidden_states.shape) == 3: context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output elif len(encoder_hidden_states.shape) == 4: @@ -284,6 +282,7 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory + # thesea modified for text prompt mask if len(encoder_hidden_states.shape) == 3: context_ff_output = _chunked_feed_forward( self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size @@ -297,6 +296,7 @@ def forward( context_ff_output.append(tmp_context_ff_output) context_ff_output = torch.stack(context_ff_output, dim=1) else: + # thesea modified for text prompt mask if len(encoder_hidden_states.shape) == 3: context_ff_output = self.ff_context(norm_encoder_hidden_states) elif len(encoder_hidden_states.shape) == 4: @@ -306,6 +306,7 @@ def forward( context_ff_output.append(tmp_context_ff_output) context_ff_output = torch.stack(context_ff_output, dim=1) + # thesea modified for text prompt mask if len(encoder_hidden_states.shape) == 3: encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output elif len(encoder_hidden_states.shape) == 4: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c6920f74f00f..6f33ef9ba54f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1446,6 +1446,7 @@ def __call__( # `context` projections. if encoder_hidden_states is not None: + # thesea modification for text prompt mask if encoder_hidden_states.ndim == 3: encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -1510,22 +1511,18 @@ def __call__( if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): hidden_states_list = [] context_states_list = [] - print(f'text_masks shape={text_masks.shape}') for mask, query, key, value in zip(text_masks, queries, keys, values): tmp_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) tmp_hidden_states = tmp_hidden_states.to(query.dtype) - print(f'mask shape={mask.shape}') mask_downsample = IPAdapterMaskProcessor.downsample( mask, batch_size, residual.shape[1], tmp_hidden_states.shape[2], ) - print(f'mask_downsample shape={mask_downsample.shape}') mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - #tmp_hidden_states[:,:residual.shape[1],:] = tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample hidden_states_list.append(tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample) context_states_list.append(tmp_hidden_states[:, residual.shape[1] :]) @@ -1560,8 +1557,6 @@ def __call__( hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: - print(f'attention processor hidden_states shape={hidden_states.shape}') - print(f'attention processor encoder_hidden_states shape={encoder_hidden_states.shape}') return hidden_states, encoder_hidden_states else: return hidden_states @@ -5754,8 +5749,6 @@ def __call__( # `context` projections. # thesea modified for text prompt mask - print(f'attention processor encoder_hidden_states.ndim={encoder_hidden_states.ndim}') - print(f'attention processor encoder_hidden_states shape={encoder_hidden_states.shape}') if encoder_hidden_states is not None: if encoder_hidden_states.ndim == 3: encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) @@ -5820,6 +5813,7 @@ def __call__( # thesea modified for text prompt mask if (encoder_hidden_states is not None) and (encoder_hidden_states.ndim == 4): hidden_states_list = [] + context_states_list = [] for mask, query, key, value in zip(text_masks, queries, keys, values): tmp_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -5828,11 +5822,12 @@ def __call__( mask_downsample = IPAdapterMaskProcessor.downsample( mask, batch_size, - tmp_hidden_states.shape[1], + residual.shape[1], tmp_hidden_states.shape[2], ) mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_list.append(tmp_hidden_states * mask_downsample) + hidden_states_list.append(tmp_hidden_states[:,:residual.shape[1],:] * mask_downsample) + context_states_list.append(tmp_hidden_states[:, residual.shape[1] :]) hidden_states_list = torch.stack(hidden_states_list) hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) @@ -5843,13 +5838,22 @@ def __call__( if encoder_hidden_states is not None: # Split the attention outputs. - hidden_states, encoder_hidden_states = ( - hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], - ) - if not attn.context_pre_only: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - + # thesea modified for text prompt mask + if encoder_hidden_states.ndim == 3: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + elif encoder_hidden_states.ndim == 4: + if not attn.context_pre_only: + encoder_hidden_states_list = [] + for index in range(encoder_hidden_states.shape[1]): + tmp_encoder_hidden_states = attn.to_add_out(context_states_list[index]) + encoder_hidden_states_list.append(tmp_encoder_hidden_states) + encoder_hidden_states = torch.stack(encoder_hidden_states_list, dim=1) + # IP Adapter if self.scale != 0 and ip_hidden_states is not None: # Norm image features @@ -5884,8 +5888,6 @@ def __call__( hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: - print(f'attention processor hidden_states shape={hidden_states.shape}') - print(f'attention processor encoder_hidden_states shape={encoder_hidden_states.shape}') return hidden_states, encoder_hidden_states else: return hidden_states diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 87823da1ce31..7219021b3b7a 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -373,7 +373,6 @@ def forward( hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) # thesea modifed for text prompt mask - print(f'transformer_sd3 encoder_hidden_states.shape={encoder_hidden_states.shape}') if len(encoder_hidden_states.shape) == 3: encoder_hidden_states = self.context_embedder(encoder_hidden_states) else: diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 7c7428635dea..3246b5cf42de 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -1225,7 +1225,6 @@ def __call__( return_dict=False, )[0] - print(f'pipeline prompt_embeds_list shape={prompt_embeds_list.shape}') noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 957dff2e004b..e88582af9078 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -982,10 +982,6 @@ def __call__( negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, - #prompt_embeds=prompt_embeds, - #negative_prompt_embeds=negative_prompt_embeds, - #pooled_prompt_embeds=pooled_prompt_embeds, - #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, From b653c663b694206f1dfedbba600712b1a06403c6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 20 Mar 2025 23:51:00 +0900 Subject: [PATCH 094/482] add codes for flux ip mask --- src/diffusers/models/attention_processor.py | 86 +++++++++++++++++---- 1 file changed, 69 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 731e5eca1e12..a50dc2c39ae5 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2866,25 +2866,77 @@ def __call__( # IP-adapter ip_query = hidden_states_query_proj ip_attn_output = torch.zeros_like(hidden_states) + + # thesea modification for ip mask + if ip_adapter_masks is not None: + if not (len(ip_adapter_masks) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, ip_state) in enumerate(zip(ip_adapter_masks, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(ip_hidden_states) - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) + for current_masks, current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_adapter_masks, ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + if mask is not None: + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _current_ip_hidden_states = F.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + ip_attn_output += scale * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states return hidden_states, encoder_hidden_states, ip_attn_output else: From 52b4a3c9cf9eb78f5280dcced4059e00ba658111 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:11:45 +0900 Subject: [PATCH 095/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index f53958df2ed0..f8a3d76c3678 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -443,6 +443,7 @@ def __call__( image_embeds = self.image_embedder(image_latents).image_embeds image_embeds = image_embeds.to(device=device) + print(f'image_embeds shape={image_embeds.shape}') # 3. Prepare (dummy) text embeddings if hasattr(self, "text_encoder") and self.text_encoder is not None: @@ -471,18 +472,22 @@ def __call__( # pooled_prompt_embeds is 768, clip text encoder hidden size pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) + print(f'prompt_embeds shape before cat={prompt_embeds.shape}') # scale & concatenate image and text embeddings prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + print(f'prompt_embeds shape after cat={prompt_embeds.shape}') prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ :, None ] + print(f'prompt_embeds shape before sum={prompt_embeds.shape}') # weighted sum prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True) + print(f'prompt_embeds shape after sum={prompt_embeds.shape}') # Offload all models self.maybe_free_model_hooks() From 51c7c16c3e46f539b1307b70008e2abe24809281 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:31:50 +0900 Subject: [PATCH 096/482] print info --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a50dc2c39ae5..8607168cec27 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,7 +2359,7 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - + print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) From 212e4f6e61e2f7a93d434c2fc89df07a92ff84a5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:36:55 +0900 Subject: [PATCH 097/482] print info --- src/diffusers/models/attention_processor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8607168cec27..3c34e86083ef 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,7 +2359,7 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') + # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -2379,6 +2379,8 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: + print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') + encoder_hidden_states = encoder_hidden_states[:,0:512,:] # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) From 305188a8d1419d8782a2958370139c49a1e83de9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:42:07 +0900 Subject: [PATCH 098/482] print info --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3c34e86083ef..e2aff513d8cb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2380,7 +2380,7 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') - encoder_hidden_states = encoder_hidden_states[:,0:512,:] + #encoder_hidden_states = encoder_hidden_states[:,0:512,:] # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) From 518be7685fa52b675d6997b4ecd77089f1312994 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:48:49 +0900 Subject: [PATCH 099/482] print info --- src/diffusers/models/attention_processor.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2aff513d8cb..9a504b223b3c 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2380,11 +2380,10 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') - #encoder_hidden_states = encoder_hidden_states[:,0:512,:] # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim From 747d97c8ccce83d1b9f2e2edb9d794f9759efa00 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:51:33 +0900 Subject: [PATCH 100/482] print --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9a504b223b3c..c64904247662 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2380,6 +2380,7 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') + print(f'encoder_hidden_states[:,0:512,:] shape={encoder_hidden_states[:,0:512,:].shape}') # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) From 6a14bfed5fa310076eca829379f13752a51fcdd5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:55:32 +0900 Subject: [PATCH 101/482] print info --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c64904247662..25d687b845f6 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2406,6 +2406,7 @@ def __call__( key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + print(f'image_rotary_emb shape={image_rotary_emb.shape}') if image_rotary_emb is not None: from .embeddings import apply_rotary_emb From 48d86143da390ee65bf8d5eff491c12904cc5ede Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 15:58:31 +0900 Subject: [PATCH 102/482] print --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 25d687b845f6..7f771c97eabf 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2406,7 +2406,7 @@ def __call__( key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - print(f'image_rotary_emb shape={image_rotary_emb.shape}') + print(f'query shape={query.shape}') if image_rotary_emb is not None: from .embeddings import apply_rotary_emb From 9424142cd1a9d50e5b8a85dad4bd4ca47a98e2fd Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 16:04:20 +0900 Subject: [PATCH 103/482] print info --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7f771c97eabf..2802721d34cb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2385,6 +2385,7 @@ def __call__( encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) + print(f'encoder_hidden_states_query_proj shape={encoder_hidden_states_query_proj.shape}') encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim From f8863d33425eb3a39b005d7a21b31c1ccd84c2d1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 16:29:31 +0900 Subject: [PATCH 104/482] add two copies of embeds --- src/diffusers/models/attention_processor.py | 37 +++++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2802721d34cb..8d27b77c56d2 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2379,14 +2379,14 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: - print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') - print(f'encoder_hidden_states[:,0:512,:] shape={encoder_hidden_states[:,0:512,:].shape}') # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) - print(f'encoder_hidden_states_query_proj shape={encoder_hidden_states_query_proj.shape}') + img_encoder_hidden_states = copy.deepcopy(encoder_hidden_states) + encoder_hidden_states[:,512:,:] = 0 + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) @@ -2402,12 +2402,35 @@ def __call__( if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + # for img embeds + img_encoder_hidden_states[:,0:512,:] = 0 + img_encoder_hidden_states_query_proj = attn.add_q_proj(img_encoder_hidden_states) + img_encoder_hidden_states_key_proj = attn.add_k_proj(img_encoder_hidden_states) + img_encoder_hidden_states_value_proj = attn.add_v_proj(img_encoder_hidden_states) + + img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) + + encoder_hidden_states_query_proj += img_encoder_hidden_states_query_proj + encoder_hidden_states_key_proj += img_encoder_hidden_states_key_proj + encoder_hidden_states_value_proj += img_encoder_hidden_states_value_proj # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - print(f'query shape={query.shape}') if image_rotary_emb is not None: from .embeddings import apply_rotary_emb From e0bdcdf52aee0d01f644f3b8e2ae93b7e7204261 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 16:57:05 +0900 Subject: [PATCH 105/482] fix --- src/diffusers/models/attention_processor.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8d27b77c56d2..b0468dd1d894 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2382,7 +2382,7 @@ def __call__( # `context` projections. img_encoder_hidden_states = copy.deepcopy(encoder_hidden_states) - encoder_hidden_states[:,512:,:] = 0 + #encoder_hidden_states[:,512:,:] = 0 encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) @@ -2403,7 +2403,7 @@ def __call__( encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # for img embeds - img_encoder_hidden_states[:,0:512,:] = 0 + #img_encoder_hidden_states[:,0:512,:] = 0 img_encoder_hidden_states_query_proj = attn.add_q_proj(img_encoder_hidden_states) img_encoder_hidden_states_key_proj = attn.add_k_proj(img_encoder_hidden_states) img_encoder_hidden_states_value_proj = attn.add_v_proj(img_encoder_hidden_states) @@ -2423,6 +2423,14 @@ def __call__( if attn.norm_added_k is not None: img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) + encoder_hidden_states_query_proj[:,512:,:] = 0 + encoder_hidden_states_key_proj[:,512:,:] = 0 + encoder_hidden_states_value_proj[:,512:,:] = 0 + + img_encoder_hidden_states_query_proj[:,0:512,:] = 0 + img_encoder_hidden_states_key_proj[:,0:512,:] = 0 + img_encoder_hidden_states_value_proj[:,0:512,:] = 0 + encoder_hidden_states_query_proj += img_encoder_hidden_states_query_proj encoder_hidden_states_key_proj += img_encoder_hidden_states_key_proj encoder_hidden_states_value_proj += img_encoder_hidden_states_value_proj From 9002010fd9d1a8ca50d8d56f2d8fd38accab89fe Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:09:16 +0900 Subject: [PATCH 106/482] fix --- src/diffusers/models/attention_processor.py | 28 +++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b0468dd1d894..acd23e89d967 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2426,18 +2426,23 @@ def __call__( encoder_hidden_states_query_proj[:,512:,:] = 0 encoder_hidden_states_key_proj[:,512:,:] = 0 encoder_hidden_states_value_proj[:,512:,:] = 0 - + img_encoder_hidden_states_query_proj[:,0:512,:] = 0 img_encoder_hidden_states_key_proj[:,0:512,:] = 0 img_encoder_hidden_states_value_proj[:,0:512,:] = 0 - encoder_hidden_states_query_proj += img_encoder_hidden_states_query_proj - encoder_hidden_states_key_proj += img_encoder_hidden_states_key_proj - encoder_hidden_states_value_proj += img_encoder_hidden_states_value_proj + #encoder_hidden_states_query_proj += img_encoder_hidden_states_query_proj + #encoder_hidden_states_key_proj += img_encoder_hidden_states_key_proj + #encoder_hidden_states_value_proj += img_encoder_hidden_states_value_proj # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) + img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) + img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) + + query = img_query + torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = img_key + torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = img_value + torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb @@ -2445,9 +2450,12 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if encoder_hidden_states is not None: + + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 6fe15960757bbd5745eaaefb3f8ed62a9e04bbc5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:10:51 +0900 Subject: [PATCH 107/482] fix bug --- src/diffusers/models/attention_processor.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index acd23e89d967..56cfa70f83fd 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2450,12 +2450,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - if encoder_hidden_states is not None: - - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 3068762f198e2697c2a28cf646efd4cf8283cd47 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:22:13 +0900 Subject: [PATCH 108/482] fix --- src/diffusers/models/attention_processor.py | 63 +++++++-------------- 1 file changed, 19 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 56cfa70f83fd..2875b4afaa2c 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2380,9 +2380,6 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. - img_encoder_hidden_states = copy.deepcopy(encoder_hidden_states) - - #encoder_hidden_states[:,512:,:] = 0 encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) @@ -2402,47 +2399,15 @@ def __call__( if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - # for img embeds - #img_encoder_hidden_states[:,0:512,:] = 0 - img_encoder_hidden_states_query_proj = attn.add_q_proj(img_encoder_hidden_states) - img_encoder_hidden_states_key_proj = attn.add_k_proj(img_encoder_hidden_states) - img_encoder_hidden_states_value_proj = attn.add_v_proj(img_encoder_hidden_states) - - img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) - - encoder_hidden_states_query_proj[:,512:,:] = 0 - encoder_hidden_states_key_proj[:,512:,:] = 0 - encoder_hidden_states_value_proj[:,512:,:] = 0 - - img_encoder_hidden_states_query_proj[:,0:512,:] = 0 - img_encoder_hidden_states_key_proj[:,0:512,:] = 0 - img_encoder_hidden_states_value_proj[:,0:512,:] = 0 - - #encoder_hidden_states_query_proj += img_encoder_hidden_states_query_proj - #encoder_hidden_states_key_proj += img_encoder_hidden_states_key_proj - #encoder_hidden_states_value_proj += img_encoder_hidden_states_value_proj # attention - img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) - img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) - img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) + #img_query = torch.cat([encoder_hidden_states_query_proj[:,512:,:], query[:,512:,:]], dim=2) + #img_key = torch.cat([encoder_hidden_states_key_proj[:,512:,:], key[:,512:,:]], dim=2) + #img_value = torch.cat([encoder_hidden_states_value_proj[:,512:,:], value[:,512:,:]], dim=2), [:,0:512,:] - query = img_query + torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = img_key + torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = img_value + torch.cat([encoder_hidden_states_value_proj, value], dim=2) + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb @@ -2450,9 +2415,19 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if encoder_hidden_states is not None: + print(f'key={key.shape}, query={query.shape}, value={value.shape}') + img_hidden_states = F.scaled_dot_product_attention( + query[:,512:,:], key[:,512:,:], value[:,512:,:], attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = F.scaled_dot_product_attention( + query[:,0:512,:], key[:,0:512,:], value[:,0:512,:], attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = torch.cat([hidden_states, img_hidden_states], dim=1) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 7344ef83e014824436befb144458d0b43d2ed15e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:25:39 +0900 Subject: [PATCH 109/482] print info --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2875b4afaa2c..adb163b67e9a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2420,9 +2420,11 @@ def __call__( img_hidden_states = F.scaled_dot_product_attention( query[:,512:,:], key[:,512:,:], value[:,512:,:], attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + print(f'img_hidden_states shape={img_hidden_states.shape}') hidden_states = F.scaled_dot_product_attention( query[:,0:512,:], key[:,0:512,:], value[:,0:512,:], attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + print(f'hidden_states shape={hidden_states.shape}') hidden_states = torch.cat([hidden_states, img_hidden_states], dim=1) else: hidden_states = F.scaled_dot_product_attention( From 2f55d11f04fabd71f8d2b5a88bdb29bee751c97f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:32:51 +0900 Subject: [PATCH 110/482] remove info --- src/diffusers/models/attention_processor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index adb163b67e9a..9e464c56cc92 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2383,7 +2383,7 @@ def __call__( encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) @@ -2400,11 +2400,6 @@ def __call__( encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention - - #img_query = torch.cat([encoder_hidden_states_query_proj[:,512:,:], query[:,512:,:]], dim=2) - #img_key = torch.cat([encoder_hidden_states_key_proj[:,512:,:], key[:,512:,:]], dim=2) - #img_value = torch.cat([encoder_hidden_states_value_proj[:,512:,:], value[:,512:,:]], dim=2), [:,0:512,:] - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) From e736c5361f706e05a4c3c0cb42333c12a3e79270 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:39:23 +0900 Subject: [PATCH 111/482] fix --- src/diffusers/models/attention_processor.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9e464c56cc92..8869f3cb79d1 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2410,21 +2410,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - if encoder_hidden_states is not None: - print(f'key={key.shape}, query={query.shape}, value={value.shape}') - img_hidden_states = F.scaled_dot_product_attention( - query[:,512:,:], key[:,512:,:], value[:,512:,:], attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - print(f'img_hidden_states shape={img_hidden_states.shape}') - hidden_states = F.scaled_dot_product_attention( - query[:,0:512,:], key[:,0:512,:], value[:,0:512,:], attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - print(f'hidden_states shape={hidden_states.shape}') - hidden_states = torch.cat([hidden_states, img_hidden_states], dim=1) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 9f50ce76aa97df9d47302eb91d0450ff2a448dc2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:42:10 +0900 Subject: [PATCH 112/482] print info --- src/diffusers/models/attention_processor.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8869f3cb79d1..58694328e0bd 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2410,9 +2410,15 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if encoder_hidden_states is not None: + print(f'key={key.shape}, query={query.shape}, value={value.shape}') + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 5d1c553c1a207e7b3fd1f30505c8a815c3c66a25 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:53:41 +0900 Subject: [PATCH 113/482] fix --- src/diffusers/models/attention_processor.py | 37 +++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 58694328e0bd..2f99889f32e6 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2384,6 +2384,18 @@ def __call__( encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + img_encoder_hidden_states_query_proj = copy.deepcopy(encoder_hidden_states_query_proj) + img_encoder_hidden_states_key_proj = copy.deepcopy(encoder_hidden_states_key_proj) + img_encoder_hidden_states_value_proj = copy.deepcopy(encoder_hidden_states_value_proj) + + encoder_hidden_states_query_proj[:,512:,:] = 0 + encoder_hidden_states_key_proj[:,512:,:] = 0 + encoder_hidden_states_value_proj[:,512:,:] = 0 + + img_encoder_hidden_states_query_proj[:,0:512,:] = 0 + img_encoder_hidden_states_key_proj[:,0:512,:] = 0 + img_encoder_hidden_states_value_proj[:,0:512,:] = 0 + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) @@ -2394,16 +2406,32 @@ def __call__( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) + img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) + img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) + img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) + if image_rotary_emb is not None: from .embeddings import apply_rotary_emb @@ -2411,11 +2439,16 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) if encoder_hidden_states is not None: - print(f'key={key.shape}, query={query.shape}, value={value.shape}') hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - else: + + img_hidden_states = F.scaled_dot_product_attention( + img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states += img_hidden_states + else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) From aeccf412d4a23a23ee4421df44a3cddd386a4488 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 17:57:12 +0900 Subject: [PATCH 114/482] fix --- src/diffusers/models/attention_processor.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2f99889f32e6..a84d9f8d1b32 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2424,9 +2424,9 @@ def __call__( img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + txt_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + txt_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + txt_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) @@ -2435,12 +2435,15 @@ def __call__( if image_rotary_emb is not None: from .embeddings import apply_rotary_emb - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + txt_query = apply_rotary_emb(txt_query, image_rotary_emb) + txt_key = apply_rotary_emb(txt_key, image_rotary_emb) + + img_query = apply_rotary_emb(img_query, image_rotary_emb) + img_key = apply_rotary_emb(img_key, image_rotary_emb) if encoder_hidden_states is not None: hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + txt_query, txt_key, txt_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) img_hidden_states = F.scaled_dot_product_attention( From acb6d77b2bc38b5154011abd30b1a7bb36bd6afc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 18:00:43 +0900 Subject: [PATCH 115/482] fix --- src/diffusers/models/attention_processor.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a84d9f8d1b32..4ff1121fba57 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2435,11 +2435,15 @@ def __call__( if image_rotary_emb is not None: from .embeddings import apply_rotary_emb - txt_query = apply_rotary_emb(txt_query, image_rotary_emb) - txt_key = apply_rotary_emb(txt_key, image_rotary_emb) + if encoder_hidden_states is not None: + txt_query = apply_rotary_emb(txt_query, image_rotary_emb) + txt_key = apply_rotary_emb(txt_key, image_rotary_emb) - img_query = apply_rotary_emb(img_query, image_rotary_emb) - img_key = apply_rotary_emb(img_key, image_rotary_emb) + img_query = apply_rotary_emb(img_query, image_rotary_emb) + img_key = apply_rotary_emb(img_key, image_rotary_emb) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) if encoder_hidden_states is not None: hidden_states = F.scaled_dot_product_attention( From b2450c83ea59c38825d85b72105c47a8a88e19d0 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 23:03:52 +0900 Subject: [PATCH 116/482] print info --- src/diffusers/models/attention_processor.py | 67 +++++---------------- 1 file changed, 14 insertions(+), 53 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4ff1121fba57..41d7a1177c93 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2380,21 +2380,11 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. + print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - img_encoder_hidden_states_query_proj = copy.deepcopy(encoder_hidden_states_query_proj) - img_encoder_hidden_states_key_proj = copy.deepcopy(encoder_hidden_states_key_proj) - img_encoder_hidden_states_value_proj = copy.deepcopy(encoder_hidden_states_value_proj) - - encoder_hidden_states_query_proj[:,512:,:] = 0 - encoder_hidden_states_key_proj[:,512:,:] = 0 - encoder_hidden_states_value_proj[:,512:,:] = 0 - - img_encoder_hidden_states_query_proj[:,0:512,:] = 0 - img_encoder_hidden_states_key_proj[:,0:512,:] = 0 - img_encoder_hidden_states_value_proj[:,0:512,:] = 0 + print(f'encoder_hidden_states_query_proj before view shape={encoder_hidden_states_query_proj.shape}') encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim @@ -2406,59 +2396,30 @@ def __call__( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) - img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) + print(f'encoder_hidden_states_query_proj after view shape={encoder_hidden_states_query_proj.shape}') if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) # attention - txt_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - txt_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - txt_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) - img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) - img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) + print(f'query before cat shape={query.shape}') + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + print(f'query after cat shape={query.shape}') if image_rotary_emb is not None: from .embeddings import apply_rotary_emb - if encoder_hidden_states is not None: - txt_query = apply_rotary_emb(txt_query, image_rotary_emb) - txt_key = apply_rotary_emb(txt_key, image_rotary_emb) - - img_query = apply_rotary_emb(img_query, image_rotary_emb) - img_key = apply_rotary_emb(img_key, image_rotary_emb) - else: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if encoder_hidden_states is not None: - hidden_states = F.scaled_dot_product_attention( - txt_query, txt_key, txt_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - img_hidden_states = F.scaled_dot_product_attention( - img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + print(f'query after apply_rotary_emb shape={query.shape}') - hidden_states += img_hidden_states - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 367397cfc471a755383967b0032f1f6d76ad9276 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 23:39:35 +0900 Subject: [PATCH 117/482] print info --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 41d7a1177c93..250586827de2 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2367,6 +2367,7 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads + print(f'attn.heads = {attn.heads}') query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) From 38bce7d6bf611f0ba8859a0f13a158d7d2555784 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 23:45:39 +0900 Subject: [PATCH 118/482] print info --- src/diffusers/models/attention_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 250586827de2..cc47d8367a98 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2382,9 +2382,9 @@ def __call__( if encoder_hidden_states is not None: # `context` projections. print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) print(f'encoder_hidden_states_query_proj before view shape={encoder_hidden_states_query_proj.shape}') encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( From 53c0f7fd43f7d0447c705f83dc87b82371e3382e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 21 Mar 2025 23:52:57 +0900 Subject: [PATCH 119/482] print info --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cc47d8367a98..1ef45ccfd2ae 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2414,6 +2414,8 @@ def __call__( if image_rotary_emb is not None: from .embeddings import apply_rotary_emb + cos, sin = image_rotary_emb + print(f'cos shape={cos.shape}, sin shape={sin.shape}') query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) print(f'query after apply_rotary_emb shape={query.shape}') From 1e2d9f10e5b32a7dba4677c269c61cb8fc54d95e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 00:05:14 +0900 Subject: [PATCH 120/482] add codes --- src/diffusers/models/attention_processor.py | 53 +++++++++++++++++---- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1ef45ccfd2ae..ce7278755e4b 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2385,6 +2385,10 @@ def __call__( encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) + + img_encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,512:,:]) + img_encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,512:,:]) + img_encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,512:,:]) print(f'encoder_hidden_states_query_proj before view shape={encoder_hidden_states_query_proj.shape}') encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( @@ -2397,32 +2401,61 @@ def __call__( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) + img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + print(f'encoder_hidden_states_query_proj after view shape={encoder_hidden_states_query_proj.shape}') if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) # attention print(f'query before cat shape={query.shape}') - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + txt_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + txt_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + txt_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) + img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) + img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) print(f'query after cat shape={query.shape}') if image_rotary_emb is not None: from .embeddings import apply_rotary_emb cos, sin = image_rotary_emb - print(f'cos shape={cos.shape}, sin shape={sin.shape}') - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - print(f'query after apply_rotary_emb shape={query.shape}') + #print(f'cos shape={cos.shape}, sin shape={sin.shape}') + txt_query = apply_rotary_emb(txt_query, (cos[0:4608,:], sin[0:4608,:])) #image_rotary_emb + txt_key = apply_rotary_emb(txt_key, (cos[0:4608,:], sin[0:4608,:])) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + img_query = apply_rotary_emb(img_query, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) #image_rotary_emb + img_key = apply_rotary_emb(img_key, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) + + if encoder_hidden_states is not None: + txt_hidden_states = F.scaled_dot_product_attention( + txt_query, txt_key, txt_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + img_hidden_states = F.scaled_dot_product_attention( + img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + print(f'txt_hidden_states shape={txt_hidden_states.shape}') + print(f'img_hidden_states shape={img_hidden_states.shape}') + hidden_states = torch.cat([txt_hidden_states, img_hidden_states],dim=1) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From faf853ee50cd1b62ebfc8197bdb4e77d6395cbc3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 00:16:51 +0900 Subject: [PATCH 121/482] print info --- src/diffusers/models/attention_processor.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ce7278755e4b..cc790e33a2c2 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2449,16 +2449,25 @@ def __call__( img_hidden_states = F.scaled_dot_product_attention( img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - print(f'txt_hidden_states shape={txt_hidden_states.shape}') - print(f'img_hidden_states shape={img_hidden_states.shape}') + print(f'txt_hidden_states before shape={txt_hidden_states.shape}') + print(f'img_hidden_states before shape={img_hidden_states.shape}') + #hidden_states = torch.cat([txt_hidden_states, img_hidden_states],dim=1) + + txt_hidden_states = txt_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + txt_hidden_states = txt_hidden_states.to(query.dtype) + + img_hidden_states = img_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + img_hidden_states = img_hidden_states.to(query.dtype) + + print(f'txt_hidden_states after shape={txt_hidden_states.shape}') + print(f'img_hidden_states after shape={img_hidden_states.shape}') hidden_states = torch.cat([txt_hidden_states, img_hidden_states],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( From f49463ec51e1e34b64fb2a801546a35f7a629ab2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 00:44:16 +0900 Subject: [PATCH 122/482] add codes --- src/diffusers/models/attention_processor.py | 24 ++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cc790e33a2c2..19396a772c2c 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2357,6 +2357,8 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + txt_mask: Optional[torch.Tensor] = None, # thesea modification for text prompt mask + img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2461,7 +2463,27 @@ def __call__( print(f'txt_hidden_states after shape={txt_hidden_states.shape}') print(f'img_hidden_states after shape={img_hidden_states.shape}') - hidden_states = torch.cat([txt_hidden_states, img_hidden_states],dim=1) + + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_mask, + batch_size, + hidden_states.shape[1], + hidden_states.shape[2], + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + + img_mask_downsample = IPAdapterMaskProcessor.downsample( + img_mask, + batch_size, + hidden_states.shape[1], + hidden_states.shape[2], + ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + + masked_txt_hidden_states = txt_hidden_states[:,0:hidden_states.shape[1],:] * txt_mask_downsample + masked_img_hidden_states = img_hidden_states[:,0:hidden_states.shape[1],:] * img_mask_downsample + + hidden_states = torch.cat([masked_txt_hidden_states + masked_img_hidden_states, txt_hidden_states[:,hidden_states.shape[1]:,:], img_hidden_states[:,hidden_states.shape[1]:,:]],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From 1021c34bc28607d22e8002c2a935eb9e92f67bdc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 12:28:29 +0900 Subject: [PATCH 123/482] print info --- src/diffusers/models/attention_processor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 19396a772c2c..ef969052ed9d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2480,10 +2480,13 @@ def __call__( ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_txt_hidden_states = txt_hidden_states[:,0:hidden_states.shape[1],:] * txt_mask_downsample - masked_img_hidden_states = img_hidden_states[:,0:hidden_states.shape[1],:] * img_mask_downsample + masked_txt_hidden_states = txt_hidden_states[:,encoder_hidden_states.shape[1]:,:] * txt_mask_downsample + masked_img_hidden_states = img_hidden_states[:,encoder_hidden_states.shape[1]:,:] * img_mask_downsample - hidden_states = torch.cat([masked_txt_hidden_states + masked_img_hidden_states, txt_hidden_states[:,hidden_states.shape[1]:,:], img_hidden_states[:,hidden_states.shape[1]:,:]],dim=1) + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) + print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') + print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') + print(f'hidden_states shape={hidden_states.shape}') else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From ede38af9691b60ded0c9610782f30c0ed2278e7f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 16:44:30 +0900 Subject: [PATCH 124/482] add mask codes --- src/diffusers/models/attention_processor.py | 28 +++++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ef969052ed9d..839c7a6d1d5f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,6 +2359,7 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, txt_mask: Optional[torch.Tensor] = None, # thesea modification for text prompt mask img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask + ip_scale: Optional[int] = None, # thesea modification for ip mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2464,26 +2465,31 @@ def __call__( print(f'txt_hidden_states after shape={txt_hidden_states.shape}') print(f'img_hidden_states after shape={img_hidden_states.shape}') - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_mask, - batch_size, - hidden_states.shape[1], - hidden_states.shape[2], - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - + if txt_mask is not None: + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_mask[0], + batch_size, + hidden_states.shape[1], + hidden_states.shape[2], + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + masked_txt_hidden_states = txt_hidden_states[:,encoder_hidden_states.shape[1]:,:] * txt_mask_downsample + img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask, + img_mask[0], batch_size, hidden_states.shape[1], hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_txt_hidden_states = txt_hidden_states[:,encoder_hidden_states.shape[1]:,:] * txt_mask_downsample masked_img_hidden_states = img_hidden_states[:,encoder_hidden_states.shape[1]:,:] * img_mask_downsample - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) + if txt_mask is not None: + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) + else: + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,encoder_hidden_states.shape[1]:,:] + ip_scale * masked_img_hidden_states],dim=1) + print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') print(f'hidden_states shape={hidden_states.shape}') From cb6ec8b6486b5b217a0c5da6abc8ef84b3dc1f10 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:12:58 +0900 Subject: [PATCH 125/482] print info --- src/diffusers/models/transformers/transformer_flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d246..45b1c46f7639 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -147,6 +147,7 @@ def forward( encoder_hidden_states, emb=temb ) joint_attention_kwargs = joint_attention_kwargs or {} + print(f'joint_attention_kwargs={joint_attention_kwargs.keys()}') # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, From 7d4effd77e6e795d98270b38629d8cad444ae585 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:16:09 +0900 Subject: [PATCH 126/482] print info --- src/diffusers/models/controlnets/controlnet_flux.py | 2 ++ src/diffusers/models/transformers/transformer_flux.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 51c34b7fe965..698c6fc64c30 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -335,11 +335,13 @@ def forward( ) else: + print(f'controlnet_flux joint_attention_kwargs={joint_attention_kwargs.keys()}') encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) block_samples = block_samples + (hidden_states,) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 45b1c46f7639..637f8bf6910d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -147,7 +147,7 @@ def forward( encoder_hidden_states, emb=temb ) joint_attention_kwargs = joint_attention_kwargs or {} - print(f'joint_attention_kwargs={joint_attention_kwargs.keys()}') + print(f'transformer_flux joint_attention_kwargs={joint_attention_kwargs.keys()}') # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, From 04d77e2b43b9a373092da10e5ad8b01ce01e5cf4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:20:06 +0900 Subject: [PATCH 127/482] fix bug --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 839c7a6d1d5f..f27b67a00efb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2473,7 +2473,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_txt_hidden_states = txt_hidden_states[:,encoder_hidden_states.shape[1]:,:] * txt_mask_downsample + masked_txt_hidden_states = txt_hidden_states[:,512,:] * txt_mask_downsample img_mask_downsample = IPAdapterMaskProcessor.downsample( img_mask[0], @@ -2483,7 +2483,7 @@ def __call__( ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_img_hidden_states = img_hidden_states[:,encoder_hidden_states.shape[1]:,:] * img_mask_downsample + masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample if txt_mask is not None: hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) From 75cafe271e2595f3a8eb8a5906da3f16479decb8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:20:57 +0900 Subject: [PATCH 128/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f27b67a00efb..5ea538c2ad0f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2488,7 +2488,7 @@ def __call__( if txt_mask is not None: hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) else: - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,encoder_hidden_states.shape[1]:,:] + ip_scale * masked_img_hidden_states],dim=1) + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,512:,:] + ip_scale * masked_img_hidden_states],dim=1) print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') From 6f1ece8077f842f01e24fd98dc74eb56e232d3d7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:23:50 +0900 Subject: [PATCH 129/482] fix bug --- src/diffusers/models/attention_processor.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5ea538c2ad0f..9b4150b94378 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2436,14 +2436,18 @@ def __call__( if image_rotary_emb is not None: from .embeddings import apply_rotary_emb - - cos, sin = image_rotary_emb - #print(f'cos shape={cos.shape}, sin shape={sin.shape}') - txt_query = apply_rotary_emb(txt_query, (cos[0:4608,:], sin[0:4608,:])) #image_rotary_emb - txt_key = apply_rotary_emb(txt_key, (cos[0:4608,:], sin[0:4608,:])) - - img_query = apply_rotary_emb(img_query, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) #image_rotary_emb - img_key = apply_rotary_emb(img_key, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) + + if encoder_hidden_states is not None: + cos, sin = image_rotary_emb + #print(f'cos shape={cos.shape}, sin shape={sin.shape}') + txt_query = apply_rotary_emb(txt_query, (cos[0:4608,:], sin[0:4608,:])) #image_rotary_emb + txt_key = apply_rotary_emb(txt_key, (cos[0:4608,:], sin[0:4608,:])) + + img_query = apply_rotary_emb(img_query, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) #image_rotary_emb + img_key = apply_rotary_emb(img_key, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) if encoder_hidden_states is not None: txt_hidden_states = F.scaled_dot_product_attention( From 0ebdaae0a9fbd34464a885f130faafef3abdac02 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:28:22 +0900 Subject: [PATCH 130/482] print info --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9b4150b94378..1d553b628d5a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2489,6 +2489,7 @@ def __call__( masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample + print(f'hidden_states.shape={hidden_states.shape}') if txt_mask is not None: hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) else: From e4070d83a6afc5328016ae56c7fd6f12cbebc8d2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:29:16 +0900 Subject: [PATCH 131/482] print info --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1d553b628d5a..06fe05e9a72d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2490,13 +2490,13 @@ def __call__( masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample print(f'hidden_states.shape={hidden_states.shape}') + print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') + print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') if txt_mask is not None: hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) else: hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,512:,:] + ip_scale * masked_img_hidden_states],dim=1) - print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') - print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') print(f'hidden_states shape={hidden_states.shape}') else: hidden_states = F.scaled_dot_product_attention( From 6b4f83ec96d29bc263d5e440c8206c5a161cc3e7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:37:30 +0900 Subject: [PATCH 132/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 06fe05e9a72d..c5226cde4f46 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2477,7 +2477,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_txt_hidden_states = txt_hidden_states[:,512,:] * txt_mask_downsample + masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample img_mask_downsample = IPAdapterMaskProcessor.downsample( img_mask[0], From 9de33ac3d7ff04ed818009fd548d4ee3e9a933a5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 17:56:58 +0900 Subject: [PATCH 133/482] remove print --- src/diffusers/models/attention_processor.py | 123 ++++++++++-------- .../models/controlnets/controlnet_flux.py | 1 - .../models/transformers/transformer_flux.py | 1 - .../flux/pipeline_flux_prior_redux.py | 6 +- 4 files changed, 67 insertions(+), 64 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c5226cde4f46..7455bb88fd3c 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,7 +2359,7 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, txt_mask: Optional[torch.Tensor] = None, # thesea modification for text prompt mask img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask - ip_scale: Optional[int] = None, # thesea modification for ip mask + ip_scale: Optional[float] = None, # thesea modification for ip mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2370,8 +2370,7 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - print(f'attn.heads = {attn.heads}') - + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -2383,56 +2382,76 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: - # `context` projections. - print(f'encoder_hidden_states shape={encoder_hidden_states.shape}') - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) + if img_mask is not None + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) + + img_encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,512:,:]) + img_encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,512:,:]) + img_encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,512:,:]) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - img_encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,512:,:]) - img_encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,512:,:]) - img_encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,512:,:]) - print(f'encoder_hidden_states_query_proj before view shape={encoder_hidden_states_query_proj.shape}') + img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) - img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) + # attention + txt_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + txt_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + txt_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - print(f'encoder_hidden_states_query_proj after view shape={encoder_hidden_states_query_proj.shape}') + img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) + img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) + img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) + else: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - # attention - print(f'query before cat shape={query.shape}') - txt_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - txt_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - txt_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) - img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) - img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) - print(f'query after cat shape={query.shape}') + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb @@ -2456,19 +2475,13 @@ def __call__( img_hidden_states = F.scaled_dot_product_attention( img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - print(f'txt_hidden_states before shape={txt_hidden_states.shape}') - print(f'img_hidden_states before shape={img_hidden_states.shape}') - #hidden_states = torch.cat([txt_hidden_states, img_hidden_states],dim=1) - + txt_hidden_states = txt_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) txt_hidden_states = txt_hidden_states.to(query.dtype) img_hidden_states = img_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) img_hidden_states = img_hidden_states.to(query.dtype) - print(f'txt_hidden_states after shape={txt_hidden_states.shape}') - print(f'img_hidden_states after shape={img_hidden_states.shape}') - if txt_mask is not None: txt_mask_downsample = IPAdapterMaskProcessor.downsample( txt_mask[0], @@ -2489,15 +2502,11 @@ def __call__( masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - print(f'hidden_states.shape={hidden_states.shape}') - print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') - print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') if txt_mask is not None: hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) else: hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,512:,:] + ip_scale * masked_img_hidden_states],dim=1) - print(f'hidden_states shape={hidden_states.shape}') else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 698c6fc64c30..dd65fbc6d8f4 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -335,7 +335,6 @@ def forward( ) else: - print(f'controlnet_flux joint_attention_kwargs={joint_attention_kwargs.keys()}') encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 637f8bf6910d..87537890d246 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -147,7 +147,6 @@ def forward( encoder_hidden_states, emb=temb ) joint_attention_kwargs = joint_attention_kwargs or {} - print(f'transformer_flux joint_attention_kwargs={joint_attention_kwargs.keys()}') # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index f8a3d76c3678..4d28cf2ab004 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -472,22 +472,18 @@ def __call__( # pooled_prompt_embeds is 768, clip text encoder hidden size pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) - print(f'prompt_embeds shape before cat={prompt_embeds.shape}') # scale & concatenate image and text embeddings prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) - print(f'prompt_embeds shape after cat={prompt_embeds.shape}') - + prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ :, None ] - print(f'prompt_embeds shape before sum={prompt_embeds.shape}') # weighted sum prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True) - print(f'prompt_embeds shape after sum={prompt_embeds.shape}') # Offload all models self.maybe_free_model_hooks() From 81e7670b1dd8aa8a26615bda6f7c34684cd60d6e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 18:01:18 +0900 Subject: [PATCH 134/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7455bb88fd3c..d16a718e7a9f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2382,7 +2382,7 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: - if img_mask is not None + if img_mask is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) From 4ce02fd2115ff21a56ec0931a347042960afdf2c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 18:02:57 +0900 Subject: [PATCH 135/482] remove print --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 4d28cf2ab004..575fa25f8b7a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -443,8 +443,7 @@ def __call__( image_embeds = self.image_embedder(image_latents).image_embeds image_embeds = image_embeds.to(device=device) - print(f'image_embeds shape={image_embeds.shape}') - + # 3. Prepare (dummy) text embeddings if hasattr(self, "text_encoder") and self.text_encoder is not None: ( From aebe334ea5ed4b54fd5e2dbe37ac671637d73b9a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 18:10:05 +0900 Subject: [PATCH 136/482] fix bug --- src/diffusers/models/attention_processor.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d16a718e7a9f..878e038884e6 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2457,13 +2457,17 @@ def __call__( from .embeddings import apply_rotary_emb if encoder_hidden_states is not None: - cos, sin = image_rotary_emb - #print(f'cos shape={cos.shape}, sin shape={sin.shape}') - txt_query = apply_rotary_emb(txt_query, (cos[0:4608,:], sin[0:4608,:])) #image_rotary_emb - txt_key = apply_rotary_emb(txt_key, (cos[0:4608,:], sin[0:4608,:])) - - img_query = apply_rotary_emb(img_query, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) #image_rotary_emb - img_key = apply_rotary_emb(img_key, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) + if img_mask is not None: + cos, sin = image_rotary_emb + #print(f'cos shape={cos.shape}, sin shape={sin.shape}') + txt_query = apply_rotary_emb(txt_query, (cos[0:4608,:], sin[0:4608,:])) #image_rotary_emb + txt_key = apply_rotary_emb(txt_key, (cos[0:4608,:], sin[0:4608,:])) + + img_query = apply_rotary_emb(img_query, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) #image_rotary_emb + img_key = apply_rotary_emb(img_key, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) else: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) From 24faa4345f3cf24a1c4a725c964088bc5156c1e1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 18:13:49 +0900 Subject: [PATCH 137/482] fix bug --- src/diffusers/models/attention_processor.py | 74 +++++++++++---------- 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 878e038884e6..3a20d37aa79b 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2473,44 +2473,50 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) if encoder_hidden_states is not None: - txt_hidden_states = F.scaled_dot_product_attention( - txt_query, txt_key, txt_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - img_hidden_states = F.scaled_dot_product_attention( - img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - txt_hidden_states = txt_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - txt_hidden_states = txt_hidden_states.to(query.dtype) - - img_hidden_states = img_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - img_hidden_states = img_hidden_states.to(query.dtype) - - if txt_mask is not None: - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_mask[0], + if img_mask is not None: + txt_hidden_states = F.scaled_dot_product_attention( + txt_query, txt_key, txt_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + img_hidden_states = F.scaled_dot_product_attention( + img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + txt_hidden_states = txt_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + txt_hidden_states = txt_hidden_states.to(query.dtype) + + img_hidden_states = img_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + img_hidden_states = img_hidden_states.to(query.dtype) + + if txt_mask is not None: + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_mask[0], + batch_size, + hidden_states.shape[1], + hidden_states.shape[2], + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample + + img_mask_downsample = IPAdapterMaskProcessor.downsample( + img_mask[0], batch_size, hidden_states.shape[1], hidden_states.shape[2], - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample - - img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask[0], - batch_size, - hidden_states.shape[1], - hidden_states.shape[2], - ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - - masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - - if txt_mask is not None: - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) + ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + + masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample + + if txt_mask is not None: + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) + else: + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,512:,:] + ip_scale * masked_img_hidden_states],dim=1) else: - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,512:,:] + ip_scale * masked_img_hidden_states],dim=1) - + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From 72a60a2324927e62b3b222158754da0c5c1b0cc7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 20:56:58 +0900 Subject: [PATCH 138/482] print info --- src/diffusers/models/attention_processor.py | 40 ++++++++++----------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3a20d37aa79b..54f9238ffb89 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2357,12 +2357,15 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - txt_mask: Optional[torch.Tensor] = None, # thesea modification for text prompt mask img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask - ip_scale: Optional[float] = None, # thesea modification for ip mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if img_mask is not None: + txt_mask = ~ img_mask + else: + txt_mask = None + # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -2460,11 +2463,11 @@ def __call__( if img_mask is not None: cos, sin = image_rotary_emb #print(f'cos shape={cos.shape}, sin shape={sin.shape}') - txt_query = apply_rotary_emb(txt_query, (cos[0:4608,:], sin[0:4608,:])) #image_rotary_emb - txt_key = apply_rotary_emb(txt_key, (cos[0:4608,:], sin[0:4608,:])) + txt_query = apply_rotary_emb(txt_query, (torch.cat(cos[0:512,:], cos[1241:,:]), torch.cat(sin[0:512,:], sin[1241:,:]))) + txt_key = apply_rotary_emb(txt_key, (torch.cat(cos[0:512,:], cos[1241:,:]), torch.cat(sin[0:512,:], sin[1241:,:]))) - img_query = apply_rotary_emb(img_query, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) #image_rotary_emb - img_key = apply_rotary_emb(img_key, (torch.cat([cos[0:4096,:], cos[4608:,:]], dim=0), torch.cat([sin[0:4096,:], sin[4608:,:]], dim=0))) + img_query = apply_rotary_emb(img_query, (cos[512:,:],sin[512:,:])) + img_key = apply_rotary_emb(img_key, (cos[512:,:],sin[512:,:])) else: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) @@ -2487,15 +2490,14 @@ def __call__( img_hidden_states = img_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) img_hidden_states = img_hidden_states.to(query.dtype) - if txt_mask is not None: - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_mask[0], - batch_size, - hidden_states.shape[1], - hidden_states.shape[2], - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_mask[0], + batch_size, + hidden_states.shape[1], + hidden_states.shape[2], + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample img_mask_downsample = IPAdapterMaskProcessor.downsample( img_mask[0], @@ -2504,13 +2506,11 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - if txt_mask is not None: - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + ip_scale * masked_img_hidden_states],dim=1) - else: - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], txt_hidden_states[:,512:,:] + ip_scale * masked_img_hidden_states],dim=1) + print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') + print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From a71be65f68b0a024afb83c1f467f8477cdf20687 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 21:03:29 +0900 Subject: [PATCH 139/482] fix bug --- src/diffusers/models/attention_processor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 54f9238ffb89..186200dfadfb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -15,6 +15,7 @@ import math from typing import Callable, List, Optional, Tuple, Union import copy +import numpy as np import torch import torch.nn.functional as F @@ -2362,7 +2363,8 @@ def __call__( batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape if img_mask is not None: - txt_mask = ~ img_mask + txt_mask = ~ np.array(img_mask, dtype=bool) + txt_mask = txt_mask.astype(np.uint8) else: txt_mask = None From 4502f94895c4c74f1d5b719872d6fe7b548b8883 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 21:06:34 +0900 Subject: [PATCH 140/482] fix bug --- src/diffusers/models/attention_processor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 186200dfadfb..16cf57a1094a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,15 +2359,10 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask + txt_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - if img_mask is not None: - txt_mask = ~ np.array(img_mask, dtype=bool) - txt_mask = txt_mask.astype(np.uint8) - else: - txt_mask = None - # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) From 6d5d537ac3fbf174aaa0495f9694afbeee13ab15 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 21:12:25 +0900 Subject: [PATCH 141/482] fix bug --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 16cf57a1094a..a8010199097d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2460,8 +2460,8 @@ def __call__( if img_mask is not None: cos, sin = image_rotary_emb #print(f'cos shape={cos.shape}, sin shape={sin.shape}') - txt_query = apply_rotary_emb(txt_query, (torch.cat(cos[0:512,:], cos[1241:,:]), torch.cat(sin[0:512,:], sin[1241:,:]))) - txt_key = apply_rotary_emb(txt_key, (torch.cat(cos[0:512,:], cos[1241:,:]), torch.cat(sin[0:512,:], sin[1241:,:]))) + txt_query = apply_rotary_emb(txt_query, (torch.cat([cos[0:512,:], cos[1241:,:]], dim = 0), torch.cat([sin[0:512,:], sin[1241:,:]],dim=0))) + txt_key = apply_rotary_emb(txt_key, (torch.cat([cos[0:512,:], cos[1241:,:]], dim = 0), torch.cat([sin[0:512,:], sin[1241:,:]],dim=0))) img_query = apply_rotary_emb(img_query, (cos[512:,:],sin[512:,:])) img_key = apply_rotary_emb(img_key, (cos[512:,:],sin[512:,:])) From 15402a248eb0770c4e6ecd694377ef882ca3f962 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 22 Mar 2025 21:16:59 +0900 Subject: [PATCH 142/482] remove print --- src/diffusers/models/attention_processor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a8010199097d..a9d9d170d95a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2505,8 +2505,6 @@ def __call__( img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - print(f'txt_hidden_states[:,:-hidden_states.shape[1],:] shape={txt_hidden_states[:,:-hidden_states.shape[1],:].shape}') - print(f'img_hidden_states[:,:-hidden_states.shape[1],:] shape={img_hidden_states[:,:-hidden_states.shape[1],:].shape}') hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) else: hidden_states = F.scaled_dot_product_attention( From fae24ab476335c5954b8181aeca3080b92c223d6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 11:14:03 +0900 Subject: [PATCH 143/482] print info --- src/diffusers/models/attention_processor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a9d9d170d95a..f6cdf54c600b 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2503,6 +2503,9 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + unique_vals = torch.unique(img_mask_downsample) + + print(f'img_mask unique_vals: {unique_vals}') masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From 92838ee1bc10ffabb66dfe5b9e2050e87903a2fa Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 11:28:30 +0900 Subject: [PATCH 144/482] remove interpolated values --- src/diffusers/models/attention_processor.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f6cdf54c600b..2a4a94cbf76d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2503,9 +2503,16 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + + unique_vals = torch.unique(img_mask_downsample) + print(f'img_mask before unique_vals: {unique_vals}') + + indices = (img_mask_downsample > 0) & (img_mask_downsample < 1) + img_mask_downsample[indices] = 0 + unique_vals = torch.unique(img_mask_downsample) + print(f'img_mask after unique_vals: {unique_vals}') - print(f'img_mask unique_vals: {unique_vals}') masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From 438e4152231d34a1ec99711c59d935610eae271f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 11:30:45 +0900 Subject: [PATCH 145/482] remove nonzero and non-one values --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2a4a94cbf76d..1bb512ebf277 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2507,7 +2507,7 @@ def __call__( unique_vals = torch.unique(img_mask_downsample) print(f'img_mask before unique_vals: {unique_vals}') - indices = (img_mask_downsample > 0) & (img_mask_downsample < 1) + indices = (img_mask_downsample != 0.0) & (img_mask_downsample != 1.0) img_mask_downsample[indices] = 0 unique_vals = torch.unique(img_mask_downsample) From 70fcd4a7f794dc69b16c7c88bc20a2478b76622e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 11:49:28 +0900 Subject: [PATCH 146/482] 0.5 interval --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1bb512ebf277..d78dc058359a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2507,8 +2507,8 @@ def __call__( unique_vals = torch.unique(img_mask_downsample) print(f'img_mask before unique_vals: {unique_vals}') - indices = (img_mask_downsample != 0.0) & (img_mask_downsample != 1.0) - img_mask_downsample[indices] = 0 + img_mask_downsample[img_mask_downsample <= 0.5] = 0 + img_mask_downsample[img_mask_downsample > 0.5] = 1.0 unique_vals = torch.unique(img_mask_downsample) print(f'img_mask after unique_vals: {unique_vals}') From ac761b15c517cdaaf24c70ff88db76c774e89826 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 12:25:13 +0900 Subject: [PATCH 147/482] introduce img_ratio --- src/diffusers/models/attention_processor.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d78dc058359a..d93526c65fab 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2360,6 +2360,7 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask txt_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask + img_ratio: Optional[float] = None, # thesea modification for ip mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2503,19 +2504,9 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - - unique_vals = torch.unique(img_mask_downsample) - print(f'img_mask before unique_vals: {unique_vals}') - - img_mask_downsample[img_mask_downsample <= 0.5] = 0 - img_mask_downsample[img_mask_downsample > 0.5] = 1.0 - - unique_vals = torch.unique(img_mask_downsample) - print(f'img_mask after unique_vals: {unique_vals}') - masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states * (1.0/img_ratio)+ masked_img_hidden_states * img_ratio],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From b579d8dae4c4f62208470b1e3e12beb3278b9adb Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 12:36:12 +0900 Subject: [PATCH 148/482] img_ratio summed 2 --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d93526c65fab..38ac9d166491 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2506,7 +2506,7 @@ def __call__( img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states * (1.0/img_ratio)+ masked_img_hidden_states * img_ratio],dim=1) + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states * (2.0 - img_ratio)+ masked_img_hidden_states * img_ratio],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From 79603aeea82ae5547fe9e5c3847def31f0c11a77 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 12:49:22 +0900 Subject: [PATCH 149/482] img_ratio applied mask --- src/diffusers/models/attention_processor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 38ac9d166491..70ab7b652b47 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2495,7 +2495,10 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + txt_mask_downsample[txt_mask_downsample>=0.5] *= 1.0/img_ratio + masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample + img_mask_downsample = IPAdapterMaskProcessor.downsample( img_mask[0], @@ -2504,9 +2507,10 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + txt_mask_downsample[txt_mask_downsample<=0.5] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states * (2.0 - img_ratio)+ masked_img_hidden_states * img_ratio],dim=1) + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From a4f20da39749c50278081408996f8f490968f277 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 12:50:06 +0900 Subject: [PATCH 150/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 70ab7b652b47..7203e2c48cc9 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2507,7 +2507,7 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample<=0.5] *= img_ratio + img_mask_downsample[img_mask_downsample<=0.5] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From a6d2d1b5cf364cbe39325a511d633fad154cdcd2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 12:53:43 +0900 Subject: [PATCH 151/482] img ratio summed 2 --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7203e2c48cc9..a7e11680326e 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2495,7 +2495,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample>=0.5] *= 1.0/img_ratio + txt_mask_downsample[txt_mask_downsample>=0.5] *= (2.0 - img_ratio) masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample From 999174537da3b84af7082eb832336501a358ec4c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 13:07:00 +0900 Subject: [PATCH 152/482] fix bug --- src/diffusers/models/attention_processor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a7e11680326e..726366b5f0a6 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2496,7 +2496,6 @@ def __call__( ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) txt_mask_downsample[txt_mask_downsample>=0.5] *= (2.0 - img_ratio) - masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample @@ -2507,7 +2506,7 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample[img_mask_downsample<=0.5] *= img_ratio + img_mask_downsample[img_mask_downsample<0.5] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From 960529e22c12f1d218f223540463cec0c2f96cc4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 13:12:31 +0900 Subject: [PATCH 153/482] fix bug --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 726366b5f0a6..dc864bcdb873 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2495,7 +2495,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample>=0.5] *= (2.0 - img_ratio) + txt_mask_downsample[txt_mask_downsample<=0.5] *= (2.0 - img_ratio) masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample @@ -2506,7 +2506,7 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample[img_mask_downsample<0.5] *= img_ratio + img_mask_downsample[img_mask_downsample<=0.5] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From f50709e10137b8303875aaff8ff03eac7ffa1ae2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 13:20:27 +0900 Subject: [PATCH 154/482] 1.0 --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index dc864bcdb873..9c29bd4e3566 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2495,7 +2495,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample<=0.5] *= (2.0 - img_ratio) + txt_mask_downsample[txt_mask_downsample < 1.0] *= (2.0 - img_ratio) masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample @@ -2506,7 +2506,7 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample[img_mask_downsample<=0.5] *= img_ratio + img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From 18ae9168e483297add5e2ace3fcb7bc6be9bf00f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 13:28:07 +0900 Subject: [PATCH 155/482] max value --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9c29bd4e3566..3041c2de094a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2495,7 +2495,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample < 1.0] *= (2.0 - img_ratio) + txt_mask_downsample[txt_mask_downsample < torch.max(txt_mask_downsample)] *= (2.0 - img_ratio) masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample @@ -2506,7 +2506,7 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio + img_mask_downsample[img_mask_downsample < torch.max(img_mask_downsample)] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From 587394de9d253a21f31d65bf6b1e6b6cc331022b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 13:34:18 +0900 Subject: [PATCH 156/482] revert to 1.0 --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3041c2de094a..af1e725220a7 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2495,7 +2495,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample < torch.max(txt_mask_downsample)] *= (2.0 - img_ratio) + txt_mask_downsample[txt_mask_downsample <= 1.0] *= (2.0 - img_ratio) masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample @@ -2506,7 +2506,7 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample[img_mask_downsample < torch.max(img_mask_downsample)] *= img_ratio + img_mask_downsample[img_mask_downsample <= 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From 78819b511518d5d602961d25cee910965b4311dc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 13:37:51 +0900 Subject: [PATCH 157/482] 1.0 not included --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index af1e725220a7..9c29bd4e3566 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2495,7 +2495,7 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample <= 1.0] *= (2.0 - img_ratio) + txt_mask_downsample[txt_mask_downsample < 1.0] *= (2.0 - img_ratio) masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample @@ -2506,7 +2506,7 @@ def __call__( hidden_states.shape[2], ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample[img_mask_downsample <= 1.0] *= img_ratio + img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) From f18f37a9ffab29b8f415505207eefb726374a837 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 16:19:32 +0900 Subject: [PATCH 158/482] nearest --- src/diffusers/image_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index d6913f045ad2..c24f3aca6759 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -1182,7 +1182,7 @@ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embe mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) mask_w = num_queries // mask_h - mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) + mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="nearest").squeeze(0) # Repeat batch_size times if mask_downsample.shape[0] < batch_size: From 7f463688d6bc271822b3d6405bf5bc2a5951c7a9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 17:00:24 +0900 Subject: [PATCH 159/482] print info --- src/diffusers/models/attention_processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9c29bd4e3566..eb64ac8b628a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2508,7 +2508,8 @@ def __call__( img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - + print(f'unique_vals={torch.unique(img_mask_downsample)}') + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) else: hidden_states = F.scaled_dot_product_attention( From cb36eb9c2b975e3319ab2c5df3029fc3d274f8f9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 17:08:55 +0900 Subject: [PATCH 160/482] print info --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index eb64ac8b628a..05119e2ad4cd 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2509,6 +2509,8 @@ def __call__( img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample print(f'unique_vals={torch.unique(img_mask_downsample)}') + print(f'num_zeros = {(img_mask_downsample == 0.0).sum()}') + print(f'num_ones = {(img_mask_downsample == 1.0).sum()}') hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) else: From ac9b0ecee62fa0d52fee389b451d40fa90b8d2b3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 17:20:15 +0900 Subject: [PATCH 161/482] mask 2 --- src/diffusers/models/attention_processor.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 05119e2ad4cd..d6f58a0cc958 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2505,14 +2505,19 @@ def __call__( hidden_states.shape[1], hidden_states.shape[2], ) + + img_mask_downsample_2 = IPAdapterMaskProcessor.downsample( + img_mask[0], + batch_size, + 729, + hidden_states.shape[2], + ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - print(f'unique_vals={torch.unique(img_mask_downsample)}') - print(f'num_zeros = {(img_mask_downsample == 0.0).sum()}') - print(f'num_ones = {(img_mask_downsample == 1.0).sum()}') - - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) + + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:] * img_mask_downsample_2, masked_txt_hidden_states + masked_img_hidden_states],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From d55eafb45d4f7dc677a9c147e7b92920b4eb6870 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 17:24:02 +0900 Subject: [PATCH 162/482] fix bug --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d6f58a0cc958..7a4529755851 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2514,6 +2514,7 @@ def __call__( ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + img_mask_downsample_2 = img_mask_downsample_2.to(dtype=query.dtype, device=query.device) img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample From a89f2fcebe02c4731fcd32823e79e136f06af8b5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 21:41:36 +0900 Subject: [PATCH 163/482] introduce product ratio --- .../flux/pipeline_flux_prior_redux.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 575fa25f8b7a..6cf360b68a8a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -14,6 +14,7 @@ from typing import List, Optional, Union +import numpy as np import torch from PIL import Image @@ -379,6 +380,7 @@ def __call__( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, + product_ratio: Optional[float] = None, return_dict: bool = True, ): r""" @@ -426,7 +428,7 @@ def __call__( if image is not None and isinstance(image, Image.Image): batch_size = 1 elif image is not None and isinstance(image, list): - batch_size = len(image) + batch_size = 1 #len(image) else: batch_size = image.shape[0] if prompt is not None and isinstance(prompt, str): @@ -439,7 +441,27 @@ def __call__( device = self._execution_device # 3. Prepare image embeddings - image_latents = self.encode_image(image, device, 1) + if isinstance(image, list) and product_ratio is not None: + image[1] = image[1].convert('RGBA') + + rgba_np = np.array(image[1]) + product_mask = rgba_np[:, :, 3] + product_mask = product_mask > 0 + product_mask = np.stack((product_mask,)*3, axis=-1) + + img = Image.new('RGBA', image[1].size, (255, 255, 255, 255)) + img.paste(image[1], mask=image[1].split()[3]) + img = img.convert('RGB') + product_image_array = np.asarray(img) + + background_image_array = np.asarray(image[0].convert("RGB")) + background_mask = ~product_mask + + composed_image = product_image_array * product_mask * product_ratio + background_image_array * product_mask * (1.0 - product_ratio) + background_image_array * background_mask + composed_image = Image.fromarray(composed_image.astype(np.uint8)) + image_latents = self.encode_image(composed_image, device, 1) + else: + image_latents = self.encode_image(image, device, 1) image_embeds = self.image_embedder(image_latents).image_embeds image_embeds = image_embeds.to(device=device) From f1d431e6658040bc1032f4cba0673a07e5c4e7ec Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 21:47:01 +0900 Subject: [PATCH 164/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 6cf360b68a8a..b9e1541d13d2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -464,6 +464,7 @@ def __call__( image_latents = self.encode_image(image, device, 1) image_embeds = self.image_embedder(image_latents).image_embeds + print(f'image_embeds shape={image_embeds.shape}') image_embeds = image_embeds.to(device=device) # 3. Prepare (dummy) text embeddings From 53a54c68ab03c778488ddc072368bcd320d9b674 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 21:49:50 +0900 Subject: [PATCH 165/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index b9e1541d13d2..5a26ae13b27e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -464,7 +464,7 @@ def __call__( image_latents = self.encode_image(image, device, 1) image_embeds = self.image_embedder(image_latents).image_embeds - print(f'image_embeds shape={image_embeds.shape}') + print(f'image_embeds shape={image_embeds.shape}, batch_size={batch_size}') image_embeds = image_embeds.to(device=device) # 3. Prepare (dummy) text embeddings From 5223148e1820094d25f5953cfe611598b19e2945 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 22:33:59 +0900 Subject: [PATCH 166/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 5a26ae13b27e..66b134628104 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -461,10 +461,10 @@ def __call__( composed_image = Image.fromarray(composed_image.astype(np.uint8)) image_latents = self.encode_image(composed_image, device, 1) else: + image = image.convert('RGB') image_latents = self.encode_image(image, device, 1) image_embeds = self.image_embedder(image_latents).image_embeds - print(f'image_embeds shape={image_embeds.shape}, batch_size={batch_size}') image_embeds = image_embeds.to(device=device) # 3. Prepare (dummy) text embeddings From 4fd70f5c9674a56ed0207b120ac0c41c719a7d00 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 23 Mar 2025 23:13:43 +0900 Subject: [PATCH 167/482] add comments --- src/diffusers/image_processor.py | 2 +- src/diffusers/models/attention_processor.py | 17 ++---- .../flux/pipeline_flux_prior_redux.py | 52 +++++++++++-------- 3 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index c24f3aca6759..d6913f045ad2 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -1182,7 +1182,7 @@ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embe mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) mask_w = num_queries // mask_h - mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="nearest").squeeze(0) + mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) # Repeat batch_size times if mask_downsample.shape[0] < batch_size: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7a4529755851..0182c639cfbb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2360,7 +2360,6 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask txt_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask - img_ratio: Optional[float] = None, # thesea modification for ip mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2383,6 +2382,7 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: + # theseam modified if img_mask is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) @@ -2457,6 +2457,7 @@ def __call__( if image_rotary_emb is not None: from .embeddings import apply_rotary_emb + # theseam modified if encoder_hidden_states is not None: if img_mask is not None: cos, sin = image_rotary_emb @@ -2474,6 +2475,7 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) if encoder_hidden_states is not None: + # theseam modified if img_mask is not None: txt_hidden_states = F.scaled_dot_product_attention( txt_query, txt_key, txt_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -2495,10 +2497,8 @@ def __call__( hidden_states.shape[2], ) txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsample[txt_mask_downsample < 1.0] *= (2.0 - img_ratio) masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample - img_mask_downsample = IPAdapterMaskProcessor.downsample( img_mask[0], batch_size, @@ -2506,19 +2506,10 @@ def __call__( hidden_states.shape[2], ) - img_mask_downsample_2 = IPAdapterMaskProcessor.downsample( - img_mask[0], - batch_size, - 729, - hidden_states.shape[2], - ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample_2 = img_mask_downsample_2.to(dtype=query.dtype, device=query.device) - img_mask_downsample[img_mask_downsample < 1.0] *= img_ratio masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:] * img_mask_downsample_2, masked_txt_hidden_states + masked_img_hidden_states],dim=1) + hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 66b134628104..7ec33e17a16e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -380,7 +380,7 @@ def __call__( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, - product_ratio: Optional[float] = None, + product_ratio: Optional[float] = None, # theseam modified return_dict: bool = True, ): r""" @@ -428,7 +428,11 @@ def __call__( if image is not None and isinstance(image, Image.Image): batch_size = 1 elif image is not None and isinstance(image, list): - batch_size = 1 #len(image) + # theseam modified + if product_ratio is not None: + batch_size = 1 + else: + batch_size = len(image) else: batch_size = image.shape[0] if prompt is not None and isinstance(prompt, str): @@ -441,25 +445,31 @@ def __call__( device = self._execution_device # 3. Prepare image embeddings - if isinstance(image, list) and product_ratio is not None: - image[1] = image[1].convert('RGBA') - - rgba_np = np.array(image[1]) - product_mask = rgba_np[:, :, 3] - product_mask = product_mask > 0 - product_mask = np.stack((product_mask,)*3, axis=-1) - - img = Image.new('RGBA', image[1].size, (255, 255, 255, 255)) - img.paste(image[1], mask=image[1].split()[3]) - img = img.convert('RGB') - product_image_array = np.asarray(img) - - background_image_array = np.asarray(image[0].convert("RGB")) - background_mask = ~product_mask - - composed_image = product_image_array * product_mask * product_ratio + background_image_array * product_mask * (1.0 - product_ratio) + background_image_array * background_mask - composed_image = Image.fromarray(composed_image.astype(np.uint8)) - image_latents = self.encode_image(composed_image, device, 1) + # theseam modified + if isinstance(image, list): + if product_ratio is not None: + image[1] = image[1].convert('RGBA') + + rgba_np = np.array(image[1]) + product_mask = rgba_np[:, :, 3] + product_mask = product_mask > 0 + product_mask = np.stack((product_mask,)*3, axis=-1) + + img = Image.new('RGBA', image[1].size, (255, 255, 255, 255)) + img.paste(image[1], mask=image[1].split()[3]) + img = img.convert('RGB') + product_image_array = np.asarray(img) + + background_image_array = np.asarray(image[0].convert("RGB")) + background_mask = ~product_mask + + composed_image = product_image_array * product_mask * product_ratio + background_image_array * product_mask * (1.0 - product_ratio) + background_image_array * background_mask + composed_image = Image.fromarray(composed_image.astype(np.uint8)) + image_latents = self.encode_image(composed_image, device, 1) + else: + for img in image: + img = img.convert('RGB') + image_latents = self.encode_image(image, device, 1) else: image = image.convert('RGB') image_latents = self.encode_image(image, device, 1) From aa11bb7d87c7dd389485277376a41a988990a72c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 24 Mar 2025 14:01:01 +0900 Subject: [PATCH 168/482] add image resize --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 7ec33e17a16e..e1553e2904c2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -381,6 +381,7 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, product_ratio: Optional[float] = None, # theseam modified + image_size: int = 1024, return_dict: bool = True, ): r""" @@ -449,6 +450,7 @@ def __call__( if isinstance(image, list): if product_ratio is not None: image[1] = image[1].convert('RGBA') + image[1] = image[1].resize((image_size, image_size), resample=Image.BICUBIC) rgba_np = np.array(image[1]) product_mask = rgba_np[:, :, 3] @@ -460,7 +462,10 @@ def __call__( img = img.convert('RGB') product_image_array = np.asarray(img) - background_image_array = np.asarray(image[0].convert("RGB")) + image[0] = image[0].convert("RGB") + image[0] = image[0].resize((image_size, image_size), resample=Image.BICUBIC) + + background_image_array = np.asarray(image[0]) background_mask = ~product_mask composed_image = product_image_array * product_mask * product_ratio + background_image_array * product_mask * (1.0 - product_ratio) + background_image_array * background_mask From ddf3697e139ca74b61284fec3a98718f2dee7ba5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 24 Mar 2025 14:03:54 +0900 Subject: [PATCH 169/482] return mask --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index e1553e2904c2..8bee80f151fc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -470,6 +470,8 @@ def __call__( composed_image = product_image_array * product_mask * product_ratio + background_image_array * product_mask * (1.0 - product_ratio) + background_image_array * background_mask composed_image = Image.fromarray(composed_image.astype(np.uint8)) + + mask = Image.fromarray(background_mask.astype(np.uint8)*255) image_latents = self.encode_image(composed_image, device, 1) else: for img in image: @@ -525,6 +527,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (prompt_embeds, pooled_prompt_embeds) + return (prompt_embeds, pooled_prompt_embeds, mask) return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds) From ff69b1e3da88b1a3845b6db80a1f1e5cb0024a8b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 24 Mar 2025 14:05:03 +0900 Subject: [PATCH 170/482] return composed image --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 8bee80f151fc..2abf9580d668 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -527,6 +527,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (prompt_embeds, pooled_prompt_embeds, mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds) From bb33939ba3fbac37d09f6cf988775e6d6802287d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 24 Mar 2025 14:05:42 +0900 Subject: [PATCH 171/482] add default logic --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 2abf9580d668..a7369f56613f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -527,6 +527,9 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) + if product_ratio is not None: + return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) + else: + return (prompt_embeds, pooled_prompt_embeds) return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds) From 1853aefd2c25f99eeccb9b3d49498ba3e71b47fb Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 24 Mar 2025 14:24:36 +0900 Subject: [PATCH 172/482] add resize default logic --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index a7369f56613f..7f4ee4f11afa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -381,7 +381,7 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, product_ratio: Optional[float] = None, # theseam modified - image_size: int = 1024, + image_size: Optional[int] = None, return_dict: bool = True, ): r""" @@ -450,7 +450,8 @@ def __call__( if isinstance(image, list): if product_ratio is not None: image[1] = image[1].convert('RGBA') - image[1] = image[1].resize((image_size, image_size), resample=Image.BICUBIC) + if image_size is not None: + image[1] = image[1].resize((image_size, image_size), resample=Image.BICUBIC) rgba_np = np.array(image[1]) product_mask = rgba_np[:, :, 3] @@ -463,7 +464,8 @@ def __call__( product_image_array = np.asarray(img) image[0] = image[0].convert("RGB") - image[0] = image[0].resize((image_size, image_size), resample=Image.BICUBIC) + if image_size is not None: + image[0] = image[0].resize((image_size, image_size), resample=Image.BICUBIC) background_image_array = np.asarray(image[0]) background_mask = ~product_mask From b1801e372ad05e077bd8883bad1c1d23bff68c16 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 24 Mar 2025 15:58:40 +0900 Subject: [PATCH 173/482] add prod mask to attention processor --- src/diffusers/models/attention_processor.py | 86 +++++++++++++++++++-- 1 file changed, 80 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0182c639cfbb..d6a9290aff6a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2360,8 +2360,12 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask txt_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask + prod_masks: Optional[torch.Tensor] = None, # thesea modification for text mask ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if prod_masks is None: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + else: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states[:,0,:,:].shape # `sample` projections. query = attn.to_q(hidden_states) @@ -2429,6 +2433,43 @@ def __call__( img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) + elif prod_masks is not None: + if not prod_masks.shape[0] == encoder_hidden_states.shape[1]: + raise ValueError( + f"Length of text masks ({prod_masks.shape[0]}) must match " + f"the number of text prompts " + f"({encoder_hidden_states.shape[1]})") + queries = [] + keys = [] + values = [] + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,index,:,:]) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + tmp_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + tmp_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + tmp_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + queries.append(tmp_query) + keys.append(tmp_key) + values.append(tmp_value) else: encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -2467,6 +2508,10 @@ def __call__( img_query = apply_rotary_emb(img_query, (cos[512:,:],sin[512:,:])) img_key = apply_rotary_emb(img_key, (cos[512:,:],sin[512:,:])) + elif prod_masks is not None: + for tmp_query, tmp_key in zip(queries, keys): + tmp_query = apply_rotary_emb(tmp_query, image_rotary_emb) + tmp_key = apply_rotary_emb(tmp_key, image_rotary_emb) else: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) @@ -2510,6 +2555,30 @@ def __call__( masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) + elif prod_masks is not None: + hidden_states_list = [] + encoder_hidden_states_list = [] + for tmp_mask, tmp_query, tmp_key, tmp_value in zip(prod_masks, queries, keys, values): + tmp_hidden_states = F.scaled_dot_product_attention( + tmp_query, tmp_key, tmp_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + tmp_hidden_states = tmp_hidden_states.to(query.dtype) + + mask_downsample = IPAdapterMaskProcessor.downsample( + tmp_mask, + batch_size, + hidden_states.shape[1], + hidden_states.shape[2], + ) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states_list.append(tmp_hidden_states[:,512:,:] * mask_downsample) + encoder_hidden_states_list.append(tmp_hidden_states[:,:512,:]) + + hidden_states_list = torch.stack(hidden_states_list) + hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) + + encoder_hidden_states = torch.stack(encoder_hidden_states_list, dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -2524,17 +2593,22 @@ def __call__( hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) + if prod_masks is None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + if prod_masks is None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + else: + for index in range(encoder_hidden_states.shape[1]): + encoder_hidden_states[:,index,:,:] = attn.to_add_out(encoder_hidden_states[:,index,:,:]) return hidden_states, encoder_hidden_states else: From f54b5a206f28501461839c9854efa1005c3c6f95 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 10:41:05 +0900 Subject: [PATCH 174/482] add codes for multiple text prompts --- .../flux/pipeline_flux_controlnet.py | 65 +++++++++++++------ 1 file changed, 46 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index eee41b9af4d1..06597508c7a0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -847,7 +847,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + batch_size = 1 #len(prompt) #thesea modified for text prompt mask else: batch_size = prompt_embeds.shape[0] @@ -859,20 +859,47 @@ def __call__( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) + ## thesea modified for text prompt mask + prompt_embeds_list = [] + pooled_prompt_embeds_list = [] + text_ids_list = [] + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=pmt, + prompt_2=prompt_2, + #prompt_embeds=prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + prompt_embeds_list.append(prompt_embeds) + pooled_prompt_embeds_list.append(pooled_prompt_embeds) + text_ids_list.append(text_ids) + prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) + pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) + text_ids_list = torch.stack(text_ids_list,dim=1) + else: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) if do_true_cfg: ( negative_prompt_embeds, @@ -1085,8 +1112,8 @@ def __call__( conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, # TBD: pooled_prompt_embeds_list if isinstance(prompt, list) else pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for text prompt mask txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, @@ -1102,8 +1129,8 @@ def __call__( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, # TBD: pooled_prompt_embeds_list if isinstance(prompt, list) else pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for text prompt mask controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, txt_ids=text_ids, From d2641a3cd69ed705c56ab0d30f6a16994ce24356 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 11:57:03 +0900 Subject: [PATCH 175/482] list of images for flux redux --- .../flux/pipeline_flux_prior_redux.py | 72 ++++++++++++------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 7f4ee4f11afa..c08a95c8d82f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -381,7 +381,8 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, product_ratio: Optional[float] = None, # theseam modified - image_size: Optional[int] = None, + image_width: Optional[int] = 1024, + image_height: Optional[int] = 1024, return_dict: bool = True, ): r""" @@ -446,34 +447,55 @@ def __call__( device = self._execution_device # 3. Prepare image embeddings - # theseam modified + # thesea modified if isinstance(image, list): if product_ratio is not None: - image[1] = image[1].convert('RGBA') - if image_size is not None: - image[1] = image[1].resize((image_size, image_size), resample=Image.BICUBIC) - - rgba_np = np.array(image[1]) - product_mask = rgba_np[:, :, 3] - product_mask = product_mask > 0 - product_mask = np.stack((product_mask,)*3, axis=-1) - - img = Image.new('RGBA', image[1].size, (255, 255, 255, 255)) - img.paste(image[1], mask=image[1].split()[3]) - img = img.convert('RGB') - product_image_array = np.asarray(img) - - image[0] = image[0].convert("RGB") - if image_size is not None: - image[0] = image[0].resize((image_size, image_size), resample=Image.BICUBIC) - - background_image_array = np.asarray(image[0]) - background_mask = ~product_mask - - composed_image = product_image_array * product_mask * product_ratio + background_image_array * product_mask * (1.0 - product_ratio) + background_image_array * background_mask + image_array_list = [] + mask_list = [] + is_product_list = [] + for img in image: + metadata = img.info + is_product = metadata.get('is_product') + is_product_list.append(is_product) + img = img.convert('RGBA') + img = img.resize((image_width, image_height), resample=Image.BICUBIC) + + rgba_np = np.array(img) + mask = rgba_np[:, :, 3] + mask = mask > 0 + mask = np.stack((mask,)*3, axis=-1) + mask_list.append(mask) + + tmp_img = Image.new('RGBA', img.size, (255, 255, 255, 255)) + tmp_img.paste(img, mask=img.split()[3]) + tmp_img = tmp_img.convert('RGB') + image_array = np.asarray(tmp_img) + image_array_list.append(image_array) + + product_mask = np.full((image_width, image_height), True, dtype=bool) + image_mask = {} + for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): + if is_product: + product_mask = product_mask & ~mask + else: + product_mask = product_mask | mask + + if index not in image_mask: + image_mask[index] = mask + else: + for k in image_mask: + image_mask[k] = image_mask[k] & ~mask + + composed_image = np.zeros((image_width, image_height, 3)) + for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): + if is_product: + composed_image += img_array * image_mask[index] * product_ratio + else: + composed_image += img_array * image_mask[index] + composed_image = Image.fromarray(composed_image.astype(np.uint8)) - mask = Image.fromarray(background_mask.astype(np.uint8)*255) + mask = Image.fromarray(~product_mask.astype(np.uint8)*255) image_latents = self.encode_image(composed_image, device, 1) else: for img in image: From e9ec3e08aceb986ae26507146dd2a48b7018be62 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 12:12:57 +0900 Subject: [PATCH 176/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index c08a95c8d82f..71b1786358df 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -472,7 +472,7 @@ def __call__( image_array = np.asarray(tmp_img) image_array_list.append(image_array) - product_mask = np.full((image_width, image_height), True, dtype=bool) + product_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): if is_product: From 7c384a18d03bbd4d55c525653e747174def5dee3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 12:24:21 +0900 Subject: [PATCH 177/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 71b1786358df..42af79ce9166 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -457,6 +457,7 @@ def __call__( metadata = img.info is_product = metadata.get('is_product') is_product_list.append(is_product) + img = img.convert('RGBA') img = img.resize((image_width, image_height), resample=Image.BICUBIC) @@ -475,6 +476,7 @@ def __call__( product_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): + print(f'index={index}, is_product={is_product}') if is_product: product_mask = product_mask & ~mask else: From 7bb5ad49a9af9214d7df3c4f130c81c01ea09ef8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 12:33:51 +0900 Subject: [PATCH 178/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 42af79ce9166..21938f56342a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -476,8 +476,7 @@ def __call__( product_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): - print(f'index={index}, is_product={is_product}') - if is_product: + if is_product.lower() == "true": product_mask = product_mask & ~mask else: product_mask = product_mask | mask @@ -490,7 +489,7 @@ def __call__( composed_image = np.zeros((image_width, image_height, 3)) for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): - if is_product: + if is_product.lower() == "true": composed_image += img_array * image_mask[index] * product_ratio else: composed_image += img_array * image_mask[index] From 01067ec54c225925f516f302fb63ac6674fdfafd Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 12:40:33 +0900 Subject: [PATCH 179/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 21938f56342a..3ab919d140e2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -476,6 +476,7 @@ def __call__( product_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): + print(f'index={index}, is_product={is_product}') if is_product.lower() == "true": product_mask = product_mask & ~mask else: From 8bbe1d8d6c55359e2ed0f1739926ce1f3d7bbc6a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 12:45:53 +0900 Subject: [PATCH 180/482] retun info --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 3ab919d140e2..00e7db00462c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -554,7 +554,7 @@ def __call__( if not return_dict: if product_ratio is not None: - return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image, mask, mask_list, image_array_list) else: return (prompt_embeds, pooled_prompt_embeds) From 08ad49df422f737e06ac890f9ed47d710df9f432 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 12:53:22 +0900 Subject: [PATCH 181/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 00e7db00462c..6efc527fce5b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -497,7 +497,7 @@ def __call__( composed_image = Image.fromarray(composed_image.astype(np.uint8)) - mask = Image.fromarray(~product_mask.astype(np.uint8)*255) + mask = Image.fromarray(product_mask.astype(np.uint8)*255) image_latents = self.encode_image(composed_image, device, 1) else: for img in image: From c0b7f4934c8c39c5f970866f80cc78739fd9536c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 13:02:45 +0900 Subject: [PATCH 182/482] add return --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 6efc527fce5b..3a90a2c1d0ad 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -554,7 +554,7 @@ def __call__( if not return_dict: if product_ratio is not None: - return (prompt_embeds, pooled_prompt_embeds, composed_image, mask, mask_list, image_array_list) + return (prompt_embeds, pooled_prompt_embeds, composed_image, mask, mask_list, image_mask, image_array_list) else: return (prompt_embeds, pooled_prompt_embeds) From 5051a5b6dc4105f79c0095f4efaf47c6b5649364 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 14:06:14 +0900 Subject: [PATCH 183/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 3a90a2c1d0ad..e9aec70d3650 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -484,8 +484,8 @@ def __call__( if index not in image_mask: image_mask[index] = mask - else: - for k in image_mask: + for k in image_mask: + if k != index: image_mask[k] = image_mask[k] & ~mask composed_image = np.zeros((image_width, image_height, 3)) From 669d638ffa3751f6b205a1507a582b796f1bf490 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 14:30:11 +0900 Subject: [PATCH 184/482] remove print --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index e9aec70d3650..8ea8fa351442 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -476,7 +476,6 @@ def __call__( product_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): - print(f'index={index}, is_product={is_product}') if is_product.lower() == "true": product_mask = product_mask & ~mask else: @@ -554,7 +553,7 @@ def __call__( if not return_dict: if product_ratio is not None: - return (prompt_embeds, pooled_prompt_embeds, composed_image, mask, mask_list, image_mask, image_array_list) + return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) #, mask_list, image_mask, image_array_list else: return (prompt_embeds, pooled_prompt_embeds) From 7ae8a08151bde8620f71be7a982fa6e91e09f9e5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 16:11:48 +0900 Subject: [PATCH 185/482] change from meta data to type list --- .../models/transformers/transformer_flux.py | 7 ++++++- .../pipelines/flux/pipeline_flux_prior_redux.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d246..b05101dbbc70 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -455,7 +455,12 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) + print(f'transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') + if encoder_hidden_states.ndim == 4: + for index in range(encoder_hidden_states.ndim): + encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) + else: + encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: logger.warning( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 8ea8fa351442..0b01e13d64af 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -374,6 +374,7 @@ def encode_prompt( def __call__( self, image: PipelineImageInput, + layer_type: Optional[Union[str, List[str]]] = None, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, @@ -453,10 +454,13 @@ def __call__( image_array_list = [] mask_list = [] is_product_list = [] - for img in image: - metadata = img.info - is_product = metadata.get('is_product') - is_product_list.append(is_product) + for img, type in zip(image, layer_type): + #metadata = img.info + #is_product = metadata.get('is_product') + if 'product' in type: + is_product_list.append('true') + else: + is_product_list.append('false') img = img.convert('RGBA') img = img.resize((image_width, image_height), resample=Image.BICUBIC) From 44f2c6e548d5ded00e8d47590bad38a836714662 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 17:40:30 +0900 Subject: [PATCH 186/482] add text prompt mask to flux ctln pipeline --- src/diffusers/models/attention_processor.py | 3 ++- src/diffusers/models/transformers/transformer_flux.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d6a9290aff6a..c146a01806c8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2434,6 +2434,7 @@ def __call__( img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) elif prod_masks is not None: + print(f'attention processor encoder_hidden_states.shape={encoder_hidden_states.shape}') if not prod_masks.shape[0] == encoder_hidden_states.shape[1]: raise ValueError( f"Length of text masks ({prod_masks.shape[0]}) must match " @@ -2610,7 +2611,7 @@ def __call__( for index in range(encoder_hidden_states.shape[1]): encoder_hidden_states[:,index,:,:] = attn.to_add_out(encoder_hidden_states[:,index,:,:]) - return hidden_states, encoder_hidden_states + return hidden_states, encoder_hidden_states[:,-1,:,:] else: return hidden_states diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index b05101dbbc70..1cc42ea99070 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -455,6 +455,8 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) + + # thesea modifed for text prompt mask print(f'transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') if encoder_hidden_states.ndim == 4: for index in range(encoder_hidden_states.ndim): From e7ad1c4d4bf04dc0e1a5c7e288296b333406b78d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 18:32:05 +0900 Subject: [PATCH 187/482] print info --- src/diffusers/models/controlnets/controlnet_flux.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index dd65fbc6d8f4..567938f03152 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -296,8 +296,14 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - + # thesea modifed for text prompt mask + print(f'ctln transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') + if encoder_hidden_states.ndim == 4: + for index in range(encoder_hidden_states.ndim): + encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) + else: + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if self.union: # union mode if controlnet_mode is None: @@ -344,6 +350,7 @@ def forward( ) block_samples = block_samples + (hidden_states,) + print(f'ctln encoder_hidden_states shape={encoder_hidden_states.shape}, hidden_states shape={hidden_states.shape}') hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) single_block_samples = () From 351d13001014cad40ab36827313126fdc3f28d4e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 18:43:00 +0900 Subject: [PATCH 188/482] fix bug --- src/diffusers/models/attention_processor.py | 3 ++- src/diffusers/models/controlnets/controlnet_flux.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c146a01806c8..4c4c3f96b128 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2607,11 +2607,12 @@ def __call__( if prod_masks is None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + return hidden_states, encoder_hidden_states else: for index in range(encoder_hidden_states.shape[1]): encoder_hidden_states[:,index,:,:] = attn.to_add_out(encoder_hidden_states[:,index,:,:]) - return hidden_states, encoder_hidden_states[:,-1,:,:] + return hidden_states, encoder_hidden_states[:,-1,:,:] else: return hidden_states diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 567938f03152..72d13a702673 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -303,6 +303,7 @@ def forward( encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) else: encoder_hidden_states = self.context_embedder(encoder_hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) if self.union: # union mode From 396f163f54e5ac2718069d34bcf28a9b1d84f6c3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 18:53:36 +0900 Subject: [PATCH 189/482] Revert "fix bug" This reverts commit 351d13001014cad40ab36827313126fdc3f28d4e. --- src/diffusers/models/attention_processor.py | 3 +-- src/diffusers/models/controlnets/controlnet_flux.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4c4c3f96b128..c146a01806c8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2607,12 +2607,11 @@ def __call__( if prod_masks is None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - return hidden_states, encoder_hidden_states else: for index in range(encoder_hidden_states.shape[1]): encoder_hidden_states[:,index,:,:] = attn.to_add_out(encoder_hidden_states[:,index,:,:]) - return hidden_states, encoder_hidden_states[:,-1,:,:] + return hidden_states, encoder_hidden_states[:,-1,:,:] else: return hidden_states diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 72d13a702673..567938f03152 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -303,7 +303,6 @@ def forward( encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) else: encoder_hidden_states = self.context_embedder(encoder_hidden_states) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) if self.union: # union mode From 4b2428b9a12f7c9ba185afbbf79342f325bc0ce6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 18:53:54 +0900 Subject: [PATCH 190/482] Revert "print info" This reverts commit e7ad1c4d4bf04dc0e1a5c7e288296b333406b78d. --- src/diffusers/models/controlnets/controlnet_flux.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 567938f03152..dd65fbc6d8f4 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -296,14 +296,8 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - # thesea modifed for text prompt mask - print(f'ctln transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') - if encoder_hidden_states.ndim == 4: - for index in range(encoder_hidden_states.ndim): - encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) - else: - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if self.union: # union mode if controlnet_mode is None: @@ -350,7 +344,6 @@ def forward( ) block_samples = block_samples + (hidden_states,) - print(f'ctln encoder_hidden_states shape={encoder_hidden_states.shape}, hidden_states shape={hidden_states.shape}') hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) single_block_samples = () From b146b94d3d0320b73bced0959b19cc6dac31a8a4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 18:54:08 +0900 Subject: [PATCH 191/482] Revert "add text prompt mask to flux ctln pipeline" This reverts commit 44f2c6e548d5ded00e8d47590bad38a836714662. --- src/diffusers/models/attention_processor.py | 3 +-- src/diffusers/models/transformers/transformer_flux.py | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c146a01806c8..d6a9290aff6a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2434,7 +2434,6 @@ def __call__( img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) elif prod_masks is not None: - print(f'attention processor encoder_hidden_states.shape={encoder_hidden_states.shape}') if not prod_masks.shape[0] == encoder_hidden_states.shape[1]: raise ValueError( f"Length of text masks ({prod_masks.shape[0]}) must match " @@ -2611,7 +2610,7 @@ def __call__( for index in range(encoder_hidden_states.shape[1]): encoder_hidden_states[:,index,:,:] = attn.to_add_out(encoder_hidden_states[:,index,:,:]) - return hidden_states, encoder_hidden_states[:,-1,:,:] + return hidden_states, encoder_hidden_states else: return hidden_states diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1cc42ea99070..b05101dbbc70 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -455,8 +455,6 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - - # thesea modifed for text prompt mask print(f'transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') if encoder_hidden_states.ndim == 4: for index in range(encoder_hidden_states.ndim): From 8bac12aefd2ddfda718a9d482aa10cbec6d2ef4c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 25 Mar 2025 18:56:31 +0900 Subject: [PATCH 192/482] remove print --- src/diffusers/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index b05101dbbc70..b948948c4a3a 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -455,7 +455,7 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - print(f'transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') + #print(f'transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') if encoder_hidden_states.ndim == 4: for index in range(encoder_hidden_states.ndim): encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) From 7570c8a81aa450dedf418b5105af2e2d0246fce6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 17:50:58 +0800 Subject: [PATCH 193/482] ip-mask re-implemented by attention mask --- src/diffusers/models/attention_processor.py | 359 ++++++++---------- .../models/controlnets/controlnet_flux.py | 1 + .../models/transformers/transformer_flux.py | 8 +- .../flux/pipeline_flux_controlnet.py | 19 +- .../flux/pipeline_flux_prior_redux.py | 59 ++- 5 files changed, 203 insertions(+), 243 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0027672a7f77..f9171cce2869 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2358,15 +2358,12 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - img_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask - txt_mask: Optional[torch.Tensor] = None, # thesea modification for ip mask - prod_masks: Optional[torch.Tensor] = None, # thesea modification for text mask + ip_img: Optional[bool] = False, # thesea modified for ip image + img_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask + txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask ) -> torch.FloatTensor: - if prod_masks is None: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - else: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states[:,0,:,:].shape - + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -2374,7 +2371,7 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -2386,229 +2383,179 @@ def __call__( # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: - # theseam modified - if img_mask is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,0:512,:]) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,0:512,:]) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,0:512,:]) - - img_encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,512:,:]) - img_encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,512:,:]) - img_encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,512:,:]) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - img_encoder_hidden_states_query_proj = img_encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_key_proj = img_encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - img_encoder_hidden_states_value_proj = img_encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - img_encoder_hidden_states_query_proj = attn.norm_added_q(img_encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - img_encoder_hidden_states_key_proj = attn.norm_added_k(img_encoder_hidden_states_key_proj) - - # attention - txt_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - txt_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - txt_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - img_query = torch.cat([img_encoder_hidden_states_query_proj, query], dim=2) - img_key = torch.cat([img_encoder_hidden_states_key_proj, key], dim=2) - img_value = torch.cat([img_encoder_hidden_states_value_proj, value], dim=2) - elif prod_masks is not None: - if not prod_masks.shape[0] == encoder_hidden_states.shape[1]: - raise ValueError( - f"Length of text masks ({prod_masks.shape[0]}) must match " - f"the number of text prompts " - f"({encoder_hidden_states.shape[1]})") - queries = [] - keys = [] - values = [] - for index in range(encoder_hidden_states.shape[1]): - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states[:,index,:,:]) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states[:,index,:,:]) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states[:,index,:,:]) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - tmp_query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - tmp_key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - tmp_value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - queries.append(tmp_query) - keys.append(tmp_key) - values.append(tmp_value) - else: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb - - # theseam modified - if encoder_hidden_states is not None: - if img_mask is not None: - cos, sin = image_rotary_emb - #print(f'cos shape={cos.shape}, sin shape={sin.shape}') - txt_query = apply_rotary_emb(txt_query, (torch.cat([cos[0:512,:], cos[1241:,:]], dim = 0), torch.cat([sin[0:512,:], sin[1241:,:]],dim=0))) - txt_key = apply_rotary_emb(txt_key, (torch.cat([cos[0:512,:], cos[1241:,:]], dim = 0), torch.cat([sin[0:512,:], sin[1241:,:]],dim=0))) - - img_query = apply_rotary_emb(img_query, (cos[512:,:],sin[512:,:])) - img_key = apply_rotary_emb(img_key, (cos[512:,:],sin[512:,:])) - elif prod_masks is not None: - for tmp_query, tmp_key in zip(queries, keys): - tmp_query = apply_rotary_emb(tmp_query, image_rotary_emb) - tmp_key = apply_rotary_emb(tmp_key, image_rotary_emb) - else: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - else: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - if encoder_hidden_states is not None: - # theseam modified - if img_mask is not None: - txt_hidden_states = F.scaled_dot_product_attention( - txt_query, txt_key, txt_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - img_hidden_states = F.scaled_dot_product_attention( - img_query, img_key, img_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - txt_hidden_states = txt_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - txt_hidden_states = txt_hidden_states.to(query.dtype) + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) - img_hidden_states = img_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - img_hidden_states = img_hidden_states.to(query.dtype) + if txt_masks is not None: + if ip_img: + num_of_prompts = int((query.shape[2] - 4825)/512) + if num_of_prompts == 1: + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + self_attend_masks = torch.zeros((4096, 4096), device=query.device) + union_masks = torch.zeros((4096, 4096), device=query.device) + + # text related attention mask + attention_mask[:512, :512] = torch.ones(512, 512) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + txt_masks[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 + mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[:512,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose + + img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) + img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) + self_attend_masks = torch.logical_or(self_attend_masks, + torch.logical_and(img_size_masks, img_size_masks_transpose)) + union_masks = torch.logical_or(union_masks, + torch.logical_or(img_size_masks, img_size_masks_transpose)) + # text related attention mask + + # image related attention mask + attention_mask[512:1241, 512:1241] = torch.ones(729, 729) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + img_mask[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 + mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose + + img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) + img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) + self_attend_masks = torch.logical_or(self_attend_masks, + torch.logical_and(img_size_masks, img_size_masks_transpose)) + union_masks = torch.logical_or(union_masks, + torch.logical_or(img_size_masks, img_size_masks_transpose)) + # image related attention mask + + background_masks = torch.logical_not(union_masks) + background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) + attention_mask[-4096:,-4096:] = background_and_self_attend_masks + + zero_index = attention_mask == 0 + one_index = attention_mask == 1 + attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + attention_mask = attention_mask.masked_fill(one_index, 0) + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + + hidden_states_region = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_mask[0], - batch_size, - hidden_states.shape[1], - hidden_states.shape[2], - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_txt_hidden_states = txt_hidden_states[:,512:,:] * txt_mask_downsample - - img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask[0], - batch_size, - hidden_states.shape[1], - hidden_states.shape[2], - ) + hidden_states_txt = F.scaled_dot_product_attention( + torch.cat([query[:,:,:512,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,:512,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,:512,:], value[:,:,-4096:,:]], dim=2), + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - masked_img_hidden_states = img_hidden_states[:,729:,:] * img_mask_downsample - - hidden_states = torch.cat([txt_hidden_states[:,:-hidden_states.shape[1],:], img_hidden_states[:,:-hidden_states.shape[1],:], masked_txt_hidden_states + masked_img_hidden_states],dim=1) - elif prod_masks is not None: - hidden_states_list = [] - encoder_hidden_states_list = [] - for tmp_mask, tmp_query, tmp_key, tmp_value in zip(prod_masks, queries, keys, values): - tmp_hidden_states = F.scaled_dot_product_attention( - tmp_query, tmp_key, tmp_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states_img = F.scaled_dot_product_attention( + query[:,:,512:,:], + key[:,:,512:,:], + value[:,:,512:,:], + attn_mask=None, + dropout_p=0.0, + is_causal=False ) - tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - tmp_hidden_states = tmp_hidden_states.to(query.dtype) - mask_downsample = IPAdapterMaskProcessor.downsample( - tmp_mask, + hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_region = hidden_states_region.to(query.dtype) + + hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_txt = hidden_states_txt.to(query.dtype) + + hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_img = hidden_states_img.to(query.dtype) + + + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_masks[0], batch_size, - hidden_states.shape[1], - hidden_states.shape[2], - ) - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_list.append(tmp_hidden_states[:,512:,:] * mask_downsample) - encoder_hidden_states_list.append(tmp_hidden_states[:,:512,:]) - - hidden_states_list = torch.stack(hidden_states_list) - hidden_states = torch.sum(hidden_states_list, dim=0, keepdim=False) - - encoder_hidden_states = torch.stack(encoder_hidden_states_list, dim=1) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - else: + 4096, + attn.heads * head_dim, + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + + img_mask_downsample = IPAdapterMaskProcessor.downsample( + img_mask[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + + hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample + hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) + else hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: - if prod_masks is None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - if prod_masks is None: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - else: - for index in range(encoder_hidden_states.shape[1]): - encoder_hidden_states[:,index,:,:] = attn.to_add_out(encoder_hidden_states[:,index,:,:]) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index dd65fbc6d8f4..4dd9b64854ed 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -361,6 +361,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index b948948c4a3a..da7d6696c2f5 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -455,12 +455,8 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - #print(f'transformer flux encoder_hidden_states.ndim={encoder_hidden_states.ndim}') - if encoder_hidden_states.ndim == 4: - for index in range(encoder_hidden_states.ndim): - encoder_hidden_states[:,index,:,:] = self.context_embedder(encoder_hidden_states[:,index,:,:]) - else: - encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: logger.warning( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 263ae7b5eab5..8e3586c68494 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -848,7 +848,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = 1 #len(prompt) #thesea modified for text prompt mask + batch_size = 1 #len(prompt) # thesea modified for text prompt mask else: batch_size = prompt_embeds.shape[0] @@ -862,7 +862,6 @@ def __call__( do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None ## thesea modified for text prompt mask prompt_embeds_list = [] - pooled_prompt_embeds_list = [] text_ids_list = [] if isinstance(prompt, list): for pmt in prompt: @@ -881,11 +880,9 @@ def __call__( lora_scale=lora_scale, ) prompt_embeds_list.append(prompt_embeds) - pooled_prompt_embeds_list.append(pooled_prompt_embeds) text_ids_list.append(text_ids) - prompt_embeds_list = torch.stack(prompt_embeds_list,dim=1) - pooled_prompt_embeds_list = torch.stack(pooled_prompt_embeds_list,dim=1) - text_ids_list = torch.stack(text_ids_list,dim=1) + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + text_ids = torch.cat(text_ids_list, dim=0) else: ( prompt_embeds, @@ -1113,8 +1110,8 @@ def __call__( conditioning_scale=cond_scale, timestep=timestep / 1000, guidance=guidance, - pooled_projections=pooled_prompt_embeds, # TBD: pooled_prompt_embeds_list if isinstance(prompt, list) else pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for text prompt mask + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, @@ -1130,8 +1127,8 @@ def __call__( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=pooled_prompt_embeds, # TBD: pooled_prompt_embeds_list if isinstance(prompt, list) else pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds_list if isinstance(prompt, list) else prompt_embeds, #prompt_embeds, thesea modified for text prompt mask + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, txt_ids=text_ids, @@ -1202,4 +1199,4 @@ def __call__( if not return_dict: return (image,) - return FluxPipelineOutput(images=image) + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 0b01e13d64af..84489a7b60c5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -438,8 +438,9 @@ def __call__( batch_size = len(image) else: batch_size = image.shape[0] - if prompt is not None and isinstance(prompt, str): - prompt = batch_size * [prompt] + # thesea modified for ip and txt masks + #if prompt is not None and isinstance(prompt, str): + # prompt = batch_size * [prompt] if isinstance(prompt_embeds_scale, float): prompt_embeds_scale = batch_size * [prompt_embeds_scale] if isinstance(pooled_prompt_embeds_scale, float): @@ -449,15 +450,14 @@ def __call__( # 3. Prepare image embeddings # thesea modified + # thesea modified for ip and txt masks if isinstance(image, list): if product_ratio is not None: image_array_list = [] mask_list = [] is_product_list = [] for img, type in zip(image, layer_type): - #metadata = img.info - #is_product = metadata.get('is_product') - if 'product' in type: + if 'product' or 'Product' in type: is_product_list.append('true') else: is_product_list.append('false') @@ -514,21 +514,40 @@ def __call__( image_embeds = image_embeds.to(device=device) # 3. Prepare (dummy) text embeddings + # thesea modified for ip and txt masks if hasattr(self, "text_encoder") and self.text_encoder is not None: - ( - prompt_embeds, - pooled_prompt_embeds, - _, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=1, - max_sequence_length=512, - lora_scale=None, - ) + prompt_embeds_list = [] + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=pmt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=batch_size, + max_sequence_length=512, + lora_scale=None, + ) + prompt_embeds_list.append(prompt_embeds) + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + else: + ( + prompt_embeds, + pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=batch_size, + max_sequence_length=512, + lora_scale=None, + ) else: if prompt is not None: logger.warning( @@ -557,7 +576,7 @@ def __call__( if not return_dict: if product_ratio is not None: - return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) #, mask_list, image_mask, image_array_list + return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) else: return (prompt_embeds, pooled_prompt_embeds) From dbdbbd9ef5842425b5ece838d7ee7d551215929f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 17:54:13 +0800 Subject: [PATCH 194/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f9171cce2869..92b9d25c69d4 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2536,7 +2536,7 @@ def __call__( hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) - else + else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) From 338da6eb18909303ef872f72bd6067b0b79162e4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:06:53 +0800 Subject: [PATCH 195/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 84489a7b60c5..5cd5a10363f3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -457,7 +457,8 @@ def __call__( mask_list = [] is_product_list = [] for img, type in zip(image, layer_type): - if 'product' or 'Product' in type: + if 'product' in img_type or 'Product' in img_type: + print(f'in img_type = {in img_type}') is_product_list.append('true') else: is_product_list.append('false') From 48e6fb8c6a267f67ea75f0612ff105201ffa4010 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:07:15 +0800 Subject: [PATCH 196/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 5cd5a10363f3..16dab63966da 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -456,7 +456,7 @@ def __call__( image_array_list = [] mask_list = [] is_product_list = [] - for img, type in zip(image, layer_type): + for img, img_type in zip(image, layer_type): if 'product' in img_type or 'Product' in img_type: print(f'in img_type = {in img_type}') is_product_list.append('true') From 5db112c11bc6b79dcd5589b14a54bd2faae09ecc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:09:23 +0800 Subject: [PATCH 197/482] fix print bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 16dab63966da..a0b1795f1f66 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -458,7 +458,7 @@ def __call__( is_product_list = [] for img, img_type in zip(image, layer_type): if 'product' in img_type or 'Product' in img_type: - print(f'in img_type = {in img_type}') + print(f'img_type = {img_type}') is_product_list.append('true') else: is_product_list.append('false') From 6a0a9bc4ac877ec860603bbd0b657ea1e5d2e753 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:16:12 +0800 Subject: [PATCH 198/482] fix device bug --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 92b9d25c69d4..37fbdd4d11aa 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2533,6 +2533,7 @@ def __call__( 4096, attn.heads * head_dim, ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) From 33023f442269a70c7ed0396467da6533a91d3765 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:25:18 +0800 Subject: [PATCH 199/482] remove mask in ctln --- src/diffusers/models/controlnets/controlnet_flux.py | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 4dd9b64854ed..3fe6608dd1c8 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -340,7 +340,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -361,7 +361,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index a0b1795f1f66..34f2e89171a7 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -458,7 +458,6 @@ def __call__( is_product_list = [] for img, img_type in zip(image, layer_type): if 'product' in img_type or 'Product' in img_type: - print(f'img_type = {img_type}') is_product_list.append('true') else: is_product_list.append('false') From 2881819b10f539a374c0eddb4961de6acea3fce1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:34:40 +0800 Subject: [PATCH 200/482] use hidden_states_region --- src/diffusers/models/attention_processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 37fbdd4d11aa..13e31e4dcf56 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2536,7 +2536,8 @@ def __call__( img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample - hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) + #hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) + hidden_states = hidden_states_region else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From 4fc0cd8a8bbd511dd69827d1ce986de704d51b86 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:39:11 +0800 Subject: [PATCH 201/482] removejoint_attention_kwargs=self.joint_attention_kwargs in ctln --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 8e3586c68494..02809ce5f9e8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1114,7 +1114,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, + #joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, ) From 27f5e016c914328bf6e7fc24de33c77c88c8a5ee Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 18:57:22 +0800 Subject: [PATCH 202/482] add pr=0.0 logic --- src/diffusers/models/attention_processor.py | 4 +-- .../flux/pipeline_flux_prior_redux.py | 25 +++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 13e31e4dcf56..98c3155aff83 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2536,8 +2536,8 @@ def __call__( img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample - #hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) - hidden_states = hidden_states_region + hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) + #hidden_states = hidden_states_region else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 34f2e89171a7..528166b1b6f4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -491,16 +491,21 @@ def __call__( if k != index: image_mask[k] = image_mask[k] & ~mask - composed_image = np.zeros((image_width, image_height, 3)) - for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): - if is_product.lower() == "true": - composed_image += img_array * image_mask[index] * product_ratio - else: - composed_image += img_array * image_mask[index] - - composed_image = Image.fromarray(composed_image.astype(np.uint8)) - - mask = Image.fromarray(product_mask.astype(np.uint8)*255) + if product_ratio > 0.0: + composed_image = np.zeros((image_width, image_height, 3)) + for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): + if is_product.lower() == "true": + composed_image += img_array * image_mask[index] * product_ratio + else: + composed_image += img_array * image_mask[index] + + composed_image = Image.fromarray(composed_image.astype(np.uint8)) + else: + composed_image = image + for img in composed_image: + img = img.convert('RGB') + + mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') image_latents = self.encode_image(composed_image, device, 1) else: for img in image: From a9c8762f8827d793fb929e1cf3cc58e647ae936a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 19:01:30 +0800 Subject: [PATCH 203/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 528166b1b6f4..1b3f412c0f89 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -501,10 +501,8 @@ def __call__( composed_image = Image.fromarray(composed_image.astype(np.uint8)) else: - composed_image = image - for img in composed_image: - img = img.convert('RGB') - + composed_image = image[0].convert('RGB') + mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') image_latents = self.encode_image(composed_image, device, 1) else: From 0f9a6ab12689c9e509f8485d5e6473b312bdd44e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 20:38:50 +0800 Subject: [PATCH 204/482] remove necessary codes --- src/diffusers/models/attention_processor.py | 30 ++++++++++++--------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 98c3155aff83..ac54a011a6ee 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2419,8 +2419,8 @@ def __call__( num_of_prompts = int((query.shape[2] - 4825)/512) if num_of_prompts == 1: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - self_attend_masks = torch.zeros((4096, 4096), device=query.device) - union_masks = torch.zeros((4096, 4096), device=query.device) + #self_attend_masks = torch.zeros((4096, 4096), device=query.device) + #union_masks = torch.zeros((4096, 4096), device=query.device) # text related attention mask attention_mask[:512, :512] = torch.ones(512, 512) @@ -2432,19 +2432,21 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 - mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 + #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 + #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[:512,-4096:] = mask_downsample_t2i_tensor attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose - + + """ img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) self_attend_masks = torch.logical_or(self_attend_masks, torch.logical_and(img_size_masks, img_size_masks_transpose)) union_masks = torch.logical_or(union_masks, torch.logical_or(img_size_masks, img_size_masks_transpose)) + """ # text related attention mask # image related attention mask @@ -2457,27 +2459,30 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 - mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 + #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 + #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose + """ img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) self_attend_masks = torch.logical_or(self_attend_masks, torch.logical_and(img_size_masks, img_size_masks_transpose)) union_masks = torch.logical_or(union_masks, torch.logical_or(img_size_masks, img_size_masks_transpose)) + """ # image related attention mask - background_masks = torch.logical_not(union_masks) - background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) - attention_mask[-4096:,-4096:] = background_and_self_attend_masks + #background_masks = torch.logical_not(union_masks) + #background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) + + attention_mask[-4096:,-4096:] = 1#background_and_self_attend_masks - zero_index = attention_mask == 0 - one_index = attention_mask == 1 + zero_index = attention_mask < 0.5 + one_index = attention_mask >= 0.5 attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) attention_mask = attention_mask.masked_fill(one_index, 0) attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) @@ -2518,7 +2523,6 @@ def __call__( hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_img = hidden_states_img.to(query.dtype) - txt_mask_downsample = IPAdapterMaskProcessor.downsample( txt_masks[0], batch_size, From 55cacf79b2e716e7223eb1d12b06f2043009910f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 21:02:25 +0800 Subject: [PATCH 205/482] clear codes --- src/diffusers/models/attention_processor.py | 56 +++++++-------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ac54a011a6ee..b204e82fd1ae 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2418,12 +2418,12 @@ def __call__( if ip_img: num_of_prompts = int((query.shape[2] - 4825)/512) if num_of_prompts == 1: - attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - #self_attend_masks = torch.zeros((4096, 4096), device=query.device) - #union_masks = torch.zeros((4096, 4096), device=query.device) - + #attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + attention_mask = torch.zeros(1241, 1241, dtype=torch.bool, device=query.device) # text related attention mask - attention_mask[:512, :512] = torch.ones(512, 512) + attention_mask[:512, :512] = torch.ones(512, 512, dtype=torch.bool, device=query.device) + attention_mask[512:1241, 512:1241] = torch.ones(729, 729, dtype=torch.bool, device=query.device) + """ mask_downsample_t2i = IPAdapterMaskProcessor.downsample( txt_masks[0], 1, @@ -2432,25 +2432,16 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 - #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[:512,-4096:] = mask_downsample_t2i_tensor attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose - - """ - img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) - img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) - self_attend_masks = torch.logical_or(self_attend_masks, - torch.logical_and(img_size_masks, img_size_masks_transpose)) - union_masks = torch.logical_or(union_masks, - torch.logical_or(img_size_masks, img_size_masks_transpose)) """ # text related attention mask # image related attention mask - attention_mask[512:1241, 512:1241] = torch.ones(729, 729) + #attention_mask[512:1241, 512:1241] = torch.ones(729, 729) + """ mask_downsample_t2i = IPAdapterMaskProcessor.downsample( img_mask[0], 1, @@ -2459,38 +2450,25 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 - #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose - - """ - img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) - img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) - self_attend_masks = torch.logical_or(self_attend_masks, - torch.logical_and(img_size_masks, img_size_masks_transpose)) - union_masks = torch.logical_or(union_masks, - torch.logical_or(img_size_masks, img_size_masks_transpose)) """ # image related attention mask + + #attention_mask[-4096:,-4096:] = 1#background_and_self_attend_masks - #background_masks = torch.logical_not(union_masks) - #background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) - - attention_mask[-4096:,-4096:] = 1#background_and_self_attend_masks - - zero_index = attention_mask < 0.5 - one_index = attention_mask >= 0.5 - attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) - attention_mask = attention_mask.masked_fill(one_index, 0) - attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + #zero_index = attention_mask < 0.5 + #one_index = attention_mask >= 0.5 + #attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + #attention_mask = attention_mask.masked_fill(one_index, 0) + #attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) hidden_states_region = F.scaled_dot_product_attention( - query, - key, - value, + query[:,:,:1241,:], + key[:,:,:1241,:], + value[:,:,:1241,:], attn_mask=attention_mask, dropout_p=0.0, is_causal=False From 87eb0a9b0c6f53f3fd07cb325f87c402734b7db2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 21:07:45 +0800 Subject: [PATCH 206/482] Revert "clear codes" This reverts commit 55cacf79b2e716e7223eb1d12b06f2043009910f. --- src/diffusers/models/attention_processor.py | 56 ++++++++++++++------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b204e82fd1ae..ac54a011a6ee 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2418,12 +2418,12 @@ def __call__( if ip_img: num_of_prompts = int((query.shape[2] - 4825)/512) if num_of_prompts == 1: - #attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - attention_mask = torch.zeros(1241, 1241, dtype=torch.bool, device=query.device) + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + #self_attend_masks = torch.zeros((4096, 4096), device=query.device) + #union_masks = torch.zeros((4096, 4096), device=query.device) + # text related attention mask - attention_mask[:512, :512] = torch.ones(512, 512, dtype=torch.bool, device=query.device) - attention_mask[512:1241, 512:1241] = torch.ones(729, 729, dtype=torch.bool, device=query.device) - """ + attention_mask[:512, :512] = torch.ones(512, 512) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( txt_masks[0], 1, @@ -2432,16 +2432,25 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() + #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 + #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[:512,-4096:] = mask_downsample_t2i_tensor attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose + + """ + img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) + img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) + self_attend_masks = torch.logical_or(self_attend_masks, + torch.logical_and(img_size_masks, img_size_masks_transpose)) + union_masks = torch.logical_or(union_masks, + torch.logical_or(img_size_masks, img_size_masks_transpose)) """ # text related attention mask # image related attention mask - #attention_mask[512:1241, 512:1241] = torch.ones(729, 729) - """ + attention_mask[512:1241, 512:1241] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( img_mask[0], 1, @@ -2450,25 +2459,38 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() + #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 + #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose + + """ + img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) + img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) + self_attend_masks = torch.logical_or(self_attend_masks, + torch.logical_and(img_size_masks, img_size_masks_transpose)) + union_masks = torch.logical_or(union_masks, + torch.logical_or(img_size_masks, img_size_masks_transpose)) """ # image related attention mask - - #attention_mask[-4096:,-4096:] = 1#background_and_self_attend_masks - #zero_index = attention_mask < 0.5 - #one_index = attention_mask >= 0.5 - #attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) - #attention_mask = attention_mask.masked_fill(one_index, 0) - #attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + #background_masks = torch.logical_not(union_masks) + #background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) + + attention_mask[-4096:,-4096:] = 1#background_and_self_attend_masks + + zero_index = attention_mask < 0.5 + one_index = attention_mask >= 0.5 + attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + attention_mask = attention_mask.masked_fill(one_index, 0) + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) hidden_states_region = F.scaled_dot_product_attention( - query[:,:,:1241,:], - key[:,:,:1241,:], - value[:,:,:1241,:], + query, + key, + value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From 67fcb23e4f995160e091f9640fa6d9f4f43ff74e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 21:09:43 +0800 Subject: [PATCH 207/482] clear codes --- src/diffusers/models/attention_processor.py | 34 ++------------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ac54a011a6ee..c4b2946c234d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2419,9 +2419,6 @@ def __call__( num_of_prompts = int((query.shape[2] - 4825)/512) if num_of_prompts == 1: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - #self_attend_masks = torch.zeros((4096, 4096), device=query.device) - #union_masks = torch.zeros((4096, 4096), device=query.device) - # text related attention mask attention_mask[:512, :512] = torch.ones(512, 512) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( @@ -2432,23 +2429,11 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 - #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[:512,-4096:] = mask_downsample_t2i_tensor attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose - """ - img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) - img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) - self_attend_masks = torch.logical_or(self_attend_masks, - torch.logical_and(img_size_masks, img_size_masks_transpose)) - union_masks = torch.logical_or(union_masks, - torch.logical_or(img_size_masks, img_size_masks_transpose)) - """ - # text related attention mask - # image related attention mask attention_mask[512:1241, 512:1241] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( @@ -2459,27 +2444,12 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - #mask_downsample_t2i[mask_downsample_t2i<0.5] = 0 - #mask_downsample_t2i[mask_downsample_t2i>=0.5] = 1 mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose - - """ - img_size_masks = mask_downsample_t2i_tensor_transpose[:, :1].repeat(1, 4096).to(device=query.device) - img_size_masks_transpose = img_size_masks.transpose(-1, -2).to(device=query.device) - self_attend_masks = torch.logical_or(self_attend_masks, - torch.logical_and(img_size_masks, img_size_masks_transpose)) - union_masks = torch.logical_or(union_masks, - torch.logical_or(img_size_masks, img_size_masks_transpose)) - """ - # image related attention mask - - #background_masks = torch.logical_not(union_masks) - #background_and_self_attend_masks = torch.logical_or(background_masks, self_attend_masks) + attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose - attention_mask[-4096:,-4096:] = 1#background_and_self_attend_masks + attention_mask[-4096:,-4096:] = 1 zero_index = attention_mask < 0.5 one_index = attention_mask >= 0.5 From a93b84ac903d668dacd650f456417628a8dc61af Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 22:12:33 +0800 Subject: [PATCH 208/482] add prod ip imag --- src/diffusers/models/attention_processor.py | 283 ++++++++++++------ .../flux/pipeline_flux_prior_redux.py | 37 ++- 2 files changed, 219 insertions(+), 101 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c4b2946c234d..765eab210a09 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,6 +2359,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ip_img: Optional[bool] = False, # thesea modified for ip image + ip_prod_img: Optional[bool] = False, # thesea modified for ip image img_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask ) -> torch.FloatTensor: @@ -2416,102 +2417,200 @@ def __call__( if txt_masks is not None: if ip_img: - num_of_prompts = int((query.shape[2] - 4825)/512) - if num_of_prompts == 1: - attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - # text related attention mask - attention_mask[:512, :512] = torch.ones(512, 512) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - txt_masks[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[:512,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose - - # image related attention mask - attention_mask[512:1241, 512:1241] = torch.ones(729, 729) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - img_mask[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose - - attention_mask[-4096:,-4096:] = 1 - - zero_index = attention_mask < 0.5 - one_index = attention_mask >= 0.5 - attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) - attention_mask = attention_mask.masked_fill(one_index, 0) - attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) - - hidden_states_region = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False - ) + if not ip_prod_img: + num_of_prompts = int((query.shape[2] - 4825)/512) + if num_of_prompts == 1: + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + # text related attention mask + attention_mask[:512, :512] = torch.ones(512, 512) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + txt_masks[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[:512,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose + + # image related attention mask + attention_mask[512:1241, 512:1241] = torch.ones(729, 729) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + img_mask[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose + + attention_mask[-4096:,-4096:] = 1 + + zero_index = attention_mask < 0.5 + one_index = attention_mask >= 0.5 + attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + attention_mask = attention_mask.masked_fill(one_index, 0) + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + + hidden_states_region = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) - hidden_states_txt = F.scaled_dot_product_attention( - torch.cat([query[:,:,:512,:], query[:,:,-4096:,:]], dim=2), - torch.cat([key[:,:,:512,:], key[:,:,-4096:,:]], dim=2), - torch.cat([value[:,:,:512,:], value[:,:,-4096:,:]], dim=2), - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) + hidden_states_txt = F.scaled_dot_product_attention( + torch.cat([query[:,:,:512,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,:512,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,:512,:], value[:,:,-4096:,:]], dim=2), + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) - hidden_states_img = F.scaled_dot_product_attention( - query[:,:,512:,:], - key[:,:,512:,:], - value[:,:,512:,:], - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) + hidden_states_img = F.scaled_dot_product_attention( + query[:,:,512:,:], + key[:,:,512:,:], + value[:,:,512:,:], + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) - hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_region = hidden_states_region.to(query.dtype) - - hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_txt = hidden_states_txt.to(query.dtype) - - hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_img = hidden_states_img.to(query.dtype) - - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_masks[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - - img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_region = hidden_states_region.to(query.dtype) + + hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_txt = hidden_states_txt.to(query.dtype) + + hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_img = hidden_states_img.to(query.dtype) + + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_masks[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + + img_mask_downsample = IPAdapterMaskProcessor.downsample( + img_mask[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + + hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample + hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) + #hidden_states = hidden_states_region + else: + num_of_prompts = int((query.shape[2] - 5554)/512) + if num_of_prompts == 1: + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + # text related attention mask + attention_mask[:1241, :1241] = torch.ones(1241, 1241) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + txt_masks[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(1241, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[:1241,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, :1241] = mask_downsample_t2i_tensor_transpose + + # image related attention mask + attention_mask[1241:1970, 1241:1970] = torch.ones(729, 729) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + img_mask[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[1241:1970,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, 1241:1970] = mask_downsample_t2i_tensor_transpose + + attention_mask[-4096:,-4096:] = 1 + + zero_index = attention_mask < 0.5 + one_index = attention_mask >= 0.5 + attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + attention_mask = attention_mask.masked_fill(one_index, 0) + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + + hidden_states_region = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + + hidden_states_txt = F.scaled_dot_product_attention( + torch.cat([query[:,:,:1241,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,:1241,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,:1241,:], value[:,:,-4096:,:]], dim=2), + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + + hidden_states_img = F.scaled_dot_product_attention( + query[:,:,1241:,:], + key[:,:,1241:,:], + value[:,:,1241:,:], + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + + hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_region = hidden_states_region.to(query.dtype) + + hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_txt = hidden_states_txt.to(query.dtype) + + hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_img = hidden_states_img.to(query.dtype) + + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_masks[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + + img_mask_downsample = IPAdapterMaskProcessor.downsample( + img_mask[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample - hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) - #hidden_states = hidden_states_region + hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample + hidden_states = torch.cat([hidden_states_region[:,:1970,:], hidden_states_common], dim=1) + #hidden_states = hidden_states_region else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 1b3f412c0f89..b58de55da768 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -492,19 +492,26 @@ def __call__( image_mask[k] = image_mask[k] & ~mask if product_ratio > 0.0: - composed_image = np.zeros((image_width, image_height, 3)) + composed_bg_image = np.zeros((image_width, image_height, 3)) + composed_prod_image = np.zeros((image_width, image_height, 3)) for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_image += img_array * image_mask[index] * product_ratio + composed_prod_image += img_array * image_mask[index] * product_ratio else: - composed_image += img_array * image_mask[index] + composed_bg_image += img_array * image_mask[index] - composed_image = Image.fromarray(composed_image.astype(np.uint8)) + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) + composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) else: composed_image = image[0].convert('RGB') mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') - image_latents = self.encode_image(composed_image, device, 1) + + if product_ratio > 0.0: + image_latents_bg = self.encode_image(composed_bg_image, device, 1) + image_latents_prod = self.encode_image(composed_prod_image, device, 1) + else: + image_latents = self.encode_image(composed_image, device, 1) else: for img in image: img = img.convert('RGB') @@ -513,9 +520,18 @@ def __call__( image = image.convert('RGB') image_latents = self.encode_image(image, device, 1) - image_embeds = self.image_embedder(image_latents).image_embeds - image_embeds = image_embeds.to(device=device) - + if product_ratio > 0.0: + image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds + image_embeds_bg = image_embeds_bg.to(device=device) + + image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds + image_embeds_prod = image_embeds_prod.to(device=device) + image_embeds = torch.cat([image_embeds_prod, image_embeds_bg], , dim=1) + else: + image_embeds = self.image_embedder(image_latents).image_embeds + image_embeds = image_embeds.to(device=device) + + # 3. Prepare (dummy) text embeddings # thesea modified for ip and txt masks if hasattr(self, "text_encoder") and self.text_encoder is not None: @@ -579,7 +595,10 @@ def __call__( if not return_dict: if product_ratio is not None: - return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) + if product_ratio > 0.0: + return (prompt_embeds, pooled_prompt_embeds, composed_bg_image, composed_prod_image, mask) + else: + return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) else: return (prompt_embeds, pooled_prompt_embeds) From 758bff7db8acee82f8447ab3a48c9cf36d7890ca Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 22:14:42 +0800 Subject: [PATCH 209/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index b58de55da768..570370c6587f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -526,7 +526,7 @@ def __call__( image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds image_embeds_prod = image_embeds_prod.to(device=device) - image_embeds = torch.cat([image_embeds_prod, image_embeds_bg], , dim=1) + image_embeds = torch.cat([image_embeds_prod, image_embeds_bg], dim=1) else: image_embeds = self.image_embedder(image_latents).image_embeds image_embeds = image_embeds.to(device=device) From d22977810f78a6fddda6657eedc710d1b0133f2b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 22:52:27 +0800 Subject: [PATCH 210/482] separate prod masks from bg masks --- .../flux/pipeline_flux_prior_redux.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 570370c6587f..e7adf48f6ad3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -478,18 +478,26 @@ def __call__( image_array_list.append(image_array) product_mask = np.full((image_width, image_height, 3), True, dtype=bool) - image_mask = {} + image_mask_prod = {} + image_mask_bg = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): if is_product.lower() == "true": product_mask = product_mask & ~mask else: product_mask = product_mask | mask - if index not in image_mask: - image_mask[index] = mask - for k in image_mask: - if k != index: - image_mask[k] = image_mask[k] & ~mask + if is_product.lower() == "true": + if index not in image_mask_prod: + image_mask_prod[index] = mask + for k in image_mask_prod: + if k != index: + image_mask_prod[k] = image_mask_prod[k] & ~mask + else: + if index not in image_mask_bg: + image_mask_bg[index] = mask + for k in image_mask_bg: + if k != index: + image_mask_bg[k] = image_mask_bg[k] & ~mask if product_ratio > 0.0: composed_bg_image = np.zeros((image_width, image_height, 3)) From 0bf8f3f07dd79a4cb2404f96952234a2f0bf1864 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 31 Mar 2025 22:54:10 +0800 Subject: [PATCH 211/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index e7adf48f6ad3..d28537804f19 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -504,9 +504,9 @@ def __call__( composed_prod_image = np.zeros((image_width, image_height, 3)) for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_prod_image += img_array * image_mask[index] * product_ratio + composed_prod_image += img_array * image_mask_prod[index] * product_ratio else: - composed_bg_image += img_array * image_mask[index] + composed_bg_image += img_array * image_mask_bg[index] composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) From 5c17b5f9907b3a930489118011114cf7a94a3d88 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 11:09:59 +0800 Subject: [PATCH 212/482] add multiprod codes --- src/diffusers/models/attention_processor.py | 206 +++++++++++++----- .../flux/pipeline_flux_prior_redux.py | 75 +++++-- 2 files changed, 212 insertions(+), 69 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 765eab210a09..87e4dc72b78b 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,6 +2359,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ip_img: Optional[bool] = False, # thesea modified for ip image + multiprod: Optional[bool] = False, # thesea modified for ip image ip_prod_img: Optional[bool] = False, # thesea modified for ip image img_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask @@ -2515,26 +2516,125 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) #hidden_states = hidden_states_region else: - num_of_prompts = int((query.shape[2] - 5554)/512) - if num_of_prompts == 1: + if not multiprod: + num_of_prompts = int((query.shape[2] - 5554)/512) + if num_of_prompts == 1: + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + # text related attention mask + attention_mask[:1241, :1241] = torch.ones(1241, 1241) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + txt_masks[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(1241, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[:1241,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, :1241] = mask_downsample_t2i_tensor_transpose + + # image related attention mask + attention_mask[1241:1970, 1241:1970] = torch.ones(729, 729) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + img_mask[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[1241:1970,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, 1241:1970] = mask_downsample_t2i_tensor_transpose + + attention_mask[-4096:,-4096:] = 1 + + zero_index = attention_mask < 0.5 + one_index = attention_mask >= 0.5 + attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + attention_mask = attention_mask.masked_fill(one_index, 0) + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + + hidden_states_region = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + + hidden_states_txt = F.scaled_dot_product_attention( + torch.cat([query[:,:,:1241,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,:1241,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,:1241,:], value[:,:,-4096:,:]], dim=2), + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + + hidden_states_img = F.scaled_dot_product_attention( + query[:,:,1241:,:], + key[:,:,1241:,:], + value[:,:,1241:,:], + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + + hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_region = hidden_states_region.to(query.dtype) + + hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_txt = hidden_states_txt.to(query.dtype) + + hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_img = hidden_states_img.to(query.dtype) + + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_masks[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + + img_mask_downsample = IPAdapterMaskProcessor.downsample( + img_mask[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + + hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample + hidden_states = torch.cat([hidden_states_region[:,:1970,:], hidden_states_common], dim=1) + #hidden_states = hidden_states_region + else: + num_of_prods = int((query.shape[2] - 4825)/1241) attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) # text related attention mask - attention_mask[:1241, :1241] = torch.ones(1241, 1241) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - txt_masks[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(1241, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[:1241,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, :1241] = mask_downsample_t2i_tensor_transpose + for index in range(len(txt_masks)): + attention_mask[index*1241:(index+1)*1241, index*1241:(index+1)*1241] = torch.ones(1241, 1241) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + txt_masks[index], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(1241, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[index*1241:(index+1)*1241,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, index*1241:(index+1)*1241] = mask_downsample_t2i_tensor_transpose + # image related attention mask - attention_mask[1241:1970, 1241:1970] = torch.ones(729, 729) + attention_mask[len(txt_masks)*1241:len(txt_masks)*1241 + 729, len(txt_masks)*1241:len(txt_masks)*1241 + 729] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( img_mask[0], 1, @@ -2545,11 +2645,12 @@ def __call__( mask_downsample_t2i = mask_downsample_t2i.squeeze() mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[1241:1970,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, 1241:1970] = mask_downsample_t2i_tensor_transpose - + attention_mask[len(txt_masks)*1241:len(txt_masks)*1241 + 729,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(txt_masks)*1241:len(txt_masks)*1241 + 729] = mask_downsample_t2i_tensor_transpose + + attention_mask[-4096:,-4096:] = 1 - + zero_index = attention_mask < 0.5 one_index = attention_mask >= 0.5 attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) @@ -2564,42 +2665,45 @@ def __call__( dropout_p=0.0, is_causal=False ) + hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_region = hidden_states_region.to(query.dtype) + + + hidden_states_prods = [] + txt_mask_downsamples = [] + for index in range(len(txt_masks)): + hidden_states_tmp = F.scaled_dot_product_attention( + torch.cat([query[:,:,index*1241:(index+1)*1241,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,index*1241:(index+1)*1241,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,index*1241:(index+1)*1241,:], value[:,:,-4096:,:]], dim=2), + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + hidden_states_tmp = hidden_states_tmp.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_tmp = hidden_states_tmp.to(query.dtype) + hidden_states_prods.append(hidden_states_tmp) - hidden_states_txt = F.scaled_dot_product_attention( - torch.cat([query[:,:,:1241,:], query[:,:,-4096:,:]], dim=2), - torch.cat([key[:,:,:1241,:], key[:,:,-4096:,:]], dim=2), - torch.cat([value[:,:,:1241,:], value[:,:,-4096:,:]], dim=2), - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_masks[index], + batch_size, + 4096, + attn.heads * head_dim, + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + txt_mask_downsamples.append(txt_mask_downsample) + hidden_states_img = F.scaled_dot_product_attention( - query[:,:,1241:,:], - key[:,:,1241:,:], - value[:,:,1241:,:], + query[:,:,-4825:,:], + key[:,:,-4825:,:], + value[:,:,-4825:,:], attn_mask=None, dropout_p=0.0, is_causal=False ) - - hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_region = hidden_states_region.to(query.dtype) - - hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_txt = hidden_states_txt.to(query.dtype) - hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_img = hidden_states_img.to(query.dtype) - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_masks[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - img_mask_downsample = IPAdapterMaskProcessor.downsample( img_mask[0], batch_size, @@ -2608,9 +2712,11 @@ def __call__( ) img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample - hidden_states = torch.cat([hidden_states_region[:,:1970,:], hidden_states_common], dim=1) - #hidden_states = hidden_states_region + hidden_states_common = hidden_states_img[:,-4096:,:] * img_mask_downsample + for hidden_states_prod, txt_mask_downsample in zip(hidden_states_prods, txt_mask_downsamples): + hidden_states_common += hidden_states_prod[:,-4096:,:] * txt_mask_downsample + + hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index d28537804f19..5c276c3725f2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -381,6 +381,8 @@ def __call__( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, + multiprod: Optional[bool] = False, # thesea modified for ip image + multiprod_prompts: Optional[Union[str, List[str]]] = None, product_ratio: Optional[float] = None, # theseam modified image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, @@ -500,24 +502,42 @@ def __call__( image_mask_bg[k] = image_mask_bg[k] & ~mask if product_ratio > 0.0: - composed_bg_image = np.zeros((image_width, image_height, 3)) - composed_prod_image = np.zeros((image_width, image_height, 3)) - for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): - if is_product.lower() == "true": - composed_prod_image += img_array * image_mask_prod[index] * product_ratio - else: - composed_bg_image += img_array * image_mask_bg[index] - - composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) - composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) + if not multiprod: + composed_bg_image = np.zeros((image_width, image_height, 3)) + composed_prod_image = np.zeros((image_width, image_height, 3)) + for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): + if is_product.lower() == "true": + composed_prod_image += img_array * image_mask_prod[index] * product_ratio + else: + composed_bg_image += img_array * image_mask_bg[index] + + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) + composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) + else: + composed_bg_image = np.zeros((image_width, image_height, 3)) + composed_prod_images = [] + for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): + if is_product.lower() == "true": + composed_prod_image = img_array * image_mask_prod[index] * product_ratio + composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) + else: + composed_bg_image += img_array * image_mask_bg[index] + + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) else: composed_image = image[0].convert('RGB') mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') if product_ratio > 0.0: - image_latents_bg = self.encode_image(composed_bg_image, device, 1) - image_latents_prod = self.encode_image(composed_prod_image, device, 1) + if not multiprod: + image_latents_bg = self.encode_image(composed_bg_image, device, 1) + image_latents_prod = self.encode_image(composed_prod_image, device, 1) + else: + image_latents_bg = self.encode_image(composed_bg_image, device, 1) + image_latents_prods = [] + for composed_prod_image in composed_prod_images: + image_latents_prods.append(self.encode_image(composed_prod_image, device, 1)) else: image_latents = self.encode_image(composed_image, device, 1) else: @@ -529,12 +549,22 @@ def __call__( image_latents = self.encode_image(image, device, 1) if product_ratio > 0.0: - image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds - image_embeds_bg = image_embeds_bg.to(device=device) + if not multiprod: + image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds + image_embeds_bg = image_embeds_bg.to(device=device) - image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds - image_embeds_prod = image_embeds_prod.to(device=device) - image_embeds = torch.cat([image_embeds_prod, image_embeds_bg], dim=1) + image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds + image_embeds_prod = image_embeds_prod.to(device=device) + image_embeds = torch.cat([image_embeds_prod, image_embeds_bg], dim=1) + else: + image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds + image_embeds_bg = image_embeds_bg.to(device=device) + + image_embeds_prods = [] + for image_latents_prod in image_latents_prods: + image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds + image_embeds_prod = image_embeds_prod.to(device=device) + image_embeds_prods.append(image_embeds_prod) else: image_embeds = self.image_embedder(image_latents).image_embeds image_embeds = image_embeds.to(device=device) @@ -559,7 +589,8 @@ def __call__( lora_scale=None, ) prompt_embeds_list.append(prompt_embeds) - prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + if not multiprod: + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) else: ( prompt_embeds, @@ -587,7 +618,13 @@ def __call__( pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) # scale & concatenate image and text embeddings - prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + if not multiprod: + prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + else: + prompt_embeds = image_embeds_bg + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(prompt_embeds_list, image_embeds_prods): + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod, prompt_embeds], dim=1) + print(f'prompt_embeds shape={prompt_embeds.shape}') prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ From b6f622ff53944cb2980d76b3041e22066d43021f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 11:23:10 +0800 Subject: [PATCH 213/482] add composed_image_all for hed extraction --- .../pipelines/flux/pipeline_flux_prior_redux.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 5c276c3725f2..dcc986edda1b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -514,16 +514,20 @@ def __call__( composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) else: + composed_image_all = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": composed_prod_image = img_array * image_mask_prod[index] * product_ratio composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) + composed_image_all += img_array * image_mask_prod[index] else: composed_bg_image += img_array * image_mask_bg[index] - - composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) + composed_image_all += img_array * image_mask_bg[index] + + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') + composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') else: composed_image = image[0].convert('RGB') @@ -641,7 +645,10 @@ def __call__( if not return_dict: if product_ratio is not None: if product_ratio > 0.0: - return (prompt_embeds, pooled_prompt_embeds, composed_bg_image, composed_prod_image, mask) + if not multiprod: + return (prompt_embeds, pooled_prompt_embeds, composed_bg_image, composed_prod_image, mask) + else: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, mask) else: return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) else: From 608694d4a4ede854e511e9e238b97b8ed20795e7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 11:43:46 +0800 Subject: [PATCH 214/482] remove check input --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index dcc986edda1b..8ce9a22bf9b4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -419,6 +419,7 @@ def __call__( """ # 1. Check inputs. Raise error if not correct + """ self.check_inputs( image, prompt, @@ -428,6 +429,7 @@ def __call__( prompt_embeds_scale=prompt_embeds_scale, pooled_prompt_embeds_scale=pooled_prompt_embeds_scale, ) + """ # 2. Define call parameters if image is not None and isinstance(image, Image.Image): @@ -532,6 +534,9 @@ def __call__( composed_image = image[0].convert('RGB') mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') + prod_masks = [] + for tmp_mask in image_mask_prod: + prod_masks.append(Image.fromarray(image_mask_prod[tmp_mask].astype(np.uint8)*255).convert('RGB')) if product_ratio > 0.0: if not multiprod: @@ -648,7 +653,7 @@ def __call__( if not multiprod: return (prompt_embeds, pooled_prompt_embeds, composed_bg_image, composed_prod_image, mask) else: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, mask) else: return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) else: From 11d41471ca1341d89e337a9cf9e2a785daad423a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 12:28:20 +0800 Subject: [PATCH 215/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 8ce9a22bf9b4..461e0430100d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -635,8 +635,8 @@ def __call__( prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod, prompt_embeds], dim=1) print(f'prompt_embeds shape={prompt_embeds.shape}') - prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] - pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ + prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[:, None, None] + pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[ :, None ] From 895acd900e0bc22eaa331fae1589be2676211dff Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 12:39:04 +0800 Subject: [PATCH 216/482] remove print --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 461e0430100d..a0e94fd59ea8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -633,7 +633,6 @@ def __call__( prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(prompt_embeds_list, image_embeds_prods): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod, prompt_embeds], dim=1) - print(f'prompt_embeds shape={prompt_embeds.shape}') prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[:, None, None] pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[ From 821c6a6bc0c0501f6a12d2740a428dce592ee150 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 12:49:48 +0800 Subject: [PATCH 217/482] fix bug --- .../pipelines/flux/pipeline_flux_prior_redux.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index a0e94fd59ea8..4f2c5a852604 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -484,12 +484,19 @@ def __call__( product_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask_prod = {} image_mask_bg = {} + image_mask_all = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): if is_product.lower() == "true": product_mask = product_mask & ~mask else: product_mask = product_mask | mask + if index not in image_mask_all: + image_mask_all[index] = mask + for k in image_mask_all: + if k != index: + image_mask_all[k] = image_mask_all[k] & ~mask + if is_product.lower() == "true": if index not in image_mask_prod: image_mask_prod[index] = mask @@ -523,10 +530,10 @@ def __call__( if is_product.lower() == "true": composed_prod_image = img_array * image_mask_prod[index] * product_ratio composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) - composed_image_all += img_array * image_mask_prod[index] + composed_image_all += img_array * image_mask_all[index] else: composed_bg_image += img_array * image_mask_bg[index] - composed_image_all += img_array * image_mask_bg[index] + composed_image_all += img_array * image_mask_all[index] composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') From 34c77de6e93030011b240311ef48606047dd5464 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 13:00:07 +0800 Subject: [PATCH 218/482] fix order bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 4f2c5a852604..e33c93d433d7 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -638,7 +638,7 @@ def __call__( prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) else: prompt_embeds = image_embeds_bg - for tmp_prompt_embeds, tmp_image_embeds_prod in zip(prompt_embeds_list, image_embeds_prods): + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod, prompt_embeds], dim=1) prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[:, None, None] From 64c2dd3bde352b64f0889fd2391edd9545e7cae7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 15:36:43 +0800 Subject: [PATCH 219/482] remove unnecessary para --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index e33c93d433d7..0e1bfd7f670d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -382,7 +382,6 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, multiprod: Optional[bool] = False, # thesea modified for ip image - multiprod_prompts: Optional[Union[str, List[str]]] = None, product_ratio: Optional[float] = None, # theseam modified image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, From 4c82bb372dc5a86a7060f0ca887719635f55a056 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 15:44:45 +0800 Subject: [PATCH 220/482] change product_ratio behavior --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 0e1bfd7f670d..d4d2ca99d30b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -515,7 +515,7 @@ def __call__( composed_prod_image = np.zeros((image_width, image_height, 3)) for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_prod_image += img_array * image_mask_prod[index] * product_ratio + composed_prod_image += img_array * image_mask_prod[index] #* product_ratio else: composed_bg_image += img_array * image_mask_bg[index] @@ -527,7 +527,7 @@ def __call__( composed_prod_images = [] for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_prod_image = img_array * image_mask_prod[index] * product_ratio + composed_prod_image = img_array * image_mask_prod[index] #* product_ratio composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) composed_image_all += img_array * image_mask_all[index] else: @@ -638,7 +638,7 @@ def __call__( else: prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): - prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod, prompt_embeds], dim=1) + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod * product_ratio, prompt_embeds], dim=1) prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[:, None, None] pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[ From b9966051becdf987e0594f39f97cbb2141c54018 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 1 Apr 2025 22:49:12 +0800 Subject: [PATCH 221/482] change product ratio behavior --- src/diffusers/models/attention_processor.py | 24 ++++++++++--------- .../flux/pipeline_flux_prior_redux.py | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 87e4dc72b78b..4e394b2eac4e 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2361,6 +2361,7 @@ def __call__( ip_img: Optional[bool] = False, # thesea modified for ip image multiprod: Optional[bool] = False, # thesea modified for ip image ip_prod_img: Optional[bool] = False, # thesea modified for ip image + product_ratio: Optional[float] = None, # theseam modified img_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask ) -> torch.FloatTensor: @@ -2614,11 +2615,12 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:1970,:], hidden_states_common], dim=1) #hidden_states = hidden_states_region else: - num_of_prods = int((query.shape[2] - 4825)/1241) + #num_of_prods = int((query.shape[2] - 4825)/1241) attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + prod_embeds_dim = 512 + int(729 * product_ratio) # text related attention mask for index in range(len(txt_masks)): - attention_mask[index*1241:(index+1)*1241, index*1241:(index+1)*1241] = torch.ones(1241, 1241) + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( txt_masks[index], 1, @@ -2627,14 +2629,14 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(1241, 1).to(device=query.device) + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(prod_embeds_dim, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[index*1241:(index+1)*1241,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, index*1241:(index+1)*1241] = mask_downsample_t2i_tensor_transpose + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = mask_downsample_t2i_tensor_transpose # image related attention mask - attention_mask[len(txt_masks)*1241:len(txt_masks)*1241 + 729, len(txt_masks)*1241:len(txt_masks)*1241 + 729] = torch.ones(729, 729) + attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( img_mask[0], 1, @@ -2645,8 +2647,8 @@ def __call__( mask_downsample_t2i = mask_downsample_t2i.squeeze() mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[len(txt_masks)*1241:len(txt_masks)*1241 + 729,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, len(txt_masks)*1241:len(txt_masks)*1241 + 729] = mask_downsample_t2i_tensor_transpose + attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose attention_mask[-4096:,-4096:] = 1 @@ -2673,9 +2675,9 @@ def __call__( txt_mask_downsamples = [] for index in range(len(txt_masks)): hidden_states_tmp = F.scaled_dot_product_attention( - torch.cat([query[:,:,index*1241:(index+1)*1241,:], query[:,:,-4096:,:]], dim=2), - torch.cat([key[:,:,index*1241:(index+1)*1241,:], key[:,:,-4096:,:]], dim=2), - torch.cat([value[:,:,index*1241:(index+1)*1241,:], value[:,:,-4096:,:]], dim=2), + torch.cat([query[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], value[:,:,-4096:,:]], dim=2), attn_mask=None, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index d4d2ca99d30b..3f2d1b4cbcbe 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -638,7 +638,7 @@ def __call__( else: prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): - prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod * product_ratio, prompt_embeds], dim=1) + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[:, None, None] pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[ From 00f60c5b8affbb3ca1baaa15801e5e7212b0c89c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 10:45:47 +0800 Subject: [PATCH 222/482] add product ratio to single product --- .../pipelines/flux/pipeline_flux_prior_redux.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 3f2d1b4cbcbe..19dc85bf97b2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -511,14 +511,17 @@ def __call__( if product_ratio > 0.0: if not multiprod: + composed_image_all = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_image = np.zeros((image_width, image_height, 3)) for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_prod_image += img_array * image_mask_prod[index] #* product_ratio + composed_prod_image += img_array * image_mask_prod[index] else: composed_bg_image += img_array * image_mask_bg[index] - + + composed_image_all += img_array * image_mask_all[index] + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) else: @@ -527,12 +530,12 @@ def __call__( composed_prod_images = [] for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_prod_image = img_array * image_mask_prod[index] #* product_ratio + composed_prod_image = img_array * image_mask_prod[index] composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) - composed_image_all += img_array * image_mask_all[index] else: composed_bg_image += img_array * image_mask_bg[index] - composed_image_all += img_array * image_mask_all[index] + + composed_image_all += img_array * image_mask_all[index] composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') @@ -634,7 +637,7 @@ def __call__( # scale & concatenate image and text embeddings if not multiprod: - prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + prompt_embeds = torch.cat([prompt_embeds, image_embeds[:,:int(729*product_ratio),:]], dim=1) else: prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): @@ -656,7 +659,7 @@ def __call__( if product_ratio is not None: if product_ratio > 0.0: if not multiprod: - return (prompt_embeds, pooled_prompt_embeds, composed_bg_image, composed_prod_image, mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_image, mask) else: return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, mask) else: From e6dd38450c36c4a0c3391e3a7390e3bbbf9461ee Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 10:50:11 +0800 Subject: [PATCH 223/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 19dc85bf97b2..59d053fb49df 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -524,6 +524,7 @@ def __call__( composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) + composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') else: composed_image_all = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) From 0d5ee0ab0b510d619cd02e2d1d034c3660de0f56 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 11:32:35 +0800 Subject: [PATCH 224/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 59d053fb49df..baba875755e6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -625,6 +625,7 @@ def __call__( max_sequence_length=512, lora_scale=None, ) + prompt_embeds_list.append(prompt_embeds) else: if prompt is not None: logger.warning( From 83845afca0bfc3bad192adc635ffda313ba0d334 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 12:33:39 +0800 Subject: [PATCH 225/482] clear codes --- .../flux/pipeline_flux_prior_redux.py | 234 +++++++----------- 1 file changed, 92 insertions(+), 142 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index baba875755e6..cf6e879357e6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -381,8 +381,8 @@ def __call__( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, - multiprod: Optional[bool] = False, # thesea modified for ip image - product_ratio: Optional[float] = None, # theseam modified + is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots + product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, return_dict: bool = True, @@ -418,6 +418,7 @@ def __call__( """ # 1. Check inputs. Raise error if not correct + # thesea modified """ self.check_inputs( image, @@ -452,143 +453,100 @@ def __call__( device = self._execution_device # 3. Prepare image embeddings - # thesea modified # thesea modified for ip and txt masks - if isinstance(image, list): - if product_ratio is not None: - image_array_list = [] - mask_list = [] - is_product_list = [] - for img, img_type in zip(image, layer_type): - if 'product' in img_type or 'Product' in img_type: - is_product_list.append('true') - else: - is_product_list.append('false') - - img = img.convert('RGBA') - img = img.resize((image_width, image_height), resample=Image.BICUBIC) - - rgba_np = np.array(img) - mask = rgba_np[:, :, 3] - mask = mask > 0 - mask = np.stack((mask,)*3, axis=-1) - mask_list.append(mask) - - tmp_img = Image.new('RGBA', img.size, (255, 255, 255, 255)) - tmp_img.paste(img, mask=img.split()[3]) - tmp_img = tmp_img.convert('RGB') - image_array = np.asarray(tmp_img) - image_array_list.append(image_array) - - product_mask = np.full((image_width, image_height, 3), True, dtype=bool) - image_mask_prod = {} - image_mask_bg = {} - image_mask_all = {} - for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): - if is_product.lower() == "true": - product_mask = product_mask & ~mask - else: - product_mask = product_mask | mask - - if index not in image_mask_all: - image_mask_all[index] = mask - for k in image_mask_all: + if is_qv: + image_array_list = [] + mask_list = [] + is_product_list = [] + for img, img_type in zip(image, layer_type): + if 'product' in img_type or 'Product' in img_type: + is_product_list.append('true') + else: + is_product_list.append('false') + + img = img.convert('RGBA') + img = img.resize((image_width, image_height), resample=Image.BICUBIC) + + rgba_np = np.array(img) + mask = rgba_np[:, :, 3] + mask = mask > 0 + mask = np.stack((mask,)*3, axis=-1) + mask_list.append(mask) + + tmp_img = Image.new('RGBA', img.size, (255, 255, 255, 255)) + tmp_img.paste(img, mask=img.split()[3]) + tmp_img = tmp_img.convert('RGB') + image_array = np.asarray(tmp_img) + image_array_list.append(image_array) + + product_mask = np.full((image_width, image_height, 3), True, dtype=bool) + image_mask_prod = {} + image_mask_bg = {} + image_mask_all = {} + for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): + if is_product.lower() == "true": + product_mask = product_mask & ~mask + else: + product_mask = product_mask | mask + + if index not in image_mask_all: + image_mask_all[index] = mask + for k in image_mask_all: + if k != index: + image_mask_all[k] = image_mask_all[k] & ~mask + + if is_product.lower() == "true": + if index not in image_mask_prod: + image_mask_prod[index] = mask + for k in image_mask_prod: if k != index: - image_mask_all[k] = image_mask_all[k] & ~mask - - if is_product.lower() == "true": - if index not in image_mask_prod: - image_mask_prod[index] = mask - for k in image_mask_prod: - if k != index: - image_mask_prod[k] = image_mask_prod[k] & ~mask - else: - if index not in image_mask_bg: - image_mask_bg[index] = mask - for k in image_mask_bg: - if k != index: - image_mask_bg[k] = image_mask_bg[k] & ~mask - - if product_ratio > 0.0: - if not multiprod: - composed_image_all = np.zeros((image_width, image_height, 3)) - composed_bg_image = np.zeros((image_width, image_height, 3)) - composed_prod_image = np.zeros((image_width, image_height, 3)) - for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): - if is_product.lower() == "true": - composed_prod_image += img_array * image_mask_prod[index] - else: - composed_bg_image += img_array * image_mask_bg[index] - - composed_image_all += img_array * image_mask_all[index] - - composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)) - composed_prod_image = Image.fromarray(composed_prod_image.astype(np.uint8)) - composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') - else: - composed_image_all = np.zeros((image_width, image_height, 3)) - composed_bg_image = np.zeros((image_width, image_height, 3)) - composed_prod_images = [] - for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): - if is_product.lower() == "true": - composed_prod_image = img_array * image_mask_prod[index] - composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) - else: - composed_bg_image += img_array * image_mask_bg[index] - - composed_image_all += img_array * image_mask_all[index] - - composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') - composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') + image_mask_prod[k] = image_mask_prod[k] & ~mask else: - composed_image = image[0].convert('RGB') - - mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') - prod_masks = [] - for tmp_mask in image_mask_prod: - prod_masks.append(Image.fromarray(image_mask_prod[tmp_mask].astype(np.uint8)*255).convert('RGB')) - - if product_ratio > 0.0: - if not multiprod: - image_latents_bg = self.encode_image(composed_bg_image, device, 1) - image_latents_prod = self.encode_image(composed_prod_image, device, 1) - else: - image_latents_bg = self.encode_image(composed_bg_image, device, 1) - image_latents_prods = [] - for composed_prod_image in composed_prod_images: - image_latents_prods.append(self.encode_image(composed_prod_image, device, 1)) + if index not in image_mask_bg: + image_mask_bg[index] = mask + for k in image_mask_bg: + if k != index: + image_mask_bg[k] = image_mask_bg[k] & ~mask + + composed_image_all = np.zeros((image_width, image_height, 3)) + composed_bg_image = np.zeros((image_width, image_height, 3)) + composed_prod_images = [] + for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): + if is_product.lower() == "true": + composed_prod_image = img_array * image_mask_prod[index] + composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) else: - image_latents = self.encode_image(composed_image, device, 1) - else: - for img in image: - img = img.convert('RGB') - image_latents = self.encode_image(image, device, 1) - else: - image = image.convert('RGB') - image_latents = self.encode_image(image, device, 1) - - if product_ratio > 0.0: - if not multiprod: - image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds - image_embeds_bg = image_embeds_bg.to(device=device) + composed_bg_image += img_array * image_mask_bg[index] + + composed_image_all += img_array * image_mask_all[index] + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') + composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') + + mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') + prod_masks = [] + for tmp_mask in image_mask_prod: + prod_masks.append(Image.fromarray(image_mask_prod[tmp_mask].astype(np.uint8)*255).convert('RGB')) + + image_latents_bg = self.encode_image(composed_bg_image, device, 1) + image_latents_prods = [] + for composed_prod_image in composed_prod_images: + image_latents_prods.append(self.encode_image(composed_prod_image, device, 1)) + + image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds + image_embeds_bg = image_embeds_bg.to(device=device) + + image_embeds_prods = [] + for image_latents_prod in image_latents_prods: image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds image_embeds_prod = image_embeds_prod.to(device=device) - image_embeds = torch.cat([image_embeds_prod, image_embeds_bg], dim=1) - else: - image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds - image_embeds_bg = image_embeds_bg.to(device=device) - - image_embeds_prods = [] - for image_latents_prod in image_latents_prods: - image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds - image_embeds_prod = image_embeds_prod.to(device=device) - image_embeds_prods.append(image_embeds_prod) + image_embeds_prods.append(image_embeds_prod) else: + image = image.convert('RGB') + image_latents = self.encode_image(image, device, 1) image_embeds = self.image_embedder(image_latents).image_embeds image_embeds = image_embeds.to(device=device) - # 3. Prepare (dummy) text embeddings # thesea modified for ip and txt masks if hasattr(self, "text_encoder") and self.text_encoder is not None: @@ -608,8 +566,6 @@ def __call__( lora_scale=None, ) prompt_embeds_list.append(prompt_embeds) - if not multiprod: - prompt_embeds = torch.cat(prompt_embeds_list, dim=1) else: ( prompt_embeds, @@ -625,7 +581,7 @@ def __call__( max_sequence_length=512, lora_scale=None, ) - prompt_embeds_list.append(prompt_embeds) + prompt_embeds_list.append(prompt_embeds) # thesea modified for ip and txt masks else: if prompt is not None: logger.warning( @@ -638,12 +594,12 @@ def __call__( pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) # scale & concatenate image and text embeddings - if not multiprod: - prompt_embeds = torch.cat([prompt_embeds, image_embeds[:,:int(729*product_ratio),:]], dim=1) - else: + if is_qv: prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) + else: + prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[:, None, None] pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[ @@ -658,14 +614,8 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - if product_ratio is not None: - if product_ratio > 0.0: - if not multiprod: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_image, mask) - else: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, mask) - else: - return (prompt_embeds, pooled_prompt_embeds, composed_image, mask) + if is_qv + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, mask) else: return (prompt_embeds, pooled_prompt_embeds) From 16ebf1b09fc8c1add9bd03713b46d9c5a3431111 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 12:34:32 +0800 Subject: [PATCH 226/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index cf6e879357e6..53a5c9e4cae5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -614,7 +614,7 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - if is_qv + if is_qv: return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, mask) else: return (prompt_embeds, pooled_prompt_embeds) From d5c008bd6d625ee694a8fd2ce7f79436bf990350 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 12:58:43 +0800 Subject: [PATCH 227/482] clear codes --- src/diffusers/models/attention_processor.py | 396 +++++--------------- 1 file changed, 98 insertions(+), 298 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4e394b2eac4e..c9fab6d7ef23 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2358,9 +2358,7 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - ip_img: Optional[bool] = False, # thesea modified for ip image - multiprod: Optional[bool] = False, # thesea modified for ip image - ip_prod_img: Optional[bool] = False, # thesea modified for ip image + ip_qv: Optional[bool] = False, # thesea modified for ip image product_ratio: Optional[float] = None, # theseam modified img_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask @@ -2417,308 +2415,110 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - if txt_masks is not None: - if ip_img: - if not ip_prod_img: - num_of_prompts = int((query.shape[2] - 4825)/512) - if num_of_prompts == 1: - attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - # text related attention mask - attention_mask[:512, :512] = torch.ones(512, 512) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - txt_masks[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(512, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[:512,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, :512] = mask_downsample_t2i_tensor_transpose - - # image related attention mask - attention_mask[512:1241, 512:1241] = torch.ones(729, 729) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - img_mask[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[512:1241,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, 512:1241] = mask_downsample_t2i_tensor_transpose - - attention_mask[-4096:,-4096:] = 1 - - zero_index = attention_mask < 0.5 - one_index = attention_mask >= 0.5 - attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) - attention_mask = attention_mask.masked_fill(one_index, 0) - attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) - - hidden_states_region = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False - ) - - hidden_states_txt = F.scaled_dot_product_attention( - torch.cat([query[:,:,:512,:], query[:,:,-4096:,:]], dim=2), - torch.cat([key[:,:,:512,:], key[:,:,-4096:,:]], dim=2), - torch.cat([value[:,:,:512,:], value[:,:,-4096:,:]], dim=2), - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - - hidden_states_img = F.scaled_dot_product_attention( - query[:,:,512:,:], - key[:,:,512:,:], - value[:,:,512:,:], - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - - hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_region = hidden_states_region.to(query.dtype) - - hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_txt = hidden_states_txt.to(query.dtype) - - hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_img = hidden_states_img.to(query.dtype) - - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_masks[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - - img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - - hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample - hidden_states = torch.cat([hidden_states_region[:,:1241,:], hidden_states_common], dim=1) - #hidden_states = hidden_states_region - else: - if not multiprod: - num_of_prompts = int((query.shape[2] - 5554)/512) - if num_of_prompts == 1: - attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - # text related attention mask - attention_mask[:1241, :1241] = torch.ones(1241, 1241) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - txt_masks[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(1241, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[:1241,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, :1241] = mask_downsample_t2i_tensor_transpose - - # image related attention mask - attention_mask[1241:1970, 1241:1970] = torch.ones(729, 729) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - img_mask[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[1241:1970,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, 1241:1970] = mask_downsample_t2i_tensor_transpose - - attention_mask[-4096:,-4096:] = 1 - - zero_index = attention_mask < 0.5 - one_index = attention_mask >= 0.5 - attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) - attention_mask = attention_mask.masked_fill(one_index, 0) - attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) - - hidden_states_region = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False - ) - - hidden_states_txt = F.scaled_dot_product_attention( - torch.cat([query[:,:,:1241,:], query[:,:,-4096:,:]], dim=2), - torch.cat([key[:,:,:1241,:], key[:,:,-4096:,:]], dim=2), - torch.cat([value[:,:,:1241,:], value[:,:,-4096:,:]], dim=2), - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - - hidden_states_img = F.scaled_dot_product_attention( - query[:,:,1241:,:], - key[:,:,1241:,:], - value[:,:,1241:,:], - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - - hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_region = hidden_states_region.to(query.dtype) - - hidden_states_txt = hidden_states_txt.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_txt = hidden_states_txt.to(query.dtype) - - hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_img = hidden_states_img.to(query.dtype) - - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_masks[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + if is_qv is not None: + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + prod_embeds_dim = 512 + int(729 * product_ratio) + # text related attention mask + for index in range(len(txt_masks)): + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + txt_masks[index], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(prod_embeds_dim, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = mask_downsample_t2i_tensor_transpose + - img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + # image related attention mask + attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + img_mask[0], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose - hidden_states_common = hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_img[:,-4096:,:] * img_mask_downsample - hidden_states = torch.cat([hidden_states_region[:,:1970,:], hidden_states_common], dim=1) - #hidden_states = hidden_states_region - else: - #num_of_prods = int((query.shape[2] - 4825)/1241) - attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) - prod_embeds_dim = 512 + int(729 * product_ratio) - # text related attention mask - for index in range(len(txt_masks)): - attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - txt_masks[index], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(prod_embeds_dim, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = mask_downsample_t2i_tensor_transpose - - # image related attention mask - attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) - mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - img_mask[0], - 1, - 4096, - 1, - ) - mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) - mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) - mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose - - - attention_mask[-4096:,-4096:] = 1 - - zero_index = attention_mask < 0.5 - one_index = attention_mask >= 0.5 - attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) - attention_mask = attention_mask.masked_fill(one_index, 0) - attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) - - hidden_states_region = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False - ) - hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_region = hidden_states_region.to(query.dtype) - - - hidden_states_prods = [] - txt_mask_downsamples = [] - for index in range(len(txt_masks)): - hidden_states_tmp = F.scaled_dot_product_attention( - torch.cat([query[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], query[:,:,-4096:,:]], dim=2), - torch.cat([key[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], key[:,:,-4096:,:]], dim=2), - torch.cat([value[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], value[:,:,-4096:,:]], dim=2), - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - hidden_states_tmp = hidden_states_tmp.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_tmp = hidden_states_tmp.to(query.dtype) - hidden_states_prods.append(hidden_states_tmp) + attention_mask[-4096:,-4096:] = 1 + + zero_index = attention_mask < 0.5 + one_index = attention_mask >= 0.5 + attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + attention_mask = attention_mask.masked_fill(one_index, 0) + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + + hidden_states_region = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_region = hidden_states_region.to(query.dtype) + - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_masks[index], - batch_size, - 4096, - attn.heads * head_dim, - ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsamples.append(txt_mask_downsample) - - hidden_states_img = F.scaled_dot_product_attention( - query[:,:,-4825:,:], - key[:,:,-4825:,:], - value[:,:,-4825:,:], - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_img = hidden_states_img.to(query.dtype) + hidden_states_prods = [] + txt_mask_downsamples = [] + for index in range(len(txt_masks)): + hidden_states_tmp = F.scaled_dot_product_attention( + torch.cat([query[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], value[:,:,-4096:,:]], dim=2), + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + hidden_states_tmp = hidden_states_tmp.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_tmp = hidden_states_tmp.to(query.dtype) + hidden_states_prods.append(hidden_states_tmp) - img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask[0], - batch_size, - 4096, - attn.heads * head_dim, - ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_masks[index], + batch_size, + 4096, + attn.heads * head_dim, + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + txt_mask_downsamples.append(txt_mask_downsample) + + hidden_states_img = F.scaled_dot_product_attention( + query[:,:,-4825:,:], + key[:,:,-4825:,:], + value[:,:,-4825:,:], + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_img = hidden_states_img.to(query.dtype) + + img_mask_downsample = IPAdapterMaskProcessor.downsample( + img_mask[0], + batch_size, + 4096, + attn.heads * head_dim, + ) + img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_common = hidden_states_img[:,-4096:,:] * img_mask_downsample - for hidden_states_prod, txt_mask_downsample in zip(hidden_states_prods, txt_mask_downsamples): - hidden_states_common += hidden_states_prod[:,-4096:,:] * txt_mask_downsample + hidden_states_common = hidden_states_img[:,-4096:,:] * img_mask_downsample + for hidden_states_prod, txt_mask_downsample in zip(hidden_states_prods, txt_mask_downsamples): + hidden_states_common += hidden_states_prod[:,-4096:,:] * txt_mask_downsample - hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) + hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) else: hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False From ca33a91711061ca8a5b92d2b4cb8b8ebe019660a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 13:03:40 +0800 Subject: [PATCH 228/482] fix bug --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c9fab6d7ef23..945d4dde6e91 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2358,7 +2358,7 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - ip_qv: Optional[bool] = False, # thesea modified for ip image + is_qv: Optional[bool] = False, # thesea modified for ip image product_ratio: Optional[float] = None, # theseam modified img_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask @@ -2415,7 +2415,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - if is_qv is not None: + if is_qv: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) # text related attention mask From 4cf49ce068779fd9a6b2b0ce26fba2432471cf97 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 13:46:53 +0800 Subject: [PATCH 229/482] fix minor codes --- src/diffusers/models/attention_processor.py | 6 +++++ .../flux/pipeline_flux_prior_redux.py | 25 ++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 945d4dde6e91..5a022f8affdf 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2418,6 +2418,12 @@ def __call__( if is_qv: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) + num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) + if num_of_prompts != len(txt_masks) + raise ValueError( + f"Length of txt_masks ({len(txt_masks)}) must match number of prompts {num_of_prompts}" + ) + # text related attention mask for index in range(len(txt_masks)): attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 53a5c9e4cae5..de0e6c4bfcb9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -479,15 +479,15 @@ def __call__( image_array = np.asarray(tmp_img) image_array_list.append(image_array) - product_mask = np.full((image_width, image_height, 3), True, dtype=bool) + bg_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask_prod = {} image_mask_bg = {} image_mask_all = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): if is_product.lower() == "true": - product_mask = product_mask & ~mask + bg_mask = bg_mask & ~mask else: - product_mask = product_mask | mask + bg_mask = bg_mask | mask if index not in image_mask_all: image_mask_all[index] = mask @@ -495,13 +495,14 @@ def __call__( if k != index: image_mask_all[k] = image_mask_all[k] & ~mask - if is_product.lower() == "true": - if index not in image_mask_prod: - image_mask_prod[index] = mask - for k in image_mask_prod: - if k != index: - image_mask_prod[k] = image_mask_prod[k] & ~mask - else: + #if is_product.lower() == "true": + if index not in image_mask_prod: + image_mask_prod[index] = mask + for k in image_mask_prod: + if k != index: + image_mask_prod[k] = image_mask_prod[k] & ~mask + + if is_product.lower() != "true": if index not in image_mask_bg: image_mask_bg[index] = mask for k in image_mask_bg: @@ -523,7 +524,7 @@ def __call__( composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') - mask = Image.fromarray(product_mask.astype(np.uint8)*255).convert('RGB') + bg_mask = Image.fromarray(bg_mask.astype(np.uint8)*255).convert('RGB') prod_masks = [] for tmp_mask in image_mask_prod: prod_masks.append(Image.fromarray(image_mask_prod[tmp_mask].astype(np.uint8)*255).convert('RGB')) @@ -615,7 +616,7 @@ def __call__( if not return_dict: if is_qv: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: return (prompt_embeds, pooled_prompt_embeds) From 7322ed227ce51803e98ed081c81fbdb6948f9b37 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 13:50:37 +0800 Subject: [PATCH 230/482] fix bug --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5a022f8affdf..90ff1cf48d03 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2419,11 +2419,11 @@ def __call__( attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) - if num_of_prompts != len(txt_masks) + if num_of_prompts != len(txt_masks): raise ValueError( f"Length of txt_masks ({len(txt_masks)}) must match number of prompts {num_of_prompts}" ) - + # text related attention mask for index in range(len(txt_masks)): attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) From 33af960579b42f8232c4ffa9ce5f6cf062521b56 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 13:56:54 +0800 Subject: [PATCH 231/482] fix mask bug --- .../pipelines/flux/pipeline_flux_prior_redux.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index de0e6c4bfcb9..16e7f96e6a21 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -382,6 +382,7 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots + is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, @@ -495,20 +496,20 @@ def __call__( if k != index: image_mask_all[k] = image_mask_all[k] & ~mask - #if is_product.lower() == "true": - if index not in image_mask_prod: - image_mask_prod[index] = mask - for k in image_mask_prod: - if k != index: - image_mask_prod[k] = image_mask_prod[k] & ~mask - - if is_product.lower() != "true": + if is_product.lower() == "true": + if index not in image_mask_prod: + image_mask_prod[index] = mask + else: if index not in image_mask_bg: image_mask_bg[index] = mask for k in image_mask_bg: if k != index: image_mask_bg[k] = image_mask_bg[k] & ~mask + for k in image_mask_prod: + if k != index: + image_mask_prod[k] = image_mask_prod[k] & ~mask + composed_image_all = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] From fa55567d768dbb23997cf3a322b015b562d7a3d8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 14:08:05 +0800 Subject: [PATCH 232/482] print info --- .../flux/pipeline_flux_prior_redux.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 16e7f96e6a21..64202c080d20 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -482,6 +482,7 @@ def __call__( bg_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask_prod = {} + image_mask_prod_for_ip = {} image_mask_bg = {} image_mask_all = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): @@ -499,6 +500,13 @@ def __call__( if is_product.lower() == "true": if index not in image_mask_prod: image_mask_prod[index] = mask + + if index not in image_mask_prod_for_ip: + image_mask_prod_for_ip[index] = mask + + for k in image_mask_prod_for_ip: + if k != index: + image_mask_prod_for_ip[k] = image_mask_prod_for_ip[k] & ~mask else: if index not in image_mask_bg: image_mask_bg[index] = mask @@ -509,13 +517,19 @@ def __call__( for k in image_mask_prod: if k != index: image_mask_prod[k] = image_mask_prod[k] & ~mask - + + if len(image_mask_prod) > 1: + if not is_multiprod: + prompt=[prompt*len(image_mask_prod)] + for index, pmt enumerate(prompt): + print(f'index={index}, prompt={pmt}') + composed_image_all = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_prod_image = img_array * image_mask_prod[index] + composed_prod_image = img_array * image_mask_prod_for_ip[index] composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) else: composed_bg_image += img_array * image_mask_bg[index] From edf43f2ef5d1bd42ec709148bef6cb79114c13c1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 14:11:20 +0800 Subject: [PATCH 233/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 64202c080d20..96308d9dea43 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -521,8 +521,10 @@ def __call__( if len(image_mask_prod) > 1: if not is_multiprod: prompt=[prompt*len(image_mask_prod)] - for index, pmt enumerate(prompt): - print(f'index={index}, prompt={pmt}') + + print(f'number of prompts={len(prompt)}') + for pmt in prompt: + print(f'prompt={pmt}') composed_image_all = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) From cb83ce149d8104809c24eeed82d1cd1216c2d2cf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 14:18:59 +0800 Subject: [PATCH 234/482] fix bug --- .../pipelines/flux/pipeline_flux_prior_redux.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 96308d9dea43..55a0db2ced6a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -482,7 +482,6 @@ def __call__( bg_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask_prod = {} - image_mask_prod_for_ip = {} image_mask_bg = {} image_mask_all = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): @@ -500,13 +499,6 @@ def __call__( if is_product.lower() == "true": if index not in image_mask_prod: image_mask_prod[index] = mask - - if index not in image_mask_prod_for_ip: - image_mask_prod_for_ip[index] = mask - - for k in image_mask_prod_for_ip: - if k != index: - image_mask_prod_for_ip[k] = image_mask_prod_for_ip[k] & ~mask else: if index not in image_mask_bg: image_mask_bg[index] = mask @@ -522,7 +514,7 @@ def __call__( if not is_multiprod: prompt=[prompt*len(image_mask_prod)] - print(f'number of prompts={len(prompt)}') + print(f'number of prompts={len(image_mask_prod)}') for pmt in prompt: print(f'prompt={pmt}') @@ -531,8 +523,7 @@ def __call__( composed_prod_images = [] for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": - composed_prod_image = img_array * image_mask_prod_for_ip[index] - composed_prod_images.append(Image.fromarray(composed_prod_image.astype(np.uint8))) + composed_prod_images.append(Image.fromarray(img_array.astype(np.uint8))) else: composed_bg_image += img_array * image_mask_bg[index] From 2625d5289400000c7641bad50433b47ad9835ef1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 14:29:08 +0800 Subject: [PATCH 235/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 55a0db2ced6a..0846fdb8e681 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -512,7 +512,7 @@ def __call__( if len(image_mask_prod) > 1: if not is_multiprod: - prompt=[prompt*len(image_mask_prod)] + prompt=[prompt]*len(image_mask_prod) print(f'number of prompts={len(image_mask_prod)}') for pmt in prompt: From d88f4d913a434f8dd6737328794af27eb436f60f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 14:31:24 +0800 Subject: [PATCH 236/482] remove print --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 0846fdb8e681..a5fbd1288ed0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -513,10 +513,6 @@ def __call__( if len(image_mask_prod) > 1: if not is_multiprod: prompt=[prompt]*len(image_mask_prod) - - print(f'number of prompts={len(image_mask_prod)}') - for pmt in prompt: - print(f'prompt={pmt}') composed_image_all = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) From d3418ba4f945f112081859740c90ddf49b77e49e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 14:34:15 +0800 Subject: [PATCH 237/482] remove joint_attention_kwargs in transformer single blocks --- src/diffusers/models/transformers/transformer_flux.py | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index da7d6696c2f5..5df4de6c7a56 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -455,7 +455,7 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - + encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: @@ -525,7 +525,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 02809ce5f9e8..4c7d099df313 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1114,7 +1114,6 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - #joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, ) From 066c7e334bd4ae08cd6dca1a65d54a4af4bb4474 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 15:03:02 +0800 Subject: [PATCH 238/482] revise mask. names --- src/diffusers/models/attention_processor.py | 44 ++++++++++----------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 90ff1cf48d03..3573d1c3e249 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2360,8 +2360,8 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, is_qv: Optional[bool] = False, # thesea modified for ip image product_ratio: Optional[float] = None, # theseam modified - img_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask - txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask + bg_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask + prod_masks: Optional[torch.Tensor] = None, # thesea modified for text mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2419,16 +2419,16 @@ def __call__( attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) - if num_of_prompts != len(txt_masks): + if num_of_prompts != len(prod_masks): raise ValueError( - f"Length of txt_masks ({len(txt_masks)}) must match number of prompts {num_of_prompts}" + f"Length of prod_masks ({len(prod_masks)}) must match number of prompts {num_of_prompts}" ) # text related attention mask - for index in range(len(txt_masks)): + for index in range(len(prod_masks)): attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - txt_masks[index], + prod_masks[index], 1, 4096, 1, @@ -2442,9 +2442,9 @@ def __call__( # image related attention mask - attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( - img_mask[0], + bg_mask[0], 1, 4096, 1, @@ -2453,8 +2453,8 @@ def __call__( mask_downsample_t2i = mask_downsample_t2i.squeeze() mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, len(txt_masks)*prod_embeds_dim:len(txt_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose attention_mask[-4096:,-4096:] = 1 @@ -2478,8 +2478,8 @@ def __call__( hidden_states_prods = [] - txt_mask_downsamples = [] - for index in range(len(txt_masks)): + prod_mask_downsamples = [] + for index in range(len(prod_masks)): hidden_states_tmp = F.scaled_dot_product_attention( torch.cat([query[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], query[:,:,-4096:,:]], dim=2), torch.cat([key[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], key[:,:,-4096:,:]], dim=2), @@ -2492,14 +2492,14 @@ def __call__( hidden_states_tmp = hidden_states_tmp.to(query.dtype) hidden_states_prods.append(hidden_states_tmp) - txt_mask_downsample = IPAdapterMaskProcessor.downsample( - txt_masks[index], + prod_mask_downsample = IPAdapterMaskProcessor.downsample( + prod_masks[index], batch_size, 4096, attn.heads * head_dim, ) - txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) - txt_mask_downsamples.append(txt_mask_downsample) + prod_mask_downsample = prod_mask_downsample.to(dtype=query.dtype, device=query.device) + prod_mask_downsamples.append(prod_mask_downsample) hidden_states_img = F.scaled_dot_product_attention( query[:,:,-4825:,:], @@ -2512,17 +2512,17 @@ def __call__( hidden_states_img = hidden_states_img.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_img = hidden_states_img.to(query.dtype) - img_mask_downsample = IPAdapterMaskProcessor.downsample( - img_mask[0], + bg_mask_downsample = IPAdapterMaskProcessor.downsample( + bg_mask[0], batch_size, 4096, attn.heads * head_dim, ) - img_mask_downsample = img_mask_downsample.to(dtype=query.dtype, device=query.device) + bg_mask_downsample = bg_mask_downsample.to(dtype=query.dtype, device=query.device) - hidden_states_common = hidden_states_img[:,-4096:,:] * img_mask_downsample - for hidden_states_prod, txt_mask_downsample in zip(hidden_states_prods, txt_mask_downsamples): - hidden_states_common += hidden_states_prod[:,-4096:,:] * txt_mask_downsample + hidden_states_common = hidden_states_img[:,-4096:,:] * bg_mask_downsample + for hidden_states_prod, prod_mask_downsample in zip(hidden_states_prods, prod_mask_downsamples): + hidden_states_common += hidden_states_prod[:,-4096:,:] * prod_mask_downsample hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) else: From 9c2cf8873036e33b6b078203095436f87cbab59a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 16:41:57 +0800 Subject: [PATCH 239/482] add codes for txt prompt masks --- src/diffusers/models/attention_processor.py | 91 +++++++++++++++++++-- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3573d1c3e249..927fffa177e8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2358,10 +2358,11 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - is_qv: Optional[bool] = False, # thesea modified for ip image - product_ratio: Optional[float] = None, # theseam modified - bg_mask: Optional[torch.Tensor] = None, # thesea modified for ip mask - prod_masks: Optional[torch.Tensor] = None, # thesea modified for text mask + is_qv: Optional[bool] = False, # thesea modified for quick validation + product_ratio: Optional[float] = None, # theseam modified for quick validation + bg_mask: Optional[torch.Tensor] = None, # thesea modified for quick validation + prod_masks: Optional[torch.Tensor] = None, # thesea modified for quick validation + txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape @@ -2415,7 +2416,83 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - if is_qv: + # thesea modified for txt prompt masks + if txt_masks is not None: + prod_embeds_dim = 512 + num_of_prompts = int ((query.size(-2) - 4096)/prod_embeds_dim) + if num_of_prompts != len(txt_masks): + raise ValueError( + f"Length of txt_masks ({len(txt_masks)}) must match number of prompts {num_of_prompts}" + ) + + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) + + # prod related attention mask + for index in range(len(txt_masks)): + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) + mask_downsample_t2i = IPAdapterMaskProcessor.downsample( + txt_masks[index], + 1, + 4096, + 1, + ) + mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) + mask_downsample_t2i = mask_downsample_t2i.squeeze() + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(prod_embeds_dim, 1).to(device=query.device) + mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = mask_downsample_t2i_tensor_transpose + + attention_mask[-4096:,-4096:] = 1 + + zero_index = attention_mask < 0.5 + one_index = attention_mask >= 0.5 + attention_mask = attention_mask.masked_fill(zero_index, float('-inf')) + attention_mask = attention_mask.masked_fill(one_index, 0) + attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) + + hidden_states_region = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + hidden_states_region = hidden_states_region.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_region = hidden_states_region.to(query.dtype) + + hidden_states_txts = [] + txt_mask_downsamples = [] + for index in range(len(txt_masks)): + hidden_states_tmp = F.scaled_dot_product_attention( + torch.cat([query[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], query[:,:,-4096:,:]], dim=2), + torch.cat([key[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], key[:,:,-4096:,:]], dim=2), + torch.cat([value[:,:,index*prod_embeds_dim:(index+1)*prod_embeds_dim,:], value[:,:,-4096:,:]], dim=2), + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + hidden_states_tmp = hidden_states_tmp.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_tmp = hidden_states_tmp.to(query.dtype) + hidden_states_txts.append(hidden_states_tmp) + + txt_mask_downsample = IPAdapterMaskProcessor.downsample( + txt_masks[index], + batch_size, + 4096, + attn.heads * head_dim, + ) + txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) + txt_mask_downsamples.append(txt_mask_downsample) + + hidden_states_commom = torch.zeros(4096, 4096, dtype=query.dtype, device=query.device) + for hidden_states_txt, txt_mask_downsample in zip(hidden_states_txts, txt_mask_downsamples): + hidden_states_common += hidden_states_txt[:,-4096:,:] * txt_mask_downsample + + hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) + # thesea modified for quick validation of product shots + elif is_qv: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) @@ -2424,7 +2501,7 @@ def __call__( f"Length of prod_masks ({len(prod_masks)}) must match number of prompts {num_of_prompts}" ) - # text related attention mask + # prod related attention mask for index in range(len(prod_masks)): attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( @@ -2441,7 +2518,7 @@ def __call__( attention_mask[-4096:, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = mask_downsample_t2i_tensor_transpose - # image related attention mask + # bg related attention mask attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( bg_mask[0], From d04443581d4255a9210d27ed11648a2d056252d6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 17:50:51 +0800 Subject: [PATCH 240/482] add joint_attention_kwargs to transformer single blocks --- src/diffusers/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 5df4de6c7a56..a6484ab1e059 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -525,7 +525,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From b24ae7b9ba8c0a08867d2806dad7926584ddb736 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 20:11:43 +0800 Subject: [PATCH 241/482] fix typo --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 927fffa177e8..9eb733717552 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2486,7 +2486,7 @@ def __call__( txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) txt_mask_downsamples.append(txt_mask_downsample) - hidden_states_commom = torch.zeros(4096, 4096, dtype=query.dtype, device=query.device) + hidden_states_common = torch.zeros(4096, 4096, dtype=query.dtype, device=query.device) for hidden_states_txt, txt_mask_downsample in zip(hidden_states_txts, txt_mask_downsamples): hidden_states_common += hidden_states_txt[:,-4096:,:] * txt_mask_downsample From 6158ca5bfbdbf67e0112fe58520b316be92f7c2a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 20:15:36 +0800 Subject: [PATCH 242/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9eb733717552..47bc4438f68a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2486,7 +2486,7 @@ def __call__( txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) txt_mask_downsamples.append(txt_mask_downsample) - hidden_states_common = torch.zeros(4096, 4096, dtype=query.dtype, device=query.device) + hidden_states_common = torch.zeros_like(hidden_states_txt[:,-4096:,:]) for hidden_states_txt, txt_mask_downsample in zip(hidden_states_txts, txt_mask_downsamples): hidden_states_common += hidden_states_txt[:,-4096:,:] * txt_mask_downsample From 6f32c6080d932b5cd47142d69a90cecb65a28247 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 20:18:22 +0800 Subject: [PATCH 243/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 47bc4438f68a..6c5e4020591d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2486,7 +2486,7 @@ def __call__( txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) txt_mask_downsamples.append(txt_mask_downsample) - hidden_states_common = torch.zeros_like(hidden_states_txt[:,-4096:,:]) + hidden_states_common = torch.zeros_like(hidden_states_txts[0][:,-4096:,:]) for hidden_states_txt, txt_mask_downsample in zip(hidden_states_txts, txt_mask_downsamples): hidden_states_common += hidden_states_txt[:,-4096:,:] * txt_mask_downsample From 8197babf4b7c71b350f37ffea67e1c6ce96a2e59 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 20:43:38 +0800 Subject: [PATCH 244/482] remove joint_attention_kwargs in transformer single blocks --- src/diffusers/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index a6484ab1e059..5df4de6c7a56 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -525,7 +525,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 9f42b7002fd444ab9364375c3271a78c33ad64eb Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 20:50:43 +0800 Subject: [PATCH 245/482] Revert "remove joint_attention_kwargs in transformer single blocks" This reverts commit 8197babf4b7c71b350f37ffea67e1c6ce96a2e59. --- src/diffusers/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 5df4de6c7a56..a6484ab1e059 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -525,7 +525,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 8361a013e05b38e6db7c0c2a7607466e8ee3ba2e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 21:08:32 +0800 Subject: [PATCH 246/482] fix bug --- src/diffusers/models/attention_processor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6c5e4020591d..d1aaf72c1ee2 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2429,7 +2429,7 @@ def __call__( # prod related attention mask for index in range(len(txt_masks)): - attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = 1 mask_downsample_t2i = IPAdapterMaskProcessor.downsample( txt_masks[index], 1, @@ -2486,11 +2486,11 @@ def __call__( txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) txt_mask_downsamples.append(txt_mask_downsample) - hidden_states_common = torch.zeros_like(hidden_states_txts[0][:,-4096:,:]) - for hidden_states_txt, txt_mask_downsample in zip(hidden_states_txts, txt_mask_downsamples): - hidden_states_common += hidden_states_txt[:,-4096:,:] * txt_mask_downsample + hidden_states_common = hidden_states_txts[0][:,-4096:,:] * txt_mask_downsample[0] + for index in range(1, len(hidden_states_txts)): + hidden_states_common += hidden_states_txt[index][:,-4096:,:] * txt_mask_downsample[index] - hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) + hidden_states = torch.cat([hidden_states-_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) @@ -2503,7 +2503,7 @@ def __call__( # prod related attention mask for index in range(len(prod_masks)): - attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = torch.ones(prod_embeds_dim, prod_embeds_dim) + attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = 1 mask_downsample_t2i = IPAdapterMaskProcessor.downsample( prod_masks[index], 1, From 0b792593bc7c234ba4a1fd28b3c7a0c5b9f8c82f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 21:10:45 +0800 Subject: [PATCH 247/482] fix bug --- src/diffusers/models/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d1aaf72c1ee2..ee678a677b53 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2486,9 +2486,9 @@ def __call__( txt_mask_downsample = txt_mask_downsample.to(dtype=query.dtype, device=query.device) txt_mask_downsamples.append(txt_mask_downsample) - hidden_states_common = hidden_states_txts[0][:,-4096:,:] * txt_mask_downsample[0] + hidden_states_common = hidden_states_txts[0][:,-4096:,:] * txt_mask_downsamples[0] for index in range(1, len(hidden_states_txts)): - hidden_states_common += hidden_states_txt[index][:,-4096:,:] * txt_mask_downsample[index] + hidden_states_common += hidden_states_txts[index][:,-4096:,:] * txt_mask_downsamples[index] hidden_states = torch.cat([hidden_states-_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots From 3279e4f982e76bb40a337a1df7de732a729b2bf5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 21:14:49 +0800 Subject: [PATCH 248/482] fix typo --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ee678a677b53..b592f66640ab 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2490,7 +2490,7 @@ def __call__( for index in range(1, len(hidden_states_txts)): hidden_states_common += hidden_states_txts[index][:,-4096:,:] * txt_mask_downsamples[index] - hidden_states = torch.cat([hidden_states-_region[:,:-4096,:], hidden_states_common], dim=1) + hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) From e0f655eeb7feae444734c2815a2d47bbc321656a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 21:39:50 +0800 Subject: [PATCH 249/482] fix bug --- src/diffusers/models/attention_processor.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b592f66640ab..4ec20a28758e 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2420,11 +2420,17 @@ def __call__( if txt_masks is not None: prod_embeds_dim = 512 num_of_prompts = int ((query.size(-2) - 4096)/prod_embeds_dim) - if num_of_prompts != len(txt_masks): + if num_of_prompts != len(txt_masks) + 1: raise ValueError( - f"Length of txt_masks ({len(txt_masks)}) must match number of prompts {num_of_prompts}" + f"Length of txt_masks ({len(txt_masks)}) + 1 must match number of prompts {num_of_prompts}" ) + last_mask = torch.ones_like(txt_masks[0]) + for tmp_mask in txt_masks: + one_index = tmp_mask == 1.0 + last_mask = last_mask.masked_fill(one_index, 0) + txt_masks.append(last_mask) + attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) # prod related attention mask From 4924ed84e3e63a09bf857c8bb9643cf8926d99b6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 21:44:48 +0800 Subject: [PATCH 250/482] fix bug --- src/diffusers/models/attention_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4ec20a28758e..9e23b0c00f19 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2426,10 +2426,10 @@ def __call__( ) last_mask = torch.ones_like(txt_masks[0]) - for tmp_mask in txt_masks: - one_index = tmp_mask == 1.0 + for index in range(len(txt_masks)): + one_index = txt_masks[index] == 1.0 last_mask = last_mask.masked_fill(one_index, 0) - txt_masks.append(last_mask) + txt_masks = torch.cat([txt_masks, last_mask], dim=0) attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) From 8a91bf623b8499ba255e0d810e7300fb44bba2ff Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Apr 2025 21:47:54 +0800 Subject: [PATCH 251/482] fix bug --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9e23b0c00f19..b317f6329abb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2429,7 +2429,7 @@ def __call__( for index in range(len(txt_masks)): one_index = txt_masks[index] == 1.0 last_mask = last_mask.masked_fill(one_index, 0) - txt_masks = torch.cat([txt_masks, last_mask], dim=0) + txt_masks = torch.cat([txt_masks, last_mask.unsqueeze(0)], dim=0) attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) From bc59111dac96af351cc4b594d0ee693206df7689 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Apr 2025 09:42:46 +0800 Subject: [PATCH 252/482] add joint_attention_kwargs to ctln --- src/diffusers/models/controlnets/controlnet_flux.py | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 3fe6608dd1c8..4dd9b64854ed 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -340,7 +340,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -361,7 +361,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 4c7d099df313..8e3586c68494 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1114,6 +1114,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, ) From 9a3862b84ff20809be15a116635782165c349705 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Apr 2025 10:09:58 +0800 Subject: [PATCH 253/482] remove joint_attention_kwargs in single blocks --- src/diffusers/models/controlnets/controlnet_flux.py | 3 ++- src/diffusers/models/transformers/transformer_flux.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 4dd9b64854ed..5569ffc8846d 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -323,6 +323,7 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) + print(f'joint_attention_kwargs keys={joint_attention_kwargs.keys()}') block_samples = () for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -361,7 +362,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index a6484ab1e059..5df4de6c7a56 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -525,7 +525,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 28010baa9bee8cfb150328a8e064e29f4856e1ab Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Apr 2025 11:04:54 +0800 Subject: [PATCH 254/482] Revert "remove joint_attention_kwargs in single blocks" This reverts commit 9a3862b84ff20809be15a116635782165c349705. --- src/diffusers/models/controlnets/controlnet_flux.py | 3 +-- src/diffusers/models/transformers/transformer_flux.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 5569ffc8846d..4dd9b64854ed 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -323,7 +323,6 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) - print(f'joint_attention_kwargs keys={joint_attention_kwargs.keys()}') block_samples = () for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -362,7 +361,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 5df4de6c7a56..a6484ab1e059 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -525,7 +525,7 @@ def forward( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 004f1e8c318c4c1a8c552404420c0253993a503f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 5 Apr 2025 14:17:02 +0800 Subject: [PATCH 255/482] remove txt masks in single blocks --- src/diffusers/models/attention_processor.py | 2 +- .../models/controlnets/controlnet_flux.py | 43 +++++++++++++------ .../models/transformers/transformer_flux.py | 20 ++++++--- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b317f6329abb..90a7146dbd4e 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -594,7 +594,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states"} + quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 4dd9b64854ed..35cf18802ace 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -335,13 +335,22 @@ def forward( ) else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) + if 'is_single_prod' in joint_attention_kwargs: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + #joint_attention_kwargs=joint_attention_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) block_samples = block_samples + (hidden_states,) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -357,12 +366,20 @@ def forward( ) else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) + if 'is_single_prod' in joint_attention_kwargs: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + #joint_attention_kwargs=joint_attention_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) # controlnet block diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index a6484ab1e059..465e8e41246a 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -521,12 +521,20 @@ def forward( ) else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) + if 'is_single_prod' in joint_attention_kwargs: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + #joint_attention_kwargs=joint_attention_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) # controlnet residual if controlnet_single_block_samples is not None: From d01b68dcc6cd44174dc2adc2a8b19393c4460f78 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 5 Apr 2025 17:30:19 +0800 Subject: [PATCH 256/482] add ctln logic --- src/diffusers/models/controlnets/controlnet_flux.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 35cf18802ace..e25ae333bf7a 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -335,13 +335,13 @@ def forward( ) else: - if 'is_single_prod' in joint_attention_kwargs: + if 'has_ctln' in joint_attention_kwargs: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( @@ -349,7 +349,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -366,19 +366,19 @@ def forward( ) else: - if 'is_single_prod' in joint_attention_kwargs: + if 'has_ctln' in joint_attention_kwargs: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) From 63e5031d02a5841fc60b77e86cf2bc29025554d8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 5 Apr 2025 17:54:50 +0800 Subject: [PATCH 257/482] change ctln logic --- src/diffusers/models/attention_processor.py | 2 +- .../models/controlnets/controlnet_flux.py | 30 +++++++------------ 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 90a7146dbd4e..0e8b692c4f25 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -594,7 +594,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod"} + quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod", "has_ctln"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index e25ae333bf7a..7f6dab53bb5e 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -335,22 +335,14 @@ def forward( ) else: - if 'has_ctln' in joint_attention_kwargs: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, - ) + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + block_samples = block_samples + (hidden_states,) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -366,19 +358,19 @@ def forward( ) else: - if 'has_ctln' in joint_attention_kwargs: + if 'is_single_prod' in joint_attention_kwargs: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) From 5136e0a93f2633d3e3d7db43473a1989388787e4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 5 Apr 2025 19:13:10 +0800 Subject: [PATCH 258/482] add first_N_blocks kwargs --- src/diffusers/models/attention_processor.py | 2 +- .../models/controlnets/controlnet_flux.py | 20 +++++++++++++------ .../models/transformers/transformer_flux.py | 20 +++++++++++++------ 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0e8b692c4f25..df744d307018 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -594,7 +594,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod", "has_ctln"} + quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod", "first_N_blocks"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 7f6dab53bb5e..e82af4e358c1 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -366,12 +366,20 @@ def forward( #joint_attention_kwargs=joint_attention_kwargs, ) else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) + if 'first_N_blocks' in joint_attention_kwargs: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs if index_block < joint_attention_kwargs['first_N_blocks'] else None, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) # controlnet block diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 465e8e41246a..ae68fcf214b3 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -529,12 +529,20 @@ def forward( #joint_attention_kwargs=joint_attention_kwargs, ) else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) + if 'first_N_blocks' in joint_attention_kwargs: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs if index_block < joint_attention_kwargs['first_N_blocks'] else None, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) # controlnet residual if controlnet_single_block_samples is not None: From e0a53113887e56c5e7b8564f50a695c53bbc71af Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 7 Apr 2025 12:22:07 +0800 Subject: [PATCH 259/482] merge blend, qv, single multiple logic --- .../models/controlnets/controlnet_flux.py | 29 +++++++++++------- .../models/transformers/transformer_flux.py | 30 ++++++++++++------- 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index e82af4e358c1..667b897b0006 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -358,28 +358,37 @@ def forward( ) else: - if 'is_single_prod' in joint_attention_kwargs: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, - ) - else: - if 'first_N_blocks' in joint_attention_kwargs: + if 'is_qv' in joint_attention_kwargs: + if 'is_multiprod' in joint_attention_kwargs: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs if index_block < joint_attention_kwargs['first_N_blocks'] else None, + joint_attention_kwargs=joint_attention_kwargs, ) else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + #joint_attention_kwargs=joint_attention_kwargs, + ) + else: + if 'is_multiprod' in joint_attention_kwargs: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + #joint_attention_kwargs=joint_attention_kwargs, + ) + single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) # controlnet block diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index ae68fcf214b3..280a48485200 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -521,27 +521,35 @@ def forward( ) else: - if 'is_single_prod' in joint_attention_kwargs: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - #joint_attention_kwargs=joint_attention_kwargs, - ) - else: - if 'first_N_blocks' in joint_attention_kwargs: + if 'is_qv' in joint_attention_kwargs: + if 'is_multiprod' in joint_attention_kwargs: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs if index_block < joint_attention_kwargs['first_N_blocks'] else None, + joint_attention_kwargs=joint_attention_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + #joint_attention_kwargs=joint_attention_kwargs, + ) + else: + if 'is_multiprod' in joint_attention_kwargs: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + #joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 0d3d466faaacfe279ed5be1a6209873d16f918bf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 7 Apr 2025 12:40:45 +0800 Subject: [PATCH 260/482] mute kwargs --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index df744d307018..488bde208c3e 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -594,7 +594,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod", "first_N_blocks"} + quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod", "is_multiprod", "is_qv" "first_N_blocks"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] From 6d38f908c8d5a7f6bd0769257bd60636ef4ecdd5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 7 Apr 2025 17:33:44 +0800 Subject: [PATCH 261/482] add fix_bg to joint kwargs --- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/models/controlnets/controlnet_flux.py | 2 +- src/diffusers/models/transformers/transformer_flux.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 488bde208c3e..e599f7ce2f2f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -594,7 +594,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod", "is_multiprod", "is_qv" "first_N_blocks"} + quiet_attn_parameters = {"text_masks", "ip_adapter_masks", "ip_hidden_states", "is_single_prod", "is_multiprod", "is_qv", "fix_bg", "first_N_blocks"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 667b897b0006..316b970db159 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -359,7 +359,7 @@ def forward( else: if 'is_qv' in joint_attention_kwargs: - if 'is_multiprod' in joint_attention_kwargs: + if ('is_multiprod' in joint_attention_kwargs) or ('fix_bg' in joint_attention_kwargs): hidden_states = block( hidden_states=hidden_states, temb=temb, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 280a48485200..9ad598ef9688 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -522,7 +522,7 @@ def forward( else: if 'is_qv' in joint_attention_kwargs: - if 'is_multiprod' in joint_attention_kwargs: + if ('is_multiprod' in joint_attention_kwargs) or ('fix_bg' in joint_attention_kwargs): hidden_states = block( hidden_states=hidden_states, temb=temb, From d7657a89664a8517342a39b95274ef4cac5d32c9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Apr 2025 18:00:00 +0800 Subject: [PATCH 262/482] add bg embeds repeat --- .../pipelines/flux/pipeline_flux_prior_redux.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index a5fbd1288ed0..65d982183385 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -386,6 +386,7 @@ def __call__( product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, + repeat: Optional[int] = 0, return_dict: bool = True, ): r""" @@ -600,7 +601,13 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: - prompt_embeds = image_embeds_bg + if repeat == 0: + prompt_embeds = image_embeds_bg + else: + prompt_embeds = image_embeds_bg + for index in range(repeat): + prompt_embeds = torch.cat([prompt_embeds, image_embeds_bg],dim=1) + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 6ffac812ba54617918f6fbd49d961ee26f4c3a70 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Apr 2025 18:18:49 +0800 Subject: [PATCH 263/482] add repeat to attention processor --- src/diffusers/models/attention_processor.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e599f7ce2f2f..6359edea22c4 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2360,6 +2360,7 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, is_qv: Optional[bool] = False, # thesea modified for quick validation product_ratio: Optional[float] = None, # theseam modified for quick validation + repeat: Optional[int] = 0, # theseam modified for quick validation bg_mask: Optional[torch.Tensor] = None, # thesea modified for quick validation prod_masks: Optional[torch.Tensor] = None, # thesea modified for quick validation txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask @@ -2501,7 +2502,8 @@ def __call__( elif is_qv: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) - num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) + img_embeds_dim = int(729*(1+repeat)) + num_of_prompts = int ((query.size(-2) - img_embeds_dim - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( f"Length of prod_masks ({len(prod_masks)}) must match number of prompts {num_of_prompts}" @@ -2525,7 +2527,7 @@ def __call__( # bg related attention mask - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim] = torch.ones(img_embeds_dim, img_embeds_dim) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( bg_mask[0], 1, @@ -2534,11 +2536,10 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(img_embeds_dim , 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose - + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim] = mask_downsample_t2i_tensor_transpose attention_mask[-4096:,-4096:] = 1 @@ -2585,9 +2586,9 @@ def __call__( prod_mask_downsamples.append(prod_mask_downsample) hidden_states_img = F.scaled_dot_product_attention( - query[:,:,-4825:,:], - key[:,:,-4825:,:], - value[:,:,-4825:,:], + query[:,:,prod_embeds_dim:,:], + key[:,:,prod_embeds_dim:,:], + value[:,:,prod_embeds_dim:,:], attn_mask=None, dropout_p=0.0, is_causal=False From fb7579f0d2a00a18ebe13801265022f6243854b3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Apr 2025 18:21:33 +0800 Subject: [PATCH 264/482] fix dim --- src/diffusers/models/attention_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6359edea22c4..8d022a0ce536 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2586,9 +2586,9 @@ def __call__( prod_mask_downsamples.append(prod_mask_downsample) hidden_states_img = F.scaled_dot_product_attention( - query[:,:,prod_embeds_dim:,:], - key[:,:,prod_embeds_dim:,:], - value[:,:,prod_embeds_dim:,:], + query[:,:,len(prod_masks)*prod_embeds_dim:,:], + key[:,:,len(prod_masks)*prod_embeds_dim:,:], + value[:,:,len(prod_masks)*prod_embeds_dim:,:], attn_mask=None, dropout_p=0.0, is_causal=False From 18b39edc8f8b77ef5828ae8c0c5ac7c4cb31f492 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Apr 2025 19:07:24 +0800 Subject: [PATCH 265/482] Revert "fix dim" This reverts commit fb7579f0d2a00a18ebe13801265022f6243854b3. --- src/diffusers/models/attention_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8d022a0ce536..6359edea22c4 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2586,9 +2586,9 @@ def __call__( prod_mask_downsamples.append(prod_mask_downsample) hidden_states_img = F.scaled_dot_product_attention( - query[:,:,len(prod_masks)*prod_embeds_dim:,:], - key[:,:,len(prod_masks)*prod_embeds_dim:,:], - value[:,:,len(prod_masks)*prod_embeds_dim:,:], + query[:,:,prod_embeds_dim:,:], + key[:,:,prod_embeds_dim:,:], + value[:,:,prod_embeds_dim:,:], attn_mask=None, dropout_p=0.0, is_causal=False From 1664ea4ce515ae8ee2941d1535096a4a9be6028e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Apr 2025 19:07:27 +0800 Subject: [PATCH 266/482] Revert "add repeat to attention processor" This reverts commit 6ffac812ba54617918f6fbd49d961ee26f4c3a70. --- src/diffusers/models/attention_processor.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6359edea22c4..e599f7ce2f2f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2360,7 +2360,6 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, is_qv: Optional[bool] = False, # thesea modified for quick validation product_ratio: Optional[float] = None, # theseam modified for quick validation - repeat: Optional[int] = 0, # theseam modified for quick validation bg_mask: Optional[torch.Tensor] = None, # thesea modified for quick validation prod_masks: Optional[torch.Tensor] = None, # thesea modified for quick validation txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask @@ -2502,8 +2501,7 @@ def __call__( elif is_qv: attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) - img_embeds_dim = int(729*(1+repeat)) - num_of_prompts = int ((query.size(-2) - img_embeds_dim - 4096)/prod_embeds_dim) + num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( f"Length of prod_masks ({len(prod_masks)}) must match number of prompts {num_of_prompts}" @@ -2527,7 +2525,7 @@ def __call__( # bg related attention mask - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim] = torch.ones(img_embeds_dim, img_embeds_dim) + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( bg_mask[0], 1, @@ -2536,10 +2534,11 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(img_embeds_dim , 1).to(device=query.device) + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embeds_dim] = mask_downsample_t2i_tensor_transpose + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose + attention_mask[-4096:,-4096:] = 1 @@ -2586,9 +2585,9 @@ def __call__( prod_mask_downsamples.append(prod_mask_downsample) hidden_states_img = F.scaled_dot_product_attention( - query[:,:,prod_embeds_dim:,:], - key[:,:,prod_embeds_dim:,:], - value[:,:,prod_embeds_dim:,:], + query[:,:,-4825:,:], + key[:,:,-4825:,:], + value[:,:,-4825:,:], attn_mask=None, dropout_p=0.0, is_causal=False From dab951487f9cbe7b5251586dcebf4bb3acfe0b25 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Apr 2025 19:07:33 +0800 Subject: [PATCH 267/482] Revert "add bg embeds repeat" This reverts commit d7657a89664a8517342a39b95274ef4cac5d32c9. --- .../pipelines/flux/pipeline_flux_prior_redux.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 65d982183385..a5fbd1288ed0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -386,7 +386,6 @@ def __call__( product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, - repeat: Optional[int] = 0, return_dict: bool = True, ): r""" @@ -601,13 +600,7 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: - if repeat == 0: - prompt_embeds = image_embeds_bg - else: - prompt_embeds = image_embeds_bg - for index in range(repeat): - prompt_embeds = torch.cat([prompt_embeds, image_embeds_bg],dim=1) - + prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 4e76d9d34567a09d9950e3cc7e5da9f6188c992f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 9 Apr 2025 10:52:52 +0800 Subject: [PATCH 268/482] add checking input logic --- .../pipelines/flux/pipeline_flux_prior_redux.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index a5fbd1288ed0..bf146e51fc94 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -459,6 +459,12 @@ def __call__( image_array_list = [] mask_list = [] is_product_list = [] + + if len(image) != len(layer_type): + raise ValueError( + f"number of images ({len(image)}) must match the number of layers {len(layer_type)}" + ) + for img, img_type in zip(image, layer_type): if 'product' in img_type or 'Product' in img_type: is_product_list.append('true') @@ -600,6 +606,12 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: + num_of_prod_imgs = int(image_embeds_prods.shape[1]/729) + if len(prompt_embeds_list) != num_of_prod_imgs: + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {num_of_prod_imgs}" + ) + prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) From 8f044284cf8aa51a5d2f20c453f1858e64291137 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 9 Apr 2025 10:58:48 +0800 Subject: [PATCH 269/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index bf146e51fc94..128308ae59a8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -606,10 +606,9 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: - num_of_prod_imgs = int(image_embeds_prods.shape[1]/729) - if len(prompt_embeds_list) != num_of_prod_imgs: + if len(prompt_embeds_list) != len(image_embeds_prods): raise ValueError( - f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {num_of_prod_imgs}" + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" ) prompt_embeds = image_embeds_bg From 5c863ded3f1536fc5527879cde7ad28ab8a29e0d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 11 Apr 2025 11:31:36 +0800 Subject: [PATCH 270/482] print info --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e599f7ce2f2f..9b78a9438a8a 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2509,6 +2509,7 @@ def __call__( # prod related attention mask for index in range(len(prod_masks)): + print(f'mask data type={type(prod_masks[index])}') attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = 1 mask_downsample_t2i = IPAdapterMaskProcessor.downsample( prod_masks[index], From fa270eb2317e480fc6bfa3cd58c5b81f5a49cf88 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 11 Apr 2025 11:39:00 +0800 Subject: [PATCH 271/482] remove print info --- src/diffusers/models/attention_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9b78a9438a8a..e599f7ce2f2f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2509,7 +2509,6 @@ def __call__( # prod related attention mask for index in range(len(prod_masks)): - print(f'mask data type={type(prod_masks[index])}') attention_mask[index*prod_embeds_dim:(index+1)*prod_embeds_dim, index*prod_embeds_dim:(index+1)*prod_embeds_dim] = 1 mask_downsample_t2i = IPAdapterMaskProcessor.downsample( prod_masks[index], From 7d2370deb47eb9291380d42206e498bd283c4256 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 15 Apr 2025 14:46:57 +0800 Subject: [PATCH 272/482] add default logic --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 546a225aa999..d2af2061a19f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -923,6 +923,9 @@ def __call__( ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + # 5. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents, latent_image_ids = self.prepare_latents( From 8b09bb4537254ac171674522f5c1196d1be15207 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 15 Apr 2025 16:14:25 +0800 Subject: [PATCH 273/482] default value --- src/diffusers/models/attention_processor.py | 1 + src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e599f7ce2f2f..fe13eee82fab 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2499,6 +2499,7 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: + product_ratio = 0.65 attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 128308ae59a8..39dd6b68456b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -612,6 +612,7 @@ def __call__( ) prompt_embeds = image_embeds_bg + product_ratio = 0.65 for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 32891f5f95fd792a48f5b15685a720e1ff193fcf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 15 Apr 2025 17:04:11 +0800 Subject: [PATCH 274/482] default value --- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fe13eee82fab..854f40350d5f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2499,7 +2499,7 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: - product_ratio = 0.65 + product_ratio = 0.7 attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 39dd6b68456b..68b10bfb31ed 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -612,7 +612,7 @@ def __call__( ) prompt_embeds = image_embeds_bg - product_ratio = 0.65 + product_ratio = 0.7 for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 75609f1b7a6ee725271d970bc7d1ad09f3ded0d1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 23 Apr 2025 23:40:48 +0800 Subject: [PATCH 275/482] add 'first_N_blocks' in joint_attention_kwargs --- src/diffusers/models/transformers/transformer_flux.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 9ad598ef9688..95365da3d1bb 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -529,6 +529,13 @@ def forward( image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) + elif 'first_N_blocks' in joint_attention_kwargs: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs if index_block <= joint_attention_kwargs['first_N_blocks'] else None, + ) else: hidden_states = block( hidden_states=hidden_states, From 4d03fc18a855c83e7fcf42de3f779e6b4bc43bda Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 24 Apr 2025 00:16:38 +0800 Subject: [PATCH 276/482] improve blend bg enhance --- src/diffusers/models/attention_processor.py | 22 +++++++++++-------- .../flux/pipeline_flux_prior_redux.py | 13 +++++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 854f40350d5f..a3d2240d4df8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,6 +2359,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, is_qv: Optional[bool] = False, # thesea modified for quick validation + is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation bg_mask: Optional[torch.Tensor] = None, # thesea modified for quick validation prod_masks: Optional[torch.Tensor] = None, # thesea modified for quick validation @@ -2499,10 +2500,13 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: - product_ratio = 0.7 attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) - num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) + if is_blend_bg_enhance: + img_embbeds_dim = 729 + 512 + else: + img_embbeds_dim = 729 + num_of_prompts = int ((query.size(-2) - img_embbeds_dim - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( f"Length of prod_masks ({len(prod_masks)}) must match number of prompts {num_of_prompts}" @@ -2526,7 +2530,7 @@ def __call__( # bg related attention mask - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim] = torch.ones(img_embbeds_dim, img_embbeds_dim) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( bg_mask[0], 1, @@ -2535,10 +2539,10 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(img_embbeds_dim, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim] = mask_downsample_t2i_tensor_transpose attention_mask[-4096:,-4096:] = 1 @@ -2586,9 +2590,9 @@ def __call__( prod_mask_downsamples.append(prod_mask_downsample) hidden_states_img = F.scaled_dot_product_attention( - query[:,:,-4825:,:], - key[:,:,-4825:,:], - value[:,:,-4825:,:], + query[:,:,-(4096+img_embbeds_dim):,:], + key[:,:,-(4096+img_embbeds_dim):,:], + value[:,:,-(4096+img_embbeds_dim):,:], attn_mask=None, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 68b10bfb31ed..570127d09e56 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -382,6 +382,7 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots + is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, @@ -611,10 +612,14 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = image_embeds_bg - product_ratio = 0.7 - for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): - prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) + if is_blend_bg_enhance: + prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods[:-1])): + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) + else: + prompt_embeds = image_embeds_bg + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) From cae03909ade47f88fa20dfea731fa0e70d0c839d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 24 Apr 2025 00:27:25 +0800 Subject: [PATCH 277/482] fix bug --- .../pipelines/flux/pipeline_flux_prior_redux.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 570127d09e56..072668633c44 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -607,16 +607,20 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: - if len(prompt_embeds_list) != len(image_embeds_prods): - raise ValueError( - f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" - ) - if is_blend_bg_enhance: + if len(prompt_embeds_list) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" + ) prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods[:-1])): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: + if len(prompt_embeds_list) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" + ) + prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) From 6b531d78a2b497b6a7032ef35f09a742387f68bc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 24 Apr 2025 00:31:00 +0800 Subject: [PATCH 278/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 072668633c44..dfe7a55a0907 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -608,7 +608,7 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: if is_blend_bg_enhance: - if len(prompt_embeds_list) != len(image_embeds_prods): + if (len(prompt_embeds_list) - 1) != len(image_embeds_prods): raise ValueError( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) From b981a93b190a25f3f8462bb450b7bb04a1c9091d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 24 Apr 2025 00:37:01 +0800 Subject: [PATCH 279/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index dfe7a55a0907..4f7c8f3f58f0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -613,7 +613,7 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) - for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods[:-1])): + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: if len(prompt_embeds_list) != len(image_embeds_prods): From 2621920cec17584ddee35cd6c14e88a91d922681 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 24 Apr 2025 09:49:55 +0800 Subject: [PATCH 280/482] reduce bg embeds dim --- src/diffusers/models/attention_processor.py | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a3d2240d4df8..21335caf3c16 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2503,9 +2503,9 @@ def __call__( attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) if is_blend_bg_enhance: - img_embbeds_dim = 729 + 512 + img_embbeds_dim = int(729 * product_ratio) + 512 else: - img_embbeds_dim = 729 + img_embbeds_dim = int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - img_embbeds_dim - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 4f7c8f3f58f0..c1e842135c0c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -612,7 +612,7 @@ def __call__( raise ValueError( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) + prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg[:,:int(729*product_ratio),:]], dim=1) for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: @@ -621,7 +621,7 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = image_embeds_bg + prompt_embeds = image_embeds_bg[:,:int(729*product_ratio),:] for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 5844bcd26a213910cacd3153e4413ceda6859697 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 24 Apr 2025 10:49:15 +0800 Subject: [PATCH 281/482] add bg_ratio --- src/diffusers/models/attention_processor.py | 5 +++-- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21335caf3c16..4c41e3a43233 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2361,6 +2361,7 @@ def __call__( is_qv: Optional[bool] = False, # thesea modified for quick validation is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation + bg_ratio: Optional[float] = None, # theseam modified for quick validation of product shots bg_mask: Optional[torch.Tensor] = None, # thesea modified for quick validation prod_masks: Optional[torch.Tensor] = None, # thesea modified for quick validation txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask @@ -2503,9 +2504,9 @@ def __call__( attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) if is_blend_bg_enhance: - img_embbeds_dim = int(729 * product_ratio) + 512 + img_embbeds_dim = int(729 * bg_ratio) + 512 else: - img_embbeds_dim = int(729 * product_ratio) + img_embbeds_dim = int(729 * bg_ratio) num_of_prompts = int ((query.size(-2) - img_embbeds_dim - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index c1e842135c0c..686f748028bb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -385,6 +385,7 @@ def __call__( is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots + bg_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, return_dict: bool = True, @@ -612,7 +613,7 @@ def __call__( raise ValueError( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg[:,:int(729*product_ratio),:]], dim=1) + prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg[:,:int(729*bg_ratio),:]], dim=1) for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: @@ -621,7 +622,7 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = image_embeds_bg[:,:int(729*product_ratio),:] + prompt_embeds = image_embeds_bg[:,:int(729*bg_ratio),:] for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 3fe06f7b7ba7f13c8864988289bf138e036297f5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:17:12 +0800 Subject: [PATCH 282/482] Revert "add bg_ratio" This reverts commit 5844bcd26a213910cacd3153e4413ceda6859697. --- src/diffusers/models/attention_processor.py | 5 ++--- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4c41e3a43233..21335caf3c16 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2361,7 +2361,6 @@ def __call__( is_qv: Optional[bool] = False, # thesea modified for quick validation is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation - bg_ratio: Optional[float] = None, # theseam modified for quick validation of product shots bg_mask: Optional[torch.Tensor] = None, # thesea modified for quick validation prod_masks: Optional[torch.Tensor] = None, # thesea modified for quick validation txt_masks: Optional[torch.Tensor] = None, # thesea modified for text mask @@ -2504,9 +2503,9 @@ def __call__( attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) if is_blend_bg_enhance: - img_embbeds_dim = int(729 * bg_ratio) + 512 + img_embbeds_dim = int(729 * product_ratio) + 512 else: - img_embbeds_dim = int(729 * bg_ratio) + img_embbeds_dim = int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - img_embbeds_dim - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 686f748028bb..c1e842135c0c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -385,7 +385,6 @@ def __call__( is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots - bg_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, return_dict: bool = True, @@ -613,7 +612,7 @@ def __call__( raise ValueError( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg[:,:int(729*bg_ratio),:]], dim=1) + prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg[:,:int(729*product_ratio),:]], dim=1) for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: @@ -622,7 +621,7 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = image_embeds_bg[:,:int(729*bg_ratio),:] + prompt_embeds = image_embeds_bg[:,:int(729*product_ratio),:] for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From a6c9549e3868f2fe73375b62a877be7a065632ed Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:17:17 +0800 Subject: [PATCH 283/482] Revert "reduce bg embeds dim" This reverts commit 2621920cec17584ddee35cd6c14e88a91d922681. --- src/diffusers/models/attention_processor.py | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21335caf3c16..a3d2240d4df8 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2503,9 +2503,9 @@ def __call__( attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) if is_blend_bg_enhance: - img_embbeds_dim = int(729 * product_ratio) + 512 + img_embbeds_dim = 729 + 512 else: - img_embbeds_dim = int(729 * product_ratio) + img_embbeds_dim = 729 num_of_prompts = int ((query.size(-2) - img_embbeds_dim - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index c1e842135c0c..4f7c8f3f58f0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -612,7 +612,7 @@ def __call__( raise ValueError( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg[:,:int(729*product_ratio),:]], dim=1) + prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: @@ -621,7 +621,7 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" ) - prompt_embeds = image_embeds_bg[:,:int(729*product_ratio),:] + prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 001b06aa3b863c24a525abdbde93ef40662f5f77 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:17:21 +0800 Subject: [PATCH 284/482] Revert "fix bug" This reverts commit b981a93b190a25f3f8462bb450b7bb04a1c9091d. --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 4f7c8f3f58f0..dfe7a55a0907 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -613,7 +613,7 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) - for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods)): + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods[:-1])): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: if len(prompt_embeds_list) != len(image_embeds_prods): From 8d503743d2ba6cc6be46398b9e1582f1028a0462 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:17:26 +0800 Subject: [PATCH 285/482] Revert "fix bug" This reverts commit 6b531d78a2b497b6a7032ef35f09a742387f68bc. --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index dfe7a55a0907..072668633c44 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -608,7 +608,7 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: if is_blend_bg_enhance: - if (len(prompt_embeds_list) - 1) != len(image_embeds_prods): + if len(prompt_embeds_list) != len(image_embeds_prods): raise ValueError( f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" ) From 216ff6403fce3c14cbf722279356f8e0f0c3c9d6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:17:28 +0800 Subject: [PATCH 286/482] Revert "fix bug" This reverts commit cae03909ade47f88fa20dfea731fa0e70d0c839d. --- .../pipelines/flux/pipeline_flux_prior_redux.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 072668633c44..570127d09e56 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -607,20 +607,16 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: + if len(prompt_embeds_list) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" + ) + if is_blend_bg_enhance: - if len(prompt_embeds_list) != len(image_embeds_prods): - raise ValueError( - f"number of prompts ({len(prompt_embeds_list)-1}) must match the number of product images {len(image_embeds_prods)}" - ) prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods[:-1])): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: - if len(prompt_embeds_list) != len(image_embeds_prods): - raise ValueError( - f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" - ) - prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) From 2ba93f2a3a477f67b6db4477aa1d26a2f0e54880 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:17:32 +0800 Subject: [PATCH 287/482] Revert "improve blend bg enhance" This reverts commit 4d03fc18a855c83e7fcf42de3f779e6b4bc43bda. --- src/diffusers/models/attention_processor.py | 22 ++++++++----------- .../flux/pipeline_flux_prior_redux.py | 13 ++++------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a3d2240d4df8..854f40350d5f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2359,7 +2359,6 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, is_qv: Optional[bool] = False, # thesea modified for quick validation - is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation bg_mask: Optional[torch.Tensor] = None, # thesea modified for quick validation prod_masks: Optional[torch.Tensor] = None, # thesea modified for quick validation @@ -2500,13 +2499,10 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: + product_ratio = 0.7 attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) - if is_blend_bg_enhance: - img_embbeds_dim = 729 + 512 - else: - img_embbeds_dim = 729 - num_of_prompts = int ((query.size(-2) - img_embbeds_dim - 4096)/prod_embeds_dim) + num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) if num_of_prompts != len(prod_masks): raise ValueError( f"Length of prod_masks ({len(prod_masks)}) must match number of prompts {num_of_prompts}" @@ -2530,7 +2526,7 @@ def __call__( # bg related attention mask - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim] = torch.ones(img_embbeds_dim, img_embbeds_dim) + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = torch.ones(729, 729) mask_downsample_t2i = IPAdapterMaskProcessor.downsample( bg_mask[0], 1, @@ -2539,10 +2535,10 @@ def __call__( ) mask_downsample_t2i = mask_downsample_t2i.to(device=query.device) mask_downsample_t2i = mask_downsample_t2i.squeeze() - mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(img_embbeds_dim, 1).to(device=query.device) + mask_downsample_t2i_tensor = mask_downsample_t2i.repeat(729, 1).to(device=query.device) mask_downsample_t2i_tensor_transpose = mask_downsample_t2i_tensor.transpose(0, 1).to(device=query.device) - attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim,-4096:] = mask_downsample_t2i_tensor - attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + img_embbeds_dim] = mask_downsample_t2i_tensor_transpose + attention_mask[len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729,-4096:] = mask_downsample_t2i_tensor + attention_mask[-4096:, len(prod_masks)*prod_embeds_dim:len(prod_masks)*prod_embeds_dim + 729] = mask_downsample_t2i_tensor_transpose attention_mask[-4096:,-4096:] = 1 @@ -2590,9 +2586,9 @@ def __call__( prod_mask_downsamples.append(prod_mask_downsample) hidden_states_img = F.scaled_dot_product_attention( - query[:,:,-(4096+img_embbeds_dim):,:], - key[:,:,-(4096+img_embbeds_dim):,:], - value[:,:,-(4096+img_embbeds_dim):,:], + query[:,:,-4825:,:], + key[:,:,-4825:,:], + value[:,:,-4825:,:], attn_mask=None, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 570127d09e56..68b10bfb31ed 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -382,7 +382,6 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots - is_blend_bg_enhance: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots image_width: Optional[int] = 1024, @@ -612,14 +611,10 @@ def __call__( f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" ) - if is_blend_bg_enhance: - prompt_embeds = torch.cat([prompt_embeds_list[-1], image_embeds_bg], dim=1) - for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list[:-1]), reversed(image_embeds_prods[:-1])): - prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) - else: - prompt_embeds = image_embeds_bg - for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): - prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) + prompt_embeds = image_embeds_bg + product_ratio = 0.7 + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) From ca104b97842a503c3784112081f081928caf9b43 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:17:35 +0800 Subject: [PATCH 288/482] Revert "add 'first_N_blocks' in joint_attention_kwargs" This reverts commit 75609f1b7a6ee725271d970bc7d1ad09f3ded0d1. --- src/diffusers/models/transformers/transformer_flux.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 95365da3d1bb..9ad598ef9688 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -529,13 +529,6 @@ def forward( image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) - elif 'first_N_blocks' in joint_attention_kwargs: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs if index_block <= joint_attention_kwargs['first_N_blocks'] else None, - ) else: hidden_states = block( hidden_states=hidden_states, From 3c223731b79414bed9d2f655049c4024bea7a458 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:18:50 +0800 Subject: [PATCH 289/482] Revert "default value" This reverts commit 32891f5f95fd792a48f5b15685a720e1ff193fcf. --- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 854f40350d5f..fe13eee82fab 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2499,7 +2499,7 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: - product_ratio = 0.7 + product_ratio = 0.65 attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 68b10bfb31ed..39dd6b68456b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -612,7 +612,7 @@ def __call__( ) prompt_embeds = image_embeds_bg - product_ratio = 0.7 + product_ratio = 0.65 for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From ce5d7683d6d32ba7dbf02067a683bf7e44558790 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 28 Apr 2025 15:19:00 +0800 Subject: [PATCH 290/482] Revert "default value" This reverts commit 8b09bb4537254ac171674522f5c1196d1be15207. --- src/diffusers/models/attention_processor.py | 1 - src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fe13eee82fab..e599f7ce2f2f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2499,7 +2499,6 @@ def __call__( hidden_states = torch.cat([hidden_states_region[:,:-4096,:], hidden_states_common], dim=1) # thesea modified for quick validation of product shots elif is_qv: - product_ratio = 0.65 attention_mask = torch.zeros(query.size(-2), key.size(-2), device=query.device) prod_embeds_dim = 512 + int(729 * product_ratio) num_of_prompts = int ((query.size(-2) - 729 - 4096)/prod_embeds_dim) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 39dd6b68456b..128308ae59a8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -612,7 +612,6 @@ def __call__( ) prompt_embeds = image_embeds_bg - product_ratio = 0.65 for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: From 11d6f08a60610c52a6b2429421610f0d184576ff Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 30 Apr 2025 17:14:26 +0800 Subject: [PATCH 291/482] add masked_bg --- .../pipelines/flux/pipeline_flux_prior_redux.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 128308ae59a8..4e7424456ef2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -15,6 +15,7 @@ from typing import List, Optional, Union import numpy as np +import cv2 import torch from PIL import Image @@ -369,6 +370,15 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + def apply_dilate_to_mask(self, mask,iterations = 10): + + kernel = np.ones((5, 5), np.uint8) + mask = mask.astype(np.uint8) + mask = cv2.dilate(mask, kernel, iterations=iterations) + mask = np.array(mask, dtype=bool) + + return mask + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -384,6 +394,7 @@ def __call__( is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots + iterations: Optional[int] = 10, # controlnet inpainting image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, return_dict: bool = True, @@ -521,6 +532,7 @@ def __call__( prompt=[prompt]*len(image_mask_prod) composed_image_all = np.zeros((image_width, image_height, 3)) + masked_bg = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): @@ -530,9 +542,11 @@ def __call__( composed_bg_image += img_array * image_mask_bg[index] composed_image_all += img_array * image_mask_all[index] + masked_bg += np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') + masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') bg_mask = Image.fromarray(bg_mask.astype(np.uint8)*255).convert('RGB') prod_masks = [] @@ -631,7 +645,7 @@ def __call__( if not return_dict: if is_qv: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: return (prompt_embeds, pooled_prompt_embeds) From 45edf8992164deaf8e1c90620df1cdff904c6cdc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 30 Apr 2025 17:21:39 +0800 Subject: [PATCH 292/482] fix mask bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 4e7424456ef2..1079c319bc0b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -542,8 +542,11 @@ def __call__( composed_bg_image += img_array * image_mask_bg[index] composed_image_all += img_array * image_mask_all[index] - masked_bg += np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) - + if is_product.lower() == "true": + masked_bg += 255*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) + else: + masked_bg += img_array * image_mask_bg[index] + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') From efd9aefd5d554da294026435a5882a72998ac5a4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 30 Apr 2025 17:25:25 +0800 Subject: [PATCH 293/482] fix mask bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 1079c319bc0b..ee2154abedf3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -545,8 +545,8 @@ def __call__( if is_product.lower() == "true": masked_bg += 255*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) else: - masked_bg += img_array * image_mask_bg[index] - + masked_bg += img_array * image_mask_all[index] + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') From 4819d91426f0b6608ff199df4c2d34a6961e8088 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 30 Apr 2025 17:37:53 +0800 Subject: [PATCH 294/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index ee2154abedf3..ab582ad99535 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -544,8 +544,6 @@ def __call__( composed_image_all += img_array * image_mask_all[index] if is_product.lower() == "true": masked_bg += 255*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) - else: - masked_bg += img_array * image_mask_all[index] composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') From ff837a18b4fce35dc1a9a9148a4e11c4d8ad07ec Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 30 Apr 2025 17:42:50 +0800 Subject: [PATCH 295/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 64cc8e13f33f..a680acf5a440 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1035,6 +1035,9 @@ def __call__( ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + # 7. Prepare latent variables latents, noise, image_latents, latent_image_ids = self.prepare_latents( From 8a61297ef09804a08a44085d865794b991058da0 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 30 Apr 2025 18:12:36 +0800 Subject: [PATCH 296/482] Update pipeline_flux_prior_redux.py --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index ab582ad99535..53a8d12acad8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -395,6 +395,7 @@ def __call__( is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots iterations: Optional[int] = 10, # controlnet inpainting + mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, return_dict: bool = True, @@ -543,7 +544,7 @@ def __call__( composed_image_all += img_array * image_mask_all[index] if is_product.lower() == "true": - masked_bg += 255*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) + masked_bg += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') From 012038fd579752fa4aa2fa2c06081974244a225d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 13:29:20 +0800 Subject: [PATCH 297/482] Update pipeline_flux_control_inpaint.py --- src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index af7e8b53fad3..71e5c0f4f9f8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -1004,6 +1004,9 @@ def __call__( ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 8 From e010c8fb11a70c4164ea04bce15f64da92b5e61e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 14:20:32 +0800 Subject: [PATCH 298/482] Update pipeline_flux_controlnet_inpainting.py --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index a680acf5a440..b6a88436006f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1075,6 +1075,8 @@ def __call__( generator, ) + print(f'mask size={mask.size()}') + controlnet_keep = [] for i in range(len(timesteps)): keeps = [ From 3e55fb9b61fb70485e4b7ff3bb1ea07ca3bac11c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 14:27:33 +0800 Subject: [PATCH 299/482] print info --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index b6a88436006f..7bb0da49eb75 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1076,6 +1076,7 @@ def __call__( ) print(f'mask size={mask.size()}') + print(f'mask value={mask}') controlnet_keep = [] for i in range(len(timesteps)): From 1bffbf0b204c86cd9a07ffa08c92bc68299f016c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 14:31:34 +0800 Subject: [PATCH 300/482] return mask --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 7bb0da49eb75..793f637e59b7 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1075,9 +1075,6 @@ def __call__( generator, ) - print(f'mask size={mask.size()}') - print(f'mask value={mask}') - controlnet_keep = [] for i in range(len(timesteps)): keeps = [ @@ -1205,6 +1202,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (image,mask) return FluxPipelineOutput(images=image) From 8201e7e8fe1a261ce505e5de1f41e62e8b81852c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 18:39:20 +0800 Subject: [PATCH 301/482] Update pipeline_flux_controlnet_inpainting.py --- .../pipeline_flux_controlnet_inpainting.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 793f637e59b7..eb2d5a6d7eec 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1,6 +1,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import cv2 import numpy as np import PIL import torch @@ -767,6 +768,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + iterations: Optional[int] = 1, ): """ Function invoked when calling the pipeline for generation. @@ -1075,6 +1077,65 @@ def __call__( generator, ) + def apply_dilate_to_mask_image(mask,iterations = 8): + kernel = np.ones((3, 3), np.uint8) + mask = np.array(mask, dtype=bool) + mask = mask.astype(np.uint8) + mask = cv2.dilate(mask, kernel, iterations=iterations) + mask = mask.astype(np.uint8) + #mask = Image.fromarray(mask * 255) + + return mask + + # input np array size: 4096 x 3 + # output np array size: 64 x 64 x 3 + def mask_gradienting(mask, iterations=1): + mask_array = mask.reshape(64,64,-1) + + # gradent 1 + dilated_mask_image = apply_dilate_to_mask_image(mask_array,iterations = iterations) + original_mask_array_bool = np.array(mask_array,dtype=bool) + dilated_mask_array_bool = np.array(dilated_mask_image, dtype=bool) + gray_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 200 + gray_array = gray_array * (~original_mask_array_bool) + gray_array = gray_array * dilated_mask_array_bool + mask1_array = np.where(mask_array>0, mask_array* 255, gray_array) + + # gradent 2 + dilated_mask1_image = apply_dilate_to_mask_image(mask1_array,iterations = iterations) + mask1_array_bool = np.array(mask1_array,dtype=bool) + dilated_mask1_image_bool = np.array(dilated_mask1_image, dtype=bool) + gray1_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 150 + gray1_array = gray1_array * (~mask1_array_bool) + gray1_array = gray1_array * dilated_mask1_image_bool + mask2_array = np.where(mask1_array>0,mask1_array,gray1_array) + + # gradent 3 + dilated_mask2_image = apply_dilate_to_mask_image(mask2_array,iterations = iterations) + mask2_array_bool = np.array(mask2_array,dtype=bool) + dilated_mask2_image_bool = np.array(dilated_mask2_image, dtype=bool) + gray2_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 100 + gray2_array = gray2_array * (~mask2_array_bool) + gray2_array = gray2_array * dilated_mask2_image_bool + mask3_array = np.where(mask2_array>0,mask2_array,gray2_array) + + # rescale to 1.0, 0.75, 0.5, 0.25, 0.0 + final_mask_array = np.zeros_like(mask_array) + final_mask_array = np.where(mask3_array==255, 1, final_mask_array) + final_mask_array = np.where(mask3_array==200, 0.75, final_mask_array) + final_mask_array = np.where(mask3_array==150, 0.5, final_mask_array) + final_mask_array = np.where(mask3_array==100, 0.25, final_mask_array) + + return final_mask_array + + # mask graident + tmp_mask = mask[0,:,:].cpu().numpy() + mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) + tmp_tensor = torch.Tensor(mask_gradient) + tmp_tensor = tmp_tensor.view(-1,64) + tmp_tensor = torch.stack([tmp_tensor, tmp_tensor, tmp_tensor], dim=0) + mask = tmp_tensor.to(device=mask.device, dtype=mask.dtype) + controlnet_keep = [] for i in range(len(timesteps)): keeps = [ From 5aa6d4ab2c410c29a56dd05278f1e88382fb0bc3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 19:12:49 +0800 Subject: [PATCH 302/482] Update pipeline_flux_controlnet_inpainting.py --- .../pipeline_flux_controlnet_inpainting.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index eb2d5a6d7eec..ece524dcfeef 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -768,7 +768,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - iterations: Optional[int] = 1, + iterations: Optional[int] = 3, ): """ Function invoked when calling the pipeline for generation. @@ -1119,12 +1119,22 @@ def mask_gradienting(mask, iterations=1): gray2_array = gray2_array * dilated_mask2_image_bool mask3_array = np.where(mask2_array>0,mask2_array,gray2_array) - # rescale to 1.0, 0.75, 0.5, 0.25, 0.0 + # gradent 4 + dilated_mask3_image = apply_dilate_to_mask_image(mask3_array,iterations = iterations) + mask3_array_bool = np.array(mask3_array,dtype=bool) + dilated_mask3_image_bool = np.array(dilated_mask3_image, dtype=bool) + gray3_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 50 + gray3_array = gray3_array * (~mask3_array_bool) + gray3_array = gray3_array * dilated_mask3_image_bool + mask4_array = np.where(mask3_array>0,mask3_array,gray3_array) + + # rescale to 1.0, 0.8, 0.6, 0.4, 0.2, 0.0 final_mask_array = np.zeros_like(mask_array) - final_mask_array = np.where(mask3_array==255, 1, final_mask_array) - final_mask_array = np.where(mask3_array==200, 0.75, final_mask_array) - final_mask_array = np.where(mask3_array==150, 0.5, final_mask_array) - final_mask_array = np.where(mask3_array==100, 0.25, final_mask_array) + final_mask_array = np.where(mask4_array==255, 1.0, final_mask_array) + final_mask_array = np.where(mask4_array==200, 0.8, final_mask_array) + final_mask_array = np.where(mask4_array==150, 0.6, final_mask_array) + final_mask_array = np.where(mask4_array==100, 0.4, final_mask_array) + final_mask_array = np.where(mask4_array==50, 0.2, final_mask_array) return final_mask_array From e35d764696f661064a64e0bef53e671915d760ae Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 19:21:57 +0800 Subject: [PATCH 303/482] Update pipeline_flux_controlnet_inpainting.py --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index ece524dcfeef..51ca85a4b4db 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1273,6 +1273,6 @@ def mask_gradienting(mask, iterations=1): self.maybe_free_model_hooks() if not return_dict: - return (image,mask) + return (image,) return FluxPipelineOutput(images=image) From 67ee6221f2aaadd744a913bd7df81f85c241a4e3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 19:23:47 +0800 Subject: [PATCH 304/482] Update pipeline_flux_controlnet_inpainting.py --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 51ca85a4b4db..a7e4d5079bf3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1077,7 +1077,7 @@ def __call__( generator, ) - def apply_dilate_to_mask_image(mask,iterations = 8): + def apply_dilate_to_mask_image(mask,iterations): kernel = np.ones((3, 3), np.uint8) mask = np.array(mask, dtype=bool) mask = mask.astype(np.uint8) @@ -1089,7 +1089,7 @@ def apply_dilate_to_mask_image(mask,iterations = 8): # input np array size: 4096 x 3 # output np array size: 64 x 64 x 3 - def mask_gradienting(mask, iterations=1): + def mask_gradienting(mask, iterations=3): mask_array = mask.reshape(64,64,-1) # gradent 1 From 888068bfa1107be1caf1e85a611f1170882a5187 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 4 May 2025 19:26:26 +0800 Subject: [PATCH 305/482] iterations=20 to redux prior --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 53a8d12acad8..ceffb8735c0b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -394,7 +394,7 @@ def __call__( is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots - iterations: Optional[int] = 10, # controlnet inpainting + iterations: Optional[int] = 20, # controlnet inpainting mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, From b2d4e246ed1a321a79d18c0dbbcd2f1ec60dfbf8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 5 May 2025 21:09:50 +0800 Subject: [PATCH 306/482] modify return logic for controlnet inpainting --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index ceffb8735c0b..b844a870362e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -394,6 +394,7 @@ def __call__( is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots + is_inpainting: Optional[bool] = False, # controlnet inpainting iterations: Optional[int] = 20, # controlnet inpainting mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, @@ -647,7 +648,10 @@ def __call__( if not return_dict: if is_qv: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + if is_inpainting: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + else: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: return (prompt_embeds, pooled_prompt_embeds) From 7c1cb35b907114f25b409c24bc397f28b4581af0 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 7 May 2025 16:23:28 +0800 Subject: [PATCH 307/482] fix batch size bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index a7e4d5079bf3..d5d7b3181fdb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1143,8 +1143,11 @@ def mask_gradienting(mask, iterations=3): mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) tmp_tensor = torch.Tensor(mask_gradient) tmp_tensor = tmp_tensor.view(-1,64) - tmp_tensor = torch.stack([tmp_tensor, tmp_tensor, tmp_tensor], dim=0) - mask = tmp_tensor.to(device=mask.device, dtype=mask.dtype) + mask_tensor = torch.zeros_like(mask) + for index in range(mask_tensor.shape[0]): + mask_tensor[index,:,:] = tmp_tensor + #tmp_tensor = torch.stack([tmp_tensor, tmp_tensor, tmp_tensor], dim=0) + mask = mask_tensor.to(device=mask.device, dtype=mask.dtype) controlnet_keep = [] for i in range(len(timesteps)): From 6a9e25fc41f53de3e16d0d87621cb5adf40386b1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 10 May 2025 18:57:56 +0800 Subject: [PATCH 308/482] return mask_image --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index d2af2061a19f..cce407440bab 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1035,6 +1035,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (image,mask_image) return FluxPipelineOutput(images=image) From 5bbf836863ef7f7542ee2640090ad1a3a84635cb Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 10 May 2025 20:04:49 +0800 Subject: [PATCH 309/482] Update pipeline_flux_fill.py --- .../pipelines/flux/pipeline_flux_fill.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index cce407440bab..2d9a2629b70f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -15,6 +15,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import cv2 import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast @@ -750,6 +751,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + iterations: Optional[int] = 10, ): r""" Function invoked when calling the pipeline for generation. @@ -941,11 +943,77 @@ def __call__( latents, ) + ## prepare mask gradient + def apply_dilate_to_mask_image(mask,iterations): + kernel = np.ones((3, 3), np.uint8) + mask = np.array(mask, dtype=bool) + mask = mask.astype(np.uint8) + mask = cv2.dilate(mask, kernel, iterations=iterations) + mask = mask.astype(np.uint8) + #mask = Image.fromarray(mask * 255) + + return mask + + def mask_gradienting(mask_array, iterations=3): + # gradent 1 + dilated_mask_image = apply_dilate_to_mask_image(mask_array,iterations = iterations) + original_mask_array_bool = np.array(mask_array,dtype=bool) + dilated_mask_array_bool = np.array(dilated_mask_image, dtype=bool) + gray_array = np.ones_like(mask_array) * 200 + gray_array = gray_array * (~original_mask_array_bool) + gray_array = gray_array * dilated_mask_array_bool + mask1_array = np.where(mask_array>0, mask_array* 255, gray_array) + + # gradent 2 + dilated_mask1_image = apply_dilate_to_mask_image(mask1_array,iterations = iterations) + mask1_array_bool = np.array(mask1_array,dtype=bool) + dilated_mask1_image_bool = np.array(dilated_mask1_image, dtype=bool) + gray1_array = np.ones_like(mask_array) * 150 + gray1_array = gray1_array * (~mask1_array_bool) + gray1_array = gray1_array * dilated_mask1_image_bool + mask2_array = np.where(mask1_array>0,mask1_array,gray1_array) + + # gradent 3 + dilated_mask2_image = apply_dilate_to_mask_image(mask2_array,iterations = iterations) + mask2_array_bool = np.array(mask2_array,dtype=bool) + dilated_mask2_image_bool = np.array(dilated_mask2_image, dtype=bool) + gray2_array = np.ones_like(mask_array) * 100 + gray2_array = gray2_array * (~mask2_array_bool) + gray2_array = gray2_array * dilated_mask2_image_bool + mask3_array = np.where(mask2_array>0,mask2_array,gray2_array) + + # gradent 4 + dilated_mask3_image = apply_dilate_to_mask_image(mask3_array,iterations = iterations) + mask3_array_bool = np.array(mask3_array,dtype=bool) + dilated_mask3_image_bool = np.array(dilated_mask3_image, dtype=bool) + gray3_array = np.ones_like(mask_array) * 50 + gray3_array = gray3_array * (~mask3_array_bool) + gray3_array = gray3_array * dilated_mask3_image_bool + mask4_array = np.where(mask3_array>0,mask3_array,gray3_array) + + # rescale to 1.0, 0.8, 0.6, 0.4, 0.2, 0.0 + final_mask_array = np.zeros_like(mask_array) + final_mask_array = np.where(mask4_array==255, 1.0, final_mask_array) + final_mask_array = np.where(mask4_array==200, 0.8, final_mask_array) + final_mask_array = np.where(mask4_array==150, 0.6, final_mask_array) + final_mask_array = np.where(mask4_array==100, 0.4, final_mask_array) + final_mask_array = np.where(mask4_array==50, 0.2, final_mask_array) + + return final_mask_array + # 6. Prepare mask and masked image latents if masked_image_latents is not None: masked_image_latents = masked_image_latents.to(latents.device) else: mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) + print(f'mask_image shape = {mask_image.shape}') + + mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) + mask_gradient = torch.Tensor(mask_gradient) + mask_gradient = mask_gradient.unsqueeze(0) + mask_gradient = mask_gradient.unsqueeze(0) + print(f'mask_gradient shape = {mask_gradient.shape}') + mask_image = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) masked_image = init_image * (1 - mask_image) masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) From 7cbd10d1d249d8316eded015bad4746687ea8738 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 10 May 2025 20:06:24 +0800 Subject: [PATCH 310/482] Update pipeline_flux_fill.py --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 2d9a2629b70f..8c3ad9904859 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1009,7 +1009,7 @@ def mask_gradienting(mask_array, iterations=3): print(f'mask_image shape = {mask_image.shape}') mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) - mask_gradient = torch.Tensor(mask_gradient) + mask_gradient = torch.from_numpy(mask_gradient) mask_gradient = mask_gradient.unsqueeze(0) mask_gradient = mask_gradient.unsqueeze(0) print(f'mask_gradient shape = {mask_gradient.shape}') From d9cb8d3e078ddc7810b49d560146e27fd7b2907a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 10 May 2025 20:12:38 +0800 Subject: [PATCH 311/482] remove print --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 8c3ad9904859..200a42930469 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1006,13 +1006,12 @@ def mask_gradienting(mask_array, iterations=3): masked_image_latents = masked_image_latents.to(latents.device) else: mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) - print(f'mask_image shape = {mask_image.shape}') - + mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) mask_gradient = torch.from_numpy(mask_gradient) mask_gradient = mask_gradient.unsqueeze(0) mask_gradient = mask_gradient.unsqueeze(0) - print(f'mask_gradient shape = {mask_gradient.shape}') + mask_image = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) masked_image = init_image * (1 - mask_image) From da83fddde0374f4d95cabea786f3fbf3ec0f1316 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 00:50:31 +0800 Subject: [PATCH 312/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 200a42930469..b5febe574541 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1007,12 +1007,11 @@ def mask_gradienting(mask_array, iterations=3): else: mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) - mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) - mask_gradient = torch.from_numpy(mask_gradient) - mask_gradient = mask_gradient.unsqueeze(0) - mask_gradient = mask_gradient.unsqueeze(0) - - mask_image = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) + #mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) + #mask_gradient = torch.from_numpy(mask_gradient) + #mask_gradient = mask_gradient.unsqueeze(0) + #mask_gradient = mask_gradient.unsqueeze(0) + #mask_image = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) masked_image = init_image * (1 - mask_image) masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) @@ -1066,6 +1065,7 @@ def mask_gradienting(mask_array, iterations=3): # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + print(f'latents size = {latents.shape}') if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From fa071768b1eaba8a0da19310be627f192ee8b402 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 17:54:05 +0800 Subject: [PATCH 313/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index b5febe574541..7442c5e4eb4e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1050,6 +1050,8 @@ def mask_gradienting(mask_array, iterations=3): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + print(f'before prediction latents size = {latents.shape}') + print(f'masked_image_latents size={masked_image_latents.shape}') noise_pred = self.transformer( hidden_states=torch.cat((latents, masked_image_latents), dim=2), timestep=timestep / 1000, @@ -1065,7 +1067,7 @@ def mask_gradienting(mask_array, iterations=3): # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - print(f'latents size = {latents.shape}') + print(f'after prediction latents size = {latents.shape}') if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 1d1a4d09a062a1fce0f0b510aa210df4b82a4b42 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 18:03:43 +0800 Subject: [PATCH 314/482] print info --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 7442c5e4eb4e..ba6dd899d1ca 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1012,7 +1012,8 @@ def mask_gradienting(mask_array, iterations=3): #mask_gradient = mask_gradient.unsqueeze(0) #mask_gradient = mask_gradient.unsqueeze(0) #mask_image = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) - + print(f'masked_image size = {masked_image.shape}') + print(f'mask_image size={mask_image.shape}') masked_image = init_image * (1 - mask_image) masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) @@ -1029,7 +1030,9 @@ def mask_gradienting(mask_array, iterations=3): device, generator, ) + print(f'before cat masked_image_latents size={masked_image_latents.shape}') masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) + print(f'mask size={mask.shape}') num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) From 3baaf07e27b00d694ce0d07da24208244af40c8e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 18:06:43 +0800 Subject: [PATCH 315/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index ba6dd899d1ca..e5d9747775c5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1012,10 +1012,11 @@ def mask_gradienting(mask_array, iterations=3): #mask_gradient = mask_gradient.unsqueeze(0) #mask_gradient = mask_gradient.unsqueeze(0) #mask_image = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) - print(f'masked_image size = {masked_image.shape}') + print(f'mask_image size={mask_image.shape}') masked_image = init_image * (1 - mask_image) masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) + print(f'masked_image size = {masked_image.shape}') height, width = init_image.shape[-2:] mask, masked_image_latents = self.prepare_mask_latents( From 18f80bb6aa60ad88730508d3f859c689c34a69aa Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 18:37:44 +0800 Subject: [PATCH 316/482] print info --- .../pipelines/flux/pipeline_flux_fill.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index e5d9747775c5..c07e76b523bc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -414,7 +414,7 @@ def prepare_mask_latents( ) mask = mask.to(device=device, dtype=dtype) - return mask, masked_image_latents + return mask, masked_image_latents, height, width # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt def encode_prompt( @@ -1007,11 +1007,11 @@ def mask_gradienting(mask_array, iterations=3): else: mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) - #mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) - #mask_gradient = torch.from_numpy(mask_gradient) - #mask_gradient = mask_gradient.unsqueeze(0) - #mask_gradient = mask_gradient.unsqueeze(0) - #mask_image = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) + mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) + mask_gradient = torch.from_numpy(mask_gradient) + mask_gradient = mask_gradient.unsqueeze(0) + mask_gradient = mask_gradient.unsqueeze(0) + mask_gradient = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) print(f'mask_image size={mask_image.shape}') masked_image = init_image * (1 - mask_image) @@ -1019,7 +1019,7 @@ def mask_gradienting(mask_array, iterations=3): print(f'masked_image size = {masked_image.shape}') height, width = init_image.shape[-2:] - mask, masked_image_latents = self.prepare_mask_latents( + mask, masked_image_latents, height, width = self.prepare_mask_latents( mask_image, masked_image, batch_size, @@ -1031,6 +1031,7 @@ def mask_gradienting(mask_array, iterations=3): device, generator, ) + print(f'height={height}, width={width}') print(f'before cat masked_image_latents size={masked_image_latents.shape}') masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) print(f'mask size={mask.shape}') From e4d0f9db4d1fa7371e6a9ddb12513788a1767a38 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 20:21:40 +0800 Subject: [PATCH 317/482] Update pipeline_flux_fill.py --- .../pipelines/flux/pipeline_flux_fill.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index c07e76b523bc..b905aa60b1b5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1019,7 +1019,7 @@ def mask_gradienting(mask_array, iterations=3): print(f'masked_image size = {masked_image.shape}') height, width = init_image.shape[-2:] - mask, masked_image_latents, height, width = self.prepare_mask_latents( + mask, masked_image_latents, height_latent, width_latent = self.prepare_mask_latents( mask_image, masked_image, batch_size, @@ -1031,7 +1031,9 @@ def mask_gradienting(mask_array, iterations=3): device, generator, ) + height_latent, width_latent = height_latent // 2, width_latent // 2, print(f'height={height}, width={width}') + print(f'height_latent={height_latent}, width_latent={width_latent}') print(f'before cat masked_image_latents size={masked_image_latents.shape}') masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) print(f'mask size={mask.shape}') @@ -1073,6 +1075,28 @@ def mask_gradienting(mask_array, iterations=3): latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] print(f'after prediction latents size = {latents.shape}') + latents = latents.view(latents.shape[0], height_latent, width_latent, -1) + if height_latent > 64: + tmp_dim = (height_latent - 64) // 2 + tmp_mean = (latents[:,tmp_dim-1,:,:] + latents[:,tmp_dim,:,:]) / 2.0 + latents[:,tmp_dim-1,:,:] = tmp_mean + latents[:,tmp_dim,:,:] = tmp_mean + + tmp_mean = (latents[:,64+tmp_dim-1,:,:] + latents[:,64+tmp_dim,:,:]) / 2.0 + latents[:,64+tmp_dim-1,:,:] = tmp_mean + latents[:,64+tmp_dim,:,:] = tmp_mean + + if width_latent > 64: + tmp_dim = (width_latent - 64) // 2 + tmp_mean = (latents[:,:,tmp_dim-1,:] + latents[:,:,tmp_dim,:]) / 2.0 + latents[:,:,tmp_dim-1,:] = tmp_mean + latents[:,:,tmp_dim,:] = tmp_mean + + tmp_mean = (latents[:,:,64+tmp_dim-1,:] + latents[:,:,64+tmp_dim,:]) / 2.0 + latents[:,:,64+tmp_dim-1,:] = tmp_mean + latents[:,:,64+tmp_dim,:] = tmp_mean + + latents = latents.view(latents.shape[0], height_latent * width_latent, -1) if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From fd3718f0e8a93db29cffbf95a023c3dd89d8faa4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 20:28:38 +0800 Subject: [PATCH 318/482] first N steps --- .../pipelines/flux/pipeline_flux_fill.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index b905aa60b1b5..6ec4d64ed23d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -752,6 +752,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, iterations: Optional[int] = 10, + first_N_steps: Optional[int] = 15, ): r""" Function invoked when calling the pipeline for generation. @@ -1053,7 +1054,8 @@ def mask_gradienting(mask_array, iterations=3): for i, t in enumerate(timesteps): if self.interrupt: continue - + + print(f'timesteps={i}') # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -1076,27 +1078,28 @@ def mask_gradienting(mask_array, iterations=3): latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] print(f'after prediction latents size = {latents.shape}') latents = latents.view(latents.shape[0], height_latent, width_latent, -1) - if height_latent > 64: - tmp_dim = (height_latent - 64) // 2 - tmp_mean = (latents[:,tmp_dim-1,:,:] + latents[:,tmp_dim,:,:]) / 2.0 - latents[:,tmp_dim-1,:,:] = tmp_mean - latents[:,tmp_dim,:,:] = tmp_mean - - tmp_mean = (latents[:,64+tmp_dim-1,:,:] + latents[:,64+tmp_dim,:,:]) / 2.0 - latents[:,64+tmp_dim-1,:,:] = tmp_mean - latents[:,64+tmp_dim,:,:] = tmp_mean - - if width_latent > 64: - tmp_dim = (width_latent - 64) // 2 - tmp_mean = (latents[:,:,tmp_dim-1,:] + latents[:,:,tmp_dim,:]) / 2.0 - latents[:,:,tmp_dim-1,:] = tmp_mean - latents[:,:,tmp_dim,:] = tmp_mean - - tmp_mean = (latents[:,:,64+tmp_dim-1,:] + latents[:,:,64+tmp_dim,:]) / 2.0 - latents[:,:,64+tmp_dim-1,:] = tmp_mean - latents[:,:,64+tmp_dim,:] = tmp_mean - - latents = latents.view(latents.shape[0], height_latent * width_latent, -1) + if i < first_N_steps: + if height_latent > 64: + tmp_dim = (height_latent - 64) // 2 + tmp_mean = (latents[:,tmp_dim-1,:,:] + latents[:,tmp_dim,:,:]) / 2.0 + latents[:,tmp_dim-1,:,:] = tmp_mean + latents[:,tmp_dim,:,:] = tmp_mean + + tmp_mean = (latents[:,64+tmp_dim-1,:,:] + latents[:,64+tmp_dim,:,:]) / 2.0 + latents[:,64+tmp_dim-1,:,:] = tmp_mean + latents[:,64+tmp_dim,:,:] = tmp_mean + + if width_latent > 64: + tmp_dim = (width_latent - 64) // 2 + tmp_mean = (latents[:,:,tmp_dim-1,:] + latents[:,:,tmp_dim,:]) / 2.0 + latents[:,:,tmp_dim-1,:] = tmp_mean + latents[:,:,tmp_dim,:] = tmp_mean + + tmp_mean = (latents[:,:,64+tmp_dim-1,:] + latents[:,:,64+tmp_dim,:]) / 2.0 + latents[:,:,64+tmp_dim-1,:] = tmp_mean + latents[:,:,64+tmp_dim,:] = tmp_mean + + latents = latents.view(latents.shape[0], height_latent * width_latent, -1) if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 16c4d90d7b687ac6e6f0e1282e3b31b98bc8c5df Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 20:34:25 +0800 Subject: [PATCH 319/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 6ec4d64ed23d..decd2d394bf4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1077,8 +1077,8 @@ def mask_gradienting(mask_array, iterations=3): latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] print(f'after prediction latents size = {latents.shape}') - latents = latents.view(latents.shape[0], height_latent, width_latent, -1) if i < first_N_steps: + latents = latents.view(latents.shape[0], height_latent, width_latent, -1) if height_latent > 64: tmp_dim = (height_latent - 64) // 2 tmp_mean = (latents[:,tmp_dim-1,:,:] + latents[:,tmp_dim,:,:]) / 2.0 From e207cd5207ce1cd8d3ff6677ef61f6f499a50709 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 21:01:35 +0800 Subject: [PATCH 320/482] apply gradient to mask --- .../pipelines/flux/pipeline_flux_fill.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index decd2d394bf4..2291f2079cbc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1008,11 +1008,11 @@ def mask_gradienting(mask_array, iterations=3): else: mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) - mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) - mask_gradient = torch.from_numpy(mask_gradient) - mask_gradient = mask_gradient.unsqueeze(0) - mask_gradient = mask_gradient.unsqueeze(0) - mask_gradient = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) + #mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) + #mask_gradient = torch.from_numpy(mask_gradient) + #mask_gradient = mask_gradient.unsqueeze(0) + #mask_gradient = mask_gradient.unsqueeze(0) + #mask_gradient = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) print(f'mask_image size={mask_image.shape}') masked_image = init_image * (1 - mask_image) @@ -1036,6 +1036,17 @@ def mask_gradienting(mask_array, iterations=3): print(f'height={height}, width={width}') print(f'height_latent={height_latent}, width_latent={width_latent}') print(f'before cat masked_image_latents size={masked_image_latents.shape}') + + mask = mask.view(mask.shape[0], height_latent, width_latent, -1) + + mask_gradient = mask_gradienting(mask[0,:,:,:], iterations=iterations) + mask_gradient = torch.from_numpy(mask_gradient) + mask_gradient = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) + + mask = [mask_gradient] * mask.shape[0] + mask = torch.stack(mask,dim=0) + mask = mask.view(mask.shape[0], height_latent*width_latent, -1) + masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) print(f'mask size={mask.shape}') From 756d144192c6a82232d50dfcf1c34c9503230b72 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 21:06:52 +0800 Subject: [PATCH 321/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 2291f2079cbc..947a5ab97f3a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1039,9 +1039,9 @@ def mask_gradienting(mask_array, iterations=3): mask = mask.view(mask.shape[0], height_latent, width_latent, -1) - mask_gradient = mask_gradienting(mask[0,:,:,:], iterations=iterations) + mask_gradient = mask_gradienting(mask[0,:,:,:].cpu().numpy(), iterations=iterations) mask_gradient = torch.from_numpy(mask_gradient) - mask_gradient = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) + mask_gradient = mask_gradient.to(device=mask.device, dtype=mask.dtype) mask = [mask_gradient] * mask.shape[0] mask = torch.stack(mask,dim=0) @@ -1147,6 +1147,6 @@ def mask_gradienting(mask_array, iterations=3): self.maybe_free_model_hooks() if not return_dict: - return (image,mask_image) + return (image,mask_gradient) return FluxPipelineOutput(images=image) From 30b91b90ee6b07afcb596ee67aa4f0853c5fc8ee Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 21:11:39 +0800 Subject: [PATCH 322/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 947a5ab97f3a..8e148319c2db 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1038,10 +1038,10 @@ def mask_gradienting(mask_array, iterations=3): print(f'before cat masked_image_latents size={masked_image_latents.shape}') mask = mask.view(mask.shape[0], height_latent, width_latent, -1) - + mask = mask.to(dtype=torch.float16) mask_gradient = mask_gradienting(mask[0,:,:,:].cpu().numpy(), iterations=iterations) mask_gradient = torch.from_numpy(mask_gradient) - mask_gradient = mask_gradient.to(device=mask.device, dtype=mask.dtype) + mask_gradient = mask_gradient.to(device=masked_image_latents.device, dtype=masked_image_latents.dtype) mask = [mask_gradient] * mask.shape[0] mask = torch.stack(mask,dim=0) From 0ad12d97dbed02e0f4453bc1b985c2d528b55202 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 21:19:31 +0800 Subject: [PATCH 323/482] return mask --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 8e148319c2db..51ab712609fa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1147,6 +1147,6 @@ def mask_gradienting(mask_array, iterations=3): self.maybe_free_model_hooks() if not return_dict: - return (image,mask_gradient) + return (image,mask,mask_gradient) return FluxPipelineOutput(images=image) From 643252783a8c1965d2ff5d6423b4b483d5afee3b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 21:21:28 +0800 Subject: [PATCH 324/482] return mask_original --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 51ab712609fa..c24fdab4c64f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import cv2 +import copy import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast @@ -1037,6 +1038,8 @@ def mask_gradienting(mask_array, iterations=3): print(f'height_latent={height_latent}, width_latent={width_latent}') print(f'before cat masked_image_latents size={masked_image_latents.shape}') + mask_original = copy.deepcopy(mask) + mask = mask.view(mask.shape[0], height_latent, width_latent, -1) mask = mask.to(dtype=torch.float16) mask_gradient = mask_gradienting(mask[0,:,:,:].cpu().numpy(), iterations=iterations) @@ -1147,6 +1150,6 @@ def mask_gradienting(mask_array, iterations=3): self.maybe_free_model_hooks() if not return_dict: - return (image,mask,mask_gradient) + return (image,mask_original,mask_gradient) return FluxPipelineOutput(images=image) From aaf6599e5e5e9116ad73d93c54f899afcbc1edfd Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 11 May 2025 22:37:39 +0800 Subject: [PATCH 325/482] remove print --- .../pipelines/flux/pipeline_flux_fill.py | 48 ++----------------- 1 file changed, 3 insertions(+), 45 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index c24fdab4c64f..9cd465f7e0ba 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -752,8 +752,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - iterations: Optional[int] = 10, - first_N_steps: Optional[int] = 15, + iterations: Optional[int] = 30, ): r""" Function invoked when calling the pipeline for generation. @@ -1009,17 +1008,9 @@ def mask_gradienting(mask_array, iterations=3): else: mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) - #mask_gradient = mask_gradienting(mask_image.squeeze(), iterations=iterations) - #mask_gradient = torch.from_numpy(mask_gradient) - #mask_gradient = mask_gradient.unsqueeze(0) - #mask_gradient = mask_gradient.unsqueeze(0) - #mask_gradient = mask_gradient.to(device=mask_image.device, dtype=mask_image.dtype) - - print(f'mask_image size={mask_image.shape}') masked_image = init_image * (1 - mask_image) masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) - print(f'masked_image size = {masked_image.shape}') - + height, width = init_image.shape[-2:] mask, masked_image_latents, height_latent, width_latent = self.prepare_mask_latents( mask_image, @@ -1034,11 +1025,6 @@ def mask_gradienting(mask_array, iterations=3): generator, ) height_latent, width_latent = height_latent // 2, width_latent // 2, - print(f'height={height}, width={width}') - print(f'height_latent={height_latent}, width_latent={width_latent}') - print(f'before cat masked_image_latents size={masked_image_latents.shape}') - - mask_original = copy.deepcopy(mask) mask = mask.view(mask.shape[0], height_latent, width_latent, -1) mask = mask.to(dtype=torch.float16) @@ -1051,7 +1037,6 @@ def mask_gradienting(mask_array, iterations=3): mask = mask.view(mask.shape[0], height_latent*width_latent, -1) masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) - print(f'mask size={mask.shape}') num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -1069,12 +1054,9 @@ def mask_gradienting(mask_array, iterations=3): if self.interrupt: continue - print(f'timesteps={i}') # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - print(f'before prediction latents size = {latents.shape}') - print(f'masked_image_latents size={masked_image_latents.shape}') noise_pred = self.transformer( hidden_states=torch.cat((latents, masked_image_latents), dim=2), timestep=timestep / 1000, @@ -1090,30 +1072,6 @@ def mask_gradienting(mask_array, iterations=3): # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - print(f'after prediction latents size = {latents.shape}') - if i < first_N_steps: - latents = latents.view(latents.shape[0], height_latent, width_latent, -1) - if height_latent > 64: - tmp_dim = (height_latent - 64) // 2 - tmp_mean = (latents[:,tmp_dim-1,:,:] + latents[:,tmp_dim,:,:]) / 2.0 - latents[:,tmp_dim-1,:,:] = tmp_mean - latents[:,tmp_dim,:,:] = tmp_mean - - tmp_mean = (latents[:,64+tmp_dim-1,:,:] + latents[:,64+tmp_dim,:,:]) / 2.0 - latents[:,64+tmp_dim-1,:,:] = tmp_mean - latents[:,64+tmp_dim,:,:] = tmp_mean - - if width_latent > 64: - tmp_dim = (width_latent - 64) // 2 - tmp_mean = (latents[:,:,tmp_dim-1,:] + latents[:,:,tmp_dim,:]) / 2.0 - latents[:,:,tmp_dim-1,:] = tmp_mean - latents[:,:,tmp_dim,:] = tmp_mean - - tmp_mean = (latents[:,:,64+tmp_dim-1,:] + latents[:,:,64+tmp_dim,:]) / 2.0 - latents[:,:,64+tmp_dim-1,:] = tmp_mean - latents[:,:,64+tmp_dim,:] = tmp_mean - - latents = latents.view(latents.shape[0], height_latent * width_latent, -1) if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -1150,6 +1108,6 @@ def mask_gradienting(mask_array, iterations=3): self.maybe_free_model_hooks() if not return_dict: - return (image,mask_original,mask_gradient) + return (image,) return FluxPipelineOutput(images=image) From 8bbaedf7a9fcf503f459d51444d87c7096b3edf2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 12 May 2025 15:52:26 +0800 Subject: [PATCH 326/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d5d7b3181fdb..283d38228a21 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1139,6 +1139,7 @@ def mask_gradienting(mask, iterations=3): return final_mask_array # mask graident + mask = mask.to(dtype=torch.float16) tmp_mask = mask[0,:,:].cpu().numpy() mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) tmp_tensor = torch.Tensor(mask_gradient) From de5aa7f5571daf905609d3883ba5f99db3c8ec52 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 12 May 2025 16:00:13 +0800 Subject: [PATCH 327/482] self._execution_device = text_encoder.device --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index b844a870362e..9075da938eaa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -143,7 +143,8 @@ def __init__( self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - + self._execution_device = text_encoder.device + def check_inputs( self, image, From a802f47a8581d3c9842bd13029635d434779cd53 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 12 May 2025 16:02:26 +0800 Subject: [PATCH 328/482] Revert "self._execution_device = text_encoder.device" This reverts commit de5aa7f5571daf905609d3883ba5f99db3c8ec52. --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 9075da938eaa..b844a870362e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -143,8 +143,7 @@ def __init__( self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self._execution_device = text_encoder.device - + def check_inputs( self, image, From f54c1579fb9fc9e6b9bdffccd16dd2be3270d156 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 12 May 2025 16:06:59 +0800 Subject: [PATCH 329/482] use self.text_encoder.device --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index b844a870362e..123abea73874 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -204,7 +204,7 @@ def _get_t5_prompt_embeds( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): - device = device or self._execution_device + device = self.text_encoder.device #device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -252,7 +252,7 @@ def _get_clip_prompt_embeds( num_images_per_prompt: int = 1, device: Optional[torch.device] = None, ): - device = device or self._execution_device + device = self.text_encoder.device#device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -323,7 +323,7 @@ def encode_prompt( lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ - device = device or self._execution_device + device = self.text_encoder.device #device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it @@ -464,7 +464,7 @@ def __call__( if isinstance(pooled_prompt_embeds_scale, float): pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale] - device = self._execution_device + device = self.text_encoder.device#self._execution_device # 3. Prepare image embeddings # thesea modified for ip and txt masks From 66acb2d667e8cc1affa5450a86a98a5631857505 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 12 May 2025 16:08:08 +0800 Subject: [PATCH 330/482] Revert "use self.text_encoder.device" This reverts commit f54c1579fb9fc9e6b9bdffccd16dd2be3270d156. --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 123abea73874..b844a870362e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -204,7 +204,7 @@ def _get_t5_prompt_embeds( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): - device = self.text_encoder.device #device or self._execution_device + device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -252,7 +252,7 @@ def _get_clip_prompt_embeds( num_images_per_prompt: int = 1, device: Optional[torch.device] = None, ): - device = self.text_encoder.device#device or self._execution_device + device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -323,7 +323,7 @@ def encode_prompt( lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ - device = self.text_encoder.device #device or self._execution_device + device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it @@ -464,7 +464,7 @@ def __call__( if isinstance(pooled_prompt_embeds_scale, float): pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale] - device = self.text_encoder.device#self._execution_device + device = self._execution_device # 3. Prepare image embeddings # thesea modified for ip and txt masks From 03476ba924e82fc56158ffdc10c44d1a881174cd Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 12 May 2025 16:30:35 +0800 Subject: [PATCH 331/482] Revert "fix bug" This reverts commit 8bbaedf7a9fcf503f459d51444d87c7096b3edf2. --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 283d38228a21..d5d7b3181fdb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1139,7 +1139,6 @@ def mask_gradienting(mask, iterations=3): return final_mask_array # mask graident - mask = mask.to(dtype=torch.float16) tmp_mask = mask[0,:,:].cpu().numpy() mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) tmp_tensor = torch.Tensor(mask_gradient) From 261c595e4ca8a2a3da1fc1eac6cd6c3db103e7a4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 15 May 2025 17:33:20 +0900 Subject: [PATCH 332/482] default iterations = 15 --- src/diffusers/pipelines/flux/pipeline_flux_fill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 9cd465f7e0ba..3d2692e24d19 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -752,7 +752,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - iterations: Optional[int] = 30, + iterations: Optional[int] = 15, ): r""" Function invoked when calling the pipeline for generation. From b5f8e02df9c6f314684063ae28a7dfa9a3188524 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 22 May 2025 15:07:00 +0900 Subject: [PATCH 333/482] update for element --- .../pipelines/flux/pipeline_flux_prior_redux.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index b844a870362e..a25edf832837 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -395,6 +395,7 @@ def __call__( is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots is_inpainting: Optional[bool] = False, # controlnet inpainting + contains_element: Optional[bool] = False, # controlnet inpainting for element iterations: Optional[int] = 20, # controlnet inpainting mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, @@ -535,6 +536,7 @@ def __call__( composed_image_all = np.zeros((image_width, image_height, 3)) masked_bg = np.zeros((image_width, image_height, 3)) + masked_bg_with_element = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): @@ -547,9 +549,13 @@ def __call__( if is_product.lower() == "true": masked_bg += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) + if index > 0: + masked_bg_with_element += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) + composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') + masked_bg_with_element = Image.fromarray(masked_bg_with_element.astype(np.uint8)).convert('RGB') bg_mask = Image.fromarray(bg_mask.astype(np.uint8)*255).convert('RGB') prod_masks = [] @@ -649,7 +655,10 @@ def __call__( if not return_dict: if is_qv: if is_inpainting: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + if contains_element: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_with_element, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + else: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: From 011a0b478b9b1a349940333739e8e6d39eba237e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 23 May 2025 12:04:21 +0900 Subject: [PATCH 334/482] fix dtype bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d5d7b3181fdb..283d38228a21 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1139,6 +1139,7 @@ def mask_gradienting(mask, iterations=3): return final_mask_array # mask graident + mask = mask.to(dtype=torch.float16) tmp_mask = mask[0,:,:].cpu().numpy() mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) tmp_tensor = torch.Tensor(mask_gradient) From e13df7bf5f75de6ccc299fc47a4623d7c76127d4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 23 May 2025 12:06:47 +0900 Subject: [PATCH 335/482] fix dtype bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 283d38228a21..3f011e039554 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1148,7 +1148,7 @@ def mask_gradienting(mask, iterations=3): for index in range(mask_tensor.shape[0]): mask_tensor[index,:,:] = tmp_tensor #tmp_tensor = torch.stack([tmp_tensor, tmp_tensor, tmp_tensor], dim=0) - mask = mask_tensor.to(device=mask.device, dtype=mask.dtype) + mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) controlnet_keep = [] for i in range(len(timesteps)): From e56722d3315351f929e3f2f33ce13362236efae0 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 15:36:30 +0900 Subject: [PATCH 336/482] add prod reference img --- .../pipeline_flux_controlnet_inpainting.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 3f011e039554..2371882c8ac4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -743,6 +743,8 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, + image_ref_prod: PipelineImageInput = None, + ratio_ref_prod: Optional[float] = 0.1, mask_image: PipelineImageInput = None, masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, @@ -925,6 +927,12 @@ def __call__( ) init_image = init_image.to(dtype=torch.float32) + if image_ref_prod is not None: + init_image_ref = self.image_processor.preprocess( + image_ref_prod, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image_ref = init_image_ref.to(dtype=torch.float32) + # 5. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): @@ -1055,6 +1063,20 @@ def __call__( latents, ) + if image_ref_prod is not None: + _, _, image_latents_ref, _ = self.prepare_latents( + init_image_ref, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # 8. Prepare mask latents mask_condition = self.mask_processor.preprocess( mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords @@ -1230,15 +1252,25 @@ def mask_gradienting(mask, iterations=3): # For inpainting, we need to apply the mask and add the masked image latents init_latents_proper = image_latents + if image_ref_prod is not None: + init_latents_proper_ref = image_latents_ref init_mask = mask + if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.scale_noise( init_latents_proper, torch.tensor([noise_timestep]), noise ) - - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + if image_ref_prod is not None: + init_latents_proper_ref = self.scheduler.scale_noise( + init_latents_proper_ref, torch.tensor([noise_timestep]), noise + ) + + if image_ref_prod is not None: + latents = (1 - init_mask) * init_latents_proper + (1.0 - ratio_ref_prod) * init_mask * latents + ratio_ref_prod * init_mask * init_latents_proper_ref + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 61ded8cd2ac457d0f9a102ac62b1e911c05f9622 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 15:44:14 +0900 Subject: [PATCH 337/482] return composed_prod_images_all --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index a25edf832837..9c041c7fa16c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -539,9 +539,11 @@ def __call__( masked_bg_with_element = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] + composed_prod_images_all = np.zeros((image_width, image_height, 3)) for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": composed_prod_images.append(Image.fromarray(img_array.astype(np.uint8))) + composed_prod_images_all += img_array * image_mask_bg[index] else: composed_bg_image += img_array * image_mask_bg[index] @@ -553,6 +555,7 @@ def __call__( masked_bg_with_element += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') + composed_prod_images_all = Image.fromarray(composed_prod_images_all.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') masked_bg_with_element = Image.fromarray(masked_bg_with_element.astype(np.uint8)).convert('RGB') @@ -656,9 +659,9 @@ def __call__( if is_qv: if is_inpainting: if contains_element: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_with_element, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_with_element, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) else: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) else: return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: From 669e3327f94d20ae8e9920f194aa7a3f1c908cd6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 15:46:04 +0900 Subject: [PATCH 338/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 9c041c7fa16c..3b3d5bd504da 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -543,7 +543,7 @@ def __call__( for index, (is_product, img_array) in enumerate(zip(is_product_list, image_array_list)): if is_product.lower() == "true": composed_prod_images.append(Image.fromarray(img_array.astype(np.uint8))) - composed_prod_images_all += img_array * image_mask_bg[index] + composed_prod_images_all += img_array * image_mask_prod[index] else: composed_bg_image += img_array * image_mask_bg[index] From 76d0f8713f707ad1fea54e7d7aac9fb014b04159 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 15:52:27 +0900 Subject: [PATCH 339/482] use original mask for ref prod img --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 2371882c8ac4..402de98632ec 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -5,6 +5,7 @@ import numpy as np import PIL import torch +import copy from transformers import ( CLIPTextModel, CLIPTokenizer, @@ -1161,6 +1162,7 @@ def mask_gradienting(mask, iterations=3): return final_mask_array # mask graident + mask_original = copy.deepcopy(mask) mask = mask.to(dtype=torch.float16) tmp_mask = mask[0,:,:].cpu().numpy() mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) @@ -1255,7 +1257,7 @@ def mask_gradienting(mask, iterations=3): if image_ref_prod is not None: init_latents_proper_ref = image_latents_ref init_mask = mask - + init_mask_original = mask_original if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] @@ -1268,7 +1270,7 @@ def mask_gradienting(mask, iterations=3): ) if image_ref_prod is not None: - latents = (1 - init_mask) * init_latents_proper + (1.0 - ratio_ref_prod) * init_mask * latents + ratio_ref_prod * init_mask * init_latents_proper_ref + latents = (1 - init_mask) * init_latents_proper + (1.0 - ratio_ref_prod) * init_mask * latents + ratio_ref_prod * init_mask_original * init_latents_proper_ref else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents From eaf17e32e42081c7963e3912598c1411afc7f57d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 16:06:10 +0900 Subject: [PATCH 340/482] fix bug --- .../pipeline_flux_controlnet_inpainting.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 402de98632ec..6e325bddf779 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -772,6 +772,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, iterations: Optional[int] = 3, + iterations_prod_erosion: Optional[int] = 3, ): """ Function invoked when calling the pipeline for generation. @@ -1110,6 +1111,15 @@ def apply_dilate_to_mask_image(mask,iterations): return mask + def apply_erosion_to_mask_image(mask,iterations): + kernel = np.ones((3, 3), np.uint8) + mask = np.array(mask, dtype=bool) + mask = mask.astype(np.uint8) + mask = cv2.erode(mask, kernel, iterations=iterations) + mask = mask.astype(np.uint8) + + return mask + # input np array size: 4096 x 3 # output np array size: 64 x 64 x 3 def mask_gradienting(mask, iterations=3): @@ -1162,7 +1172,6 @@ def mask_gradienting(mask, iterations=3): return final_mask_array # mask graident - mask_original = copy.deepcopy(mask) mask = mask.to(dtype=torch.float16) tmp_mask = mask[0,:,:].cpu().numpy() mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) @@ -1174,6 +1183,16 @@ def mask_gradienting(mask, iterations=3): #tmp_tensor = torch.stack([tmp_tensor, tmp_tensor, tmp_tensor], dim=0) mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) + # mask for prod reference img: + tmp_mask_ref_prod = apply_dilate_to_mask_image(tmp_mask.reshape(64,64,-1), iterations=iterations_prod_erosion) + tmp_mask_ref_prod = torch.Tensor(tmp_mask_ref_prod) + tmp_mask_ref_prod = tmp_mask_ref_prod.view(-1,64) + + mask_ref_prod = torch.zeros_like(mask) + for index in range(mask_ref_prod.shape[0]): + mask_ref_prod[index,:,:] = tmp_tensor + mask_ref_prod = mask_ref_prod.to(device=latents.device, dtype=latents.dtype) + controlnet_keep = [] for i in range(len(timesteps)): keeps = [ @@ -1257,7 +1276,7 @@ def mask_gradienting(mask, iterations=3): if image_ref_prod is not None: init_latents_proper_ref = image_latents_ref init_mask = mask - init_mask_original = mask_original + init_mask_ref_prod = mask_ref_prod if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] @@ -1270,7 +1289,7 @@ def mask_gradienting(mask, iterations=3): ) if image_ref_prod is not None: - latents = (1 - init_mask) * init_latents_proper + (1.0 - ratio_ref_prod) * init_mask * latents + ratio_ref_prod * init_mask_original * init_latents_proper_ref + latents = (1 - init_mask) * init_latents_proper + (1.0 - ratio_ref_prod) * init_mask * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents From 965bd970210493909eb3afb5d357f2f5c2a28ede Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 16:10:42 +0900 Subject: [PATCH 341/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 6e325bddf779..ec72f1c5cc65 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1184,7 +1184,7 @@ def mask_gradienting(mask, iterations=3): mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) # mask for prod reference img: - tmp_mask_ref_prod = apply_dilate_to_mask_image(tmp_mask.reshape(64,64,-1), iterations=iterations_prod_erosion) + tmp_mask_ref_prod = apply_erosion_to_mask_image(tmp_mask.reshape(64,64,-1), iterations=iterations_prod_erosion) tmp_mask_ref_prod = torch.Tensor(tmp_mask_ref_prod) tmp_mask_ref_prod = tmp_mask_ref_prod.view(-1,64) From 577b535e62698f021b9c4cd51ae6c80557218315 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 16:15:39 +0900 Subject: [PATCH 342/482] fix renference bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index ec72f1c5cc65..ae5ec5fb80be 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1190,7 +1190,7 @@ def mask_gradienting(mask, iterations=3): mask_ref_prod = torch.zeros_like(mask) for index in range(mask_ref_prod.shape[0]): - mask_ref_prod[index,:,:] = tmp_tensor + mask_ref_prod[index,:,:] = tmp_mask_ref_prod mask_ref_prod = mask_ref_prod.to(device=latents.device, dtype=latents.dtype) controlnet_keep = [] From 5b4506eba55f1807301cf771bd28402510f1f85b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 16:25:34 +0900 Subject: [PATCH 343/482] improve logic --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index ae5ec5fb80be..958dd1b87360 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1289,7 +1289,11 @@ def mask_gradienting(mask, iterations=3): ) if image_ref_prod is not None: - latents = (1 - init_mask) * init_latents_proper + (1.0 - ratio_ref_prod) * init_mask * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref + latents_1 = (1 - init_mask) * init_latents_proper + latents_2 = (init_mask - init_mask_ref_prod) * latents + latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref + + latents = latents_1 + latents_2 + latents_3 else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents From 773b4424a5c94497df61b894daf3c8a33d50a204 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 16:35:38 +0900 Subject: [PATCH 344/482] adjust kernel size --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 958dd1b87360..fd7f060741ec 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1112,7 +1112,7 @@ def apply_dilate_to_mask_image(mask,iterations): return mask def apply_erosion_to_mask_image(mask,iterations): - kernel = np.ones((3, 3), np.uint8) + kernel = np.ones((4, 4), np.uint8) mask = np.array(mask, dtype=bool) mask = mask.astype(np.uint8) mask = cv2.erode(mask, kernel, iterations=iterations) From b483f5a91a1a4ec725f625f8842e10f68048813a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 16:39:30 +0900 Subject: [PATCH 345/482] adjust kernel size --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index fd7f060741ec..381efebef260 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1112,7 +1112,7 @@ def apply_dilate_to_mask_image(mask,iterations): return mask def apply_erosion_to_mask_image(mask,iterations): - kernel = np.ones((4, 4), np.uint8) + kernel = np.ones((5, 5), np.uint8) mask = np.array(mask, dtype=bool) mask = mask.astype(np.uint8) mask = cv2.erode(mask, kernel, iterations=iterations) From a65e5e6f2e01152625bdc6a0c721853c1576fc1f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 16:49:33 +0900 Subject: [PATCH 346/482] return masked_bg_original --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 3b3d5bd504da..d1fe8a74f5e9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -536,6 +536,7 @@ def __call__( composed_image_all = np.zeros((image_width, image_height, 3)) masked_bg = np.zeros((image_width, image_height, 3)) + masked_bg_original = np.zeros((image_width, image_height, 3)) masked_bg_with_element = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] @@ -550,6 +551,7 @@ def __call__( composed_image_all += img_array * image_mask_all[index] if is_product.lower() == "true": masked_bg += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) + masked_bg_original += mask_value*np.ones((image_width, image_height, 3)) * image_mask_all[index] if index > 0: masked_bg_with_element += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) @@ -558,6 +560,7 @@ def __call__( composed_prod_images_all = Image.fromarray(composed_prod_images_all.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') + masked_bg_original = Image.fromarray(masked_bg_original.astype(np.uint8)).convert('RGB') masked_bg_with_element = Image.fromarray(masked_bg_with_element.astype(np.uint8)).convert('RGB') bg_mask = Image.fromarray(bg_mask.astype(np.uint8)*255).convert('RGB') @@ -659,9 +662,9 @@ def __call__( if is_qv: if is_inpainting: if contains_element: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_with_element, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, masked_bg_with_element, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) else: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) else: return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: From f2fee4b7ee5df8b6eef6870e72fe34a6bd41f3e9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 17:09:09 +0900 Subject: [PATCH 347/482] add codes --- .../pipeline_flux_controlnet_inpainting.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 381efebef260..9db1f34a4c83 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -747,6 +747,7 @@ def __call__( image_ref_prod: PipelineImageInput = None, ratio_ref_prod: Optional[float] = 0.1, mask_image: PipelineImageInput = None, + mask_image_original: PipelineImageInput = None, masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, @@ -773,6 +774,7 @@ def __call__( max_sequence_length: int = 512, iterations: Optional[int] = 3, iterations_prod_erosion: Optional[int] = 3, + kernel_prod_erosion: Optional[int] = 3, ): """ Function invoked when calling the pipeline for generation. @@ -1087,6 +1089,15 @@ def __call__( masked_image = init_image * (mask_condition < 0.5) else: masked_image = masked_image_latents + + if mask_image_original is not None: + mask_condition_original = self.mask_processor.preprocess( + mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image_original = init_image * (mask_condition_original < 0.5) + else: + masked_image_original = masked_image_latents mask, masked_image_latents = self.prepare_mask_latents( mask_condition, @@ -1101,6 +1112,20 @@ def __call__( generator, ) + if mask_image_original is not None: + mask_original, _ = self.prepare_mask_latents( + mask_condition_original, + masked_image_original, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + def apply_dilate_to_mask_image(mask,iterations): kernel = np.ones((3, 3), np.uint8) mask = np.array(mask, dtype=bool) @@ -1111,8 +1136,8 @@ def apply_dilate_to_mask_image(mask,iterations): return mask - def apply_erosion_to_mask_image(mask,iterations): - kernel = np.ones((5, 5), np.uint8) + def apply_erosion_to_mask_image(mask, iterations, kernel=kernel_prod_erosion): + kernel = np.ones((kernel_prod_erosion, kernel_prod_erosion), np.uint8) mask = np.array(mask, dtype=bool) mask = mask.astype(np.uint8) mask = cv2.erode(mask, kernel, iterations=iterations) @@ -1184,11 +1209,13 @@ def mask_gradienting(mask, iterations=3): mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) # mask for prod reference img: + mask_original = mask_original.to(dtype=torch.float16) + tmp_mask = mask_original[0,:,:].cpu().numpy() tmp_mask_ref_prod = apply_erosion_to_mask_image(tmp_mask.reshape(64,64,-1), iterations=iterations_prod_erosion) tmp_mask_ref_prod = torch.Tensor(tmp_mask_ref_prod) tmp_mask_ref_prod = tmp_mask_ref_prod.view(-1,64) - mask_ref_prod = torch.zeros_like(mask) + mask_ref_prod = torch.zeros_like(mask_original) for index in range(mask_ref_prod.shape[0]): mask_ref_prod[index,:,:] = tmp_mask_ref_prod mask_ref_prod = mask_ref_prod.to(device=latents.device, dtype=latents.dtype) From bd65e21bf856000fe04ff0d2fe2ad3f3edc3f528 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 17:09:20 +0900 Subject: [PATCH 348/482] Revert "add codes" This reverts commit f2fee4b7ee5df8b6eef6870e72fe34a6bd41f3e9. --- .../pipeline_flux_controlnet_inpainting.py | 33 ++----------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 9db1f34a4c83..381efebef260 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -747,7 +747,6 @@ def __call__( image_ref_prod: PipelineImageInput = None, ratio_ref_prod: Optional[float] = 0.1, mask_image: PipelineImageInput = None, - mask_image_original: PipelineImageInput = None, masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, @@ -774,7 +773,6 @@ def __call__( max_sequence_length: int = 512, iterations: Optional[int] = 3, iterations_prod_erosion: Optional[int] = 3, - kernel_prod_erosion: Optional[int] = 3, ): """ Function invoked when calling the pipeline for generation. @@ -1089,15 +1087,6 @@ def __call__( masked_image = init_image * (mask_condition < 0.5) else: masked_image = masked_image_latents - - if mask_image_original is not None: - mask_condition_original = self.mask_processor.preprocess( - mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image_original = init_image * (mask_condition_original < 0.5) - else: - masked_image_original = masked_image_latents mask, masked_image_latents = self.prepare_mask_latents( mask_condition, @@ -1112,20 +1101,6 @@ def __call__( generator, ) - if mask_image_original is not None: - mask_original, _ = self.prepare_mask_latents( - mask_condition_original, - masked_image_original, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) - def apply_dilate_to_mask_image(mask,iterations): kernel = np.ones((3, 3), np.uint8) mask = np.array(mask, dtype=bool) @@ -1136,8 +1111,8 @@ def apply_dilate_to_mask_image(mask,iterations): return mask - def apply_erosion_to_mask_image(mask, iterations, kernel=kernel_prod_erosion): - kernel = np.ones((kernel_prod_erosion, kernel_prod_erosion), np.uint8) + def apply_erosion_to_mask_image(mask,iterations): + kernel = np.ones((5, 5), np.uint8) mask = np.array(mask, dtype=bool) mask = mask.astype(np.uint8) mask = cv2.erode(mask, kernel, iterations=iterations) @@ -1209,13 +1184,11 @@ def mask_gradienting(mask, iterations=3): mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) # mask for prod reference img: - mask_original = mask_original.to(dtype=torch.float16) - tmp_mask = mask_original[0,:,:].cpu().numpy() tmp_mask_ref_prod = apply_erosion_to_mask_image(tmp_mask.reshape(64,64,-1), iterations=iterations_prod_erosion) tmp_mask_ref_prod = torch.Tensor(tmp_mask_ref_prod) tmp_mask_ref_prod = tmp_mask_ref_prod.view(-1,64) - mask_ref_prod = torch.zeros_like(mask_original) + mask_ref_prod = torch.zeros_like(mask) for index in range(mask_ref_prod.shape[0]): mask_ref_prod[index,:,:] = tmp_mask_ref_prod mask_ref_prod = mask_ref_prod.to(device=latents.device, dtype=latents.dtype) From 08f1e965fed9f6f4f76e7755d887b4d625a66937 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 17:10:01 +0900 Subject: [PATCH 349/482] Reapply "add codes" This reverts commit bd65e21bf856000fe04ff0d2fe2ad3f3edc3f528. --- .../pipeline_flux_controlnet_inpainting.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 381efebef260..9db1f34a4c83 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -747,6 +747,7 @@ def __call__( image_ref_prod: PipelineImageInput = None, ratio_ref_prod: Optional[float] = 0.1, mask_image: PipelineImageInput = None, + mask_image_original: PipelineImageInput = None, masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, @@ -773,6 +774,7 @@ def __call__( max_sequence_length: int = 512, iterations: Optional[int] = 3, iterations_prod_erosion: Optional[int] = 3, + kernel_prod_erosion: Optional[int] = 3, ): """ Function invoked when calling the pipeline for generation. @@ -1087,6 +1089,15 @@ def __call__( masked_image = init_image * (mask_condition < 0.5) else: masked_image = masked_image_latents + + if mask_image_original is not None: + mask_condition_original = self.mask_processor.preprocess( + mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image_original = init_image * (mask_condition_original < 0.5) + else: + masked_image_original = masked_image_latents mask, masked_image_latents = self.prepare_mask_latents( mask_condition, @@ -1101,6 +1112,20 @@ def __call__( generator, ) + if mask_image_original is not None: + mask_original, _ = self.prepare_mask_latents( + mask_condition_original, + masked_image_original, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + def apply_dilate_to_mask_image(mask,iterations): kernel = np.ones((3, 3), np.uint8) mask = np.array(mask, dtype=bool) @@ -1111,8 +1136,8 @@ def apply_dilate_to_mask_image(mask,iterations): return mask - def apply_erosion_to_mask_image(mask,iterations): - kernel = np.ones((5, 5), np.uint8) + def apply_erosion_to_mask_image(mask, iterations, kernel=kernel_prod_erosion): + kernel = np.ones((kernel_prod_erosion, kernel_prod_erosion), np.uint8) mask = np.array(mask, dtype=bool) mask = mask.astype(np.uint8) mask = cv2.erode(mask, kernel, iterations=iterations) @@ -1184,11 +1209,13 @@ def mask_gradienting(mask, iterations=3): mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) # mask for prod reference img: + mask_original = mask_original.to(dtype=torch.float16) + tmp_mask = mask_original[0,:,:].cpu().numpy() tmp_mask_ref_prod = apply_erosion_to_mask_image(tmp_mask.reshape(64,64,-1), iterations=iterations_prod_erosion) tmp_mask_ref_prod = torch.Tensor(tmp_mask_ref_prod) tmp_mask_ref_prod = tmp_mask_ref_prod.view(-1,64) - mask_ref_prod = torch.zeros_like(mask) + mask_ref_prod = torch.zeros_like(mask_original) for index in range(mask_ref_prod.shape[0]): mask_ref_prod[index,:,:] = tmp_mask_ref_prod mask_ref_prod = mask_ref_prod.to(device=latents.device, dtype=latents.dtype) From 8aac8d3a0d2873b462dc68016ea6c1154f1d7eaa Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 17:16:22 +0900 Subject: [PATCH 350/482] apply erode to mask --- .../pipelines/flux/pipeline_flux_prior_redux.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index d1fe8a74f5e9..0cd95dd2592f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -379,6 +379,15 @@ def apply_dilate_to_mask(self, mask,iterations = 10): return mask + def apply_erosion_to_mask(self, mask,iterations = 10): + + kernel = np.ones((2, 2), np.uint8) + mask = mask.astype(np.uint8) + mask = cv2.erode(mask, kernel, iterations=iterations) + mask = np.array(mask, dtype=bool) + + return mask + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -397,6 +406,7 @@ def __call__( is_inpainting: Optional[bool] = False, # controlnet inpainting contains_element: Optional[bool] = False, # controlnet inpainting for element iterations: Optional[int] = 20, # controlnet inpainting + iterations_erosion: Optional[int] = 20, # controlnet inpainting mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, @@ -551,7 +561,7 @@ def __call__( composed_image_all += img_array * image_mask_all[index] if is_product.lower() == "true": masked_bg += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) - masked_bg_original += mask_value*np.ones((image_width, image_height, 3)) * image_mask_all[index] + masked_bg_original += mask_value*np.ones((image_width, image_height, 3)) * self.apply_erosion_to_mask(image_mask_all[index], iterations=iterations_erosion) if index > 0: masked_bg_with_element += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) From 4b3574a5d4927d6bfd6bd77a4c08e6fad2d090d0 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 17:28:03 +0900 Subject: [PATCH 351/482] clear codes --- .../pipeline_flux_controlnet_inpainting.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 9db1f34a4c83..2116db35ce55 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -772,9 +772,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - iterations: Optional[int] = 3, - iterations_prod_erosion: Optional[int] = 3, - kernel_prod_erosion: Optional[int] = 3, + iterations: Optional[int] = 4, ): """ Function invoked when calling the pipeline for generation. @@ -1136,15 +1134,6 @@ def apply_dilate_to_mask_image(mask,iterations): return mask - def apply_erosion_to_mask_image(mask, iterations, kernel=kernel_prod_erosion): - kernel = np.ones((kernel_prod_erosion, kernel_prod_erosion), np.uint8) - mask = np.array(mask, dtype=bool) - mask = mask.astype(np.uint8) - mask = cv2.erode(mask, kernel, iterations=iterations) - mask = mask.astype(np.uint8) - - return mask - # input np array size: 4096 x 3 # output np array size: 64 x 64 x 3 def mask_gradienting(mask, iterations=3): @@ -1208,18 +1197,6 @@ def mask_gradienting(mask, iterations=3): #tmp_tensor = torch.stack([tmp_tensor, tmp_tensor, tmp_tensor], dim=0) mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) - # mask for prod reference img: - mask_original = mask_original.to(dtype=torch.float16) - tmp_mask = mask_original[0,:,:].cpu().numpy() - tmp_mask_ref_prod = apply_erosion_to_mask_image(tmp_mask.reshape(64,64,-1), iterations=iterations_prod_erosion) - tmp_mask_ref_prod = torch.Tensor(tmp_mask_ref_prod) - tmp_mask_ref_prod = tmp_mask_ref_prod.view(-1,64) - - mask_ref_prod = torch.zeros_like(mask_original) - for index in range(mask_ref_prod.shape[0]): - mask_ref_prod[index,:,:] = tmp_mask_ref_prod - mask_ref_prod = mask_ref_prod.to(device=latents.device, dtype=latents.dtype) - controlnet_keep = [] for i in range(len(timesteps)): keeps = [ @@ -1303,7 +1280,7 @@ def mask_gradienting(mask, iterations=3): if image_ref_prod is not None: init_latents_proper_ref = image_latents_ref init_mask = mask - init_mask_ref_prod = mask_ref_prod + init_mask_ref_prod = mask_original if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] From 8d6d073e079f092b29d6dea8f82ba229892d83d6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 17:51:52 +0900 Subject: [PATCH 352/482] change kernel size --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 0cd95dd2592f..7be33f076442 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -381,7 +381,7 @@ def apply_dilate_to_mask(self, mask,iterations = 10): def apply_erosion_to_mask(self, mask,iterations = 10): - kernel = np.ones((2, 2), np.uint8) + kernel = np.ones((3, 3), np.uint8) mask = mask.astype(np.uint8) mask = cv2.erode(mask, kernel, iterations=iterations) mask = np.array(mask, dtype=bool) From 5deb3f7af8f95a7b9e32f4b0fb12190d5d1e083c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 18:24:22 +0900 Subject: [PATCH 353/482] clear codes --- .../flux/pipeline_flux_controlnet_inpainting.py | 16 ++++++++-------- .../pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 2116db35ce55..64fc28d08409 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -744,10 +744,10 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - image_ref_prod: PipelineImageInput = None, - ratio_ref_prod: Optional[float] = 0.1, + image_ref_prod: PipelineImageInput = None, # original pord image + ratio_ref_prod: Optional[float] = 0.1, # modified for injecting pord image mask_image: PipelineImageInput = None, - mask_image_original: PipelineImageInput = None, + mask_image_original: PipelineImageInput = None, # modified for injecting original pord images masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, @@ -772,7 +772,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - iterations: Optional[int] = 4, + iterations: Optional[int] = 4, # modified for applying gradient to mask_image ): """ Function invoked when calling the pipeline for generation. @@ -1088,7 +1088,7 @@ def __call__( else: masked_image = masked_image_latents - if mask_image_original is not None: + if image_ref_prod is not None: mask_condition_original = self.mask_processor.preprocess( mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords ) @@ -1110,7 +1110,7 @@ def __call__( generator, ) - if mask_image_original is not None: + if image_ref_prod is not None: mask_original, _ = self.prepare_mask_latents( mask_condition_original, masked_image_original, @@ -1194,7 +1194,6 @@ def mask_gradienting(mask, iterations=3): mask_tensor = torch.zeros_like(mask) for index in range(mask_tensor.shape[0]): mask_tensor[index,:,:] = tmp_tensor - #tmp_tensor = torch.stack([tmp_tensor, tmp_tensor, tmp_tensor], dim=0) mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) controlnet_keep = [] @@ -1280,7 +1279,8 @@ def mask_gradienting(mask, iterations=3): if image_ref_prod is not None: init_latents_proper_ref = image_latents_ref init_mask = mask - init_mask_ref_prod = mask_original + if image_ref_prod is not None: + init_mask_ref_prod = mask_original if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 7be33f076442..a40f09415949 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -406,7 +406,7 @@ def __call__( is_inpainting: Optional[bool] = False, # controlnet inpainting contains_element: Optional[bool] = False, # controlnet inpainting for element iterations: Optional[int] = 20, # controlnet inpainting - iterations_erosion: Optional[int] = 20, # controlnet inpainting + iterations_erosion: Optional[int] = 10, # modified for injecting original prod image mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, From c7054fd79b2e42df944c1f762db76ab4305a11ab Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 18:25:26 +0900 Subject: [PATCH 354/482] change default value --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index a40f09415949..b9f80c97ee3a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -406,7 +406,7 @@ def __call__( is_inpainting: Optional[bool] = False, # controlnet inpainting contains_element: Optional[bool] = False, # controlnet inpainting for element iterations: Optional[int] = 20, # controlnet inpainting - iterations_erosion: Optional[int] = 10, # modified for injecting original prod image + iterations_erosion: Optional[int] = 7, # modified for injecting original prod image mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, From 02ce8f94941e0cd9bd80d734c05231b83abe9546 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 29 May 2025 18:50:20 +0900 Subject: [PATCH 355/482] adjust default values --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 64fc28d08409..91c8c63f8621 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -745,7 +745,7 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, image_ref_prod: PipelineImageInput = None, # original pord image - ratio_ref_prod: Optional[float] = 0.1, # modified for injecting pord image + ratio_ref_prod: Optional[float] = 0.125, # modified for injecting pord image mask_image: PipelineImageInput = None, mask_image_original: PipelineImageInput = None, # modified for injecting original pord images masked_image_latents: PipelineImageInput = None, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index b9f80c97ee3a..da9853a03994 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -405,8 +405,8 @@ def __call__( product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots is_inpainting: Optional[bool] = False, # controlnet inpainting contains_element: Optional[bool] = False, # controlnet inpainting for element - iterations: Optional[int] = 20, # controlnet inpainting - iterations_erosion: Optional[int] = 7, # modified for injecting original prod image + iterations: Optional[int] = 10, # controlnet inpainting + iterations_erosion: Optional[int] = 8, # modified for injecting original prod image mask_value: Optional[int] = 255, # controlnet inpainting image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, From 9d1c74ca987b0c96f29a3b107d67170bf54bac97 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 30 May 2025 12:53:52 +0900 Subject: [PATCH 356/482] add enhanced blend script --- .../flux/pipeline_flux_controlnet_enhanced.py | 1380 +++++++++++++++++ 1 file changed, 1380 insertions(+) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py new file mode 100644 index 000000000000..1899a961ab1b --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py @@ -0,0 +1,1380 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image + >>> from diffusers import FluxControlNetPipeline + >>> from diffusers import FluxControlNetModel + + >>> base_model = "black-forest-labs/FLUX.1-dev" + >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" + >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) + >>> pipe = FluxControlNetPipeline.from_pretrained( + ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> image = pipe( + ... prompt, + ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=1.0, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if image is not None: + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + image_latents = None + noise = None + + return latents.to(device=device, dtype=dtype), noise, image_latents, latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if image is not None: + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + image_latents = None + noise = None + + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + ratio_ref: Optional[float] = 0.1, + mask_image: PipelineImageInput = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + strength: Optional[float] = 0.6, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + padding_mask_crop: Optional[int] = None, + guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_image: PipelineImageInput = None, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_mode (`int` or `List[int]`,, *optional*, defaults to None): + The control mode when applying ControlNet-Union. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + global_height = height + global_width = width + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = 1 #len(prompt) # thesea modified for text prompt mask + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + ## thesea modified for text prompt mask + prompt_embeds_list = [] + text_ids_list = [] + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=pmt, + prompt_2=prompt_2, + #prompt_embeds=prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + prompt_embeds_list.append(prompt_embeds) + text_ids_list.append(text_ids) + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + text_ids = torch.cat(text_ids_list, dim=0) + else: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # Preprocess mask and image + if image is not None: + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # Here we ensure that `control_mode` has the same length as the control_image. + if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + control_images.append(control_image_) + + control_image = control_images + + # Here we ensure that `control_mode` has the same length as the control_image. + if isinstance(control_mode, list) and len(control_mode) != len(control_image): + raise ValueError( + "For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified" + ) + if not isinstance(control_mode, list): + control_mode = [control_mode] * len(control_image) + # set control mode + control_modes = [] + for cmode in control_mode: + if cmode is None: + cmode = -1 + control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long) + control_modes.append(control_mode) + control_mode = control_modes + + # Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if image is not None: + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Prepare mask latents + if image is not None: + mask_condition = self.mask_processor.preprocess( + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + + guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # additional codes for injecting original prod images into latents + if image is not None: + init_mask = mask + init_latents_proper = image_latents + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_2 = latents_1 + (1.0 - init_mask) * latents + + latents = latents_1 + latents_2 + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file From 03bce155b9882c4e1205cc91359df3d06a541545 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 16:04:00 +0900 Subject: [PATCH 357/482] return prod masks original --- .../pipelines/flux/pipeline_flux_prior_redux.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index da9853a03994..11b1ac76b8dc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -512,6 +512,7 @@ def __call__( bg_mask = np.full((image_width, image_height, 3), True, dtype=bool) image_mask_prod = {} + image_mask_prod_original = {} image_mask_bg = {} image_mask_all = {} for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): @@ -529,6 +530,7 @@ def __call__( if is_product.lower() == "true": if index not in image_mask_prod: image_mask_prod[index] = mask + image_mask_prod_original[index] = mask else: if index not in image_mask_bg: image_mask_bg[index] = mask @@ -577,6 +579,10 @@ def __call__( prod_masks = [] for tmp_mask in image_mask_prod: prod_masks.append(Image.fromarray(image_mask_prod[tmp_mask].astype(np.uint8)*255).convert('RGB')) + + prod_masks_original = [] + for tmp_mask in image_mask_prod_original: + prod_masks_original.append(Image.fromarray(image_mask_prod_original[tmp_mask].astype(np.uint8)*255).convert('RGB')) image_latents_bg = self.encode_image(composed_bg_image, device, 1) image_latents_prods = [] @@ -672,9 +678,9 @@ def __call__( if is_qv: if is_inpainting: if contains_element: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, masked_bg_with_element, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, masked_bg_with_element, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, prod_masks_original, bg_mask) else: - return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, bg_mask) + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, prod_masks_original, bg_mask) else: return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) else: From da5fb2ccefb6ed571795ae5328aa78bffebb7a3a Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 17:43:36 +0900 Subject: [PATCH 358/482] add averaging logic for multiple copies of the same product --- .../pipeline_flux_controlnet_inpainting.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 91c8c63f8621..b00a70484c6e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -773,6 +773,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, iterations: Optional[int] = 4, # modified for applying gradient to mask_image + averaging_steps: Optional[int] = 0, # modified for applying averaging latents for multiple copies of same product ): """ Function invoked when calling the pipeline for generation. @@ -1088,6 +1089,15 @@ def __call__( else: masked_image = masked_image_latents + if "prod_masks" in joint_attention_kwargs: + mask_conditions_prod = [] + for prod_mask in joint_attention_kwargs["prod_masks"]: + tmp_mask_condition = self.mask_processor.preprocess( + prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + mask_conditions_prod.append(tmp_mask_condition) + + if image_ref_prod is not None: mask_condition_original = self.mask_processor.preprocess( mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords @@ -1110,6 +1120,25 @@ def __call__( generator, ) + if "prod_masks" in joint_attention_kwargs: + masks_prod = [] + for tmp_mask_condition in mask_conditions_prod: + tmp_mask, _ = self.prepare_mask_latents( + tmp_mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + masks_prod.append(tmp_mask) + prod_pixel_num = torch.sum(tmp_mask[0]) + print(f'tmp_mask shape = {tmp_mask.shape}, prod_pixel_num = {prod_pixel_num}') + if image_ref_prod is not None: mask_original, _ = self.prepare_mask_latents( mask_condition_original, @@ -1279,6 +1308,9 @@ def mask_gradienting(mask, iterations=3): if image_ref_prod is not None: init_latents_proper_ref = image_latents_ref init_mask = mask + if "prod_masks" in joint_attention_kwargs: + init_masks_prod = masks_prod + if image_ref_prod is not None: init_mask_ref_prod = mask_original @@ -1292,6 +1324,15 @@ def mask_gradienting(mask, iterations=3): init_latents_proper_ref, torch.tensor([noise_timestep]), noise ) + if "prod_masks" in joint_attention_kwargs: + if i < averaging_steps: + masked_regions = torch.stack([ + v[m.bool()].reshape(-1) for v, m in zip([latents]*len(init_masks_prod), init_masks_prod) + ]) + avg_region = masked_regions.mean(dim=0) + for i in range(len(init_masks_prod)): + latents[init_masks_prod[i].bool()] = avg_region + if image_ref_prod is not None: latents_1 = (1 - init_mask) * init_latents_proper latents_2 = (init_mask - init_mask_ref_prod) * latents From 2fee918abf1a6bde9e6774826b70a102b2b387a4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:04:31 +0900 Subject: [PATCH 359/482] solve different size error --- .../flux/pipeline_flux_controlnet_inpainting.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index b00a70484c6e..248d49435c0a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1326,9 +1326,14 @@ def mask_gradienting(mask, iterations=3): if "prod_masks" in joint_attention_kwargs: if i < averaging_steps: - masked_regions = torch.stack([ - v[m.bool()].reshape(-1) for v, m in zip([latents]*len(init_masks_prod), init_masks_prod) - ]) + masked_regions = [] + for tmp_init_mask in init_masks_prod: + masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) + + min_length = min(t.size(0) for t in masked_regions) + truncated = [t[:min_length] for t in masked_regions] + masked_regions = torch.stack(truncated) + avg_region = masked_regions.mean(dim=0) for i in range(len(init_masks_prod)): latents[init_masks_prod[i].bool()] = avg_region From 2e3c149fb7f8a8050fe3a3c7e71ef5a405a450f9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:22:04 +0900 Subject: [PATCH 360/482] fix dim bug --- .../flux/pipeline_flux_controlnet_inpainting.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 248d49435c0a..52d2892d1131 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1325,9 +1325,11 @@ def mask_gradienting(mask, iterations=3): ) if "prod_masks" in joint_attention_kwargs: + latents = latents.view(-1) if i < averaging_steps: masked_regions = [] for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) min_length = min(t.size(0) for t in masked_regions) @@ -1335,9 +1337,18 @@ def mask_gradienting(mask, iterations=3): masked_regions = torch.stack(truncated) avg_region = masked_regions.mean(dim=0) - for i in range(len(init_masks_prod)): - latents[init_masks_prod[i].bool()] = avg_region - + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + if torch.sum(tmp_init_mask) == avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region + elif torch.sum(tmp_init_mask) < avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region[0:torch.sum(tmp_init_mask)] + else: + add_dim = torch.sum(tmp_init_mask) - avg_region.shape[0] + tmp_tensor = torch.ones(add_dim) * torch.mean(avg_region) + latents[tmp_init_mask.bool()] = torch.cat((avg_region, tmp_tensor), dim = 0) + + latents = latents.view(batch_size, 4096, -1) if image_ref_prod is not None: latents_1 = (1 - init_mask) * init_latents_proper latents_2 = (init_mask - init_mask_ref_prod) * latents From be57cf7b642d9a42f4acf0f61fcc8d2eea3c8df9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:26:46 +0900 Subject: [PATCH 361/482] print shape --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 52d2892d1131..90236b011266 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1338,6 +1338,7 @@ def mask_gradienting(mask, iterations=3): avg_region = masked_regions.mean(dim=0) for tmp_init_mask in init_masks_prod: + print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) if torch.sum(tmp_init_mask) == avg_region.shape[0]: latents[tmp_init_mask.bool()] = avg_region From 933191d46e0bfde178005b1e1b909ecd9ca171fc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:29:31 +0900 Subject: [PATCH 362/482] print info --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 90236b011266..87a6efe23abf 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1340,6 +1340,7 @@ def mask_gradienting(mask, iterations=3): for tmp_init_mask in init_masks_prod: print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) + print(f'torch.sum(tmp_init_mask)={torch.sum(tmp_init_mask)}, avg_region.shape[0]={avg_region.shape[0]}') if torch.sum(tmp_init_mask) == avg_region.shape[0]: latents[tmp_init_mask.bool()] = avg_region elif torch.sum(tmp_init_mask) < avg_region.shape[0]: From e9292b1453e4a5c2b79ec330eb5ceb19c850cdca Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:34:21 +0900 Subject: [PATCH 363/482] fix bug --- .../flux/pipeline_flux_controlnet_inpainting.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 87a6efe23abf..219b48ace2c9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1340,13 +1340,14 @@ def mask_gradienting(mask, iterations=3): for tmp_init_mask in init_masks_prod: print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) - print(f'torch.sum(tmp_init_mask)={torch.sum(tmp_init_mask)}, avg_region.shape[0]={avg_region.shape[0]}') - if torch.sum(tmp_init_mask) == avg_region.shape[0]: + length = torch.sum(tmp_init_mask) + print(f'length={length}, avg_region.shape[0]={avg_region.shape[0]}') + if length == avg_region.shape[0]: latents[tmp_init_mask.bool()] = avg_region - elif torch.sum(tmp_init_mask) < avg_region.shape[0]: - latents[tmp_init_mask.bool()] = avg_region[0:torch.sum(tmp_init_mask)] + elif length < avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region[0:length] else: - add_dim = torch.sum(tmp_init_mask) - avg_region.shape[0] + add_dim = length - avg_region.shape[0] tmp_tensor = torch.ones(add_dim) * torch.mean(avg_region) latents[tmp_init_mask.bool()] = torch.cat((avg_region, tmp_tensor), dim = 0) From 3c1f8f4f34855a97e9dd71919a1057e61e104665 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:36:36 +0900 Subject: [PATCH 364/482] print info --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 219b48ace2c9..60458a39cccd 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1341,6 +1341,7 @@ def mask_gradienting(mask, iterations=3): print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) length = torch.sum(tmp_init_mask) + print(f'tmp_init_mask={tmp_init_mask}') print(f'length={length}, avg_region.shape[0]={avg_region.shape[0]}') if length == avg_region.shape[0]: latents[tmp_init_mask.bool()] = avg_region From ab102f5137f5b970a583e50181e1214ccaf5c16b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:38:46 +0900 Subject: [PATCH 365/482] print info --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 60458a39cccd..39116b44dde5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1341,7 +1341,7 @@ def mask_gradienting(mask, iterations=3): print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) length = torch.sum(tmp_init_mask) - print(f'tmp_init_mask={tmp_init_mask}') + print(f'tmp_init_mask={torch.unique(tmp_init_mask)}') print(f'length={length}, avg_region.shape[0]={avg_region.shape[0]}') if length == avg_region.shape[0]: latents[tmp_init_mask.bool()] = avg_region From d8b4a5522a701f672a0f7fb3f5be8a718b8a20a1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:40:48 +0900 Subject: [PATCH 366/482] print --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 39116b44dde5..ce80d409d482 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1341,7 +1341,7 @@ def mask_gradienting(mask, iterations=3): print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) length = torch.sum(tmp_init_mask) - print(f'tmp_init_mask={torch.unique(tmp_init_mask)}') + print(f'tmp_init_mask={torch.unique(tmp_init_mask)}, dim = {tmp_init_mask.shape}') print(f'length={length}, avg_region.shape[0]={avg_region.shape[0]}') if length == avg_region.shape[0]: latents[tmp_init_mask.bool()] = avg_region From a2d6bd7eb848feda01a813e773b0455fb6d120a6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:43:24 +0900 Subject: [PATCH 367/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index ce80d409d482..13e8d758f69e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1340,7 +1340,7 @@ def mask_gradienting(mask, iterations=3): for tmp_init_mask in init_masks_prod: print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) - length = torch.sum(tmp_init_mask) + length = torch.sum(tmp_init_mask == 1) print(f'tmp_init_mask={torch.unique(tmp_init_mask)}, dim = {tmp_init_mask.shape}') print(f'length={length}, avg_region.shape[0]={avg_region.shape[0]}') if length == avg_region.shape[0]: From 1191dc389221cfd7b879a996899bf80e9dd44ef2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:46:42 +0900 Subject: [PATCH 368/482] fix device bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 13e8d758f69e..85f0eb8fbe9b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1349,7 +1349,8 @@ def mask_gradienting(mask, iterations=3): latents[tmp_init_mask.bool()] = avg_region[0:length] else: add_dim = length - avg_region.shape[0] - tmp_tensor = torch.ones(add_dim) * torch.mean(avg_region) + one_tensor = torch.ones(add_dim).to(device=latents.device, dtype=latents.dtype) + tmp_tensor = one_tensor * torch.mean(avg_region) latents[tmp_init_mask.bool()] = torch.cat((avg_region, tmp_tensor), dim = 0) latents = latents.view(batch_size, 4096, -1) From ac4216fe2731849cf3d072d413023c15cc38cd99 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:49:39 +0900 Subject: [PATCH 369/482] print info --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 85f0eb8fbe9b..33ca6b1e5d59 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1325,6 +1325,7 @@ def mask_gradienting(mask, iterations=3): ) if "prod_masks" in joint_attention_kwargs: + print(f'latents shape = {latents.shape}, batch_size={batch_size}') latents = latents.view(-1) if i < averaging_steps: masked_regions = [] @@ -1354,6 +1355,7 @@ def mask_gradienting(mask, iterations=3): latents[tmp_init_mask.bool()] = torch.cat((avg_region, tmp_tensor), dim = 0) latents = latents.view(batch_size, 4096, -1) + print(f'latents shape = {latents.shape}, batch_size={batch_size}') if image_ref_prod is not None: latents_1 = (1 - init_mask) * init_latents_proper latents_2 = (init_mask - init_mask_ref_prod) * latents From e402313dc467f17b0c10b5290d5964626e1e1a43 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:53:10 +0900 Subject: [PATCH 370/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 33ca6b1e5d59..112dfcb55bfd 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1325,6 +1325,7 @@ def mask_gradienting(mask, iterations=3): ) if "prod_masks" in joint_attention_kwargs: + batch_size = latents.shape[0] print(f'latents shape = {latents.shape}, batch_size={batch_size}') latents = latents.view(-1) if i < averaging_steps: From e2c999f7e41717477e5682b434085f9ede629d11 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 31 May 2025 18:55:46 +0900 Subject: [PATCH 371/482] remove print --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 112dfcb55bfd..8b20c02204cb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1137,7 +1137,6 @@ def __call__( ) masks_prod.append(tmp_mask) prod_pixel_num = torch.sum(tmp_mask[0]) - print(f'tmp_mask shape = {tmp_mask.shape}, prod_pixel_num = {prod_pixel_num}') if image_ref_prod is not None: mask_original, _ = self.prepare_mask_latents( @@ -1326,7 +1325,6 @@ def mask_gradienting(mask, iterations=3): if "prod_masks" in joint_attention_kwargs: batch_size = latents.shape[0] - print(f'latents shape = {latents.shape}, batch_size={batch_size}') latents = latents.view(-1) if i < averaging_steps: masked_regions = [] @@ -1340,11 +1338,8 @@ def mask_gradienting(mask, iterations=3): avg_region = masked_regions.mean(dim=0) for tmp_init_mask in init_masks_prod: - print(f'avg_region shape = {avg_region.shape}') tmp_init_mask = tmp_init_mask.view(-1) length = torch.sum(tmp_init_mask == 1) - print(f'tmp_init_mask={torch.unique(tmp_init_mask)}, dim = {tmp_init_mask.shape}') - print(f'length={length}, avg_region.shape[0]={avg_region.shape[0]}') if length == avg_region.shape[0]: latents[tmp_init_mask.bool()] = avg_region elif length < avg_region.shape[0]: @@ -1356,7 +1351,6 @@ def mask_gradienting(mask, iterations=3): latents[tmp_init_mask.bool()] = torch.cat((avg_region, tmp_tensor), dim = 0) latents = latents.view(batch_size, 4096, -1) - print(f'latents shape = {latents.shape}, batch_size={batch_size}') if image_ref_prod is not None: latents_1 = (1 - init_mask) * init_latents_proper latents_2 = (init_mask - init_mask_ref_prod) * latents From 4cbf3af633d5e4ff42b327e4d739a37b71d326c6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 1 Jun 2025 23:33:37 +0900 Subject: [PATCH 372/482] clear codes --- .../pipeline_flux_controlnet_inpainting.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 8b20c02204cb..5733a3f305f3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -744,10 +744,11 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - image_ref_prod: PipelineImageInput = None, # original pord image - ratio_ref_prod: Optional[float] = 0.125, # modified for injecting pord image + image_ref_prod: Optional[PipelineImageInput] = None, # original prod image + ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image mask_image: PipelineImageInput = None, - mask_image_original: PipelineImageInput = None, # modified for injecting original pord images + mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images + prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, @@ -773,7 +774,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, iterations: Optional[int] = 4, # modified for applying gradient to mask_image - averaging_steps: Optional[int] = 0, # modified for applying averaging latents for multiple copies of same product + averaging_steps: Optional[int] = 1, # modified for applying averaging latents for multiple copies of same product ): """ Function invoked when calling the pipeline for generation. @@ -1089,9 +1090,10 @@ def __call__( else: masked_image = masked_image_latents - if "prod_masks" in joint_attention_kwargs: + # handel prod masks + if prod_masks_original is not None: mask_conditions_prod = [] - for prod_mask in joint_attention_kwargs["prod_masks"]: + for prod_mask in prod_masks_original: tmp_mask_condition = self.mask_processor.preprocess( prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords ) @@ -1120,7 +1122,7 @@ def __call__( generator, ) - if "prod_masks" in joint_attention_kwargs: + if prod_masks_original is not None: masks_prod = [] for tmp_mask_condition in mask_conditions_prod: tmp_mask, _ = self.prepare_mask_latents( @@ -1136,7 +1138,6 @@ def __call__( generator, ) masks_prod.append(tmp_mask) - prod_pixel_num = torch.sum(tmp_mask[0]) if image_ref_prod is not None: mask_original, _ = self.prepare_mask_latents( @@ -1307,7 +1308,7 @@ def mask_gradienting(mask, iterations=3): if image_ref_prod is not None: init_latents_proper_ref = image_latents_ref init_mask = mask - if "prod_masks" in joint_attention_kwargs: + if prod_masks_original is not None: init_masks_prod = masks_prod if image_ref_prod is not None: @@ -1323,10 +1324,11 @@ def mask_gradienting(mask, iterations=3): init_latents_proper_ref, torch.tensor([noise_timestep]), noise ) - if "prod_masks" in joint_attention_kwargs: - batch_size = latents.shape[0] - latents = latents.view(-1) + # average multiple product latents + if prod_masks_original is not None: if i < averaging_steps: + batch_size = latents.shape[0] + latents = latents.view(-1) masked_regions = [] for tmp_init_mask in init_masks_prod: tmp_init_mask = tmp_init_mask.view(-1) @@ -1351,6 +1353,8 @@ def mask_gradienting(mask, iterations=3): latents[tmp_init_mask.bool()] = torch.cat((avg_region, tmp_tensor), dim = 0) latents = latents.view(batch_size, 4096, -1) + + # inject product raw images into latents if image_ref_prod is not None: latents_1 = (1 - init_mask) * init_latents_proper latents_2 = (init_mask - init_mask_ref_prod) * latents From 27e3c7cc1e8f023e772854e2cce87f2a559128c7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 2 Jun 2025 11:09:30 +0900 Subject: [PATCH 373/482] fix cat bug --- .../flux/pipeline_flux_controlnet_inpainting.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 5733a3f305f3..1f6c9a8fd998 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1348,9 +1348,13 @@ def mask_gradienting(mask, iterations=3): latents[tmp_init_mask.bool()] = avg_region[0:length] else: add_dim = length - avg_region.shape[0] - one_tensor = torch.ones(add_dim).to(device=latents.device, dtype=latents.dtype) - tmp_tensor = one_tensor * torch.mean(avg_region) - latents[tmp_init_mask.bool()] = torch.cat((avg_region, tmp_tensor), dim = 0) + add_dim1 = add_dim // 2 + add_dim2 = add_dim - add_dim1 + one_tensor1 = torch.ones(add_dim1).to(device=latents.device, dtype=latents.dtype) + one_tensor2 = torch.ones(add_dim2).to(device=latents.device, dtype=latents.dtype) + tmp_tensor1 = one_tensor1 * torch.mean(avg_region) + tmp_tensor2 = one_tensor2 * torch.mean(avg_region) + latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) latents = latents.view(batch_size, 4096, -1) From 3dfb6c88ee374ab6c431c46574bc93e3e2cda603 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 2 Jun 2025 11:24:43 +0900 Subject: [PATCH 374/482] fix bug --- .../flux/pipeline_flux_controlnet_inpainting.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 1f6c9a8fd998..6535835a816f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1335,7 +1335,19 @@ def mask_gradienting(mask, iterations=3): masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) min_length = min(t.size(0) for t in masked_regions) - truncated = [t[:min_length] for t in masked_regions] + #truncated = [t[:min_length] for t in masked_regions] + truncated = [] + for tmp_mask_region in masked_regions: + if tmp_mask_region.size(0) == min_length: + truncated.append(tmp_mask_region) + else: + extra_dim = tmp_mask_region.size(0) - min_length + half = extra_dim // 2 + if extra_dim % 2 == 0: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim)]) + else: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim-1)]) + masked_regions = torch.stack(truncated) avg_region = masked_regions.mean(dim=0) From 1050bb3c20ee5ff5aef6877b741aaba9e467423e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 2 Jun 2025 14:01:30 +0900 Subject: [PATCH 375/482] change default value --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 6535835a816f..f774a905d2ef 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -774,7 +774,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, iterations: Optional[int] = 4, # modified for applying gradient to mask_image - averaging_steps: Optional[int] = 1, # modified for applying averaging latents for multiple copies of same product + averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ): """ Function invoked when calling the pipeline for generation. From b78b181ac57d205c26d3b3da932232d9489a6ccd Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 3 Jun 2025 16:24:48 +0900 Subject: [PATCH 376/482] add ref_prod_injection_steps para --- .../flux/pipeline_flux_controlnet_enhanced.py | 36 ++++++++++++------- .../pipeline_flux_controlnet_inpainting.py | 14 +++++--- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py index 1899a961ab1b..623a514a2c08 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py @@ -240,6 +240,14 @@ def __init__( # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -816,6 +824,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images ): r""" Function invoked when calling the pipeline for generation. @@ -1325,19 +1334,20 @@ def __call__( # additional codes for injecting original prod images into latents if image is not None: - init_mask = mask - init_latents_proper = image_latents - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise - ) - - latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) - latents_2 = latents_1 + (1.0 - init_mask) * latents - - latents = latents_1 + latents_2 + if i < ref_prod_injection_steps: + init_mask = mask + init_latents_proper = image_latents + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_2 = latents_1 + (1.0 - init_mask) * latents + + latents = latents_1 + latents_2 if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index f774a905d2ef..cfa059395905 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -775,6 +775,7 @@ def __call__( max_sequence_length: int = 512, iterations: Optional[int] = 4, # modified for applying gradient to mask_image averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product + ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images ): """ Function invoked when calling the pipeline for generation. @@ -1372,11 +1373,14 @@ def mask_gradienting(mask, iterations=3): # inject product raw images into latents if image_ref_prod is not None: - latents_1 = (1 - init_mask) * init_latents_proper - latents_2 = (init_mask - init_mask_ref_prod) * latents - latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref - - latents = latents_1 + latents_2 + latents_3 + if i < ref_prod_injection_steps: + latents_1 = (1 - init_mask) * init_latents_proper + latents_2 = (init_mask - init_mask_ref_prod) * latents + latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref + + latents = latents_1 + latents_2 + latents_3 + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents From a0ff3c5af5b33a8ad6a285c4fe3c4275904047c3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 3 Jun 2025 17:12:57 +0900 Subject: [PATCH 377/482] add averaging logic --- .../flux/pipeline_flux_controlnet_enhanced.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py index 623a514a2c08..f8e5a6cfa0f5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py @@ -794,6 +794,7 @@ def __call__( image: PipelineImageInput = None, ratio_ref: Optional[float] = 0.1, mask_image: PipelineImageInput = None, + prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, @@ -824,6 +825,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images ): r""" @@ -1207,6 +1209,30 @@ def __call__( device, generator, ) + + if prod_masks_original is not None: + mask_conditions_prod = [] + for prod_mask in prod_masks_original: + tmp_mask_condition = self.mask_processor.preprocess( + prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + mask_conditions_prod.append(tmp_mask_condition) + + masks_prod = [] + for tmp_mask_condition in mask_conditions_prod: + tmp_mask, _ = self.prepare_mask_latents( + tmp_mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + masks_prod.append(tmp_mask) # 6. Create tensor stating which controlnets to keep controlnet_keep = [] @@ -1332,6 +1358,53 @@ def __call__( latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + # average multiple product latents + if prod_masks_original is not None: + init_masks_prod = masks_prod + + if i < averaging_steps: + batch_size = latents.shape[0] + latents = latents.view(-1) + masked_regions = [] + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) + + min_length = min(t.size(0) for t in masked_regions) + #truncated = [t[:min_length] for t in masked_regions] + truncated = [] + for tmp_mask_region in masked_regions: + if tmp_mask_region.size(0) == min_length: + truncated.append(tmp_mask_region) + else: + extra_dim = tmp_mask_region.size(0) - min_length + half = extra_dim // 2 + if extra_dim % 2 == 0: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim)]) + else: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim-1)]) + + masked_regions = torch.stack(truncated) + + avg_region = masked_regions.mean(dim=0) + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + length = torch.sum(tmp_init_mask == 1) + if length == avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region + elif length < avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region[0:length] + else: + add_dim = length - avg_region.shape[0] + add_dim1 = add_dim // 2 + add_dim2 = add_dim - add_dim1 + one_tensor1 = torch.ones(add_dim1).to(device=latents.device, dtype=latents.dtype) + one_tensor2 = torch.ones(add_dim2).to(device=latents.device, dtype=latents.dtype) + tmp_tensor1 = one_tensor1 * torch.mean(avg_region) + tmp_tensor2 = one_tensor2 * torch.mean(avg_region) + latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) + + latents = latents.view(batch_size, 4096, -1) # additional codes for injecting original prod images into latents if image is not None: if i < ref_prod_injection_steps: From 22c853f9535da422493ba768522686fd04831257 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 21:15:14 +0800 Subject: [PATCH 378/482] add comments --- .../pipelines/flux/pipeline_flux_controlnet_enhanced.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py index f8e5a6cfa0f5..b6a6ce6bac8f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py @@ -791,10 +791,10 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - ratio_ref: Optional[float] = 0.1, - mask_image: PipelineImageInput = None, - prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images + image: PipelineImageInput = None, # original prod image + ratio_ref: Optional[float] = 0.1, # ratio of original prod images + mask_image: PipelineImageInput = None, # modified for injecting original prod images + prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, @@ -1405,6 +1405,7 @@ def __call__( latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) latents = latents.view(batch_size, 4096, -1) + # additional codes for injecting original prod images into latents if image is not None: if i < ref_prod_injection_steps: From 70a8951830d3972d7f1544cb233ae48e6baae7b3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 21:28:56 +0800 Subject: [PATCH 379/482] add pipeline --- src/diffusers/pipelines/flux/__init__.py | 4 +++- .../pipelines/flux/pipeline_flux_controlnet_enhanced.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 72e1b578f2ca..ab3271b62a46 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -28,6 +28,7 @@ _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"] _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] + _import_structure["pipeline_flux_controlnet_enhanced"] = ["FluxControlNetPipelineEnhanced"] _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] @@ -47,10 +48,11 @@ from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline from .pipeline_flux_controlnet import FluxControlNetPipeline + from .pipeline_flux_controlnet_enhanced import FluxControlNetPipelineEnhanced from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline from .pipeline_flux_fill import FluxFillPipeline - from .pipeline_flux_img2img import FluxImg2ImgPipeline + from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_prior_redux import FluxPriorReduxPipeline else: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py index b6a6ce6bac8f..fed40dc32273 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py @@ -174,7 +174,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): +class FluxControlNetPipelineEnhanced(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): r""" The Flux pipeline for text-to-image generation. @@ -1405,7 +1405,7 @@ def __call__( latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) latents = latents.view(batch_size, 4096, -1) - + # additional codes for injecting original prod images into latents if image is not None: if i < ref_prod_injection_steps: From d2453b4a31ea741c3a69ca1cff60c238253fc13c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 21:34:51 +0800 Subject: [PATCH 380/482] Update pipeline_flux_controlnet.py --- .../flux/pipeline_flux_controlnet.py | 300 ++++++++++++++++-- 1 file changed, 281 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 8e3586c68494..35bd23a90687 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -240,6 +240,14 @@ def __init__( # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -590,6 +598,8 @@ def _unpack_latents(latents, height, width, vae_scale_factor): # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, + image, + timestep, batch_size, num_channels_latents, height, @@ -605,24 +615,123 @@ def prepare_latents( width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if image is not None: + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + image_latents = None + noise = None + + return latents.to(device=device, dtype=dtype), noise, image_latents, latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + + if image is not None: + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + image_latents = None + noise = None - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt - return latents, latent_image_ids + masked_image = masked_image.to(device=device, dtype=dtype) + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image def prepare_image( self, @@ -682,11 +791,17 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, # original prod image + ratio_ref: Optional[float] = 0.1, # ratio of original prod images + mask_image: PipelineImageInput = None, # modified for injecting original prod images + prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, + strength: Optional[float] = 0.6, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, + padding_mask_crop: Optional[int] = None, guidance_scale: float = 7.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, @@ -710,6 +825,8 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product + ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images ): r""" Function invoked when calling the pipeline for generation. @@ -813,6 +930,9 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + global_height = height + global_width = width + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): @@ -914,6 +1034,22 @@ def __call__( lora_scale=lora_scale, ) + # Preprocess mask and image + if image is not None: + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): @@ -1003,22 +1139,11 @@ def __call__( control_modes.append(control_mode) control_mode = control_modes - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 5. Prepare timesteps + # Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = latents.shape[1] + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), @@ -1037,6 +1162,78 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + if image is not None: + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Prepare mask latents + if image is not None: + mask_condition = self.mask_processor.preprocess( + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + if prod_masks_original is not None: + mask_conditions_prod = [] + for prod_mask in prod_masks_original: + tmp_mask_condition = self.mask_processor.preprocess( + prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + mask_conditions_prod.append(tmp_mask_condition) + + masks_prod = [] + for tmp_mask_condition in mask_conditions_prod: + tmp_mask, _ = self.prepare_mask_latents( + tmp_mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + masks_prod.append(tmp_mask) + # 6. Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): @@ -1161,6 +1358,71 @@ def __call__( latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + # average multiple product latents + if prod_masks_original is not None: + init_masks_prod = masks_prod + + if i < averaging_steps: + batch_size = latents.shape[0] + latents = latents.view(-1) + masked_regions = [] + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) + + min_length = min(t.size(0) for t in masked_regions) + #truncated = [t[:min_length] for t in masked_regions] + truncated = [] + for tmp_mask_region in masked_regions: + if tmp_mask_region.size(0) == min_length: + truncated.append(tmp_mask_region) + else: + extra_dim = tmp_mask_region.size(0) - min_length + half = extra_dim // 2 + if extra_dim % 2 == 0: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim)]) + else: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim-1)]) + + masked_regions = torch.stack(truncated) + + avg_region = masked_regions.mean(dim=0) + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + length = torch.sum(tmp_init_mask == 1) + if length == avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region + elif length < avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region[0:length] + else: + add_dim = length - avg_region.shape[0] + add_dim1 = add_dim // 2 + add_dim2 = add_dim - add_dim1 + one_tensor1 = torch.ones(add_dim1).to(device=latents.device, dtype=latents.dtype) + one_tensor2 = torch.ones(add_dim2).to(device=latents.device, dtype=latents.dtype) + tmp_tensor1 = one_tensor1 * torch.mean(avg_region) + tmp_tensor2 = one_tensor2 * torch.mean(avg_region) + latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) + + latents = latents.view(batch_size, 4096, -1) + + # additional codes for injecting original prod images into latents + if image is not None: + if i < ref_prod_injection_steps: + init_mask = mask + init_latents_proper = image_latents + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_2 = latents_1 + (1.0 - init_mask) * latents + + latents = latents_1 + latents_2 + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 From dc669ef506277f1b807ffa87a70f4591fc73d18c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 21:41:12 +0800 Subject: [PATCH 381/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet.py | 12 ++++++++++++ .../flux/pipeline_flux_controlnet_enhanced.py | 13 +++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 35bd23a90687..7b21cc68ed8f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -477,6 +477,18 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + def check_inputs( self, prompt, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py index fed40dc32273..166aa8a30595 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_enhanced.py @@ -477,6 +477,19 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( self, prompt, From d304457afa230efab1fa204450c6791ea00e61be Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 21:44:32 +0800 Subject: [PATCH 382/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 7b21cc68ed8f..802d28456624 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -477,6 +477,22 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep From 1c367c658cbcfda7a6455a45fc4780d296230ca2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 21:48:33 +0800 Subject: [PATCH 383/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 802d28456624..72b359b561ab 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -823,6 +823,7 @@ def __call__( ratio_ref: Optional[float] = 0.1, # ratio of original prod images mask_image: PipelineImageInput = None, # modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies + masked_image_latents: PipelineImageInput = None, true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, From 441b168bfad605343732cfbd2bf8e6ad87ee6324 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 21:53:44 +0800 Subject: [PATCH 384/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 72b359b561ab..cf8a6bf356b3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1448,7 +1448,7 @@ def __call__( ) latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) - latents_2 = latents_1 + (1.0 - init_mask) * latents + latents_2 = (1.0 - init_mask) * latents latents = latents_1 + latents_2 From 0752deafb3d601d9e6aa673bc78367b11a4c1c28 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 22:22:39 +0800 Subject: [PATCH 385/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index cf8a6bf356b3..6a053df996f9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1078,7 +1078,8 @@ def __call__( image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) - + else: + init_image = None # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): From c9d5675c83ed12ae921721a3cfc8001024c33eec Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 22:25:13 +0800 Subject: [PATCH 386/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 6a053df996f9..616748c20f7d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1201,6 +1201,8 @@ def __call__( f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + else: + latent_timestep = None # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 From 743d1bc61f41b478499fccdf1aaa68f6b1182962 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 22:47:52 +0800 Subject: [PATCH 387/482] return init_mask --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 616748c20f7d..b831f17c9863 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1491,6 +1491,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (image,init_mask) return FluxPipelineOutput(images=image) \ No newline at end of file From 4f2ce9daef97290ca5e51939edc8920477b71062 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 22:56:34 +0800 Subject: [PATCH 388/482] test --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index b831f17c9863..5689b417ad48 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1450,7 +1450,8 @@ def __call__( init_latents_proper, torch.tensor([noise_timestep]), noise ) - latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + #latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_1 = init_mask * (0 * init_latents_proper + (1.0 - 0) * latents) latents_2 = (1.0 - init_mask) * latents latents = latents_1 + latents_2 From 539819035429d500579d4d73fde64f11d4a415ce Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 4 Jun 2025 23:01:12 +0800 Subject: [PATCH 389/482] remove test --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 5689b417ad48..b831f17c9863 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1450,8 +1450,7 @@ def __call__( init_latents_proper, torch.tensor([noise_timestep]), noise ) - #latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) - latents_1 = init_mask * (0 * init_latents_proper + (1.0 - 0) * latents) + latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) latents_2 = (1.0 - init_mask) * latents latents = latents_1 + latents_2 From a060a49ffb16adddce64eccf4993283061f049fd Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 6 Jun 2025 11:35:03 +0800 Subject: [PATCH 390/482] add inpainting_starting_step --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index cfa059395905..5aedfe8d1d32 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -776,6 +776,7 @@ def __call__( iterations: Optional[int] = 4, # modified for applying gradient to mask_image averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting ): """ Function invoked when calling the pipeline for generation. @@ -1381,8 +1382,9 @@ def mask_gradienting(mask, iterations=3): latents = latents_1 + latents_2 + latents_3 else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents - else: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + else: + if i >= inpainting_starting_step: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From b76dc594c5f64290c1313aa1e863cdb276e727d5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 6 Jun 2025 15:35:22 +0800 Subject: [PATCH 391/482] add inpainting logic to blend mode --- .../flux/pipeline_flux_controlnet.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index b831f17c9863..34d66a88980f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -820,8 +820,10 @@ def __call__( negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, # original prod image + image_bg: PipelineImageInput = None, # original bg image ratio_ref: Optional[float] = 0.1, # ratio of original prod images mask_image: PipelineImageInput = None, # modified for injecting original prod images + mask_image_bg: PipelineImageInput = None, # modified for injecting bg image prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies masked_image_latents: PipelineImageInput = None, true_cfg_scale: float = 1.0, @@ -856,6 +858,7 @@ def __call__( max_sequence_length: int = 512, averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting ): r""" Function invoked when calling the pipeline for generation. @@ -1080,6 +1083,24 @@ def __call__( init_image = init_image.to(dtype=torch.float32) else: init_image = None + + if image_bg is not None: + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image_bg, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + init_image_bg = self.image_processor.preprocess( + image_bg, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image_bg = init_image_bg.to(dtype=torch.float32) + else: + init_image_bg = None + # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): @@ -1219,6 +1240,22 @@ def __call__( latents, ) + if image_bg is not None: + latents = None + latents_bg, noise_bg, image_latents_bg, latent_image_ids = self.prepare_latents( + init_image_bg, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents = latents_bg + # Prepare mask latents if image is not None: mask_condition = self.mask_processor.preprocess( @@ -1242,6 +1279,28 @@ def __call__( generator, ) + if image_bg is not None: + mask_condition_bg = self.mask_processor.preprocess( + mask_image_bg, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) + else: + masked_image_bg = masked_image_latents + + mask_bg, masked_image_latents = self.prepare_mask_latents( + mask_condition_bg, + masked_image_bg, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + if prod_masks_original is not None: mask_conditions_prod = [] for prod_mask in prod_masks_original: @@ -1455,6 +1514,20 @@ def __call__( latents = latents_1 + latents_2 + if image_bg is not None: + if i >= inpainting_starting_step: + init_mask_bg = mask_bg + init_latents_proper_bg = image_latents_bg + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper_bg = self.scheduler.scale_noise( + init_latents_proper_bg, torch.tensor([noise_timestep]), noise_bg + ) + + latents = (1 - init_mask_bg) * init_latents_proper_bg + init_mask_bg * latents + + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 From 36193cd5ab9ad9587c512db1b8524c0713b0b44f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 8 Jun 2025 10:29:18 +0800 Subject: [PATCH 392/482] add inpainting_ending_step para --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 34d66a88980f..20534849665d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -859,6 +859,7 @@ def __call__( averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting + inpainting_ending_step: Optional[int] = 22, # modified for ending inpainting ): r""" Function invoked when calling the pipeline for generation. @@ -1515,7 +1516,7 @@ def __call__( latents = latents_1 + latents_2 if image_bg is not None: - if i >= inpainting_starting_step: + if i >= inpainting_starting_step and i <= inpainting_ending_step: init_mask_bg = mask_bg init_latents_proper_bg = image_latents_bg From 20c93588d2a902f7639ef7dd234b9a5cbd735c6c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 8 Jun 2025 11:06:58 +0800 Subject: [PATCH 393/482] add steps for using redux output --- .../flux/pipeline_flux_controlnet.py | 108 ++++++++++++------ 1 file changed, 75 insertions(+), 33 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 20534849665d..badcaa1d5c0e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -860,6 +860,8 @@ def __call__( ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting inpainting_ending_step: Optional[int] = 22, # modified for ending inpainting + redux_starting_step: Optional[int] = 0, # modified for starting using redux output + redux_ending_step: Optional[int] = 22, # modified for ending using redux output ): r""" Function invoked when calling the pipeline for generation. @@ -1391,41 +1393,81 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - # controlnet - controlnet_block_samples, controlnet_single_block_samples = self.controlnet( - hidden_states=latents, - controlnet_cond=control_image, - controlnet_mode=control_mode, - conditioning_scale=cond_scale, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - ) + if i >= redux_starting_step and i <= redux_ending_step: + print(f'pooled_prompt_embeds shape = {pooled_prompt_embeds.shape}') + print(f'prompt_embeds = {prompt_embeds.shape}') + print(f'text_ids shape = {text_ids.shape}') + # controlnet + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) - guidance = ( - torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None - ) - guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + else: + # controlnet + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds[:, 0:512, :], + txt_ids=text_ids[:, 0:512], + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds[:, 0:512, :], + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids[:, 0:512], + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] if do_true_cfg: if negative_image_embeds is not None: From e7c10e55ff2e494ec1cfde8e91b3af539a481ea4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 8 Jun 2025 16:19:22 +0800 Subject: [PATCH 394/482] add ratio bg para --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index badcaa1d5c0e..c48413fe0277 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -821,7 +821,8 @@ def __call__( negative_prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, # original prod image image_bg: PipelineImageInput = None, # original bg image - ratio_ref: Optional[float] = 0.1, # ratio of original prod images + ratio_prod: Optional[float] = 0.1, # ratio of original prod images + ratio_bg: Optional[float] = 0.1, # ratio of original bg images mask_image: PipelineImageInput = None, # modified for injecting original prod images mask_image_bg: PipelineImageInput = None, # modified for injecting bg image prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies @@ -1552,7 +1553,7 @@ def __call__( init_latents_proper, torch.tensor([noise_timestep]), noise ) - latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_1 = init_mask * (ratio_prod * init_latents_proper + (1.0 - ratio_prod) * latents) latents_2 = (1.0 - init_mask) * latents latents = latents_1 + latents_2 @@ -1568,7 +1569,7 @@ def __call__( init_latents_proper_bg, torch.tensor([noise_timestep]), noise_bg ) - latents = (1 - init_mask_bg) * init_latents_proper_bg + init_mask_bg * latents + latents = (1 - init_mask_bg) * (ratio_bg * init_latents_proper_bg + (1.0 - ratio_bg) * latents) + init_mask_bg * latents if latents.dtype != latents_dtype: From fc60099d044b527a15f74c5b4301c9ca6455fb97 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 9 Jun 2025 14:38:56 +0800 Subject: [PATCH 395/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index c48413fe0277..0d2565559c24 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1220,6 +1220,15 @@ def __call__( if image is not None: timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + elif image_bg is not None: + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" From d3a33e54a550c0924de7df9f5bca9d8bfce8f95d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 9 Jun 2025 18:25:41 +0800 Subject: [PATCH 396/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 0d2565559c24..78b3e5c69efa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1453,7 +1453,7 @@ def __call__( guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds[:, 0:512, :], - txt_ids=text_ids[:, 0:512], + txt_ids=text_ids[0:512, :], img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, @@ -1472,7 +1472,7 @@ def __call__( encoder_hidden_states=prompt_embeds[:, 0:512, :], controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids[:, 0:512], + txt_ids=text_ids[0:512, :], img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, From 4a7461d718e95421b3646800f46f5512665debc8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 9 Jun 2025 18:48:07 +0800 Subject: [PATCH 397/482] unify blend and qv in qv --- .../pipeline_flux_controlnet_inpainting.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 5aedfe8d1d32..83e65d4cacc9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -777,6 +777,7 @@ def __call__( averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting + inpainting_ending_step: Optional[int] = 0, # modified for starting inpainting ): """ Function invoked when calling the pipeline for generation. @@ -1068,6 +1069,8 @@ def __call__( generator, latents, ) + if inpainting_starting_step < 0: + latents = None if image_ref_prod is not None: _, _, image_latents_ref, _ = self.prepare_latents( @@ -1307,27 +1310,17 @@ def mask_gradienting(mask, iterations=3): # For inpainting, we need to apply the mask and add the masked image latents init_latents_proper = image_latents - if image_ref_prod is not None: - init_latents_proper_ref = image_latents_ref init_mask = mask - if prod_masks_original is not None: - init_masks_prod = masks_prod - - if image_ref_prod is not None: - init_mask_ref_prod = mask_original if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.scale_noise( init_latents_proper, torch.tensor([noise_timestep]), noise ) - if image_ref_prod is not None: - init_latents_proper_ref = self.scheduler.scale_noise( - init_latents_proper_ref, torch.tensor([noise_timestep]), noise - ) - + # average multiple product latents if prod_masks_original is not None: + init_masks_prod = masks_prod if i < averaging_steps: batch_size = latents.shape[0] latents = latents.view(-1) @@ -1373,18 +1366,36 @@ def mask_gradienting(mask, iterations=3): latents = latents.view(batch_size, 4096, -1) # inject product raw images into latents + #if image_ref_prod is not None: + # if i < ref_prod_injection_steps: + # latents_1 = (1 - init_mask) * init_latents_proper + # latents_2 = (init_mask - init_mask_ref_prod) * latents + # latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref + + # latents = latents_1 + latents_2 + latents_3 + if image_ref_prod is not None: if i < ref_prod_injection_steps: - latents_1 = (1 - init_mask) * init_latents_proper - latents_2 = (init_mask - init_mask_ref_prod) * latents - latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref + init_mask_ref_prod = mask_original + init_latents_proper_ref = image_latents_ref + + init_latents_proper_ref = self.scheduler.scale_noise( + init_latents_proper_ref, torch.tensor([noise_timestep]), noise + ) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper_ref = self.scheduler.scale_noise( + init_latents_proper_ref, torch.tensor([noise_timestep]), noise + ) + + latents_1 = init_mask_ref_prod * (ratio_ref_prod * init_latents_proper_ref + (1.0 - ratio_ref_prod) * latents) + latents_2 = (1.0 - init_mask_ref_prod) * latents - latents = latents_1 + latents_2 + latents_3 - else: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - else: - if i >= inpainting_starting_step: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + latents = latents_1 + latents_2 + + if i >= inpainting_starting_step and i < inpainting_ending_step: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 0172575ec8cf098c6d83549060795b06f2675523 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 9 Jun 2025 18:58:16 +0800 Subject: [PATCH 398/482] fix bug --- .../flux/pipeline_flux_controlnet_inpainting.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 83e65d4cacc9..cc62f38791f1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -575,6 +575,7 @@ def prepare_latents( device, generator, latents=None, + inpainting_starting_step = 0, ): if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -609,6 +610,11 @@ def prepare_latents( else: noise = latents.to(device) latents = noise + + if inpainting_starting_step < 0: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = latents.to(device) + latents = noise noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) @@ -1068,9 +1074,8 @@ def __call__( device, generator, latents, + inpainting_starting_step, ) - if inpainting_starting_step < 0: - latents = None if image_ref_prod is not None: _, _, image_latents_ref, _ = self.prepare_latents( From cdbb772c38f0ebcad28245432d54b410b4612789 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 9 Jun 2025 23:09:16 +0800 Subject: [PATCH 399/482] revert to original controlnet pipeline --- .../flux/pipeline_flux_controlnet.py | 526 ++---------------- 1 file changed, 53 insertions(+), 473 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 78b3e5c69efa..8e3586c68494 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -240,14 +240,6 @@ def __init__( # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=latent_channels, - do_normalize=False, - do_binarize=True, - do_convert_grayscale=True, - ) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -477,34 +469,6 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - return image_latents - - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * strength, num_inference_steps) - - t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - def check_inputs( self, prompt, @@ -626,8 +590,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor): # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, - image, - timestep, batch_size, num_channels_latents, height, @@ -643,123 +605,24 @@ def prepare_latents( width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - if image is not None: - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - image_latents = None - noise = None - - return latents.to(device=device, dtype=dtype), noise, image_latents, latent_image_ids + return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - - if image is not None: - if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) - - noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) - image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) - else: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - image_latents = None - noise = None + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, noise, image_latents, latent_image_ids - - def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - height, - width, - dtype, - device, - generator, - ): - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height, width)) - mask = mask.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - masked_image = masked_image.to(device=device, dtype=dtype) - - if masked_image.shape[1] == 16: - masked_image_latents = masked_image - else: - masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) - - masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - masked_image_latents = self._pack_latents( - masked_image_latents, - batch_size, - num_channels_latents, - height, - width, - ) - mask = self._pack_latents( - mask.repeat(1, num_channels_latents, 1, 1), - batch_size, - num_channels_latents, - height, - width, - ) + return latents, latent_image_ids - return mask, masked_image_latents - # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image def prepare_image( self, @@ -819,21 +682,11 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, # original prod image - image_bg: PipelineImageInput = None, # original bg image - ratio_prod: Optional[float] = 0.1, # ratio of original prod images - ratio_bg: Optional[float] = 0.1, # ratio of original bg images - mask_image: PipelineImageInput = None, # modified for injecting original prod images - mask_image_bg: PipelineImageInput = None, # modified for injecting bg image - prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies - masked_image_latents: PipelineImageInput = None, true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, - strength: Optional[float] = 0.6, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, - padding_mask_crop: Optional[int] = None, guidance_scale: float = 7.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, @@ -857,12 +710,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product - ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images - inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting - inpainting_ending_step: Optional[int] = 22, # modified for ending inpainting - redux_starting_step: Optional[int] = 0, # modified for starting using redux output - redux_ending_step: Optional[int] = 22, # modified for ending using redux output ): r""" Function invoked when calling the pipeline for generation. @@ -966,9 +813,6 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor - global_height = height - global_width = width - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): @@ -1070,41 +914,6 @@ def __call__( lora_scale=lora_scale, ) - # Preprocess mask and image - if image is not None: - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region( - mask_image, global_width, global_height, pad=padding_mask_crop - ) - resize_mode = "fill" - else: - crops_coords = None - resize_mode = "default" - - init_image = self.image_processor.preprocess( - image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image = init_image.to(dtype=torch.float32) - else: - init_image = None - - if image_bg is not None: - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region( - mask_image_bg, global_width, global_height, pad=padding_mask_crop - ) - resize_mode = "fill" - else: - crops_coords = None - resize_mode = "default" - - init_image_bg = self.image_processor.preprocess( - image_bg, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image_bg = init_image_bg.to(dtype=torch.float32) - else: - init_image_bg = None - # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): @@ -1194,11 +1003,22 @@ def __call__( control_modes.append(control_mode) control_mode = control_modes - # Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( - int(global_width) // self.vae_scale_factor // 2 + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), @@ -1217,127 +1037,6 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - if image is not None: - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - elif image_bg is not None: - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - else: - latent_timestep = None - - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, noise, image_latents, latent_image_ids = self.prepare_latents( - init_image, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - if image_bg is not None: - latents = None - latents_bg, noise_bg, image_latents_bg, latent_image_ids = self.prepare_latents( - init_image_bg, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - latents = latents_bg - - # Prepare mask latents - if image is not None: - mask_condition = self.mask_processor.preprocess( - mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) - else: - masked_image = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( - mask_condition, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) - - if image_bg is not None: - mask_condition_bg = self.mask_processor.preprocess( - mask_image_bg, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) - else: - masked_image_bg = masked_image_latents - - mask_bg, masked_image_latents = self.prepare_mask_latents( - mask_condition_bg, - masked_image_bg, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) - - if prod_masks_original is not None: - mask_conditions_prod = [] - for prod_mask in prod_masks_original: - tmp_mask_condition = self.mask_processor.preprocess( - prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - mask_conditions_prod.append(tmp_mask_condition) - - masks_prod = [] - for tmp_mask_condition in mask_conditions_prod: - tmp_mask, _ = self.prepare_mask_latents( - tmp_mask_condition, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) - masks_prod.append(tmp_mask) - # 6. Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): @@ -1403,81 +1102,41 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - if i >= redux_starting_step and i <= redux_ending_step: - print(f'pooled_prompt_embeds shape = {pooled_prompt_embeds.shape}') - print(f'prompt_embeds = {prompt_embeds.shape}') - print(f'text_ids shape = {text_ids.shape}') - # controlnet - controlnet_block_samples, controlnet_single_block_samples = self.controlnet( - hidden_states=latents, - controlnet_cond=control_image, - controlnet_mode=control_mode, - conditioning_scale=cond_scale, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - ) - - guidance = ( - torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None - ) - guidance = guidance.expand(latents.shape[0]) if guidance is not None else None - - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] - else: - # controlnet - controlnet_block_samples, controlnet_single_block_samples = self.controlnet( - hidden_states=latents, - controlnet_cond=control_image, - controlnet_mode=control_mode, - conditioning_scale=cond_scale, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds[:, 0:512, :], - txt_ids=text_ids[0:512, :], - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - ) + # controlnet + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) - guidance = ( - torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None - ) - guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds[:, 0:512, :], - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids[0:512, :], - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] if do_true_cfg: if negative_image_embeds is not None: @@ -1502,85 +1161,6 @@ def __call__( latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - # average multiple product latents - if prod_masks_original is not None: - init_masks_prod = masks_prod - - if i < averaging_steps: - batch_size = latents.shape[0] - latents = latents.view(-1) - masked_regions = [] - for tmp_init_mask in init_masks_prod: - tmp_init_mask = tmp_init_mask.view(-1) - masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) - - min_length = min(t.size(0) for t in masked_regions) - #truncated = [t[:min_length] for t in masked_regions] - truncated = [] - for tmp_mask_region in masked_regions: - if tmp_mask_region.size(0) == min_length: - truncated.append(tmp_mask_region) - else: - extra_dim = tmp_mask_region.size(0) - min_length - half = extra_dim // 2 - if extra_dim % 2 == 0: - truncated.append(tmp_mask_region[half:(min_length-half+extra_dim)]) - else: - truncated.append(tmp_mask_region[half:(min_length-half+extra_dim-1)]) - - masked_regions = torch.stack(truncated) - - avg_region = masked_regions.mean(dim=0) - for tmp_init_mask in init_masks_prod: - tmp_init_mask = tmp_init_mask.view(-1) - length = torch.sum(tmp_init_mask == 1) - if length == avg_region.shape[0]: - latents[tmp_init_mask.bool()] = avg_region - elif length < avg_region.shape[0]: - latents[tmp_init_mask.bool()] = avg_region[0:length] - else: - add_dim = length - avg_region.shape[0] - add_dim1 = add_dim // 2 - add_dim2 = add_dim - add_dim1 - one_tensor1 = torch.ones(add_dim1).to(device=latents.device, dtype=latents.dtype) - one_tensor2 = torch.ones(add_dim2).to(device=latents.device, dtype=latents.dtype) - tmp_tensor1 = one_tensor1 * torch.mean(avg_region) - tmp_tensor2 = one_tensor2 * torch.mean(avg_region) - latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) - - latents = latents.view(batch_size, 4096, -1) - - # additional codes for injecting original prod images into latents - if image is not None: - if i < ref_prod_injection_steps: - init_mask = mask - init_latents_proper = image_latents - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise - ) - - latents_1 = init_mask * (ratio_prod * init_latents_proper + (1.0 - ratio_prod) * latents) - latents_2 = (1.0 - init_mask) * latents - - latents = latents_1 + latents_2 - - if image_bg is not None: - if i >= inpainting_starting_step and i <= inpainting_ending_step: - init_mask_bg = mask_bg - init_latents_proper_bg = image_latents_bg - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper_bg = self.scheduler.scale_noise( - init_latents_proper_bg, torch.tensor([noise_timestep]), noise_bg - ) - - latents = (1 - init_mask_bg) * (ratio_bg * init_latents_proper_bg + (1.0 - ratio_bg) * latents) + init_mask_bg * latents - - if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -1617,6 +1197,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (image,init_mask) + return (image,) return FluxPipelineOutput(images=image) \ No newline at end of file From e7313284f197658dfb494e5366679eeaab8dc004 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 9 Jun 2025 23:24:20 +0800 Subject: [PATCH 400/482] Update pipeline_flux_controlnet.py --- .../flux/pipeline_flux_controlnet.py | 407 +++++++++++++++++- 1 file changed, 387 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 8e3586c68494..3944fb8aedfb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1,4 +1,4 @@ -# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.More actions # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -240,6 +240,14 @@ def __init__( # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -469,6 +477,34 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + def check_inputs( self, prompt, @@ -590,6 +626,8 @@ def _unpack_latents(latents, height, width, vae_scale_factor): # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, + image, + timestep, batch_size, num_channels_latents, height, @@ -605,10 +643,29 @@ def prepare_latents( width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if image is not None: + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + image_latents = None + noise = None + + return latents.to(device=device, dtype=dtype), noise, image_latents, latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -616,12 +673,92 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if image is not None: + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + image_latents = None + noise = None + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents, noise, image_latents, latent_image_ids - return latents, latent_image_ids + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image def prepare_image( @@ -682,11 +819,20 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, # original prod image + + ratio_ref: Optional[float] = 0.1, # ratio of original prod images + mask_image: PipelineImageInput = None, # modified for injecting original prod images + + prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies + masked_image_latents: PipelineImageInput = None, true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, + strength: Optional[float] = 0.6, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, + padding_mask_crop: Optional[int] = None, guidance_scale: float = 7.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, @@ -710,6 +856,9 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product + ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + ): r""" Function invoked when calling the pipeline for generation. @@ -813,6 +962,9 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + global_height = height + global_width = width + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): @@ -914,6 +1066,41 @@ def __call__( lora_scale=lora_scale, ) + # Preprocess mask and image + if image is not None: + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + else: + init_image = None + + + + + + + + + + + + + + + + + + # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if isinstance(self.controlnet, FluxControlNetModel): @@ -1003,22 +1190,11 @@ def __call__( control_modes.append(control_mode) control_mode = control_modes - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 5. Prepare timesteps + # Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = latents.shape[1] + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), @@ -1037,6 +1213,118 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + if image is not None: + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + else: + latent_timestep = None + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + + + + + + + + + + + + + + + + + # Prepare mask latents + if image is not None: + mask_condition = self.mask_processor.preprocess( + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + + + + + + + + + + + + + + + + + + + + + + + if prod_masks_original is not None: + mask_conditions_prod = [] + for prod_mask in prod_masks_original: + tmp_mask_condition = self.mask_processor.preprocess( + prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + mask_conditions_prod.append(tmp_mask_condition) + + masks_prod = [] + for tmp_mask_condition in mask_conditions_prod: + tmp_mask, _ = self.prepare_mask_latents( + tmp_mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + masks_prod.append(tmp_mask) + # 6. Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): @@ -1161,6 +1449,85 @@ def __call__( latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + # average multiple product latents + if prod_masks_original is not None: + init_masks_prod = masks_prod + + if i < averaging_steps: + batch_size = latents.shape[0] + latents = latents.view(-1) + masked_regions = [] + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) + + min_length = min(t.size(0) for t in masked_regions) + #truncated = [t[:min_length] for t in masked_regions] + truncated = [] + for tmp_mask_region in masked_regions: + if tmp_mask_region.size(0) == min_length: + truncated.append(tmp_mask_region) + else: + extra_dim = tmp_mask_region.size(0) - min_length + half = extra_dim // 2 + if extra_dim % 2 == 0: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim)]) + else: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim-1)]) + + masked_regions = torch.stack(truncated) + + avg_region = masked_regions.mean(dim=0) + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + length = torch.sum(tmp_init_mask == 1) + if length == avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region + elif length < avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region[0:length] + else: + add_dim = length - avg_region.shape[0] + add_dim1 = add_dim // 2 + add_dim2 = add_dim - add_dim1 + one_tensor1 = torch.ones(add_dim1).to(device=latents.device, dtype=latents.dtype) + one_tensor2 = torch.ones(add_dim2).to(device=latents.device, dtype=latents.dtype) + tmp_tensor1 = one_tensor1 * torch.mean(avg_region) + tmp_tensor2 = one_tensor2 * torch.mean(avg_region) + latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) + + latents = latents.view(batch_size, 4096, -1) + + # additional codes for injecting original prod images into latents + if image is not None: + if i < ref_prod_injection_steps: + init_mask = mask + init_latents_proper = image_latents + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_2 = (1.0 - init_mask) * latents + + latents = latents_1 + latents_2 + + + + + + + + + + + + + + + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 From 96c37e65e6ee12f5c7fb2f80ba578dece9d0b7de Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 11:16:20 +0800 Subject: [PATCH 401/482] print info --- src/diffusers/utils/torch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a5df07e4a3c2..a946a5f8d8bc 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -80,6 +80,7 @@ def randn_tensor( else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + print(f'noise norm = {latents.norm()}') return latents From 757ef7923d6194df4316cf96f0ba8ad8cf893f02 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:03:18 +0800 Subject: [PATCH 402/482] norm must less than 650 --- src/diffusers/utils/torch_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a946a5f8d8bc..923d3932497c 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -79,6 +79,19 @@ def randn_tensor( latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + mean_norm = torch.mean(latents.norm()) + while(mean_norm >=650): + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + mean_norm = torch.mean(latents.norm()) print(f'noise norm = {latents.norm()}') return latents From 175978fc84d1504ac872711282035b01a6bb01f7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:08:19 +0800 Subject: [PATCH 403/482] remove generator in noise generation --- src/diffusers/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 923d3932497c..46cf40acb790 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,7 +90,7 @@ def randn_tensor( ] latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + latents = torch.randn(shape, device=rand_device, dtype=dtype, layout=layout).to(device) mean_norm = torch.mean(latents.norm()) print(f'noise norm = {latents.norm()}') From dbae261d12e8a308206e2e4599483ef3a67b9b6f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:11:52 +0800 Subject: [PATCH 404/482] remove generator in noise generation --- src/diffusers/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 46cf40acb790..7e5ad8ead79f 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -85,7 +85,7 @@ def randn_tensor( if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + torch.randn(shape, device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) From 2ba4c1eb0f728b34641af79846d59ac1d7af0e72 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:12:24 +0800 Subject: [PATCH 405/482] print info --- src/diffusers/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 7e5ad8ead79f..78136517769a 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -93,7 +93,7 @@ def randn_tensor( latents = torch.randn(shape, device=rand_device, dtype=dtype, layout=layout).to(device) mean_norm = torch.mean(latents.norm()) - print(f'noise norm = {latents.norm()}') + print(f'noise norm = {latents.norm()}') return latents From 553bc2c9c39a43533f0266af7c6fee00b6345f42 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:20:40 +0800 Subject: [PATCH 406/482] use random generator behind the scene --- src/diffusers/utils/torch_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 78136517769a..3986e1c47cbb 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -82,15 +82,16 @@ def randn_tensor( mean_norm = torch.mean(latents.norm()) while(mean_norm >=650): + generator = torch.Generator(device==rand_device).manual_seed(random.randint(0, 10000000)) if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ - torch.randn(shape, device=rand_device, dtype=dtype, layout=layout) + torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, device=rand_device, dtype=dtype, layout=layout).to(device) + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) mean_norm = torch.mean(latents.norm()) print(f'noise norm = {latents.norm()}') From 622caf5274e7fcc1f028e7dfcad453dda691143f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:21:41 +0800 Subject: [PATCH 407/482] import random --- src/diffusers/utils/torch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3986e1c47cbb..0f86294df564 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -16,6 +16,7 @@ """ from typing import List, Optional, Tuple, Union +import random from . import logging from .import_utils import is_torch_available, is_torch_version From 1aa761852e753b142ea43f891ccaed78ae7915b2 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:24:57 +0800 Subject: [PATCH 408/482] fix bug --- src/diffusers/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 0f86294df564..c43ed5679f27 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -83,7 +83,7 @@ def randn_tensor( mean_norm = torch.mean(latents.norm()) while(mean_norm >=650): - generator = torch.Generator(device==rand_device).manual_seed(random.randint(0, 10000000)) + generator = torch.Generator(device=rand_device).manual_seed(random.randint(0, 10000000)) if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ From f06073a0888ad6817e80f283135817f0ea2849ea Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:28:37 +0800 Subject: [PATCH 409/482] print info --- src/diffusers/utils/torch_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index c43ed5679f27..26d58b19c187 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -81,6 +81,8 @@ def randn_tensor( else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + print(f'noise shape = {latents.shape}') + """ mean_norm = torch.mean(latents.norm()) while(mean_norm >=650): generator = torch.Generator(device=rand_device).manual_seed(random.randint(0, 10000000)) @@ -94,8 +96,8 @@ def randn_tensor( else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) mean_norm = torch.mean(latents.norm()) - - print(f'noise norm = {latents.norm()}') + """ + print(f'noise norm = {latents.norm()}') return latents From b2f7a3ec7a9844795c4fd3b897bec179d5959c21 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:34:17 +0800 Subject: [PATCH 410/482] print info --- src/diffusers/utils/torch_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 26d58b19c187..28ef0d0b634a 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -97,7 +97,8 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) mean_norm = torch.mean(latents.norm()) """ - print(f'noise norm = {latents.norm()}') + for index in range(latents.shape[0]): + print(f'noise norm {index} = {latents[index].norm()}') return latents From 4419450dbfab01a8e6e9d03f7bb7ce43dd75a37c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:47:59 +0800 Subject: [PATCH 411/482] zero mean --- src/diffusers/utils/torch_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 28ef0d0b634a..c35523629a3f 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -82,6 +82,7 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) print(f'noise shape = {latents.shape}') + """ mean_norm = torch.mean(latents.norm()) while(mean_norm >=650): @@ -97,6 +98,10 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) mean_norm = torch.mean(latents.norm()) """ + for index in range(latents.shape[0]): + print(f'noise norm {index} = {latents[index].norm()}') + + latents = latents - latents.mean() for index in range(latents.shape[0]): print(f'noise norm {index} = {latents[index].norm()}') return latents From a35c1b8655571792381af9a54c852c6211946cba Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 12:55:52 +0800 Subject: [PATCH 412/482] print info --- src/diffusers/utils/torch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index c35523629a3f..4e913e1db910 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -101,6 +101,7 @@ def randn_tensor( for index in range(latents.shape[0]): print(f'noise norm {index} = {latents[index].norm()}') + print(f'latents.mean() = {latents.mean()}') latents = latents - latents.mean() for index in range(latents.shape[0]): print(f'noise norm {index} = {latents[index].norm()}') From d424207d28fd470ba10b2e97d18c0fdd8ff41bc9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 13:04:52 +0800 Subject: [PATCH 413/482] mean noise --- .../pipeline_flux_controlnet_inpainting.py | 8 ++++++- src/diffusers/utils/torch_utils.py | 24 ------------------- 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index cc62f38791f1..74c13776c456 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -605,7 +605,13 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noises = [] + for index in range(5): + tmp_noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noises.append(tmp_noise) + noises = torch.stack(noises,dim=0) + noise = torch.mean(noises, dim=0, keepdim=False) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) else: noise = latents.to(device) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 4e913e1db910..491cc16e39e6 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -81,30 +81,6 @@ def randn_tensor( else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - print(f'noise shape = {latents.shape}') - - """ - mean_norm = torch.mean(latents.norm()) - while(mean_norm >=650): - generator = torch.Generator(device=rand_device).manual_seed(random.randint(0, 10000000)) - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - mean_norm = torch.mean(latents.norm()) - """ - for index in range(latents.shape[0]): - print(f'noise norm {index} = {latents[index].norm()}') - - print(f'latents.mean() = {latents.mean()}') - latents = latents - latents.mean() - for index in range(latents.shape[0]): - print(f'noise norm {index} = {latents[index].norm()}') return latents From 31cd15e1383cff4428ec86f6f772ad52d38f057d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 14:57:29 +0800 Subject: [PATCH 414/482] Revert "mean noise" This reverts commit d424207d28fd470ba10b2e97d18c0fdd8ff41bc9. --- .../pipeline_flux_controlnet_inpainting.py | 8 +------ src/diffusers/utils/torch_utils.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 74c13776c456..cc62f38791f1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -605,13 +605,7 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) if latents is None: - noises = [] - for index in range(5): - tmp_noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - noises.append(tmp_noise) - noises = torch.stack(noises,dim=0) - noise = torch.mean(noises, dim=0, keepdim=False) - + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) else: noise = latents.to(device) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 491cc16e39e6..4e913e1db910 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -81,6 +81,30 @@ def randn_tensor( else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + print(f'noise shape = {latents.shape}') + + """ + mean_norm = torch.mean(latents.norm()) + while(mean_norm >=650): + generator = torch.Generator(device=rand_device).manual_seed(random.randint(0, 10000000)) + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + mean_norm = torch.mean(latents.norm()) + """ + for index in range(latents.shape[0]): + print(f'noise norm {index} = {latents[index].norm()}') + + print(f'latents.mean() = {latents.mean()}') + latents = latents - latents.mean() + for index in range(latents.shape[0]): + print(f'noise norm {index} = {latents[index].norm()}') return latents From 5c9404fd9b143a6887fcc0cc2285b275c5338415 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 14:57:57 +0800 Subject: [PATCH 415/482] Reapply "mean noise" This reverts commit 31cd15e1383cff4428ec86f6f772ad52d38f057d. --- .../pipeline_flux_controlnet_inpainting.py | 8 ++++++- src/diffusers/utils/torch_utils.py | 24 ------------------- 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index cc62f38791f1..74c13776c456 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -605,7 +605,13 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noises = [] + for index in range(5): + tmp_noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noises.append(tmp_noise) + noises = torch.stack(noises,dim=0) + noise = torch.mean(noises, dim=0, keepdim=False) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) else: noise = latents.to(device) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 4e913e1db910..491cc16e39e6 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -81,30 +81,6 @@ def randn_tensor( else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - print(f'noise shape = {latents.shape}') - - """ - mean_norm = torch.mean(latents.norm()) - while(mean_norm >=650): - generator = torch.Generator(device=rand_device).manual_seed(random.randint(0, 10000000)) - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - mean_norm = torch.mean(latents.norm()) - """ - for index in range(latents.shape[0]): - print(f'noise norm {index} = {latents[index].norm()}') - - print(f'latents.mean() = {latents.mean()}') - latents = latents - latents.mean() - for index in range(latents.shape[0]): - print(f'noise norm {index} = {latents[index].norm()}') return latents From 0bf26c49ca92f41fe84888abbb23844888e23acf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 10 Jun 2025 14:59:13 +0800 Subject: [PATCH 416/482] remove mean noise --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 74c13776c456..cc62f38791f1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -605,13 +605,7 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) if latents is None: - noises = [] - for index in range(5): - tmp_noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - noises.append(tmp_noise) - noises = torch.stack(noises,dim=0) - noise = torch.mean(noises, dim=0, keepdim=False) - + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) else: noise = latents.to(device) From 530442b196248872a1005871a3b609448e38fa16 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 13 Jun 2025 14:39:15 +0800 Subject: [PATCH 417/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index cc62f38791f1..37aa947b9099 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -752,7 +752,7 @@ def __call__( image: PipelineImageInput = None, image_ref_prod: Optional[PipelineImageInput] = None, # original prod image ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image - mask_image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, # original mask image for inpainting mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images masked_image_latents: PipelineImageInput = None, @@ -1078,7 +1078,7 @@ def __call__( ) if image_ref_prod is not None: - _, _, image_latents_ref, _ = self.prepare_latents( + _, noise_ref, image_latents_ref, _ = self.prepare_latents( init_image_ref, latent_timestep, batch_size * num_images_per_prompt, @@ -1385,13 +1385,13 @@ def mask_gradienting(mask, iterations=3): init_latents_proper_ref = image_latents_ref init_latents_proper_ref = self.scheduler.scale_noise( - init_latents_proper_ref, torch.tensor([noise_timestep]), noise + init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref ) if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper_ref = self.scheduler.scale_noise( - init_latents_proper_ref, torch.tensor([noise_timestep]), noise + init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref ) latents_1 = init_mask_ref_prod * (ratio_ref_prod * init_latents_proper_ref + (1.0 - ratio_ref_prod) * latents) From d5b49a9085ebc857bb0391fd884cc00b526befd8 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 13 Jun 2025 14:45:13 +0800 Subject: [PATCH 418/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 37aa947b9099..e97e8487d4c4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1384,10 +1384,6 @@ def mask_gradienting(mask, iterations=3): init_mask_ref_prod = mask_original init_latents_proper_ref = image_latents_ref - init_latents_proper_ref = self.scheduler.scale_noise( - init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref - ) - if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper_ref = self.scheduler.scale_noise( From 429e6cc1d095a919c2bd5c0b563b092d3172011c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 14 Jun 2025 20:36:55 +0800 Subject: [PATCH 419/482] return mask_gradient --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index e97e8487d4c4..8c05571b78de 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1435,6 +1435,6 @@ def mask_gradienting(mask, iterations=3): self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (image,mask_gradient) return FluxPipelineOutput(images=image) From 377bdfce5ca34ba3ae70e98b750d27a1d1555b4c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sat, 14 Jun 2025 21:09:38 +0800 Subject: [PATCH 420/482] refine mask gradient --- .../pipeline_flux_controlnet_inpainting.py | 60 +++++++++++++++++-- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 8c05571b78de..938ca112d1f2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1214,13 +1214,63 @@ def mask_gradienting(mask, iterations=3): gray3_array = gray3_array * dilated_mask3_image_bool mask4_array = np.where(mask3_array>0,mask3_array,gray3_array) + # gradent 5 + dilated_mask4_image = apply_dilate_to_mask_image(mask4_array,iterations = iterations) + mask4_array_bool = np.array(mask4_array,dtype=bool) + dilated_mask4_image_bool = np.array(dilated_mask4_image, dtype=bool) + gray4_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 40 + gray4_array = gray4_array * (~mask4_array_bool) + gray4_array = gray4_array * dilated_mask4_image_bool + mask5_array = np.where(mask4_array>0,mask4_array,gray4_array) + + # gradent 6 + dilated_mask5_image = apply_dilate_to_mask_image(mask5_array,iterations = iterations) + mask5_array_bool = np.array(mask5_array,dtype=bool) + dilated_mask5_image_bool = np.array(dilated_mask5_image, dtype=bool) + gray5_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 30 + gray5_array = gray5_array * (~mask5_array_bool) + gray5_array = gray5_array * dilated_mask5_image_bool + mask6_array = np.where(mask5_array>0,mask5_array,gray5_array) + + # gradent 7 + dilated_mask6_image = apply_dilate_to_mask_image(mask6_array,iterations = iterations) + mask6_array_bool = np.array(mask6_array,dtype=bool) + dilated_mask6_image_bool = np.array(dilated_mask6_image, dtype=bool) + gray6_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 20 + gray6_array = gray6_array * (~mask6_array_bool) + gray6_array = gray6_array * dilated_mask6_image_bool + mask7_array = np.where(mask6_array>0,mask6_array,gray6_array) + + # gradent 8 + dilated_mask7_image = apply_dilate_to_mask_image(mask7_array,iterations = iterations) + mask7_array_bool = np.array(mask7_array,dtype=bool) + dilated_mask7_image_bool = np.array(dilated_mask7_image, dtype=bool) + gray7_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 10 + gray7_array = gray7_array * (~mask7_array_bool) + gray7_array = gray7_array * dilated_mask7_image_bool + mask8_array = np.where(mask7_array>0,mask7_array,gray7_array) + + # gradent 9 + dilated_mask8_image = apply_dilate_to_mask_image(mask8_array,iterations = iterations) + mask8_array_bool = np.array(mask8_array,dtype=bool) + dilated_mask8_image_bool = np.array(dilated_mask8_image, dtype=bool) + gray8_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 5 + gray8_array = gray8_array * (~mask8_array_bool) + gray8_array = gray8_array * dilated_mask8_image_bool + mask9_array = np.where(mask8_array>0,mask8_array,gray8_array) + # rescale to 1.0, 0.8, 0.6, 0.4, 0.2, 0.0 final_mask_array = np.zeros_like(mask_array) - final_mask_array = np.where(mask4_array==255, 1.0, final_mask_array) - final_mask_array = np.where(mask4_array==200, 0.8, final_mask_array) - final_mask_array = np.where(mask4_array==150, 0.6, final_mask_array) - final_mask_array = np.where(mask4_array==100, 0.4, final_mask_array) - final_mask_array = np.where(mask4_array==50, 0.2, final_mask_array) + final_mask_array = np.where(mask9_array==255, 1.0, final_mask_array) + final_mask_array = np.where(mask9_array==200, 0.9, final_mask_array) + final_mask_array = np.where(mask9_array==150, 0.8, final_mask_array) + final_mask_array = np.where(mask9_array==100, 0.7, final_mask_array) + final_mask_array = np.where(mask9_array==50, 0.6, final_mask_array) + final_mask_array = np.where(mask9_array==40, 0.5, final_mask_array) + final_mask_array = np.where(mask9_array==30, 0.4, final_mask_array) + final_mask_array = np.where(mask9_array==20, 0.3, final_mask_array) + final_mask_array = np.where(mask9_array==10, 0.2, final_mask_array) + final_mask_array = np.where(mask9_array==5, 0.1, final_mask_array) return final_mask_array From c61d2a0e063f04aa00c11cbf20a23ea8d2454790 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Sun, 15 Jun 2025 11:37:20 +0800 Subject: [PATCH 421/482] adjust default values --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 938ca112d1f2..69bdb195c01f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -779,7 +779,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - iterations: Optional[int] = 4, # modified for applying gradient to mask_image + iterations: Optional[int] = 1, # modified for applying gradient to mask_image averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting @@ -1445,7 +1445,7 @@ def mask_gradienting(mask, iterations=3): latents = latents_1 + latents_2 - if i >= inpainting_starting_step and i < inpainting_ending_step: + if i > inpainting_starting_step and i < inpainting_ending_step: latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: From ddeb1148e6634c242e3864f3d9a70b55cfc87633 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 18 Jun 2025 11:40:18 +0800 Subject: [PATCH 422/482] add blend preprocessing logic --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 11b1ac76b8dc..96aace2044ed 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -401,6 +401,7 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots + is_blend: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots is_inpainting: Optional[bool] = False, # controlnet inpainting @@ -659,6 +660,13 @@ def __call__( prompt_embeds = image_embeds_bg for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) + elif is_blend: + if (len(prompt_embeds_list) - 1) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" + ) + + prompt_embeds = image_embeds_bg else: prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) From 67c4c0f1aaca986b7e1ba491478ffb530b78904d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 18 Jun 2025 11:46:20 +0800 Subject: [PATCH 423/482] add support for blend preprocessing --- .../flux/pipeline_flux_prior_redux.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 96aace2044ed..cd4ab269b337 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -401,7 +401,7 @@ def __call__( prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots - is_blend: Optional[bool] = False, # thesea modified for quick validation of product shots + for_blend_preprocessing: Optional[bool] = False, # thesea modified for quick validation of product shots is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots is_inpainting: Optional[bool] = False, # controlnet inpainting @@ -652,21 +652,22 @@ def __call__( # scale & concatenate image and text embeddings if is_qv: - if len(prompt_embeds_list) != len(image_embeds_prods): - raise ValueError( - f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" - ) - - prompt_embeds = image_embeds_bg - for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): - prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) - elif is_blend: - if (len(prompt_embeds_list) - 1) != len(image_embeds_prods): - raise ValueError( - f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" - ) - - prompt_embeds = image_embeds_bg + if for_blend_preprocessing: + if (len(prompt_embeds_list) - 1) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" + ) + + prompt_embeds = image_embeds_bg + else: + if len(prompt_embeds_list) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" + ) + + prompt_embeds = image_embeds_bg + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) else: prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) From 3ad51aae40c1f2e654c8ebcfdb61f115a51bdca3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 18 Jun 2025 14:44:22 +0800 Subject: [PATCH 424/482] adjust latents when ref_prod_injection_steps=0 --- .../flux/pipeline_flux_controlnet.py | 87 +++---------------- 1 file changed, 12 insertions(+), 75 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 3944fb8aedfb..777ef88c653b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -636,6 +636,7 @@ def prepare_latents( device, generator, latents=None, + ref_prod_injection_steps = 0, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. @@ -674,12 +675,17 @@ def prepare_latents( ) if image is not None: - if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + if ref_prod_injection_steps == 0: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + image_latents = None + noise = None + else: + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) - noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) - image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) else: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image_latents = None @@ -858,7 +864,6 @@ def __call__( max_sequence_length: int = 512, averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images - ): r""" Function invoked when calling the pipeline for generation. @@ -1083,23 +1088,6 @@ def __call__( init_image = init_image.to(dtype=torch.float32) else: init_image = None - - - - - - - - - - - - - - - - - # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 @@ -1238,24 +1226,9 @@ def __call__( device, generator, latents, + ref_prod_injection_steps, ) - - - - - - - - - - - - - - - - # Prepare mask latents if image is not None: mask_condition = self.mask_processor.preprocess( @@ -1279,28 +1252,6 @@ def __call__( generator, ) - - - - - - - - - - - - - - - - - - - - - - if prod_masks_original is not None: mask_conditions_prod = [] for prod_mask in prod_masks_original: @@ -1514,20 +1465,6 @@ def __call__( latents = latents_1 + latents_2 - - - - - - - - - - - - - - if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 From 7a6f48336351c141c2dd59e86b8f218d03d73c77 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 19 Jun 2025 11:31:50 +0800 Subject: [PATCH 425/482] support multiple prompt in ctln inpaint --- .../pipeline_flux_controlnet_inpainting.py | 48 +++++++++++++++---- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 69bdb195c01f..a319ad419e54 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -913,16 +913,44 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) + + prompt_embeds_list = [] + text_ids_list = [] + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=pmt, + prompt_2=prompt_2, + #prompt_embeds=prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + prompt_embeds_list.append(prompt_embeds) + text_ids_list.append(text_ids) + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + text_ids = torch.cat(text_ids_list, dim=0) + else: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) # 4. Preprocess mask and image if padding_mask_crop is not None: From 4e2a03a63844c1e672a9ad1815c36c668f2aede6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 19 Jun 2025 12:03:04 +0800 Subject: [PATCH 426/482] print info --- src/diffusers/models/controlnets/controlnet_flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 2d318aefa28b..8a8958012b93 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -291,6 +291,7 @@ def forward( guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None + print(f'pooled_projections shape = {pooled_projections.shape}') temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None From 070c87e117d21361398db7c26768ff08c3506509 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 19 Jun 2025 12:57:00 +0800 Subject: [PATCH 427/482] print info --- src/diffusers/models/controlnets/controlnet_flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 8a8958012b93..ddbabd7e473a 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -292,6 +292,7 @@ def forward( else: guidance = None print(f'pooled_projections shape = {pooled_projections.shape}') + print(f'timestep = {timestep}') temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None From 91108970fbe85e95d810ac56e24c667a8eca8fe5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 19 Jun 2025 13:07:23 +0800 Subject: [PATCH 428/482] print info --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index a319ad419e54..1a3c93f0e83b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1329,7 +1329,8 @@ def mask_gradienting(mask, iterations=3): for i, t in enumerate(timesteps): if self.interrupt: continue - + + print(f'latents.shape[0] = {latents.shape[0]}') timestep = t.expand(latents.shape[0]).to(latents.dtype) # predict the noise residual From 945179a55fa2ceb699b3e6f9d637f98336238549 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 19 Jun 2025 13:11:54 +0800 Subject: [PATCH 429/482] fix bug --- src/diffusers/models/controlnets/controlnet_flux.py | 3 +-- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index ddbabd7e473a..1a0151a19662 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -291,8 +291,7 @@ def forward( guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None - print(f'pooled_projections shape = {pooled_projections.shape}') - print(f'timestep = {timestep}') + temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 1a3c93f0e83b..940f588385f0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -902,7 +902,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + batch_size = 1 #len(prompt) # thesea modified for text prompt mask else: batch_size = prompt_embeds.shape[0] @@ -1330,7 +1330,6 @@ def mask_gradienting(mask, iterations=3): if self.interrupt: continue - print(f'latents.shape[0] = {latents.shape[0]}') timestep = t.expand(latents.shape[0]).to(latents.dtype) # predict the noise residual From dfafb1a8830ea4ddcd1d463c30fefab905c70207 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 20:22:42 +0800 Subject: [PATCH 430/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 940f588385f0..54044e8cd24d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1116,7 +1116,7 @@ def __call__( prompt_embeds.dtype, device, generator, - latents, + latents=None, ) # 8. Prepare mask latents From c427fbbfe7bc9b0d2796570db9d564534022d79d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 20:40:27 +0800 Subject: [PATCH 431/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 54044e8cd24d..b40c726bac0a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1463,7 +1463,7 @@ def mask_gradienting(mask, iterations=3): init_latents_proper_ref = image_latents_ref if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] + noise_timestep = timesteps[i + 2] init_latents_proper_ref = self.scheduler.scale_noise( init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref ) From 6b82e0bfbf33b0177afcf80bb7433f77deb61cf7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 20:43:42 +0800 Subject: [PATCH 432/482] use bg noise for ref img --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index b40c726bac0a..fe7d14c45ef8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1463,9 +1463,9 @@ def mask_gradienting(mask, iterations=3): init_latents_proper_ref = image_latents_ref if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 2] + noise_timestep = timesteps[i + 1] init_latents_proper_ref = self.scheduler.scale_noise( - init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref + init_latents_proper_ref, torch.tensor([noise_timestep]), noise ) latents_1 = init_mask_ref_prod * (ratio_ref_prod * init_latents_proper_ref + (1.0 - ratio_ref_prod) * latents) From d234c53711cac80ddd4473dc72ef0f93a2d64581 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 21:05:35 +0800 Subject: [PATCH 433/482] Revert "use bg noise for ref img" This reverts commit 6b82e0bfbf33b0177afcf80bb7433f77deb61cf7. --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index fe7d14c45ef8..b40c726bac0a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1463,9 +1463,9 @@ def mask_gradienting(mask, iterations=3): init_latents_proper_ref = image_latents_ref if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] + noise_timestep = timesteps[i + 2] init_latents_proper_ref = self.scheduler.scale_noise( - init_latents_proper_ref, torch.tensor([noise_timestep]), noise + init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref ) latents_1 = init_mask_ref_prod * (ratio_ref_prod * init_latents_proper_ref + (1.0 - ratio_ref_prod) * latents) From e6a462aee252364e220ba41e38aee51cd3c105f9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 21:05:51 +0800 Subject: [PATCH 434/482] Revert "fix bug" This reverts commit c427fbbfe7bc9b0d2796570db9d564534022d79d. --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index b40c726bac0a..54044e8cd24d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1463,7 +1463,7 @@ def mask_gradienting(mask, iterations=3): init_latents_proper_ref = image_latents_ref if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 2] + noise_timestep = timesteps[i + 1] init_latents_proper_ref = self.scheduler.scale_noise( init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref ) From d1e3a2186e08441c5a25b46b1382dee7111372c0 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 21:05:59 +0800 Subject: [PATCH 435/482] Revert "fix bug" This reverts commit dfafb1a8830ea4ddcd1d463c30fefab905c70207. --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 54044e8cd24d..940f588385f0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1116,7 +1116,7 @@ def __call__( prompt_embeds.dtype, device, generator, - latents=None, + latents, ) # 8. Prepare mask latents From 191ec055ea59746c1fa90338fcdd40579cba96f4 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 21:13:05 +0800 Subject: [PATCH 436/482] add fixed script --- ...peline_flux_controlnet_inpainting_fixed.py | 1520 +++++++++++++++++ 1 file changed, 1520 insertions(+) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py new file mode 100644 index 000000000000..a47abfda68ff --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py @@ -0,0 +1,1520 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import PIL +import torch +import copy +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetInpaintPipeline + >>> from diffusers.models import FluxControlNetModel + >>> from diffusers.utils import load_image + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> control_image = load_image( + ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ... ) + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + ... ) + >>> mask_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + ... ) + + >>> prompt = "A girl holding a sign that says InstantX" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=0.7, + ... strength=0.7, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image", "mask", "masked_image_latents"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: Union[ + FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel + ], + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) + + self.register_modules( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + controlnet=controlnet, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + inpainting_starting_step = 0, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + if inpainting_starting_step < 0: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + image_ref_prod: Optional[PipelineImageInput] = None, # original prod image + ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image + mask_image: PipelineImageInput = None, # original mask image for inpainting + mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images + prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images + masked_image_latents: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.6, + padding_mask_crop: Optional[int] = None, + sigmas: Optional[List[float]] = None, + num_inference_steps: int = 28, + guidance_scale: float = 7.0, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + iterations: Optional[int] = 1, # modified for applying gradient to mask_image + averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product + ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting + inpainting_ending_step: Optional[int] = 0, # modified for starting inpainting + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to inpaint. + mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels + will be preserved. + masked_image_latents (`torch.FloatTensor`, *optional*): + Pre-generated masked image latents. + control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. + padding_mask_crop (`int`, *optional*): + The size of the padding to use when cropping the mask. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `List[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + global_height = height + global_width = width + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = 1 #len(prompt) # thesea modified for text prompt mask + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + # 3. Encode input prompt + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + prompt_embeds_list = [] + text_ids_list = [] + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=pmt, + prompt_2=prompt_2, + #prompt_embeds=prompt_embeds, + #pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + prompt_embeds_list.append(prompt_embeds) + text_ids_list.append(text_ids) + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) + text_ids = torch.cat(text_ids_list, dim=0) + else: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + if image_ref_prod is not None: + init_image_ref = self.image_processor.preprocess( + image_ref_prod, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image_ref = init_image_ref.to(dtype=torch.float32) + + # 5. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # set control mode + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + # set control mode + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + # 6. Prepare timesteps + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # 7. Prepare latent variables + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + inpainting_starting_step, + ) + + if image_ref_prod is not None: + _, noise_ref, image_latents_ref, _ = self.prepare_latents( + init_image_ref, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare mask latents + mask_condition = self.mask_processor.preprocess( + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + # handel prod masks + if prod_masks_original is not None: + mask_conditions_prod = [] + for prod_mask in prod_masks_original: + tmp_mask_condition = self.mask_processor.preprocess( + prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + mask_conditions_prod.append(tmp_mask_condition) + + + if image_ref_prod is not None: + mask_condition_original = self.mask_processor.preprocess( + mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image_original = init_image * (mask_condition_original < 0.5) + else: + masked_image_original = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + if prod_masks_original is not None: + masks_prod = [] + for tmp_mask_condition in mask_conditions_prod: + tmp_mask, _ = self.prepare_mask_latents( + tmp_mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + masks_prod.append(tmp_mask) + + if image_ref_prod is not None: + mask_original, _ = self.prepare_mask_latents( + mask_condition_original, + masked_image_original, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + def apply_dilate_to_mask_image(mask,iterations): + kernel = np.ones((3, 3), np.uint8) + mask = np.array(mask, dtype=bool) + mask = mask.astype(np.uint8) + mask = cv2.dilate(mask, kernel, iterations=iterations) + mask = mask.astype(np.uint8) + #mask = Image.fromarray(mask * 255) + + return mask + + # input np array size: 4096 x 3 + # output np array size: 64 x 64 x 3 + def mask_gradienting(mask, iterations=3): + mask_array = mask.reshape(64,64,-1) + + # gradent 1 + dilated_mask_image = apply_dilate_to_mask_image(mask_array,iterations = iterations) + original_mask_array_bool = np.array(mask_array,dtype=bool) + dilated_mask_array_bool = np.array(dilated_mask_image, dtype=bool) + gray_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 200 + gray_array = gray_array * (~original_mask_array_bool) + gray_array = gray_array * dilated_mask_array_bool + mask1_array = np.where(mask_array>0, mask_array* 255, gray_array) + + # gradent 2 + dilated_mask1_image = apply_dilate_to_mask_image(mask1_array,iterations = iterations) + mask1_array_bool = np.array(mask1_array,dtype=bool) + dilated_mask1_image_bool = np.array(dilated_mask1_image, dtype=bool) + gray1_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 150 + gray1_array = gray1_array * (~mask1_array_bool) + gray1_array = gray1_array * dilated_mask1_image_bool + mask2_array = np.where(mask1_array>0,mask1_array,gray1_array) + + # gradent 3 + dilated_mask2_image = apply_dilate_to_mask_image(mask2_array,iterations = iterations) + mask2_array_bool = np.array(mask2_array,dtype=bool) + dilated_mask2_image_bool = np.array(dilated_mask2_image, dtype=bool) + gray2_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 100 + gray2_array = gray2_array * (~mask2_array_bool) + gray2_array = gray2_array * dilated_mask2_image_bool + mask3_array = np.where(mask2_array>0,mask2_array,gray2_array) + + # gradent 4 + dilated_mask3_image = apply_dilate_to_mask_image(mask3_array,iterations = iterations) + mask3_array_bool = np.array(mask3_array,dtype=bool) + dilated_mask3_image_bool = np.array(dilated_mask3_image, dtype=bool) + gray3_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 50 + gray3_array = gray3_array * (~mask3_array_bool) + gray3_array = gray3_array * dilated_mask3_image_bool + mask4_array = np.where(mask3_array>0,mask3_array,gray3_array) + + # gradent 5 + dilated_mask4_image = apply_dilate_to_mask_image(mask4_array,iterations = iterations) + mask4_array_bool = np.array(mask4_array,dtype=bool) + dilated_mask4_image_bool = np.array(dilated_mask4_image, dtype=bool) + gray4_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 40 + gray4_array = gray4_array * (~mask4_array_bool) + gray4_array = gray4_array * dilated_mask4_image_bool + mask5_array = np.where(mask4_array>0,mask4_array,gray4_array) + + # gradent 6 + dilated_mask5_image = apply_dilate_to_mask_image(mask5_array,iterations = iterations) + mask5_array_bool = np.array(mask5_array,dtype=bool) + dilated_mask5_image_bool = np.array(dilated_mask5_image, dtype=bool) + gray5_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 30 + gray5_array = gray5_array * (~mask5_array_bool) + gray5_array = gray5_array * dilated_mask5_image_bool + mask6_array = np.where(mask5_array>0,mask5_array,gray5_array) + + # gradent 7 + dilated_mask6_image = apply_dilate_to_mask_image(mask6_array,iterations = iterations) + mask6_array_bool = np.array(mask6_array,dtype=bool) + dilated_mask6_image_bool = np.array(dilated_mask6_image, dtype=bool) + gray6_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 20 + gray6_array = gray6_array * (~mask6_array_bool) + gray6_array = gray6_array * dilated_mask6_image_bool + mask7_array = np.where(mask6_array>0,mask6_array,gray6_array) + + # gradent 8 + dilated_mask7_image = apply_dilate_to_mask_image(mask7_array,iterations = iterations) + mask7_array_bool = np.array(mask7_array,dtype=bool) + dilated_mask7_image_bool = np.array(dilated_mask7_image, dtype=bool) + gray7_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 10 + gray7_array = gray7_array * (~mask7_array_bool) + gray7_array = gray7_array * dilated_mask7_image_bool + mask8_array = np.where(mask7_array>0,mask7_array,gray7_array) + + # gradent 9 + dilated_mask8_image = apply_dilate_to_mask_image(mask8_array,iterations = iterations) + mask8_array_bool = np.array(mask8_array,dtype=bool) + dilated_mask8_image_bool = np.array(dilated_mask8_image, dtype=bool) + gray8_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 5 + gray8_array = gray8_array * (~mask8_array_bool) + gray8_array = gray8_array * dilated_mask8_image_bool + mask9_array = np.where(mask8_array>0,mask8_array,gray8_array) + + # rescale to 1.0, 0.8, 0.6, 0.4, 0.2, 0.0 + final_mask_array = np.zeros_like(mask_array) + final_mask_array = np.where(mask9_array==255, 1.0, final_mask_array) + final_mask_array = np.where(mask9_array==200, 0.9, final_mask_array) + final_mask_array = np.where(mask9_array==150, 0.8, final_mask_array) + final_mask_array = np.where(mask9_array==100, 0.7, final_mask_array) + final_mask_array = np.where(mask9_array==50, 0.6, final_mask_array) + final_mask_array = np.where(mask9_array==40, 0.5, final_mask_array) + final_mask_array = np.where(mask9_array==30, 0.4, final_mask_array) + final_mask_array = np.where(mask9_array==20, 0.3, final_mask_array) + final_mask_array = np.where(mask9_array==10, 0.2, final_mask_array) + final_mask_array = np.where(mask9_array==5, 0.1, final_mask_array) + + return final_mask_array + + # mask graident + mask = mask.to(dtype=torch.float16) + tmp_mask = mask[0,:,:].cpu().numpy() + mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) + tmp_tensor = torch.Tensor(mask_gradient) + tmp_tensor = tmp_tensor.view(-1,64) + mask_tensor = torch.zeros_like(mask) + for index in range(mask_tensor.shape[0]): + mask_tensor[index,:,:] = tmp_tensor + mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # predict the noise residual + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + if use_guidance: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # For inpainting, we need to apply the mask and add the masked image latents + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + # average multiple product latents + if prod_masks_original is not None: + init_masks_prod = masks_prod + if i < averaging_steps: + batch_size = latents.shape[0] + latents = latents.view(-1) + masked_regions = [] + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) + + min_length = min(t.size(0) for t in masked_regions) + #truncated = [t[:min_length] for t in masked_regions] + truncated = [] + for tmp_mask_region in masked_regions: + if tmp_mask_region.size(0) == min_length: + truncated.append(tmp_mask_region) + else: + extra_dim = tmp_mask_region.size(0) - min_length + half = extra_dim // 2 + if extra_dim % 2 == 0: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim)]) + else: + truncated.append(tmp_mask_region[half:(min_length-half+extra_dim-1)]) + + masked_regions = torch.stack(truncated) + + avg_region = masked_regions.mean(dim=0) + for tmp_init_mask in init_masks_prod: + tmp_init_mask = tmp_init_mask.view(-1) + length = torch.sum(tmp_init_mask == 1) + if length == avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region + elif length < avg_region.shape[0]: + latents[tmp_init_mask.bool()] = avg_region[0:length] + else: + add_dim = length - avg_region.shape[0] + add_dim1 = add_dim // 2 + add_dim2 = add_dim - add_dim1 + one_tensor1 = torch.ones(add_dim1).to(device=latents.device, dtype=latents.dtype) + one_tensor2 = torch.ones(add_dim2).to(device=latents.device, dtype=latents.dtype) + tmp_tensor1 = one_tensor1 * torch.mean(avg_region) + tmp_tensor2 = one_tensor2 * torch.mean(avg_region) + latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) + + latents = latents.view(batch_size, 4096, -1) + + # inject product raw images into latents + #if image_ref_prod is not None: + # if i < ref_prod_injection_steps: + # latents_1 = (1 - init_mask) * init_latents_proper + # latents_2 = (init_mask - init_mask_ref_prod) * latents + # latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref + + # latents = latents_1 + latents_2 + latents_3 + + if image_ref_prod is not None: + if i < ref_prod_injection_steps: + init_mask_ref_prod = mask_original + init_latents_proper_ref = image_latents_ref + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper_ref = self.scheduler.scale_noise( + init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref + ) + + latents_1 = init_mask_ref_prod * (ratio_ref_prod * init_latents_proper_ref + (1.0 - ratio_ref_prod) * latents) + latents_2 = (1.0 - init_mask_ref_prod) * latents + + latents = latents_1 + latents_2 + + if i < ref_prod_injection_steps: + latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Post-processing + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,mask_gradient) + + return FluxPipelineOutput(images=image) From 7b5d8ad773f971194e47da759a55df089cbe3a3e Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 27 Jun 2025 21:19:13 +0800 Subject: [PATCH 437/482] ctln inpaint difference logic --- .../pipeline_flux_controlnet_inpainting.py | 10 +- ...peline_flux_controlnet_inpainting_fixed.py | 1520 ----------------- 2 files changed, 7 insertions(+), 1523 deletions(-) delete mode 100644 src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 940f588385f0..475decd68003 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1456,7 +1456,8 @@ def mask_gradienting(mask, iterations=3): # latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref # latents = latents_1 + latents_2 + latents_3 - + + """ if image_ref_prod is not None: if i < ref_prod_injection_steps: init_mask_ref_prod = mask_original @@ -1472,8 +1473,11 @@ def mask_gradienting(mask, iterations=3): latents_2 = (1.0 - init_mask_ref_prod) * latents latents = latents_1 + latents_2 - - if i > inpainting_starting_step and i < inpainting_ending_step: + """ + if i < ref_prod_injection_steps: + init_mask_ref_prod = mask_original + latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents + else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py deleted file mode 100644 index a47abfda68ff..000000000000 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting_fixed.py +++ /dev/null @@ -1,1520 +0,0 @@ -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import cv2 -import numpy as np -import PIL -import torch -import copy -from transformers import ( - CLIPTextModel, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) - -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel -from ...models.transformers import FluxTransformer2DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import FluxPipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import FluxControlNetInpaintPipeline - >>> from diffusers.models import FluxControlNetModel - >>> from diffusers.utils import load_image - - >>> controlnet = FluxControlNetModel.from_pretrained( - ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16 - ... ) - >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( - ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 - ... ) - >>> pipe.to("cuda") - - >>> control_image = load_image( - ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" - ... ) - >>> init_image = load_image( - ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - ... ) - >>> mask_image = load_image( - ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" - ... ) - - >>> prompt = "A girl holding a sign that says InstantX" - >>> image = pipe( - ... prompt, - ... image=init_image, - ... mask_image=mask_image, - ... control_image=control_image, - ... control_guidance_start=0.2, - ... control_guidance_end=0.8, - ... controlnet_conditioning_scale=0.7, - ... strength=0.7, - ... num_inference_steps=28, - ... guidance_scale=3.5, - ... ).images[0] - >>> image.save("flux_controlnet_inpaint.png") - ``` -""" - - -# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.15, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): - r""" - The Flux controlnet pipeline for inpainting. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image", "mask", "masked_image_latents"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5TokenizerFast, - transformer: FluxTransformer2DModel, - controlnet: Union[ - FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel - ], - ): - super().__init__() - if isinstance(controlnet, (list, tuple)): - controlnet = FluxMultiControlNetModel(controlnet) - - self.register_modules( - scheduler=scheduler, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - transformer=transformer, - controlnet=controlnet, - ) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 - # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible - # by the patch size. So the vae scale factor is multiplied by the patch size to account for this - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=latent_channels, - do_normalize=False, - do_binarize=True, - do_convert_grayscale=True, - ) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = 128 - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - return image_latents - - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * strength, num_inference_steps) - - t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - def check_inputs( - self, - prompt, - prompt_2, - image, - mask_image, - strength, - height, - width, - output_type, - prompt_embeds=None, - pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - padding_mask_crop=None, - max_sequence_length=None, - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: - logger.warning( - f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" - ) - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if padding_mask_crop is not None: - if not isinstance(image, PIL.Image.Image): - raise ValueError( - f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." - ) - if not isinstance(mask_image, PIL.Image.Image): - raise ValueError( - f"The mask image should be a PIL image when inpainting mask crop, but is of type" - f" {type(mask_image)}." - ) - if output_type != "pil": - raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - - def prepare_latents( - self, - image, - timestep, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - inpainting_starting_step = 0, - ): - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) - else: - noise = latents.to(device) - latents = noise - - if inpainting_starting_step < 0: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - noise = latents.to(device) - latents = noise - - noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) - image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, noise, image_latents, latent_image_ids - - def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - height, - width, - dtype, - device, - generator, - ): - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height, width)) - mask = mask.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - masked_image = masked_image.to(device=device, dtype=dtype) - - if masked_image.shape[1] == 16: - masked_image_latents = masked_image - else: - masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) - - masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - masked_image_latents = self._pack_latents( - masked_image_latents, - batch_size, - num_channels_latents, - height, - width, - ) - mask = self._pack_latents( - mask.repeat(1, num_channels_latents, 1, 1), - batch_size, - num_channels_latents, - height, - width, - ) - - return mask, masked_image_latents - - # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - image: PipelineImageInput = None, - image_ref_prod: Optional[PipelineImageInput] = None, # original prod image - ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image - mask_image: PipelineImageInput = None, # original mask image for inpainting - mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images - prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images - masked_image_latents: PipelineImageInput = None, - control_image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - strength: float = 0.6, - padding_mask_crop: Optional[int] = None, - sigmas: Optional[List[float]] = None, - num_inference_steps: int = 28, - guidance_scale: float = 7.0, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, - control_mode: Optional[Union[int, List[int]]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - iterations: Optional[int] = 1, # modified for applying gradient to mask_image - averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product - ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images - inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting - inpainting_ending_step: Optional[int] = 0, # modified for starting inpainting - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. - image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): - The image(s) to inpaint. - mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): - The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels - will be preserved. - masked_image_latents (`torch.FloatTensor`, *optional*): - Pre-generated masked image latents. - control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): - The ControlNet input condition. Image to control the generation. - height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - strength (`float`, *optional*, defaults to 0.6): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. - padding_mask_crop (`int`, *optional*): - The size of the padding to use when cropping the mask. - num_inference_steps (`int`, *optional*, defaults to 28): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the ControlNet starts applying. - control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the ControlNet stops applying. - control_mode (`int` or `List[int]`, *optional*): - The mode for the ControlNet. If multiple ControlNets are used, this should be a list. - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original transformer. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to - make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - Additional keyword arguments to be passed to the joint attention mechanism. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising step during the inference. - callback_on_step_end_tensor_inputs (`List[str]`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. - max_sequence_length (`int`, *optional*, defaults to 512): - The maximum length of the sequence to be generated. - - Examples: - - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. - """ - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - global_height = height - global_width = width - - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - - # 1. Check inputs - self.check_inputs( - prompt, - prompt_2, - image, - mask_image, - strength, - height, - width, - output_type=output_type, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - padding_mask_crop=padding_mask_crop, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = 1 #len(prompt) # thesea modified for text prompt mask - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - dtype = self.transformer.dtype - - # 3. Encode input prompt - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - - prompt_embeds_list = [] - text_ids_list = [] - if isinstance(prompt, list): - for pmt in prompt: - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=pmt, - prompt_2=prompt_2, - #prompt_embeds=prompt_embeds, - #pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - prompt_embeds_list.append(prompt_embeds) - text_ids_list.append(text_ids) - prompt_embeds = torch.cat(prompt_embeds_list, dim=1) - text_ids = torch.cat(text_ids_list, dim=0) - else: - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4. Preprocess mask and image - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region( - mask_image, global_width, global_height, pad=padding_mask_crop - ) - resize_mode = "fill" - else: - crops_coords = None - resize_mode = "default" - - original_image = image - init_image = self.image_processor.preprocess( - image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image = init_image.to(dtype=torch.float32) - - if image_ref_prod is not None: - init_image_ref = self.image_processor.preprocess( - image_ref_prod, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image_ref = init_image_ref.to(dtype=torch.float32) - - # 5. Prepare control image - num_channels_latents = self.transformer.config.in_channels // 4 - if isinstance(self.controlnet, FluxControlNetModel): - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) - height, width = control_image.shape[-2:] - - # xlab controlnet has a input_hint_block and instantx controlnet does not - controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True - if self.controlnet.input_hint_block is None: - # vae encode - control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) - - # set control mode - if control_mode is not None: - control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) - - elif isinstance(self.controlnet, FluxMultiControlNetModel): - control_images = [] - - # xlab controlnet has a input_hint_block and instantx controlnet does not - controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True - for i, control_image_ in enumerate(control_image): - control_image_ = self.prepare_image( - image=control_image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) - height, width = control_image_.shape[-2:] - - if self.controlnet.nets[0].input_hint_block is None: - # vae encode - control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) - - control_images.append(control_image_) - - control_image = control_images - - # set control mode - control_mode_ = [] - if isinstance(control_mode, list): - for cmode in control_mode: - if cmode is None: - control_mode_.append(-1) - else: - control_mode_.append(cmode) - control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) - - # 6. Prepare timesteps - - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( - int(global_width) // self.vae_scale_factor // 2 - ) - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.15), - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - sigmas=sigmas, - mu=mu, - ) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - - if num_inference_steps < 1: - raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - ) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - if self.joint_attention_kwargs is None: - self._joint_attention_kwargs = {} - - # 7. Prepare latent variables - - latents, noise, image_latents, latent_image_ids = self.prepare_latents( - init_image, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - latents, - inpainting_starting_step, - ) - - if image_ref_prod is not None: - _, noise_ref, image_latents_ref, _ = self.prepare_latents( - init_image_ref, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 8. Prepare mask latents - mask_condition = self.mask_processor.preprocess( - mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) - else: - masked_image = masked_image_latents - - # handel prod masks - if prod_masks_original is not None: - mask_conditions_prod = [] - for prod_mask in prod_masks_original: - tmp_mask_condition = self.mask_processor.preprocess( - prod_mask, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - mask_conditions_prod.append(tmp_mask_condition) - - - if image_ref_prod is not None: - mask_condition_original = self.mask_processor.preprocess( - mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image_original = init_image * (mask_condition_original < 0.5) - else: - masked_image_original = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( - mask_condition, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) - - if prod_masks_original is not None: - masks_prod = [] - for tmp_mask_condition in mask_conditions_prod: - tmp_mask, _ = self.prepare_mask_latents( - tmp_mask_condition, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) - masks_prod.append(tmp_mask) - - if image_ref_prod is not None: - mask_original, _ = self.prepare_mask_latents( - mask_condition_original, - masked_image_original, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) - - def apply_dilate_to_mask_image(mask,iterations): - kernel = np.ones((3, 3), np.uint8) - mask = np.array(mask, dtype=bool) - mask = mask.astype(np.uint8) - mask = cv2.dilate(mask, kernel, iterations=iterations) - mask = mask.astype(np.uint8) - #mask = Image.fromarray(mask * 255) - - return mask - - # input np array size: 4096 x 3 - # output np array size: 64 x 64 x 3 - def mask_gradienting(mask, iterations=3): - mask_array = mask.reshape(64,64,-1) - - # gradent 1 - dilated_mask_image = apply_dilate_to_mask_image(mask_array,iterations = iterations) - original_mask_array_bool = np.array(mask_array,dtype=bool) - dilated_mask_array_bool = np.array(dilated_mask_image, dtype=bool) - gray_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 200 - gray_array = gray_array * (~original_mask_array_bool) - gray_array = gray_array * dilated_mask_array_bool - mask1_array = np.where(mask_array>0, mask_array* 255, gray_array) - - # gradent 2 - dilated_mask1_image = apply_dilate_to_mask_image(mask1_array,iterations = iterations) - mask1_array_bool = np.array(mask1_array,dtype=bool) - dilated_mask1_image_bool = np.array(dilated_mask1_image, dtype=bool) - gray1_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 150 - gray1_array = gray1_array * (~mask1_array_bool) - gray1_array = gray1_array * dilated_mask1_image_bool - mask2_array = np.where(mask1_array>0,mask1_array,gray1_array) - - # gradent 3 - dilated_mask2_image = apply_dilate_to_mask_image(mask2_array,iterations = iterations) - mask2_array_bool = np.array(mask2_array,dtype=bool) - dilated_mask2_image_bool = np.array(dilated_mask2_image, dtype=bool) - gray2_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 100 - gray2_array = gray2_array * (~mask2_array_bool) - gray2_array = gray2_array * dilated_mask2_image_bool - mask3_array = np.where(mask2_array>0,mask2_array,gray2_array) - - # gradent 4 - dilated_mask3_image = apply_dilate_to_mask_image(mask3_array,iterations = iterations) - mask3_array_bool = np.array(mask3_array,dtype=bool) - dilated_mask3_image_bool = np.array(dilated_mask3_image, dtype=bool) - gray3_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 50 - gray3_array = gray3_array * (~mask3_array_bool) - gray3_array = gray3_array * dilated_mask3_image_bool - mask4_array = np.where(mask3_array>0,mask3_array,gray3_array) - - # gradent 5 - dilated_mask4_image = apply_dilate_to_mask_image(mask4_array,iterations = iterations) - mask4_array_bool = np.array(mask4_array,dtype=bool) - dilated_mask4_image_bool = np.array(dilated_mask4_image, dtype=bool) - gray4_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 40 - gray4_array = gray4_array * (~mask4_array_bool) - gray4_array = gray4_array * dilated_mask4_image_bool - mask5_array = np.where(mask4_array>0,mask4_array,gray4_array) - - # gradent 6 - dilated_mask5_image = apply_dilate_to_mask_image(mask5_array,iterations = iterations) - mask5_array_bool = np.array(mask5_array,dtype=bool) - dilated_mask5_image_bool = np.array(dilated_mask5_image, dtype=bool) - gray5_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 30 - gray5_array = gray5_array * (~mask5_array_bool) - gray5_array = gray5_array * dilated_mask5_image_bool - mask6_array = np.where(mask5_array>0,mask5_array,gray5_array) - - # gradent 7 - dilated_mask6_image = apply_dilate_to_mask_image(mask6_array,iterations = iterations) - mask6_array_bool = np.array(mask6_array,dtype=bool) - dilated_mask6_image_bool = np.array(dilated_mask6_image, dtype=bool) - gray6_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 20 - gray6_array = gray6_array * (~mask6_array_bool) - gray6_array = gray6_array * dilated_mask6_image_bool - mask7_array = np.where(mask6_array>0,mask6_array,gray6_array) - - # gradent 8 - dilated_mask7_image = apply_dilate_to_mask_image(mask7_array,iterations = iterations) - mask7_array_bool = np.array(mask7_array,dtype=bool) - dilated_mask7_image_bool = np.array(dilated_mask7_image, dtype=bool) - gray7_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 10 - gray7_array = gray7_array * (~mask7_array_bool) - gray7_array = gray7_array * dilated_mask7_image_bool - mask8_array = np.where(mask7_array>0,mask7_array,gray7_array) - - # gradent 9 - dilated_mask8_image = apply_dilate_to_mask_image(mask8_array,iterations = iterations) - mask8_array_bool = np.array(mask8_array,dtype=bool) - dilated_mask8_image_bool = np.array(dilated_mask8_image, dtype=bool) - gray8_array = np.ones((64, 64, mask_array.shape[-1]), dtype=np.uint8) * 5 - gray8_array = gray8_array * (~mask8_array_bool) - gray8_array = gray8_array * dilated_mask8_image_bool - mask9_array = np.where(mask8_array>0,mask8_array,gray8_array) - - # rescale to 1.0, 0.8, 0.6, 0.4, 0.2, 0.0 - final_mask_array = np.zeros_like(mask_array) - final_mask_array = np.where(mask9_array==255, 1.0, final_mask_array) - final_mask_array = np.where(mask9_array==200, 0.9, final_mask_array) - final_mask_array = np.where(mask9_array==150, 0.8, final_mask_array) - final_mask_array = np.where(mask9_array==100, 0.7, final_mask_array) - final_mask_array = np.where(mask9_array==50, 0.6, final_mask_array) - final_mask_array = np.where(mask9_array==40, 0.5, final_mask_array) - final_mask_array = np.where(mask9_array==30, 0.4, final_mask_array) - final_mask_array = np.where(mask9_array==20, 0.3, final_mask_array) - final_mask_array = np.where(mask9_array==10, 0.2, final_mask_array) - final_mask_array = np.where(mask9_array==5, 0.1, final_mask_array) - - return final_mask_array - - # mask graident - mask = mask.to(dtype=torch.float16) - tmp_mask = mask[0,:,:].cpu().numpy() - mask_gradient = mask_gradienting(tmp_mask, iterations=iterations) - tmp_tensor = torch.Tensor(mask_gradient) - tmp_tensor = tmp_tensor.view(-1,64) - mask_tensor = torch.zeros_like(mask) - for index in range(mask_tensor.shape[0]): - mask_tensor[index,:,:] = tmp_tensor - mask = mask_tensor.to(device=latents.device, dtype=latents.dtype) - - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) - - # 9. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - # predict the noise residual - if isinstance(self.controlnet, FluxMultiControlNetModel): - use_guidance = self.controlnet.nets[0].config.guidance_embeds - else: - use_guidance = self.controlnet.config.guidance_embeds - if use_guidance: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] - - controlnet_block_samples, controlnet_single_block_samples = self.controlnet( - hidden_states=latents, - controlnet_cond=control_image, - controlnet_mode=control_mode, - conditioning_scale=cond_scale, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - ) - - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - # For inpainting, we need to apply the mask and add the masked image latents - init_latents_proper = image_latents - init_mask = mask - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise - ) - - # average multiple product latents - if prod_masks_original is not None: - init_masks_prod = masks_prod - if i < averaging_steps: - batch_size = latents.shape[0] - latents = latents.view(-1) - masked_regions = [] - for tmp_init_mask in init_masks_prod: - tmp_init_mask = tmp_init_mask.view(-1) - masked_regions.append(latents[tmp_init_mask.bool()].reshape(-1)) - - min_length = min(t.size(0) for t in masked_regions) - #truncated = [t[:min_length] for t in masked_regions] - truncated = [] - for tmp_mask_region in masked_regions: - if tmp_mask_region.size(0) == min_length: - truncated.append(tmp_mask_region) - else: - extra_dim = tmp_mask_region.size(0) - min_length - half = extra_dim // 2 - if extra_dim % 2 == 0: - truncated.append(tmp_mask_region[half:(min_length-half+extra_dim)]) - else: - truncated.append(tmp_mask_region[half:(min_length-half+extra_dim-1)]) - - masked_regions = torch.stack(truncated) - - avg_region = masked_regions.mean(dim=0) - for tmp_init_mask in init_masks_prod: - tmp_init_mask = tmp_init_mask.view(-1) - length = torch.sum(tmp_init_mask == 1) - if length == avg_region.shape[0]: - latents[tmp_init_mask.bool()] = avg_region - elif length < avg_region.shape[0]: - latents[tmp_init_mask.bool()] = avg_region[0:length] - else: - add_dim = length - avg_region.shape[0] - add_dim1 = add_dim // 2 - add_dim2 = add_dim - add_dim1 - one_tensor1 = torch.ones(add_dim1).to(device=latents.device, dtype=latents.dtype) - one_tensor2 = torch.ones(add_dim2).to(device=latents.device, dtype=latents.dtype) - tmp_tensor1 = one_tensor1 * torch.mean(avg_region) - tmp_tensor2 = one_tensor2 * torch.mean(avg_region) - latents[tmp_init_mask.bool()] = torch.cat((tmp_tensor1, avg_region, tmp_tensor2), dim = 0) - - latents = latents.view(batch_size, 4096, -1) - - # inject product raw images into latents - #if image_ref_prod is not None: - # if i < ref_prod_injection_steps: - # latents_1 = (1 - init_mask) * init_latents_proper - # latents_2 = (init_mask - init_mask_ref_prod) * latents - # latents_3 = (1.0 - ratio_ref_prod) * init_mask_ref_prod * latents + ratio_ref_prod * init_mask_ref_prod * init_latents_proper_ref - - # latents = latents_1 + latents_2 + latents_3 - - if image_ref_prod is not None: - if i < ref_prod_injection_steps: - init_mask_ref_prod = mask_original - init_latents_proper_ref = image_latents_ref - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper_ref = self.scheduler.scale_noise( - init_latents_proper_ref, torch.tensor([noise_timestep]), noise_ref - ) - - latents_1 = init_mask_ref_prod * (ratio_ref_prod * init_latents_proper_ref + (1.0 - ratio_ref_prod) * latents) - latents_2 = (1.0 - init_mask_ref_prod) * latents - - latents = latents_1 + latents_2 - - if i < ref_prod_injection_steps: - latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents - else: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - control_image = callback_outputs.pop("control_image", control_image) - mask = callback_outputs.pop("mask", mask) - masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - # Post-processing - if output_type == "latent": - image = latents - else: - latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,mask_gradient) - - return FluxPipelineOutput(images=image) From a0c34720905b6829c6181811abb75f61ed264c3f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Jul 2025 12:33:50 +0800 Subject: [PATCH 438/482] add bg ref image to ctln pipeline --- .../flux/pipeline_flux_controlnet.py | 61 ++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 777ef88c653b..7523346f5a82 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -826,10 +826,10 @@ def __call__( negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, # original prod image - + image_bg: PipelineImageInput = None, # original bg image ratio_ref: Optional[float] = 0.1, # ratio of original prod images mask_image: PipelineImageInput = None, # modified for injecting original prod images - + mask_bg: PipelineImageInput = None, # modified for injecting original bg images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies masked_image_latents: PipelineImageInput = None, true_cfg_scale: float = 1.0, @@ -864,6 +864,7 @@ def __call__( max_sequence_length: int = 512, averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + ref_bg_injection_steps: Optional[int] = 0, # modified for injecting ref product images ): r""" Function invoked when calling the pipeline for generation. @@ -1088,6 +1089,12 @@ def __call__( init_image = init_image.to(dtype=torch.float32) else: init_image = None + + if image_bg is not None: + init_image_bg = self.image_processor.preprocess( + image_bg, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image_bg = init_image_bg.to(dtype=torch.float32) # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 @@ -1229,6 +1236,21 @@ def __call__( ref_prod_injection_steps, ) + if image_bg is not None: + _, noise_bg, image_latents_bg, _ = self.prepare_latents( + init_image_bg, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents=None, + ref_prod_injection_steps=ref_prod_injection_steps, + ) + # Prepare mask latents if image is not None: mask_condition = self.mask_processor.preprocess( @@ -1251,6 +1273,28 @@ def __call__( device, generator, ) + + if image_bg is not None: + mask_condition_bg = self.mask_processor.preprocess( + mask_bg, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) + else: + masked_image_bg = masked_image_latents + + mask_bg, masked_image_latents_bg = self.prepare_mask_latents( + mask_condition_bg, + masked_image_bg, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) if prod_masks_original is not None: mask_conditions_prod = [] @@ -1465,6 +1509,19 @@ def __call__( latents = latents_1 + latents_2 + if image_bg is not None: + if i < ref_bg_injection_steps: + init_mask_bg = mask_bg + init_latents_proper_bg = image_latents_bg + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper_bg = self.scheduler.scale_noise( + init_latents_proper_bg, torch.tensor([noise_timestep]), noise_bg + ) + + latents = (1.0 - init_mask_bg) * init_latents_proper_bg + init_mask_bg * latents + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 From 9a9d768c726a20354c01f5bbd6e549e5d2bf439c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Jul 2025 14:33:16 +0800 Subject: [PATCH 439/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 7523346f5a82..d5d766a839c3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1278,10 +1278,7 @@ def __call__( mask_condition_bg = self.mask_processor.preprocess( mask_bg, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords ) - if masked_image_latents is None: - masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) - else: - masked_image_bg = masked_image_latents + masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) mask_bg, masked_image_latents_bg = self.prepare_mask_latents( mask_condition_bg, From ec41665571fa3b9dff561b447c8e8d6d7b3a12c5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Jul 2025 17:11:32 +0800 Subject: [PATCH 440/482] add bg_flexible_strength para --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 475decd68003..4a7a6be789a8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -752,6 +752,7 @@ def __call__( image: PipelineImageInput = None, image_ref_prod: Optional[PipelineImageInput] = None, # original prod image ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image + bg_flexible_strength: Optional[float] = 0.0, # modified for injecting prod image mask_image: PipelineImageInput = None, # original mask image for inpainting mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images @@ -1476,9 +1477,9 @@ def mask_gradienting(mask, iterations=3): """ if i < ref_prod_injection_steps: init_mask_ref_prod = mask_original - latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents + latents = (1.0-bg_flexible_strength) * ((1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents) + bg_flexible_strength * latents else: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + latents = (1.0-bg_flexible_strength) * ((1 - init_mask) * init_latents_proper + init_mask * latents) + bg_flexible_strength * latents if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 4785208baa3ced3581af581cf34b1333be53eb2b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 2 Jul 2025 18:50:11 +0800 Subject: [PATCH 441/482] use masked_bg.convert('RGB') as mask for inpaint redo --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 4a7a6be789a8..d6431317e315 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1477,7 +1477,11 @@ def mask_gradienting(mask, iterations=3): """ if i < ref_prod_injection_steps: init_mask_ref_prod = mask_original - latents = (1.0-bg_flexible_strength) * ((1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents) + bg_flexible_strength * latents + # use Image.new("RGB", (image_size, image_size), color=(255, 255, 255)) as mask + #latents = (1.0-bg_flexible_strength) * ((1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents) + bg_flexible_strength * latents + + # use masked_bg.convert('RGB') as mask + latents = (1 - init_mask) * ((1.0-bg_flexible_strength) * init_latents_proper + bg_flexible_strength * latents) + init_mask * init_latents_proper else: latents = (1.0-bg_flexible_strength) * ((1 - init_mask) * init_latents_proper + init_mask * latents) + bg_flexible_strength * latents From 564d6e3da0ac40220885583439b9ffccc93a117c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Jul 2025 12:53:27 +0800 Subject: [PATCH 442/482] adjust refer image paras for multiple images --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d6431317e315..08fee52f4269 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -750,11 +750,11 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - image_ref_prod: Optional[PipelineImageInput] = None, # original prod image + image_ref_prod: Optional[Union[PipelineImageInput, list[PipelineImageInput]]] = None, # original prod image ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image bg_flexible_strength: Optional[float] = 0.0, # modified for injecting prod image mask_image: PipelineImageInput = None, # original mask image for inpainting - mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images + mask_image_original: Optional[Union[PipelineImageInput, list[PipelineImageInput]]] = None, # modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, @@ -782,7 +782,7 @@ def __call__( max_sequence_length: int = 512, iterations: Optional[int] = 1, # modified for applying gradient to mask_image averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product - ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + ref_prod_injection_steps: Optional[Union[int, List[int]]] = None, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting inpainting_ending_step: Optional[int] = 0, # modified for starting inpainting ): From afca700d4d6763a24745bdb0a32ece50c2e450e5 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Jul 2025 12:53:40 +0800 Subject: [PATCH 443/482] Revert "adjust refer image paras for multiple images" This reverts commit 564d6e3da0ac40220885583439b9ffccc93a117c. --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 08fee52f4269..d6431317e315 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -750,11 +750,11 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - image_ref_prod: Optional[Union[PipelineImageInput, list[PipelineImageInput]]] = None, # original prod image + image_ref_prod: Optional[PipelineImageInput] = None, # original prod image ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image bg_flexible_strength: Optional[float] = 0.0, # modified for injecting prod image mask_image: PipelineImageInput = None, # original mask image for inpainting - mask_image_original: Optional[Union[PipelineImageInput, list[PipelineImageInput]]] = None, # modified for injecting original prod images + mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, @@ -782,7 +782,7 @@ def __call__( max_sequence_length: int = 512, iterations: Optional[int] = 1, # modified for applying gradient to mask_image averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product - ref_prod_injection_steps: Optional[Union[int, List[int]]] = None, # modified for injecting ref product images + ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting inpainting_ending_step: Optional[int] = 0, # modified for starting inpainting ): From d250774982b1c4a99e0be874e46dc47f2de73d7c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Jul 2025 12:54:20 +0800 Subject: [PATCH 444/482] Revert "use masked_bg.convert('RGB') as mask for inpaint redo" This reverts commit 4785208baa3ced3581af581cf34b1333be53eb2b. --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d6431317e315..4a7a6be789a8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1477,11 +1477,7 @@ def mask_gradienting(mask, iterations=3): """ if i < ref_prod_injection_steps: init_mask_ref_prod = mask_original - # use Image.new("RGB", (image_size, image_size), color=(255, 255, 255)) as mask - #latents = (1.0-bg_flexible_strength) * ((1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents) + bg_flexible_strength * latents - - # use masked_bg.convert('RGB') as mask - latents = (1 - init_mask) * ((1.0-bg_flexible_strength) * init_latents_proper + bg_flexible_strength * latents) + init_mask * init_latents_proper + latents = (1.0-bg_flexible_strength) * ((1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents) + bg_flexible_strength * latents else: latents = (1.0-bg_flexible_strength) * ((1 - init_mask) * init_latents_proper + init_mask * latents) + bg_flexible_strength * latents From 5638abe245ca8dcadd73f4e042d464277a42ddfc Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Jul 2025 12:54:32 +0800 Subject: [PATCH 445/482] Revert "add bg_flexible_strength para" This reverts commit ec41665571fa3b9dff561b447c8e8d6d7b3a12c5. --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 4a7a6be789a8..475decd68003 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -752,7 +752,6 @@ def __call__( image: PipelineImageInput = None, image_ref_prod: Optional[PipelineImageInput] = None, # original prod image ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image - bg_flexible_strength: Optional[float] = 0.0, # modified for injecting prod image mask_image: PipelineImageInput = None, # original mask image for inpainting mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images @@ -1477,9 +1476,9 @@ def mask_gradienting(mask, iterations=3): """ if i < ref_prod_injection_steps: init_mask_ref_prod = mask_original - latents = (1.0-bg_flexible_strength) * ((1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents) + bg_flexible_strength * latents + latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents else: - latents = (1.0-bg_flexible_strength) * ((1 - init_mask) * init_latents_proper + init_mask * latents) + bg_flexible_strength * latents + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From a9c082b93ba049420dfae8a1d21c71993ab48373 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Jul 2025 12:54:48 +0800 Subject: [PATCH 446/482] Revert "fix bug" This reverts commit 9a9d768c726a20354c01f5bbd6e549e5d2bf439c. --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index d5d766a839c3..7523346f5a82 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1278,7 +1278,10 @@ def __call__( mask_condition_bg = self.mask_processor.preprocess( mask_bg, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords ) - masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) + if masked_image_latents is None: + masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) + else: + masked_image_bg = masked_image_latents mask_bg, masked_image_latents_bg = self.prepare_mask_latents( mask_condition_bg, From 329e420d9df89848468fe1a35bc0c6228361ac84 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Jul 2025 12:54:58 +0800 Subject: [PATCH 447/482] Revert "add bg ref image to ctln pipeline" This reverts commit a0c34720905b6829c6181811abb75f61ed264c3f. --- .../flux/pipeline_flux_controlnet.py | 61 +------------------ 1 file changed, 2 insertions(+), 59 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 7523346f5a82..777ef88c653b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -826,10 +826,10 @@ def __call__( negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, # original prod image - image_bg: PipelineImageInput = None, # original bg image + ratio_ref: Optional[float] = 0.1, # ratio of original prod images mask_image: PipelineImageInput = None, # modified for injecting original prod images - mask_bg: PipelineImageInput = None, # modified for injecting original bg images + prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies masked_image_latents: PipelineImageInput = None, true_cfg_scale: float = 1.0, @@ -864,7 +864,6 @@ def __call__( max_sequence_length: int = 512, averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images - ref_bg_injection_steps: Optional[int] = 0, # modified for injecting ref product images ): r""" Function invoked when calling the pipeline for generation. @@ -1089,12 +1088,6 @@ def __call__( init_image = init_image.to(dtype=torch.float32) else: init_image = None - - if image_bg is not None: - init_image_bg = self.image_processor.preprocess( - image_bg, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode - ) - init_image_bg = init_image_bg.to(dtype=torch.float32) # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 @@ -1236,21 +1229,6 @@ def __call__( ref_prod_injection_steps, ) - if image_bg is not None: - _, noise_bg, image_latents_bg, _ = self.prepare_latents( - init_image_bg, - latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - latents=None, - ref_prod_injection_steps=ref_prod_injection_steps, - ) - # Prepare mask latents if image is not None: mask_condition = self.mask_processor.preprocess( @@ -1273,28 +1251,6 @@ def __call__( device, generator, ) - - if image_bg is not None: - mask_condition_bg = self.mask_processor.preprocess( - mask_bg, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image_bg = init_image_bg * (mask_condition_bg < 0.5) - else: - masked_image_bg = masked_image_latents - - mask_bg, masked_image_latents_bg = self.prepare_mask_latents( - mask_condition_bg, - masked_image_bg, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) if prod_masks_original is not None: mask_conditions_prod = [] @@ -1509,19 +1465,6 @@ def __call__( latents = latents_1 + latents_2 - if image_bg is not None: - if i < ref_bg_injection_steps: - init_mask_bg = mask_bg - init_latents_proper_bg = image_latents_bg - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper_bg = self.scheduler.scale_noise( - init_latents_proper_bg, torch.tensor([noise_timestep]), noise_bg - ) - - latents = (1.0 - init_mask_bg) * init_latents_proper_bg + init_mask_bg * latents - if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 From 6f0943d2d2ce55daa6c97822330a6f20ae90abe9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 3 Jul 2025 22:51:40 +0800 Subject: [PATCH 448/482] individual masks for multiple products --- .../pipeline_flux_controlnet_inpainting.py | 93 ++++++++++++++----- 1 file changed, 69 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 475decd68003..1b020af6ceee 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -753,7 +753,7 @@ def __call__( image_ref_prod: Optional[PipelineImageInput] = None, # original prod image ratio_ref_prod: Optional[float] = 0.125, # modified for injecting prod image mask_image: PipelineImageInput = None, # original mask image for inpainting - mask_image_original: Optional[PipelineImageInput] = None, # modified for injecting original prod images + mask_image_original: Optional[Union[PipelineImageInput, List[PipelineImageInput]]] = None, # modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for injecting original prod images masked_image_latents: PipelineImageInput = None, control_image: PipelineImageInput = None, @@ -781,7 +781,7 @@ def __call__( max_sequence_length: int = 512, iterations: Optional[int] = 1, # modified for applying gradient to mask_image averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product - ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + ref_prod_injection_steps: Optional[Union[int, List[int]]] = 22, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting inpainting_ending_step: Optional[int] = 0, # modified for starting inpainting ): @@ -1139,13 +1139,29 @@ def __call__( if image_ref_prod is not None: - mask_condition_original = self.mask_processor.preprocess( - mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image_original = init_image * (mask_condition_original < 0.5) + if len(mask_image_original) == 1: + mask_condition_original = self.mask_processor.preprocess( + mask_image_original[0], height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image_original = init_image * (mask_condition_original < 0.5) + else: + masked_image_original = masked_image_latents else: - masked_image_original = masked_image_latents + if len(ref_prod_injection_steps) != len(mask_image_original): + raise ValueError(f"The number of ref_prod_injection_steps ({ref_prod_injection_steps}) has to be equal to the number of masks ({mask_image_original})") + + mask_condition_original = [] + for tmp_mask_image_original in mask_image_original: + tmp_mask_condition_original = self.mask_processor.preprocess( + tmp_mask_image_original, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + mask_condition_original.append(tmp_mask_condition_original) + if masked_image_latents is None: + masked_image_original = init_image * (tmp_mask_condition_original < 0.5) + else: + masked_image_original = masked_image_latents + mask, masked_image_latents = self.prepare_mask_latents( mask_condition, @@ -1178,18 +1194,36 @@ def __call__( masks_prod.append(tmp_mask) if image_ref_prod is not None: - mask_original, _ = self.prepare_mask_latents( - mask_condition_original, - masked_image_original, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) + if len(mask_image_original) == 1: + mask_original, _ = self.prepare_mask_latents( + mask_condition_original, + masked_image_original, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + else: + mask_original = [] + for tmp_mask_condition_original in mask_condition_original: + tmp_mask_original, _ = self.prepare_mask_latents( + tmp_mask_condition_original, + masked_image_original, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + mask_original.append(tmp_mask_original) + def apply_dilate_to_mask_image(mask,iterations): kernel = np.ones((3, 3), np.uint8) @@ -1474,11 +1508,22 @@ def mask_gradienting(mask, iterations=3): latents = latents_1 + latents_2 """ - if i < ref_prod_injection_steps: - init_mask_ref_prod = mask_original - latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents + init_mask_ref_prod = mask_original + if len(mask_image_original) == 1: + if i < ref_prod_injection_steps: + latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents else: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + init_mask_ref_prod_all = torch.zeros_like(init_mask) + for tmp_ref_prod_injection_steps, tmp_init_mask_ref_prod in (ref_prod_injection_steps,init_mask_ref_prod): + if i < tmp_ref_prod_injection_steps: + init_mask_ref_prod_all += tmp_init_mask_ref_prod + + latents = (1 - (init_mask - init_mask_ref_prod_all)) * init_latents_proper + (init_mask - init_mask_ref_prod_all) * latents + + + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 119545fcb17c0c6c450ff3785e35a716244a121f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 12:02:23 +0800 Subject: [PATCH 449/482] individual mask for multiple products --- .../flux/pipeline_flux_controlnet.py | 112 ++++++++++++------ 1 file changed, 75 insertions(+), 37 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 777ef88c653b..f5357502b058 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -826,10 +826,8 @@ def __call__( negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, # original prod image - ratio_ref: Optional[float] = 0.1, # ratio of original prod images - mask_image: PipelineImageInput = None, # modified for injecting original prod images - + mask_image: Optional[Union[PipelineImageInput, List[PipelineImageInput]]] = None,# modified for injecting original prod images prod_masks_original: Optional[List[PipelineImageInput]] = None, # modified for averaging multiple copies masked_image_latents: PipelineImageInput = None, true_cfg_scale: float = 1.0, @@ -863,7 +861,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, averaging_steps: Optional[int] = 2, # modified for applying averaging latents for multiple copies of same product - ref_prod_injection_steps: Optional[int] = 22, # modified for injecting ref product images + ref_prod_injection_steps: Optional[Union[int, List[int]]] = 22, # modified for injecting ref product images ): r""" Function invoked when calling the pipeline for generation. @@ -1075,7 +1073,7 @@ def __call__( if image is not None: if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region( - mask_image, global_width, global_height, pad=padding_mask_crop + mask_image[0], global_width, global_height, pad=padding_mask_crop ) resize_mode = "fill" else: @@ -1086,6 +1084,10 @@ def __call__( image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) + + if len(mask_image) > 1: + if len(ref_prod_injection_steps) != len(mask_image): + raise ValueError(f"The number of ref_prod_injection_steps ({ref_prod_injection_steps}) has to be equal to the number of masks ({mask_image})") else: init_image = None @@ -1231,26 +1233,51 @@ def __call__( # Prepare mask latents if image is not None: - mask_condition = self.mask_processor.preprocess( - mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords - ) - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) + if len(mask_image) == 1: + mask_condition = self.mask_processor.preprocess( + mask_image[0], height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) else: - masked_image = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( - mask_condition, - masked_image, - batch_size, - num_channels_latents, - num_images_per_prompt, - global_height, - global_width, - prompt_embeds.dtype, - device, - generator, - ) + mask = [] + for tmp_mask_image in mask_image: + mask_condition = self.mask_processor.preprocess( + tmp_mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + tmp_mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + mask.append(tmp_mask) if prod_masks_original is not None: mask_conditions_prod = [] @@ -1450,21 +1477,32 @@ def __call__( # additional codes for injecting original prod images into latents if image is not None: - if i < ref_prod_injection_steps: - init_mask = mask - init_latents_proper = image_latents - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise - ) - - latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) - latents_2 = (1.0 - init_mask) * latents + init_mask = mask + init_latents_proper = image_latents + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + if len(mask_image) == 1: + if i < ref_prod_injection_steps: + latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_2 = (1.0 - init_mask) * latents + + latents = latents_1 + latents_2 + else: + init_mask_all = torch.zeros_like(init_mask) + for tmp_ref_prod_injection_steps, tmp_init_mask in (ref_prod_injection_steps, init_mask): + if i < tmp_ref_prod_injection_steps: + init_mask_all += tmp_init_mask + + latents_1 = init_mask_all * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) + latents_2 = (1.0 - init_mask_all) * latents latents = latents_1 + latents_2 - + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 From 7b6008da0aa24237c76d6ef00d0fcb1d5c939d22 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 12:11:42 +0800 Subject: [PATCH 450/482] change masked_bg_original to array --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index cd4ab269b337..f604506669d5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -549,7 +549,7 @@ def __call__( composed_image_all = np.zeros((image_width, image_height, 3)) masked_bg = np.zeros((image_width, image_height, 3)) - masked_bg_original = np.zeros((image_width, image_height, 3)) + masked_bg_original_array = [] masked_bg_with_element = np.zeros((image_width, image_height, 3)) composed_bg_image = np.zeros((image_width, image_height, 3)) composed_prod_images = [] @@ -564,7 +564,7 @@ def __call__( composed_image_all += img_array * image_mask_all[index] if is_product.lower() == "true": masked_bg += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) - masked_bg_original += mask_value*np.ones((image_width, image_height, 3)) * self.apply_erosion_to_mask(image_mask_all[index], iterations=iterations_erosion) + masked_bg_original_array.append(mask_value*np.ones((image_width, image_height, 3)) * self.apply_erosion_to_mask(image_mask_all[index], iterations=iterations_erosion)) if index > 0: masked_bg_with_element += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) @@ -573,7 +573,9 @@ def __call__( composed_prod_images_all = Image.fromarray(composed_prod_images_all.astype(np.uint8)).convert('RGB') composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') - masked_bg_original = Image.fromarray(masked_bg_original.astype(np.uint8)).convert('RGB') + masked_bg_original = [] + for tmp_masked_bg_original in masked_bg_original_array: + masked_bg_original.append(Image.fromarray(tmp_masked_bg_original.astype(np.uint8)).convert('RGB')) masked_bg_with_element = Image.fromarray(masked_bg_with_element.astype(np.uint8)).convert('RGB') bg_mask = Image.fromarray(bg_mask.astype(np.uint8)*255).convert('RGB') From 6e7bef02dc89e0a6129ffc9379948bc56e74c6d7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 12:18:14 +0800 Subject: [PATCH 451/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index f5357502b058..d5d9ce3c49e9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1265,7 +1265,7 @@ def __call__( else: masked_image = masked_image_latents - tmp_mask, masked_image_latents = self.prepare_mask_latents( + tmp_mask, _ = self.prepare_mask_latents( mask_condition, masked_image, batch_size, From 0e9db8e5e90f575e5c20215d73596d7385b64b40 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 12:20:35 +0800 Subject: [PATCH 452/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index d5d9ce3c49e9..7aa091a772aa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1493,7 +1493,7 @@ def __call__( latents = latents_1 + latents_2 else: - init_mask_all = torch.zeros_like(init_mask) + init_mask_all = torch.zeros_like(init_mask[0]) for tmp_ref_prod_injection_steps, tmp_init_mask in (ref_prod_injection_steps, init_mask): if i < tmp_ref_prod_injection_steps: init_mask_all += tmp_init_mask From 68fd3f0f879d432181cfa9b19ee251590db0f4b7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 12:22:49 +0800 Subject: [PATCH 453/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 7aa091a772aa..9a9895a1591d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1494,7 +1494,7 @@ def __call__( latents = latents_1 + latents_2 else: init_mask_all = torch.zeros_like(init_mask[0]) - for tmp_ref_prod_injection_steps, tmp_init_mask in (ref_prod_injection_steps, init_mask): + for tmp_ref_prod_injection_steps, tmp_init_mask in zip(ref_prod_injection_steps, init_mask): if i < tmp_ref_prod_injection_steps: init_mask_all += tmp_init_mask From dcff0c311e642ce70dc37be1483459464c76a640 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 12:24:58 +0800 Subject: [PATCH 454/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 1b020af6ceee..f9464076e8ee 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1516,7 +1516,7 @@ def mask_gradienting(mask, iterations=3): latents = (1 - init_mask) * init_latents_proper + init_mask * latents else: init_mask_ref_prod_all = torch.zeros_like(init_mask) - for tmp_ref_prod_injection_steps, tmp_init_mask_ref_prod in (ref_prod_injection_steps,init_mask_ref_prod): + for tmp_ref_prod_injection_steps, tmp_init_mask_ref_prod in zip(ref_prod_injection_steps,init_mask_ref_prod): if i < tmp_ref_prod_injection_steps: init_mask_ref_prod_all += tmp_init_mask_ref_prod From 9d217ce7d011cd0fdfaa1d2cf11dd070796d3c98 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 14:18:44 +0800 Subject: [PATCH 455/482] fix bug --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index f9464076e8ee..561cd7fc6d7c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1510,7 +1510,7 @@ def mask_gradienting(mask, iterations=3): """ init_mask_ref_prod = mask_original if len(mask_image_original) == 1: - if i < ref_prod_injection_steps: + if i < ref_prod_injection_steps[0]: latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents From 75f9dab31e14f25689e3306edb267c4b6352f86b Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 14:22:27 +0800 Subject: [PATCH 456/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 9a9895a1591d..23e40a1bac7e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1487,7 +1487,7 @@ def __call__( ) if len(mask_image) == 1: - if i < ref_prod_injection_steps: + if i < ref_prod_injection_steps[0]: latents_1 = init_mask * (ratio_ref * init_latents_proper + (1.0 - ratio_ref) * latents) latents_2 = (1.0 - init_mask) * latents From 41ca72f274dd4aefb99414612aa55dfb935fbab3 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 14:27:30 +0800 Subject: [PATCH 457/482] check number of steps and masks --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 561cd7fc6d7c..48407e8f8334 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1139,6 +1139,9 @@ def __call__( if image_ref_prod is not None: + if len(ref_prod_injection_steps) != len(mask_image_original): + raise ValueError(f"The number of ref_prod_injection_steps ({ref_prod_injection_steps}) has to be equal to the number of masks ({mask_image_original})") + if len(mask_image_original) == 1: mask_condition_original = self.mask_processor.preprocess( mask_image_original[0], height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords @@ -1148,9 +1151,6 @@ def __call__( else: masked_image_original = masked_image_latents else: - if len(ref_prod_injection_steps) != len(mask_image_original): - raise ValueError(f"The number of ref_prod_injection_steps ({ref_prod_injection_steps}) has to be equal to the number of masks ({mask_image_original})") - mask_condition_original = [] for tmp_mask_image_original in mask_image_original: tmp_mask_condition_original = self.mask_processor.preprocess( From 9f927f35b69482ca990caece4b396c086416da61 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 4 Jul 2025 14:43:09 +0800 Subject: [PATCH 458/482] check number of steps and masks --- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 5 ++--- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 23e40a1bac7e..670c699bf048 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1085,9 +1085,8 @@ def __call__( ) init_image = init_image.to(dtype=torch.float32) - if len(mask_image) > 1: - if len(ref_prod_injection_steps) != len(mask_image): - raise ValueError(f"The number of ref_prod_injection_steps ({ref_prod_injection_steps}) has to be equal to the number of masks ({mask_image})") + if len(ref_prod_injection_steps) != len(mask_image): + raise ValueError(f"The number of ref_prod_injection_steps ({len(ref_prod_injection_steps)}) has to be equal to the number of masks ({len(mask_image)})") else: init_image = None diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 48407e8f8334..7fed3c07964f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1140,7 +1140,7 @@ def __call__( if image_ref_prod is not None: if len(ref_prod_injection_steps) != len(mask_image_original): - raise ValueError(f"The number of ref_prod_injection_steps ({ref_prod_injection_steps}) has to be equal to the number of masks ({mask_image_original})") + raise ValueError(f"The number of ref_prod_injection_steps ({len(ref_prod_injection_steps)}) has to be equal to the number of masks ({len(mask_image_original)})") if len(mask_image_original) == 1: mask_condition_original = self.mask_processor.preprocess( From a5f19a218debc1d7025c345b31fee1f6a6a559a1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Mon, 7 Jul 2025 16:48:17 +0800 Subject: [PATCH 459/482] add inpaint default logic --- .../pipeline_flux_controlnet_inpainting.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 7fed3c07964f..0e6067ca5786 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1508,21 +1508,22 @@ def mask_gradienting(mask, iterations=3): latents = latents_1 + latents_2 """ - init_mask_ref_prod = mask_original - if len(mask_image_original) == 1: - if i < ref_prod_injection_steps[0]: - latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents + if mask_image_original is not None: + init_mask_ref_prod = mask_original + if len(mask_image_original) == 1: + if i < ref_prod_injection_steps[0]: + latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents + else: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents else: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + init_mask_ref_prod_all = torch.zeros_like(init_mask) + for tmp_ref_prod_injection_steps, tmp_init_mask_ref_prod in zip(ref_prod_injection_steps,init_mask_ref_prod): + if i < tmp_ref_prod_injection_steps: + init_mask_ref_prod_all += tmp_init_mask_ref_prod + + latents = (1 - (init_mask - init_mask_ref_prod_all)) * init_latents_proper + (init_mask - init_mask_ref_prod_all) * latents else: - init_mask_ref_prod_all = torch.zeros_like(init_mask) - for tmp_ref_prod_injection_steps, tmp_init_mask_ref_prod in zip(ref_prod_injection_steps,init_mask_ref_prod): - if i < tmp_ref_prod_injection_steps: - init_mask_ref_prod_all += tmp_init_mask_ref_prod - - latents = (1 - (init_mask - init_mask_ref_prod_all)) * init_latents_proper + (init_mask - init_mask_ref_prod_all) * latents - - + latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: From 0d74ee9fa711c5120b2150ac99be4eaa03377c3a Mon Sep 17 00:00:00 2001 From: YangMI-YM Date: Mon, 7 Jul 2025 11:28:36 +0000 Subject: [PATCH 460/482] added Kontext --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/flux/__init__.py | 3 + .../flux/pipeline_flux_kontext_inpaint.py | 1459 +++++++++++++++++ 3 files changed, 1463 insertions(+) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c4a3a5bffcbc..b26e2f3ddf7f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -368,6 +368,7 @@ "FluxFillPipeline", "FluxImg2ImgPipeline", "FluxInpaintPipeline", + "FluxKontextInpaintPipeline", "FluxPipeline", "FluxPriorReduxPipeline", "HiDreamImagePipeline", diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index ab3271b62a46..72488b720a75 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -35,6 +35,8 @@ _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] + #_import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"] + _import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -55,6 +57,7 @@ from .pipeline_flux_img2img import FluxImg2ImgPipeline from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_prior_redux import FluxPriorReduxPipeline + from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py new file mode 100644 index 000000000000..1a79f17215d4 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -0,0 +1,1459 @@ +# Copyright 2025 ZenAI. All rights reserved. +# author: @vuongminh1907 + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + # Inpainting with text only + ```py + >>> import torch + >>> from diffusers import FluxKontextInpaintPipeline + >>> from diffusers.utils import load_image + + >>> prompt = "Change the yellow dinosaur to green one" + >>> img_url = ( + ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true" + ... ) + >>> mask_url = ( + ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true" + ... ) + + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + + >>> pipe = FluxKontextInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0] + >>> image.save("kontext_inpainting_normal.png") + ``` + + # Inpainting with image conditioning + ```py + >>> import torch + >>> from diffusers import FluxKontextInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxKontextInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "Replace this ball" + >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" + >>> mask_url = ( + ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true" + ... ) + >>> image_reference_url = ( + ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s" + ... ) + + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image_reference = load_image(image_reference_url) + + >>> mask = pipe.mask_processor.blur(mask, blur_factor=12) + >>> image = pipe( + ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0 + ... ).images[0] + >>> image.save("kontext_inpainting_ref.png") + ``` +""" + +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxKontextInpaintPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux Kontext pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image: Optional[torch.Tensor], + timestep: int, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + image_reference: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + + # Prepare image latents + image_latents = image_ids = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + # Prepare image reference latents + image_reference_latents = image_reference_ids = None + if image_reference is not None: + image_reference = image_reference.to(device=device, dtype=dtype) + if image_reference.shape[1] != self.latent_channels: + image_reference_latents = self._encode_vae_image(image=image_reference, generator=generator) + else: + image_reference_latents = image_reference + if batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_reference_latents.shape[0] + image_reference_latents = torch.cat([image_reference_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image_reference` of batch size {image_reference_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_reference_latents = torch.cat([image_reference_latents], dim=0) + + latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device=device, dtype=dtype) + latents = noise + + image_latent_height, image_latent_width = image_latents.shape[2:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + image_ids = self._prepare_latent_image_ids( + batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + + if image_reference_latents is not None: + image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:] + image_reference_latents = self._pack_latents( + image_reference_latents, + batch_size, + num_channels_latents, + image_reference_latent_height, + image_reference_latent_width, + ) + image_reference_ids = self._prepare_latent_image_ids( + batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype + ) + # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_reference_ids[..., 0] = 1 + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = ( + masked_image_latents - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + image_reference: Optional[PipelineImageInput] = None, + mask_image: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + padding_mask_crop: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + max_area: int = 1024**2, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image + to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point for the + masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If + it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is + a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can + also accept image latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): + Maximum sequence length to use with the `prompt`. + max_area (`int`, defaults to `1024 ** 2`): + The maximum area of the generated image in pixels. The height and width will be adjusted to fit this + area while maintaining the aspect ratio. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_height, original_width = height, width + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + image = torch.cat(image, dim=0) + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + + # Choose the resolution of the image to be the same as the image + width = image_width + height = image_height + + # 2.1 Preprocess mask + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + image = self.image_processor.preprocess( + image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + else: + raise ValueError("image must be provided correctly for inpainting") + + init_image = image.to(dtype=torch.float32) + + # 2.1 Preprocess image_reference + if image_reference is not None and not ( + isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels + ): + if ( + isinstance(image_reference, list) + and isinstance(image_reference[0], torch.Tensor) + and image_reference[0].ndim == 4 + ): + image_reference = torch.cat(image_reference, dim=0) + img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference + image_reference_height, image_reference_width = self.image_processor.get_default_height_width( + img_reference + ) + aspect_ratio = image_reference_width / image_reference_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_reference_width, image_reference_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_reference_width = image_reference_width // multiple_of * multiple_of + image_reference_height = image_reference_height // multiple_of * multiple_of + image_reference = self.image_processor.resize( + image_reference, image_reference_height, image_reference_width + ) + image_reference = self.image_processor.preprocess( + image_reference, + image_reference_height, + image_reference_width, + crops_coords=crops_coords, + resize_mode=resize_mode, + ) + else: + image_reference = None + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = ( + self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image_reference, + ) + ) + + if image_reference_ids is not None: + latent_ids = torch.cat([latent_ids, image_reference_ids], dim=0) # dim 0 is sequence dimension + elif image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask_condition < 0.5) + + mask, _ = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + latent_model_input = latents + if image_reference_latents is not None: + latent_model_input = torch.cat([latents, image_reference_latents], dim=1) + elif image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file From 1373ed10dbe16d903bd03e7f30f557ad84738bcb Mon Sep 17 00:00:00 2001 From: YangMI-YM Date: Mon, 7 Jul 2025 16:40:12 +0000 Subject: [PATCH 461/482] fixed import error --- src/diffusers/__init__.py | 3 + src/diffusers/pipelines/__init__.py | 4 + src/diffusers/pipelines/flux/__init__.py | 3 +- .../pipelines/flux/pipeline_flux_kontext.py | 1134 +++++++++++++++++ 4 files changed, 1143 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_kontext.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b26e2f3ddf7f..74e32897f703 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -369,6 +369,7 @@ "FluxImg2ImgPipeline", "FluxInpaintPipeline", "FluxKontextInpaintPipeline", + "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", "HiDreamImagePipeline", @@ -943,6 +944,8 @@ FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, + FluxKontextInpaintPipeline, + FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, HiDreamImagePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 011f23ed371c..d28f3529fb88 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -140,6 +140,8 @@ "FluxFillPipeline", "FluxPriorReduxPipeline", "ReduxImageEncoder", + "FluxKontextPipeline", + "FluxKontextInpaintPipeline", ] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ @@ -582,6 +584,8 @@ FluxFillPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, + FluxKontextInpaintPipeline, + FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, ReduxImageEncoder, diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py index 72488b720a75..f08d04473094 100644 --- a/src/diffusers/pipelines/flux/__init__.py +++ b/src/diffusers/pipelines/flux/__init__.py @@ -35,7 +35,7 @@ _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] - #_import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"] + _import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"] _import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -58,6 +58,7 @@ from .pipeline_flux_inpaint import FluxInpaintPipeline from .pipeline_flux_prior_redux import FluxPriorReduxPipeline from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline + from .pipeline_flux_kontext import FluxKontextPipeline else: import sys diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py new file mode 100644 index 000000000000..3ddca2a760dd --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -0,0 +1,1134 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxKontextPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxKontextPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" + >>> image = pipe( + ... image=image, + ... prompt=prompt, + ... guidance_scale=2.5, + ... generator=torch.Generator().manual_seed(42), + ... ).images[0] + >>> image.save("output.png") + ``` +""" + +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxKontextPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux Kontext pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + + image_latents = image_ids = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[2:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + image_ids = self._prepare_latent_image_ids( + batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + + latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latent_ids, image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + max_area: int = 1024**2, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): + Maximum sequence length to use with the `prompt`. + max_area (`int`, defaults to `1024 ** 2`): + The maximum area of the generated image in pixels. The height and width will be adjusted to fit this + area while maintaining the aspect ratio. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_height, original_width = height, width + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + image = self.image_processor.preprocess(image, image_height, image_width) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents, latent_ids, image_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file From e7a0adb785faa4f08259f17839b0c0c076bb1baf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Jul 2025 17:02:25 +0800 Subject: [PATCH 462/482] add shadow_step para --- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 0e6067ca5786..afbb5353a889 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -784,6 +784,7 @@ def __call__( ref_prod_injection_steps: Optional[Union[int, List[int]]] = 22, # modified for injecting ref product images inpainting_starting_step: Optional[int] = 0, # modified for starting inpainting inpainting_ending_step: Optional[int] = 0, # modified for starting inpainting + shadow_step: Optional[int] = 0, # modified for shadow ): """ Function invoked when calling the pipeline for generation. @@ -1511,6 +1512,8 @@ def mask_gradienting(mask, iterations=3): if mask_image_original is not None: init_mask_ref_prod = mask_original if len(mask_image_original) == 1: + if i < shadow_step: + init_mask_ref_prod = init_mask if i < ref_prod_injection_steps[0]: latents = (1 - (init_mask -init_mask_ref_prod)) * init_latents_proper + (init_mask -init_mask_ref_prod) * latents else: @@ -1521,6 +1524,8 @@ def mask_gradienting(mask, iterations=3): if i < tmp_ref_prod_injection_steps: init_mask_ref_prod_all += tmp_init_mask_ref_prod + if i < shadow_step: + init_mask_ref_prod_all = init_mask latents = (1 - (init_mask - init_mask_ref_prod_all)) * init_latents_proper + (init_mask - init_mask_ref_prod_all) * latents else: latents = (1 - init_mask) * init_latents_proper + init_mask * latents From cec559f2f9e0d3280c47cec0abf2517e3ebab60f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Tue, 8 Jul 2025 20:01:03 +0800 Subject: [PATCH 463/482] Create pipeline_flux_prior_redux_transparency.py --- .../pipeline_flux_prior_redux_transparency.py | 718 ++++++++++++++++++ 1 file changed, 718 insertions(+) create mode 100644 src/diffusers/pipelines/flux/pipeline_flux_prior_redux_transparency.py diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux_transparency.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux_transparency.py new file mode 100644 index 000000000000..91461e6080e7 --- /dev/null +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux_transparency.py @@ -0,0 +1,718 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Union +import numpy as np +import cv2 + +import torch +from PIL import Image +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ..pipeline_utils import DiffusionPipeline +from .modeling_flux import ReduxImageEncoder +from .pipeline_output import FluxPriorReduxPipelineOutput + + +if is_torch_xla_available(): + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPriorReduxPipeline, FluxPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> dtype = torch.bfloat16 + + >>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev" + >>> repo_base = "black-forest-labs/FLUX.1-dev" + >>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device) + >>> pipe = FluxPipeline.from_pretrained( + ... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16 + ... ).to(device) + + >>> image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png" + ... ) + >>> pipe_prior_output = pipe_prior_redux(image) + >>> images = pipe( + ... guidance_scale=2.5, + ... num_inference_steps=50, + ... generator=torch.Generator("cpu").manual_seed(0), + ... **pipe_prior_output, + ... ).images + >>> images[0].save("flux-redux.png") + ``` +""" + + +class FluxPriorReduxPipeline(DiffusionPipeline): + r""" + The Flux Redux pipeline for image-to-image generation. + + Reference: https://blackforestlabs.ai/flux-1-tools/ + + Args: + image_encoder ([`SiglipVisionModel`]): + SIGLIP vision model to encode the input image. + feature_extractor ([`SiglipImageProcessor`]): + Image processor for preprocessing images for the SIGLIP model. + image_embedder ([`ReduxImageEncoder`]): + Redux image encoder to process the SIGLIP embeddings. + text_encoder ([`CLIPTextModel`], *optional*): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`], *optional*): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`, *optional*): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`, *optional*): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "image_encoder->image_embedder" + _optional_components = [ + "text_encoder", + "tokenizer", + "text_encoder_2", + "tokenizer_2", + ] + _callback_tensor_inputs = [] + + def __init__( + self, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + image_embedder: ReduxImageEncoder, + text_encoder: CLIPTextModel = None, + tokenizer: CLIPTokenizer = None, + text_encoder_2: T5EncoderModel = None, + tokenizer_2: T5TokenizerFast = None, + ): + super().__init__() + + self.register_modules( + image_encoder=image_encoder, + feature_extractor=feature_extractor, + image_embedder=image_embedder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + + def check_inputs( + self, + image, + prompt, + prompt_2, + prompt_embeds=None, + pooled_prompt_embeds=None, + prompt_embeds_scale=1.0, + pooled_prompt_embeds_scale=1.0, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): + raise ValueError( + f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images" + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if isinstance(prompt_embeds_scale, list) and ( + isinstance(image, list) and len(prompt_embeds_scale) != len(image) + ): + raise ValueError( + f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" + ) + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + image = self.feature_extractor.preprocess( + images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True + ) + image = image.to(device=device, dtype=dtype) + + image_enc_hidden_states = self.image_encoder(**image).last_hidden_state + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + return image_enc_hidden_states + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def apply_dilate_to_mask(self, mask,iterations = 10): + + kernel = np.ones((5, 5), np.uint8) + mask = mask.astype(np.uint8) + mask = cv2.dilate(mask, kernel, iterations=iterations) + mask = np.array(mask, dtype=bool) + + return mask + + def apply_erosion_to_mask(self, mask,iterations = 10): + + kernel = np.ones((3, 3), np.uint8) + mask = mask.astype(np.uint8) + mask = cv2.erode(mask, kernel, iterations=iterations) + mask = np.array(mask, dtype=bool) + + return mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + layer_type: Optional[Union[str, List[str]]] = None, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, + pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0, + is_qv: Optional[bool] = False, # thesea modified for quick validation of product shots + for_blend_preprocessing: Optional[bool] = False, # thesea modified for quick validation of product shots + is_multiprod: Optional[bool] = False, # thesea modified for quick validation of product shots + product_ratio: Optional[float] = None, # theseam modified for quick validation of product shots + is_inpainting: Optional[bool] = False, # controlnet inpainting + contains_element: Optional[bool] = False, # controlnet inpainting for element + iterations: Optional[int] = 10, # controlnet inpainting + iterations_erosion: Optional[int] = 8, # modified for injecting original prod image + mask_value: Optional[int] = 255, # controlnet inpainting + transparency_list: Optional[list[float]] = None, # for controlling transparency + image_width: Optional[int] = 1024, + image_height: Optional[int] = 1024, + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. **experimental feature**: to use this feature, + make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders + are not loaded. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`: + [`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + # 1. Check inputs. Raise error if not correct + # thesea modified + """ + self.check_inputs( + image, + prompt, + prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_embeds_scale=prompt_embeds_scale, + pooled_prompt_embeds_scale=pooled_prompt_embeds_scale, + ) + """ + + # 2. Define call parameters + if image is not None and isinstance(image, Image.Image): + batch_size = 1 + elif image is not None and isinstance(image, list): + # theseam modified + if product_ratio is not None: + batch_size = 1 + else: + batch_size = len(image) + else: + batch_size = image.shape[0] + # thesea modified for ip and txt masks + #if prompt is not None and isinstance(prompt, str): + # prompt = batch_size * [prompt] + if isinstance(prompt_embeds_scale, float): + prompt_embeds_scale = batch_size * [prompt_embeds_scale] + if isinstance(pooled_prompt_embeds_scale, float): + pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale] + + device = self._execution_device + + # 3. Prepare image embeddings + # thesea modified for ip and txt masks + if is_qv: + image_array_list = [] + mask_list = [] + is_product_list = [] + + if len(image) != len(layer_type): + raise ValueError( + f"number of images ({len(image)}) must match the number of layers {len(layer_type)}" + ) + + for img, img_type in zip(image, layer_type): + if 'product' in img_type or 'Product' in img_type: + is_product_list.append('true') + else: + is_product_list.append('false') + + img = img.convert('RGBA') + img = img.resize((image_width, image_height), resample=Image.BICUBIC) + + rgba_np = np.array(img) + mask = rgba_np[:, :, 3] + mask = mask > 0 + mask = np.stack((mask,)*3, axis=-1) + mask_list.append(mask) + + tmp_img = Image.new('RGBA', img.size, (255, 255, 255, 255)) + tmp_img.paste(img, mask=img.split()[3]) + tmp_img = tmp_img.convert('RGB') + image_array = np.asarray(tmp_img) / 255.0 + image_array_list.append(image_array) + + bg_mask = np.full((image_width, image_height, 3), True, dtype=bool) + image_mask_prod = {} + image_mask_prod_original = {} + image_mask_bg = {} + image_mask_all = {} + for index, (is_product, mask) in enumerate(zip(is_product_list, mask_list)): + if is_product.lower() == "true": + bg_mask = bg_mask & ~mask + else: + bg_mask = bg_mask | mask + + if index not in image_mask_all: + image_mask_all[index] = mask + for k in image_mask_all: + if k != index: + image_mask_all[k] = image_mask_all[k] & ~mask + + if is_product.lower() == "true": + if index not in image_mask_prod: + image_mask_prod[index] = mask + image_mask_prod_original[index] = mask + else: + if index not in image_mask_bg: + image_mask_bg[index] = mask + for k in image_mask_bg: + if k != index: + image_mask_bg[k] = image_mask_bg[k] & ~mask + + for k in image_mask_prod: + if k != index: + image_mask_prod[k] = image_mask_prod[k] & ~mask + + if len(image_mask_prod) > 1: + if not is_multiprod: + prompt=[prompt]*len(image_mask_prod) + + composed_image_all = np.zeros((image_width, image_height, 3)) + masked_bg = np.zeros((image_width, image_height, 3)) + masked_bg_original_array = [] + masked_bg_with_element = np.zeros((image_width, image_height, 3)) + composed_bg_image = np.zeros((image_width, image_height, 3)) + composed_prod_images = [] + composed_prod_images_all = np.zeros((image_width, image_height, 3)) + + bg_array = None + bg_mask = None + product_masks = [] + product_transparencies = [] + for index, (is_product, img_array, transparency) in enumerate(zip(is_product_list, image_array_list, transparency_list)): + if is_product.lower() == "true": + composed_prod_images.append(Image.fromarray((img_array*255).astype(np.uint8))) + composed_prod_images_all += img_array * image_mask_prod[index] + else: + composed_bg_image += img_array * image_mask_bg[index] + + if is_product.lower() == "true": + composed_image_all += img_array * image_mask_all[index] * (1.0 - transparency) + product_masks.append(image_mask_all[index]) + product_transparencies.append(transparency) + else: + composed_image_all += img_array * image_mask_all[index] + bg_array = img_array + bg_mask = image_mask_all[index] + + + if is_product.lower() == "true": + masked_bg += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) + masked_bg_original_array.append(mask_value*np.ones((image_width, image_height, 3)) * self.apply_erosion_to_mask(image_mask_all[index], iterations=iterations_erosion)) + + if index > 0: + masked_bg_with_element += mask_value*np.ones((image_width, image_height, 3)) * self.apply_dilate_to_mask(image_mask_all[index], iterations=iterations) + + for (product_mask, product_transparency) in zip(product_masks, product_transparencies): + composed_image_all += bg_array * product_mask * product_transparency + + composed_bg_image = Image.fromarray((composed_bg_image*255).astype(np.uint8)).convert('RGB') + composed_prod_images_all = Image.fromarray((composed_prod_images_all*255).astype(np.uint8)).convert('RGB') + composed_image_all = Image.fromarray((composed_image_all*255).astype(np.uint8)).convert('RGB') + masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') + masked_bg_original = [] + for tmp_masked_bg_original in masked_bg_original_array: + masked_bg_original.append(Image.fromarray(tmp_masked_bg_original.astype(np.uint8)).convert('RGB')) + masked_bg_with_element = Image.fromarray(masked_bg_with_element.astype(np.uint8)).convert('RGB') + + bg_mask = Image.fromarray(bg_mask.astype(np.uint8)*255).convert('RGB') + prod_masks = [] + for tmp_mask in image_mask_prod: + prod_masks.append(Image.fromarray(image_mask_prod[tmp_mask].astype(np.uint8)*255).convert('RGB')) + + prod_masks_original = [] + for tmp_mask in image_mask_prod_original: + prod_masks_original.append(Image.fromarray(image_mask_prod_original[tmp_mask].astype(np.uint8)*255).convert('RGB')) + + image_latents_bg = self.encode_image(composed_bg_image, device, 1) + image_latents_prods = [] + for composed_prod_image in composed_prod_images: + image_latents_prods.append(self.encode_image(composed_prod_image, device, 1)) + + image_embeds_bg = self.image_embedder(image_latents_bg).image_embeds + image_embeds_bg = image_embeds_bg.to(device=device) + + image_embeds_prods = [] + for image_latents_prod in image_latents_prods: + image_embeds_prod = self.image_embedder(image_latents_prod).image_embeds + image_embeds_prod = image_embeds_prod.to(device=device) + image_embeds_prods.append(image_embeds_prod) + else: + image = image.convert('RGB') + image_latents = self.encode_image(image, device, 1) + image_embeds = self.image_embedder(image_latents).image_embeds + image_embeds = image_embeds.to(device=device) + + # 3. Prepare (dummy) text embeddings + # thesea modified for ip and txt masks + if hasattr(self, "text_encoder") and self.text_encoder is not None: + prompt_embeds_list = [] + if isinstance(prompt, list): + for pmt in prompt: + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=pmt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=batch_size, + max_sequence_length=512, + lora_scale=None, + ) + prompt_embeds_list.append(prompt_embeds) + else: + ( + prompt_embeds, + pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=batch_size, + max_sequence_length=512, + lora_scale=None, + ) + prompt_embeds_list.append(prompt_embeds) # thesea modified for ip and txt masks + else: + if prompt is not None: + logger.warning( + "prompt input is ignored when text encoders are not loaded to the pipeline. " + "Make sure to explicitly load the text encoders to enable prompt input. " + ) + # max_sequence_length is 512, t5 encoder hidden size is 4096 + prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype) + # pooled_prompt_embeds is 768, clip text encoder hidden size + pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) + + # scale & concatenate image and text embeddings + if is_qv: + if for_blend_preprocessing: + if (len(prompt_embeds_list) - 1) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" + ) + + prompt_embeds = image_embeds_bg + else: + if len(prompt_embeds_list) != len(image_embeds_prods): + raise ValueError( + f"number of prompts ({len(prompt_embeds_list)}) must match the number of product images {len(image_embeds_prods)}" + ) + + prompt_embeds = image_embeds_bg + for tmp_prompt_embeds, tmp_image_embeds_prod in zip(reversed(prompt_embeds_list), reversed(image_embeds_prods)): + prompt_embeds = torch.cat([tmp_prompt_embeds, tmp_image_embeds_prod[:,:int(729*product_ratio),:], prompt_embeds], dim=1) + else: + prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + + prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[:, None, None] + pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=prompt_embeds.dtype)[ + :, None + ] + + # weighted sum + prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) + pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + if is_qv: + if is_inpainting: + if contains_element: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, masked_bg_with_element, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, prod_masks_original, bg_mask) + else: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, masked_bg, masked_bg_original, composed_bg_image, composed_prod_images, composed_prod_images_all, prod_masks, prod_masks_original, bg_mask) + else: + return (prompt_embeds, pooled_prompt_embeds, composed_image_all, composed_bg_image, composed_prod_images, prod_masks, bg_mask) + else: + return (prompt_embeds, pooled_prompt_embeds) + + return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds) From e5578d66be5369765283c6cf042c50030f4a5a35 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 11 Jul 2025 15:56:35 +0800 Subject: [PATCH 464/482] add fake shadow to image --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index f604506669d5..fbd977134db6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -409,6 +409,7 @@ def __call__( iterations: Optional[int] = 10, # controlnet inpainting iterations_erosion: Optional[int] = 8, # modified for injecting original prod image mask_value: Optional[int] = 255, # controlnet inpainting + shadow_mask: Optional[PipelineImageInput]=None, # added for enhancing shadow, image_width: Optional[int] = 1024, image_height: Optional[int] = 1024, return_dict: bool = True, @@ -571,6 +572,13 @@ def __call__( composed_bg_image = Image.fromarray(composed_bg_image.astype(np.uint8)).convert('RGB') composed_prod_images_all = Image.fromarray(composed_prod_images_all.astype(np.uint8)).convert('RGB') + + if shadow_mask is not None: + shadow_mask = np.all(shadow_mask == 255, axis=-1) + shadow_mask = np.stack((shadow_mask,)*3, axis=-1) + tmp_img = Image.new('RGB', img.size, (0, 0, 0)) + composed_image_all = shadow_mask * tmp_img * 0.5 + shadow_mask * composed_image_all * 0.5 + (~shadow_mask) * composed_image_all + composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') masked_bg = Image.fromarray(masked_bg.astype(np.uint8)).convert('RGB') masked_bg_original = [] From eccb75345742e07ed043d72b65ce0f5ca7df76b1 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 11 Jul 2025 16:16:52 +0800 Subject: [PATCH 465/482] revise shadow mask --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index fbd977134db6..6a1d3de02d54 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -574,9 +574,9 @@ def __call__( composed_prod_images_all = Image.fromarray(composed_prod_images_all.astype(np.uint8)).convert('RGB') if shadow_mask is not None: - shadow_mask = np.all(shadow_mask == 255, axis=-1) + shadow_mask = np.all(shadow_mask > 128, axis=-1) shadow_mask = np.stack((shadow_mask,)*3, axis=-1) - tmp_img = Image.new('RGB', img.size, (0, 0, 0)) + tmp_img = Image.new('RGB', (1024,1024), (0, 0, 0)) composed_image_all = shadow_mask * tmp_img * 0.5 + shadow_mask * composed_image_all * 0.5 + (~shadow_mask) * composed_image_all composed_image_all = Image.fromarray(composed_image_all.astype(np.uint8)).convert('RGB') From 2c23f93aa9e5d05a49c014e50dfe65799e1a148c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Fri, 11 Jul 2025 16:19:27 +0800 Subject: [PATCH 466/482] fix bug --- src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 6a1d3de02d54..75f2c34167fb 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -574,7 +574,7 @@ def __call__( composed_prod_images_all = Image.fromarray(composed_prod_images_all.astype(np.uint8)).convert('RGB') if shadow_mask is not None: - shadow_mask = np.all(shadow_mask > 128, axis=-1) + shadow_mask = np.all(np.array(shadow_mask) > 128, axis=-1) shadow_mask = np.stack((shadow_mask,)*3, axis=-1) tmp_img = Image.new('RGB', (1024,1024), (0, 0, 0)) composed_image_all = shadow_mask * tmp_img * 0.5 + shadow_mask * composed_image_all * 0.5 + (~shadow_mask) * composed_image_all From e270eaaf4acc8e97267d5660de20991c8a20d59f Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 12:48:57 +0900 Subject: [PATCH 467/482] Support Fal Kontext LoRA --- src/diffusers/loaders/lora_base.py | 8 +- .../loaders/lora_conversion_utils.py | 220 ++++++++++++++++++ src/diffusers/loaders/lora_pipeline.py | 23 +- 3 files changed, 248 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 280a9fa6e73f..473840fe38d8 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,6 +25,7 @@ from huggingface_hub.constants import HF_HUB_OFFLINE from ..models.modeling_utils import ModelMixin, load_state_dict +from ..utils.state_dict_utils import _load_sft_state_dict_metadata from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -206,6 +207,7 @@ def _fetch_state_dict( subfolder, user_agent, allow_pickle, + metadata=None, ): model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): @@ -236,11 +238,14 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + metadata = _load_sft_state_dict_metadata(model_file) + except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e # try loading non-safetensors weights model_file = None + metadata = None pass if model_file is None: @@ -261,10 +266,11 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = load_state_dict(model_file) + metadata = None else: state_dict = pretrained_model_name_or_path_or_dict - return state_dict + return state_dict, metadata def _best_guess_weight_name( diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 7fec3299eeac..c0fb910abcf9 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1503,6 +1503,226 @@ def remap_single_transformer_blocks_(key, state_dict): return converted_state_dict +def _convert_fal_kontext_lora_to_diffusers(original_state_dict): + converted_state_dict = {} + original_state_dict_keys = list(original_state_dict.keys()) + num_layers = 19 + num_single_layers = 38 + inner_dim = 3072 + mlp_ratio = 4.0 + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + original_block_prefix = "base_model.model." + + for lora_key in ["lora_A", "lora_B"]: + # norms + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight" + ) + if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight" + ) + + # Q, K, V + if lora_key == "lora_A": + sample_lora_weight = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight]) + + context_lora_weight = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( + [context_lora_weight] + ) + else: + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight" + ), + 3, + dim=0, + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v]) + + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight" + ), + 3, + dim=0, + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v]) + + if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys: + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), + 3, + dim=0, + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) + + if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys: + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), + 3, + dim=0, + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias]) + + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight" + ) + if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight" + ) + if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight" + ) + if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" + ) + + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight" + ) + if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" + ) + + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight" + ) + if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight" + ) + if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" + ) + + # single transformer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + + for lora_key in ["lora_A", "lora_B"]: + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight" + ) + if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" + ) + + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + + if lora_key == "lora_A": + lora_weight = original_state_dict.pop( + f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) + + if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: + lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias]) + else: + q, k, v, mlp = torch.split( + original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"), + split_size, + dim=0, + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) + + if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"), + split_size, + dim=0, + ) + converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias]) + + # output projections. + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight" + ) + if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" + ) + + for lora_key in ["lora_A", "lora_B"]: + converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( + f"{original_block_prefix}final_layer.linear.{lora_key}.weight" + ) + if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys: + converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( + f"{original_block_prefix}final_layer.linear.{lora_key}.bias" + ) + + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict): # Remove "diffusion_model." prefix from keys. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index aa508cf87f40..a73e8cdec13f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -40,6 +40,7 @@ ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, + _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, @@ -2003,6 +2004,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -2014,7 +2016,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2050,8 +2052,25 @@ def lora_state_dict( is_bfl_control = any("query_norm.scale" in k for k in state_dict) if is_bfl_control: state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + is_fal_kontext = any("base_model" in k for k in state_dict) + if is_fal_kontext: + state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict) + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys()) From 9c44182800fc9b0cb9ff5e5df7a0b109d648d346 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 12:54:25 +0900 Subject: [PATCH 468/482] fix bu --- src/diffusers/utils/state_dict_utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 3682c5bfacd6..bfcdfbde78e0 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -16,7 +16,7 @@ """ import enum - +import json from .import_utils import is_torch_available from .logging import get_logger @@ -347,3 +347,18 @@ def state_dict_all_zero(state_dict, filter_str=None): state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} return all(torch.all(param == 0).item() for param in state_dict.values()) + +def _load_sft_state_dict_metadata(model_file: str): + import safetensors.torch + + from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + if metadata: + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + return json.loads(raw) if raw else None + else: + return None From b5495e16203903a1ebb2fcee8982fa285f46b04d Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 12:56:57 +0900 Subject: [PATCH 469/482] fix bug --- src/diffusers/loaders/lora_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 473840fe38d8..fe5f52c7149a 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -63,6 +63,7 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): From ad94d113f6dba3ed3e26e5199320497312455770 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 13:05:05 +0900 Subject: [PATCH 470/482] updates --- src/diffusers/loaders/lora_base.py | 375 ++-- .../loaders/lora_conversion_utils.py | 582 ++++-- src/diffusers/loaders/lora_pipeline.py | 1686 ++++++++++------- src/diffusers/utils/state_dict_utils.py | 10 +- 4 files changed, 1666 insertions(+), 987 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index fe5f52c7149a..ff2cb6a54c1c 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ import copy import inspect +import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -25,7 +26,6 @@ from huggingface_hub.constants import HF_HUB_OFFLINE from ..models.modeling_utils import ModelMixin, load_state_dict -from ..utils.state_dict_utils import _load_sft_state_dict_metadata from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -34,7 +34,6 @@ delete_adapter_layers, deprecate, get_adapter_name, - get_peft_kwargs, is_accelerate_available, is_peft_available, is_peft_version, @@ -46,13 +45,13 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.peft_utils import _create_lora_config +from ..utils.state_dict_utils import _load_sft_state_dict_metadata if is_transformers_available(): from transformers import PreTrainedModel - from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules - if is_peft_available(): from peft.tuners.tuners_utils import BaseTunerLayer @@ -306,13 +305,18 @@ def _best_guess_weight_name( targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) if len(targeted_files) > 1: - raise ValueError( - f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." + logger.warning( + f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`." ) weight_name = targeted_files[0] return weight_name +def _pack_dict_with_prefix(state_dict, prefix): + sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} + return sd_with_prefix + + def _load_lora_into_text_encoder( state_dict, network_alphas, @@ -324,10 +328,16 @@ def _load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading + if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + if network_alphas and metadata: + raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.") + peft_kwargs = {} if low_cpu_mem_usage: if not is_peft_version(">=", "0.13.1"): @@ -342,8 +352,6 @@ def _load_lora_into_text_encoder( ) peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - from peft import LoraConfig - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as # their prefixes. @@ -355,7 +363,9 @@ def _load_lora_into_text_encoder( # Load the layers corresponding to text encoder and make necessary adjustments. if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -365,54 +375,27 @@ def _load_lora_into_text_encoder( # convert state dict state_dict = convert_state_dict_to_peft(state_dict) - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] + for name, _ in text_encoder.named_modules(): + if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")): + rank_key = f"{name}.lora_B.weight" + if rank_key in state_dict: + rank[rank_key] = state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") + network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} - lora_config = LoraConfig(**lora_config_kwargs) + # create `LoraConfig` + lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) - is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) - + # if prefix is not None and not state_dict: + model_class_name = text_encoder.__class__.__name__ logger.warning( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " + f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " "This is safe to ignore if LoRA state dict didn't originally have any " - f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " + f"{model_class_name} related params. You can also try specifying `prefix=None` " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "https://github.com/huggingface/diffusers/issues/new" ) def _func_optionally_disable_offloading(_pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. + """ + from ..hooks.group_offloading import _is_group_offload_enabled + is_model_cpu_offload = False is_sequential_cpu_offload = False + is_group_offload = False if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) + if not isinstance(component, nn.Module): + continue + is_group_offload = is_group_offload or _is_group_offload_enabled(component) + if not hasattr(component, "_hf_hook"): + continue + is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) + is_sequential_cpu_offload = is_sequential_cpu_offload or ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) + if is_sequential_cpu_offload or is_model_cpu_offload: + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + for _, component in _pipeline.components.items(): + if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"): + continue remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - return (is_model_cpu_offload, is_sequential_cpu_offload) + return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) class LoraBaseMixin: """Utility class for handling LoRAs.""" _lora_loadable_modules = [] - num_fused_loras = 0 + _merged_adapters = set() + + @property + def lora_scale(self) -> float: + """ + Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set, + return 1. + """ + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + @property + def num_fused_loras(self): + """Returns the number of LoRAs that have been fused.""" + return len(self._merged_adapters) + + @property + def fused_loras(self): + """Returns names of the LoRAs that have been fused.""" + return self._merged_adapters def load_lora_weights(self, **kwargs): raise NotImplementedError("`load_lora_weights()` is not implemented.") @@ -485,33 +510,6 @@ def save_lora_weights(cls, **kwargs): def lora_state_dict(cls, **kwargs): raise NotImplementedError("`lora_state_dict()` is not implemented.") - @classmethod - def _optionally_disable_offloading(cls, _pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. - """ - return _func_optionally_disable_offloading(_pipeline=_pipeline) - - @classmethod - def _fetch_state_dict(cls, *args, **kwargs): - deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." - deprecate("_fetch_state_dict", "0.35.0", deprecation_message) - return _fetch_state_dict(*args, **kwargs) - - @classmethod - def _best_guess_weight_name(cls, *args, **kwargs): - deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." - deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) - return _best_guess_weight_name(*args, **kwargs) - def unload_lora_weights(self): """ Unloads the LoRA parameters. @@ -599,6 +597,9 @@ def fuse_lora( if len(components) == 0: raise ValueError("`components` cannot be an empty list.") + # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it + # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`. + merged_adapter_names = set() for fuse_component in components: if fuse_component not in self._lora_loadable_modules: raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") @@ -608,13 +609,19 @@ def fuse_lora( # check if diffusers model if issubclass(model.__class__, ModelMixin): model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + merged_adapter_names.update(set(module.merged_adapters)) # handle transformers models. if issubclass(model.__class__, PreTrainedModel): fuse_text_encoder_lora( model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + merged_adapter_names.update(set(module.merged_adapters)) - self.num_fused_loras += 1 + self._merged_adapters = self._merged_adapters | merged_adapter_names def unfuse_lora(self, components: List[str] = [], **kwargs): r""" @@ -668,15 +675,42 @@ def unfuse_lora(self, components: List[str] = [], **kwargs): if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): for module in model.modules(): if isinstance(module, BaseTunerLayer): + for adapter in set(module.merged_adapters): + if adapter and adapter in self._merged_adapters: + self._merged_adapters = self._merged_adapters - {adapter} module.unmerge() - self.num_fused_loras -= 1 - def set_adapters( self, adapter_names: Union[List[str], str], adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, ): + """ + Set the currently active adapters for use in the pipeline. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + adapter_weights (`Union[List[float], float]`, *optional*): + The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the + adapters. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + ``` + """ if isinstance(adapter_weights, dict): components_passed = set(adapter_weights.keys()) lora_components = set(self._lora_loadable_modules) @@ -746,6 +780,24 @@ def set_adapters( set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) def disable_lora(self): + """ + Disables the active LoRA layers of the pipeline. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.disable_lora() + ``` + """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -758,6 +810,24 @@ def disable_lora(self): disable_lora_for_text_encoder(model) def enable_lora(self): + """ + Enables the active LoRA layers of the pipeline. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" + ) + pipeline.enable_lora() + ``` + """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -771,10 +841,26 @@ def enable_lora(self): def delete_adapters(self, adapter_names: Union[List[str], str]): """ + Delete an adapter's LoRA layers from the pipeline. + Args: - Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). adapter_names (`Union[List[str], str]`): - The names of the adapter to delete. Can be a single string or a list of strings + The names of the adapters to delete. + + Example: + + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights( + "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" + ) + pipeline.delete_adapters("cinematic") + ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -851,6 +937,27 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case you want to load multiple adapters and free some GPU memory. + After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters + can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to + GPU before using those LoRA adapters for inference. + + ```python + >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1") + >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2") + >>> pipe.set_adapters("adapter-1") + >>> image_1 = pipe(**kwargs) + >>> # switch to adapter-2, offload adapter-1 + >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu") + >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0") + >>> pipe.set_adapters("adapter-2") + >>> image_2 = pipe(**kwargs) + >>> # switch back to adapter-1, offload adapter-2 + >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu") + >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0") + >>> pipe.set_adapters("adapter-1") + >>> ... + ``` + Args: adapter_names (`List[str]`): List of adapters to send device to. @@ -866,6 +973,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, for module in model.modules(): if isinstance(module, BaseTunerLayer): for adapter_name in adapter_names: + if adapter_name not in module.lora_A: + # it is sufficient to check lora_A + continue + module.lora_A[adapter_name].to(device) module.lora_B[adapter_name].to(device) # this is a param, not a module, so device placement is not in-place -> re-assign @@ -875,11 +986,28 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, adapter_name ].to(device) + def enable_lora_hotswap(self, **kwargs) -> None: + """ + Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are + different. + + Args: + target_rank (`int`): + The highest rank among all the adapters that will be loaded. + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle a model that is already compiled. The check can return the following messages: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + for key, component in self.components.items(): + if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): + component.enable_lora_hotswap(**kwargs) + @staticmethod def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict + return _pack_dict_with_prefix(layers_weights, prefix) @staticmethod def write_lora_layers( @@ -889,16 +1017,33 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + lora_adapter_metadata: Optional[dict] = None, ): + """Writes the state dict of the LoRA layers (optionally with metadata) to disk.""" if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + if lora_adapter_metadata and not safe_serialization: + raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") + if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict): + raise TypeError("`lora_adapter_metadata` must be of type `dict`.") + if save_function is None: if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + # Inject framework format. + metadata = {"format": "pt"} + if lora_adapter_metadata: + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps( + lora_adapter_metadata, indent=2, sort_keys=True + ) + + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save @@ -915,28 +1060,6 @@ def save_function(weights, filename): save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") - @property - def lora_scale(self) -> float: - # property function that returns the lora scale which can be set at run time by the pipeline. - # if _lora_scale has not been set, return 1 - return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - - def enable_lora_hotswap(self, **kwargs) -> None: - """Enables the possibility to hotswap LoRA adapters. - - Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of - the loaded adapters differ. - - Args: - target_rank (`int`): - The highest rank among all the adapters that will be loaded. - check_compiled (`str`, *optional*, defaults to `"error"`): - How to handle the case when the model is already compiled, which should generally be avoided. The - options are: - - "error" (default): raise an error - - "warn": issue a warning - - "ignore": do nothing - """ - for key, component in self.components.items(): - if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): - component.enable_lora_hotswap(**kwargs) + @classmethod + def _optionally_disable_offloading(cls, _pipeline): + return _func_optionally_disable_offloading(_pipeline=_pipeline) \ No newline at end of file diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index c0fb910abcf9..1ef7f49a8e9b 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -433,7 +433,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] if not is_sparse: # down_weight is copied to each split - ait_sd.update({k: down_weight for k in ait_down_keys}) + ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 @@ -727,8 +727,25 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): elif k.startswith("lora_te1_"): has_te_keys = True continue + elif k.startswith("lora_transformer_context_embedder"): + diffusers_key = "context_embedder" + elif k.startswith("lora_transformer_norm_out_linear"): + diffusers_key = "norm_out.linear" + elif k.startswith("lora_transformer_proj_out"): + diffusers_key = "proj_out" + elif k.startswith("lora_transformer_x_embedder"): + diffusers_key = "x_embedder" + elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.text_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}" else: - raise NotImplementedError + raise NotImplementedError(f"Handling for key ({k}) is not implemented.") if "attn_" in k: if "_to_out_0" in k: @@ -819,7 +836,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_pe: logger.info( "The `position_embedding` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the curret state dict to make loading possible." + "So, we will purge them out of the current state dict to make loading possible." ) else: @@ -835,7 +852,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_t5: logger.info( "The `t5xxl` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the curret state dict to make loading possible." + "So, we will purge them out of the current state dict to make loading possible." ) else: logger.info( @@ -850,7 +867,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_diff_b: logger.info( "The `diff_b` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the curret state dict to make loading possible." + "So, we will purge them out of the current state dict to make loading possible." ) else: logger.info( @@ -866,7 +883,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_diff: logger.info( "The `diff` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the curret state dict to make loading possible." + "So, we will purge them out of the current state dict to make loading possible." ) else: logger.info( @@ -923,7 +940,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] # down_weight is copied to each split - ait_sd.update({k: down_weight for k in ait_down_keys}) + ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 @@ -1237,7 +1254,7 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): f"double_blocks.{i}.txt_attn.norm.key_norm.scale" ) - # single transfomer blocks + # single transformer blocks for i in range(num_single_layers): block_prefix = f"single_transformer_blocks.{i}." @@ -1329,180 +1346,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): return converted_state_dict -def _convert_hunyuan_video_lora_to_diffusers(original_state_dict): - converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())} - - def remap_norm_scale_shift_(key, state_dict): - weight = state_dict.pop(key) - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight - - def remap_txt_in_(key, state_dict): - def rename_key(key): - new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") - new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") - new_key = new_key.replace("txt_in", "context_embedder") - new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") - new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") - new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") - new_key = new_key.replace("mlp", "ff") - return new_key - - if "self_attn_qkv" in key: - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v - else: - state_dict[rename_key(key)] = state_dict.pop(key) - - def remap_img_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - if "lora_A" in key: - state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight - state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight - state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight - else: - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q - state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k - state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v - - def remap_txt_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - if "lora_A" in key: - state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight - state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight - state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight - else: - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q - state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k - state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v - - def remap_single_transformer_blocks_(key, state_dict): - hidden_size = 3072 - - if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key: - linear1_weight = state_dict.pop(key) - if "lora_A" in key: - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_A.weight" - ) - state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight - state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight - state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight - state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight - else: - split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) - q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_B.weight" - ) - state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q - state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k - state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v - state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp - - elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key: - linear1_bias = state_dict.pop(key) - if "lora_A" in key: - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_A.bias" - ) - state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias - state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias - state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias - state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias - else: - split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) - q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_B.bias" - ) - state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias - state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias - state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias - state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias - - else: - new_key = key.replace("single_blocks", "single_transformer_blocks") - new_key = new_key.replace("linear2", "proj_out") - new_key = new_key.replace("q_norm", "attn.norm_q") - new_key = new_key.replace("k_norm", "attn.norm_k") - state_dict[new_key] = state_dict.pop(key) - - TRANSFORMER_KEYS_RENAME_DICT = { - "img_in": "x_embedder", - "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", - "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", - "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", - "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", - "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", - "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", - "double_blocks": "transformer_blocks", - "img_attn_q_norm": "attn.norm_q", - "img_attn_k_norm": "attn.norm_k", - "img_attn_proj": "attn.to_out.0", - "txt_attn_q_norm": "attn.norm_added_q", - "txt_attn_k_norm": "attn.norm_added_k", - "txt_attn_proj": "attn.to_add_out", - "img_mod.linear": "norm1.linear", - "img_norm1": "norm1.norm", - "img_norm2": "norm2", - "img_mlp": "ff", - "txt_mod.linear": "norm1_context.linear", - "txt_norm1": "norm1.norm", - "txt_norm2": "norm2_context", - "txt_mlp": "ff_context", - "self_attn_proj": "attn.to_out.0", - "modulation.linear": "norm.linear", - "pre_norm": "norm.norm", - "final_layer.norm_final": "norm_out.norm", - "final_layer.linear": "proj_out", - "fc1": "net.0.proj", - "fc2": "net.2", - "input_embedder": "proj_in", - } - - TRANSFORMER_SPECIAL_KEYS_REMAP = { - "txt_in": remap_txt_in_, - "img_attn_qkv": remap_img_attn_qkv_, - "txt_attn_qkv": remap_txt_attn_qkv_, - "single_blocks": remap_single_transformer_blocks_, - "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, - } - - # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys - # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make - # sure that both follow the same initial format by stripping off the "transformer." prefix. - for key in list(converted_state_dict.keys()): - if key.startswith("transformer."): - converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key) - if key.startswith("diffusion_model."): - converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key) - - # Rename and remap the state dict keys - for key in list(converted_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) - - # Add back the "transformer." prefix - for key in list(converted_state_dict.keys()): - converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) - - return converted_state_dict - def _convert_fal_kontext_lora_to_diffusers(original_state_dict): converted_state_dict = {} original_state_dict_keys = list(original_state_dict.keys()) @@ -1724,6 +1567,182 @@ def _convert_fal_kontext_lora_to_diffusers(original_state_dict): return converted_state_dict + +def _convert_hunyuan_video_lora_to_diffusers(original_state_dict): + converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())} + + def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + if "lora_A" in key: + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight + else: + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + if "lora_A" in key: + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight + else: + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key: + linear1_weight = state_dict.pop(key) + if "lora_A" in key: + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_A.weight" + ) + state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight + state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight + state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight + state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight + else: + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_B.weight" + ) + state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q + state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k + state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v + state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp + + elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key: + linear1_bias = state_dict.pop(key) + if "lora_A" in key: + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_A.bias" + ) + state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias + state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias + state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias + state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias + else: + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_B.bias" + ) + state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + } + + # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys + # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make + # sure that both follow the same initial format by stripping off the "transformer." prefix. + for key in list(converted_state_dict.keys()): + if key.startswith("transformer."): + converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key) + if key.startswith("diffusion_model."): + converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key) + + # Rename and remap the state dict keys + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + # Add back the "transformer." prefix + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict + + def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict): # Remove "diffusion_model." prefix from keys. state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} @@ -1799,48 +1818,175 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict = {} original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) + block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")} + min_block = min(block_numbers) + max_block = max(block_numbers) + is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) + lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" + lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" + has_time_projection_weight = any( + k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict + ) - for i in range(num_blocks): + for key in list(original_state_dict.keys()): + if key.endswith((".diff", ".diff_b")) and "norm" in key: + # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it + # in future if needed and they are not zeroed. + original_state_dict.pop(key) + logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.") + + if "time_projection" in key and not has_time_projection_weight: + # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where + # our lora config adds the time proj lora layers, but we don't have the weights for them. + # CausVid lora has the weight keys and the bias keys. + original_state_dict.pop(key) + + # For the `diff_b` keys, we treat them as lora_bias. + # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias + + for i in range(min_block, max_block + 1): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_B.weight" - ) + original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.self_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" - ) + original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" - ) + original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.cross_attn.{o}.diff_b" + converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): - converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_B.weight" - ) + original_key = f"blocks.{i}.{o}.{lora_down_key}.weight" + converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.{o}.{lora_up_key}.weight" + converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"blocks.{i}.{o}.diff_b" + converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + # Remaining. + if original_state_dict: + if any("time_projection" in k for k in original_state_dict): + original_key = f"time_projection.1.{lora_down_key}.weight" + converted_key = "condition_embedder.time_proj.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"time_projection.1.{lora_up_key}.weight" + converted_key = "condition_embedder.time_proj.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + if "time_projection.1.diff_b" in original_state_dict: + converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop( + "time_projection.1.diff_b" + ) + + if any("head.head" in k for k in state_dict): + converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop( + f"head.head.{lora_down_key}.weight" + ) + converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight") + if "head.head.diff_b" in original_state_dict: + converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b") + + for text_time in ["text_embedding", "time_embedding"]: + if any(text_time in k for k in original_state_dict): + for b_n in [0, 2]: + diffusers_b_n = 1 if b_n == 0 else 2 + diffusers_name = ( + "condition_embedder.text_embedder" + if text_time == "text_embedding" + else "condition_embedder.time_embedder" + ) + if any(f"{text_time}.{b_n}" in k for k in original_state_dict): + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight") + ) + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight") + ) + if f"{text_time}.{b_n}.diff_b" in original_state_dict: + converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = ( + original_state_dict.pop(f"{text_time}.{b_n}.diff_b") + ) + + for img_ours, img_theirs in [ + ("ff.net.0.proj", "img_emb.proj.1"), + ("ff.net.2", "img_emb.proj.3"), + ]: + original_key = f"{img_theirs}.{lora_down_key}.weight" + converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) + + original_key = f"{img_theirs}.{lora_up_key}.weight" + converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight" + if original_key in original_state_dict: + converted_state_dict[converted_key] = original_state_dict.pop(original_key) if len(original_state_dict) > 0: - raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") + diff = all(".diff" in k for k in original_state_dict) + if diff: + diff_keys = {k for k in original_state_dict if k.endswith(".diff")} + if not all("lora" not in k for k in diff_keys): + raise ValueError + logger.info( + "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: " + "https://github.com/huggingface/diffusers//issues/new" + ) + else: + raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") for key in list(converted_state_dict.keys()): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) @@ -1907,3 +2053,19 @@ def get_alpha_scales(down_weight, key): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict + + +def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): + if not all(k.startswith(non_diffusers_prefix) for k in state_dict): + raise ValueError("Invalid LoRA state dict for HiDream.") + converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} + converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} + return converted_state_dict + + +def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): + if not all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict): + raise ValueError("Invalid LoRA state dict for LTX-Video.") + converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} + converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} + return converted_state_dict \ No newline at end of file diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a73e8cdec13f..f645ef16eb7c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ LoraBaseMixin, _fetch_state_dict, _load_lora_into_text_encoder, + _pack_dict_with_prefix, ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, @@ -44,7 +45,9 @@ _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, + _convert_non_diffusers_ltxv_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, @@ -80,30 +83,36 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): from ..quantizers.gguf.utils import dequantize_gguf_tensor is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" + is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params" is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" if is_bnb_4bit_quantized and not is_bitsandbytes_available(): raise ValueError( "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." ) + if is_bnb_8bit_quantized and not is_bitsandbytes_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints." + ) if is_gguf_quantized and not is_gguf_available(): raise ValueError( "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." ) weight_on_cpu = False - if not module.weight.is_cuda: + if module.weight.device.type == "cpu": weight_on_cpu = True - if is_bnb_4bit_quantized: + device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" + if is_bnb_4bit_quantized or is_bnb_8bit_quantized: module_weight = dequantize_bnb_weight( - module.weight.cuda() if weight_on_cpu else module.weight, - state=module.weight.quant_state, + module.weight.to(device) if weight_on_cpu else module.weight, + state=module.weight.quant_state if is_bnb_4bit_quantized else module.state, dtype=model.dtype, ).data elif is_gguf_quantized: module_weight = dequantize_gguf_tensor( - module.weight.cuda() if weight_on_cpu else module.weight, + module.weight.to(device) if weight_on_cpu else module.weight, ) module_weight = module_weight.to(model.dtype) else: @@ -128,7 +137,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name=None, + adapter_name: Optional[str] = None, hotswap: bool = False, **kwargs, ): @@ -155,7 +164,7 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) + hotswap (`bool`, *optional*): Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means that, instead of loading an additional adapter, this will take the existing adapter weights and replace them with the weights of the new adapter. This can be faster and more @@ -195,7 +204,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -206,6 +216,7 @@ def load_lora_weights( network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -219,6 +230,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + metadata=metadata, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -275,6 +287,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -288,18 +302,16 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -336,7 +348,8 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) + return out @classmethod def load_lora_into_unet( @@ -348,6 +361,7 @@ def load_lora_into_unet( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -369,29 +383,11 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -410,6 +406,7 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -427,6 +424,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -452,29 +450,11 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -484,6 +464,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -499,6 +480,8 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + unet_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -521,8 +504,13 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + unet_lora_adapter_metadata: + LoRA adapter metadata associated with the unet to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") @@ -533,6 +521,14 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if unet_lora_adapter_metadata: + lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -541,6 +537,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -626,6 +623,7 @@ def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name: Optional[str] = None, + hotswap: bool = False, **kwargs, ): """ @@ -652,6 +650,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -673,7 +673,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, @@ -688,8 +689,10 @@ def load_lora_weights( network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) self.load_lora_into_text_encoder( state_dict, @@ -698,8 +701,10 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) self.load_lora_into_text_encoder( state_dict, @@ -708,8 +713,10 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -765,6 +772,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -778,18 +787,16 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -826,7 +833,8 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) + return out @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet @@ -839,6 +847,7 @@ def load_lora_into_unet( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -860,29 +869,11 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -901,6 +892,7 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -919,6 +911,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -944,29 +937,11 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -976,6 +951,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -992,6 +968,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + unet_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -1017,8 +996,15 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + unet_lora_adapter_metadata: + LoRA adapter metadata associated with the unet to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + text_encoder_2_lora_adapter_metadata: + LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1034,6 +1020,19 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if unet_lora_adapter_metadata is not None: + lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + + if text_encoder_2_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + ) + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -1041,6 +1040,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -1174,6 +1174,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1187,18 +1189,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1219,7 +1219,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out def load_lora_weights( self, @@ -1249,29 +1250,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1289,7 +1269,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1299,6 +1280,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1310,6 +1292,7 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1321,6 +1304,7 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1328,7 +1312,14 @@ def load_lora_weights( @classmethod def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1346,29 +1337,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1381,6 +1354,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1399,6 +1373,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1424,29 +1399,11 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1456,6 +1413,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1473,6 +1431,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -1498,8 +1459,15 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + text_encoder_2_lora_adapter_metadata: + LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1515,6 +1483,21 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + + if text_encoder_2_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + ) + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -1522,6 +1505,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer @@ -1653,6 +1637,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1666,18 +1652,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1698,11 +1682,16 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -1720,6 +1709,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1737,7 +1728,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1747,14 +1739,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1772,29 +1773,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1807,6 +1790,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1822,9 +1806,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -1841,14 +1826,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -1858,6 +1850,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -1991,7 +1984,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -2011,10 +2005,7 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -2041,13 +2032,25 @@ def lora_state_dict( if is_kohya: state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) # Kohya already takes care of scaling the LoRA parameters with alpha. - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) is_xlabs = any("processor" in k for k in state_dict) if is_xlabs: state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) # xlabs doesn't use `alpha`. - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) is_bfl_control = any("query_norm.scale" in k for k in state_dict) if is_bfl_control: @@ -2070,7 +2073,7 @@ def lora_state_dict( return_alphas=return_alphas, return_metadata=return_lora_metadata, ) - + # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys()) @@ -2087,15 +2090,21 @@ def lora_state_dict( f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." ) - if return_alphas: - return state_dict, network_alphas - else: + if return_alphas or return_lora_metadata: + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=network_alphas, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) + else: return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name=None, + adapter_name: Optional[str] = None, hotswap: bool = False, **kwargs, ): @@ -2114,34 +2123,16 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new - adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an - additional method before loading the adapter: - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2157,7 +2148,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) @@ -2208,6 +2200,7 @@ def load_lora_weights( network_alphas=network_alphas, transformer=transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2227,6 +2220,7 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2239,6 +2233,7 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2263,29 +2258,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2298,6 +2275,7 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2315,7 +2293,7 @@ def _load_norm_into_transformer( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) # Find invalid keys transformer_state_dict = transformer.state_dict() @@ -2370,6 +2348,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2395,29 +2374,11 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2427,6 +2388,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2443,6 +2405,8 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -2465,8 +2429,13 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") @@ -2477,6 +2446,16 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -2485,6 +2464,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -2646,7 +2626,7 @@ def _maybe_expand_transformer_param_shape_or_error_( ) -> bool: """ Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and - generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. + generalizes things a bit so that any parameter that needs expansion receives appropriate treatment. """ state_dict = {} if lora_state_dict is not None: @@ -2658,7 +2638,7 @@ def _maybe_expand_transformer_param_shape_or_error_( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False @@ -2776,14 +2756,13 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if unexpected_modules: logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - is_peft_loaded = getattr(transformer, "peft_config", None) is not None for k in lora_module_names: if k in unexpected_modules: continue base_param_name = ( f"{k.replace(prefix, '')}.base_layer.weight" - if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict + if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict else f"{k.replace(prefix, '')}.weight" ) base_weight_param = transformer_state_dict[base_param_name] @@ -2837,6 +2816,15 @@ def _get_weight_shape(weight: torch.Tensor): raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + @staticmethod + def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False): + outputs = [state_dict] + if return_alphas: + outputs.append(alphas) + if return_metadata: + outputs.append(metadata) + return tuple(outputs) if (return_alphas or return_metadata) else state_dict + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. @@ -2853,6 +2841,7 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2877,29 +2866,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2912,6 +2883,7 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2930,6 +2902,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2955,29 +2928,11 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2987,6 +2942,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3106,6 +3062,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3119,18 +3077,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3151,10 +3107,15 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3172,6 +3133,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3189,7 +3152,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3199,14 +3163,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3224,29 +3197,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3259,6 +3214,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3274,9 +3230,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3293,14 +3250,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3310,6 +3274,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3436,6 +3401,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3449,18 +3416,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3481,11 +3446,16 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3503,6 +3473,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3520,7 +3492,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3530,14 +3503,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3555,29 +3537,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3590,6 +3554,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3605,9 +3570,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3624,14 +3590,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3641,6 +3614,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3720,7 +3694,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -3769,7 +3742,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -3782,18 +3756,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3814,11 +3786,20 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) + if is_non_diffusers_format: + state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict) + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3836,6 +3817,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3853,7 +3836,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3863,14 +3847,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3888,29 +3881,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3923,6 +3898,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3938,9 +3914,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3957,14 +3934,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3974,6 +3958,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4102,6 +4087,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -4115,18 +4102,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4147,11 +4132,16 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4169,6 +4159,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4186,7 +4178,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4196,14 +4189,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4221,29 +4223,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4256,6 +4240,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4271,9 +4256,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4290,14 +4276,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4307,6 +4300,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4434,7 +4428,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4447,18 +4442,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4483,11 +4476,16 @@ def lora_state_dict( if is_original_hunyuan_video: state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4505,6 +4503,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4522,7 +4522,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4532,14 +4533,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4557,29 +4567,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4592,6 +4584,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4607,9 +4600,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4626,14 +4620,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4643,6 +4644,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4770,7 +4772,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4783,18 +4786,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4820,11 +4821,16 @@ def lora_state_dict( if non_diffusers: state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4842,6 +4848,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4859,7 +4867,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4869,14 +4878,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4894,29 +4912,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4929,6 +4929,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4944,9 +4945,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4963,14 +4965,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4980,6 +4989,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -5107,7 +5117,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -5120,18 +5131,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5156,7 +5165,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out @classmethod def _maybe_expand_t2v_lora_for_i2v( @@ -5167,26 +5177,51 @@ def _maybe_expand_t2v_lora_for_i2v( if transformer.config.image_dim is None: return state_dict + target_device = transformer.device + if any(k.startswith("transformer.blocks.") for k in state_dict): - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) + has_bias = any(".lora_B.bias" in k for k in state_dict) if is_i2v_lora: return state_dict for i in range(num_blocks): for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + # These keys should exist if the block `i` was part of the T2V LoRA. + ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" + ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" + + if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: + continue + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"] + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device ) state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"] + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device ) + # If the original LoRA had biases (indicated by has_bias) + # AND the specific reference bias key exists for this block. + + ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" + if has_bias and ref_key_lora_B_bias in state_dict: + ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( + ref_lora_B_bias_tensor, + device=target_device, + ) + return state_dict def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -5204,6 +5239,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -5221,7 +5258,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers state_dict = self._maybe_expand_t2v_lora_for_i2v( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, @@ -5235,14 +5273,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5260,29 +5307,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5295,6 +5324,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5310,9 +5340,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5329,14 +5360,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5346,6 +5384,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5474,6 +5513,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -5487,18 +5528,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5519,11 +5558,16 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -5541,6 +5585,8 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -5558,7 +5604,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -5568,14 +5615,23 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5593,29 +5649,11 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5628,6 +5666,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5643,9 +5682,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5662,14 +5702,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5679,6 +5726,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5748,8 +5796,352 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class HiDreamImageLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + is_non_diffusers_format = any("diffusion_model" in k for k in state_dict) + if is_non_diffusers_format: + state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict) + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`HiDreamImageTransformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, + ): + r""" + Save the LoRA parameters corresponding to the transformer. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + """ + state_dict = {} + lora_adapter_metadata = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, + ) + + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index bfcdfbde78e0..b1e804cb4089 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ import enum import json + from .import_utils import is_torch_available from .logging import get_logger @@ -219,7 +220,7 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): kwargs (`dict`, *args*): Additional arguments to pass to the method. - - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + - **adapter_name**: For example, in case of PEFT, some keys will be prepended with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in `get_peft_model_state_dict` method: https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 @@ -290,7 +291,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): kwargs (`dict`, *args*): Additional arguments to pass to the method. - - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + - **adapter_name**: For example, in case of PEFT, some keys will be prepended with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in `get_peft_model_state_dict` method: https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 @@ -348,6 +349,7 @@ def state_dict_all_zero(state_dict, filter_str=None): return all(torch.all(param == 0).item() for param in state_dict.values()) + def _load_sft_state_dict_metadata(model_file: str): import safetensors.torch @@ -361,4 +363,4 @@ def _load_sft_state_dict_metadata(model_file: str): raw = metadata.get(LORA_ADAPTER_METADATA_KEY) return json.loads(raw) if raw else None else: - return None + return None \ No newline at end of file From f349254fb3b6aa78a6cebe9a3181b86d041dd2f7 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 13:07:35 +0900 Subject: [PATCH 471/482] updates --- src/diffusers/utils/peft_utils.py | 144 +++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 10 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index d1269fbc5f20..94aa224e1a4c 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,9 +21,13 @@ from packaging import version -from .import_utils import is_peft_available, is_torch_available +from . import logging +from .import_utils import is_peft_available, is_peft_version, is_torch_available +from .torch_utils import empty_device_cache +logger = logging.get_logger(__name__) + if is_torch_available(): import torch @@ -95,8 +99,7 @@ def recurse_remove_peft_layers(model): setattr(model, name, new_module) del module - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache() return model @@ -147,25 +150,27 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): +def get_peft_kwargs( + rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None +): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] if len(set(rank_dict.values())) > 1: - # get the rank occuring the most number of times + # get the rank occurring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] - # for modules with rank different from the most occuring rank, add it to the `rank_pattern` + # for modules with rank different from the most occurring rank, add it to the `rank_pattern` rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} if network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: - # get the alpha occuring the most number of times + # get the alpha occurring the most number of times lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] - # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` + # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern` alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) if is_unet: alpha_pattern = { @@ -177,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True else: lora_alpha = set(network_alpha_dict.values()).pop() - # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) # for now we know that the "bias" keys are only associated with `lora_B`. @@ -192,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True "use_dora": use_dora, "lora_bias": lora_bias, } + + # Example: try load FusionX LoRA into Wan VACE + exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name) + if exclude_modules: + if not is_peft_version(">=", "0.14.0"): + msg = """ +It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft` +version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U +peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue - +https://github.com/huggingface/diffusers/issues/new + """ + logger.debug(msg) + else: + lora_config_kwargs.update({"exclude_modules": exclude_modules}) + return lora_config_kwargs @@ -288,3 +307,108 @@ def check_peft_version(min_version: str) -> None: f"The version of PEFT you are using is not compatible, please use a version that is greater" f" than {min_version}" ) + + +def _create_lora_config( + state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None +): + from peft import LoraConfig + + if metadata is not None: + lora_config_kwargs = metadata + else: + lora_config_kwargs = get_peft_kwargs( + rank_pattern_dict, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + is_unet=is_unet, + model_state_dict=model_state_dict, + adapter_name=adapter_name, + ) + + _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) + + # Version checks for DoRA and lora_bias + if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") + + if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.") + + try: + return LoraConfig(**lora_config_kwargs) + except TypeError as e: + raise TypeError("`LoraConfig` class could not be instantiated.") from e + + +def _maybe_raise_error_for_ambiguous_keys(config): + rank_pattern = config["rank_pattern"].copy() + target_modules = config["target_modules"] + + for key in list(rank_pattern.keys()): + # try to detect ambiguity + # `target_modules` can also be a str, in which case this loop would loop + # over the chars of the str. The technically correct way to match LoRA keys + # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). + # But this cuts it for now. + exact_matches = [mod for mod in target_modules if mod == key] + substring_matches = [mod for mod in target_modules if key in mod and mod != key] + + if exact_matches and substring_matches: + if is_peft_version("<", "0.14.1"): + raise ValueError( + "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." + ) + + +def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): + warn_msg = "" + if incompatible_keys is not None: + # Check only for unexpected keys. + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) + + +def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): + """ + Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the + `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it + doesn't exist in `peft_state_dict`. + """ + if model_state_dict is None: + return + all_modules = set() + string_to_replace = f"{adapter_name}." if adapter_name else "" + + for name in model_state_dict.keys(): + if string_to_replace: + name = name.replace(string_to_replace, "") + if "." in name: + module_name = name.rsplit(".", 1)[0] + all_modules.add(module_name) + + target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} + exclude_modules = list(all_modules - target_modules_set) + + return exclude_modules \ No newline at end of file From 7af96db697362704932dedf944eed887f0adbba6 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 13:08:52 +0900 Subject: [PATCH 472/482] updates --- src/diffusers/utils/torch_utils.py | 40 ++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 491cc16e39e6..d1bd1f6e4b3d 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,10 +16,9 @@ """ from typing import List, Optional, Tuple, Union -import random from . import logging -from .import_utils import is_torch_available, is_torch_version +from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version if is_torch_available(): @@ -39,7 +38,7 @@ def maybe_allow_in_graph(cls): def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, - device: Optional["torch.device"] = None, + device: Optional[Union[str, "torch.device"]] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): @@ -48,6 +47,8 @@ def randn_tensor( is always created on the CPU. """ # device on which tensor is created defaults to device + if isinstance(device, str): + device = torch.device(device) rand_device = device batch_size = shape[0] @@ -62,7 +63,7 @@ def randn_tensor( logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" - f" slighly speed up this function by passing a generator that was created on the {device} device." + f" slightly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") @@ -80,7 +81,7 @@ def randn_tensor( latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - + return latents @@ -91,8 +92,13 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": - """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). + """Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497). This version of the method comes from here: https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 @@ -165,7 +171,27 @@ def get_torch_cuda_device_capability(): def get_device(): if torch.cuda.is_available(): return "cuda" + elif is_torch_npu_available(): + return "npu" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" + elif torch.backends.mps.is_available(): + return "mps" else: return "cpu" + + +def empty_device_cache(device_type: Optional[str] = None): + if device_type is None: + device_type = get_device() + if device_type in ["cpu"]: + return + device_mod = getattr(torch, device_type, torch.cuda) + device_mod.empty_cache() + + +def device_synchronize(device_type: Optional[str] = None): + if device_type is None: + device_type = get_device() + device_mod = getattr(torch, device_type, torch.cuda) + device_mod.synchronize() \ No newline at end of file From d7256478f4e454f4e04a4c40212ef1067ce89636 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 13:10:53 +0900 Subject: [PATCH 473/482] updates --- src/diffusers/loaders/peft.py | 186 ++++++++++++++-------------------- 1 file changed, 74 insertions(+), 112 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 1d990e81458d..822c25994725 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json import os from functools import partial from pathlib import Path @@ -28,13 +29,13 @@ convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, - get_peft_kwargs, is_peft_available, is_peft_version, logging, set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .unet_loader_utils import _maybe_expand_lora_scales @@ -56,29 +57,13 @@ "Lumina2Transformer2DModel": lambda model_cls, weights: weights, "WanTransformer3DModel": lambda model_cls, weights: weights, "CogView4Transformer2DModel": lambda model_cls, weights: weights, + "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, + "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, + "WanVACETransformer3DModel": lambda model_cls, weights: weights, + "ChromaTransformer2DModel": lambda model_cls, weights: weights, } -def _maybe_raise_error_for_ambiguity(config): - rank_pattern = config["rank_pattern"].copy() - target_modules = config["target_modules"] - - for key in list(rank_pattern.keys()): - # try to detect ambiguity - # `target_modules` can also be a str, in which case this loop would loop - # over the chars of the str. The technically correct way to match LoRA keys - # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). - # But this cuts it for now. - exact_matches = [mod for mod in target_modules if mod == key] - substring_matches = [mod for mod in target_modules if key in mod and mod != key] - - if exact_matches and substring_matches: - if is_peft_version("<", "0.14.1"): - raise ValueError( - "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." - ) - - class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -100,17 +85,6 @@ class PeftAdapterMixin: @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. - """ return _func_optionally_disable_offloading(_pipeline=_pipeline) def load_lora_adapter( @@ -182,10 +156,15 @@ def load_lora_adapter( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap + metadata: + LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to + initialize `LoraConfig`. """ - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft import inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -199,6 +178,7 @@ def load_lora_adapter( network_alphas = kwargs.pop("network_alphas", None) _pipeline = kwargs.pop("_pipeline", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + metadata = kwargs.pop("metadata", None) allow_pickle = False if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): @@ -206,12 +186,8 @@ def load_lora_adapter( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -224,12 +200,17 @@ def load_lora_adapter( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, + metadata=metadata, ) if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + if network_alphas and metadata: + raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -249,7 +230,7 @@ def load_lora_adapter( rank = {} for key, val in state_dict.items(): - # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. + # Cannot figure out rank from lora layers that don't have at least 2 dimensions. # Bias layers in LoRA only have a single dimension if "lora_B" in key and val.ndim > 1: # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. @@ -260,44 +241,33 @@ def load_lora_adapter( if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - _maybe_raise_error_for_ambiguity(lora_config_kwargs) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") + network_alphas = { + k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys + } - lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) + # create LoraConfig + lora_config = _create_lora_config( + state_dict, + network_alphas, + metadata, + rank, + model_state_dict=self.state_dict(), + adapter_name=adapter_name, + ) + # =", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage @@ -329,7 +299,7 @@ def map_state_dict_for_hotswap(sd): new_sd[k] = v return new_sd - # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, + # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful, # we should also delete the `peft_config` associated to the `adapter_name`. try: if hotswap: @@ -343,7 +313,7 @@ def map_state_dict_for_hotswap(sd): config=lora_config, ) except Exception as e: - logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") + logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}") raise # the hotswap function raises if there are incompatible keys, so if we reach this point we can set # it to None @@ -378,46 +348,28 @@ def map_state_dict_for_hotswap(sd): module.delete_adapter(adapter_name) self.peft_config.pop(adapter_name) - logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") + logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}") raise - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) + _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() + elif is_group_offload: + for component in _pipeline.components.values(): + if isinstance(component, torch.nn.Module): + _maybe_remove_and_reapply_group_offloading(component) # Unsafe code /> if prefix is not None and not state_dict: + model_class_name = self.__class__.__name__ logger.warning( - f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. " + f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " "This is safe to ignore if LoRA state dict didn't originally have any " - f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` " + f"{model_class_name} related params. You can also try specifying `prefix=None` " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "https://github.com/huggingface/diffusers/issues/new" ) @@ -440,17 +392,13 @@ def save_lora_adapter( underlying model has multiple adapters loaded. upcast_before_saving (`bool`, defaults to `False`): Whether to cast the underlying model to `torch.float32` before serialization. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. """ from peft.utils import get_peft_model_state_dict - from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE if adapter_name is None: adapter_name = get_adapter_name(self) @@ -458,6 +406,8 @@ def save_lora_adapter( if adapter_name not in getattr(self, "peft_config", {}): raise ValueError(f"Adapter name {adapter_name} not found in the model.") + lora_adapter_metadata = self.peft_config[adapter_name].to_dict() + lora_layers_to_save = get_peft_model_state_dict( self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name ) @@ -467,7 +417,15 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + # Inject framework format. + metadata = {"format": "pt"} + if lora_adapter_metadata is not None: + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save @@ -480,7 +438,6 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME - # TODO: we could consider saving the `peft_config` as well. save_path = Path(save_directory, weight_name).as_posix() save_function(lora_layers_to_save, save_path) logger.info(f"Model weights saved in {save_path}") @@ -491,7 +448,7 @@ def set_adapters( weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, ): """ - Set the currently active adapters for use in the UNet. + Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.). Args: adapter_names (`List[str]` or `str`): @@ -513,7 +470,7 @@ def set_adapters( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) ``` """ if not USE_PEFT_BACKEND: @@ -711,7 +668,7 @@ def _fuse_lora_apply(self, module, adapter_names=None): if self.lora_scale != 1.0: module.scale_layer(self.lora_scale) - # For BC with prevous PEFT versions, we need to check the signature + # For BC with previous PEFT versions, we need to check the signature # of the `merge` method to see if it supports the `adapter_names` argument. supported_merge_kwargs = list(inspect.signature(module.merge).parameters) if "adapter_names" in supported_merge_kwargs: @@ -739,11 +696,16 @@ def unload_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `unload_lora()`.") + from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..utils import recurse_remove_peft_layers recurse_remove_peft_layers(self) if hasattr(self, "peft_config"): del self.peft_config + if hasattr(self, "_hf_peft_config_loaded"): + self._hf_peft_config_loaded = None + + _maybe_remove_and_reapply_group_offloading(self) def disable_lora(self): """ @@ -761,7 +723,7 @@ def disable_lora(self): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) - pipeline.disable_lora() + pipeline.unet.disable_lora() ``` """ if not USE_PEFT_BACKEND: @@ -784,7 +746,7 @@ def enable_lora(self): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) - pipeline.enable_lora() + pipeline.unet.enable_lora() ``` """ if not USE_PEFT_BACKEND: @@ -811,7 +773,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" ) - pipeline.delete_adapters("cinematic") + pipeline.unet.delete_adapters("cinematic") ``` """ if not USE_PEFT_BACKEND: @@ -858,4 +820,4 @@ def enable_lora_hotswap( f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." ) - self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} + self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} \ No newline at end of file From ac7197b9d9a8f1334b8e788b2d65c7fa8aedf6b9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Wed, 16 Jul 2025 13:12:44 +0900 Subject: [PATCH 474/482] updates --- src/diffusers/hooks/group_offloading.py | 509 +++++++++++++++--------- 1 file changed, 316 insertions(+), 193 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index ac6cf653641b..6eb824646561 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import os from contextlib import contextmanager, nullcontext -from typing import Dict, List, Optional, Set, Tuple +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Set, Tuple, Union +import safetensors.torch import torch from ..utils import get_logger, is_accelerate_available @@ -33,7 +38,7 @@ _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" - +_GROUP_ID_LAZY_LEAF = "lazy_leafs" _SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, @@ -44,6 +49,24 @@ # fmt: on +class GroupOffloadingType(str, Enum): + BLOCK_LEVEL = "block_level" + LEAF_LEVEL = "leaf_level" + + +@dataclass +class GroupOffloadingConfig: + onload_device: torch.device + offload_device: torch.device + offload_type: GroupOffloadingType + non_blocking: bool + record_stream: bool + low_cpu_mem_usage: bool + num_blocks_per_group: Optional[int] = None + offload_to_disk_path: Optional[str] = None + stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + + class ModuleGroup: def __init__( self, @@ -55,10 +78,12 @@ def __init__( parameters: Optional[List[torch.nn.Parameter]] = None, buffers: Optional[List[torch.Tensor]] = None, non_blocking: bool = False, - stream: Optional[torch.cuda.Stream] = None, + stream: Union[torch.cuda.Stream, torch.Stream, None] = None, record_stream: Optional[bool] = False, - low_cpu_mem_usage=False, + low_cpu_mem_usage: bool = False, onload_self: bool = True, + offload_to_disk_path: Optional[str] = None, + group_id: Optional[int] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -72,10 +97,29 @@ def __init__( self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.cpu_param_dict = self._init_cpu_param_dict() - if self.stream is None and self.record_stream: - raise ValueError("`record_stream` cannot be True when `stream` is None.") + self.offload_to_disk_path = offload_to_disk_path + self._is_offloaded_to_disk = False + + if self.offload_to_disk_path: + # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. + self.group_id = group_id if group_id is not None else str(id(self)) + short_hash = _compute_group_hash(self.group_id) + self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") + + all_tensors = [] + for module in self.modules: + all_tensors.extend(list(module.parameters())) + all_tensors.extend(list(module.buffers())) + all_tensors.extend(self.parameters) + all_tensors.extend(self.buffers) + all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates + + self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} + self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self.cpu_param_dict = {} + else: + self.cpu_param_dict = self._init_cpu_param_dict() def _init_cpu_param_dict(self): cpu_param_dict = {} @@ -113,58 +157,127 @@ def _pinned_memory_tensors(self): finally: pinned_dict = None - def onload_(self): - r"""Onloads the group of modules to the onload_device.""" - context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) - current_stream = torch.cuda.current_stream() if self.record_stream else None + def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): + tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream and current_stream is not None: + tensor.data.record_stream(current_stream) + + def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None): + for group_module in self.modules: + for param in group_module.parameters(): + source = pinned_memory[param] if pinned_memory else param.data + self._transfer_tensor_to_device(param, source, current_stream) + for buffer in group_module.buffers(): + source = pinned_memory[buffer] if pinned_memory else buffer.data + self._transfer_tensor_to_device(buffer, source, current_stream) + for param in self.parameters: + source = pinned_memory[param] if pinned_memory else param.data + self._transfer_tensor_to_device(param, source, current_stream) + + for buffer in self.buffers: + source = pinned_memory[buffer] if pinned_memory else buffer.data + self._transfer_tensor_to_device(buffer, source, current_stream) + + def _onload_from_disk(self, current_stream): if self.stream is not None: - # Wait for previous Host->Device transfer to complete - self.stream.synchronize() + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - with context: - if self.stream is not None: - with self._pinned_memory_tensors() as pinned_memory: - for group_module in self.modules: - for param in group_module.parameters(): - param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - param.data.record_stream(current_stream) - for buffer in group_module.buffers(): - buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) - - for param in self.parameters: - param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - param.data.record_stream(current_stream) + for key, tensor_obj in self.key_to_tensor.items(): + self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key] - for buffer in self.buffers: - buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) + with self._pinned_memory_tensors() as pinned_memory: + for key, tensor_obj in self.key_to_tensor.items(): + self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream) - else: - for group_module in self.modules: - for param in group_module.parameters(): - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - for buffer in group_module.buffers(): - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + self.cpu_param_dict.clear() + + else: + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] - for param in self.parameters: - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + def _onload_from_memory(self, current_stream): + if self.stream is not None: + with self._pinned_memory_tensors() as pinned_memory: + self._process_tensors_from_modules(pinned_memory, current_stream) + else: + self._process_tensors_from_modules(None, current_stream) - for buffer in self.buffers: - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - buffer.data.record_stream(current_stream) + @torch.compiler.disable() + def onload_(self): + torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) + context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) + current_stream = torch_accelerator_module.current_stream() if self.record_stream else None + + if self.offload_to_disk_path: + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + if self.stream is not None: + # Load to CPU, pin, and async copy to device for overlapping transfer and compute + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") + for key, tensor_obj in self.key_to_tensor.items(): + pinned_tensor = loaded_cpu_tensors[key].pin_memory() + tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + tensor_obj.data.record_stream(current_stream) + else: + # Load directly to the target device (synchronous) + onload_device = ( + self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device + ) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) + for key, tensor_obj in self.key_to_tensor.items(): + tensor_obj.data = loaded_tensors[key] + return - def offload_(self): - r"""Offloads the group of modules to the offload_device.""" + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + if self.offload_to_disk_path: + self._onload_from_disk(current_stream) + else: + self._onload_from_memory(current_stream) + + def _offload_to_disk(self): + # TODO: we can potentially optimize this code path by checking if the _all_ the desired + # safetensor files exist on the disk and if so, skip this step entirely, reducing IO + # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # we perform a write. + # Check if the file has been saved in this session or if it already exists on disk. + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) + tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + + # The group is now considered offloaded to disk for the rest of the session. + self._is_offloaded_to_disk = True + + # We do this to free up the RAM which is still holding the up tensor data. + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + + def _offload_to_memory(self): + torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) if self.stream is not None: if not self.record_stream: - torch.cuda.current_stream().synchronize() + torch_accelerator_module.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -181,6 +294,14 @@ def offload_(self): for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + @torch.compiler.disable() + def offload_(self): + r"""Offloads the group of modules to the offload_device.""" + if self.offload_to_disk_path: + self._offload_to_disk() + else: + self._offload_to_memory() + class GroupOffloadingHook(ModelHook): r""" @@ -193,12 +314,11 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False def __init__( - self, - group: ModuleGroup, - next_group: Optional[ModuleGroup] = None, + self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig ) -> None: self.group = group self.next_group = next_group + self.config = config def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -232,7 +352,7 @@ def post_forward(self, module: torch.nn.Module, output): class LazyPrefetchGroupOffloadingHook(ModelHook): r""" - A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. + A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows prefetching groups in the correct order. @@ -344,12 +464,13 @@ def apply_group_offloading( module: torch.nn.Module, onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), - offload_type: str = "block_level", + offload_type: Union[str, GroupOffloadingType] = "block_level", num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, record_stream: bool = False, low_cpu_mem_usage: bool = False, + offload_to_disk_path: Optional[str] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -385,9 +506,12 @@ def apply_group_offloading( The device to which the group of modules are onloaded. offload_device (`torch.device`, defaults to `torch.device("cpu")`): The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. - offload_type (`str`, defaults to "block_level"): + offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". + offload_to_disk_path (`str`, *optional*, defaults to `None`): + The path to the directory where parameters will be offloaded. Setting this option can be useful in limited + RAM environment settings where a reasonable speed-memory trade-off is desired. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -425,80 +549,59 @@ def apply_group_offloading( ``` """ + offload_type = GroupOffloadingType(offload_type) + stream = None if use_stream: if torch.cuda.is_available(): stream = torch.cuda.Stream() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + stream = torch.Stream() else: - raise ValueError("Using streams for data transfer requires a CUDA device.") + raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") + + if not use_stream and record_stream: + raise ValueError("`record_stream` cannot be True when `use_stream=False`.") + if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None: + raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.") _raise_error_if_accelerate_model_or_sequential_hook_present(module) - if offload_type == "block_level": - if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") - - _apply_group_offloading_block_level( - module=module, - num_blocks_per_group=num_blocks_per_group, - offload_device=offload_device, - onload_device=onload_device, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - elif offload_type == "leaf_level": - _apply_group_offloading_leaf_level( - module=module, - offload_device=offload_device, - onload_device=onload_device, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + config = GroupOffloadingConfig( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk_path=offload_to_disk_path, + ) + _apply_group_offloading(module, config) + + +def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: + if config.offload_type == GroupOffloadingType.BLOCK_LEVEL: + _apply_group_offloading_block_level(module, config) + elif config.offload_type == GroupOffloadingType.LEAF_LEVEL: + _apply_group_offloading_leaf_level(module, config) else: - raise ValueError(f"Unsupported offload_type: {offload_type}") + assert False -def _apply_group_offloading_block_level( - module: torch.nn.Module, - num_blocks_per_group: int, - offload_device: torch.device, - onload_device: torch.device, - non_blocking: bool, - stream: Optional[torch.cuda.Stream] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, -) -> None: +def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - - Args: - module (`torch.nn.Module`): - The module to which group offloading is applied. - offload_device (`torch.device`): - The device to which the group of modules are offloaded. This should typically be the CPU. - onload_device (`torch.device`): - The device to which the group of modules are onloaded. - non_blocking (`bool`): - If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation - and data transfer. - stream (`torch.cuda.Stream`, *optional*): - If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful - for overlapping computation and data transfer. - record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor - as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the - [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more - details. - low_cpu_mem_usage (`bool`, defaults to `False`): - If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This - option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when - the CPU memory is a bottleneck but may counteract the benefits of using streams. """ + if config.stream is not None and config.num_blocks_per_group != 1: + logger.warning( + f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." + ) + config.num_blocks_per_group = 1 + # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() unmatched_modules = [] @@ -509,19 +612,22 @@ def _apply_group_offloading_block_level( modules_with_group_offloading.add(name) continue - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = submodule[i : i + config.num_blocks_per_group] + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" group = ModuleGroup( modules=current_modules, - offload_device=offload_device, - onload_device=onload_device, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=current_modules[-1], onload_leader=current_modules[0], - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - onload_self=stream is None, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=group_id, ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -529,12 +635,8 @@ def _apply_group_offloading_block_level( # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): - next_group = ( - matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None - ) - for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, next_group) + _apply_group_offloading_hook(group_module, group, None, config=config) # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not @@ -549,8 +651,9 @@ def _apply_group_offloading_block_level( unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_group = ModuleGroup( modules=unmatched_modules, - offload_device=offload_device, - onload_device=onload_device, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=parameters, @@ -559,49 +662,21 @@ def _apply_group_offloading_block_level( stream=None, record_stream=False, onload_self=True, + group_id=f"{module.__class__.__name__}_unmatched_group", ) - next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None - _apply_group_offloading_hook(module, unmatched_group, next_group) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, None, config=config) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) -def _apply_group_offloading_leaf_level( - module: torch.nn.Module, - offload_device: torch.device, - onload_device: torch.device, - non_blocking: bool, - stream: Optional[torch.cuda.Stream] = None, - record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, -) -> None: +def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory requirements. However, it can be slower compared to other offloading methods due to the excessive number of device synchronizations. When using devices that support streams to overlap data transfer and computation, this method can reduce memory usage without any performance degradation. - - Args: - module (`torch.nn.Module`): - The module to which group offloading is applied. - offload_device (`torch.device`): - The device to which the group of modules are offloaded. This should typically be the CPU. - onload_device (`torch.device`): - The device to which the group of modules are onloaded. - non_blocking (`bool`): - If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation - and data transfer. - stream (`torch.cuda.Stream`, *optional*): - If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful - for overlapping computation and data transfer. - record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor - as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the - [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more - details. - low_cpu_mem_usage (`bool`, defaults to `False`): - If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This - option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when - the CPU memory is a bottleneck but may counteract the benefits of using streams. """ - # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() for name, submodule in module.named_modules(): @@ -609,17 +684,19 @@ def _apply_group_offloading_leaf_level( continue group = ModuleGroup( modules=[submodule], - offload_device=offload_device, - onload_device=onload_device, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=submodule, onload_leader=submodule, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, + group_id=name, ) - _apply_group_offloading_hook(submodule, group, None) + _apply_group_offloading_hook(submodule, group, None, config=config) modules_with_group_offloading.add(name) # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass @@ -650,31 +727,33 @@ def _apply_group_offloading_leaf_level( parameters = parent_to_parameters.get(name, []) buffers = parent_to_buffers.get(name, []) parent_module = module_dict[name] - assert getattr(parent_module, "_diffusers_hook", None) is None group = ModuleGroup( modules=[], - offload_device=offload_device, - onload_device=onload_device, + offload_device=config.offload_device, + onload_device=config.onload_device, offload_leader=parent_module, onload_leader=parent_module, + offload_to_disk_path=config.offload_to_disk_path, parameters=parameters, buffers=buffers, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, + group_id=name, ) - _apply_group_offloading_hook(parent_module, group, None) + _apply_group_offloading_hook(parent_module, group, None, config=config) - if stream is not None: + if config.stream is not None: # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the # execution order and apply prefetching in the correct order. unmatched_group = ModuleGroup( modules=[], - offload_device=offload_device, - onload_device=onload_device, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, offload_leader=module, onload_leader=module, parameters=None, @@ -682,23 +761,26 @@ def _apply_group_offloading_leaf_level( non_blocking=False, stream=None, record_stream=False, - low_cpu_mem_usage=low_cpu_mem_usage, + low_cpu_mem_usage=config.low_cpu_mem_usage, onload_self=True, + group_id=_GROUP_ID_LAZY_LEAF, ) - _apply_lazy_group_offloading_hook(module, unmatched_group, None) + _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) def _apply_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, + *, + config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group) + hook = GroupOffloadingHook(group, next_group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) @@ -706,13 +788,15 @@ def _apply_lazy_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, + *, + config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group) + hook = GroupOffloadingHook(group, next_group, config=config) registry.register_hook(hook, _GROUP_OFFLOADING) lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() @@ -779,15 +863,54 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn ) -def _is_group_offload_enabled(module: torch.nn.Module) -> bool: +def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]: for submodule in module.modules(): - if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: - return True - return False + if hasattr(submodule, "_diffusers_hook"): + group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) + if group_offloading_hook is not None: + return group_offloading_hook + return None + + +def _is_group_offload_enabled(module: torch.nn.Module) -> bool: + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + return top_level_group_offload_hook is not None def _get_group_onload_device(module: torch.nn.Module) -> torch.device: - for submodule in module.modules(): - if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: - return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + if top_level_group_offload_hook is not None: + return top_level_group_offload_hook.config.onload_device raise ValueError("Group offloading is not enabled for the provided module.") + + +def _compute_group_hash(group_id): + hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() + # first 16 characters for a reasonably short but unique name + return hashed_id[:16] + + +def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None: + r""" + Removes the group offloading hook from the module and re-applies it. This is useful when the module has been + modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place + modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly. + + In this implementation, we make an assumption that group offloading has only been applied at the top-level module, + and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the + case where user has applied group offloading at multiple levels, this function will not work as expected. + + There is some performance penalty associated with doing this when non-default streams are used, because we need to + retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`. + """ + top_level_group_offload_hook = _get_top_level_group_offload_hook(module) + + if top_level_group_offload_hook is None: + return + + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.remove_hook(_GROUP_OFFLOADING, recurse=True) + registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True) + registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True) + + _apply_group_offloading(module, top_level_group_offload_hook.config) \ No newline at end of file From ae4cc6a156bbd5ee209deb175b1c359ee9828fd9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:46:40 +0900 Subject: [PATCH 475/482] Revert "updates" This reverts commit ac7197b9d9a8f1334b8e788b2d65c7fa8aedf6b9. --- src/diffusers/hooks/group_offloading.py | 509 +++++++++--------------- 1 file changed, 193 insertions(+), 316 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 6eb824646561..ac6cf653641b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib -import os from contextlib import contextmanager, nullcontext -from dataclasses import dataclass -from enum import Enum -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple -import safetensors.torch import torch from ..utils import get_logger, is_accelerate_available @@ -38,7 +33,7 @@ _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" -_GROUP_ID_LAZY_LEAF = "lazy_leafs" + _SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, @@ -49,24 +44,6 @@ # fmt: on -class GroupOffloadingType(str, Enum): - BLOCK_LEVEL = "block_level" - LEAF_LEVEL = "leaf_level" - - -@dataclass -class GroupOffloadingConfig: - onload_device: torch.device - offload_device: torch.device - offload_type: GroupOffloadingType - non_blocking: bool - record_stream: bool - low_cpu_mem_usage: bool - num_blocks_per_group: Optional[int] = None - offload_to_disk_path: Optional[str] = None - stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None - - class ModuleGroup: def __init__( self, @@ -78,12 +55,10 @@ def __init__( parameters: Optional[List[torch.nn.Parameter]] = None, buffers: Optional[List[torch.Tensor]] = None, non_blocking: bool = False, - stream: Union[torch.cuda.Stream, torch.Stream, None] = None, + stream: Optional[torch.cuda.Stream] = None, record_stream: Optional[bool] = False, - low_cpu_mem_usage: bool = False, + low_cpu_mem_usage=False, onload_self: bool = True, - offload_to_disk_path: Optional[str] = None, - group_id: Optional[int] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -97,29 +72,10 @@ def __init__( self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage + self.cpu_param_dict = self._init_cpu_param_dict() - self.offload_to_disk_path = offload_to_disk_path - self._is_offloaded_to_disk = False - - if self.offload_to_disk_path: - # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well. - self.group_id = group_id if group_id is not None else str(id(self)) - short_hash = _compute_group_hash(self.group_id) - self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors") - - all_tensors = [] - for module in self.modules: - all_tensors.extend(list(module.parameters())) - all_tensors.extend(list(module.buffers())) - all_tensors.extend(self.parameters) - all_tensors.extend(self.buffers) - all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates - - self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} - self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} - self.cpu_param_dict = {} - else: - self.cpu_param_dict = self._init_cpu_param_dict() + if self.stream is None and self.record_stream: + raise ValueError("`record_stream` cannot be True when `stream` is None.") def _init_cpu_param_dict(self): cpu_param_dict = {} @@ -157,127 +113,58 @@ def _pinned_memory_tensors(self): finally: pinned_dict = None - def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): - tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream and current_stream is not None: - tensor.data.record_stream(current_stream) - - def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None): - for group_module in self.modules: - for param in group_module.parameters(): - source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, current_stream) - for buffer in group_module.buffers(): - source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, current_stream) - - for param in self.parameters: - source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, current_stream) - - for buffer in self.buffers: - source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, current_stream) - - def _onload_from_disk(self, current_stream): - if self.stream is not None: - loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - - for key, tensor_obj in self.key_to_tensor.items(): - self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key] - - with self._pinned_memory_tensors() as pinned_memory: - for key, tensor_obj in self.key_to_tensor.items(): - self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream) - - self.cpu_param_dict.clear() - - else: - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] - - def _onload_from_memory(self, current_stream): - if self.stream is not None: - with self._pinned_memory_tensors() as pinned_memory: - self._process_tensors_from_modules(pinned_memory, current_stream) - else: - self._process_tensors_from_modules(None, current_stream) - - @torch.compiler.disable() def onload_(self): - torch_accelerator_module = ( - getattr(torch, torch.accelerator.current_accelerator().type) - if hasattr(torch, "accelerator") - else torch.cuda - ) - context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) - current_stream = torch_accelerator_module.current_stream() if self.record_stream else None - - if self.offload_to_disk_path: - if self.stream is not None: - # Wait for previous Host->Device transfer to complete - self.stream.synchronize() - - with context: - if self.stream is not None: - # Load to CPU, pin, and async copy to device for overlapping transfer and compute - loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - for key, tensor_obj in self.key_to_tensor.items(): - pinned_tensor = loaded_cpu_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) - else: - # Load directly to the target device (synchronous) - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] - return + r"""Onloads the group of modules to the onload_device.""" + context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) + current_stream = torch.cuda.current_stream() if self.record_stream else None if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() with context: - if self.offload_to_disk_path: - self._onload_from_disk(current_stream) + if self.stream is not None: + with self._pinned_memory_tensors() as pinned_memory: + for group_module in self.modules: + for param in group_module.parameters(): + param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + param.data.record_stream(current_stream) + for buffer in group_module.buffers(): + buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) + + for param in self.parameters: + param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + param.data.record_stream(current_stream) + + for buffer in self.buffers: + buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) + else: - self._onload_from_memory(current_stream) - - def _offload_to_disk(self): - # TODO: we can potentially optimize this code path by checking if the _all_ the desired - # safetensor files exist on the disk and if so, skip this step entirely, reducing IO - # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not - # we perform a write. - # Check if the file has been saved in this session or if it already exists on disk. - if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): - os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) - - # The group is now considered offloaded to disk for the rest of the session. - self._is_offloaded_to_disk = True - - # We do this to free up the RAM which is still holding the up tensor data. - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) - - def _offload_to_memory(self): - torch_accelerator_module = ( - getattr(torch, torch.accelerator.current_accelerator().type) - if hasattr(torch, "accelerator") - else torch.cuda - ) + for group_module in self.modules: + for param in group_module.parameters(): + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + for buffer in group_module.buffers(): + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + + for param in self.parameters: + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + + for buffer in self.buffers: + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) + + def offload_(self): + r"""Offloads the group of modules to the offload_device.""" if self.stream is not None: if not self.record_stream: - torch_accelerator_module.current_stream().synchronize() + torch.cuda.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -294,14 +181,6 @@ def _offload_to_memory(self): for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) - @torch.compiler.disable() - def offload_(self): - r"""Offloads the group of modules to the offload_device.""" - if self.offload_to_disk_path: - self._offload_to_disk() - else: - self._offload_to_memory() - class GroupOffloadingHook(ModelHook): r""" @@ -314,11 +193,12 @@ class GroupOffloadingHook(ModelHook): _is_stateful = False def __init__( - self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig + self, + group: ModuleGroup, + next_group: Optional[ModuleGroup] = None, ) -> None: self.group = group self.next_group = next_group - self.config = config def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.group.offload_leader == module: @@ -352,7 +232,7 @@ def post_forward(self, module: torch.nn.Module, output): class LazyPrefetchGroupOffloadingHook(ModelHook): r""" - A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. + A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows prefetching groups in the correct order. @@ -464,13 +344,12 @@ def apply_group_offloading( module: torch.nn.Module, onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), - offload_type: Union[str, GroupOffloadingType] = "block_level", + offload_type: str = "block_level", num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, record_stream: bool = False, low_cpu_mem_usage: bool = False, - offload_to_disk_path: Optional[str] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -506,12 +385,9 @@ def apply_group_offloading( The device to which the group of modules are onloaded. offload_device (`torch.device`, defaults to `torch.device("cpu")`): The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. - offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"): + offload_type (`str`, defaults to "block_level"): The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is "block_level". - offload_to_disk_path (`str`, *optional*, defaults to `None`): - The path to the directory where parameters will be offloaded. Setting this option can be useful in limited - RAM environment settings where a reasonable speed-memory trade-off is desired. num_blocks_per_group (`int`, *optional*): The number of blocks per group when using offload_type="block_level". This is required when using offload_type="block_level". @@ -549,58 +425,79 @@ def apply_group_offloading( ``` """ - offload_type = GroupOffloadingType(offload_type) - stream = None if use_stream: if torch.cuda.is_available(): stream = torch.cuda.Stream() - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - stream = torch.Stream() else: - raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") - - if not use_stream and record_stream: - raise ValueError("`record_stream` cannot be True when `use_stream=False`.") - if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None: - raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.") + raise ValueError("Using streams for data transfer requires a CUDA device.") _raise_error_if_accelerate_model_or_sequential_hook_present(module) - config = GroupOffloadingConfig( - onload_device=onload_device, - offload_device=offload_device, - offload_type=offload_type, - num_blocks_per_group=num_blocks_per_group, - non_blocking=non_blocking, - stream=stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - offload_to_disk_path=offload_to_disk_path, - ) - _apply_group_offloading(module, config) - - -def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: - if config.offload_type == GroupOffloadingType.BLOCK_LEVEL: - _apply_group_offloading_block_level(module, config) - elif config.offload_type == GroupOffloadingType.LEAF_LEVEL: - _apply_group_offloading_leaf_level(module, config) + if offload_type == "block_level": + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") + + _apply_group_offloading_block_level( + module=module, + num_blocks_per_group=num_blocks_per_group, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + elif offload_type == "leaf_level": + _apply_group_offloading_leaf_level( + module=module, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + ) else: - assert False + raise ValueError(f"Unsupported offload_type: {offload_type}") -def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: +def _apply_group_offloading_block_level( + module: torch.nn.Module, + num_blocks_per_group: int, + offload_device: torch.device, + onload_device: torch.device, + non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, + low_cpu_mem_usage: bool = False, +) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - """ - if config.stream is not None and config.num_blocks_per_group != 1: - logger.warning( - f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." - ) - config.num_blocks_per_group = 1 + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + offload_device (`torch.device`): + The device to which the group of modules are offloaded. This should typically be the CPU. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + non_blocking (`bool`): + If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation + and data transfer. + stream (`torch.cuda.Stream`, *optional*): + If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful + for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. + """ # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -612,22 +509,19 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf modules_with_group_offloading.add(name) continue - for i in range(0, len(submodule), config.num_blocks_per_group): - current_modules = submodule[i : i + config.num_blocks_per_group] - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] group = ModuleGroup( modules=current_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, + offload_device=offload_device, + onload_device=onload_device, offload_leader=current_modules[-1], onload_leader=current_modules[0], - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, - onload_self=True, - group_id=group_id, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + onload_self=stream is None, ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -635,8 +529,12 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): + next_group = ( + matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None + ) + for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, None, config=config) + _apply_group_offloading_hook(group_module, group, next_group) # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not @@ -651,9 +549,8 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_group = ModuleGroup( modules=unmatched_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, + offload_device=offload_device, + onload_device=onload_device, offload_leader=module, onload_leader=module, parameters=parameters, @@ -662,21 +559,49 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf stream=None, record_stream=False, onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", ) - if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, None, config=config) - else: - _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) + next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None + _apply_group_offloading_hook(module, unmatched_group, next_group) -def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: +def _apply_group_offloading_leaf_level( + module: torch.nn.Module, + offload_device: torch.device, + onload_device: torch.device, + non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, + low_cpu_mem_usage: bool = False, +) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory requirements. However, it can be slower compared to other offloading methods due to the excessive number of device synchronizations. When using devices that support streams to overlap data transfer and computation, this method can reduce memory usage without any performance degradation. + + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + offload_device (`torch.device`): + The device to which the group of modules are offloaded. This should typically be the CPU. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + non_blocking (`bool`): + If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation + and data transfer. + stream (`torch.cuda.Stream`, *optional*): + If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful + for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. """ + # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() for name, submodule in module.named_modules(): @@ -684,19 +609,17 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff continue group = ModuleGroup( modules=[submodule], - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, + offload_device=offload_device, + onload_device=onload_device, offload_leader=submodule, onload_leader=submodule, - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - group_id=name, ) - _apply_group_offloading_hook(submodule, group, None, config=config) + _apply_group_offloading_hook(submodule, group, None) modules_with_group_offloading.add(name) # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass @@ -727,33 +650,31 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff parameters = parent_to_parameters.get(name, []) buffers = parent_to_buffers.get(name, []) parent_module = module_dict[name] + assert getattr(parent_module, "_diffusers_hook", None) is None group = ModuleGroup( modules=[], - offload_device=config.offload_device, - onload_device=config.onload_device, + offload_device=offload_device, + onload_device=onload_device, offload_leader=parent_module, onload_leader=parent_module, - offload_to_disk_path=config.offload_to_disk_path, parameters=parameters, buffers=buffers, - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - group_id=name, ) - _apply_group_offloading_hook(parent_module, group, None, config=config) + _apply_group_offloading_hook(parent_module, group, None) - if config.stream is not None: + if stream is not None: # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the # execution order and apply prefetching in the correct order. unmatched_group = ModuleGroup( modules=[], - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, + offload_device=offload_device, + onload_device=onload_device, offload_leader=module, onload_leader=module, parameters=None, @@ -761,26 +682,23 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff non_blocking=False, stream=None, record_stream=False, - low_cpu_mem_usage=config.low_cpu_mem_usage, + low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - group_id=_GROUP_ID_LAZY_LEAF, ) - _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config) + _apply_lazy_group_offloading_hook(module, unmatched_group, None) def _apply_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, - *, - config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group, config=config) + hook = GroupOffloadingHook(group, next_group) registry.register_hook(hook, _GROUP_OFFLOADING) @@ -788,15 +706,13 @@ def _apply_lazy_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, - *, - config: GroupOffloadingConfig, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, next_group, config=config) + hook = GroupOffloadingHook(group, next_group) registry.register_hook(hook, _GROUP_OFFLOADING) lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() @@ -863,54 +779,15 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn ) -def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]: - for submodule in module.modules(): - if hasattr(submodule, "_diffusers_hook"): - group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) - if group_offloading_hook is not None: - return group_offloading_hook - return None - - def _is_group_offload_enabled(module: torch.nn.Module) -> bool: - top_level_group_offload_hook = _get_top_level_group_offload_hook(module) - return top_level_group_offload_hook is not None + for submodule in module.modules(): + if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: + return True + return False def _get_group_onload_device(module: torch.nn.Module) -> torch.device: - top_level_group_offload_hook = _get_top_level_group_offload_hook(module) - if top_level_group_offload_hook is not None: - return top_level_group_offload_hook.config.onload_device + for submodule in module.modules(): + if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: + return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device raise ValueError("Group offloading is not enabled for the provided module.") - - -def _compute_group_hash(group_id): - hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() - # first 16 characters for a reasonably short but unique name - return hashed_id[:16] - - -def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None: - r""" - Removes the group offloading hook from the module and re-applies it. This is useful when the module has been - modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place - modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly. - - In this implementation, we make an assumption that group offloading has only been applied at the top-level module, - and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the - case where user has applied group offloading at multiple levels, this function will not work as expected. - - There is some performance penalty associated with doing this when non-default streams are used, because we need to - retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`. - """ - top_level_group_offload_hook = _get_top_level_group_offload_hook(module) - - if top_level_group_offload_hook is None: - return - - registry = HookRegistry.check_if_exists_or_initialize(module) - registry.remove_hook(_GROUP_OFFLOADING, recurse=True) - registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True) - registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True) - - _apply_group_offloading(module, top_level_group_offload_hook.config) \ No newline at end of file From 043fdd294a3a61fc2bf0940196a45bcab5c933e9 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:46:57 +0900 Subject: [PATCH 476/482] Revert "updates" This reverts commit d7256478f4e454f4e04a4c40212ef1067ce89636. --- src/diffusers/loaders/peft.py | 186 ++++++++++++++++++++-------------- 1 file changed, 112 insertions(+), 74 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 822c25994725..1d990e81458d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import json import os from functools import partial from pathlib import Path @@ -29,13 +28,13 @@ convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, + get_peft_kwargs, is_peft_available, is_peft_version, logging, set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .unet_loader_utils import _maybe_expand_lora_scales @@ -57,13 +56,29 @@ "Lumina2Transformer2DModel": lambda model_cls, weights: weights, "WanTransformer3DModel": lambda model_cls, weights: weights, "CogView4Transformer2DModel": lambda model_cls, weights: weights, - "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, - "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, - "WanVACETransformer3DModel": lambda model_cls, weights: weights, - "ChromaTransformer2DModel": lambda model_cls, weights: weights, } +def _maybe_raise_error_for_ambiguity(config): + rank_pattern = config["rank_pattern"].copy() + target_modules = config["target_modules"] + + for key in list(rank_pattern.keys()): + # try to detect ambiguity + # `target_modules` can also be a str, in which case this loop would loop + # over the chars of the str. The technically correct way to match LoRA keys + # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). + # But this cuts it for now. + exact_matches = [mod for mod in target_modules if mod == key] + substring_matches = [mod for mod in target_modules if key in mod and mod != key] + + if exact_matches and substring_matches: + if is_peft_version("<", "0.14.1"): + raise ValueError( + "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." + ) + + class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -85,6 +100,17 @@ class PeftAdapterMixin: @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ return _func_optionally_disable_offloading(_pipeline=_pipeline) def load_lora_adapter( @@ -156,15 +182,10 @@ def load_lora_adapter( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap - metadata: - LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to - initialize `LoraConfig`. """ - from peft import inject_adapter_in_model, set_peft_model_state_dict + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer - from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading - cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -178,7 +199,6 @@ def load_lora_adapter( network_alphas = kwargs.pop("network_alphas", None) _pipeline = kwargs.pop("_pipeline", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) - metadata = kwargs.pop("metadata", None) allow_pickle = False if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): @@ -186,8 +206,12 @@ def load_lora_adapter( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict, metadata = _fetch_state_dict( + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -200,17 +224,12 @@ def load_lora_adapter( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, - metadata=metadata, ) if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") - if network_alphas and metadata: - raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") if prefix is not None: - state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - if metadata is not None: - metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} + state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -230,7 +249,7 @@ def load_lora_adapter( rank = {} for key, val in state_dict.items(): - # Cannot figure out rank from lora layers that don't have at least 2 dimensions. + # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. # Bias layers in LoRA only have a single dimension if "lora_B" in key and val.ndim > 1: # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. @@ -241,33 +260,44 @@ def load_lora_adapter( if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = { - k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys - } + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + _maybe_raise_error_for_ambiguity(lora_config_kwargs) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) - # create LoraConfig - lora_config = _create_lora_config( - state_dict, - network_alphas, - metadata, - rank, - model_state_dict=self.state_dict(), - adapter_name=adapter_name, - ) - # =", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage @@ -299,7 +329,7 @@ def map_state_dict_for_hotswap(sd): new_sd[k] = v return new_sd - # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful, + # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, # we should also delete the `peft_config` associated to the `adapter_name`. try: if hotswap: @@ -313,7 +343,7 @@ def map_state_dict_for_hotswap(sd): config=lora_config, ) except Exception as e: - logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}") + logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") raise # the hotswap function raises if there are incompatible keys, so if we reach this point we can set # it to None @@ -348,28 +378,46 @@ def map_state_dict_for_hotswap(sd): module.delete_adapter(adapter_name) self.peft_config.pop(adapter_name) - logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}") + logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") raise - _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) + warn_msg = "" + if incompatible_keys is not None: + # Check only for unexpected keys. + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() - elif is_group_offload: - for component in _pipeline.components.values(): - if isinstance(component, torch.nn.Module): - _maybe_remove_and_reapply_group_offloading(component) # Unsafe code /> if prefix is not None and not state_dict: - model_class_name = self.__class__.__name__ logger.warning( - f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " + f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. " "This is safe to ignore if LoRA state dict didn't originally have any " - f"{model_class_name} related params. You can also try specifying `prefix=None` " + f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "https://github.com/huggingface/diffusers/issues/new" ) @@ -392,13 +440,17 @@ def save_lora_adapter( underlying model has multiple adapters loaded. upcast_before_saving (`bool`, defaults to `False`): Whether to cast the underlying model to `torch.float32` before serialization. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. """ from peft.utils import get_peft_model_state_dict - from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE if adapter_name is None: adapter_name = get_adapter_name(self) @@ -406,8 +458,6 @@ def save_lora_adapter( if adapter_name not in getattr(self, "peft_config", {}): raise ValueError(f"Adapter name {adapter_name} not found in the model.") - lora_adapter_metadata = self.peft_config[adapter_name].to_dict() - lora_layers_to_save = get_peft_model_state_dict( self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name ) @@ -417,15 +467,7 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): - # Inject framework format. - metadata = {"format": "pt"} - if lora_adapter_metadata is not None: - for key, value in lora_adapter_metadata.items(): - if isinstance(value, set): - lora_adapter_metadata[key] = list(value) - metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) - - return safetensors.torch.save_file(weights, filename, metadata=metadata) + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save @@ -438,6 +480,7 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME + # TODO: we could consider saving the `peft_config` as well. save_path = Path(save_directory, weight_name).as_posix() save_function(lora_layers_to_save, save_path) logger.info(f"Model weights saved in {save_path}") @@ -448,7 +491,7 @@ def set_adapters( weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, ): """ - Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.). + Set the currently active adapters for use in the UNet. Args: adapter_names (`List[str]` or `str`): @@ -470,7 +513,7 @@ def set_adapters( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) + pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) ``` """ if not USE_PEFT_BACKEND: @@ -668,7 +711,7 @@ def _fuse_lora_apply(self, module, adapter_names=None): if self.lora_scale != 1.0: module.scale_layer(self.lora_scale) - # For BC with previous PEFT versions, we need to check the signature + # For BC with prevous PEFT versions, we need to check the signature # of the `merge` method to see if it supports the `adapter_names` argument. supported_merge_kwargs = list(inspect.signature(module.merge).parameters) if "adapter_names" in supported_merge_kwargs: @@ -696,16 +739,11 @@ def unload_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `unload_lora()`.") - from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..utils import recurse_remove_peft_layers recurse_remove_peft_layers(self) if hasattr(self, "peft_config"): del self.peft_config - if hasattr(self, "_hf_peft_config_loaded"): - self._hf_peft_config_loaded = None - - _maybe_remove_and_reapply_group_offloading(self) def disable_lora(self): """ @@ -723,7 +761,7 @@ def disable_lora(self): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) - pipeline.unet.disable_lora() + pipeline.disable_lora() ``` """ if not USE_PEFT_BACKEND: @@ -746,7 +784,7 @@ def enable_lora(self): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) - pipeline.unet.enable_lora() + pipeline.enable_lora() ``` """ if not USE_PEFT_BACKEND: @@ -773,7 +811,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" ) - pipeline.unet.delete_adapters("cinematic") + pipeline.delete_adapters("cinematic") ``` """ if not USE_PEFT_BACKEND: @@ -820,4 +858,4 @@ def enable_lora_hotswap( f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." ) - self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} \ No newline at end of file + self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} From 1c1bb8223c785b39b8566740909411407f534dff Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:47:10 +0900 Subject: [PATCH 477/482] Revert "updates" This reverts commit 7af96db697362704932dedf944eed887f0adbba6. --- src/diffusers/utils/torch_utils.py | 40 ++++++------------------------ 1 file changed, 7 insertions(+), 33 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index d1bd1f6e4b3d..491cc16e39e6 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +16,10 @@ """ from typing import List, Optional, Tuple, Union +import random from . import logging -from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version +from .import_utils import is_torch_available, is_torch_version if is_torch_available(): @@ -38,7 +39,7 @@ def maybe_allow_in_graph(cls): def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, - device: Optional[Union[str, "torch.device"]] = None, + device: Optional["torch.device"] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): @@ -47,8 +48,6 @@ def randn_tensor( is always created on the CPU. """ # device on which tensor is created defaults to device - if isinstance(device, str): - device = torch.device(device) rand_device = device batch_size = shape[0] @@ -63,7 +62,7 @@ def randn_tensor( logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" - f" slightly speed up this function by passing a generator that was created on the {device} device." + f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") @@ -81,7 +80,7 @@ def randn_tensor( latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - + return latents @@ -92,13 +91,8 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) -def unwrap_module(module): - """Unwraps a module if it was compiled with torch.compile()""" - return module._orig_mod if is_compiled_module(module) else module - - def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": - """Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497). + """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). This version of the method comes from here: https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 @@ -171,27 +165,7 @@ def get_torch_cuda_device_capability(): def get_device(): if torch.cuda.is_available(): return "cuda" - elif is_torch_npu_available(): - return "npu" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" - elif torch.backends.mps.is_available(): - return "mps" else: return "cpu" - - -def empty_device_cache(device_type: Optional[str] = None): - if device_type is None: - device_type = get_device() - if device_type in ["cpu"]: - return - device_mod = getattr(torch, device_type, torch.cuda) - device_mod.empty_cache() - - -def device_synchronize(device_type: Optional[str] = None): - if device_type is None: - device_type = get_device() - device_mod = getattr(torch, device_type, torch.cuda) - device_mod.synchronize() \ No newline at end of file From 5bf927a1d4cbec5af638c4d652478d6560bc5779 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:47:20 +0900 Subject: [PATCH 478/482] Revert "updates" This reverts commit f349254fb3b6aa78a6cebe9a3181b86d041dd2f7. --- src/diffusers/utils/peft_utils.py | 144 +++--------------------------- 1 file changed, 10 insertions(+), 134 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 94aa224e1a4c..d1269fbc5f20 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,13 +21,9 @@ from packaging import version -from . import logging -from .import_utils import is_peft_available, is_peft_version, is_torch_available -from .torch_utils import empty_device_cache +from .import_utils import is_peft_available, is_torch_available -logger = logging.get_logger(__name__) - if is_torch_available(): import torch @@ -99,7 +95,8 @@ def recurse_remove_peft_layers(model): setattr(model, name, new_module) del module - empty_device_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return model @@ -150,27 +147,25 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs( - rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None -): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] if len(set(rank_dict.values())) > 1: - # get the rank occurring the most number of times + # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] - # for modules with rank different from the most occurring rank, add it to the `rank_pattern` + # for modules with rank different from the most occuring rank, add it to the `rank_pattern` rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} if network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: - # get the alpha occurring the most number of times + # get the alpha occuring the most number of times lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] - # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern` + # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) if is_unet: alpha_pattern = { @@ -182,6 +177,7 @@ def get_peft_kwargs( else: lora_alpha = set(network_alpha_dict.values()).pop() + # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) # for now we know that the "bias" keys are only associated with `lora_B`. @@ -196,21 +192,6 @@ def get_peft_kwargs( "use_dora": use_dora, "lora_bias": lora_bias, } - - # Example: try load FusionX LoRA into Wan VACE - exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name) - if exclude_modules: - if not is_peft_version(">=", "0.14.0"): - msg = """ -It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft` -version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U -peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue - -https://github.com/huggingface/diffusers/issues/new - """ - logger.debug(msg) - else: - lora_config_kwargs.update({"exclude_modules": exclude_modules}) - return lora_config_kwargs @@ -307,108 +288,3 @@ def check_peft_version(min_version: str) -> None: f"The version of PEFT you are using is not compatible, please use a version that is greater" f" than {min_version}" ) - - -def _create_lora_config( - state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None -): - from peft import LoraConfig - - if metadata is not None: - lora_config_kwargs = metadata - else: - lora_config_kwargs = get_peft_kwargs( - rank_pattern_dict, - network_alpha_dict=network_alphas, - peft_state_dict=state_dict, - is_unet=is_unet, - model_state_dict=model_state_dict, - adapter_name=adapter_name, - ) - - _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) - - # Version checks for DoRA and lora_bias - if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") - - if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.") - - try: - return LoraConfig(**lora_config_kwargs) - except TypeError as e: - raise TypeError("`LoraConfig` class could not be instantiated.") from e - - -def _maybe_raise_error_for_ambiguous_keys(config): - rank_pattern = config["rank_pattern"].copy() - target_modules = config["target_modules"] - - for key in list(rank_pattern.keys()): - # try to detect ambiguity - # `target_modules` can also be a str, in which case this loop would loop - # over the chars of the str. The technically correct way to match LoRA keys - # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). - # But this cuts it for now. - exact_matches = [mod for mod in target_modules if mod == key] - substring_matches = [mod for mod in target_modules if key in mod and mod != key] - - if exact_matches and substring_matches: - if is_peft_version("<", "0.14.1"): - raise ValueError( - "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." - ) - - -def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - -def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): - """ - Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the - `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it - doesn't exist in `peft_state_dict`. - """ - if model_state_dict is None: - return - all_modules = set() - string_to_replace = f"{adapter_name}." if adapter_name else "" - - for name in model_state_dict.keys(): - if string_to_replace: - name = name.replace(string_to_replace, "") - if "." in name: - module_name = name.rsplit(".", 1)[0] - all_modules.add(module_name) - - target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} - exclude_modules = list(all_modules - target_modules_set) - - return exclude_modules \ No newline at end of file From cce1b1e1ef28e70210b4a6a7847bbbfe245cc9cb Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:47:37 +0900 Subject: [PATCH 479/482] Revert "updates" This reverts commit ad94d113f6dba3ed3e26e5199320497312455770. --- src/diffusers/loaders/lora_base.py | 375 ++-- .../loaders/lora_conversion_utils.py | 582 ++---- src/diffusers/loaders/lora_pipeline.py | 1684 +++++++---------- src/diffusers/utils/state_dict_utils.py | 10 +- 4 files changed, 986 insertions(+), 1665 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index ff2cb6a54c1c..fe5f52c7149a 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ import copy import inspect -import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -26,6 +25,7 @@ from huggingface_hub.constants import HF_HUB_OFFLINE from ..models.modeling_utils import ModelMixin, load_state_dict +from ..utils.state_dict_utils import _load_sft_state_dict_metadata from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -34,6 +34,7 @@ delete_adapter_layers, deprecate, get_adapter_name, + get_peft_kwargs, is_accelerate_available, is_peft_available, is_peft_version, @@ -45,13 +46,13 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.peft_utils import _create_lora_config -from ..utils.state_dict_utils import _load_sft_state_dict_metadata if is_transformers_available(): from transformers import PreTrainedModel + from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules + if is_peft_available(): from peft.tuners.tuners_utils import BaseTunerLayer @@ -305,18 +306,13 @@ def _best_guess_weight_name( targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) if len(targeted_files) > 1: - logger.warning( - f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`." + raise ValueError( + f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." ) weight_name = targeted_files[0] return weight_name -def _pack_dict_with_prefix(state_dict, prefix): - sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} - return sd_with_prefix - - def _load_lora_into_text_encoder( state_dict, network_alphas, @@ -328,16 +324,10 @@ def _load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): - from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading - if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") - if network_alphas and metadata: - raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.") - peft_kwargs = {} if low_cpu_mem_usage: if not is_peft_version(">=", "0.13.1"): @@ -352,6 +342,8 @@ def _load_lora_into_text_encoder( ) peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as # their prefixes. @@ -363,9 +355,7 @@ def _load_lora_into_text_encoder( # Load the layers corresponding to text encoder and make necessary adjustments. if prefix is not None: - state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - if metadata is not None: - metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} + state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -375,27 +365,54 @@ def _load_lora_into_text_encoder( # convert state dict state_dict = convert_state_dict_to_peft(state_dict) - for name, _ in text_encoder.named_modules(): - if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")): - rank_key = f"{name}.lora_B.weight" - if rank_key in state_dict: - rank[rank_key] = state_dict[rank_key].shape[1] + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in state_dict: + continue + rank[rank_key] = state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in state_dict: + continue + rank[rank_key] = state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") - # create `LoraConfig` - lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False) + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) - # if prefix is not None and not state_dict: - model_class_name = text_encoder.__class__.__name__ logger.warning( - f"No LoRA keys associated to {model_class_name} found with the {prefix=}. " + f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " "This is safe to ignore if LoRA state dict didn't originally have any " - f"{model_class_name} related params. You can also try specifying `prefix=None` " + f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "https://github.com/huggingface/diffusers/issues/new" ) def _func_optionally_disable_offloading(_pipeline): - """ - Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. - - Args: - _pipeline (`DiffusionPipeline`): - The pipeline to disable offloading for. - - Returns: - tuple: - A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True. - """ - from ..hooks.group_offloading import _is_group_offload_enabled - is_model_cpu_offload = False is_sequential_cpu_offload = False - is_group_offload = False if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if not isinstance(component, nn.Module): - continue - is_group_offload = is_group_offload or _is_group_offload_enabled(component) - if not hasattr(component, "_hf_hook"): - continue - is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload) - is_sequential_cpu_offload = is_sequential_cpu_offload or ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) - if is_sequential_cpu_offload or is_model_cpu_offload: - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - for _, component in _pipeline.components.items(): - if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"): - continue + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) + return (is_model_cpu_offload, is_sequential_cpu_offload) class LoraBaseMixin: """Utility class for handling LoRAs.""" _lora_loadable_modules = [] - _merged_adapters = set() - - @property - def lora_scale(self) -> float: - """ - Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set, - return 1. - """ - return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - - @property - def num_fused_loras(self): - """Returns the number of LoRAs that have been fused.""" - return len(self._merged_adapters) - - @property - def fused_loras(self): - """Returns names of the LoRAs that have been fused.""" - return self._merged_adapters + num_fused_loras = 0 def load_lora_weights(self, **kwargs): raise NotImplementedError("`load_lora_weights()` is not implemented.") @@ -510,6 +485,33 @@ def save_lora_weights(cls, **kwargs): def lora_state_dict(cls, **kwargs): raise NotImplementedError("`lora_state_dict()` is not implemented.") + @classmethod + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + return _func_optionally_disable_offloading(_pipeline=_pipeline) + + @classmethod + def _fetch_state_dict(cls, *args, **kwargs): + deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." + deprecate("_fetch_state_dict", "0.35.0", deprecation_message) + return _fetch_state_dict(*args, **kwargs) + + @classmethod + def _best_guess_weight_name(cls, *args, **kwargs): + deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." + deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) + return _best_guess_weight_name(*args, **kwargs) + def unload_lora_weights(self): """ Unloads the LoRA parameters. @@ -597,9 +599,6 @@ def fuse_lora( if len(components) == 0: raise ValueError("`components` cannot be an empty list.") - # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it - # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`. - merged_adapter_names = set() for fuse_component in components: if fuse_component not in self._lora_loadable_modules: raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") @@ -609,19 +608,13 @@ def fuse_lora( # check if diffusers model if issubclass(model.__class__, ModelMixin): model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - merged_adapter_names.update(set(module.merged_adapters)) # handle transformers models. if issubclass(model.__class__, PreTrainedModel): fuse_text_encoder_lora( model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - merged_adapter_names.update(set(module.merged_adapters)) - self._merged_adapters = self._merged_adapters | merged_adapter_names + self.num_fused_loras += 1 def unfuse_lora(self, components: List[str] = [], **kwargs): r""" @@ -675,42 +668,15 @@ def unfuse_lora(self, components: List[str] = [], **kwargs): if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): for module in model.modules(): if isinstance(module, BaseTunerLayer): - for adapter in set(module.merged_adapters): - if adapter and adapter in self._merged_adapters: - self._merged_adapters = self._merged_adapters - {adapter} module.unmerge() + self.num_fused_loras -= 1 + def set_adapters( self, adapter_names: Union[List[str], str], adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, ): - """ - Set the currently active adapters for use in the pipeline. - - Args: - adapter_names (`List[str]` or `str`): - The names of the adapters to use. - adapter_weights (`Union[List[float], float]`, *optional*): - The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the - adapters. - - Example: - - ```py - from diffusers import AutoPipelineForText2Image - import torch - - pipeline = AutoPipelineForText2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights( - "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" - ) - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) - ``` - """ if isinstance(adapter_weights, dict): components_passed = set(adapter_weights.keys()) lora_components = set(self._lora_loadable_modules) @@ -780,24 +746,6 @@ def set_adapters( set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) def disable_lora(self): - """ - Disables the active LoRA layers of the pipeline. - - Example: - - ```py - from diffusers import AutoPipelineForText2Image - import torch - - pipeline = AutoPipelineForText2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights( - "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" - ) - pipeline.disable_lora() - ``` - """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -810,24 +758,6 @@ def disable_lora(self): disable_lora_for_text_encoder(model) def enable_lora(self): - """ - Enables the active LoRA layers of the pipeline. - - Example: - - ```py - from diffusers import AutoPipelineForText2Image - import torch - - pipeline = AutoPipelineForText2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights( - "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" - ) - pipeline.enable_lora() - ``` - """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -841,26 +771,10 @@ def enable_lora(self): def delete_adapters(self, adapter_names: Union[List[str], str]): """ - Delete an adapter's LoRA layers from the pipeline. - Args: + Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). adapter_names (`Union[List[str], str]`): - The names of the adapters to delete. - - Example: - - ```py - from diffusers import AutoPipelineForText2Image - import torch - - pipeline = AutoPipelineForText2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights( - "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" - ) - pipeline.delete_adapters("cinematic") - ``` + The names of the adapter to delete. Can be a single string or a list of strings """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -937,27 +851,6 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case you want to load multiple adapters and free some GPU memory. - After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters - can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to - GPU before using those LoRA adapters for inference. - - ```python - >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1") - >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2") - >>> pipe.set_adapters("adapter-1") - >>> image_1 = pipe(**kwargs) - >>> # switch to adapter-2, offload adapter-1 - >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu") - >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0") - >>> pipe.set_adapters("adapter-2") - >>> image_2 = pipe(**kwargs) - >>> # switch back to adapter-1, offload adapter-2 - >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu") - >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0") - >>> pipe.set_adapters("adapter-1") - >>> ... - ``` - Args: adapter_names (`List[str]`): List of adapters to send device to. @@ -973,10 +866,6 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, for module in model.modules(): if isinstance(module, BaseTunerLayer): for adapter_name in adapter_names: - if adapter_name not in module.lora_A: - # it is sufficient to check lora_A - continue - module.lora_A[adapter_name].to(device) module.lora_B[adapter_name].to(device) # this is a param, not a module, so device placement is not in-place -> re-assign @@ -986,28 +875,11 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, adapter_name ].to(device) - def enable_lora_hotswap(self, **kwargs) -> None: - """ - Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are - different. - - Args: - target_rank (`int`): - The highest rank among all the adapters that will be loaded. - check_compiled (`str`, *optional*, defaults to `"error"`): - How to handle a model that is already compiled. The check can return the following messages: - - "error" (default): raise an error - - "warn": issue a warning - - "ignore": do nothing - """ - for key, component in self.components.items(): - if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): - component.enable_lora_hotswap(**kwargs) - @staticmethod def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - return _pack_dict_with_prefix(layers_weights, prefix) + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict @staticmethod def write_lora_layers( @@ -1017,33 +889,16 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, - lora_adapter_metadata: Optional[dict] = None, ): - """Writes the state dict of the LoRA layers (optionally with metadata) to disk.""" if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return - if lora_adapter_metadata and not safe_serialization: - raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") - if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict): - raise TypeError("`lora_adapter_metadata` must be of type `dict`.") - if save_function is None: if safe_serialization: def save_function(weights, filename): - # Inject framework format. - metadata = {"format": "pt"} - if lora_adapter_metadata: - for key, value in lora_adapter_metadata.items(): - if isinstance(value, set): - lora_adapter_metadata[key] = list(value) - metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps( - lora_adapter_metadata, indent=2, sort_keys=True - ) - - return safetensors.torch.save_file(weights, filename, metadata=metadata) + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save @@ -1060,6 +915,28 @@ def save_function(weights, filename): save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") - @classmethod - def _optionally_disable_offloading(cls, _pipeline): - return _func_optionally_disable_offloading(_pipeline=_pipeline) \ No newline at end of file + @property + def lora_scale(self) -> float: + # property function that returns the lora scale which can be set at run time by the pipeline. + # if _lora_scale has not been set, return 1 + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + def enable_lora_hotswap(self, **kwargs) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`): + The highest rank among all the adapters that will be loaded. + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + for key, component in self.components.items(): + if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): + component.enable_lora_hotswap(**kwargs) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 1ef7f49a8e9b..c0fb910abcf9 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -433,7 +433,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] if not is_sparse: # down_weight is copied to each split - ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) + ait_sd.update({k: down_weight for k in ait_down_keys}) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 @@ -727,25 +727,8 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): elif k.startswith("lora_te1_"): has_te_keys = True continue - elif k.startswith("lora_transformer_context_embedder"): - diffusers_key = "context_embedder" - elif k.startswith("lora_transformer_norm_out_linear"): - diffusers_key = "norm_out.linear" - elif k.startswith("lora_transformer_proj_out"): - diffusers_key = "proj_out" - elif k.startswith("lora_transformer_x_embedder"): - diffusers_key = "x_embedder" - elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"): - i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1]) - diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}" - elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"): - i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1]) - diffusers_key = f"time_text_embed.text_embedder.linear_{i}" - elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"): - i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1]) - diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}" else: - raise NotImplementedError(f"Handling for key ({k}) is not implemented.") + raise NotImplementedError if "attn_" in k: if "_to_out_0" in k: @@ -836,7 +819,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_pe: logger.info( "The `position_embedding` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the current state dict to make loading possible." + "So, we will purge them out of the curret state dict to make loading possible." ) else: @@ -852,7 +835,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_t5: logger.info( "The `t5xxl` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the current state dict to make loading possible." + "So, we will purge them out of the curret state dict to make loading possible." ) else: logger.info( @@ -867,7 +850,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_diff_b: logger.info( "The `diff_b` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the current state dict to make loading possible." + "So, we will purge them out of the curret state dict to make loading possible." ) else: logger.info( @@ -883,7 +866,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict): if zero_status_diff: logger.info( "The `diff` LoRA params are all zeros which make them ineffective. " - "So, we will purge them out of the current state dict to make loading possible." + "So, we will purge them out of the curret state dict to make loading possible." ) else: logger.info( @@ -940,7 +923,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] # down_weight is copied to each split - ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) + ait_sd.update({k: down_weight for k in ait_down_keys}) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 @@ -1254,7 +1237,7 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): f"double_blocks.{i}.txt_attn.norm.key_norm.scale" ) - # single transformer blocks + # single transfomer blocks for i in range(num_single_layers): block_prefix = f"single_transformer_blocks.{i}." @@ -1346,6 +1329,180 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): return converted_state_dict +def _convert_hunyuan_video_lora_to_diffusers(original_state_dict): + converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())} + + def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + if "lora_A" in key: + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight + else: + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + if "lora_A" in key: + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight + else: + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key: + linear1_weight = state_dict.pop(key) + if "lora_A" in key: + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_A.weight" + ) + state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight + state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight + state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight + state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight + else: + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_B.weight" + ) + state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q + state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k + state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v + state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp + + elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key: + linear1_bias = state_dict.pop(key) + if "lora_A" in key: + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_A.bias" + ) + state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias + state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias + state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias + state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias + else: + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( + ".linear1.lora_B.bias" + ) + state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, + } + + # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys + # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make + # sure that both follow the same initial format by stripping off the "transformer." prefix. + for key in list(converted_state_dict.keys()): + if key.startswith("transformer."): + converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key) + if key.startswith("diffusion_model."): + converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key) + + # Rename and remap the state dict keys + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + # Add back the "transformer." prefix + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict + def _convert_fal_kontext_lora_to_diffusers(original_state_dict): converted_state_dict = {} original_state_dict_keys = list(original_state_dict.keys()) @@ -1567,182 +1724,6 @@ def _convert_fal_kontext_lora_to_diffusers(original_state_dict): return converted_state_dict - -def _convert_hunyuan_video_lora_to_diffusers(original_state_dict): - converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())} - - def remap_norm_scale_shift_(key, state_dict): - weight = state_dict.pop(key) - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight - - def remap_txt_in_(key, state_dict): - def rename_key(key): - new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") - new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") - new_key = new_key.replace("txt_in", "context_embedder") - new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") - new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") - new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") - new_key = new_key.replace("mlp", "ff") - return new_key - - if "self_attn_qkv" in key: - weight = state_dict.pop(key) - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k - state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v - else: - state_dict[rename_key(key)] = state_dict.pop(key) - - def remap_img_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - if "lora_A" in key: - state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight - state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight - state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight - else: - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q - state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k - state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v - - def remap_txt_attn_qkv_(key, state_dict): - weight = state_dict.pop(key) - if "lora_A" in key: - state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight - state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight - state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight - else: - to_q, to_k, to_v = weight.chunk(3, dim=0) - state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q - state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k - state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v - - def remap_single_transformer_blocks_(key, state_dict): - hidden_size = 3072 - - if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key: - linear1_weight = state_dict.pop(key) - if "lora_A" in key: - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_A.weight" - ) - state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight - state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight - state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight - state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight - else: - split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) - q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_B.weight" - ) - state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q - state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k - state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v - state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp - - elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key: - linear1_bias = state_dict.pop(key) - if "lora_A" in key: - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_A.bias" - ) - state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias - state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias - state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias - state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias - else: - split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) - q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) - new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix( - ".linear1.lora_B.bias" - ) - state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias - state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias - state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias - state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias - - else: - new_key = key.replace("single_blocks", "single_transformer_blocks") - new_key = new_key.replace("linear2", "proj_out") - new_key = new_key.replace("q_norm", "attn.norm_q") - new_key = new_key.replace("k_norm", "attn.norm_k") - state_dict[new_key] = state_dict.pop(key) - - TRANSFORMER_KEYS_RENAME_DICT = { - "img_in": "x_embedder", - "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", - "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", - "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", - "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", - "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", - "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", - "double_blocks": "transformer_blocks", - "img_attn_q_norm": "attn.norm_q", - "img_attn_k_norm": "attn.norm_k", - "img_attn_proj": "attn.to_out.0", - "txt_attn_q_norm": "attn.norm_added_q", - "txt_attn_k_norm": "attn.norm_added_k", - "txt_attn_proj": "attn.to_add_out", - "img_mod.linear": "norm1.linear", - "img_norm1": "norm1.norm", - "img_norm2": "norm2", - "img_mlp": "ff", - "txt_mod.linear": "norm1_context.linear", - "txt_norm1": "norm1.norm", - "txt_norm2": "norm2_context", - "txt_mlp": "ff_context", - "self_attn_proj": "attn.to_out.0", - "modulation.linear": "norm.linear", - "pre_norm": "norm.norm", - "final_layer.norm_final": "norm_out.norm", - "final_layer.linear": "proj_out", - "fc1": "net.0.proj", - "fc2": "net.2", - "input_embedder": "proj_in", - } - - TRANSFORMER_SPECIAL_KEYS_REMAP = { - "txt_in": remap_txt_in_, - "img_attn_qkv": remap_img_attn_qkv_, - "txt_attn_qkv": remap_txt_attn_qkv_, - "single_blocks": remap_single_transformer_blocks_, - "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, - } - - # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys - # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make - # sure that both follow the same initial format by stripping off the "transformer." prefix. - for key in list(converted_state_dict.keys()): - if key.startswith("transformer."): - converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key) - if key.startswith("diffusion_model."): - converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key) - - # Rename and remap the state dict keys - for key in list(converted_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - converted_state_dict[new_key] = converted_state_dict.pop(key) - - for key in list(converted_state_dict.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): - if special_key not in key: - continue - handler_fn_inplace(key, converted_state_dict) - - # Add back the "transformer." prefix - for key in list(converted_state_dict.keys()): - converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) - - return converted_state_dict - - def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict): # Remove "diffusion_model." prefix from keys. state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} @@ -1818,175 +1799,48 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): converted_state_dict = {} original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")} - min_block = min(block_numbers) - max_block = max(block_numbers) - + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) - lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" - lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" - has_time_projection_weight = any( - k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict - ) - - for key in list(original_state_dict.keys()): - if key.endswith((".diff", ".diff_b")) and "norm" in key: - # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it - # in future if needed and they are not zeroed. - original_state_dict.pop(key) - logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.") - if "time_projection" in key and not has_time_projection_weight: - # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where - # our lora config adds the time proj lora layers, but we don't have the weights for them. - # CausVid lora has the weight keys and the bias keys. - original_state_dict.pop(key) - - # For the `diff_b` keys, we treat them as lora_bias. - # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias - - for i in range(min_block, max_block + 1): + for i in range(num_blocks): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.self_attn.{o}.diff_b" - converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.lora_B.weight" + ) # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.diff_b" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_B.weight" + ) if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.diff_b" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_B.weight" + ) # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): - original_key = f"blocks.{i}.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.{o}.diff_b" - converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - # Remaining. - if original_state_dict: - if any("time_projection" in k for k in original_state_dict): - original_key = f"time_projection.1.{lora_down_key}.weight" - converted_key = "condition_embedder.time_proj.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"time_projection.1.{lora_up_key}.weight" - converted_key = "condition_embedder.time_proj.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - if "time_projection.1.diff_b" in original_state_dict: - converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop( - "time_projection.1.diff_b" - ) - - if any("head.head" in k for k in state_dict): - converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop( - f"head.head.{lora_down_key}.weight" - ) - converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight") - if "head.head.diff_b" in original_state_dict: - converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b") - - for text_time in ["text_embedding", "time_embedding"]: - if any(text_time in k for k in original_state_dict): - for b_n in [0, 2]: - diffusers_b_n = 1 if b_n == 0 else 2 - diffusers_name = ( - "condition_embedder.text_embedder" - if text_time == "text_embedding" - else "condition_embedder.time_embedder" - ) - if any(f"{text_time}.{b_n}" in k for k in original_state_dict): - converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = ( - original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight") - ) - converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = ( - original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight") - ) - if f"{text_time}.{b_n}.diff_b" in original_state_dict: - converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = ( - original_state_dict.pop(f"{text_time}.{b_n}.diff_b") - ) - - for img_ours, img_theirs in [ - ("ff.net.0.proj", "img_emb.proj.1"), - ("ff.net.2", "img_emb.proj.3"), - ]: - original_key = f"{img_theirs}.{lora_down_key}.weight" - converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"{img_theirs}.{lora_up_key}.weight" - converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.{o}.lora_B.weight" + ) if len(original_state_dict) > 0: - diff = all(".diff" in k for k in original_state_dict) - if diff: - diff_keys = {k for k in original_state_dict if k.endswith(".diff")} - if not all("lora" not in k for k in diff_keys): - raise ValueError - logger.info( - "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: " - "https://github.com/huggingface/diffusers//issues/new" - ) - else: - raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") + raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") for key in list(converted_state_dict.keys()): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) @@ -2053,19 +1907,3 @@ def get_alpha_scales(down_weight, key): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict - - -def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): - if not all(k.startswith(non_diffusers_prefix) for k in state_dict): - raise ValueError("Invalid LoRA state dict for HiDream.") - converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} - converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} - return converted_state_dict - - -def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): - if not all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict): - raise ValueError("Invalid LoRA state dict for LTX-Video.") - converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} - converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} - return converted_state_dict \ No newline at end of file diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f645ef16eb7c..a73e8cdec13f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,6 @@ LoraBaseMixin, _fetch_state_dict, _load_lora_into_text_encoder, - _pack_dict_with_prefix, ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, @@ -45,9 +44,7 @@ _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, - _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, - _convert_non_diffusers_ltxv_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, @@ -83,36 +80,30 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): from ..quantizers.gguf.utils import dequantize_gguf_tensor is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" - is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params" is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" if is_bnb_4bit_quantized and not is_bitsandbytes_available(): raise ValueError( "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." ) - if is_bnb_8bit_quantized and not is_bitsandbytes_available(): - raise ValueError( - "The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints." - ) if is_gguf_quantized and not is_gguf_available(): raise ValueError( "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." ) weight_on_cpu = False - if module.weight.device.type == "cpu": + if not module.weight.is_cuda: weight_on_cpu = True - device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" - if is_bnb_4bit_quantized or is_bnb_8bit_quantized: + if is_bnb_4bit_quantized: module_weight = dequantize_bnb_weight( - module.weight.to(device) if weight_on_cpu else module.weight, - state=module.weight.quant_state if is_bnb_4bit_quantized else module.state, + module.weight.cuda() if weight_on_cpu else module.weight, + state=module.weight.quant_state, dtype=model.dtype, ).data elif is_gguf_quantized: module_weight = dequantize_gguf_tensor( - module.weight.to(device) if weight_on_cpu else module.weight, + module.weight.cuda() if weight_on_cpu else module.weight, ) module_weight = module_weight.to(model.dtype) else: @@ -137,7 +128,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, + adapter_name=None, hotswap: bool = False, **kwargs, ): @@ -164,7 +155,7 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): + hotswap : (`bool`, *optional*) Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means that, instead of loading an additional adapter, this will take the existing adapter weights and replace them with the weights of the new adapter. This can be faster and more @@ -204,8 +195,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -216,7 +206,6 @@ def load_lora_weights( network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -230,7 +219,6 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, - metadata=metadata, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -287,8 +275,6 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -302,16 +288,18 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -348,8 +336,7 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) - return out + return state_dict, network_alphas @classmethod def load_lora_into_unet( @@ -361,7 +348,6 @@ def load_lora_into_unet( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -383,11 +369,29 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -406,7 +410,6 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -424,7 +427,6 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -450,11 +452,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -464,7 +484,6 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -480,8 +499,6 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - unet_lora_adapter_metadata=None, - text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -504,13 +521,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - unet_lora_adapter_metadata: - LoRA adapter metadata associated with the unet to be serialized with the state dict. - text_encoder_lora_adapter_metadata: - LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") @@ -521,14 +533,6 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - if unet_lora_adapter_metadata: - lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -537,7 +541,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -623,7 +626,6 @@ def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name: Optional[str] = None, - hotswap: bool = False, **kwargs, ): """ @@ -650,8 +652,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -673,8 +673,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, network_alphas, metadata = self.lora_state_dict( + state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, @@ -689,10 +688,8 @@ def load_lora_weights( network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) self.load_lora_into_text_encoder( state_dict, @@ -701,10 +698,8 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) self.load_lora_into_text_encoder( state_dict, @@ -713,10 +708,8 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod @@ -772,8 +765,6 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -787,16 +778,18 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -833,8 +826,7 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) - return out + return state_dict, network_alphas @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet @@ -847,7 +839,6 @@ def load_lora_into_unet( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -869,11 +860,29 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -892,7 +901,6 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -911,7 +919,6 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -937,11 +944,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -951,7 +976,6 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -968,9 +992,6 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - unet_lora_adapter_metadata=None, - text_encoder_lora_adapter_metadata=None, - text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -996,15 +1017,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - unet_lora_adapter_metadata: - LoRA adapter metadata associated with the unet to be serialized with the state dict. - text_encoder_lora_adapter_metadata: - LoRA adapter metadata associated with the text encoder to be serialized with the state dict. - text_encoder_2_lora_adapter_metadata: - LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1020,19 +1034,6 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - if unet_lora_adapter_metadata is not None: - lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - - if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") - ) - cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -1040,7 +1041,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -1174,8 +1174,6 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1189,16 +1187,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1219,8 +1219,7 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict def load_lora_weights( self, @@ -1250,8 +1249,29 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1269,8 +1289,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1280,7 +1299,6 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1292,7 +1310,6 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1304,7 +1321,6 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1312,14 +1328,7 @@ def load_lora_weights( @classmethod def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1337,11 +1346,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1354,7 +1381,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1373,7 +1399,6 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1399,11 +1424,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1413,7 +1456,6 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1431,9 +1473,6 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata=None, - text_encoder_lora_adapter_metadata=None, - text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -1459,15 +1498,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. - text_encoder_lora_adapter_metadata: - LoRA adapter metadata associated with the text encoder to be serialized with the state dict. - text_encoder_2_lora_adapter_metadata: - LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1483,21 +1515,6 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - - if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") - ) - cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -1505,7 +1522,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer @@ -1637,8 +1653,6 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1652,16 +1666,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1682,16 +1698,11 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -1709,8 +1720,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1728,8 +1737,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1739,23 +1747,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1773,11 +1772,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1790,7 +1807,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1806,10 +1822,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -1826,21 +1841,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -1850,7 +1858,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -1984,8 +1991,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -2005,7 +2011,10 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -2032,25 +2041,13 @@ def lora_state_dict( if is_kohya: state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) # Kohya already takes care of scaling the LoRA parameters with alpha. - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) + return (state_dict, None) if return_alphas else state_dict is_xlabs = any("processor" in k for k in state_dict) if is_xlabs: state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) # xlabs doesn't use `alpha`. - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) + return (state_dict, None) if return_alphas else state_dict is_bfl_control = any("query_norm.scale" in k for k in state_dict) if is_bfl_control: @@ -2073,7 +2070,7 @@ def lora_state_dict( return_alphas=return_alphas, return_metadata=return_lora_metadata, ) - + # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys()) @@ -2090,21 +2087,15 @@ def lora_state_dict( f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." ) - if return_alphas or return_lora_metadata: - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=network_alphas, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) + if return_alphas: + return state_dict, network_alphas else: return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, + adapter_name=None, hotswap: bool = False, **kwargs, ): @@ -2123,16 +2114,34 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new + adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an + additional method before loading the adapter: + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2148,8 +2157,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, network_alphas, metadata = self.lora_state_dict( + state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) @@ -2200,7 +2208,6 @@ def load_lora_weights( network_alphas=network_alphas, transformer=transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2220,7 +2227,6 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2233,7 +2239,6 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2258,11 +2263,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2275,7 +2298,6 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2293,7 +2315,7 @@ def _load_norm_into_transformer( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) # Find invalid keys transformer_state_dict = transformer.state_dict() @@ -2348,7 +2370,6 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2374,11 +2395,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2388,7 +2427,6 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2405,8 +2443,6 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata=None, - text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -2429,13 +2465,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. - text_encoder_lora_adapter_metadata: - LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") @@ -2446,16 +2477,6 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - if transformer_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -2464,7 +2485,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -2626,7 +2646,7 @@ def _maybe_expand_transformer_param_shape_or_error_( ) -> bool: """ Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and - generalizes things a bit so that any parameter that needs expansion receives appropriate treatment. + generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. """ state_dict = {} if lora_state_dict is not None: @@ -2638,7 +2658,7 @@ def _maybe_expand_transformer_param_shape_or_error_( prefix = prefix or cls.transformer_name for key in list(state_dict.keys()): if key.split(".")[0] == prefix: - state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key) + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False @@ -2756,13 +2776,14 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): if unexpected_modules: logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for k in lora_module_names: if k in unexpected_modules: continue base_param_name = ( f"{k.replace(prefix, '')}.base_layer.weight" - if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict + if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict else f"{k.replace(prefix, '')}.weight" ) base_weight_param = transformer_state_dict[base_param_name] @@ -2816,15 +2837,6 @@ def _get_weight_shape(weight: torch.Tensor): raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") - @staticmethod - def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False): - outputs = [state_dict] - if return_alphas: - outputs.append(alphas) - if return_metadata: - outputs.append(metadata) - return tuple(outputs) if (return_alphas or return_metadata) else state_dict - # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. @@ -2841,7 +2853,6 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2866,11 +2877,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2883,7 +2912,6 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2902,7 +2930,6 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2928,11 +2955,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2942,7 +2987,6 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3062,8 +3106,6 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3077,16 +3119,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3107,15 +3151,10 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3133,8 +3172,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3152,8 +3189,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3163,23 +3199,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3197,11 +3224,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3214,7 +3259,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3230,10 +3274,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -3250,21 +3293,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -3274,7 +3310,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3401,8 +3436,6 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3416,16 +3449,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3446,16 +3481,11 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3473,8 +3503,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3492,8 +3520,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3503,23 +3530,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3537,11 +3555,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3554,7 +3590,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3570,10 +3605,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -3590,21 +3624,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -3614,7 +3641,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3694,6 +3720,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -3742,8 +3769,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -3756,16 +3782,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3786,20 +3814,11 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) - if is_non_diffusers_format: - state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict) - - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -3817,8 +3836,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -3836,8 +3853,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3847,23 +3863,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3881,11 +3888,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3898,7 +3923,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3914,10 +3938,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -3934,21 +3957,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -3958,7 +3974,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4087,8 +4102,6 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -4102,16 +4115,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4132,16 +4147,11 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4159,8 +4169,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4178,8 +4186,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4189,23 +4196,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4223,11 +4221,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4240,7 +4256,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4256,10 +4271,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -4276,21 +4290,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -4300,7 +4307,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4428,8 +4434,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4442,16 +4447,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4476,16 +4483,11 @@ def lora_state_dict( if is_original_hunyuan_video: state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4503,8 +4505,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4522,8 +4522,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4533,23 +4532,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4567,11 +4557,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4584,7 +4592,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4600,10 +4607,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -4620,21 +4626,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -4644,7 +4643,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4772,8 +4770,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4786,16 +4783,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4821,16 +4820,11 @@ def lora_state_dict( if non_diffusers: state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -4848,8 +4842,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -4867,8 +4859,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4878,23 +4869,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4912,11 +4894,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4929,7 +4929,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4945,10 +4944,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -4965,21 +4963,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -4989,7 +4980,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -5117,8 +5107,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -5131,16 +5120,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5165,8 +5156,7 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict @classmethod def _maybe_expand_t2v_lora_for_i2v( @@ -5177,51 +5167,26 @@ def _maybe_expand_t2v_lora_for_i2v( if transformer.config.image_dim is None: return state_dict - target_device = transformer.device - if any(k.startswith("transformer.blocks.") for k in state_dict): - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) - has_bias = any(".lora_B.bias" in k for k in state_dict) if is_i2v_lora: return state_dict for i in range(num_blocks): for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - # These keys should exist if the block `i` was part of the T2V LoRA. - ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" - ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" - - if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: - continue - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"] ) state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"] ) - # If the original LoRA had biases (indicated by has_bias) - # AND the specific reference bias key exists for this block. - - ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" - if has_bias and ref_key_lora_B_bias in state_dict: - ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( - ref_lora_B_bias_tensor, - device=target_device, - ) - return state_dict def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -5239,8 +5204,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -5258,8 +5221,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers state_dict = self._maybe_expand_t2v_lora_for_i2v( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, @@ -5273,23 +5235,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5307,11 +5260,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5324,7 +5295,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5340,10 +5310,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -5360,21 +5329,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -5384,7 +5346,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5513,8 +5474,6 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -5528,16 +5487,18 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5558,16 +5519,11 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out + return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -5585,8 +5541,6 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -5604,8 +5558,7 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -5615,23 +5568,14 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, - metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5649,11 +5593,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5666,7 +5628,6 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, - metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5682,10 +5643,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): @@ -5702,21 +5662,14 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} - lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( @@ -5726,7 +5679,6 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5796,352 +5748,8 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) -class HiDreamImageLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - - state_dict, metadata = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - is_non_diffusers_format = any("diffusion_model" in k for k in state_dict) - if is_non_diffusers_format: - state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict) - - out = (state_dict, metadata) if return_lora_metadata else state_dict - return out - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel - def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`HiDreamImageTransformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - transformer_lora_adapter_metadata: Optional[dict] = None, - ): - r""" - Save the LoRA parameters corresponding to the transformer. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer to be serialized with the state dict. - """ - state_dict = {} - lora_adapter_metadata = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) \ No newline at end of file + super().__init__(*args, **kwargs) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index b1e804cb4089..bfcdfbde78e0 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ import enum import json - from .import_utils import is_torch_available from .logging import get_logger @@ -220,7 +219,7 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): kwargs (`dict`, *args*): Additional arguments to pass to the method. - - **adapter_name**: For example, in case of PEFT, some keys will be prepended + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in `get_peft_model_state_dict` method: https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 @@ -291,7 +290,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): kwargs (`dict`, *args*): Additional arguments to pass to the method. - - **adapter_name**: For example, in case of PEFT, some keys will be prepended + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in `get_peft_model_state_dict` method: https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 @@ -349,7 +348,6 @@ def state_dict_all_zero(state_dict, filter_str=None): return all(torch.all(param == 0).item() for param in state_dict.values()) - def _load_sft_state_dict_metadata(model_file: str): import safetensors.torch @@ -363,4 +361,4 @@ def _load_sft_state_dict_metadata(model_file: str): raw = metadata.get(LORA_ADAPTER_METADATA_KEY) return json.loads(raw) if raw else None else: - return None \ No newline at end of file + return None From 3597fca2979f650b59546f0253969ca8895ea064 Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:47:46 +0900 Subject: [PATCH 480/482] Revert "fix bug" This reverts commit b5495e16203903a1ebb2fcee8982fa285f46b04d. --- src/diffusers/loaders/lora_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index fe5f52c7149a..473840fe38d8 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -63,7 +63,6 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" -LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): From f3dc53689525e94633b70ea44c85d7d735978c5c Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:47:55 +0900 Subject: [PATCH 481/482] Revert "fix bu" This reverts commit 9c44182800fc9b0cb9ff5e5df7a0b109d648d346. --- src/diffusers/utils/state_dict_utils.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index bfcdfbde78e0..3682c5bfacd6 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -16,7 +16,7 @@ """ import enum -import json + from .import_utils import is_torch_available from .logging import get_logger @@ -347,18 +347,3 @@ def state_dict_all_zero(state_dict, filter_str=None): state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} return all(torch.all(param == 0).item() for param in state_dict.values()) - -def _load_sft_state_dict_metadata(model_file: str): - import safetensors.torch - - from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY - - with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: - metadata = f.metadata() or {} - - metadata.pop("format", None) - if metadata: - raw = metadata.get(LORA_ADAPTER_METADATA_KEY) - return json.loads(raw) if raw else None - else: - return None From 1358f473a131b0e82751cbafcc14232d5b15bfaf Mon Sep 17 00:00:00 2001 From: LinchuanXUTheSEAAI Date: Thu, 17 Jul 2025 14:48:27 +0900 Subject: [PATCH 482/482] Revert "Support Fal Kontext LoRA" This reverts commit e270eaaf4acc8e97267d5660de20991c8a20d59f. --- src/diffusers/loaders/lora_base.py | 8 +- .../loaders/lora_conversion_utils.py | 220 ------------------ src/diffusers/loaders/lora_pipeline.py | 23 +- 3 files changed, 3 insertions(+), 248 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 473840fe38d8..280a9fa6e73f 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,7 +25,6 @@ from huggingface_hub.constants import HF_HUB_OFFLINE from ..models.modeling_utils import ModelMixin, load_state_dict -from ..utils.state_dict_utils import _load_sft_state_dict_metadata from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -207,7 +206,6 @@ def _fetch_state_dict( subfolder, user_agent, allow_pickle, - metadata=None, ): model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): @@ -238,14 +236,11 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") - metadata = _load_sft_state_dict_metadata(model_file) - except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e # try loading non-safetensors weights model_file = None - metadata = None pass if model_file is None: @@ -266,11 +261,10 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = load_state_dict(model_file) - metadata = None else: state_dict = pretrained_model_name_or_path_or_dict - return state_dict, metadata + return state_dict def _best_guess_weight_name( diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index c0fb910abcf9..7fec3299eeac 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1503,226 +1503,6 @@ def remap_single_transformer_blocks_(key, state_dict): return converted_state_dict -def _convert_fal_kontext_lora_to_diffusers(original_state_dict): - converted_state_dict = {} - original_state_dict_keys = list(original_state_dict.keys()) - num_layers = 19 - num_single_layers = 38 - inner_dim = 3072 - mlp_ratio = 4.0 - - # double transformer blocks - for i in range(num_layers): - block_prefix = f"transformer_blocks.{i}." - original_block_prefix = "base_model.model." - - for lora_key in ["lora_A", "lora_B"]: - # norms - converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight" - ) - if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight" - ) - - # Q, K, V - if lora_key == "lora_A": - sample_lora_weight = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight]) - - context_lora_weight = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight" - ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( - [context_lora_weight] - ) - else: - sample_q, sample_k, sample_v = torch.chunk( - original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight" - ), - 3, - dim=0, - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v]) - - context_q, context_k, context_v = torch.chunk( - original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight" - ), - 3, - dim=0, - ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v]) - - if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys: - sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( - original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), - 3, - dim=0, - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) - - if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys: - context_q_bias, context_k_bias, context_v_bias = torch.chunk( - original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), - 3, - dim=0, - ) - converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias]) - converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias]) - converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias]) - - # ff img_mlp - converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight" - ) - if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight" - ) - if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight" - ) - if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" - ) - - converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight" - ) - if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" - ) - - # output projections. - converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight" - ) - if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" - ) - converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight" - ) - if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" - ) - - # single transformer blocks - for i in range(num_single_layers): - block_prefix = f"single_transformer_blocks.{i}." - - for lora_key in ["lora_A", "lora_B"]: - # norm.linear <- single_blocks.0.modulation.lin - converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight" - ) - if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" - ) - - # Q, K, V, mlp - mlp_hidden_dim = int(inner_dim * mlp_ratio) - split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) - - if lora_key == "lora_A": - lora_weight = original_state_dict.pop( - f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight" - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) - - if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: - lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias]) - else: - q, k, v, mlp = torch.split( - original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"), - split_size, - dim=0, - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) - - if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: - q_bias, k_bias, v_bias, mlp_bias = torch.split( - original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"), - split_size, - dim=0, - ) - converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias]) - converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias]) - converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias]) - converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias]) - - # output projections. - converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight" - ) - if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" - ) - - for lora_key in ["lora_A", "lora_B"]: - converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( - f"{original_block_prefix}final_layer.linear.{lora_key}.weight" - ) - if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys: - converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( - f"{original_block_prefix}final_layer.linear.{lora_key}.bias" - ) - - if len(original_state_dict) > 0: - raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") - - for key in list(converted_state_dict.keys()): - converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) - - return converted_state_dict def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict): # Remove "diffusion_model." prefix from keys. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a73e8cdec13f..aa508cf87f40 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -40,7 +40,6 @@ ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, - _convert_fal_kontext_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, @@ -2004,7 +2003,6 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -2016,7 +2014,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict, metadata = _fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2052,25 +2050,8 @@ def lora_state_dict( is_bfl_control = any("query_norm.scale" in k for k in state_dict) if is_bfl_control: state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) + return (state_dict, None) if return_alphas else state_dict - is_fal_kontext = any("base_model" in k for k in state_dict) - if is_fal_kontext: - state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict) - return cls._prepare_outputs( - state_dict, - metadata=metadata, - alphas=None, - return_alphas=return_alphas, - return_metadata=return_lora_metadata, - ) - # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA keys = list(state_dict.keys())