diff --git a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py index 4fc00bb..d4e78c3 100644 --- a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py +++ b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py @@ -3,6 +3,7 @@ import numpy as np import torch from diffusers import DiffusionPipeline +from diffusers.models.embeddings import MultiIPAdapterImageProjection from diffusers.utils import load_image from PIL import Image from torch import nn @@ -167,9 +168,21 @@ def infer(self, ) adapter_state_dict = process_ip_adapter_state_dict( self.unet, self.image_projection) - pipeline.load_ip_adapter( - pretrained_model_name_or_path_or_dict=adapter_state_dict, - subfolder="", weight_name="") + + # convert IP-Adapter Image Projection layers to diffusers + image_projection_layers = [] + for state_dict in [adapter_state_dict]: + image_projection_layer = ( + pipeline.unet._convert_ip_adapter_image_proj_to_diffusers( # noqa + state_dict["image_proj"])) + image_projection_layer.to( + device=pipeline.unet.device, dtype=pipeline.unet.dtype) + image_projection_layers.append(image_projection_layer) + + pipeline.unet.encoder_hid_proj = MultiIPAdapterImageProjection( + image_projection_layers) + pipeline.unet.config.encoder_hid_dim_type = "ip_image_proj" + if self.prediction_type is not None: # set prediction_type of scheduler if defined scheduler_args = {"prediction_type": self.prediction_type} @@ -276,12 +289,11 @@ def forward( # TODO(takuoko): drop image # noqa ip_tokens = self.image_projection(image_embeds) - prompt_embeds = torch.cat([prompt_embeds, ip_tokens], dim=1) model_pred = self.unet( noisy_latents, timesteps, - prompt_embeds, + (prompt_embeds, ip_tokens), added_cond_kwargs=unet_added_conditions).sample return self.loss(model_pred, noise, latents, timesteps, weight) @@ -379,12 +391,11 @@ def forward( # TODO(takuoko): drop image # noqa ip_tokens = self.image_projection(image_embeds) - prompt_embeds = torch.cat([prompt_embeds, ip_tokens], dim=1) model_pred = self.unet( noisy_latents, timesteps, - prompt_embeds, + (prompt_embeds, ip_tokens), added_cond_kwargs=unet_added_conditions).sample return self.loss(model_pred, noise, latents, timesteps, weight)