Skip to content

Commit 64909f1

Browse files
yiyixuxuyiyixuxu
andauthored
update IP-adapter code in UNetMotionModel (huggingface#6828)
fix Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent f09ca90 commit 64909f1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -792,17 +792,17 @@ def forward(
792792

793793
emb = self.time_embedding(t_emb, timestep_cond)
794794
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
795+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
795796

796797
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
797798
if "image_embeds" not in added_cond_kwargs:
798799
raise ValueError(
799800
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
800801
)
801802
image_embeds = added_cond_kwargs.get("image_embeds")
802-
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
803-
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
804-
805-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
803+
image_embeds = self.encoder_hid_proj(image_embeds)
804+
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
805+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
806806

807807
# 2. pre-process
808808
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])

0 commit comments

Comments
 (0)