Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7e1bf9b
allow list
Jan 14, 2024
3b55a1e
update
Jan 14, 2024
345b4d6
save
Jan 14, 2024
f6bae6e
add
Jan 15, 2024
baa7b83
remove print lines
Jan 15, 2024
9024698
style
Jan 15, 2024
27fc796
support multi-image
Jan 17, 2024
67908bf
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 17, 2024
afd91e3
fix a typo
Jan 17, 2024
7d42455
Merge branch 'multi-ipadapter' of github.com:huggingface/diffusers in…
Jan 17, 2024
1fdcd7b
fix
Jan 17, 2024
45fb582
fix a bug!
yiyixuxu Jan 18, 2024
1a4c6b1
fix
Jan 19, 2024
d924c47
update
Jan 19, 2024
cc2aa1b
fix
Jan 19, 2024
2c86534
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 19, 2024
0049e44
Apply suggestions from code review
yiyixuxu Jan 24, 2024
4a3df90
ImageProjectionLayers -> image_projection_layers
Jan 24, 2024
ff96407
merge
Jan 24, 2024
193d6e8
fix
Jan 24, 2024
f7f2465
fix-copies
Jan 24, 2024
efa704a
update test
Jan 25, 2024
9abdcf9
add test
Jan 25, 2024
5e47ceb
add prepare_ip_adapter_image_embeds method so pipelines can copy from
Jan 25, 2024
c6670de
update all pipelines support ip-adapter
Jan 25, 2024
711387e
deprecate image_embeds as 3d tensor
Jan 25, 2024
bce309f
corrent num_images_per_prompt behavior
Jan 25, 2024
fae861e
fix batching behavior
Jan 26, 2024
21da205
revert for lcm and sd safe
Jan 26, 2024
accee6b
correct tests
Jan 26, 2024
816578f
update attention processer so backward compatible
Jan 29, 2024
e57103f
revert changes made to ipadapter attention processor to follow the de…
Jan 30, 2024
6cfa34b
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 30, 2024
1e68c64
add doc
Jan 30, 2024
2cc1561
update doc
Jan 30, 2024
475046e
update docstring
Jan 30, 2024
98fa0c2
add a slow test for multi
Jan 30, 2024
3a52ecb
remove ddim config
Jan 30, 2024
e742cf4
style
Jan 31, 2024
dcdde9c
Merge branch 'main' into multi-ipadapter
yiyixuxu Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
yiyixuxu committed Jan 19, 2024
commit d924c47bb1ffe8c842df5fc20d8bf76218f5bc11
45 changes: 16 additions & 29 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,6 @@ def __call__(
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
**kwargs,
) -> torch.Tensor:
residual = hidden_states

Expand Down Expand Up @@ -1197,7 +1196,6 @@ def __call__(
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
Expand Down Expand Up @@ -2125,12 +2123,15 @@ def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
encoder_hidden_states,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep encoder_hidden_states optional here for backwards breaking

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok
but I think encocer_hidden_state was not actually an optional argument in the previous code, even though it followed the same code structure as the default attention processor, so it will be able to run when encoder_hidden_state is None. however, it makes no sense though, and the ip-adapter attention processor can not be used for self-attention. so I think it was a mistake in the signature. If that's the case, it would not be a breaking change to remove =None, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to follow closely with default attention processor

attention_mask=None,
temb=None,
**kwargs,
):
residual = hidden_states

encoder_hidden_states, ip_hidden_states = encoder_hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)

Let's try to make the class backwards compatible


if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand All @@ -2140,28 +2141,18 @@ def __call__(
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
batch_size, sequence_length, _ = encoder_hidden_states.shape

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

# split hidden states
end_pos = encoder_hidden_states.shape[1] - sum(self.num_tokens)
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

Expand All @@ -2174,11 +2165,9 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)

# for ip-adapter
for num_token, scale, to_k_ip, to_v_ip in zip(self.num_tokens, self.scale, self.to_k_ip, self.to_v_ip):
current_ip_hidden_states, ip_hidden_states = (
ip_hidden_states[:, :num_token, :],
ip_hidden_states[:, num_token:, :],
)
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

Expand Down Expand Up @@ -2254,13 +2243,15 @@ def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
encoder_hidden_states,
attention_mask=None,
temb=None,
ip_hidden_states=None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not add a kwargs here unnecessarily

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's not unnecessary and all processors should use this layout as it is much more flexible and less error prone if cross_attention_kwargs are actually used. At the moment it's not possible to use a processor with custom signature without awkward cross_attention_kwargs dict restructuring in attention.py, a core file, as cross_attention_kwargs has to match all receiving function signatures.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we don't need **kwargs here, but I'm slightly in favor of adding it - I think it's more readable and makes mix-and-match attention processors easier here.

For example, the scale argument is not used in the IP-adapter attention calculation at all. However, we have to add it in the signature; otherwise, it won't work because the IP-adapter attention processor is used together with the default attention processors, which takes a scale argument.

Likewise, if we want to add a new argument to ip-adapter attention processor, we would have to add that argument to the default attention_processor's signature too (e.g. if we are going to support ip-adapter attention mask, we will have to add a new ip-attention-mask argument) - I think it is more confusing than simply appending a **kwargs to the signature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Yiyi here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you feel strongly here @yiyixuxu go for it! Still don't think we should add it though because adding kwargs makes it much harder for people to spot errors. E.g. if a function has a **kwarg input, it swallows everything.

E.g. if you have:

def forward(input_values, attention_scale=1.0, **kwargs):
      ...

and someone write forward(input_values, atention_scale=5.0) # note the typo for attention_scale here.

Then this code is executed nevertheless with attention_scale set to 1.0 though because atention_scale is swallod by kwargs.

Most programming languages don't allow something like **kwargs for good reason in my opinion.
Also it sets a bad precedent and incentivizes people to add kwargs everywhere.
It also significantly hurts readability. Having a kwargs argument means that you can never be sure what can all be passed to the function making it harder to understand the code.

But again, I do see the upsides you mention above. Another possibility that we could start to explore more is to add validation function decorators that are used quite heavily in huggingface_hub: https://github.com/huggingface/huggingface_hub/blob/c528f78fa7404869f32a2b04eac0f9964a5c6f9e/src/huggingface_hub/utils/_validators.py#L46 we could use such a validator to remove an argument like scale

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using validation decorators would be very nice to have. I am happy to take a look at it next week.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation! make sense! I will not add **kwargs for now:)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I posted a comment here to complement on @patrickvonplaten's comment above. IMO using a decorator to avoid **kwargs is legit for some use cases while more debatable in other situations. (just to keep it in mind)

):
residual = hidden_states

encoder_hidden_states, ip_hidden_states = encoder_hidden_states

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand All @@ -2270,9 +2261,7 @@ def __call__(
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
batch_size, sequence_length, _ = encoder_hidden_states.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
Expand All @@ -2285,9 +2274,7 @@ def __call__(

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

# split hidden states
Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,9 +1075,7 @@ def forward(
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
if cross_attention_kwargs is None:
cross_attention_kwargs = {}
cross_attention_kwargs["ip_hidden_states"] = image_embeds
encoder_hidden_states = (encoder_hidden_states, image_embeds)

# 2. pre-process
sample = self.conv_in(sample)
Expand Down