Skip to content

Commit aa82df5

Browse files
sayakpaulyiyixuxu
andauthored
[IP Adapters] introduce ip_adapter_image_embeds in the SD pipeline call (huggingface#6868)
* add: support for passing ip adapter image embeddings * debugging * make feature_extractor unloading conditioned on safety_checker * better condition * type annotation * index to look into value slices * more debugging * debugging * serialize embeddings dict * better conditioning * remove unnecessary prints. * Update src/diffusers/loaders/ip_adapter.py Co-authored-by: YiYi Xu <[email protected]> * make fix-copies and styling. * styling and further copy fixing. * fix: check_inputs call in controlnet sdxl img2img pipeline --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent a11b0f8 commit aa82df5

File tree

33 files changed

+779
-439
lines changed

33 files changed

+779
-439
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,12 @@ def unload_ip_adapter(self):
210210
self.image_encoder = None
211211
self.register_to_config(image_encoder=[None, None])
212212

213-
# remove feature extractor
214-
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
215-
self.feature_extractor = None
216-
self.register_to_config(feature_extractor=[None, None])
213+
# remove feature extractor only when safety_checker is None as safety_checker uses
214+
# the feature_extractor later
215+
if not hasattr(self, "safety_checker"):
216+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
217+
self.feature_extractor = None
218+
self.register_to_config(feature_extractor=[None, None])
217219

218220
# remove hidden encoder
219221
self.unet.encoder_hid_proj = None

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -427,32 +427,38 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
427427
return image_embeds, uncond_image_embeds
428428

429429
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
430-
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
431-
if not isinstance(ip_adapter_image, list):
432-
ip_adapter_image = [ip_adapter_image]
433-
434-
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
435-
raise ValueError(
436-
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
437-
)
430+
def prepare_ip_adapter_image_embeds(
431+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
432+
):
433+
if ip_adapter_image_embeds is None:
434+
if not isinstance(ip_adapter_image, list):
435+
ip_adapter_image = [ip_adapter_image]
438436

439-
image_embeds = []
440-
for single_ip_adapter_image, image_proj_layer in zip(
441-
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
442-
):
443-
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
444-
single_image_embeds, single_negative_image_embeds = self.encode_image(
445-
single_ip_adapter_image, device, 1, output_hidden_state
446-
)
447-
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
448-
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
437+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
438+
raise ValueError(
439+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
440+
)
449441

450-
if self.do_classifier_free_guidance:
451-
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
452-
single_image_embeds = single_image_embeds.to(device)
442+
image_embeds = []
443+
for single_ip_adapter_image, image_proj_layer in zip(
444+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
445+
):
446+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
447+
single_image_embeds, single_negative_image_embeds = self.encode_image(
448+
single_ip_adapter_image, device, 1, output_hidden_state
449+
)
450+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
451+
single_negative_image_embeds = torch.stack(
452+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
453+
)
453454

454-
image_embeds.append(single_image_embeds)
455+
if self.do_classifier_free_guidance:
456+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
457+
single_image_embeds = single_image_embeds.to(device)
455458

459+
image_embeds.append(single_image_embeds)
460+
else:
461+
image_embeds = ip_adapter_image_embeds
456462
return image_embeds
457463

458464
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
@@ -620,6 +626,8 @@ def check_inputs(
620626
negative_prompt=None,
621627
prompt_embeds=None,
622628
negative_prompt_embeds=None,
629+
ip_adapter_image=None,
630+
ip_adapter_image_embeds=None,
623631
callback_on_step_end_tensor_inputs=None,
624632
):
625633
if height % 8 != 0 or width % 8 != 0:
@@ -663,6 +671,11 @@ def check_inputs(
663671
f" {negative_prompt_embeds.shape}."
664672
)
665673

674+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
675+
raise ValueError(
676+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
677+
)
678+
666679
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
667680
def prepare_latents(
668681
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
@@ -882,6 +895,7 @@ def __call__(
882895
prompt_embeds: Optional[torch.FloatTensor] = None,
883896
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
884897
ip_adapter_image: Optional[PipelineImageInput] = None,
898+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
885899
output_type: Optional[str] = "pil",
886900
return_dict: bool = True,
887901
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -931,6 +945,9 @@ def __call__(
931945
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
932946
ip_adapter_image: (`PipelineImageInput`, *optional*):
933947
Optional image input to work with IP Adapters.
948+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
949+
Pre-generated image embeddings for IP-Adapter. If not
950+
provided, embeddings are computed from the `ip_adapter_image` input argument.
934951
output_type (`str`, *optional*, defaults to `"pil"`):
935952
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
936953
`np.array`.
@@ -992,6 +1009,8 @@ def __call__(
9921009
negative_prompt,
9931010
prompt_embeds,
9941011
negative_prompt_embeds,
1012+
ip_adapter_image,
1013+
ip_adapter_image_embeds,
9951014
callback_on_step_end_tensor_inputs,
9961015
)
9971016

@@ -1030,9 +1049,9 @@ def __call__(
10301049
if self.do_classifier_free_guidance:
10311050
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
10321051

1033-
if ip_adapter_image is not None:
1052+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
10341053
image_embeds = self.prepare_ip_adapter_image_embeds(
1035-
ip_adapter_image, device, batch_size * num_videos_per_prompt
1054+
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
10361055
)
10371056

10381057
# 4. Prepare timesteps

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -507,32 +507,38 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
507507
return image_embeds, uncond_image_embeds
508508

509509
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
510-
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
511-
if not isinstance(ip_adapter_image, list):
512-
ip_adapter_image = [ip_adapter_image]
513-
514-
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
515-
raise ValueError(
516-
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
517-
)
510+
def prepare_ip_adapter_image_embeds(
511+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
512+
):
513+
if ip_adapter_image_embeds is None:
514+
if not isinstance(ip_adapter_image, list):
515+
ip_adapter_image = [ip_adapter_image]
518516

519-
image_embeds = []
520-
for single_ip_adapter_image, image_proj_layer in zip(
521-
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
522-
):
523-
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
524-
single_image_embeds, single_negative_image_embeds = self.encode_image(
525-
single_ip_adapter_image, device, 1, output_hidden_state
526-
)
527-
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
528-
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
517+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
518+
raise ValueError(
519+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
520+
)
529521

530-
if self.do_classifier_free_guidance:
531-
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
532-
single_image_embeds = single_image_embeds.to(device)
522+
image_embeds = []
523+
for single_ip_adapter_image, image_proj_layer in zip(
524+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
525+
):
526+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
527+
single_image_embeds, single_negative_image_embeds = self.encode_image(
528+
single_ip_adapter_image, device, 1, output_hidden_state
529+
)
530+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
531+
single_negative_image_embeds = torch.stack(
532+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
533+
)
533534

534-
image_embeds.append(single_image_embeds)
535+
if self.do_classifier_free_guidance:
536+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
537+
single_image_embeds = single_image_embeds.to(device)
535538

539+
image_embeds.append(single_image_embeds)
540+
else:
541+
image_embeds = ip_adapter_image_embeds
536542
return image_embeds
537543

538544
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
@@ -588,6 +594,8 @@ def check_inputs(
588594
negative_prompt=None,
589595
prompt_embeds=None,
590596
negative_prompt_embeds=None,
597+
ip_adapter_image=None,
598+
ip_adapter_image_embeds=None,
591599
controlnet_conditioning_scale=1.0,
592600
control_guidance_start=0.0,
593601
control_guidance_end=1.0,
@@ -726,6 +734,11 @@ def check_inputs(
726734
if end > 1.0:
727735
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
728736

737+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
738+
raise ValueError(
739+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
740+
)
741+
729742
def check_image(self, image, prompt, prompt_embeds):
730743
image_is_pil = isinstance(image, PIL.Image.Image)
731744
image_is_tensor = isinstance(image, torch.Tensor)
@@ -910,6 +923,7 @@ def __call__(
910923
prompt_embeds: Optional[torch.FloatTensor] = None,
911924
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
912925
ip_adapter_image: Optional[PipelineImageInput] = None,
926+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
913927
output_type: Optional[str] = "pil",
914928
return_dict: bool = True,
915929
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -974,6 +988,9 @@ def __call__(
974988
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
975989
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
976990
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
991+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
992+
Pre-generated image embeddings for IP-Adapter. If not
993+
provided, embeddings are computed from the `ip_adapter_image` input argument.
977994
output_type (`str`, *optional*, defaults to `"pil"`):
978995
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
979996
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1060,6 +1077,8 @@ def __call__(
10601077
negative_prompt,
10611078
prompt_embeds,
10621079
negative_prompt_embeds,
1080+
ip_adapter_image,
1081+
ip_adapter_image_embeds,
10631082
controlnet_conditioning_scale,
10641083
control_guidance_start,
10651084
control_guidance_end,
@@ -1111,9 +1130,9 @@ def __call__(
11111130
if self.do_classifier_free_guidance:
11121131
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
11131132

1114-
if ip_adapter_image is not None:
1133+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
11151134
image_embeds = self.prepare_ip_adapter_image_embeds(
1116-
ip_adapter_image, device, batch_size * num_images_per_prompt
1135+
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
11171136
)
11181137

11191138
# 4. Prepare image

0 commit comments

Comments
 (0)