99import PIL .Image
1010import torch
1111from packaging import version
12- from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
12+ from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
1313
1414import diffusers
1515from diffusers import SchedulerMixin , StableDiffusionPipeline
@@ -520,6 +520,7 @@ def __init__(
520520 safety_checker : StableDiffusionSafetyChecker ,
521521 feature_extractor : CLIPFeatureExtractor ,
522522 requires_safety_checker : bool = True ,
523+ image_encoder : CLIPVisionModelWithProjection = None ,
523524 clip_skip : int = 1 ,
524525 ):
525526 super ().__init__ (
@@ -531,32 +532,11 @@ def __init__(
531532 safety_checker = safety_checker ,
532533 feature_extractor = feature_extractor ,
533534 requires_safety_checker = requires_safety_checker ,
535+ image_encoder = image_encoder ,
534536 )
535- self .clip_skip = clip_skip
537+ self .custom_clip_skip = clip_skip
536538 self .__init__additional__ ()
537539
538- # else:
539- # def __init__(
540- # self,
541- # vae: AutoencoderKL,
542- # text_encoder: CLIPTextModel,
543- # tokenizer: CLIPTokenizer,
544- # unet: UNet2DConditionModel,
545- # scheduler: SchedulerMixin,
546- # safety_checker: StableDiffusionSafetyChecker,
547- # feature_extractor: CLIPFeatureExtractor,
548- # ):
549- # super().__init__(
550- # vae=vae,
551- # text_encoder=text_encoder,
552- # tokenizer=tokenizer,
553- # unet=unet,
554- # scheduler=scheduler,
555- # safety_checker=safety_checker,
556- # feature_extractor=feature_extractor,
557- # )
558- # self.__init__additional__()
559-
560540 def __init__additional__ (self ):
561541 if not hasattr (self , "vae_scale_factor" ):
562542 setattr (self , "vae_scale_factor" , 2 ** (len (self .vae .config .block_out_channels ) - 1 ))
@@ -624,7 +604,7 @@ def _encode_prompt(
624604 prompt = prompt ,
625605 uncond_prompt = negative_prompt if do_classifier_free_guidance else None ,
626606 max_embeddings_multiples = max_embeddings_multiples ,
627- clip_skip = self .clip_skip ,
607+ clip_skip = self .custom_clip_skip ,
628608 )
629609 bs_embed , seq_len , _ = text_embeddings .shape
630610 text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 )
0 commit comments