@@ -507,32 +507,38 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
507
507
return image_embeds , uncond_image_embeds
508
508
509
509
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
510
- def prepare_ip_adapter_image_embeds (self , ip_adapter_image , device , num_images_per_prompt ):
511
- if not isinstance (ip_adapter_image , list ):
512
- ip_adapter_image = [ip_adapter_image ]
513
-
514
- if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
515
- raise ValueError (
516
- f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
517
- )
510
+ def prepare_ip_adapter_image_embeds (
511
+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
512
+ ):
513
+ if ip_adapter_image_embeds is None :
514
+ if not isinstance (ip_adapter_image , list ):
515
+ ip_adapter_image = [ip_adapter_image ]
518
516
519
- image_embeds = []
520
- for single_ip_adapter_image , image_proj_layer in zip (
521
- ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
522
- ):
523
- output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
524
- single_image_embeds , single_negative_image_embeds = self .encode_image (
525
- single_ip_adapter_image , device , 1 , output_hidden_state
526
- )
527
- single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
528
- single_negative_image_embeds = torch .stack ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
517
+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
518
+ raise ValueError (
519
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
520
+ )
529
521
530
- if self .do_classifier_free_guidance :
531
- single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
532
- single_image_embeds = single_image_embeds .to (device )
522
+ image_embeds = []
523
+ for single_ip_adapter_image , image_proj_layer in zip (
524
+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
525
+ ):
526
+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
527
+ single_image_embeds , single_negative_image_embeds = self .encode_image (
528
+ single_ip_adapter_image , device , 1 , output_hidden_state
529
+ )
530
+ single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
531
+ single_negative_image_embeds = torch .stack (
532
+ [single_negative_image_embeds ] * num_images_per_prompt , dim = 0
533
+ )
533
534
534
- image_embeds .append (single_image_embeds )
535
+ if self .do_classifier_free_guidance :
536
+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
537
+ single_image_embeds = single_image_embeds .to (device )
535
538
539
+ image_embeds .append (single_image_embeds )
540
+ else :
541
+ image_embeds = ip_adapter_image_embeds
536
542
return image_embeds
537
543
538
544
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
@@ -588,6 +594,8 @@ def check_inputs(
588
594
negative_prompt = None ,
589
595
prompt_embeds = None ,
590
596
negative_prompt_embeds = None ,
597
+ ip_adapter_image = None ,
598
+ ip_adapter_image_embeds = None ,
591
599
controlnet_conditioning_scale = 1.0 ,
592
600
control_guidance_start = 0.0 ,
593
601
control_guidance_end = 1.0 ,
@@ -726,6 +734,11 @@ def check_inputs(
726
734
if end > 1.0 :
727
735
raise ValueError (f"control guidance end: { end } can't be larger than 1.0." )
728
736
737
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
738
+ raise ValueError (
739
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
740
+ )
741
+
729
742
def check_image (self , image , prompt , prompt_embeds ):
730
743
image_is_pil = isinstance (image , PIL .Image .Image )
731
744
image_is_tensor = isinstance (image , torch .Tensor )
@@ -910,6 +923,7 @@ def __call__(
910
923
prompt_embeds : Optional [torch .FloatTensor ] = None ,
911
924
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
912
925
ip_adapter_image : Optional [PipelineImageInput ] = None ,
926
+ ip_adapter_image_embeds : Optional [List [torch .FloatTensor ]] = None ,
913
927
output_type : Optional [str ] = "pil" ,
914
928
return_dict : bool = True ,
915
929
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -974,6 +988,9 @@ def __call__(
974
988
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
975
989
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
976
990
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
991
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
992
+ Pre-generated image embeddings for IP-Adapter. If not
993
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
977
994
output_type (`str`, *optional*, defaults to `"pil"`):
978
995
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
979
996
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1060,6 +1077,8 @@ def __call__(
1060
1077
negative_prompt ,
1061
1078
prompt_embeds ,
1062
1079
negative_prompt_embeds ,
1080
+ ip_adapter_image ,
1081
+ ip_adapter_image_embeds ,
1063
1082
controlnet_conditioning_scale ,
1064
1083
control_guidance_start ,
1065
1084
control_guidance_end ,
@@ -1111,9 +1130,9 @@ def __call__(
1111
1130
if self .do_classifier_free_guidance :
1112
1131
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
1113
1132
1114
- if ip_adapter_image is not None :
1133
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
1115
1134
image_embeds = self .prepare_ip_adapter_image_embeds (
1116
- ip_adapter_image , device , batch_size * num_images_per_prompt
1135
+ ip_adapter_image , ip_adapter_image_embeds , device , batch_size * num_images_per_prompt
1117
1136
)
1118
1137
1119
1138
# 4. Prepare image
0 commit comments