Skip to content

Commit 491a933

Browse files
authored
[I2VGenXL] attention_head_dim in the UNet (huggingface#6872)
* attention_head_dim * debug * print more info * correct num_attention_heads behaviour * down_block_num_attention_heads -> num_attention_heads. * correct the image link in doc. * add: deprecation for num_attention_head * fix: test argument to use attention_head_dim * more fixes. * quality * address comments. * remove depcrecation.
1 parent aa82df5 commit 491a933

File tree

4 files changed

+15
-3
lines changed

4 files changed

+15
-3
lines changed

src/diffusers/models/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
super().__init__()
159159
self.only_cross_attention = only_cross_attention
160160

161+
# We keep these boolean flags for backward-compatibility.
161162
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
162163
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
163164
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"

src/diffusers/models/unets/unet_i2vgen_xl.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
120120
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
121121
If `None`, normalization and activation layers is skipped in post-processing.
122122
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
123+
attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim.
123124
num_attention_heads (`int`, *optional*): The number of attention heads.
124125
"""
125126

@@ -147,10 +148,19 @@ def __init__(
147148
layers_per_block: int = 2,
148149
norm_num_groups: Optional[int] = 32,
149150
cross_attention_dim: int = 1024,
150-
num_attention_heads: Optional[Union[int, Tuple[int]]] = 64,
151+
attention_head_dim: Union[int, Tuple[int]] = 64,
152+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
151153
):
152154
super().__init__()
153155

156+
# When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence
157+
# of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This
158+
# is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below.
159+
# This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it
160+
# without running proper depcrecation cycles for the {down,mid,up} blocks which are a
161+
# part of the public API.
162+
num_attention_heads = attention_head_dim
163+
154164
# Check inputs
155165
if len(down_block_types) != len(up_block_types):
156166
raise ValueError(

src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
>>> pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
4747
>>> pipeline.enable_model_cpu_offload()
4848
49-
>>> image_url = "https://github.com/ali-vilab/i2vgen-xl/blob/main/data/test_images/img_0009.png?raw=true"
49+
>>> image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
5050
>>> image = load_image(image_url).convert("RGB")
5151
5252
>>> prompt = "Papers were floating in the air on a table in the library"

tests/pipelines/i2vgen_xl/test_i2vgenxl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def get_dummy_components(self):
8080
down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
8181
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
8282
cross_attention_dim=4,
83-
num_attention_heads=4,
83+
attention_head_dim=4,
84+
num_attention_heads=None,
8485
norm_num_groups=2,
8586
)
8687

0 commit comments

Comments
 (0)